Spaces:
Running
Running
from typing import List, Union, Dict, Tuple | |
from typing_extensions import NotRequired, TypedDict | |
from langchain_core.agents import ( | |
AgentAction, | |
AgentFinish | |
) | |
from langchain_core.messages import BaseMessage, AIMessage | |
from langgraph.graph import END, StateGraph | |
from runnables import Answerer, Agent | |
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE | |
OUTPUT_KEY = "response" | |
def _get_graph( | |
agent_model_name: str = "gpt-4-turbo", | |
agent_system_template: str = _AGENT_SYSTEM_TEMPLATE, | |
agent_temperature: float = 0.0, | |
answerer_model_name: str = "gpt-4-turbo", | |
answerer_system_template: str = _ANSWERER_SYSTEM_TEMPLATE, | |
answerer_temperature: float = 0.0, | |
collection_index:int = 0, | |
use_doctrines:bool = True, | |
search_type:str = "similarity", | |
similarity_threshold:float = 0.0, | |
k:int = 15, | |
): | |
agent = Agent( | |
model_name = agent_model_name, | |
system_template = agent_system_template, | |
temperature = agent_temperature, | |
) | |
agent_runnable = agent.runnable | |
answerer = Answerer( | |
model_name = answerer_model_name, | |
system_template = answerer_system_template, | |
temperature = answerer_temperature, | |
collection_index = collection_index, | |
use_doctrines = use_doctrines, | |
search_type = search_type, | |
similarity_threshold = similarity_threshold, | |
k = k, | |
) | |
answerer_runnable = answerer.runnable | |
# GRAPH | |
class GraphState(TypedDict): | |
query: str | |
agent_outcome: NotRequired[ | |
Union[AgentAction, AgentFinish] | |
] | |
chat_history: List[BaseMessage] | |
response: NotRequired[ | |
Dict[ | |
str, | |
Union[ | |
str, | |
List[int], | |
List[Dict[str, Union[int, str]]] | |
] | |
] | |
] | |
## Nodes' functions | |
async def execute_agent( | |
state: GraphState, | |
config: Dict, | |
) -> Union[AgentAction, AgentFinish, None]: | |
""" | |
Invokes the agent model to generate a response based on the current state. | |
This function calls the agent model to generate a response to the current conversation state. | |
Args: | |
state (messages): The current state of the agent. | |
Returns: | |
dict: The new agent outcome. | |
""" | |
inputs = state.copy() | |
agent_outcome = await agent_runnable \ | |
.with_config({"run_name": "agent_node"}) \ | |
.ainvoke(inputs, config=config) | |
return {"agent_outcome": agent_outcome} | |
def execute_tool( | |
state: GraphState, | |
config: Dict, | |
) -> List[Tuple[AgentAction, str]]: | |
""" | |
Executes the Retrieve tool. | |
Args: | |
state (messages): The current state of the agent. | |
Returns: | |
dict: The final response. | |
""" | |
inputs = state["agent_outcome"][0].tool_input | |
tool_output = answerer_runnable.invoke( | |
{"query": inputs["standalone_question"]}, | |
config=config | |
) | |
return { | |
OUTPUT_KEY: tool_output | |
} | |
def finish( | |
state: GraphState | |
) -> None: | |
if state[OUTPUT_KEY] is not None: | |
response = state[OUTPUT_KEY] | |
else: | |
response = { | |
"answer": AIMessage(state['agent_outcome'].return_values['output']), | |
"docs": [], | |
"standalone_question": None | |
} | |
return {OUTPUT_KEY: response} | |
## Edges' functions | |
def parse( | |
state: GraphState | |
) -> str: | |
""" | |
Router based on the previous agent outcome. | |
This function checks the agent outcome to determine if the agent decided to finish the conversation. | |
In that case it ends the process, otherwise it calls a tool. | |
Args: | |
state (messages): The current state of the agent. | |
Returns: | |
str: A decision to either "end", "use_tool". | |
""" | |
agent_outcome = state["agent_outcome"] | |
if isinstance(agent_outcome, AgentFinish): | |
return "end" | |
elif isinstance(agent_outcome, List): | |
agent_outcome = agent_outcome[0] | |
if agent_outcome.tool is not None: | |
return "use_tool" | |
## Graph | |
graph = StateGraph(GraphState) | |
# Define the nodes | |
graph.add_node("agent", execute_agent) | |
graph.add_node("tools", execute_tool) | |
graph.add_node("finish", finish) | |
# Start from the agent node | |
graph.set_entry_point("agent") | |
# Decide whether to use a tool or end the process | |
graph.add_conditional_edges( | |
"agent", | |
parse, | |
{ | |
"use_tool": "tools", | |
"end": "finish", | |
}, | |
) | |
# It ends after generating a response with context | |
graph.add_edge("tools", "finish") | |
graph.add_edge("finish", END) | |
# Compile | |
compiled_graph = graph.compile() | |
return compiled_graph |