97 lines
3.7 KiB
Python
97 lines
3.7 KiB
Python
from transformers import BertTokenizerFast
|
|
import argparse
|
|
import pickle
|
|
from tqdm import tqdm
|
|
import logging
|
|
import numpy as np
|
|
|
|
|
|
def create_logger(log_path):
|
|
"""
|
|
将日志输出到日志文件和控制台
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
# 创建一个handler,用于写入日志文件
|
|
file_handler = logging.FileHandler(
|
|
filename=log_path)
|
|
file_handler.setFormatter(formatter)
|
|
file_handler.setLevel(logging.INFO)
|
|
logger.addHandler(file_handler)
|
|
|
|
# 创建一个handler,用于将日志输出到控制台
|
|
console = logging.StreamHandler()
|
|
console.setLevel(logging.DEBUG)
|
|
console.setFormatter(formatter)
|
|
logger.addHandler(console)
|
|
|
|
return logger
|
|
|
|
|
|
def preprocess():
|
|
"""
|
|
对原始语料进行tokenize,将每段对话处理成如下形式:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
|
|
"""
|
|
# 设置参数
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--vocab_path', default='vocab/vocab.txt', type=str, required=False,
|
|
help='词表路径')
|
|
parser.add_argument('--log_path', default='data/preprocess.log', type=str, required=False, help='训练日志存放位置')
|
|
parser.add_argument('--train_path', default='data/train.txt', type=str, required=False, help='训练日志存放位置')
|
|
parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False, help='tokenize的训练数据集')
|
|
args = parser.parse_args()
|
|
|
|
# 初始化日志对象
|
|
logger = create_logger(args.log_path)
|
|
|
|
# 初始化tokenizer
|
|
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
|
sep_id = tokenizer.sep_token_id
|
|
cls_id = tokenizer.cls_token_id
|
|
logger.info("preprocessing data,data path:{}, save path:{}".format(args.train_path, args.save_path))
|
|
|
|
# 读取训练数据集
|
|
with open(args.train_path, 'rb') as f:
|
|
data = f.read().decode("utf-8")
|
|
|
|
# 需要区分linux和windows环境下的换行符
|
|
if "\r\n" in data:
|
|
train_data = data.split("\r\n\r\n")
|
|
else:
|
|
train_data = data.split("\n\n")
|
|
logger.info("there are {} dialogue in dataset".format(len(train_data)))
|
|
|
|
# 开始进行tokenize
|
|
# 保存所有的对话数据,每条数据的格式为:"[CLS]utterance1[SEP]utterance2[SEP]utterance3[SEP]"
|
|
dialogue_len = [] # 记录所有对话tokenize之后的长度,用于统计中位数与均值
|
|
dialogue_list = []
|
|
with open(args.save_path, "w", encoding="utf-8") as f:
|
|
for index, dialogue in enumerate(tqdm(train_data)):
|
|
if "\r\n" in data:
|
|
utterances = dialogue.split("\r\n")
|
|
else:
|
|
utterances = dialogue.split("\n")
|
|
|
|
input_ids = [cls_id] # 每个dialogue以[CLS]开头
|
|
for utterance in utterances:
|
|
input_ids += tokenizer.encode(utterance, add_special_tokens=False)
|
|
input_ids.append(sep_id) # 每个utterance之后添加[SEP],表示utterance结束
|
|
dialogue_len.append(len(input_ids))
|
|
dialogue_list.append(input_ids)
|
|
len_mean = np.mean(dialogue_len)
|
|
len_median = np.median(dialogue_len)
|
|
len_max = np.max(dialogue_len)
|
|
with open(args.save_path, "wb") as f:
|
|
pickle.dump(dialogue_list, f)
|
|
logger.info("finish preprocessing data,the result is stored in {}".format(args.save_path))
|
|
logger.info("mean of dialogue len:{},median of dialogue len:{},max len:{}".format(len_mean, len_median, len_max))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
preprocess()
|
|
|