from typing import List, Union, Dict, Tuple from typing_extensions import NotRequired, TypedDict from langchain_core.agents import ( AgentAction, AgentFinish ) from langchain_core.messages import BaseMessage, AIMessage from langgraph.graph import END, StateGraph from runnables import Answerer, Agent from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE OUTPUT_KEY = "response" def _get_graph( agent_model_name: str = "gpt-4-turbo", agent_system_template: str = _AGENT_SYSTEM_TEMPLATE, agent_temperature: float = 0.0, answerer_model_name: str = "gpt-4-turbo", answerer_system_template: str = _ANSWERER_SYSTEM_TEMPLATE, answerer_temperature: float = 0.0, collection_index:int = 0, use_doctrines:bool = True, search_type:str = "similarity", similarity_threshold:float = 0.0, k:int = 15, ): agent = Agent( model_name = agent_model_name, system_template = agent_system_template, temperature = agent_temperature, ) agent_runnable = agent.runnable answerer = Answerer( model_name = answerer_model_name, system_template = answerer_system_template, temperature = answerer_temperature, collection_index = collection_index, use_doctrines = use_doctrines, search_type = search_type, similarity_threshold = similarity_threshold, k = k, ) answerer_runnable = answerer.runnable # GRAPH class GraphState(TypedDict): query: str agent_outcome: NotRequired[ Union[AgentAction, AgentFinish] ] chat_history: List[BaseMessage] response: NotRequired[ Dict[ str, Union[ str, List[int], List[Dict[str, Union[int, str]]] ] ] ] ## Nodes' functions async def execute_agent( state: GraphState, config: Dict, ) -> Union[AgentAction, AgentFinish, None]: """ Invokes the agent model to generate a response based on the current state. This function calls the agent model to generate a response to the current conversation state. Args: state (messages): The current state of the agent. Returns: dict: The new agent outcome. """ inputs = state.copy() agent_outcome = await agent_runnable \ .with_config({"run_name": "agent_node"}) \ .ainvoke(inputs, config=config) return {"agent_outcome": agent_outcome} def execute_tool( state: GraphState, config: Dict, ) -> List[Tuple[AgentAction, str]]: """ Executes the Retrieve tool. Args: state (messages): The current state of the agent. Returns: dict: The final response. """ inputs = state["agent_outcome"][0].tool_input tool_output = answerer_runnable.invoke( {"query": inputs["standalone_question"]}, config=config ) return { OUTPUT_KEY: tool_output } def finish( state: GraphState ) -> None: if state[OUTPUT_KEY] is not None: response = state[OUTPUT_KEY] else: response = { "answer": AIMessage(state['agent_outcome'].return_values['output']), "docs": [], "standalone_question": None } return {OUTPUT_KEY: response} ## Edges' functions def parse( state: GraphState ) -> str: """ Router based on the previous agent outcome. This function checks the agent outcome to determine if the agent decided to finish the conversation. In that case it ends the process, otherwise it calls a tool. Args: state (messages): The current state of the agent. Returns: str: A decision to either "end", "use_tool". """ agent_outcome = state["agent_outcome"] if isinstance(agent_outcome, AgentFinish): return "end" elif isinstance(agent_outcome, List): agent_outcome = agent_outcome[0] if agent_outcome.tool is not None: return "use_tool" ## Graph graph = StateGraph(GraphState) # Define the nodes graph.add_node("agent", execute_agent) graph.add_node("tools", execute_tool) graph.add_node("finish", finish) # Start from the agent node graph.set_entry_point("agent") # Decide whether to use a tool or end the process graph.add_conditional_edges( "agent", parse, { "use_tool": "tools", "end": "finish", }, ) # It ends after generating a response with context graph.add_edge("tools", "finish") graph.add_edge("finish", END) # Compile compiled_graph = graph.compile() return compiled_graph