File size: 1,607 Bytes
732e177
 
 
 
b57108e
071ac3e
732e177
c660b8d
 
b57108e
 
c660b8d
 
 
732e177
 
 
 
b57108e
 
 
 
 
071ac3e
f701bc1
071ac3e
c660b8d
732e177
 
 
 
 
 
 
c660b8d
732e177
b57108e
 
 
 
071ac3e
b57108e
 
12eddd2
 
071ac3e
b57108e
 
 
12eddd2
b57108e
 
12eddd2
071ac3e
13993db
c660b8d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from transformers import pipeline
from .config import settings
from pydantic import BaseModel

app = FastAPI(
    title="DeepSeek Chat",
    description="A chat API using DeepSeek model",
    version="1.0.0"
)

# Mount static files and templates
app.mount("/static", StaticFiles(directory="app/static"), name="static")
templates = Jinja2Templates(directory="app/templates")

# Initialize pipeline
print("Loading model pipeline...")
pipe = pipeline(
    "text-generation",
    model=settings.MODEL_NAME,
    token=settings.HUGGINGFACE_TOKEN,
    trust_remote_code=True
)

class ChatMessage(BaseModel):
    message: str

@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
    return templates.TemplateResponse("chat.html", {"request": request})

@app.post("/chat")
async def chat(message: ChatMessage):
    # Prepare messages
    messages = [
        {"role": "user", "content": message.message}
    ]
    
    # Generate response using pipeline
    response = pipe(messages)

    print(response)
    
    # Extract the response text
    if isinstance(response, list):
        response_text = response[0].get('generated_text', '')
        print(response_text)
    else:
        response_text = response.get('generated_text', '')
        print(response_text)
    
    return {"response": response_text[-1]['content']}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)