Spaces:
Sleeping
Sleeping
File size: 2,765 Bytes
ee0e357 0eadccb ee0e357 0eadccb ee0e357 0eadccb 2ad44a0 0eadccb ee0e357 0eadccb ee0e357 0eadccb ee0e357 0eadccb ee0e357 2ad44a0 0eadccb 2ad44a0 0eadccb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import asyncio
app = FastAPI()
# Chargement du modèle
model_name = "microsoft/Phi-3.5-mini-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Fonction de génération avec streaming
async def generate_stream(prompt: str):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
for _ in range(512): # Limite de tokens
outputs = model.generate(
**inputs,
max_new_tokens=1,
do_sample=True,
temperature=0.7,
top_p=0.9
)
new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
yield f"data: {new_token}\n\n"
await asyncio.sleep(0.05)
inputs = {"input_ids": outputs}
# Interface Gradio standard
def generate_text(prompt: str):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Page web de streaming
@app.get("/", response_class=HTMLResponse)
async def web_interface(request: Request):
return """
<!DOCTYPE html>
<html>
<head>
<title>Chat Streaming</title>
<script>
async function startStream() {
const prompt = document.getElementById("prompt").value;
const output = document.getElementById("output");
output.innerHTML = "";
const eventSource = new EventSource(`/stream?prompt=${encodeURIComponent(prompt)}`);
eventSource.onmessage = (event) => {
output.innerHTML += event.data;
output.scrollTop = output.scrollHeight;
};
}
</script>
</head>
<body>
<h1>Chat en temps réel</h1>
<textarea id="prompt" rows="4"></textarea>
<button onclick="startStream()">Envoyer</button>
<div id="output" style="white-space: pre-wrap; margin-top: 20px;"></div>
</body>
</html>
"""
# Endpoint de streaming
@app.get("/stream")
async def stream_response(prompt: str):
return StreamingResponse(
generate_stream(prompt),
media_type="text/event-stream"
)
# Interface Gradio (accessible via /gradio)
demo = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="Phi-3 Chat"
)
app = gr.mount_gradio_app(app, demo, path="/gradio") |