Spaces:
Running
Running
File size: 5,190 Bytes
d46cc41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from typing import List, Union, Dict, Tuple
from typing_extensions import NotRequired, TypedDict
from langchain_core.agents import (
AgentAction,
AgentFinish
)
from langchain_core.messages import BaseMessage, AIMessage
from langgraph.graph import END, StateGraph
from runnables import Answerer, Agent
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE
OUTPUT_KEY = "response"
def _get_graph(
agent_model_name: str = "gpt-4-turbo",
agent_system_template: str = _AGENT_SYSTEM_TEMPLATE,
agent_temperature: float = 0.0,
answerer_model_name: str = "gpt-4-turbo",
answerer_system_template: str = _ANSWERER_SYSTEM_TEMPLATE,
answerer_temperature: float = 0.0,
collection_index:int = 0,
use_doctrines:bool = True,
search_type:str = "similarity",
similarity_threshold:float = 0.0,
k:int = 15,
):
agent = Agent(
model_name = agent_model_name,
system_template = agent_system_template,
temperature = agent_temperature,
)
agent_runnable = agent.runnable
answerer = Answerer(
model_name = answerer_model_name,
system_template = answerer_system_template,
temperature = answerer_temperature,
collection_index = collection_index,
use_doctrines = use_doctrines,
search_type = search_type,
similarity_threshold = similarity_threshold,
k = k,
)
answerer_runnable = answerer.runnable
# GRAPH
class GraphState(TypedDict):
query: str
agent_outcome: NotRequired[
Union[AgentAction, AgentFinish]
]
chat_history: List[BaseMessage]
response: NotRequired[
Dict[
str,
Union[
str,
List[int],
List[Dict[str, Union[int, str]]]
]
]
]
## Nodes' functions
async def execute_agent(
state: GraphState,
config: Dict,
) -> Union[AgentAction, AgentFinish, None]:
"""
Invokes the agent model to generate a response based on the current state.
This function calls the agent model to generate a response to the current conversation state.
Args:
state (messages): The current state of the agent.
Returns:
dict: The new agent outcome.
"""
inputs = state.copy()
agent_outcome = await agent_runnable \
.with_config({"run_name": "agent_node"}) \
.ainvoke(inputs, config=config)
return {"agent_outcome": agent_outcome}
def execute_tool(
state: GraphState,
config: Dict,
) -> List[Tuple[AgentAction, str]]:
"""
Executes the Retrieve tool.
Args:
state (messages): The current state of the agent.
Returns:
dict: The final response.
"""
inputs = state["agent_outcome"][0].tool_input
tool_output = answerer_runnable.invoke(
{"query": inputs["standalone_question"]},
config=config
)
return {
OUTPUT_KEY: tool_output
}
def finish(
state: GraphState
) -> None:
if state[OUTPUT_KEY] is not None:
response = state[OUTPUT_KEY]
else:
response = {
"answer": AIMessage(state['agent_outcome'].return_values['output']),
"docs": [],
"standalone_question": None
}
return {OUTPUT_KEY: response}
## Edges' functions
def parse(
state: GraphState
) -> str:
"""
Router based on the previous agent outcome.
This function checks the agent outcome to determine if the agent decided to finish the conversation.
In that case it ends the process, otherwise it calls a tool.
Args:
state (messages): The current state of the agent.
Returns:
str: A decision to either "end", "use_tool".
"""
agent_outcome = state["agent_outcome"]
if isinstance(agent_outcome, AgentFinish):
return "end"
elif isinstance(agent_outcome, List):
agent_outcome = agent_outcome[0]
if agent_outcome.tool is not None:
return "use_tool"
## Graph
graph = StateGraph(GraphState)
# Define the nodes
graph.add_node("agent", execute_agent)
graph.add_node("tools", execute_tool)
graph.add_node("finish", finish)
# Start from the agent node
graph.set_entry_point("agent")
# Decide whether to use a tool or end the process
graph.add_conditional_edges(
"agent",
parse,
{
"use_tool": "tools",
"end": "finish",
},
)
# It ends after generating a response with context
graph.add_edge("tools", "finish")
graph.add_edge("finish", END)
# Compile
compiled_graph = graph.compile()
return compiled_graph |