lautel's picture
Upload 4 files
7bf44ed verified
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_openai import ChatOpenAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from omegaconf import OmegaConf
from .tools import *
def load_config(config_path: str):
config = OmegaConf.load(config_path)
return config
# --- Constants ---
CONFIG = load_config("config.yaml")
SYSTEM_PROMPT = CONFIG["system_prompt"]["custom"]
# Load environment variables from .env file
load_dotenv()
class LangGraphAgent4GAIA:
def __init__(self, model_provider: str, model_name: str):
self.sys_prompt = SystemMessage(content=SYSTEM_PROMPT)
self.graph = self.get_agent(model_provider, model_name)
def assistant(self, state: MessagesState):
"""Assistant node"""
return {"messages": [self.llm_with_tools.invoke([self.sys_prompt] + state["messages"])]}
def get_agent(self, provider: str, model_name: str):
tools = [
multiply,
add,
add_list,
subtract,
divide,
modulo,
web_search,
arxiv_search,
wiki_search,
read_xlsx_file,
get_python_file
]
# 1. Build graph
if provider == "openai":
llm = ChatOpenAI(
model=model_name,
temperature=0,
max_retries=2,
api_key=os.getenv("OPENAI_API_KEY")
)
elif provider == "huggingface":
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id=model_name,
task="text-generation",
max_new_tokens=1024,
do_sample=False,
repetition_penalty=1.03,
temperature=0
),
verbose=True
)
else:
raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")
# 2. Bind tools to LLM
self.llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False)
builder = StateGraph(MessagesState)
builder.add_node("assistant", self.assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile graph
return builder.compile()
if __name__ == "__main__":
from langchain_core.runnables.graph import MermaidDrawMethod
question = "What is the capital of Spain?"
# Build the graph
agent_manager = LangGraphAgent4GAIA(CONFIG["model"]["provider"], CONFIG["model"]["name"])
img_data = agent_manager.graph.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API)
with open('agentic/graph.png', "wb") as f:
f.write(img_data)
# Run the graph
messages = [HumanMessage(content=question)]
messages = agent_manager.graph.invoke({"messages": messages}, {"recursion_limit": 50})
for m in messages["messages"]:
m.pretty_print()