Spaces:
Paused
Paused
import os | |
from typing import List, Dict | |
import logging | |
import dotenv | |
import torch | |
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse | |
from pydantic import BaseModel | |
from llama_cpp import Llama | |
from huggingface_hub import hf_hub_download, login | |
import uvicorn | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
AutoModelForCausalLM, | |
pipeline | |
) | |
from UofTearsBot import UofTearsBot | |
MODEL_REPO = "bartowski/Mistral-7B-Instruct-v0.3-GGUF" | |
MODEL_FILE = "Mistral-7B-Instruct-v0.3-Q4_K_M.gguf" | |
CHAT_FORMAT = "mistral-instruct" | |
dotenv.load_dotenv() | |
login(token=os.getenv("HF_TOKEN")) | |
MODEL_PATH = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=MODEL_FILE, | |
local_dir="/tmp/models", | |
local_dir_use_symlinks=False, | |
) | |
llm = Llama( | |
model_path=MODEL_PATH, | |
n_ctx=int(os.getenv("N_CTX", "1024")), | |
n_threads=os.cpu_count() or 4, | |
n_batch=int(os.getenv("N_BATCH", "32")), | |
chat_format=CHAT_FORMAT, | |
) | |
# Start the FastAPI app | |
app = FastAPI() | |
chatbots: Dict[str, UofTearsBot] = {} | |
class ChatRequest(BaseModel): | |
user_id: str | |
user_text: str | |
async def chat(request: ChatRequest): | |
try: | |
if request.user_id not in chatbots: | |
chatbots[request.user_id] = UofTearsBot(llm) | |
current_bot = chatbots[request.user_id] | |
def token_generator(): | |
print("[INFO] Model is streaming response...", flush=True) | |
for token in current_bot.converse(request.user_text): | |
yield token | |
print("[INFO] Model finished streaming β ", flush=True) | |
return StreamingResponse(token_generator(), media_type="text/plain") | |
except Exception as e: | |
import traceback | |
traceback.print_exc() # logs to HF logs | |
return JSONResponse( | |
status_code=500, | |
content={"error": str(e)} | |
) | |
async def home(): | |
return "<h1>App is running π</h1>" | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) # huggingface port | |