test-interface / graph.py
tommasodelorenzo's picture
Upload folder using huggingface_hub
d46cc41 verified
from typing import List, Union, Dict, Tuple
from typing_extensions import NotRequired, TypedDict
from langchain_core.agents import (
AgentAction,
AgentFinish
)
from langchain_core.messages import BaseMessage, AIMessage
from langgraph.graph import END, StateGraph
from runnables import Answerer, Agent
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE
OUTPUT_KEY = "response"
def _get_graph(
agent_model_name: str = "gpt-4-turbo",
agent_system_template: str = _AGENT_SYSTEM_TEMPLATE,
agent_temperature: float = 0.0,
answerer_model_name: str = "gpt-4-turbo",
answerer_system_template: str = _ANSWERER_SYSTEM_TEMPLATE,
answerer_temperature: float = 0.0,
collection_index:int = 0,
use_doctrines:bool = True,
search_type:str = "similarity",
similarity_threshold:float = 0.0,
k:int = 15,
):
agent = Agent(
model_name = agent_model_name,
system_template = agent_system_template,
temperature = agent_temperature,
)
agent_runnable = agent.runnable
answerer = Answerer(
model_name = answerer_model_name,
system_template = answerer_system_template,
temperature = answerer_temperature,
collection_index = collection_index,
use_doctrines = use_doctrines,
search_type = search_type,
similarity_threshold = similarity_threshold,
k = k,
)
answerer_runnable = answerer.runnable
# GRAPH
class GraphState(TypedDict):
query: str
agent_outcome: NotRequired[
Union[AgentAction, AgentFinish]
]
chat_history: List[BaseMessage]
response: NotRequired[
Dict[
str,
Union[
str,
List[int],
List[Dict[str, Union[int, str]]]
]
]
]
## Nodes' functions
async def execute_agent(
state: GraphState,
config: Dict,
) -> Union[AgentAction, AgentFinish, None]:
"""
Invokes the agent model to generate a response based on the current state.
This function calls the agent model to generate a response to the current conversation state.
Args:
state (messages): The current state of the agent.
Returns:
dict: The new agent outcome.
"""
inputs = state.copy()
agent_outcome = await agent_runnable \
.with_config({"run_name": "agent_node"}) \
.ainvoke(inputs, config=config)
return {"agent_outcome": agent_outcome}
def execute_tool(
state: GraphState,
config: Dict,
) -> List[Tuple[AgentAction, str]]:
"""
Executes the Retrieve tool.
Args:
state (messages): The current state of the agent.
Returns:
dict: The final response.
"""
inputs = state["agent_outcome"][0].tool_input
tool_output = answerer_runnable.invoke(
{"query": inputs["standalone_question"]},
config=config
)
return {
OUTPUT_KEY: tool_output
}
def finish(
state: GraphState
) -> None:
if state[OUTPUT_KEY] is not None:
response = state[OUTPUT_KEY]
else:
response = {
"answer": AIMessage(state['agent_outcome'].return_values['output']),
"docs": [],
"standalone_question": None
}
return {OUTPUT_KEY: response}
## Edges' functions
def parse(
state: GraphState
) -> str:
"""
Router based on the previous agent outcome.
This function checks the agent outcome to determine if the agent decided to finish the conversation.
In that case it ends the process, otherwise it calls a tool.
Args:
state (messages): The current state of the agent.
Returns:
str: A decision to either "end", "use_tool".
"""
agent_outcome = state["agent_outcome"]
if isinstance(agent_outcome, AgentFinish):
return "end"
elif isinstance(agent_outcome, List):
agent_outcome = agent_outcome[0]
if agent_outcome.tool is not None:
return "use_tool"
## Graph
graph = StateGraph(GraphState)
# Define the nodes
graph.add_node("agent", execute_agent)
graph.add_node("tools", execute_tool)
graph.add_node("finish", finish)
# Start from the agent node
graph.set_entry_point("agent")
# Decide whether to use a tool or end the process
graph.add_conditional_edges(
"agent",
parse,
{
"use_tool": "tools",
"end": "finish",
},
)
# It ends after generating a response with context
graph.add_edge("tools", "finish")
graph.add_edge("finish", END)
# Compile
compiled_graph = graph.compile()
return compiled_graph