Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
import torch | |
# from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
app = FastAPI() | |
# MODEL = "google/flan-t5-small" | |
# MODEL = "jingyaogong/minimind-v1-small" | |
MODEL = "tclh123/minimind-v1-small" | |
# pipe_flan = pipeline("text2text-generation", model=MODEL, trust_remote_code=True) | |
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(MODEL, trust_remote_code=True) | |
model = model.to(device) | |
model = model.eval() | |
def query(message, max_seq_len=512, temperature=0.7, top_k=16): | |
prompt = '请问,' + message | |
messages = [] | |
messages.append({"role": "user", "content": prompt}) | |
stream = True | |
# print(messages) | |
new_prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
)[-(max_seq_len - 1):] | |
x = tokenizer(new_prompt).data['input_ids'] | |
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...]) | |
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature, top_k=top_k, stream=stream) | |
try: | |
y = next(res_y) | |
except StopIteration: | |
# print("No answer") | |
return "" | |
ret = [] | |
history_idx = 0 | |
while y != None: | |
answer = tokenizer.decode(y[0].tolist()) | |
if answer and answer[-1] == '�': | |
try: | |
y = next(res_y) | |
except: | |
break | |
continue | |
# print(answer) | |
if not len(answer): | |
try: | |
y = next(res_y) | |
except: | |
break | |
continue | |
# print(answer[history_idx:], end='', flush=True) | |
ret.append(answer[history_idx:]) | |
try: | |
y = next(res_y) | |
except: | |
break | |
history_idx = len(answer) | |
if not stream: | |
break | |
# print('\n') | |
ret.append('\n') | |
return ''.join(ret) | |
def t5(input): | |
# output = pipe_flan(input) | |
# return {"output": output[0]["generated_text"]} | |
output = query(input) | |
return {"output": output} | |
app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
def index() -> FileResponse: | |
return FileResponse(path="/app/static/index.html", media_type="text/html") | |