import os from dotenv import load_dotenv from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition, ToolNode from langchain_openai import ChatOpenAI from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from omegaconf import OmegaConf from .tools import * def load_config(config_path: str): config = OmegaConf.load(config_path) return config # --- Constants --- CONFIG = load_config("config.yaml") SYSTEM_PROMPT = CONFIG["system_prompt"]["custom"] # Load environment variables from .env file load_dotenv() class LangGraphAgent4GAIA: def __init__(self, model_provider: str, model_name: str): self.sys_prompt = SystemMessage(content=SYSTEM_PROMPT) self.graph = self.get_agent(model_provider, model_name) def assistant(self, state: MessagesState): """Assistant node""" return {"messages": [self.llm_with_tools.invoke([self.sys_prompt] + state["messages"])]} def get_agent(self, provider: str, model_name: str): tools = [ multiply, add, add_list, subtract, divide, modulo, web_search, arxiv_search, wiki_search, read_xlsx_file, get_python_file ] # 1. Build graph if provider == "openai": llm = ChatOpenAI( model=model_name, temperature=0, max_retries=2, api_key=os.getenv("OPENAI_API_KEY") ) elif provider == "huggingface": llm = ChatHuggingFace( llm=HuggingFaceEndpoint( repo_id=model_name, task="text-generation", max_new_tokens=1024, do_sample=False, repetition_penalty=1.03, temperature=0 ), verbose=True ) else: raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.") # 2. Bind tools to LLM self.llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False) builder = StateGraph(MessagesState) builder.add_node("assistant", self.assistant) builder.add_node("tools", ToolNode(tools)) builder.add_edge(START, "assistant") builder.add_conditional_edges( "assistant", tools_condition, ) builder.add_edge("tools", "assistant") # Compile graph return builder.compile() if __name__ == "__main__": from langchain_core.runnables.graph import MermaidDrawMethod question = "What is the capital of Spain?" # Build the graph agent_manager = LangGraphAgent4GAIA(CONFIG["model"]["provider"], CONFIG["model"]["name"]) img_data = agent_manager.graph.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API) with open('agentic/graph.png', "wb") as f: f.write(img_data) # Run the graph messages = [HumanMessage(content=question)] messages = agent_manager.graph.invoke({"messages": messages}, {"recursion_limit": 50}) for m in messages["messages"]: m.pretty_print()