from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi import FastAPI, HTTPException
import logging

from llm_backend import chat_with_model, stream_with_model
from schema import ChatRequest

"""
uvicorn api:app --reload --port 5723
fastapi dev api.py --port 5723
"""

app = FastAPI()

logger = logging.getLogger("uvicorn.error")


@app.get("/")
def index():
    logger.info("this is a debug message")
    return {"Hello": "world"}


@app.post("/chat_stream")
def chat_stream(request: ChatRequest):
    kwargs = {
        "max_tokens": request.max_tokens,
        "temperature": request.temperature,
        "stream": True,
        "top_p": request.top_p,
        "min_p": request.min_p,
        "typical_p": request.typical_p,
        "frequency_penalty": request.frequency_penalty,
        "presence_penalty": request.presence_penalty,
        "repeat_penalty": request.repeat_penalty,
        "top_k": request.top_k,
        "seed": request.seed,
        "tfs_z": request.tfs_z,
        "mirostat_mode": request.mirostat_mode,
        "mirostat_tau": request.mirostat_tau,
        "mirostat_eta": request.mirostat_eta,
    }
    try:
        token_generator = stream_with_model(request.chat_history, request.model, kwargs)
        return StreamingResponse(token_generator, media_type="text/plain")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/chat")
def chat(request: ChatRequest):
    kwargs = {
        "max_tokens": request.max_tokens,
        "temperature": request.temperature,
        "stream": False,
        "top_p": request.top_p,
        "min_p": request.min_p,
        "typical_p": request.typical_p,
        "frequency_penalty": request.frequency_penalty,
        "presence_penalty": request.presence_penalty,
        "repeat_penalty": request.repeat_penalty,
        "top_k": request.top_k,
        "seed": request.seed,
        "tfs_z": request.tfs_z,
        "mirostat_mode": request.mirostat_mode,
        "mirostat_tau": request.mirostat_tau,
        "mirostat_eta": request.mirostat_eta,
    }
    try:
        output = chat_with_model(request.chat_history, request.model, kwargs)
        return {"response": output}
        # return HTMLResponse(output, media_type="text/plain")
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))