Maximofn's picture
Update model loading with English language support and system prompt
6a068d1
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
import os
from dotenv import load_dotenv
load_dotenv()
# Initialize the model and tokenizer
print("Cargando modelo y tokenizer...")
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
try:
# Load the model in BF16 format for better performance and lower memory usage
tokenizer = AutoTokenizer.from_pretrained(model_name)
if device == "cuda":
print("Using GPU for the model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True
)
else:
print("Using CPU for the model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map={"": device},
torch_dtype=torch.float32
)
print(f"Model loaded successfully on: {device}")
except Exception as e:
print(f"Error loading the model: {str(e)}")
raise
# Define the function that calls the model
def call_model(state: MessagesState):
"""
Call the model with the given messages
Args:
state: MessagesState
Returns:
dict: A dictionary containing the generated text and the thread ID
"""
# Convert LangChain messages to chat format
messages = [
{"role": "system", "content": "You are a friendly Chatbot. Always reply in the language in which the user is writing to you."}
]
for msg in state["messages"]:
if isinstance(msg, HumanMessage):
messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
messages.append({"role": "assistant", "content": msg.content})
# Prepare the input using the chat template
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
# Generate response
outputs = model.generate(
inputs,
max_new_tokens=512, # Increase the number of tokens for longer responses
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode and clean the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response (after the last user message)
response = response.split("Assistant:")[-1].strip()
# Convert the response to LangChain format
ai_message = AIMessage(content=response)
return {"messages": state["messages"] + [ai_message]}
# Define the graph
workflow = StateGraph(state_schema=MessagesState)
# Define the node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)
# Add memory
memory = MemorySaver()
graph_app = workflow.compile(checkpointer=memory)
# Define the data model for the request
class QueryRequest(BaseModel):
query: str
thread_id: str = "default"
# Create the FastAPI application
app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")
# Welcome endpoint
@app.get("/")
async def api_home():
"""Welcome endpoint"""
return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}
# Generate endpoint
@app.post("/generate")
async def generate(request: QueryRequest):
"""
Endpoint to generate text using the language model
Args:
request: QueryRequest
query: str
thread_id: str = "default"
Returns:
dict: A dictionary containing the generated text and the thread ID
"""
try:
# Configure the thread ID
config = {"configurable": {"thread_id": request.thread_id}}
# Create the input message
input_messages = [HumanMessage(content=request.query)]
# Invoke the graph
output = graph_app.invoke({"messages": input_messages}, config)
# Get the model response
response = output["messages"][-1].content
return {
"generated_text": response,
"thread_id": request.thread_id
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)