AleRive commited on
Commit
1b5ee57
·
verified ·
1 Parent(s): 8e62ac4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -22
app.py CHANGED
@@ -1,45 +1,46 @@
1
  import os
2
-
3
- # Imposta la cache dei modelli in una cartella scrivibile all'interno della home dell'utente
4
- os.environ["HF_HOME"] = "/tmp/huggingface"
5
-
6
  from fastapi import FastAPI, Request
7
  from fastapi.responses import JSONResponse, FileResponse
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
9
  import uvicorn
10
 
11
  app = FastAPI()
12
 
13
- # Crea la cartella di cache se non esiste
 
14
  os.makedirs("/tmp/huggingface", exist_ok=True)
15
 
16
- # Carica il modello Hugging Face
17
- model_name = "microsoft/DialoGPT-medium"
18
- tokenizer = AutoTokenizer.from_pretrained(model_name)
19
- model = AutoModelForCausalLM.from_pretrained(model_name)
20
 
21
- # Servire il frontend statico
22
  @app.get("/")
23
  async def serve_index():
24
  return FileResponse("static/index.html")
25
 
26
- # API per la chat
27
  @app.post("/chat")
28
  async def chat(request: Request):
29
  data = await request.json()
30
  prompt = data.get("prompt", "")
31
 
32
- # Tokenizzazione e generazione della risposta
33
- inputs = tokenizer(prompt, return_tensors="pt")
34
- outputs = model.generate(
35
- inputs["input_ids"],
36
- max_length=50,
37
- pad_token_id=tokenizer.eos_token_id, # Aggiunto per evitare warning
38
- attention_mask=inputs["attention_mask"] # Aggiunto per maggiore stabilità
39
- )
40
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
41
-
42
- return JSONResponse({"response": response})
 
 
 
 
 
43
 
44
  if __name__ == "__main__":
45
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
 
 
 
 
2
  from fastapi import FastAPI, Request
3
  from fastapi.responses import JSONResponse, FileResponse
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import torch
6
  import uvicorn
7
 
8
  app = FastAPI()
9
 
10
+ # Imposta la cache per Hugging Face in una directory scrivibile
11
+ os.environ["HF_HOME"] = "/tmp/huggingface"
12
  os.makedirs("/tmp/huggingface", exist_ok=True)
13
 
14
+ # Carica il modello DialoGPT
15
+ model_name = "microsoft/DialoGPT-small"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/huggingface")
17
+ model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="/tmp/huggingface")
18
 
 
19
  @app.get("/")
20
  async def serve_index():
21
  return FileResponse("static/index.html")
22
 
 
23
  @app.post("/chat")
24
  async def chat(request: Request):
25
  data = await request.json()
26
  prompt = data.get("prompt", "")
27
 
28
+ # Tokenizzazione del prompt
29
+ input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors="pt")
30
+
31
+ # Generazione della risposta
32
+ response_ids = model.generate(
33
+ input_ids,
34
+ max_length=100,
35
+ num_return_sequences=1,
36
+ pad_token_id=tokenizer.eos_token_id,
37
+ attention_mask=torch.ones(input_ids.shape, dtype=torch.long) # Aggiunto per correggere l'errore
38
+ )
39
+
40
+ # Decodifica della risposta
41
+ response_text = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
42
+
43
+ return JSONResponse({"response": response_text})
44
 
45
  if __name__ == "__main__":
46
  uvicorn.run(app, host="0.0.0.0", port=7860)