149 lines
5.5 KiB
Python
149 lines
5.5 KiB
Python
import torch
|
|
import os
|
|
from transformers import GPT2LMHeadModel
|
|
from transformers import BertTokenizerFast
|
|
import torch.nn.functional as F
|
|
|
|
import random
|
|
import asyncio
|
|
import websockets
|
|
import json
|
|
|
|
from cqhttp_settings import *
|
|
from interact import set_args, create_logger, top_k_top_p_filtering
|
|
|
|
PAD = '[PAD]'
|
|
pad_id = 0
|
|
|
|
|
|
def init_model():
|
|
args = set_args()
|
|
logger = create_logger(args)
|
|
# 当用户使用GPU,并且GPU可用时
|
|
args.cuda = torch.cuda.is_available() and not args.no_cuda
|
|
device = 'cuda' if args.cuda else 'cpu'
|
|
logger.info('using device:{}'.format(device))
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
|
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
|
|
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
|
|
model = GPT2LMHeadModel.from_pretrained(args.model_path)
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
return args, device, tokenizer, model
|
|
|
|
|
|
def process(text, args, device, tokenizer, model):
|
|
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
|
history.append(text_ids)
|
|
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
|
|
|
|
for history_id, history_utr in enumerate(history[-args.max_history_len:]):
|
|
input_ids.extend(history_utr)
|
|
input_ids.append(tokenizer.sep_token_id)
|
|
input_ids = torch.tensor(input_ids).long().to(device)
|
|
input_ids = input_ids.unsqueeze(0)
|
|
response = [] # 根据context,生成的response
|
|
# 最多生成max_len个token
|
|
for _ in range(args.max_len):
|
|
outputs = model(input_ids=input_ids)
|
|
logits = outputs.logits
|
|
next_token_logits = logits[0, -1, :]
|
|
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
|
|
for id in set(response):
|
|
next_token_logits[id] /= args.repetition_penalty
|
|
next_token_logits = next_token_logits / args.temperature
|
|
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
|
|
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
|
|
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
|
|
# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
|
|
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
|
if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束
|
|
break
|
|
response.append(next_token.item())
|
|
input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
|
|
# his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
|
|
# print("his_text:{}".format(his_text))
|
|
history.append(response)
|
|
|
|
return "".join(tokenizer.convert_ids_to_tokens(response))
|
|
|
|
|
|
# 存储聊天记录,每个utterance以token的id的形式进行存储
|
|
history = []
|
|
|
|
|
|
def handle_event(evjson):
|
|
event = json.loads(evjson)
|
|
from_id = ""
|
|
msg_type = ""
|
|
req = ""
|
|
# print(event)
|
|
rand_response_possibility = random.randint(0, 100)
|
|
if 'message_type' in event and 'raw_message' in event:
|
|
msg_type = event['message_type']
|
|
msg_recv = event['raw_message']
|
|
print("收到消息: ", msg_recv)
|
|
if event['message_type'] == "group" and event['group_id'] in event_response_settings_group_enabled:
|
|
# 群消息
|
|
print("回复几率(?): {} < {}".format(rand_response_possibility, event_response_settings_group_rate))
|
|
if rand_response_possibility < event_response_settings_group_rate:
|
|
from_id = event['group_id']
|
|
req = msg_recv
|
|
|
|
elif event['message_type'] == "private":
|
|
# 私聊消息
|
|
print("回复几率(?): {} < {}".format(rand_response_possibility, event_response_settings_private_rate))
|
|
if rand_response_possibility < event_response_settings_private_rate:
|
|
from_id = event['user_id']
|
|
req = msg_recv
|
|
|
|
return from_id, msg_type, req
|
|
|
|
|
|
async def send_msg(cqhttp, source, msg_type, text):
|
|
print("Msg to {}({}): {}".format(source, msg_type, text))
|
|
data_send = {}
|
|
if msg_type == "group":
|
|
# 群消息
|
|
data_send = {
|
|
'action': "send_group_msg",
|
|
'params': {
|
|
'group_id': source,
|
|
'message': text,
|
|
},
|
|
}
|
|
elif msg_type == "private":
|
|
# 私聊消息
|
|
data_send = {
|
|
'action': "send_private_msg",
|
|
'params': {
|
|
'user_id': source,
|
|
'message': text,
|
|
},
|
|
}
|
|
|
|
await cqhttp.send(json.dumps(data_send))
|
|
|
|
|
|
async def init_cqhttp_ws(args, device, tokenizer, model):
|
|
print("Initializing CQHTTP WebSocket...")
|
|
async with websockets.connect(cqhttp_ws_addr + "?access_token=" + cqhttp_ws_access_token) as cqhttp:
|
|
while True:
|
|
evjson = await cqhttp.recv()
|
|
|
|
source, msg_type, req = handle_event(evjson)
|
|
|
|
print(req)
|
|
|
|
if req:
|
|
res = process(req, args, device, tokenizer, model)
|
|
if res:
|
|
print("Q: {}\nA: {}\n".format(req, res))
|
|
await send_msg(cqhttp, source, msg_type, res)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args, device, tokenizer, model = init_model()
|
|
asyncio.get_event_loop().run_until_complete(init_cqhttp_ws(args, device, tokenizer, model))
|