This repository has been archived on 2026-03-12. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
nlp-chatbot/preprocess.py
T
2026-03-12 11:09:11 +08:00

97 lines
3.7 KiB
Python

from transformers import BertTokenizerFast
import argparse
import pickle
from tqdm import tqdm
import logging
import numpy as np
def create_logger(log_path):
"""
将日志输出到日志文件和控制台
"""
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s')
# 创建一个handler,用于写入日志文件
file_handler = logging.FileHandler(
filename=log_path)
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO)
logger.addHandler(file_handler)
# 创建一个handler,用于将日志输出到控制台
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
console.setFormatter(formatter)
logger.addHandler(console)
return logger
def preprocess():
"""
对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
"""
# 设置参数
parser = argparse.ArgumentParser()
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False,
help='词表路径')
parser.add_argument('--log_path', default='data/preprocess.log', type=str, required=False, help='训练日志存放位置')
parser.add_argument('--train_path', default='data/train.txt', type=str, required=False, help='训练日志存放位置')
parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False, help='tokenize的训练数据集')
args = parser.parse_args()
# 初始化日志对象
logger = create_logger(args.log_path)
# 初始化tokenizer
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
sep_id = tokenizer.sep_token_id
cls_id = tokenizer.cls_token_id
logger.info("preprocessing data,data path:{}, save path:{}".format(args.train_path, args.save_path))
# 读取训练数据集
with open(args.train_path, 'rb') as f:
data = f.read().decode("utf-8")
# 需要区分linux和windows环境下的换行符
if "\r\n" in data:
train_data = data.split("\r\n\r\n")
else:
train_data = data.split("\n\n")
logger.info("there are {} dialogue in dataset".format(len(train_data)))
# 开始进行tokenize
# 保存所有的对话数据,每条数据的格式为:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
dialogue_len = [] # 记录所有对话tokenize之后的长度,用于统计中位数与均值
dialogue_list = []
with open(args.save_path, "w", encoding="utf-8") as f:
for index, dialogue in enumerate(tqdm(train_data)):
if "\r\n" in data:
utterances = dialogue.split("\r\n")
else:
utterances = dialogue.split("\n")
input_ids = [cls_id] # 每个dialogue以[CLS]开头
for utterance in utterances:
input_ids += tokenizer.encode(utterance, add_special_tokens=False)
input_ids.append(sep_id) # 每个utterance之后添加[SEP],表示utterance结束
dialogue_len.append(len(input_ids))
dialogue_list.append(input_ids)
len_mean = np.mean(dialogue_len)
len_median = np.median(dialogue_len)
len_max = np.max(dialogue_len)
with open(args.save_path, "wb") as f:
pickle.dump(dialogue_list, f)
logger.info("finish preprocessing data,the result is stored in {}".format(args.save_path))
logger.info("mean of dialogue len:{},median of dialogue len:{},max len:{}".format(len_mean, len_median, len_max))
if __name__ == '__main__':
preprocess()