Test / app.py
GueuleDange's picture
Update app.py
0eadccb verified
raw
history blame
2.77 kB
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import asyncio
app = FastAPI()
# Chargement du modèle
model_name = "microsoft/Phi-3.5-mini-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Fonction de génération avec streaming
async def generate_stream(prompt: str):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
for _ in range(512): # Limite de tokens
outputs = model.generate(
**inputs,
max_new_tokens=1,
do_sample=True,
temperature=0.7,
top_p=0.9
)
new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
yield f"data: {new_token}\n\n"
await asyncio.sleep(0.05)
inputs = {"input_ids": outputs}
# Interface Gradio standard
def generate_text(prompt: str):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Page web de streaming
@app.get("/", response_class=HTMLResponse)
async def web_interface(request: Request):
return """
<!DOCTYPE html>
<html>
<head>
<title>Chat Streaming</title>
<script>
async function startStream() {
const prompt = document.getElementById("prompt").value;
const output = document.getElementById("output");
output.innerHTML = "";
const eventSource = new EventSource(`/stream?prompt=${encodeURIComponent(prompt)}`);
eventSource.onmessage = (event) => {
output.innerHTML += event.data;
output.scrollTop = output.scrollHeight;
};
}
</script>
</head>
<body>
<h1>Chat en temps réel</h1>
<textarea id="prompt" rows="4"></textarea>
<button onclick="startStream()">Envoyer</button>
<div id="output" style="white-space: pre-wrap; margin-top: 20px;"></div>
</body>
</html>
"""
# Endpoint de streaming
@app.get("/stream")
async def stream_response(prompt: str):
return StreamingResponse(
generate_stream(prompt),
media_type="text/event-stream"
)
# Interface Gradio (accessible via /gradio)
demo = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="Phi-3 Chat"
)
app = gr.mount_gradio_app(app, demo, path="/gradio")