Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import logging | |
import gradio as gr | |
import uvicorn | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
MODEL_ID = "tugstugi/Qwen2.5-Coder-0.5B-QwQ-draft" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(device) | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
messages: list[ChatMessage] | |
class ChatResponse(BaseModel): | |
response: str | |
status: str = "success" | |
def build_prompt(messages): | |
prompt = "" | |
for message in messages: | |
if message["role"] == "user": | |
prompt += f"<|im_start|>user\n{message['content']}<|im_end|>\n" | |
elif message["role"] == "assistant": | |
prompt += f"<|im_start|>assistant\n{message['content']}<|im_end|>\n" | |
prompt += "<|im_start|>assistant\n" | |
return prompt | |
def generate_response(conversation_history, max_new_tokens=1500): | |
prompt_text = build_prompt(conversation_history) | |
inputs = tokenizer(prompt_text, return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=0.8, | |
top_p=0.95, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
generated_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) | |
return generated_text.strip() | |
async def chat_endpoint(request: ChatRequest): | |
try: | |
conversation = [{"role": msg.role, "content": msg.content} for msg in request.messages] | |
response_text = generate_response(conversation) | |
return ChatResponse(response=response_text) | |
except Exception as e: | |
logger.error(f"Error: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
return {"status": "healthy"} | |
# Gradio setup | |
iface = gr.Interface(fn=lambda input: generate_response([{"role": "user", "content": input}]), | |
inputs="text", outputs="text") | |
app = gr.mount_gradio_app(app, iface, path="/") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |