from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph

import os
from dotenv import load_dotenv
load_dotenv()

# Initialize the model and tokenizer
print("Cargando modelo y tokenizer...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"

try:
    # Load the model in BF16 format for better performance and lower memory usage
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    if device == "cuda":
        print("Using GPU for the model...")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            low_cpu_mem_usage=True
        )
    else:
        print("Using CPU for the model...")
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map={"": device},
            torch_dtype=torch.float32
        )

    print(f"Model loaded successfully on: {device}")
except Exception as e:
    print(f"Error loading the model: {str(e)}")
    raise

# Define the function that calls the model
def call_model(state: MessagesState):
    """
    Call the model with the given messages

    Args:
        state: MessagesState

    Returns:
        dict: A dictionary containing the generated text and the thread ID
    """
    # Convert LangChain messages to chat format
    messages = [
        {"role": "system", "content": "You are a friendly Chatbot. Always reply in the language in which the user is writing to you."}
    ]
    
    for msg in state["messages"]:
        if isinstance(msg, HumanMessage):
            messages.append({"role": "user", "content": msg.content})
        elif isinstance(msg, AIMessage):
            messages.append({"role": "assistant", "content": msg.content})
    
    # Prepare the input using the chat template
    input_text = tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
    
    # Generate response
    outputs = model.generate(
        inputs,
        max_new_tokens=512,  # Increase the number of tokens for longer responses
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    # Decode and clean the response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the assistant's response (after the last user message)
    response = response.split("Assistant:")[-1].strip()
    
    # Convert the response to LangChain format
    ai_message = AIMessage(content=response)
    return {"messages": state["messages"] + [ai_message]}

# Define the graph
workflow = StateGraph(state_schema=MessagesState)

# Define the node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)

# Add memory
memory = MemorySaver()
graph_app = workflow.compile(checkpointer=memory)

# Define the data model for the request
class QueryRequest(BaseModel):
    query: str
    thread_id: str = "default"

# Create the FastAPI application
app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")

# Welcome endpoint
@app.get("/")
async def api_home():
    """Welcome endpoint"""
    return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}

# Generate endpoint
@app.post("/generate")
async def generate(request: QueryRequest):
    """
    Endpoint to generate text using the language model
    
    Args:
        request: QueryRequest
        query: str
        thread_id: str = "default"

    Returns:
        dict: A dictionary containing the generated text and the thread ID
    """
    try:
        # Configure the thread ID
        config = {"configurable": {"thread_id": request.thread_id}}
        
        # Create the input message
        input_messages = [HumanMessage(content=request.query)]
        
        # Invoke the graph
        output = graph_app.invoke({"messages": input_messages}, config)
        
        # Get the model response
        response = output["messages"][-1].content
        
        return {
            "generated_text": response,
            "thread_id": request.thread_id
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)