GueuleDange commited on
Commit
0eadccb
·
verified ·
1 Parent(s): 754f340

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -52
app.py CHANGED
@@ -1,69 +1,90 @@
1
  from fastapi import FastAPI, Request
2
- from fastapi.responses import StreamingResponse, HTMLResponse
3
- from fastapi.templating import Jinja2Templates
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
  import asyncio
7
- import gradio as gr
8
 
9
- # Initialisation FastAPI
10
  app = FastAPI()
11
- templates = Jinja2Templates(directory="templates")
12
 
13
- # Chargement du modèle (avec gestion d'erreur)
14
- try:
15
- tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
16
- model = AutoModelForCausalLM.from_pretrained(
17
- "microsoft/Phi-3.5-mini-instruct",
18
- torch_dtype=torch.float16,
19
- device_map="auto"
20
- )
21
- except Exception as e:
22
- print(f"Erreur de chargement du modèle: {str(e)}")
23
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # Fonction de génération commune
26
- async def generate_tokens(prompt: str):
27
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
28
- with torch.no_grad():
29
- for _ in range(512):
30
- outputs = model.generate(
31
- **inputs,
32
- max_new_tokens=1,
33
- do_sample=True,
34
- temperature=0.7,
35
- top_p=0.9
36
- )
37
- new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
38
- yield new_token
39
- inputs = {"input_ids": outputs}
40
 
41
- # Route FastAPI pour le site web
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @app.get("/stream")
43
- async def stream(prompt: str):
44
  return StreamingResponse(
45
- generate_tokens(prompt),
46
  media_type="text/event-stream"
47
  )
48
 
49
- # Interface Gradio
50
- def gradio_interface(prompt: str):
51
- full_response = ""
52
- for token in generate_tokens(prompt):
53
- full_response += token
54
- return full_response
55
-
56
- gradio_app = gr.Interface(
57
- fn=gradio_interface,
58
- inputs=gr.Textbox(label="Votre message"),
59
- outputs=gr.Textbox(label="Réponse", interactive=False),
60
- title="Chat avec (Gradio)"
61
  )
62
 
63
- # Montage des deux apps
64
- app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
65
-
66
- # Route racine (peut rediriger vers Gradio ou votre site)
67
- @app.get("/", response_class=HTMLResponse)
68
- async def home(request: Request):
69
- return templates.TemplateResponse("index.html", {"request": request})
 
1
  from fastapi import FastAPI, Request
2
+ from fastapi.responses import HTMLResponse, StreamingResponse
3
+ import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
  import asyncio
 
7
 
 
8
  app = FastAPI()
 
9
 
10
+ # Chargement du modèle
11
+ model_name = "microsoft/Phi-3.5-mini-instruct"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ torch_dtype=torch.float16,
16
+ device_map="auto"
17
+ )
18
+
19
+ # Fonction de génération avec streaming
20
+ async def generate_stream(prompt: str):
21
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
22
+
23
+ for _ in range(512): # Limite de tokens
24
+ outputs = model.generate(
25
+ **inputs,
26
+ max_new_tokens=1,
27
+ do_sample=True,
28
+ temperature=0.7,
29
+ top_p=0.9
30
+ )
31
+ new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
32
+ yield f"data: {new_token}\n\n"
33
+ await asyncio.sleep(0.05)
34
+ inputs = {"input_ids": outputs}
35
 
36
+ # Interface Gradio standard
37
+ def generate_text(prompt: str):
38
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
39
+ outputs = model.generate(**inputs, max_new_tokens=512)
40
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Page web de streaming
43
+ @app.get("/", response_class=HTMLResponse)
44
+ async def web_interface(request: Request):
45
+ return """
46
+ <!DOCTYPE html>
47
+ <html>
48
+ <head>
49
+ <title>Chat Streaming</title>
50
+ <script>
51
+ async function startStream() {
52
+ const prompt = document.getElementById("prompt").value;
53
+ const output = document.getElementById("output");
54
+ output.innerHTML = "";
55
+
56
+ const eventSource = new EventSource(`/stream?prompt=${encodeURIComponent(prompt)}`);
57
+
58
+ eventSource.onmessage = (event) => {
59
+ output.innerHTML += event.data;
60
+ output.scrollTop = output.scrollHeight;
61
+ };
62
+ }
63
+ </script>
64
+ </head>
65
+ <body>
66
+ <h1>Chat en temps réel</h1>
67
+ <textarea id="prompt" rows="4"></textarea>
68
+ <button onclick="startStream()">Envoyer</button>
69
+ <div id="output" style="white-space: pre-wrap; margin-top: 20px;"></div>
70
+ </body>
71
+ </html>
72
+ """
73
+
74
+ # Endpoint de streaming
75
  @app.get("/stream")
76
+ async def stream_response(prompt: str):
77
  return StreamingResponse(
78
+ generate_stream(prompt),
79
  media_type="text/event-stream"
80
  )
81
 
82
+ # Interface Gradio (accessible via /gradio)
83
+ demo = gr.Interface(
84
+ fn=generate_text,
85
+ inputs="text",
86
+ outputs="text",
87
+ title="Phi-3 Chat"
 
 
 
 
 
 
88
  )
89
 
90
+ app = gr.mount_gradio_app(app, demo, path="/gradio")