GueuleDange commited on
Commit
2ad44a0
·
verified ·
1 Parent(s): fbb1dd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -24
app.py CHANGED
@@ -4,32 +4,28 @@ from fastapi.templating import Jinja2Templates
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
  import asyncio
7
- import os
8
-
9
- # Créer le dossier static s'il n'existe pas
10
- os.makedirs("static", exist_ok=True)
11
 
 
12
  app = FastAPI()
13
  templates = Jinja2Templates(directory="templates")
14
 
15
- # Configuration simplifiée pour Hugging Face Spaces
16
- model_name = "microsoft/Phi-3.5-mini-instruct"
17
-
18
  try:
19
- tokenizer = AutoTokenizer.from_pretrained(model_name)
20
  model = AutoModelForCausalLM.from_pretrained(
21
- model_name,
22
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
  device_map="auto"
24
  )
25
  except Exception as e:
26
  print(f"Erreur de chargement du modèle: {str(e)}")
27
  raise
28
 
29
- async def generate_response(prompt: str):
30
- try:
31
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
32
-
33
  for _ in range(512):
34
  outputs = model.generate(
35
  **inputs,
@@ -39,19 +35,35 @@ async def generate_response(prompt: str):
39
  top_p=0.9
40
  )
41
  new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
42
- yield f"data: {new_token}\n\n"
43
- await asyncio.sleep(0.05)
44
  inputs = {"input_ids": outputs}
45
- except Exception as e:
46
- yield f"data: [ERREUR: {str(e)}]\n\n"
47
-
48
- @app.get("/", response_class=HTMLResponse)
49
- async def home(request: Request):
50
- return templates.TemplateResponse("index.html", {"request": request})
51
 
 
52
  @app.get("/stream")
53
  async def stream(prompt: str):
54
  return StreamingResponse(
55
- generate_response(prompt),
56
  media_type="text/event-stream"
57
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
  import asyncio
7
+ import gradio as gr
 
 
 
8
 
9
+ # Initialisation FastAPI
10
  app = FastAPI()
11
  templates = Jinja2Templates(directory="templates")
12
 
13
+ # Chargement du modèle (avec gestion d'erreur)
 
 
14
  try:
15
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
16
  model = AutoModelForCausalLM.from_pretrained(
17
+ "microsoft/Phi-3.5-mini-instruct",
18
+ torch_dtype=torch.float16,
19
  device_map="auto"
20
  )
21
  except Exception as e:
22
  print(f"Erreur de chargement du modèle: {str(e)}")
23
  raise
24
 
25
+ # Fonction de génération commune
26
+ async def generate_tokens(prompt: str):
27
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
28
+ with torch.no_grad():
29
  for _ in range(512):
30
  outputs = model.generate(
31
  **inputs,
 
35
  top_p=0.9
36
  )
37
  new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
38
+ yield new_token
 
39
  inputs = {"input_ids": outputs}
 
 
 
 
 
 
40
 
41
+ # Route FastAPI pour le site web
42
  @app.get("/stream")
43
  async def stream(prompt: str):
44
  return StreamingResponse(
45
+ generate_tokens(prompt),
46
  media_type="text/event-stream"
47
+ )
48
+
49
+ # Interface Gradio
50
+ def gradio_interface(prompt: str):
51
+ full_response = ""
52
+ for token in generate_tokens(prompt):
53
+ full_response += token
54
+ return full_response
55
+
56
+ gradio_app = gr.Interface(
57
+ fn=gradio_interface,
58
+ inputs=gr.Textbox(label="Votre message"),
59
+ outputs=gr.Textbox(label="Réponse", interactive=False),
60
+ title="Chat avec (Gradio)"
61
+ )
62
+
63
+ # Montage des deux apps
64
+ app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
65
+
66
+ # Route racine (peut rediriger vers Gradio ou votre site)
67
+ @app.get("/", response_class=HTMLResponse)
68
+ async def home(request: Request):
69
+ return templates.TemplateResponse("index.html", {"request": request})