GueuleDange commited on
Commit
3796eb5
·
verified ·
1 Parent(s): e41d8d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -60
app.py CHANGED
@@ -1,26 +1,28 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.responses import HTMLResponse, StreamingResponse
3
- import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
  import asyncio
 
7
 
8
  app = FastAPI()
 
9
 
10
- # Chargement du modèle
11
  model_name = "microsoft/Phi-3.5-mini-instruct"
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
  torch_dtype=torch.float16,
16
- device_map="auto"
 
17
  )
18
 
19
- # Fonction de génération avec streaming
20
- async def generate_stream(prompt: str):
21
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
22
 
23
- for _ in range(512): # Limite de tokens
24
  outputs = model.generate(
25
  **inputs,
26
  max_new_tokens=1,
@@ -30,61 +32,34 @@ async def generate_stream(prompt: str):
30
  )
31
  new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
32
  yield f"data: {new_token}\n\n"
33
- await asyncio.sleep(0.05)
 
 
 
 
 
 
34
  inputs = {"input_ids": outputs}
35
 
36
- # Interface Gradio standard
37
- def generate_text(prompt: str):
38
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
39
- outputs = model.generate(**inputs, max_new_tokens=512)
40
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
41
-
42
- # Page web de streaming
43
  @app.get("/", response_class=HTMLResponse)
44
- async def web_interface(request: Request):
45
- return """
46
- <!DOCTYPE html>
47
- <html>
48
- <head>
49
- <title>Chat Streaming</title>
50
- <script>
51
- async function startStream() {
52
- const prompt = document.getElementById("prompt").value;
53
- const output = document.getElementById("output");
54
- output.innerHTML = "";
55
-
56
- const eventSource = new EventSource(`/stream?prompt=${encodeURIComponent(prompt)}`);
57
-
58
- eventSource.onmessage = (event) => {
59
- output.innerHTML += event.data;
60
- output.scrollTop = output.scrollHeight;
61
- };
62
- }
63
- </script>
64
- </head>
65
- <body>
66
- <h1>Chat en temps réel</h1>
67
- <textarea id="prompt" rows="4"></textarea>
68
- <button onclick="startStream()">Envoyer</button>
69
- <div id="output" style="white-space: pre-wrap; margin-top: 20px;"></div>
70
- </body>
71
- </html>
72
- """
73
 
74
- # Endpoint de streaming
75
  @app.get("/stream")
76
  async def stream_response(prompt: str):
77
- return StreamingResponse(
78
- generate_stream(prompt),
79
- media_type="text/event-stream"
80
- )
81
-
82
- # Interface Gradio (accessible via /gradio)
83
- demo = gr.Interface(
84
- fn=generate_text,
85
- inputs="text",
86
- outputs="text",
87
- title="Phi-3 Chat"
88
- )
89
 
90
- app = gr.mount_gradio_app(app, demo, path="/gradio")
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
+ from fastapi.responses import StreamingResponse, HTMLResponse
3
+ from fastapi.templating import Jinja2Templates
 
4
  import torch
5
  import asyncio
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
  app = FastAPI()
9
+ templates = Jinja2Templates(directory="templates")
10
 
11
+ # Configuration du modèle (optimisé pour 2000 tokens)
12
  model_name = "microsoft/Phi-3.5-mini-instruct"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForCausalLM.from_pretrained(
15
  model_name,
16
  torch_dtype=torch.float16,
17
+ device_map="auto",
18
+ low_cpu_mem_usage=True # Critique pour les longs contextes
19
  )
20
 
21
+ async def generate_stream(prompt: str, max_tokens: int = 2000):
 
22
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
23
+ generated_count = 0
24
 
25
+ while generated_count < max_tokens:
26
  outputs = model.generate(
27
  **inputs,
28
  max_new_tokens=1,
 
32
  )
33
  new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
34
  yield f"data: {new_token}\n\n"
35
+ generated_count += 1
36
+
37
+ # Optimisation mémoire
38
+ if generated_count % 50 == 0:
39
+ await asyncio.sleep(0.01) # Réduit la pression sur le GPU
40
+ torch.cuda.empty_cache() # Nettoyage mémoire
41
+
42
  inputs = {"input_ids": outputs}
43
 
 
 
 
 
 
 
 
44
  @app.get("/", response_class=HTMLResponse)
45
+ async def chat_page(request: Request):
46
+ return templates.TemplateResponse("stream.html", {"request": request})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
48
  @app.get("/stream")
49
  async def stream_response(prompt: str):
50
+ try:
51
+ return StreamingResponse(
52
+ generate_stream(prompt),
53
+ media_type="text/event-stream",
54
+ headers={
55
+ "Cache-Control": "no-cache",
56
+ "Connection": "keep-alive",
57
+ "X-Accel-Buffering": "no" # Critique pour les streams longs
58
+ }
59
+ )
60
+ except Exception as e:
61
+ raise HTTPException(status_code=500, detail=str(e))
62
 
63
+ if __name__ == "__main__":
64
+ import uvicorn
65
+ uvicorn.run(app, host="0.0.0.0", port=7860)