from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch from langchain_core.messages import HumanMessage, AIMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import START, MessagesState, StateGraph import os from dotenv import load_dotenv load_dotenv() # Initialize the model and tokenizer print("Cargando modelo y tokenizer...") device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct" try: # Load the model in BF16 format for better performance and lower memory usage tokenizer = AutoTokenizer.from_pretrained(model_name) if device == "cuda": print("Using GPU for the model...") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True ) else: print("Using CPU for the model...") model = AutoModelForCausalLM.from_pretrained( model_name, device_map={"": device}, torch_dtype=torch.float32 ) print(f"Model loaded successfully on: {device}") except Exception as e: print(f"Error loading the model: {str(e)}") raise # Define the function that calls the model def call_model(state: MessagesState): """ Call the model with the given messages Args: state: MessagesState Returns: dict: A dictionary containing the generated text and the thread ID """ # Convert LangChain messages to chat format messages = [ {"role": "system", "content": "You are a friendly Chatbot. Always reply in the language in which the user is writing to you."} ] for msg in state["messages"]: if isinstance(msg, HumanMessage): messages.append({"role": "user", "content": msg.content}) elif isinstance(msg, AIMessage): messages.append({"role": "assistant", "content": msg.content}) # Prepare the input using the chat template input_text = tokenizer.apply_chat_template(messages, tokenize=False) inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) # Generate response outputs = model.generate( inputs, max_new_tokens=512, # Increase the number of tokens for longer responses temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode and clean the response response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's response (after the last user message) response = response.split("Assistant:")[-1].strip() # Convert the response to LangChain format ai_message = AIMessage(content=response) return {"messages": state["messages"] + [ai_message]} # Define the graph workflow = StateGraph(state_schema=MessagesState) # Define the node in the graph workflow.add_edge(START, "model") workflow.add_node("model", call_model) # Add memory memory = MemorySaver() graph_app = workflow.compile(checkpointer=memory) # Define the data model for the request class QueryRequest(BaseModel): query: str thread_id: str = "default" # Create the FastAPI application app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph") # Welcome endpoint @app.get("/") async def api_home(): """Welcome endpoint""" return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"} # Generate endpoint @app.post("/generate") async def generate(request: QueryRequest): """ Endpoint to generate text using the language model Args: request: QueryRequest query: str thread_id: str = "default" Returns: dict: A dictionary containing the generated text and the thread ID """ try: # Configure the thread ID config = {"configurable": {"thread_id": request.thread_id}} # Create the input message input_messages = [HumanMessage(content=request.query)] # Invoke the graph output = graph_app.invoke({"messages": input_messages}, config) # Get the model response response = output["messages"][-1].content return { "generated_text": response, "thread_id": request.thread_id } except Exception as e: raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)