tclh123's picture
update
1f85c85
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import torch
# from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
app = FastAPI()
# MODEL = "google/flan-t5-small"
# MODEL = "jingyaogong/minimind-v1-small"
MODEL = "tclh123/minimind-v1-small"
# pipe_flan = pipeline("text2text-generation", model=MODEL, trust_remote_code=True)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL, trust_remote_code=True)
model = model.to(device)
model = model.eval()
def query(message, max_seq_len=512, temperature=0.7, top_k=16):
prompt = '请问,' + message
messages = []
messages.append({"role": "user", "content": prompt})
stream = True
# print(messages)
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)[-(max_seq_len - 1):]
x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature, top_k=top_k, stream=stream)
try:
y = next(res_y)
except StopIteration:
# print("No answer")
return ""
ret = []
history_idx = 0
while y != None:
answer = tokenizer.decode(y[0].tolist())
if answer and answer[-1] == '�':
try:
y = next(res_y)
except:
break
continue
# print(answer)
if not len(answer):
try:
y = next(res_y)
except:
break
continue
# print(answer[history_idx:], end='', flush=True)
ret.append(answer[history_idx:])
try:
y = next(res_y)
except:
break
history_idx = len(answer)
if not stream:
break
# print('\n')
ret.append('\n')
return ''.join(ret)
@app.get("/infer_t5")
def t5(input):
# output = pipe_flan(input)
# return {"output": output[0]["generated_text"]}
output = query(input)
return {"output": output}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")