Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse, StreamingResponse | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import asyncio | |
app = FastAPI() | |
# Chargement du modèle | |
model_name = "microsoft/Phi-3.5-mini-instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Fonction de génération avec streaming | |
async def generate_stream(prompt: str): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
for _ in range(512): # Limite de tokens | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=1, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True) | |
yield f"data: {new_token}\n\n" | |
await asyncio.sleep(0.05) | |
inputs = {"input_ids": outputs} | |
# Interface Gradio standard | |
def generate_text(prompt: str): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate(**inputs, max_new_tokens=512) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Page web de streaming | |
async def web_interface(request: Request): | |
return """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Chat Streaming</title> | |
<script> | |
async function startStream() { | |
const prompt = document.getElementById("prompt").value; | |
const output = document.getElementById("output"); | |
output.innerHTML = ""; | |
const eventSource = new EventSource(`/stream?prompt=${encodeURIComponent(prompt)}`); | |
eventSource.onmessage = (event) => { | |
output.innerHTML += event.data; | |
output.scrollTop = output.scrollHeight; | |
}; | |
} | |
</script> | |
</head> | |
<body> | |
<h1>Chat en temps réel</h1> | |
<textarea id="prompt" rows="4"></textarea> | |
<button onclick="startStream()">Envoyer</button> | |
<div id="output" style="white-space: pre-wrap; margin-top: 20px;"></div> | |
</body> | |
</html> | |
""" | |
# Endpoint de streaming | |
async def stream_response(prompt: str): | |
return StreamingResponse( | |
generate_stream(prompt), | |
media_type="text/event-stream" | |
) | |
# Interface Gradio (accessible via /gradio) | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs="text", | |
outputs="text", | |
title="Phi-3 Chat" | |
) | |
app = gr.mount_gradio_app(app, demo, path="/gradio") |