from fastapi import FastAPI
from pydantic import BaseModel, Field
from typing import List, Dict
from llama_cpp import Llama
import os
os.environ["HF_HOME"] = "./cache"


# Load the Llama model with the specified path and configuration
llm = Llama.from_pretrained(
    repo_id="bartowski/Llama-3.2-3B-Instruct-GGUF",  # Replace with the actual model repository ID
    filename="Llama-3.2-3B-Instruct-Q8_0.gguf",      # Replace with your actual model filename if necessary
    n_ctx=4096,
    cache_dir="./cache",
    n_threads=2,
)

# Define a Pydantic model for request validation
class Message(BaseModel):
    role: str   # "user" or "assistant"
    content: str  # The actual message content

class Validation(BaseModel):
    messages: List[Message] = Field(default_factory=list)  # List of previous messages in the conversation
    max_tokens: int = 1024   # Maximum tokens for the response
    temperature: float = 0.01  # Model response temperature for creativity

# Initialize the FastAPI application
app = FastAPI()

# Define the endpoint for generating responses
@app.post("/generate_response")
async def generate_response(item: Validation):
    # Generate a response using the Llama model with the chat history
    response = llm.create_chat_completion(
        messages=[{"role": msg.role, "content": msg.content} for msg in item.messages],
        max_tokens=item.max_tokens,
        temperature=item.temperature
    )

    # Extract and return the response text
    return {"response": response['choices'][0]['message']['content']}