chatmlTest / model /infer.py
fangshengren's picture
Upload 59 files
f4fac26 verified
import os
from threading import Thread
import platform
from typing import Union
import torch
from transformers import TextIteratorStreamer,PreTrainedTokenizerFast
from safetensors.torch import load_model
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
# import 自定义类和函数
from model.chat_model import TextToTextModel
from utils.functions import get_T5_config
from config import InferConfig, T5ModelConfig
class ChatBot:
def __init__(self, infer_config: InferConfig) -> None:
'''
'''
self.infer_config = infer_config
# 初始化tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(infer_config.model_dir)
self.tokenizer = tokenizer
self.encode = tokenizer.encode_plus
self.batch_decode = tokenizer.batch_decode
self.batch_encode_plus = tokenizer.batch_encode_plus
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
try:
model = TextToTextModel(t5_config)
if os.path.isdir(infer_config.model_dir):
# from_pretrained
model = model.from_pretrained(infer_config.model_dir)
elif infer_config.model_dir.endswith('.safetensors'):
# load safetensors
load_model(model, infer_config.model_dir)
else:
# load torch checkpoint
model.load_state_dict(torch.load(infer_config.model_dir))
self.model = model
except Exception as e:
print(str(e), 'transformers and pytorch load fail, try accelerate load function.')
empty_model = None
with init_empty_weights():
empty_model = TextToTextModel(t5_config)
self.model = load_checkpoint_and_dispatch(
model=empty_model,
checkpoint=infer_config.model_dir,
device_map='auto',
dtype=torch.float16,
)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.streamer = TextIteratorStreamer(tokenizer=tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True)
def stream_chat(self, input_txt: str) -> TextIteratorStreamer:
'''
流式对话,线程启动后可返回,通过迭代streamer获取生成的文字,仅支持greedy search
'''
encoded = self.encode(input_txt + '[EOS]')
input_ids = torch.LongTensor([encoded.input_ids]).to(self.device)
attention_mask = torch.LongTensor([encoded.attention_mask]).to(self.device)
generation_kwargs = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'max_seq_len': self.infer_config.max_seq_len,
'streamer': self.streamer,
'search_type': 'greedy',
}
thread = Thread(target=self.model.my_generate, kwargs=generation_kwargs)
thread.start()
return self.streamer
def chat(self, input_txt: Union[str, list[str]] ) -> Union[str, list[str]]:
'''
非流式生成,可以使用beam search、beam sample等方法生成文本。
'''
if isinstance(input_txt, str):
input_txt = [input_txt]
elif not isinstance(input_txt, list):
raise Exception('input_txt mast be a str or list[str]')
# add EOS token
input_txts = [f"{txt}[EOS]" for txt in input_txt]
encoded = self.batch_encode_plus(input_txts, padding=True)
input_ids = torch.LongTensor(encoded.input_ids).to(self.device)
attention_mask = torch.LongTensor(encoded.attention_mask).to(self.device)
outputs = self.model.my_generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_seq_len=self.infer_config.max_seq_len,
search_type='greedy',
)
outputs = self.batch_decode(outputs.cpu().numpy(), clean_up_tokenization_spaces=True, skip_special_tokens=True)
note = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
outputs = [item if len(item) != 0 else note for item in outputs]
return outputs[0] if len(outputs) == 1 else outputs