st-chat / app.py
fartinalbania's picture
Update app.py
7d1cf45 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import logging
import gradio as gr
import uvicorn
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_ID = "tugstugi/Qwen2.5-Coder-0.5B-QwQ-draft"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(device)
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: list[ChatMessage]
class ChatResponse(BaseModel):
response: str
status: str = "success"
def build_prompt(messages):
prompt = ""
for message in messages:
if message["role"] == "user":
prompt += f"<|im_start|>user\n{message['content']}<|im_end|>\n"
elif message["role"] == "assistant":
prompt += f"<|im_start|>assistant\n{message['content']}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
return prompt
def generate_response(conversation_history, max_new_tokens=1500):
prompt_text = build_prompt(conversation_history)
inputs = tokenizer(prompt_text, return_tensors="pt").to(device)
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.8,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
return generated_text.strip()
@app.post("/api/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
try:
conversation = [{"role": msg.role, "content": msg.content} for msg in request.messages]
response_text = generate_response(conversation)
return ChatResponse(response=response_text)
except Exception as e:
logger.error(f"Error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/health")
async def health_check():
return {"status": "healthy"}
# Gradio setup
iface = gr.Interface(fn=lambda input: generate_response([{"role": "user", "content": input}]),
inputs="text", outputs="text")
app = gr.mount_gradio_app(app, iface, path="/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)