File size: 1,924 Bytes
2605bf3
 
 
 
 
 
b55152c
783cfb7
b55152c
 
 
 
783cfb7
 
 
b55152c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
783cfb7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os

os.environ["HF_HOME"] = "/tmp/hf"
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf/datasets"
os.environ["HF_METRICS_CACHE"] = "/tmp/hf/metrics"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from fastapi.responses import StreamingResponse
import torch
import threading

app = FastAPI()

# Cargar modelo y tokenizer de Phi-2 (usa el modelo de Hugging Face Hub)
model_id = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

# Modelo de entrada
class ChatRequest(BaseModel):
    message: str

@app.post("/chat/stream")
async def chat_stream(request: ChatRequest):
    prompt = f"""Responde en español de forma clara y breve como un asistente IA.
Usuario: {request.message}
IA:"""

    # Tokenizar entrada
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Streamer para obtener tokens generados poco a poco
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    # Iniciar la generación en un hilo aparte
    generation_kwargs = dict(
        input_ids=input_ids,    
        attention_mask=attention_mask,
        max_new_tokens=48,  # Puedes ajustar este valor para más/menos tokens
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        streamer=streamer,
        pad_token_id=tokenizer.eos_token_id
    )
    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # StreamingResponse espera un generador que devuelva texto
    async def event_generator():
        for new_text in streamer:
            yield new_text

    return StreamingResponse(event_generator(), media_type="text/plain")