初始化
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
vocab/
|
||||
download/
|
||||
data/
|
||||
@@ -0,0 +1,2 @@
|
||||
vocab文件太大了,没添加到git里,请去release下载。
|
||||
原始语料数据清洗得不是很干净,加上有隐私顾虑,就不上传了。
|
||||
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"_name_or_path": "model/epoch29",
|
||||
"activation_function": "gelu_new",
|
||||
"architectures": [
|
||||
"GPT2LMHeadModel"
|
||||
],
|
||||
"attn_pdrop": 0.1,
|
||||
"bos_token_id": 50256,
|
||||
"embd_pdrop": 0.1,
|
||||
"eos_token_id": 50256,
|
||||
"gradient_checkpointing": false,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"model_type": "gpt2",
|
||||
"n_ctx": 1024,
|
||||
"n_embd": 768,
|
||||
"n_head": 12,
|
||||
"n_inner": null,
|
||||
"n_layer": 12,
|
||||
"n_positions": 1024,
|
||||
"output_past": true,
|
||||
"resid_pdrop": 0.1,
|
||||
"summary_activation": null,
|
||||
"summary_first_dropout": 0.1,
|
||||
"summary_proj_to_labels": true,
|
||||
"summary_type": "cls_index",
|
||||
"summary_use_proj": true,
|
||||
"task_specific_params": {
|
||||
"text-generation": {
|
||||
"do_sample": true,
|
||||
"max_length": 400
|
||||
}
|
||||
},
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"transformers_version": "4.2.0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 13317
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
import torch
|
||||
import os
|
||||
from transformers import GPT2LMHeadModel
|
||||
from transformers import BertTokenizerFast
|
||||
import torch.nn.functional as F
|
||||
|
||||
import random
|
||||
import asyncio
|
||||
import websockets
|
||||
import json
|
||||
|
||||
from cqhttp_settings import *
|
||||
from interact import set_args, create_logger, top_k_top_p_filtering
|
||||
|
||||
PAD = '[PAD]'
|
||||
pad_id = 0
|
||||
|
||||
|
||||
def init_model():
|
||||
args = set_args()
|
||||
logger = create_logger(args)
|
||||
# 当用户使用GPU,并且GPU可用时
|
||||
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
||||
device = 'cuda' if args.cuda else 'cpu'
|
||||
logger.info('using device:{}'.format(device))
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
||||
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
||||
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_path)
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
return args, device, tokenizer, model
|
||||
|
||||
|
||||
def process(text, args, device, tokenizer, model):
|
||||
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
history.append(text_ids)
|
||||
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
|
||||
|
||||
for history_id, history_utr in enumerate(history[-args.max_history_len:]):
|
||||
input_ids.extend(history_utr)
|
||||
input_ids.append(tokenizer.sep_token_id)
|
||||
input_ids = torch.tensor(input_ids).long().to(device)
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
response = [] # 根据context,生成的response
|
||||
# 最多生成max_len个token
|
||||
for _ in range(args.max_len):
|
||||
outputs = model(input_ids=input_ids)
|
||||
logits = outputs.logits
|
||||
next_token_logits = logits[0, -1, :]
|
||||
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
|
||||
for id in set(response):
|
||||
next_token_logits[id] /= args.repetition_penalty
|
||||
next_token_logits = next_token_logits / args.temperature
|
||||
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
|
||||
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
|
||||
# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束
|
||||
break
|
||||
response.append(next_token.item())
|
||||
input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
|
||||
# his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
|
||||
# print("his_text:{}".format(his_text))
|
||||
history.append(response)
|
||||
|
||||
return "".join(tokenizer.convert_ids_to_tokens(response))
|
||||
|
||||
|
||||
# 存储聊天记录,每个utterance以token的id的形式进行存储
|
||||
history = []
|
||||
|
||||
|
||||
def handle_event(evjson):
|
||||
event = json.loads(evjson)
|
||||
from_id = ""
|
||||
msg_type = ""
|
||||
req = ""
|
||||
# print(event)
|
||||
rand_response_possibility = random.randint(0, 100)
|
||||
if 'message_type' in event and 'raw_message' in event:
|
||||
msg_type = event['message_type']
|
||||
msg_recv = event['raw_message']
|
||||
print("收到消息: ", msg_recv)
|
||||
if event['message_type'] == "group" and event['group_id'] in event_response_settings_group_enabled:
|
||||
# 群消息
|
||||
print("回复几率(?): {} < {}".format(rand_response_possibility, event_response_settings_group_rate))
|
||||
if rand_response_possibility < event_response_settings_group_rate:
|
||||
from_id = event['group_id']
|
||||
req = msg_recv
|
||||
|
||||
elif event['message_type'] == "private":
|
||||
# 私聊消息
|
||||
print("回复几率(?): {} < {}".format(rand_response_possibility, event_response_settings_private_rate))
|
||||
if rand_response_possibility < event_response_settings_private_rate:
|
||||
from_id = event['user_id']
|
||||
req = msg_recv
|
||||
|
||||
return from_id, msg_type, req
|
||||
|
||||
|
||||
async def send_msg(cqhttp, source, msg_type, text):
|
||||
print("Msg to {}({}): {}".format(source, msg_type, text))
|
||||
data_send = {}
|
||||
if msg_type == "group":
|
||||
# 群消息
|
||||
data_send = {
|
||||
'action': "send_group_msg",
|
||||
'params': {
|
||||
'group_id': source,
|
||||
'message': text,
|
||||
},
|
||||
}
|
||||
elif msg_type == "private":
|
||||
# 私聊消息
|
||||
data_send = {
|
||||
'action': "send_private_msg",
|
||||
'params': {
|
||||
'user_id': source,
|
||||
'message': text,
|
||||
},
|
||||
}
|
||||
|
||||
await cqhttp.send(json.dumps(data_send))
|
||||
|
||||
|
||||
async def init_cqhttp_ws(args, device, tokenizer, model):
|
||||
print("Initializing CQHTTP WebSocket...")
|
||||
async with websockets.connect(cqhttp_ws_addr + "?access_token=" + cqhttp_ws_access_token) as cqhttp:
|
||||
while True:
|
||||
evjson = await cqhttp.recv()
|
||||
|
||||
source, msg_type, req = handle_event(evjson)
|
||||
|
||||
print(req)
|
||||
|
||||
if req:
|
||||
res = process(req, args, device, tokenizer, model)
|
||||
if res:
|
||||
print("Q: {}\nA: {}\n".format(req, res))
|
||||
await send_msg(cqhttp, source, msg_type, res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, device, tokenizer, model = init_model()
|
||||
asyncio.get_event_loop().run_until_complete(init_cqhttp_ws(args, device, tokenizer, model))
|
||||
@@ -0,0 +1,6 @@
|
||||
cqhttp_ws_addr = "ws://127.0.0.1:16700"
|
||||
cqhttp_ws_access_token = "WPlJfomObZADDQBiUneH7nv1HfyReDcY"
|
||||
|
||||
event_response_settings_group_rate = 100 # 0~100
|
||||
event_response_settings_group_enabled = [1060806176]
|
||||
event_response_settings_private_rate = 100 # 0~100
|
||||
@@ -0,0 +1,100 @@
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch
|
||||
from torch.nn.parallel._functions import Scatter
|
||||
from torch.nn.parallel.parallel_apply import parallel_apply
|
||||
|
||||
|
||||
def scatter(inputs, target_gpus, chunk_sizes, dim=0):
|
||||
r"""
|
||||
Slices tensors into approximately equal chunks and
|
||||
distributes them across given GPUs. Duplicates
|
||||
references to objects that are not tensors.
|
||||
"""
|
||||
def scatter_map(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
try:
|
||||
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
|
||||
except:
|
||||
print('obj', obj.size())
|
||||
print('dim', dim)
|
||||
print('chunk_sizes', chunk_sizes)
|
||||
quit()
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
return list(map(list, zip(*map(scatter_map, obj))))
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
||||
return [obj for targets in target_gpus]
|
||||
|
||||
# After scatter_map is called, a scatter_map cell will exist. This cell
|
||||
# has a reference to the actual function scatter_map, which has references
|
||||
# to a closure that has a reference to the scatter_map cell (because the
|
||||
# fn is recursive). To avoid this reference cycle, we set the function to
|
||||
# None, clearing the cell
|
||||
try:
|
||||
return scatter_map(inputs)
|
||||
finally:
|
||||
scatter_map = None
|
||||
|
||||
|
||||
def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
|
||||
r"""Scatter with support for kwargs dictionary"""
|
||||
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
|
||||
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
|
||||
if len(inputs) < len(kwargs):
|
||||
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
||||
elif len(kwargs) < len(inputs):
|
||||
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
||||
inputs = tuple(inputs)
|
||||
kwargs = tuple(kwargs)
|
||||
return inputs, kwargs
|
||||
|
||||
|
||||
class BalancedDataParallel(DataParallel):
|
||||
def __init__(self, gpu0_bsz, *args, **kwargs):
|
||||
self.gpu0_bsz = gpu0_bsz
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
return self.module(*inputs, **kwargs)
|
||||
if self.gpu0_bsz == 0:
|
||||
device_ids = self.device_ids[1:]
|
||||
else:
|
||||
device_ids = self.device_ids
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
|
||||
# print('len(inputs)1: ', str(len(inputs)))
|
||||
# print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))
|
||||
if len(self.device_ids) == 1:
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
||||
if self.gpu0_bsz == 0:
|
||||
replicas = replicas[1:]
|
||||
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
|
||||
return self.gather(outputs, self.output_device)
|
||||
|
||||
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
|
||||
return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
bsz = inputs[0].size(self.dim)
|
||||
num_dev = len(self.device_ids)
|
||||
gpu0_bsz = self.gpu0_bsz
|
||||
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
|
||||
if gpu0_bsz < bsz_unit:
|
||||
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
|
||||
delta = bsz - sum(chunk_sizes)
|
||||
for i in range(delta):
|
||||
chunk_sizes[i + 1] += 1
|
||||
if gpu0_bsz == 0:
|
||||
chunk_sizes = chunk_sizes[1:]
|
||||
else:
|
||||
return super().scatter(inputs, kwargs, device_ids)
|
||||
|
||||
# print('bsz: ', bsz)
|
||||
# print('num_dev: ', num_dev)
|
||||
# print('gpu0_bsz: ', gpu0_bsz)
|
||||
# print('bsz_unit: ', bsz_unit)
|
||||
# print('chunk_sizes: ', chunk_sizes)
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, input_list, max_len):
|
||||
self.input_list = input_list
|
||||
self.max_len = max_len
|
||||
|
||||
def __getitem__(self, index):
|
||||
input_ids = self.input_list[index]
|
||||
input_ids = input_ids[:self.max_len]
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||
return input_ids
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_list)
|
||||
@@ -0,0 +1,66 @@
|
||||
import argparse
|
||||
from os.path import join
|
||||
from collections import Counter
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.pyplot import MultipleLocator
|
||||
|
||||
|
||||
def generate_subset():
|
||||
"""
|
||||
用于生成训练子集
|
||||
:return:
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--raw_data_path', default='data/train.txt', type=str, required=False, help='原始训练语料')
|
||||
parser.add_argument('--subset_size', default=1000000, type=int, required=False, help='要获取的对话数据子集的规模')
|
||||
parser.add_argument('--subset_data_path', default='data', type=str, required=False,
|
||||
help='数据子集文件路径,指定文件的父目录')
|
||||
args = parser.parse_args()
|
||||
with open(args.raw_data_path, "r", encoding="utf8") as f:
|
||||
data = f.read()
|
||||
dialogues = data.split("\n\n")
|
||||
subset_size = min(len(dialogues), args.subset_size)
|
||||
|
||||
with open(join(args.subset_data_path, "train_{}w.txt".format(int(subset_size / 10000))), "w", encoding="utf8") as f:
|
||||
print("generating subset,please wait a few minutes")
|
||||
for dialogue_index, dialogue in enumerate(dialogues):
|
||||
if dialogue_index >= subset_size:
|
||||
break
|
||||
for utterance in dialogue.split("\n"):
|
||||
f.writelines(utterance + "\n")
|
||||
f.writelines("\n")
|
||||
|
||||
|
||||
def compute_dialogue_length():
|
||||
"""
|
||||
查看聊天语料中的dialogue的长度分布
|
||||
:return:
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--raw_data_path', default='data/train.txt', type=str, required=False, help='原始训练语料')
|
||||
args = parser.parse_args()
|
||||
with open(args.raw_data_path, "r", encoding="utf8") as f:
|
||||
data = f.read()
|
||||
dialogues = data.split("\n\n")
|
||||
# 统计各个dialogue的长度
|
||||
dialogues_lengths = [len(dialogue.replace("\n", "")) for dialogue in dialogues]
|
||||
counter = Counter(dialogues_lengths) # {label:sum(label)}
|
||||
dialogue_length_arr = list(counter)
|
||||
num_arr = [counter[element] for element in list(counter)]
|
||||
print(counter[300])
|
||||
|
||||
x_major_locator = MultipleLocator(100) # MultipleLocator用于设置刻度间隔
|
||||
# y_major_locator = MultipleLocator(20000)
|
||||
ax = plt.gca() # ax为两条坐标轴的实例
|
||||
ax.xaxis.set_major_locator(x_major_locator) # 把x轴的主刻度设置为10的倍数
|
||||
# ax.yaxis.set_major_locator(y_major_locator)
|
||||
|
||||
plt.xlabel('dialogue length')
|
||||
plt.ylabel('number of dialogue')
|
||||
# plt.plot(dialogue_length_arr, num_arr, c='green')
|
||||
plt.scatter(dialogue_length_arr, num_arr)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate_subset()
|
||||
+167
@@ -0,0 +1,167 @@
|
||||
import torch
|
||||
import os
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from transformers import GPT2LMHeadModel
|
||||
from transformers import BertTokenizerFast
|
||||
import torch.nn.functional as F
|
||||
|
||||
PAD = '[PAD]'
|
||||
pad_id = 0
|
||||
|
||||
|
||||
def set_args():
|
||||
"""
|
||||
Sets up the arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--device', default='0', type=str, required=False, help='生成设备')
|
||||
parser.add_argument('--temperature', default=1, type=float, required=False, help='生成的temperature')
|
||||
parser.add_argument('--topk', default=8, type=int, required=False, help='最高k选1')
|
||||
parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率')
|
||||
# parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False,
|
||||
# help='模型参数')
|
||||
parser.add_argument('--log_path', default='data/interact.log', type=str, required=False, help='interact日志存放位置')
|
||||
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False, help='选择词库')
|
||||
parser.add_argument('--model_path', default='model/epoch40', type=str, required=False, help='对话模型路径')
|
||||
parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径")
|
||||
parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False,
|
||||
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
|
||||
# parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
|
||||
parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
|
||||
parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度")
|
||||
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_logger(args):
|
||||
"""
|
||||
将日志输出到日志文件和控制台
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# 创建一个handler,用于写入日志文件
|
||||
file_handler = logging.FileHandler(
|
||||
filename=args.log_path)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 创建一个handler,用于将日志输出到控制台
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.DEBUG)
|
||||
console.setFormatter(formatter)
|
||||
logger.addHandler(console)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
logits: logits distribution shape (vocab size)
|
||||
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the last token of the top-k
|
||||
# torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices)
|
||||
# ...表示其他维度由计算机自行推断
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷
|
||||
|
||||
if top_p > 0.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # 对logits进行递减排序
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token above the threshold
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
def main():
|
||||
args = set_args()
|
||||
logger = create_logger(args)
|
||||
# 当用户使用GPU,并且GPU可用时
|
||||
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
||||
device = 'cuda' if args.cuda else 'cpu'
|
||||
logger.info('using device:{}'.format(device))
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
||||
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
||||
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_path)
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
if args.save_samples_path:
|
||||
if not os.path.exists(args.save_samples_path):
|
||||
os.makedirs(args.save_samples_path)
|
||||
samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8')
|
||||
samples_file.write("聊天记录{}:\n".format(datetime.now()))
|
||||
# 存储聊天记录,每个utterance以token的id的形式进行存储
|
||||
history = []
|
||||
print('开始和chatbot聊天,输入CTRL + Z以退出')
|
||||
|
||||
while True:
|
||||
try:
|
||||
text = input("user:")
|
||||
# text = "你好"
|
||||
if args.save_samples_path:
|
||||
samples_file.write("user:{}\n".format(text))
|
||||
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
history.append(text_ids)
|
||||
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
|
||||
|
||||
for history_id, history_utr in enumerate(history[-args.max_history_len:]):
|
||||
input_ids.extend(history_utr)
|
||||
input_ids.append(tokenizer.sep_token_id)
|
||||
input_ids = torch.tensor(input_ids).long().to(device)
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
response = [] # 根据context,生成的response
|
||||
# 最多生成max_len个token
|
||||
for _ in range(args.max_len):
|
||||
outputs = model(input_ids=input_ids)
|
||||
logits = outputs.logits
|
||||
next_token_logits = logits[0, -1, :]
|
||||
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
|
||||
for id in set(response):
|
||||
next_token_logits[id] /= args.repetition_penalty
|
||||
next_token_logits = next_token_logits / args.temperature
|
||||
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
|
||||
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
|
||||
# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
|
||||
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
||||
if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束
|
||||
break
|
||||
response.append(next_token.item())
|
||||
input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
|
||||
# his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
|
||||
# print("his_text:{}".format(his_text))
|
||||
history.append(response)
|
||||
text = tokenizer.convert_ids_to_tokens(response)
|
||||
print("chatbot:" + "".join(text))
|
||||
if args.save_samples_path:
|
||||
samples_file.write("chatbot:{}\n".format("".join(text)))
|
||||
except KeyboardInterrupt:
|
||||
if args.save_samples_path:
|
||||
samples_file.close()
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,96 @@
|
||||
from transformers import BertTokenizerFast
|
||||
import argparse
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
|
||||
def create_logger(log_path):
|
||||
"""
|
||||
将日志输出到日志文件和控制台
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# 创建一个handler,用于写入日志文件
|
||||
file_handler = logging.FileHandler(
|
||||
filename=log_path)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 创建一个handler,用于将日志输出到控制台
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.DEBUG)
|
||||
console.setFormatter(formatter)
|
||||
logger.addHandler(console)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def preprocess():
|
||||
"""
|
||||
对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
|
||||
"""
|
||||
# 设置参数
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False,
|
||||
help='词表路径')
|
||||
parser.add_argument('--log_path', default='data/preprocess.log', type=str, required=False, help='训练日志存放位置')
|
||||
parser.add_argument('--train_path', default='data/train.txt', type=str, required=False, help='训练日志存放位置')
|
||||
parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False, help='tokenize的训练数据集')
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化日志对象
|
||||
logger = create_logger(args.log_path)
|
||||
|
||||
# 初始化tokenizer
|
||||
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
||||
sep_id = tokenizer.sep_token_id
|
||||
cls_id = tokenizer.cls_token_id
|
||||
logger.info("preprocessing data,data path:{}, save path:{}".format(args.train_path, args.save_path))
|
||||
|
||||
# 读取训练数据集
|
||||
with open(args.train_path, 'rb') as f:
|
||||
data = f.read().decode("utf-8")
|
||||
|
||||
# 需要区分linux和windows环境下的换行符
|
||||
if "\r\n" in data:
|
||||
train_data = data.split("\r\n\r\n")
|
||||
else:
|
||||
train_data = data.split("\n\n")
|
||||
logger.info("there are {} dialogue in dataset".format(len(train_data)))
|
||||
|
||||
# 开始进行tokenize
|
||||
# 保存所有的对话数据,每条数据的格式为:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
|
||||
dialogue_len = [] # 记录所有对话tokenize之后的长度,用于统计中位数与均值
|
||||
dialogue_list = []
|
||||
with open(args.save_path, "w", encoding="utf-8") as f:
|
||||
for index, dialogue in enumerate(tqdm(train_data)):
|
||||
if "\r\n" in data:
|
||||
utterances = dialogue.split("\r\n")
|
||||
else:
|
||||
utterances = dialogue.split("\n")
|
||||
|
||||
input_ids = [cls_id] # 每个dialogue以[CLS]开头
|
||||
for utterance in utterances:
|
||||
input_ids += tokenizer.encode(utterance, add_special_tokens=False)
|
||||
input_ids.append(sep_id) # 每个utterance之后添加[SEP],表示utterance结束
|
||||
dialogue_len.append(len(input_ids))
|
||||
dialogue_list.append(input_ids)
|
||||
len_mean = np.mean(dialogue_len)
|
||||
len_median = np.median(dialogue_len)
|
||||
len_max = np.max(dialogue_len)
|
||||
with open(args.save_path, "wb") as f:
|
||||
pickle.dump(dialogue_list, f)
|
||||
logger.info("finish preprocessing data,the result is stored in {}".format(args.save_path))
|
||||
logger.info("mean of dialogue len:{},median of dialogue len:{},max len:{}".format(len_mean, len_median, len_max))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
preprocess()
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
||||
|
||||
def __init__(self, patience=7, verbose=False, delta=0, save_path="."):
|
||||
"""
|
||||
Args:
|
||||
patience (int): How long to wait after last time validation loss improved.
|
||||
Default: 7
|
||||
verbose (bool): If True, prints a message for each validation loss improvement.
|
||||
Default: False
|
||||
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
||||
Default: 0
|
||||
"""
|
||||
self.patience = patience
|
||||
self.verbose = verbose
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
self.val_loss_min = np.Inf
|
||||
self.delta = delta
|
||||
self.save_path = save_path
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
|
||||
score = -val_loss
|
||||
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
elif score < self.best_score + self.delta:
|
||||
self.counter += 1
|
||||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.counter = 0
|
||||
|
||||
def save_checkpoint(self, val_loss, model):
|
||||
'''Saves model when validation loss decrease.'''
|
||||
if self.verbose:
|
||||
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
|
||||
# save_path = join(self.save_path, "best_model")
|
||||
# if not os.path.exists(save_path):
|
||||
# os.mkdir(save_path)
|
||||
# model_to_save = model.module if hasattr(model, 'module') else model
|
||||
# model_to_save.save_pretrained(save_path)
|
||||
self.val_loss_min = val_loss
|
||||
@@ -0,0 +1,40 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
files = os.walk("download/QQMsg")
|
||||
|
||||
outputFile = open("data/train_qq.txt", "w", encoding="utf-8")
|
||||
|
||||
for path, dir_list, file_list in files:
|
||||
for file_name in file_list:
|
||||
print(os.path.join(path, file_name))
|
||||
f = open(os.path.join(path, file_name), "r", encoding="utf-8")
|
||||
lines = f.readlines()
|
||||
stat = 0 # 0: ready to parse time / 1: ready to parse log
|
||||
lastTime = datetime.strptime("1970-1-1 00:00:00", "%Y-%m-%d %H:%M:%S")
|
||||
for i in range(8, len(lines)):
|
||||
raw = lines[i].replace("\r\n", "").replace("\n", "")
|
||||
|
||||
# 这一行是时间
|
||||
timeStrs = raw.split(' ', 2)
|
||||
try:
|
||||
# 这一行是时间
|
||||
if timeStrs[0][0] == '2':
|
||||
tsStr = timeStrs[0] + " " + timeStrs[1]
|
||||
else:
|
||||
tsStr = timeStrs[1] + " " + timeStrs[2]
|
||||
ts = datetime.strptime(tsStr, "%Y-%m-%d %H:%M:%S")
|
||||
if ((ts - lastTime).seconds > 120) or ((ts - lastTime).seconds < 0):
|
||||
# 间隔2分钟以上,认为是不同的对话
|
||||
outputFile.write("\n")
|
||||
lastTime = ts
|
||||
except (IndexError, ValueError) as e:
|
||||
# 这一行是消息
|
||||
msg = raw.replace("[图片]", "").replace("[表情]", "").replace("[合并转发]请使用手机QQ最新版本查看", "")
|
||||
if msg != "":
|
||||
# 是有效行
|
||||
outputFile.write(msg + "\n")
|
||||
|
||||
f.close()
|
||||
|
||||
outputFile.close()
|
||||
@@ -0,0 +1,6 @@
|
||||
transformers~=4.6.1
|
||||
torch~=1.8.1
|
||||
tqdm~=4.61.0
|
||||
numpy~=1.20.3
|
||||
matplotlib~=3.4.2
|
||||
websockets~=9.1
|
||||
@@ -0,0 +1,417 @@
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import os
|
||||
from torch.utils.data import DataLoader
|
||||
from os.path import join
|
||||
from torch.nn import DataParallel
|
||||
import transformers
|
||||
import pickle
|
||||
from pytorchtools import EarlyStopping
|
||||
from transformers import GPT2LMHeadModel, GPT2Config
|
||||
from transformers import BertTokenizerFast
|
||||
import torch.nn.utils.rnn as rnn_utils
|
||||
from dataset import MyDataset
|
||||
|
||||
|
||||
def set_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--device', default='3', type=str, required=False, help='设置使用哪些显卡')
|
||||
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行训练')
|
||||
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False,
|
||||
help='词表路径')
|
||||
parser.add_argument('--model_config', default='config/config.json', type=str, required=False,
|
||||
help='设置模型参数')
|
||||
parser.add_argument('--train_path', default='data/train.pkl', type=str, required=False, help='训练集路径')
|
||||
parser.add_argument('--max_len', default=150, type=int, required=False, help='训练时,输入数据的最大长度')
|
||||
|
||||
parser.add_argument('--log_path', default='data/train.log', type=str, required=False, help='训练日志存放位置')
|
||||
parser.add_argument('--log', default=True, help="是否记录日志")
|
||||
parser.add_argument('--ignore_index', default=-100, type=int, required=False, help='对于ignore_index的label token不计算梯度')
|
||||
# parser.add_argument('--input_len', default=200, type=int, required=False, help='输入的长度')
|
||||
parser.add_argument('--epochs', default=100, type=int, required=False, help='训练的最大轮次')
|
||||
parser.add_argument('--batch_size', default=4, type=int, required=False, help='训练的batch size')
|
||||
parser.add_argument('--gpu0_bsz', default=10, type=int, required=False, help='0号卡的batch size')
|
||||
parser.add_argument('--lr', default=2.6e-5, type=float, required=False, help='学习率')
|
||||
parser.add_argument('--eps', default=1.0e-09, type=float, required=False, help='衰减率')
|
||||
parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss')
|
||||
parser.add_argument('--gradient_accumulation_steps', default=4, type=int, required=False, help='梯度积累')
|
||||
parser.add_argument('--max_grad_norm', default=2.0, type=float, required=False)
|
||||
parser.add_argument('--save_model_path', default='model', type=str, required=False,
|
||||
help='模型输出路径')
|
||||
parser.add_argument('--pretrained_model', default='', type=str, required=False,
|
||||
help='预训练的模型的路径')
|
||||
# parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
|
||||
parser.add_argument('--num_workers', type=int, default=0, help="dataloader加载数据时使用的线程数量")
|
||||
parser.add_argument('--patience', type=int, default=0, help="用于early stopping,设为0时,不进行early stopping.early stop得到的模型的生成效果不一定会更好。")
|
||||
parser.add_argument('--warmup_steps', type=int, default=4000, help='warm up步数')
|
||||
# parser.add_argument('--label_smoothing', default=True, action='store_true', help='是否进行标签平滑')
|
||||
parser.add_argument('--val_num', type=int, default=8000, help='验证集大小')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def create_logger(args):
|
||||
"""
|
||||
将日志输出到日志文件和控制台
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# 创建一个handler,用于写入日志文件
|
||||
file_handler = logging.FileHandler(
|
||||
filename=args.log_path)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(logging.INFO)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# 创建一个handler,用于将日志输出到控制台
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.DEBUG)
|
||||
console.setFormatter(formatter)
|
||||
logger.addHandler(console)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=0)
|
||||
labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100)
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
# def padding_batch(data_list, pad_id):
|
||||
# """
|
||||
# 使用pad_id将data_list的每条数据,填充至data_list中最长的长度
|
||||
# :param data_list:
|
||||
# :param pad_id:
|
||||
# :return:
|
||||
# """
|
||||
# # 统计data_list中的最大长度
|
||||
# max_len = 0
|
||||
# for data in data_list:
|
||||
# max_len = max_len if max_len > len(data) else len(data)
|
||||
#
|
||||
# # 对数据进行padding
|
||||
# new_data_list = []
|
||||
# for data in data_list:
|
||||
# new_data = data + [pad_id] * (max_len - len(data))
|
||||
# new_data_list.append(new_data)
|
||||
# return new_data_list
|
||||
|
||||
|
||||
def load_dataset(logger, args):
|
||||
"""
|
||||
加载训练集和验证集
|
||||
"""
|
||||
logger.info("loading training dataset and validating dataset")
|
||||
train_path = args.train_path
|
||||
|
||||
with open(train_path, "rb") as f:
|
||||
input_list = pickle.load(f)
|
||||
|
||||
# 划分训练集与验证集
|
||||
val_num = args.val_num
|
||||
input_list_train = input_list[val_num:]
|
||||
input_list_val = input_list[:val_num]
|
||||
# test
|
||||
# input_list_train = input_list_train[:24]
|
||||
# input_list_val = input_list_val[:24]
|
||||
|
||||
train_dataset = MyDataset(input_list_train, args.max_len)
|
||||
val_dataset = MyDataset(input_list_val, args.max_len)
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
||||
def train_epoch(model, train_dataloader, optimizer, scheduler, logger,
|
||||
epoch, args):
|
||||
model.train()
|
||||
device = args.device
|
||||
# pad_id = args.pad_id
|
||||
# sep_id = args.sep_id
|
||||
ignore_index = args.ignore_index
|
||||
epoch_start_time = datetime.now()
|
||||
total_loss = 0 # 记录下整个epoch的loss的总和
|
||||
|
||||
# epoch_correct_num:每个epoch中,output预测正确的word的数量
|
||||
# epoch_total_num: 每个epoch中,output预测的word的总数量
|
||||
epoch_correct_num, epoch_total_num = 0, 0
|
||||
|
||||
for batch_idx, (input_ids, labels) in enumerate(train_dataloader):
|
||||
# 捕获cuda out of memory exception
|
||||
try:
|
||||
input_ids = input_ids.to(device)
|
||||
labels = labels.to(device)
|
||||
outputs = model.forward(input_ids, labels=labels)
|
||||
logits = outputs.logits
|
||||
loss = outputs.loss
|
||||
loss = loss.mean()
|
||||
|
||||
# 统计该batch的预测token的正确数与总数
|
||||
batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)
|
||||
# 统计该epoch的预测token的正确数与总数
|
||||
epoch_correct_num += batch_correct_num
|
||||
epoch_total_num += batch_total_num
|
||||
# 计算该batch的accuracy
|
||||
batch_acc = batch_correct_num / batch_total_num
|
||||
|
||||
total_loss += loss.item()
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
|
||||
loss.backward()
|
||||
# 梯度裁剪
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
# 进行一定step的梯度累计之后,更新参数
|
||||
if (batch_idx + 1) % args.gradient_accumulation_steps == 0:
|
||||
# 更新参数
|
||||
optimizer.step()
|
||||
# 更新学习率
|
||||
scheduler.step()
|
||||
# 清空梯度信息
|
||||
optimizer.zero_grad()
|
||||
|
||||
if (batch_idx + 1) % args.log_step == 0:
|
||||
logger.info(
|
||||
"batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(
|
||||
batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))
|
||||
|
||||
del input_ids, outputs
|
||||
|
||||
except RuntimeError as exception:
|
||||
if "out of memory" in str(exception):
|
||||
logger.info("WARNING: ran out of memory")
|
||||
if hasattr(torch.cuda, 'empty_cache'):
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
logger.info(str(exception))
|
||||
raise exception
|
||||
|
||||
# 记录当前epoch的平均loss与accuracy
|
||||
epoch_mean_loss = total_loss / len(train_dataloader)
|
||||
epoch_mean_acc = epoch_correct_num / epoch_total_num
|
||||
logger.info(
|
||||
"epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))
|
||||
|
||||
# save model
|
||||
logger.info('saving model for epoch {}'.format(epoch + 1))
|
||||
model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))
|
||||
if not os.path.exists(model_path):
|
||||
os.mkdir(model_path)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_to_save.save_pretrained(model_path)
|
||||
logger.info('epoch {} finished'.format(epoch + 1))
|
||||
epoch_finish_time = datetime.now()
|
||||
logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))
|
||||
|
||||
return epoch_mean_loss
|
||||
|
||||
|
||||
def validate_epoch(model, validate_dataloader, logger, epoch, args):
|
||||
logger.info("start validating")
|
||||
model.eval()
|
||||
device = args.device
|
||||
# pad_id = args.pad_id
|
||||
# sep_id = args.sep_id
|
||||
ignore_index = args.ignore_index
|
||||
epoch_start_time = datetime.now()
|
||||
total_loss = 0
|
||||
# 捕获cuda out of memory exception
|
||||
try:
|
||||
with torch.no_grad():
|
||||
for batch_idx, (input_ids, labels) in enumerate(validate_dataloader):
|
||||
input_ids = input_ids.to(device)
|
||||
labels = labels.to(device)
|
||||
outputs = model.forward(input_ids, labels=labels)
|
||||
logits = outputs.logits
|
||||
loss = outputs.loss
|
||||
loss = loss.mean()
|
||||
|
||||
total_loss += loss.item()
|
||||
del input_ids, outputs
|
||||
|
||||
# 记录当前epoch的平均loss
|
||||
epoch_mean_loss = total_loss / len(validate_dataloader)
|
||||
logger.info(
|
||||
"validate epoch {}: loss {}".format(epoch+1, epoch_mean_loss))
|
||||
epoch_finish_time = datetime.now()
|
||||
logger.info('time for validating one epoch: {}'.format(epoch_finish_time - epoch_start_time))
|
||||
return epoch_mean_loss
|
||||
except RuntimeError as exception:
|
||||
if "out of memory" in str(exception):
|
||||
logger.info("WARNING: ran out of memory")
|
||||
if hasattr(torch.cuda, 'empty_cache'):
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
logger.info(str(exception))
|
||||
raise exception
|
||||
|
||||
|
||||
def train(model, logger, train_dataset, validate_dataset, args):
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,
|
||||
drop_last=True
|
||||
)
|
||||
validate_dataloader = DataLoader(validate_dataset, batch_size=args.batch_size, shuffle=True,
|
||||
num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True)
|
||||
early_stopping = EarlyStopping(args.patience, verbose=True, save_path=args.save_model_path)
|
||||
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs
|
||||
optimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)
|
||||
# scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
|
||||
scheduler = transformers.get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
|
||||
logger.info('starting training')
|
||||
|
||||
# 用于记录每个epoch训练和验证的loss
|
||||
train_losses, validate_losses = [], []
|
||||
# 记录验证集的最小loss
|
||||
best_val_loss = 10000
|
||||
# 开始训练
|
||||
for epoch in range(args.epochs):
|
||||
# ========== train ========== #
|
||||
train_loss = train_epoch(
|
||||
model=model, train_dataloader=train_dataloader,
|
||||
optimizer=optimizer, scheduler=scheduler,
|
||||
logger=logger, epoch=epoch, args=args)
|
||||
train_losses.append(train_loss)
|
||||
|
||||
# ========== validate ========== #
|
||||
validate_loss = validate_epoch(
|
||||
model=model, validate_dataloader=validate_dataloader,
|
||||
logger=logger, epoch=epoch, args=args)
|
||||
validate_losses.append(validate_loss)
|
||||
|
||||
# 保存当前困惑度最低的模型,困惑度低,模型的生成效果不一定会越好
|
||||
if validate_loss < best_val_loss:
|
||||
best_val_loss = validate_loss
|
||||
logger.info('saving current best model for epoch {}'.format(epoch + 1))
|
||||
model_path = join(args.save_model_path, 'min_ppl_model'.format(epoch + 1))
|
||||
if not os.path.exists(model_path):
|
||||
os.mkdir(model_path)
|
||||
model_to_save = model.module if hasattr(model, 'module') else model
|
||||
model_to_save.save_pretrained(model_path)
|
||||
|
||||
# 如果patience=0,则不进行early stopping
|
||||
if args.patience == 0:
|
||||
continue
|
||||
early_stopping(validate_loss, model)
|
||||
if early_stopping.early_stop:
|
||||
logger.info("Early stopping")
|
||||
break
|
||||
logger.info('training finished')
|
||||
logger.info("train_losses:{}".format(train_losses))
|
||||
logger.info("validate_losses:{}".format(validate_losses))
|
||||
|
||||
|
||||
def caculate_loss(logit, target, pad_idx, smoothing=True):
|
||||
if smoothing:
|
||||
logit = logit[..., :-1, :].contiguous().view(-1, logit.size(2))
|
||||
target = target[..., 1:].contiguous().view(-1)
|
||||
|
||||
eps = 0.1
|
||||
n_class = logit.size(-1)
|
||||
|
||||
one_hot = torch.zeros_like(logit).scatter(1, target.view(-1, 1), 1)
|
||||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||
log_prb = F.log_softmax(logit, dim=1)
|
||||
|
||||
non_pad_mask = target.ne(pad_idx)
|
||||
loss = -(one_hot * log_prb).sum(dim=1)
|
||||
loss = loss.masked_select(non_pad_mask).mean() # average later
|
||||
else:
|
||||
# loss = F.cross_entropy(predict_logit, target, ignore_index=pad_idx)
|
||||
logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1))
|
||||
labels = target[..., 1:].contiguous().view(-1)
|
||||
loss = F.cross_entropy(logit, labels, ignore_index=pad_idx)
|
||||
return loss
|
||||
|
||||
|
||||
def calculate_acc(logit, labels, ignore_index=-100):
|
||||
logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1))
|
||||
labels = labels[..., 1:].contiguous().view(-1)
|
||||
|
||||
_, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index
|
||||
# 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1
|
||||
non_pad_mask = labels.ne(ignore_index)
|
||||
n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item()
|
||||
n_word = non_pad_mask.sum().item()
|
||||
return n_correct, n_word
|
||||
|
||||
|
||||
def main():
|
||||
# 初始化参数
|
||||
args = set_args()
|
||||
|
||||
# 设置使用哪些显卡进行训练
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
||||
|
||||
args.cuda = not args.no_cuda
|
||||
|
||||
if args.batch_size < 2048 and args.warmup_steps <= 4000:
|
||||
print('[Warning] The warmup steps may be not enough.\n'
|
||||
'(sz_b, warmup) = (2048, 4000) is the official setting.\n'
|
||||
'Using smaller batch w/o longer warmup may cause '
|
||||
'the warmup stage ends with only little data trained.')
|
||||
|
||||
# 创建日志对象
|
||||
logger = create_logger(args)
|
||||
# 当用户使用GPU,并且GPU可用时
|
||||
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
||||
device = 'cuda:0' if args.cuda else 'cpu'
|
||||
args.device = device
|
||||
logger.info('using device:{}'.format(device))
|
||||
|
||||
# 初始化tokenizer
|
||||
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
||||
args.sep_id = tokenizer.sep_token_id
|
||||
args.pad_id = tokenizer.pad_token_id
|
||||
args.cls_id = tokenizer.cls_token_id
|
||||
|
||||
# 创建模型的输出目录
|
||||
if not os.path.exists(args.save_model_path):
|
||||
os.mkdir(args.save_model_path)
|
||||
|
||||
# 创建模型
|
||||
if args.pretrained_model: # 加载预训练模型
|
||||
model = GPT2LMHeadModel.from_pretrained(args.pretrained_model)
|
||||
else: # 初始化模型
|
||||
model_config = GPT2Config.from_json_file(args.model_config)
|
||||
model = GPT2LMHeadModel(config=model_config)
|
||||
model = model.to(device)
|
||||
logger.info('model config:\n{}'.format(model.config.to_json_string()))
|
||||
assert model.config.vocab_size == tokenizer.vocab_size
|
||||
|
||||
# 并行训练模型
|
||||
if args.cuda and torch.cuda.device_count() > 1:
|
||||
model = DataParallel(model).cuda()
|
||||
# model = BalancedDataParallel(args.gpu0_bsz, model, dim=0).cuda()
|
||||
logger.info("use GPU {} to train".format(args.device))
|
||||
|
||||
# 计算模型参数数量
|
||||
num_parameters = 0
|
||||
parameters = model.parameters()
|
||||
for parameter in parameters:
|
||||
num_parameters += parameter.numel()
|
||||
logger.info('number of model parameters: {}'.format(num_parameters))
|
||||
|
||||
# 记录参数设置
|
||||
logger.info("args:{}".format(args))
|
||||
|
||||
# 加载训练集和验证集
|
||||
# ========= Loading Dataset ========= #
|
||||
train_dataset, validate_dataset = load_dataset(logger, args)
|
||||
|
||||
train(model, logger, train_dataset, validate_dataset, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user