This repository has been archived on 2026-03-12. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
nlp-chatbot/cqhttp.py
T
2026-03-12 11:09:11 +08:00

149 lines
5.5 KiB
Python

import torch
import os
from transformers import GPT2LMHeadModel
from transformers import BertTokenizerFast
import torch.nn.functional as F
import random
import asyncio
import websockets
import json
from cqhttp_settings import *
from interact import set_args, create_logger, top_k_top_p_filtering
PAD = '[PAD]'
pad_id = 0
def init_model():
args = set_args()
logger = create_logger(args)
# 当用户使用GPU,并且GPU可用时
args.cuda = torch.cuda.is_available() and not args.no_cuda
device = 'cuda' if args.cuda else 'cpu'
logger.info('using device:{}'.format(device))
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
tokenizer = BertTokenizerFast(vocab_file=args.vocab_path, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
# tokenizer = BertTokenizer(vocab_file=args.voca_path)
model = GPT2LMHeadModel.from_pretrained(args.model_path)
model = model.to(device)
model.eval()
return args, device, tokenizer, model
def process(text, args, device, tokenizer, model):
text_ids = tokenizer.encode(text, add_special_tokens=False)
history.append(text_ids)
input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头
for history_id, history_utr in enumerate(history[-args.max_history_len:]):
input_ids.extend(history_utr)
input_ids.append(tokenizer.sep_token_id)
input_ids = torch.tensor(input_ids).long().to(device)
input_ids = input_ids.unsqueeze(0)
response = [] # 根据context,生成的response
# 最多生成max_len个token
for _ in range(args.max_len):
outputs = model(input_ids=input_ids)
logits = outputs.logits
next_token_logits = logits[0, -1, :]
# 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
for id in set(response):
next_token_logits[id] /= args.repetition_penalty
next_token_logits = next_token_logits / args.temperature
# 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp)
# torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if next_token == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束
break
response.append(next_token.item())
input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)
# his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
# print("his_text:{}".format(his_text))
history.append(response)
return "".join(tokenizer.convert_ids_to_tokens(response))
# 存储聊天记录,每个utterance以token的id的形式进行存储
history = []
def handle_event(evjson):
event = json.loads(evjson)
from_id = ""
msg_type = ""
req = ""
# print(event)
rand_response_possibility = random.randint(0, 100)
if 'message_type' in event and 'raw_message' in event:
msg_type = event['message_type']
msg_recv = event['raw_message']
print("收到消息: ", msg_recv)
if event['message_type'] == "group" and event['group_id'] in event_response_settings_group_enabled:
# 群消息
print("回复几率(?): {} < {}".format(rand_response_possibility, event_response_settings_group_rate))
if rand_response_possibility < event_response_settings_group_rate:
from_id = event['group_id']
req = msg_recv
elif event['message_type'] == "private":
# 私聊消息
print("回复几率(?): {} < {}".format(rand_response_possibility, event_response_settings_private_rate))
if rand_response_possibility < event_response_settings_private_rate:
from_id = event['user_id']
req = msg_recv
return from_id, msg_type, req
async def send_msg(cqhttp, source, msg_type, text):
print("Msg to {}({}): {}".format(source, msg_type, text))
data_send = {}
if msg_type == "group":
# 群消息
data_send = {
'action': "send_group_msg",
'params': {
'group_id': source,
'message': text,
},
}
elif msg_type == "private":
# 私聊消息
data_send = {
'action': "send_private_msg",
'params': {
'user_id': source,
'message': text,
},
}
await cqhttp.send(json.dumps(data_send))
async def init_cqhttp_ws(args, device, tokenizer, model):
print("Initializing CQHTTP WebSocket...")
async with websockets.connect(cqhttp_ws_addr + "?access_token=" + cqhttp_ws_access_token) as cqhttp:
while True:
evjson = await cqhttp.recv()
source, msg_type, req = handle_event(evjson)
print(req)
if req:
res = process(req, args, device, tokenizer, model)
if res:
print("Q: {}\nA: {}\n".format(req, res))
await send_msg(cqhttp, source, msg_type, res)
if __name__ == '__main__':
args, device, tokenizer, model = init_model()
asyncio.get_event_loop().run_until_complete(init_cqhttp_ws(args, device, tokenizer, model))