From f3448cabc11e0319c57fab00b08508096c075100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=93=9D=E5=86=B0=E8=AE=B0=E5=BF=86?= Date: Thu, 12 Mar 2026 11:09:11 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + README.md | 2 + config/config.json | 38 ++++ cqhttp.py | 148 +++++++++++++ cqhttp_settings.py | 6 + data_parallel.py | 100 +++++++++ dataset.py | 21 ++ generate_dialogue_subset.py | 66 ++++++ interact.py | 167 +++++++++++++++ preprocess.py | 96 +++++++++ pytorchtools.py | 52 +++++ qqmsg_process.py | 40 ++++ requirements.txt | 6 + train.py | 417 ++++++++++++++++++++++++++++++++++++ 14 files changed, 1162 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 config/config.json create mode 100644 cqhttp.py create mode 100644 cqhttp_settings.py create mode 100644 data_parallel.py create mode 100644 dataset.py create mode 100644 generate_dialogue_subset.py create mode 100644 interact.py create mode 100644 preprocess.py create mode 100644 pytorchtools.py create mode 100644 qqmsg_process.py create mode 100644 requirements.txt create mode 100644 train.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..51a32b3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +vocab/ +download/ +data/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..ef1fec9 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +vocab文件太大了,没添加到git里,请去release下载。 +原始语料数据清洗得不是很干净,加上有隐私顾虑,就不上传了。 diff --git a/config/config.json b/config/config.json new file mode 100644 index 0000000..b06c51f --- /dev/null +++ b/config/config.json @@ -0,0 +1,38 @@ +{ + "_name_or_path": "model/epoch29", + "activation_function": "gelu_new", + "architectures": [ + "GPT2LMHeadModel" + ], + "attn_pdrop": 0.1, + "bos_token_id": 50256, + "embd_pdrop": 0.1, + "eos_token_id": 50256, + "gradient_checkpointing": false, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, + "n_head": 12, + "n_inner": null, + "n_layer": 12, + "n_positions": 1024, + "output_past": true, + "resid_pdrop": 0.1, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 400 + } + }, + "tokenizer_class": "BertTokenizer", + "transformers_version": "4.2.0", + "use_cache": true, + "vocab_size": 13317 +} diff --git a/cqhttp.py b/cqhttp.py new file mode 100644 index 0000000..1394f58 --- /dev/null +++ b/cqhttp.py @@ -0,0 +1,148 @@ +import torch +import os +from transformers import GPT2LMHeadModel +from transformers import BertTokenizerFast +import torch.nn.functional as F + +import random +import asyncio +import websockets +import json + +from cqhttp_settings import * +from interact import set_args, create_logger, top_k_top_p_filtering + +PAD = '[PAD]' +pad_id = 0 + + +def init_model(): + args = set_args() + logger = create_logger(args) + # 当用户使用GPU,并且GPU可用时 + args.cuda = torch.cuda.is_available() and not args.no_cuda + device = 'cuda' if args.cuda else 'cpu' + logger.info('using device:{}'.format(device)) + os.environ["CUDA_VISIBLE_DEVICES"] = args.device + tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") + # tokenizer = BertTokenizer(vocab_file=args.voca_path) + model = GPT2LMHeadModel.from_pretrained(args.model_path) + model = model.to(device) + model.eval() + + return args, device, tokenizer, model + + +def process(text, args, device, tokenizer, model): + text_ids = tokenizer.encode(text, add_special_tokens=False) + history.append(text_ids) + input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头 + + for history_id, history_utr in enumerate(history[-args.max_history_len:]): + input_ids.extend(history_utr) + input_ids.append(tokenizer.sep_token_id) + input_ids = torch.tensor(input_ids).long().to(device) + input_ids = input_ids.unsqueeze(0) + response = [] # 根据context,生成的response + # 最多生成max_len个token + for _ in range(args.max_len): + outputs = model(input_ids=input_ids) + logits = outputs.logits + next_token_logits = logits[0, -1, :] + # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率 + for id in set(response): + next_token_logits[id] /= args.repetition_penalty + next_token_logits = next_token_logits / args.temperature + # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token + next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') + filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp) + # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束 + break + response.append(next_token.item()) + input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1) + # his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist()) + # print("his_text:{}".format(his_text)) + history.append(response) + + return "".join(tokenizer.convert_ids_to_tokens(response)) + + +# 存储聊天记录,每个utterance以token的id的形式进行存储 +history = [] + + +def handle_event(evjson): + event = json.loads(evjson) + from_id = "" + msg_type = "" + req = "" + # print(event) + rand_response_possibility = random.randint(0, 100) + if 'message_type' in event and 'raw_message' in event: + msg_type = event['message_type'] + msg_recv = event['raw_message'] + print("收到消息: ", msg_recv) + if event['message_type'] == "group" and event['group_id'] in event_response_settings_group_enabled: + # 群消息 + print("回复几率(?): {} < {}".format(rand_response_possibility, event_response_settings_group_rate)) + if rand_response_possibility < event_response_settings_group_rate: + from_id = event['group_id'] + req = msg_recv + + elif event['message_type'] == "private": + # 私聊消息 + print("回复几率(?): {} < {}".format(rand_response_possibility, event_response_settings_private_rate)) + if rand_response_possibility < event_response_settings_private_rate: + from_id = event['user_id'] + req = msg_recv + + return from_id, msg_type, req + + +async def send_msg(cqhttp, source, msg_type, text): + print("Msg to {}({}): {}".format(source, msg_type, text)) + data_send = {} + if msg_type == "group": + # 群消息 + data_send = { + 'action': "send_group_msg", + 'params': { + 'group_id': source, + 'message': text, + }, + } + elif msg_type == "private": + # 私聊消息 + data_send = { + 'action': "send_private_msg", + 'params': { + 'user_id': source, + 'message': text, + }, + } + + await cqhttp.send(json.dumps(data_send)) + + +async def init_cqhttp_ws(args, device, tokenizer, model): + print("Initializing CQHTTP WebSocket...") + async with websockets.connect(cqhttp_ws_addr + "?access_token=" + cqhttp_ws_access_token) as cqhttp: + while True: + evjson = await cqhttp.recv() + + source, msg_type, req = handle_event(evjson) + + print(req) + + if req: + res = process(req, args, device, tokenizer, model) + if res: + print("Q: {}\nA: {}\n".format(req, res)) + await send_msg(cqhttp, source, msg_type, res) + + +if __name__ == '__main__': + args, device, tokenizer, model = init_model() + asyncio.get_event_loop().run_until_complete(init_cqhttp_ws(args, device, tokenizer, model)) diff --git a/cqhttp_settings.py b/cqhttp_settings.py new file mode 100644 index 0000000..84b1eae --- /dev/null +++ b/cqhttp_settings.py @@ -0,0 +1,6 @@ +cqhttp_ws_addr = "ws://127.0.0.1:16700" +cqhttp_ws_access_token = "WPlJfomObZADDQBiUneH7nv1HfyReDcY" + +event_response_settings_group_rate = 100 # 0~100 +event_response_settings_group_enabled = [1060806176] +event_response_settings_private_rate = 100 # 0~100 diff --git a/data_parallel.py b/data_parallel.py new file mode 100644 index 0000000..6e55b60 --- /dev/null +++ b/data_parallel.py @@ -0,0 +1,100 @@ +from torch.nn.parallel import DataParallel +import torch +from torch.nn.parallel._functions import Scatter +from torch.nn.parallel.parallel_apply import parallel_apply + + +def scatter(inputs, target_gpus, chunk_sizes, dim=0): + r""" + Slices tensors into approximately equal chunks and + distributes them across given GPUs. Duplicates + references to objects that are not tensors. + """ + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + try: + return Scatter.apply(target_gpus, chunk_sizes, dim, obj) + except: + print('obj', obj.size()) + print('dim', dim) + print('chunk_sizes', chunk_sizes) + quit() + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + return list(map(list, zip(*map(scatter_map, obj)))) + if isinstance(obj, dict) and len(obj) > 0: + return list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return [obj for targets in target_gpus] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + return scatter_map(inputs) + finally: + scatter_map = None + + +def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): + r"""Scatter with support for kwargs dictionary""" + inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] + kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs + + +class BalancedDataParallel(DataParallel): + def __init__(self, gpu0_bsz, *args, **kwargs): + self.gpu0_bsz = gpu0_bsz + super().__init__(*args, **kwargs) + + def forward(self, *inputs, **kwargs): + if not self.device_ids: + return self.module(*inputs, **kwargs) + if self.gpu0_bsz == 0: + device_ids = self.device_ids[1:] + else: + device_ids = self.device_ids + inputs, kwargs = self.scatter(inputs, kwargs, device_ids) + # print('len(inputs)1: ', str(len(inputs))) + # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)])) + if len(self.device_ids) == 1: + return self.module(*inputs[0], **kwargs[0]) + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + if self.gpu0_bsz == 0: + replicas = replicas[1:] + outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) + return self.gather(outputs, self.output_device) + + def parallel_apply(self, replicas, device_ids, inputs, kwargs): + return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)]) + + def scatter(self, inputs, kwargs, device_ids): + bsz = inputs[0].size(self.dim) + num_dev = len(self.device_ids) + gpu0_bsz = self.gpu0_bsz + bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) + if gpu0_bsz < bsz_unit: + chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) + delta = bsz - sum(chunk_sizes) + for i in range(delta): + chunk_sizes[i + 1] += 1 + if gpu0_bsz == 0: + chunk_sizes = chunk_sizes[1:] + else: + return super().scatter(inputs, kwargs, device_ids) + + # print('bsz: ', bsz) + # print('num_dev: ', num_dev) + # print('gpu0_bsz: ', gpu0_bsz) + # print('bsz_unit: ', bsz_unit) + # print('chunk_sizes: ', chunk_sizes) + return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..1194149 --- /dev/null +++ b/dataset.py @@ -0,0 +1,21 @@ +from torch.utils.data import Dataset +import torch + + +class MyDataset(Dataset): + """ + + """ + + def __init__(self, input_list, max_len): + self.input_list = input_list + self.max_len = max_len + + def __getitem__(self, index): + input_ids = self.input_list[index] + input_ids = input_ids[:self.max_len] + input_ids = torch.tensor(input_ids, dtype=torch.long) + return input_ids + + def __len__(self): + return len(self.input_list) diff --git a/generate_dialogue_subset.py b/generate_dialogue_subset.py new file mode 100644 index 0000000..23f5fd2 --- /dev/null +++ b/generate_dialogue_subset.py @@ -0,0 +1,66 @@ +import argparse +from os.path import join +from collections import Counter +import matplotlib.pyplot as plt +from matplotlib.pyplot import MultipleLocator + + +def generate_subset(): + """ + 用于生成训练子集 + :return: + """ + parser = argparse.ArgumentParser() + parser.add_argument('--raw_data_path', default='data/train.txt', type=str, required=False, help='原始训练语料') + parser.add_argument('--subset_size', default=1000000, type=int, required=False, help='要获取的对话数据子集的规模') + parser.add_argument('--subset_data_path', default='data', type=str, required=False, + help='数据子集文件路径,指定文件的父目录') + args = parser.parse_args() + with open(args.raw_data_path, "r", encoding="utf8") as f: + data = f.read() + dialogues = data.split("\n\n") + subset_size = min(len(dialogues), args.subset_size) + + with open(join(args.subset_data_path, "train_{}w.txt".format(int(subset_size / 10000))), "w", encoding="utf8") as f: + print("generating subset,please wait a few minutes") + for dialogue_index, dialogue in enumerate(dialogues): + if dialogue_index >= subset_size: + break + for utterance in dialogue.split("\n"): + f.writelines(utterance + "\n") + f.writelines("\n") + + +def compute_dialogue_length(): + """ + 查看聊天语料中的dialogue的长度分布 + :return: + """ + parser = argparse.ArgumentParser() + parser.add_argument('--raw_data_path', default='data/train.txt', type=str, required=False, help='原始训练语料') + args = parser.parse_args() + with open(args.raw_data_path, "r", encoding="utf8") as f: + data = f.read() + dialogues = data.split("\n\n") + # 统计各个dialogue的长度 + dialogues_lengths = [len(dialogue.replace("\n", "")) for dialogue in dialogues] + counter = Counter(dialogues_lengths) # {label:sum(label)} + dialogue_length_arr = list(counter) + num_arr = [counter[element] for element in list(counter)] + print(counter[300]) + + x_major_locator = MultipleLocator(100) # MultipleLocator用于设置刻度间隔 + # y_major_locator = MultipleLocator(20000) + ax = plt.gca() # ax为两条坐标轴的实例 + ax.xaxis.set_major_locator(x_major_locator) # 把x轴的主刻度设置为10的倍数 + # ax.yaxis.set_major_locator(y_major_locator) + + plt.xlabel('dialogue length') + plt.ylabel('number of dialogue') + # plt.plot(dialogue_length_arr, num_arr, c='green') + plt.scatter(dialogue_length_arr, num_arr) + plt.show() + + +if __name__ == '__main__': + generate_subset() diff --git a/interact.py b/interact.py new file mode 100644 index 0000000..740131e --- /dev/null +++ b/interact.py @@ -0,0 +1,167 @@ +import torch +import os +import argparse +from datetime import datetime +import logging +from transformers import GPT2LMHeadModel +from transformers import BertTokenizerFast +import torch.nn.functional as F + +PAD = '[PAD]' +pad_id = 0 + + +def set_args(): + """ + Sets up the arguments. + """ + parser = argparse.ArgumentParser() + parser.add_argument('--device', default='0', type=str, required=False, help='生成设备') + parser.add_argument('--temperature', default=1, type=float, required=False, help='生成的temperature') + parser.add_argument('--topk', default=8, type=int, required=False, help='最高k选1') + parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率') + # parser.add_argument('--model_config', default='config/model_config_dialogue_small.json', type=str, required=False, + # help='模型参数') + parser.add_argument('--log_path', default='data/interact.log', type=str, required=False, help='interact日志存放位置') + parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False, help='选择词库') + parser.add_argument('--model_path', default='model/epoch40', type=str, required=False, help='对话模型路径') + parser.add_argument('--save_samples_path', default="sample/", type=str, required=False, help="保存聊天记录的文件路径") + parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, + help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数") + # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的') + parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断') + parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度") + parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测') + return parser.parse_args() + + +def create_logger(args): + """ + 将日志输出到日志文件和控制台 + """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s') + + # 创建一个handler,用于写入日志文件 + file_handler = logging.FileHandler( + filename=args.log_path) + file_handler.setFormatter(formatter) + file_handler.setLevel(logging.INFO) + logger.addHandler(file_handler) + + # 创建一个handler,用于将日志输出到控制台 + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + console.setFormatter(formatter) + logger.addHandler(console) + + return logger + + +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (vocab size) + top_k > 0: keep only top k tokens with highest probability (top-k filtering). + top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear + top_k = min(top_k, logits.size(-1)) # Safety check + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + # torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices) + # ...表示其他维度由计算机自行推断 + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷 + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) # 对logits进行递减排序 + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + return logits + + +def main(): + args = set_args() + logger = create_logger(args) + # 当用户使用GPU,并且GPU可用时 + args.cuda = torch.cuda.is_available() and not args.no_cuda + device = 'cuda' if args.cuda else 'cpu' + logger.info('using device:{}'.format(device)) + os.environ["CUDA_VISIBLE_DEVICES"] = args.device + tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") + # tokenizer = BertTokenizer(vocab_file=args.voca_path) + model = GPT2LMHeadModel.from_pretrained(args.model_path) + model = model.to(device) + model.eval() + if args.save_samples_path: + if not os.path.exists(args.save_samples_path): + os.makedirs(args.save_samples_path) + samples_file = open(args.save_samples_path + '/samples.txt', 'a', encoding='utf8') + samples_file.write("聊天记录{}:\n".format(datetime.now())) + # 存储聊天记录,每个utterance以token的id的形式进行存储 + history = [] + print('开始和chatbot聊天,输入CTRL + Z以退出') + + while True: + try: + text = input("user:") + # text = "你好" + if args.save_samples_path: + samples_file.write("user:{}\n".format(text)) + text_ids = tokenizer.encode(text, add_special_tokens=False) + history.append(text_ids) + input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头 + + for history_id, history_utr in enumerate(history[-args.max_history_len:]): + input_ids.extend(history_utr) + input_ids.append(tokenizer.sep_token_id) + input_ids = torch.tensor(input_ids).long().to(device) + input_ids = input_ids.unsqueeze(0) + response = [] # 根据context,生成的response + # 最多生成max_len个token + for _ in range(args.max_len): + outputs = model(input_ids=input_ids) + logits = outputs.logits + next_token_logits = logits[0, -1, :] + # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率 + for id in set(response): + next_token_logits[id] /= args.repetition_penalty + next_token_logits = next_token_logits / args.temperature + # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token + next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') + filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp) + # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 + next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) + if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束 + break + response.append(next_token.item()) + input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1) + # his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist()) + # print("his_text:{}".format(his_text)) + history.append(response) + text = tokenizer.convert_ids_to_tokens(response) + print("chatbot:" + "".join(text)) + if args.save_samples_path: + samples_file.write("chatbot:{}\n".format("".join(text))) + except KeyboardInterrupt: + if args.save_samples_path: + samples_file.close() + break + + +if __name__ == '__main__': + main() diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000..93fbd9c --- /dev/null +++ b/preprocess.py @@ -0,0 +1,96 @@ +from transformers import BertTokenizerFast +import argparse +import pickle +from tqdm import tqdm +import logging +import numpy as np + + +def create_logger(log_path): + """ + 将日志输出到日志文件和控制台 + """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s') + + # 创建一个handler,用于写入日志文件 + file_handler = logging.FileHandler( + filename=log_path) + file_handler.setFormatter(formatter) + file_handler.setLevel(logging.INFO) + logger.addHandler(file_handler) + + # 创建一个handler,用于将日志输出到控制台 + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + console.setFormatter(formatter) + logger.addHandler(console) + + return logger + + +def preprocess(): + """ + 对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]" + """ + # 设置参数 + parser = argparse.ArgumentParser() + parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False, + help='词表路径') + parser.add_argument('--log_path', default='data/preprocess.log', type=str, required=False, help='训练日志存放位置') + parser.add_argument('--train_path', default='data/train.txt', type=str, required=False, help='训练日志存放位置') + parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False, help='tokenize的训练数据集') + args = parser.parse_args() + + # 初始化日志对象 + logger = create_logger(args.log_path) + + # 初始化tokenizer + tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") + sep_id = tokenizer.sep_token_id + cls_id = tokenizer.cls_token_id + logger.info("preprocessing data,data path:{}, save path:{}".format(args.train_path, args.save_path)) + + # 读取训练数据集 + with open(args.train_path, 'rb') as f: + data = f.read().decode("utf-8") + + # 需要区分linux和windows环境下的换行符 + if "\r\n" in data: + train_data = data.split("\r\n\r\n") + else: + train_data = data.split("\n\n") + logger.info("there are {} dialogue in dataset".format(len(train_data))) + + # 开始进行tokenize + # 保存所有的对话数据,每条数据的格式为:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]" + dialogue_len = [] # 记录所有对话tokenize之后的长度,用于统计中位数与均值 + dialogue_list = [] + with open(args.save_path, "w", encoding="utf-8") as f: + for index, dialogue in enumerate(tqdm(train_data)): + if "\r\n" in data: + utterances = dialogue.split("\r\n") + else: + utterances = dialogue.split("\n") + + input_ids = [cls_id] # 每个dialogue以[CLS]开头 + for utterance in utterances: + input_ids += tokenizer.encode(utterance, add_special_tokens=False) + input_ids.append(sep_id) # 每个utterance之后添加[SEP],表示utterance结束 + dialogue_len.append(len(input_ids)) + dialogue_list.append(input_ids) + len_mean = np.mean(dialogue_len) + len_median = np.median(dialogue_len) + len_max = np.max(dialogue_len) + with open(args.save_path, "wb") as f: + pickle.dump(dialogue_list, f) + logger.info("finish preprocessing data,the result is stored in {}".format(args.save_path)) + logger.info("mean of dialogue len:{},median of dialogue len:{},max len:{}".format(len_mean, len_median, len_max)) + + +if __name__ == '__main__': + preprocess() + diff --git a/pytorchtools.py b/pytorchtools.py new file mode 100644 index 0000000..2b71f3c --- /dev/null +++ b/pytorchtools.py @@ -0,0 +1,52 @@ +import numpy as np + + +class EarlyStopping: + """Early stops the training if validation loss doesn't improve after a given patience.""" + + def __init__(self, patience=7, verbose=False, delta=0, save_path="."): + """ + Args: + patience (int): How long to wait after last time validation loss improved. + Default: 7 + verbose (bool): If True, prints a message for each validation loss improvement. + Default: False + delta (float): Minimum change in the monitored quantity to qualify as an improvement. + Default: 0 + """ + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + self.save_path = save_path + + def __call__(self, val_loss, model): + + score = -val_loss + + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model) + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model) + self.counter = 0 + + def save_checkpoint(self, val_loss, model): + '''Saves model when validation loss decrease.''' + if self.verbose: + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + # save_path = join(self.save_path, "best_model") + # if not os.path.exists(save_path): + # os.mkdir(save_path) + # model_to_save = model.module if hasattr(model, 'module') else model + # model_to_save.save_pretrained(save_path) + self.val_loss_min = val_loss diff --git a/qqmsg_process.py b/qqmsg_process.py new file mode 100644 index 0000000..503c7c9 --- /dev/null +++ b/qqmsg_process.py @@ -0,0 +1,40 @@ +import os +from datetime import datetime + +files = os.walk("download/QQMsg") + +outputFile = open("data/train_qq.txt", "w", encoding="utf-8") + +for path, dir_list, file_list in files: + for file_name in file_list: + print(os.path.join(path, file_name)) + f = open(os.path.join(path, file_name), "r", encoding="utf-8") + lines = f.readlines() + stat = 0 # 0: ready to parse time / 1: ready to parse log + lastTime = datetime.strptime("1970-1-1 00:00:00", "%Y-%m-%d %H:%M:%S") + for i in range(8, len(lines)): + raw = lines[i].replace("\r\n", "").replace("\n", "") + + # 这一行是时间 + timeStrs = raw.split(' ', 2) + try: + # 这一行是时间 + if timeStrs[0][0] == '2': + tsStr = timeStrs[0] + " " + timeStrs[1] + else: + tsStr = timeStrs[1] + " " + timeStrs[2] + ts = datetime.strptime(tsStr, "%Y-%m-%d %H:%M:%S") + if ((ts - lastTime).seconds > 120) or ((ts - lastTime).seconds < 0): + # 间隔2分钟以上,认为是不同的对话 + outputFile.write("\n") + lastTime = ts + except (IndexError, ValueError) as e: + # 这一行是消息 + msg = raw.replace("[图片]", "").replace("[表情]", "").replace("[合并转发]请使用手机QQ最新版本查看", "") + if msg != "": + # 是有效行 + outputFile.write(msg + "\n") + + f.close() + +outputFile.close() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..97b3c13 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +transformers~=4.6.1 +torch~=1.8.1 +tqdm~=4.61.0 +numpy~=1.20.3 +matplotlib~=3.4.2 +websockets~=9.1 \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..f952562 --- /dev/null +++ b/train.py @@ -0,0 +1,417 @@ +import argparse +import torch +import torch.nn.functional as F +import logging +from datetime import datetime +import os +from torch.utils.data import DataLoader +from os.path import join +from torch.nn import DataParallel +import transformers +import pickle +from pytorchtools import EarlyStopping +from transformers import GPT2LMHeadModel, GPT2Config +from transformers import BertTokenizerFast +import torch.nn.utils.rnn as rnn_utils +from dataset import MyDataset + + +def set_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--device', default='3', type=str, required=False, help='设置使用哪些显卡') + parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行训练') + parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False, + help='词表路径') + parser.add_argument('--model_config', default='config/config.json', type=str, required=False, + help='设置模型参数') + parser.add_argument('--train_path', default='data/train.pkl', type=str, required=False, help='训练集路径') + parser.add_argument('--max_len', default=150, type=int, required=False, help='训练时,输入数据的最大长度') + + parser.add_argument('--log_path', default='data/train.log', type=str, required=False, help='训练日志存放位置') + parser.add_argument('--log', default=True, help="是否记录日志") + parser.add_argument('--ignore_index', default=-100, type=int, required=False, help='对于ignore_index的label token不计算梯度') + # parser.add_argument('--input_len', default=200, type=int, required=False, help='输入的长度') + parser.add_argument('--epochs', default=100, type=int, required=False, help='训练的最大轮次') + parser.add_argument('--batch_size', default=4, type=int, required=False, help='训练的batch size') + parser.add_argument('--gpu0_bsz', default=10, type=int, required=False, help='0号卡的batch size') + parser.add_argument('--lr', default=2.6e-5, type=float, required=False, help='学习率') + parser.add_argument('--eps', default=1.0e-09, type=float, required=False, help='衰减率') + parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss') + parser.add_argument('--gradient_accumulation_steps', default=4, type=int, required=False, help='梯度积累') + parser.add_argument('--max_grad_norm', default=2.0, type=float, required=False) + parser.add_argument('--save_model_path', default='model', type=str, required=False, + help='模型输出路径') + parser.add_argument('--pretrained_model', default='', type=str, required=False, + help='预训练的模型的路径') + # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的') + parser.add_argument('--num_workers', type=int, default=0, help="dataloader加载数据时使用的线程数量") + parser.add_argument('--patience', type=int, default=0, help="用于early stopping,设为0时,不进行early stopping.early stop得到的模型的生成效果不一定会更好。") + parser.add_argument('--warmup_steps', type=int, default=4000, help='warm up步数') + # parser.add_argument('--label_smoothing', default=True, action='store_true', help='是否进行标签平滑') + parser.add_argument('--val_num', type=int, default=8000, help='验证集大小') + args = parser.parse_args() + return args + + +def create_logger(args): + """ + 将日志输出到日志文件和控制台 + """ + logger = logging.getLogger(__name__) + logger.setLevel(logging.INFO) + + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s') + + # 创建一个handler,用于写入日志文件 + file_handler = logging.FileHandler( + filename=args.log_path) + file_handler.setFormatter(formatter) + file_handler.setLevel(logging.INFO) + logger.addHandler(file_handler) + + # 创建一个handler,用于将日志输出到控制台 + console = logging.StreamHandler() + console.setLevel(logging.DEBUG) + console.setFormatter(formatter) + logger.addHandler(console) + + return logger + + +def collate_fn(batch): + input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=0) + labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100) + return input_ids, labels + + +# def padding_batch(data_list, pad_id): +# """ +# 使用pad_id将data_list的每条数据,填充至data_list中最长的长度 +# :param data_list: +# :param pad_id: +# :return: +# """ +# # 统计data_list中的最大长度 +# max_len = 0 +# for data in data_list: +# max_len = max_len if max_len > len(data) else len(data) +# +# # 对数据进行padding +# new_data_list = [] +# for data in data_list: +# new_data = data + [pad_id] * (max_len - len(data)) +# new_data_list.append(new_data) +# return new_data_list + + +def load_dataset(logger, args): + """ + 加载训练集和验证集 + """ + logger.info("loading training dataset and validating dataset") + train_path = args.train_path + + with open(train_path, "rb") as f: + input_list = pickle.load(f) + + # 划分训练集与验证集 + val_num = args.val_num + input_list_train = input_list[val_num:] + input_list_val = input_list[:val_num] + # test + # input_list_train = input_list_train[:24] + # input_list_val = input_list_val[:24] + + train_dataset = MyDataset(input_list_train, args.max_len) + val_dataset = MyDataset(input_list_val, args.max_len) + + return train_dataset, val_dataset + + +def train_epoch(model, train_dataloader, optimizer, scheduler, logger, + epoch, args): + model.train() + device = args.device + # pad_id = args.pad_id + # sep_id = args.sep_id + ignore_index = args.ignore_index + epoch_start_time = datetime.now() + total_loss = 0 # 记录下整个epoch的loss的总和 + + # epoch_correct_num:每个epoch中,output预测正确的word的数量 + # epoch_total_num: 每个epoch中,output预测的word的总数量 + epoch_correct_num, epoch_total_num = 0, 0 + + for batch_idx, (input_ids, labels) in enumerate(train_dataloader): + # 捕获cuda out of memory exception + try: + input_ids = input_ids.to(device) + labels = labels.to(device) + outputs = model.forward(input_ids, labels=labels) + logits = outputs.logits + loss = outputs.loss + loss = loss.mean() + + # 统计该batch的预测token的正确数与总数 + batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index) + # 统计该epoch的预测token的正确数与总数 + epoch_correct_num += batch_correct_num + epoch_total_num += batch_total_num + # 计算该batch的accuracy + batch_acc = batch_correct_num / batch_total_num + + total_loss += loss.item() + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + loss.backward() + # 梯度裁剪 + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + # 进行一定step的梯度累计之后,更新参数 + if (batch_idx + 1) % args.gradient_accumulation_steps == 0: + # 更新参数 + optimizer.step() + # 更新学习率 + scheduler.step() + # 清空梯度信息 + optimizer.zero_grad() + + if (batch_idx + 1) % args.log_step == 0: + logger.info( + "batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format( + batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr())) + + del input_ids, outputs + + except RuntimeError as exception: + if "out of memory" in str(exception): + logger.info("WARNING: ran out of memory") + if hasattr(torch.cuda, 'empty_cache'): + torch.cuda.empty_cache() + else: + logger.info(str(exception)) + raise exception + + # 记录当前epoch的平均loss与accuracy + epoch_mean_loss = total_loss / len(train_dataloader) + epoch_mean_acc = epoch_correct_num / epoch_total_num + logger.info( + "epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc)) + + # save model + logger.info('saving model for epoch {}'.format(epoch + 1)) + model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1)) + if not os.path.exists(model_path): + os.mkdir(model_path) + model_to_save = model.module if hasattr(model, 'module') else model + model_to_save.save_pretrained(model_path) + logger.info('epoch {} finished'.format(epoch + 1)) + epoch_finish_time = datetime.now() + logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time)) + + return epoch_mean_loss + + +def validate_epoch(model, validate_dataloader, logger, epoch, args): + logger.info("start validating") + model.eval() + device = args.device + # pad_id = args.pad_id + # sep_id = args.sep_id + ignore_index = args.ignore_index + epoch_start_time = datetime.now() + total_loss = 0 + # 捕获cuda out of memory exception + try: + with torch.no_grad(): + for batch_idx, (input_ids, labels) in enumerate(validate_dataloader): + input_ids = input_ids.to(device) + labels = labels.to(device) + outputs = model.forward(input_ids, labels=labels) + logits = outputs.logits + loss = outputs.loss + loss = loss.mean() + + total_loss += loss.item() + del input_ids, outputs + + # 记录当前epoch的平均loss + epoch_mean_loss = total_loss / len(validate_dataloader) + logger.info( + "validate epoch {}: loss {}".format(epoch+1, epoch_mean_loss)) + epoch_finish_time = datetime.now() + logger.info('time for validating one epoch: {}'.format(epoch_finish_time - epoch_start_time)) + return epoch_mean_loss + except RuntimeError as exception: + if "out of memory" in str(exception): + logger.info("WARNING: ran out of memory") + if hasattr(torch.cuda, 'empty_cache'): + torch.cuda.empty_cache() + else: + logger.info(str(exception)) + raise exception + + +def train(model, logger, train_dataset, validate_dataset, args): + train_dataloader = DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn, + drop_last=True + ) + validate_dataloader = DataLoader(validate_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.num_workers, collate_fn=collate_fn, drop_last=True) + early_stopping = EarlyStopping(args.patience, verbose=True, save_path=args.save_model_path) + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs + optimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps) + # scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = transformers.get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total + ) + + logger.info('starting training') + + # 用于记录每个epoch训练和验证的loss + train_losses, validate_losses = [], [] + # 记录验证集的最小loss + best_val_loss = 10000 + # 开始训练 + for epoch in range(args.epochs): + # ========== train ========== # + train_loss = train_epoch( + model=model, train_dataloader=train_dataloader, + optimizer=optimizer, scheduler=scheduler, + logger=logger, epoch=epoch, args=args) + train_losses.append(train_loss) + + # ========== validate ========== # + validate_loss = validate_epoch( + model=model, validate_dataloader=validate_dataloader, + logger=logger, epoch=epoch, args=args) + validate_losses.append(validate_loss) + + # 保存当前困惑度最低的模型,困惑度低,模型的生成效果不一定会越好 + if validate_loss < best_val_loss: + best_val_loss = validate_loss + logger.info('saving current best model for epoch {}'.format(epoch + 1)) + model_path = join(args.save_model_path, 'min_ppl_model'.format(epoch + 1)) + if not os.path.exists(model_path): + os.mkdir(model_path) + model_to_save = model.module if hasattr(model, 'module') else model + model_to_save.save_pretrained(model_path) + + # 如果patience=0,则不进行early stopping + if args.patience == 0: + continue + early_stopping(validate_loss, model) + if early_stopping.early_stop: + logger.info("Early stopping") + break + logger.info('training finished') + logger.info("train_losses:{}".format(train_losses)) + logger.info("validate_losses:{}".format(validate_losses)) + + +def caculate_loss(logit, target, pad_idx, smoothing=True): + if smoothing: + logit = logit[..., :-1, :].contiguous().view(-1, logit.size(2)) + target = target[..., 1:].contiguous().view(-1) + + eps = 0.1 + n_class = logit.size(-1) + + one_hot = torch.zeros_like(logit).scatter(1, target.view(-1, 1), 1) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(logit, dim=1) + + non_pad_mask = target.ne(pad_idx) + loss = -(one_hot * log_prb).sum(dim=1) + loss = loss.masked_select(non_pad_mask).mean() # average later + else: + # loss = F.cross_entropy(predict_logit, target, ignore_index=pad_idx) + logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) + labels = target[..., 1:].contiguous().view(-1) + loss = F.cross_entropy(logit, labels, ignore_index=pad_idx) + return loss + + +def calculate_acc(logit, labels, ignore_index=-100): + logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) + labels = labels[..., 1:].contiguous().view(-1) + + _, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index + # 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1 + non_pad_mask = labels.ne(ignore_index) + n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item() + n_word = non_pad_mask.sum().item() + return n_correct, n_word + + +def main(): + # 初始化参数 + args = set_args() + + # 设置使用哪些显卡进行训练 + os.environ["CUDA_VISIBLE_DEVICES"] = args.device + + args.cuda = not args.no_cuda + + if args.batch_size < 2048 and args.warmup_steps <= 4000: + print('[Warning] The warmup steps may be not enough.\n' + '(sz_b, warmup) = (2048, 4000) is the official setting.\n' + 'Using smaller batch w/o longer warmup may cause ' + 'the warmup stage ends with only little data trained.') + + # 创建日志对象 + logger = create_logger(args) + # 当用户使用GPU,并且GPU可用时 + args.cuda = torch.cuda.is_available() and not args.no_cuda + device = 'cuda:0' if args.cuda else 'cpu' + args.device = device + logger.info('using device:{}'.format(device)) + + # 初始化tokenizer + tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") + args.sep_id = tokenizer.sep_token_id + args.pad_id = tokenizer.pad_token_id + args.cls_id = tokenizer.cls_token_id + + # 创建模型的输出目录 + if not os.path.exists(args.save_model_path): + os.mkdir(args.save_model_path) + + # 创建模型 + if args.pretrained_model: # 加载预训练模型 + model = GPT2LMHeadModel.from_pretrained(args.pretrained_model) + else: # 初始化模型 + model_config = GPT2Config.from_json_file(args.model_config) + model = GPT2LMHeadModel(config=model_config) + model = model.to(device) + logger.info('model config:\n{}'.format(model.config.to_json_string())) + assert model.config.vocab_size == tokenizer.vocab_size + + # 并行训练模型 + if args.cuda and torch.cuda.device_count() > 1: + model = DataParallel(model).cuda() + # model = BalancedDataParallel(args.gpu0_bsz, model, dim=0).cuda() + logger.info("use GPU {} to train".format(args.device)) + + # 计算模型参数数量 + num_parameters = 0 + parameters = model.parameters() + for parameter in parameters: + num_parameters += parameter.numel() + logger.info('number of model parameters: {}'.format(num_parameters)) + + # 记录参数设置 + logger.info("args:{}".format(args)) + + # 加载训练集和验证集 + # ========= Loading Dataset ========= # + train_dataset, validate_dataset = load_dataset(logger, args) + + train(model, logger, train_dataset, validate_dataset, args) + + +if __name__ == '__main__': + main()