chat_due / app.py
AleRive's picture
Update app.py
690d3d8 verified
raw
history blame
1.5 kB
import os
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, FileResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import uvicorn
app = FastAPI()
# Imposta la cache per Hugging Face in una directory scrivibile
os.environ["HF_HOME"] = "/tmp/huggingface"
os.makedirs("/tmp/huggingface", exist_ok=True)
# Carica il modello DialoGPT
model_name = "facebook/blenderbot-3B"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/huggingface")
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="/tmp/huggingface")
@app.get("/")
async def serve_index():
return FileResponse("static/index.html")
@app.post("/chat")
async def chat(request: Request):
data = await request.json()
prompt = data.get("prompt", "")
# Tokenizzazione del prompt
input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors="pt")
# Generazione della risposta
response_ids = model.generate(
input_ids,
max_length=100,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
attention_mask=torch.ones(input_ids.shape, dtype=torch.long) # Aggiunto per correggere l'errore
)
# Decodifica della risposta
response_text = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
return JSONResponse({"response": response_text})
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)