Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,32 +4,28 @@ from fastapi.templating import Jinja2Templates
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
import torch
|
6 |
import asyncio
|
7 |
-
import
|
8 |
-
|
9 |
-
# Créer le dossier static s'il n'existe pas
|
10 |
-
os.makedirs("static", exist_ok=True)
|
11 |
|
|
|
12 |
app = FastAPI()
|
13 |
templates = Jinja2Templates(directory="templates")
|
14 |
|
15 |
-
#
|
16 |
-
model_name = "microsoft/Phi-3.5-mini-instruct"
|
17 |
-
|
18 |
try:
|
19 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
20 |
model = AutoModelForCausalLM.from_pretrained(
|
21 |
-
|
22 |
-
torch_dtype=torch.float16
|
23 |
device_map="auto"
|
24 |
)
|
25 |
except Exception as e:
|
26 |
print(f"Erreur de chargement du modèle: {str(e)}")
|
27 |
raise
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
for _ in range(512):
|
34 |
outputs = model.generate(
|
35 |
**inputs,
|
@@ -39,19 +35,35 @@ async def generate_response(prompt: str):
|
|
39 |
top_p=0.9
|
40 |
)
|
41 |
new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
|
42 |
-
yield
|
43 |
-
await asyncio.sleep(0.05)
|
44 |
inputs = {"input_ids": outputs}
|
45 |
-
except Exception as e:
|
46 |
-
yield f"data: [ERREUR: {str(e)}]\n\n"
|
47 |
-
|
48 |
-
@app.get("/", response_class=HTMLResponse)
|
49 |
-
async def home(request: Request):
|
50 |
-
return templates.TemplateResponse("index.html", {"request": request})
|
51 |
|
|
|
52 |
@app.get("/stream")
|
53 |
async def stream(prompt: str):
|
54 |
return StreamingResponse(
|
55 |
-
|
56 |
media_type="text/event-stream"
|
57 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
import torch
|
6 |
import asyncio
|
7 |
+
import gradio as gr
|
|
|
|
|
|
|
8 |
|
9 |
+
# Initialisation FastAPI
|
10 |
app = FastAPI()
|
11 |
templates = Jinja2Templates(directory="templates")
|
12 |
|
13 |
+
# Chargement du modèle (avec gestion d'erreur)
|
|
|
|
|
14 |
try:
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
16 |
model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
"microsoft/Phi-3.5-mini-instruct",
|
18 |
+
torch_dtype=torch.float16,
|
19 |
device_map="auto"
|
20 |
)
|
21 |
except Exception as e:
|
22 |
print(f"Erreur de chargement du modèle: {str(e)}")
|
23 |
raise
|
24 |
|
25 |
+
# Fonction de génération commune
|
26 |
+
async def generate_tokens(prompt: str):
|
27 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
28 |
+
with torch.no_grad():
|
29 |
for _ in range(512):
|
30 |
outputs = model.generate(
|
31 |
**inputs,
|
|
|
35 |
top_p=0.9
|
36 |
)
|
37 |
new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
|
38 |
+
yield new_token
|
|
|
39 |
inputs = {"input_ids": outputs}
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
# Route FastAPI pour le site web
|
42 |
@app.get("/stream")
|
43 |
async def stream(prompt: str):
|
44 |
return StreamingResponse(
|
45 |
+
generate_tokens(prompt),
|
46 |
media_type="text/event-stream"
|
47 |
+
)
|
48 |
+
|
49 |
+
# Interface Gradio
|
50 |
+
def gradio_interface(prompt: str):
|
51 |
+
full_response = ""
|
52 |
+
for token in generate_tokens(prompt):
|
53 |
+
full_response += token
|
54 |
+
return full_response
|
55 |
+
|
56 |
+
gradio_app = gr.Interface(
|
57 |
+
fn=gradio_interface,
|
58 |
+
inputs=gr.Textbox(label="Votre message"),
|
59 |
+
outputs=gr.Textbox(label="Réponse", interactive=False),
|
60 |
+
title="Chat avec (Gradio)"
|
61 |
+
)
|
62 |
+
|
63 |
+
# Montage des deux apps
|
64 |
+
app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
|
65 |
+
|
66 |
+
# Route racine (peut rediriger vers Gradio ou votre site)
|
67 |
+
@app.get("/", response_class=HTMLResponse)
|
68 |
+
async def home(request: Request):
|
69 |
+
return templates.TemplateResponse("index.html", {"request": request})
|