SmolLM2_backend / app.py
Maximofn's picture
Add HuggingFace token environment variable loading
af05832
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
import os
from dotenv import load_dotenv
load_dotenv()
# HuggingFace token
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", os.getenv("HUGGINGFACE_TOKEN"))
# Initialize the HuggingFace model
model = InferenceClient(
model="Qwen/Qwen2.5-72B-Instruct",
api_key=os.getenv("HUGGINGFACE_TOKEN")
)
# Define the function that calls the model
def call_model(state: MessagesState):
"""
Call the model with the given messages
Args:
state: MessagesState
Returns:
dict: A dictionary containing the generated text and the thread ID
"""
# Convert LangChain messages to HuggingFace format
hf_messages = []
for msg in state["messages"]:
if isinstance(msg, HumanMessage):
hf_messages.append({"role": "user", "content": msg.content})
elif isinstance(msg, AIMessage):
hf_messages.append({"role": "assistant", "content": msg.content})
# Call the API
response = model.chat_completion(
messages=hf_messages,
temperature=0.5,
max_tokens=64,
top_p=0.7
)
# Convert the response to LangChain format
ai_message = AIMessage(content=response.choices[0].message.content)
return {"messages": state["messages"] + [ai_message]}
# Define the graph
workflow = StateGraph(state_schema=MessagesState)
# Define the node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)
# Add memory
memory = MemorySaver()
graph_app = workflow.compile(checkpointer=memory)
# Define the data model for the request
class QueryRequest(BaseModel):
query: str
thread_id: str = "default"
# Create the FastAPI application
app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph")
# Welcome endpoint
@app.get("/")
async def api_home():
"""Welcome endpoint"""
return {"detail": "Welcome to FastAPI, Langchain, Docker tutorial"}
# Generate endpoint
@app.post("/generate")
async def generate(request: QueryRequest):
"""
Endpoint to generate text using the language model
Args:
request: QueryRequest
query: str
thread_id: str = "default"
Returns:
dict: A dictionary containing the generated text and the thread ID
"""
try:
# Configure the thread ID
config = {"configurable": {"thread_id": request.thread_id}}
# Create the input message
input_messages = [HumanMessage(content=request.query)]
# Invoke the graph
output = graph_app.invoke({"messages": input_messages}, config)
# Get the model response
response = output["messages"][-1].content
return {
"generated_text": response,
"thread_id": request.thread_id
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error al generar texto: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)