Spaces:
Paused
Paused
File size: 2,157 Bytes
22d76f2 b48d8e8 22d76f2 b48d8e8 22d76f2 e556488 22d76f2 e556488 22d76f2 b48d8e8 22d76f2 756c3ae b48d8e8 756c3ae b48d8e8 756c3ae 22d76f2 b48d8e8 22d76f2 b48d8e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
from typing import List, Dict
import logging
import dotenv
import torch
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse
from pydantic import BaseModel
from llama_cpp import Llama
from huggingface_hub import hf_hub_download, login
import uvicorn
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
pipeline
)
from UofTearsBot import UofTearsBot
MODEL_REPO = "bartowski/Mistral-7B-Instruct-v0.3-GGUF"
MODEL_FILE = "Mistral-7B-Instruct-v0.3-Q4_K_M.gguf"
CHAT_FORMAT = "mistral-instruct"
dotenv.load_dotenv()
login(token=os.getenv("HF_TOKEN"))
MODEL_PATH = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
local_dir="/tmp/models",
local_dir_use_symlinks=False,
)
llm = Llama(
model_path=MODEL_PATH,
n_ctx=int(os.getenv("N_CTX", "1024")),
n_threads=os.cpu_count() or 4,
n_batch=int(os.getenv("N_BATCH", "32")),
chat_format=CHAT_FORMAT,
)
# Start the FastAPI app
app = FastAPI()
chatbots: Dict[str, UofTearsBot] = {}
class ChatRequest(BaseModel):
user_id: str
user_text: str
@app.post("/chat")
async def chat(request: ChatRequest):
try:
if request.user_id not in chatbots:
chatbots[request.user_id] = UofTearsBot(llm)
current_bot = chatbots[request.user_id]
def token_generator():
print("[INFO] Model is streaming response...", flush=True)
for token in current_bot.converse(request.user_text):
yield token
print("[INFO] Model finished streaming β
", flush=True)
return StreamingResponse(token_generator(), media_type="text/plain")
except Exception as e:
import traceback
traceback.print_exc() # logs to HF logs
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
@app.get("/", response_class=HTMLResponse)
async def home():
return "<h1>App is running π</h1>"
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) # huggingface port
|