File size: 2,157 Bytes
22d76f2
 
 
 
 
 
 
b48d8e8
22d76f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b48d8e8
 
 
22d76f2
 
 
 
 
 
 
 
 
 
 
 
 
e556488
22d76f2
e556488
22d76f2
 
 
 
 
 
 
 
 
 
 
b48d8e8
22d76f2
 
756c3ae
 
 
 
b48d8e8
 
 
 
 
 
 
 
 
756c3ae
 
b48d8e8
756c3ae
 
 
 
 
22d76f2
 
 
 
b48d8e8
22d76f2
 
b48d8e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from typing import List, Dict
import logging
import dotenv

import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse
from pydantic import BaseModel
from llama_cpp import Llama
from huggingface_hub import hf_hub_download, login

import uvicorn

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    pipeline
)

from UofTearsBot import UofTearsBot

MODEL_REPO = "bartowski/Mistral-7B-Instruct-v0.3-GGUF"
MODEL_FILE = "Mistral-7B-Instruct-v0.3-Q4_K_M.gguf"
CHAT_FORMAT = "mistral-instruct"

dotenv.load_dotenv()
login(token=os.getenv("HF_TOKEN"))

MODEL_PATH = hf_hub_download(
    repo_id=MODEL_REPO,
    filename=MODEL_FILE,
    local_dir="/tmp/models",
    local_dir_use_symlinks=False,
)

llm = Llama(
    model_path=MODEL_PATH,
    n_ctx=int(os.getenv("N_CTX", "1024")),
    n_threads=os.cpu_count() or 4,
    n_batch=int(os.getenv("N_BATCH", "32")),
    chat_format=CHAT_FORMAT,
)

# Start the FastAPI app
app = FastAPI()
chatbots: Dict[str, UofTearsBot] = {}

class ChatRequest(BaseModel):
    user_id: str
    user_text: str


@app.post("/chat")
async def chat(request: ChatRequest):
    try:
        if request.user_id not in chatbots:
            chatbots[request.user_id] = UofTearsBot(llm)
        current_bot = chatbots[request.user_id]

        def token_generator():
            print("[INFO] Model is streaming response...", flush=True)
            for token in current_bot.converse(request.user_text):
                yield token
            print("[INFO] Model finished streaming βœ…", flush=True)

        return StreamingResponse(token_generator(), media_type="text/plain")

    except Exception as e:
        import traceback
        traceback.print_exc()  # logs to HF logs
        return JSONResponse(
            status_code=500,
            content={"error": str(e)}
        )


@app.get("/", response_class=HTMLResponse)
async def home():
    return "<h1>App is running πŸš€</h1>"


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)  # huggingface port