File size: 2,765 Bytes
ee0e357
0eadccb
 
ee0e357
 
 
 
 
 
0eadccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee0e357
0eadccb
 
2ad44a0
0eadccb
 
ee0e357
0eadccb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee0e357
0eadccb
ee0e357
0eadccb
ee0e357
2ad44a0
 
0eadccb
 
 
 
 
 
2ad44a0
 
0eadccb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import asyncio

app = FastAPI()

# Chargement du modèle
model_name = "microsoft/Phi-3.5-mini-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Fonction de génération avec streaming
async def generate_stream(prompt: str):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    for _ in range(512):  # Limite de tokens
        outputs = model.generate(
            **inputs,
            max_new_tokens=1,
            do_sample=True,
            temperature=0.7,
            top_p=0.9
        )
        new_token = tokenizer.decode(outputs[0][-1], skip_special_tokens=True)
        yield f"data: {new_token}\n\n"
        await asyncio.sleep(0.05)
        inputs = {"input_ids": outputs}

# Interface Gradio standard
def generate_text(prompt: str):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=512)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Page web de streaming
@app.get("/", response_class=HTMLResponse)
async def web_interface(request: Request):
    return """
    <!DOCTYPE html>
    <html>
    <head>
        <title>Chat Streaming</title>
        <script>
            async function startStream() {
                const prompt = document.getElementById("prompt").value;
                const output = document.getElementById("output");
                output.innerHTML = "";
                
                const eventSource = new EventSource(`/stream?prompt=${encodeURIComponent(prompt)}`);
                
                eventSource.onmessage = (event) => {
                    output.innerHTML += event.data;
                    output.scrollTop = output.scrollHeight;
                };
            }
        </script>
    </head>
    <body>
        <h1>Chat en temps réel</h1>
        <textarea id="prompt" rows="4"></textarea>
        <button onclick="startStream()">Envoyer</button>
        <div id="output" style="white-space: pre-wrap; margin-top: 20px;"></div>
    </body>
    </html>
    """

# Endpoint de streaming
@app.get("/stream")
async def stream_response(prompt: str):
    return StreamingResponse(
        generate_stream(prompt),
        media_type="text/event-stream"
    )

# Interface Gradio (accessible via /gradio)
demo = gr.Interface(
    fn=generate_text,
    inputs="text",
    outputs="text",
    title="Phi-3 Chat"
)

app = gr.mount_gradio_app(app, demo, path="/gradio")