418 lines
17 KiB
Python
418 lines
17 KiB
Python
import argparse
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import logging
|
|
from datetime import datetime
|
|
import os
|
|
from torch.utils.data import DataLoader
|
|
from os.path import join
|
|
from torch.nn import DataParallel
|
|
import transformers
|
|
import pickle
|
|
from pytorchtools import EarlyStopping
|
|
from transformers import GPT2LMHeadModel, GPT2Config
|
|
from transformers import BertTokenizerFast
|
|
import torch.nn.utils.rnn as rnn_utils
|
|
from dataset import MyDataset
|
|
|
|
|
|
def set_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--device', default='3', type=str, required=False, help='设置使用哪些显卡')
|
|
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行训练')
|
|
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False,
|
|
help='词表路径')
|
|
parser.add_argument('--model_config', default='config/config.json', type=str, required=False,
|
|
help='设置模型参数')
|
|
parser.add_argument('--train_path', default='data/train.pkl', type=str, required=False, help='训练集路径')
|
|
parser.add_argument('--max_len', default=150, type=int, required=False, help='训练时,输入数据的最大长度')
|
|
|
|
parser.add_argument('--log_path', default='data/train.log', type=str, required=False, help='训练日志存放位置')
|
|
parser.add_argument('--log', default=True, help="是否记录日志")
|
|
parser.add_argument('--ignore_index', default=-100, type=int, required=False, help='对于ignore_index的label token不计算梯度')
|
|
# parser.add_argument('--input_len', default=200, type=int, required=False, help='输入的长度')
|
|
parser.add_argument('--epochs', default=100, type=int, required=False, help='训练的最大轮次')
|
|
parser.add_argument('--batch_size', default=4, type=int, required=False, help='训练的batch size')
|
|
parser.add_argument('--gpu0_bsz', default=10, type=int, required=False, help='0号卡的batch size')
|
|
parser.add_argument('--lr', default=2.6e-5, type=float, required=False, help='学习率')
|
|
parser.add_argument('--eps', default=1.0e-09, type=float, required=False, help='衰减率')
|
|
parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss')
|
|
parser.add_argument('--gradient_accumulation_steps', default=4, type=int, required=False, help='梯度积累')
|
|
parser.add_argument('--max_grad_norm', default=2.0, type=float, required=False)
|
|
parser.add_argument('--save_model_path', default='model', type=str, required=False,
|
|
help='模型输出路径')
|
|
parser.add_argument('--pretrained_model', default='', type=str, required=False,
|
|
help='预训练的模型的路径')
|
|
# parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
|
|
parser.add_argument('--num_workers', type=int, default=0, help="dataloader加载数据时使用的线程数量")
|
|
parser.add_argument('--patience', type=int, default=0, help="用于early stopping,设为0时,不进行early stopping.early stop得到的模型的生成效果不一定会更好。")
|
|
parser.add_argument('--warmup_steps', type=int, default=4000, help='warm up步数')
|
|
# parser.add_argument('--label_smoothing', default=True, action='store_true', help='是否进行标签平滑')
|
|
parser.add_argument('--val_num', type=int, default=8000, help='验证集大小')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def create_logger(args):
|
|
"""
|
|
将日志输出到日志文件和控制台
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
# 创建一个handler,用于写入日志文件
|
|
file_handler = logging.FileHandler(
|
|
filename=args.log_path)
|
|
file_handler.setFormatter(formatter)
|
|
file_handler.setLevel(logging.INFO)
|
|
logger.addHandler(file_handler)
|
|
|
|
# 创建一个handler,用于将日志输出到控制台
|
|
console = logging.StreamHandler()
|
|
console.setLevel(logging.DEBUG)
|
|
console.setFormatter(formatter)
|
|
logger.addHandler(console)
|
|
|
|
return logger
|
|
|
|
|
|
def collate_fn(batch):
|
|
input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=0)
|
|
labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100)
|
|
return input_ids, labels
|
|
|
|
|
|
# def padding_batch(data_list, pad_id):
|
|
# """
|
|
# 使用pad_id将data_list的每条数据,填充至data_list中最长的长度
|
|
# :param data_list:
|
|
# :param pad_id:
|
|
# :return:
|
|
# """
|
|
# # 统计data_list中的最大长度
|
|
# max_len = 0
|
|
# for data in data_list:
|
|
# max_len = max_len if max_len > len(data) else len(data)
|
|
#
|
|
# # 对数据进行padding
|
|
# new_data_list = []
|
|
# for data in data_list:
|
|
# new_data = data + [pad_id] * (max_len - len(data))
|
|
# new_data_list.append(new_data)
|
|
# return new_data_list
|
|
|
|
|
|
def load_dataset(logger, args):
|
|
"""
|
|
加载训练集和验证集
|
|
"""
|
|
logger.info("loading training dataset and validating dataset")
|
|
train_path = args.train_path
|
|
|
|
with open(train_path, "rb") as f:
|
|
input_list = pickle.load(f)
|
|
|
|
# 划分训练集与验证集
|
|
val_num = args.val_num
|
|
input_list_train = input_list[val_num:]
|
|
input_list_val = input_list[:val_num]
|
|
# test
|
|
# input_list_train = input_list_train[:24]
|
|
# input_list_val = input_list_val[:24]
|
|
|
|
train_dataset = MyDataset(input_list_train, args.max_len)
|
|
val_dataset = MyDataset(input_list_val, args.max_len)
|
|
|
|
return train_dataset, val_dataset
|
|
|
|
|
|
def train_epoch(model, train_dataloader, optimizer, scheduler, logger,
|
|
epoch, args):
|
|
model.train()
|
|
device = args.device
|
|
# pad_id = args.pad_id
|
|
# sep_id = args.sep_id
|
|
ignore_index = args.ignore_index
|
|
epoch_start_time = datetime.now()
|
|
total_loss = 0 # 记录下整个epoch的loss的总和
|
|
|
|
# epoch_correct_num:每个epoch中,output预测正确的word的数量
|
|
# epoch_total_num: 每个epoch中,output预测的word的总数量
|
|
epoch_correct_num, epoch_total_num = 0, 0
|
|
|
|
for batch_idx, (input_ids, labels) in enumerate(train_dataloader):
|
|
# 捕获cuda out of memory exception
|
|
try:
|
|
input_ids = input_ids.to(device)
|
|
labels = labels.to(device)
|
|
outputs = model.forward(input_ids, labels=labels)
|
|
logits = outputs.logits
|
|
loss = outputs.loss
|
|
loss = loss.mean()
|
|
|
|
# 统计该batch的预测token的正确数与总数
|
|
batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)
|
|
# 统计该epoch的预测token的正确数与总数
|
|
epoch_correct_num += batch_correct_num
|
|
epoch_total_num += batch_total_num
|
|
# 计算该batch的accuracy
|
|
batch_acc = batch_correct_num / batch_total_num
|
|
|
|
total_loss += loss.item()
|
|
if args.gradient_accumulation_steps > 1:
|
|
loss = loss / args.gradient_accumulation_steps
|
|
|
|
loss.backward()
|
|
# 梯度裁剪
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
|
|
|
# 进行一定step的梯度累计之后,更新参数
|
|
if (batch_idx + 1) % args.gradient_accumulation_steps == 0:
|
|
# 更新参数
|
|
optimizer.step()
|
|
# 更新学习率
|
|
scheduler.step()
|
|
# 清空梯度信息
|
|
optimizer.zero_grad()
|
|
|
|
if (batch_idx + 1) % args.log_step == 0:
|
|
logger.info(
|
|
"batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(
|
|
batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))
|
|
|
|
del input_ids, outputs
|
|
|
|
except RuntimeError as exception:
|
|
if "out of memory" in str(exception):
|
|
logger.info("WARNING: ran out of memory")
|
|
if hasattr(torch.cuda, 'empty_cache'):
|
|
torch.cuda.empty_cache()
|
|
else:
|
|
logger.info(str(exception))
|
|
raise exception
|
|
|
|
# 记录当前epoch的平均loss与accuracy
|
|
epoch_mean_loss = total_loss / len(train_dataloader)
|
|
epoch_mean_acc = epoch_correct_num / epoch_total_num
|
|
logger.info(
|
|
"epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))
|
|
|
|
# save model
|
|
logger.info('saving model for epoch {}'.format(epoch + 1))
|
|
model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))
|
|
if not os.path.exists(model_path):
|
|
os.mkdir(model_path)
|
|
model_to_save = model.module if hasattr(model, 'module') else model
|
|
model_to_save.save_pretrained(model_path)
|
|
logger.info('epoch {} finished'.format(epoch + 1))
|
|
epoch_finish_time = datetime.now()
|
|
logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))
|
|
|
|
return epoch_mean_loss
|
|
|
|
|
|
def validate_epoch(model, validate_dataloader, logger, epoch, args):
|
|
logger.info("start validating")
|
|
model.eval()
|
|
device = args.device
|
|
# pad_id = args.pad_id
|
|
# sep_id = args.sep_id
|
|
ignore_index = args.ignore_index
|
|
epoch_start_time = datetime.now()
|
|
total_loss = 0
|
|
# 捕获cuda out of memory exception
|
|
try:
|
|
with torch.no_grad():
|
|
for batch_idx, (input_ids, labels) in enumerate(validate_dataloader):
|
|
input_ids = input_ids.to(device)
|
|
labels = labels.to(device)
|
|
outputs = model.forward(input_ids, labels=labels)
|
|
logits = outputs.logits
|
|
loss = outputs.loss
|
|
loss = loss.mean()
|
|
|
|
total_loss += loss.item()
|
|
del input_ids, outputs
|
|
|
|
# 记录当前epoch的平均loss
|
|
epoch_mean_loss = total_loss / len(validate_dataloader)
|
|
logger.info(
|
|
"validate epoch {}: loss {}".format(epoch+1, epoch_mean_loss))
|
|
epoch_finish_time = datetime.now()
|
|
logger.info('time for validating one epoch: {}'.format(epoch_finish_time - epoch_start_time))
|
|
return epoch_mean_loss
|
|
except RuntimeError as exception:
|
|
if "out of memory" in str(exception):
|
|
logger.info("WARNING: ran out of memory")
|
|
if hasattr(torch.cuda, 'empty_cache'):
|
|
torch.cuda.empty_cache()
|
|
else:
|
|
logger.info(str(exception))
|
|
raise exception
|
|
|
|
|
|
def train(model, logger, train_dataset, validate_dataset, args):
|
|
train_dataloader = DataLoader(
|
|
train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,
|
|
drop_last=True
|
|
)
|
|
validate_dataloader = DataLoader(validate_dataset, batch_size=args.batch_size, shuffle=True,
|
|
num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True)
|
|
early_stopping = EarlyStopping(args.patience, verbose=True, save_path=args.save_model_path)
|
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs
|
|
optimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)
|
|
# scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
|
scheduler = transformers.get_linear_schedule_with_warmup(
|
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
|
)
|
|
|
|
logger.info('starting training')
|
|
|
|
# 用于记录每个epoch训练和验证的loss
|
|
train_losses, validate_losses = [], []
|
|
# 记录验证集的最小loss
|
|
best_val_loss = 10000
|
|
# 开始训练
|
|
for epoch in range(args.epochs):
|
|
# ========== train ========== #
|
|
train_loss = train_epoch(
|
|
model=model, train_dataloader=train_dataloader,
|
|
optimizer=optimizer, scheduler=scheduler,
|
|
logger=logger, epoch=epoch, args=args)
|
|
train_losses.append(train_loss)
|
|
|
|
# ========== validate ========== #
|
|
validate_loss = validate_epoch(
|
|
model=model, validate_dataloader=validate_dataloader,
|
|
logger=logger, epoch=epoch, args=args)
|
|
validate_losses.append(validate_loss)
|
|
|
|
# 保存当前困惑度最低的模型,困惑度低,模型的生成效果不一定会越好
|
|
if validate_loss < best_val_loss:
|
|
best_val_loss = validate_loss
|
|
logger.info('saving current best model for epoch {}'.format(epoch + 1))
|
|
model_path = join(args.save_model_path, 'min_ppl_model'.format(epoch + 1))
|
|
if not os.path.exists(model_path):
|
|
os.mkdir(model_path)
|
|
model_to_save = model.module if hasattr(model, 'module') else model
|
|
model_to_save.save_pretrained(model_path)
|
|
|
|
# 如果patience=0,则不进行early stopping
|
|
if args.patience == 0:
|
|
continue
|
|
early_stopping(validate_loss, model)
|
|
if early_stopping.early_stop:
|
|
logger.info("Early stopping")
|
|
break
|
|
logger.info('training finished')
|
|
logger.info("train_losses:{}".format(train_losses))
|
|
logger.info("validate_losses:{}".format(validate_losses))
|
|
|
|
|
|
def caculate_loss(logit, target, pad_idx, smoothing=True):
|
|
if smoothing:
|
|
logit = logit[..., :-1, :].contiguous().view(-1, logit.size(2))
|
|
target = target[..., 1:].contiguous().view(-1)
|
|
|
|
eps = 0.1
|
|
n_class = logit.size(-1)
|
|
|
|
one_hot = torch.zeros_like(logit).scatter(1, target.view(-1, 1), 1)
|
|
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
|
log_prb = F.log_softmax(logit, dim=1)
|
|
|
|
non_pad_mask = target.ne(pad_idx)
|
|
loss = -(one_hot * log_prb).sum(dim=1)
|
|
loss = loss.masked_select(non_pad_mask).mean() # average later
|
|
else:
|
|
# loss = F.cross_entropy(predict_logit, target, ignore_index=pad_idx)
|
|
logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1))
|
|
labels = target[..., 1:].contiguous().view(-1)
|
|
loss = F.cross_entropy(logit, labels, ignore_index=pad_idx)
|
|
return loss
|
|
|
|
|
|
def calculate_acc(logit, labels, ignore_index=-100):
|
|
logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1))
|
|
labels = labels[..., 1:].contiguous().view(-1)
|
|
|
|
_, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index
|
|
# 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1
|
|
non_pad_mask = labels.ne(ignore_index)
|
|
n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item()
|
|
n_word = non_pad_mask.sum().item()
|
|
return n_correct, n_word
|
|
|
|
|
|
def main():
|
|
# 初始化参数
|
|
args = set_args()
|
|
|
|
# 设置使用哪些显卡进行训练
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
|
|
|
args.cuda = not args.no_cuda
|
|
|
|
if args.batch_size < 2048 and args.warmup_steps <= 4000:
|
|
print('[Warning] The warmup steps may be not enough.\n'
|
|
'(sz_b, warmup) = (2048, 4000) is the official setting.\n'
|
|
'Using smaller batch w/o longer warmup may cause '
|
|
'the warmup stage ends with only little data trained.')
|
|
|
|
# 创建日志对象
|
|
logger = create_logger(args)
|
|
# 当用户使用GPU,并且GPU可用时
|
|
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
|
device = 'cuda:0' if args.cuda else 'cpu'
|
|
args.device = device
|
|
logger.info('using device:{}'.format(device))
|
|
|
|
# 初始化tokenizer
|
|
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
|
args.sep_id = tokenizer.sep_token_id
|
|
args.pad_id = tokenizer.pad_token_id
|
|
args.cls_id = tokenizer.cls_token_id
|
|
|
|
# 创建模型的输出目录
|
|
if not os.path.exists(args.save_model_path):
|
|
os.mkdir(args.save_model_path)
|
|
|
|
# 创建模型
|
|
if args.pretrained_model: # 加载预训练模型
|
|
model = GPT2LMHeadModel.from_pretrained(args.pretrained_model)
|
|
else: # 初始化模型
|
|
model_config = GPT2Config.from_json_file(args.model_config)
|
|
model = GPT2LMHeadModel(config=model_config)
|
|
model = model.to(device)
|
|
logger.info('model config:\n{}'.format(model.config.to_json_string()))
|
|
assert model.config.vocab_size == tokenizer.vocab_size
|
|
|
|
# 并行训练模型
|
|
if args.cuda and torch.cuda.device_count() > 1:
|
|
model = DataParallel(model).cuda()
|
|
# model = BalancedDataParallel(args.gpu0_bsz, model, dim=0).cuda()
|
|
logger.info("use GPU {} to train".format(args.device))
|
|
|
|
# 计算模型参数数量
|
|
num_parameters = 0
|
|
parameters = model.parameters()
|
|
for parameter in parameters:
|
|
num_parameters += parameter.numel()
|
|
logger.info('number of model parameters: {}'.format(num_parameters))
|
|
|
|
# 记录参数设置
|
|
logger.info("args:{}".format(args))
|
|
|
|
# 加载训练集和验证集
|
|
# ========= Loading Dataset ========= #
|
|
train_dataset, validate_dataset = load_dataset(logger, args)
|
|
|
|
train(model, logger, train_dataset, validate_dataset, args)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|