|
import os |
|
from dotenv import load_dotenv |
|
from langgraph.graph import START, StateGraph, MessagesState |
|
from langgraph.prebuilt import tools_condition, ToolNode |
|
from langchain_openai import ChatOpenAI |
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint |
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
from omegaconf import OmegaConf |
|
from .tools import * |
|
|
|
|
|
def load_config(config_path: str): |
|
config = OmegaConf.load(config_path) |
|
return config |
|
|
|
|
|
CONFIG = load_config("config.yaml") |
|
SYSTEM_PROMPT = CONFIG["system_prompt"]["custom"] |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
class LangGraphAgent4GAIA: |
|
def __init__(self, model_provider: str, model_name: str): |
|
self.sys_prompt = SystemMessage(content=SYSTEM_PROMPT) |
|
self.graph = self.get_agent(model_provider, model_name) |
|
|
|
def assistant(self, state: MessagesState): |
|
"""Assistant node""" |
|
return {"messages": [self.llm_with_tools.invoke([self.sys_prompt] + state["messages"])]} |
|
|
|
def get_agent(self, provider: str, model_name: str): |
|
tools = [ |
|
multiply, |
|
add, |
|
add_list, |
|
subtract, |
|
divide, |
|
modulo, |
|
web_search, |
|
arxiv_search, |
|
wiki_search, |
|
read_xlsx_file, |
|
get_python_file |
|
] |
|
|
|
|
|
if provider == "openai": |
|
llm = ChatOpenAI( |
|
model=model_name, |
|
temperature=0, |
|
max_retries=2, |
|
api_key=os.getenv("OPENAI_API_KEY") |
|
) |
|
elif provider == "huggingface": |
|
llm = ChatHuggingFace( |
|
llm=HuggingFaceEndpoint( |
|
repo_id=model_name, |
|
task="text-generation", |
|
max_new_tokens=1024, |
|
do_sample=False, |
|
repetition_penalty=1.03, |
|
temperature=0 |
|
), |
|
verbose=True |
|
) |
|
else: |
|
raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.") |
|
|
|
|
|
|
|
self.llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False) |
|
|
|
builder = StateGraph(MessagesState) |
|
builder.add_node("assistant", self.assistant) |
|
builder.add_node("tools", ToolNode(tools)) |
|
builder.add_edge(START, "assistant") |
|
builder.add_conditional_edges( |
|
"assistant", |
|
tools_condition, |
|
) |
|
builder.add_edge("tools", "assistant") |
|
|
|
|
|
return builder.compile() |
|
|
|
|
|
if __name__ == "__main__": |
|
from langchain_core.runnables.graph import MermaidDrawMethod |
|
|
|
question = "What is the capital of Spain?" |
|
|
|
agent_manager = LangGraphAgent4GAIA(CONFIG["model"]["provider"], CONFIG["model"]["name"]) |
|
img_data = agent_manager.graph.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API) |
|
with open('agentic/graph.png', "wb") as f: |
|
f.write(img_data) |
|
|
|
|
|
messages = [HumanMessage(content=question)] |
|
messages = agent_manager.graph.invoke({"messages": messages}, {"recursion_limit": 50}) |
|
for m in messages["messages"]: |
|
m.pretty_print() |