Spaces:
Runtime error
Runtime error
import os | |
from threading import Thread | |
import platform | |
from typing import Union | |
import torch | |
from transformers import TextIteratorStreamer,PreTrainedTokenizerFast | |
from safetensors.torch import load_model | |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
# import 自定义类和函数 | |
from model.chat_model import TextToTextModel | |
from utils.functions import get_T5_config | |
from config import InferConfig, T5ModelConfig | |
class ChatBot: | |
def __init__(self, infer_config: InferConfig) -> None: | |
''' | |
''' | |
self.infer_config = infer_config | |
# 初始化tokenizer | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(infer_config.model_dir) | |
self.tokenizer = tokenizer | |
self.encode = tokenizer.encode_plus | |
self.batch_decode = tokenizer.batch_decode | |
self.batch_encode_plus = tokenizer.batch_encode_plus | |
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id) | |
try: | |
model = TextToTextModel(t5_config) | |
if os.path.isdir(infer_config.model_dir): | |
# from_pretrained | |
model = model.from_pretrained(infer_config.model_dir) | |
elif infer_config.model_dir.endswith('.safetensors'): | |
# load safetensors | |
load_model(model, infer_config.model_dir) | |
else: | |
# load torch checkpoint | |
model.load_state_dict(torch.load(infer_config.model_dir)) | |
self.model = model | |
except Exception as e: | |
print(str(e), 'transformers and pytorch load fail, try accelerate load function.') | |
empty_model = None | |
with init_empty_weights(): | |
empty_model = TextToTextModel(t5_config) | |
self.model = load_checkpoint_and_dispatch( | |
model=empty_model, | |
checkpoint=infer_config.model_dir, | |
device_map='auto', | |
dtype=torch.float16, | |
) | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
self.model.to(self.device) | |
self.streamer = TextIteratorStreamer(tokenizer=tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
def stream_chat(self, input_txt: str) -> TextIteratorStreamer: | |
''' | |
流式对话,线程启动后可返回,通过迭代streamer获取生成的文字,仅支持greedy search | |
''' | |
encoded = self.encode(input_txt + '[EOS]') | |
input_ids = torch.LongTensor([encoded.input_ids]).to(self.device) | |
attention_mask = torch.LongTensor([encoded.attention_mask]).to(self.device) | |
generation_kwargs = { | |
'input_ids': input_ids, | |
'attention_mask': attention_mask, | |
'max_seq_len': self.infer_config.max_seq_len, | |
'streamer': self.streamer, | |
'search_type': 'greedy', | |
} | |
thread = Thread(target=self.model.my_generate, kwargs=generation_kwargs) | |
thread.start() | |
return self.streamer | |
def chat(self, input_txt: Union[str, list[str]] ) -> Union[str, list[str]]: | |
''' | |
非流式生成,可以使用beam search、beam sample等方法生成文本。 | |
''' | |
if isinstance(input_txt, str): | |
input_txt = [input_txt] | |
elif not isinstance(input_txt, list): | |
raise Exception('input_txt mast be a str or list[str]') | |
# add EOS token | |
input_txts = [f"{txt}[EOS]" for txt in input_txt] | |
encoded = self.batch_encode_plus(input_txts, padding=True) | |
input_ids = torch.LongTensor(encoded.input_ids).to(self.device) | |
attention_mask = torch.LongTensor(encoded.attention_mask).to(self.device) | |
outputs = self.model.my_generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_seq_len=self.infer_config.max_seq_len, | |
search_type='greedy', | |
) | |
outputs = self.batch_decode(outputs.cpu().numpy(), clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
note = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋" | |
outputs = [item if len(item) != 0 else note for item in outputs] | |
return outputs[0] if len(outputs) == 1 else outputs | |