File size: 5,190 Bytes
d46cc41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from typing import List, Union, Dict, Tuple
from typing_extensions import NotRequired, TypedDict

from langchain_core.agents import (
    AgentAction, 
    AgentFinish
)
from langchain_core.messages import BaseMessage, AIMessage
from langgraph.graph import END, StateGraph

from runnables import Answerer, Agent
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE

OUTPUT_KEY = "response"

def _get_graph(
        agent_model_name: str = "gpt-4-turbo",
        agent_system_template: str = _AGENT_SYSTEM_TEMPLATE,
        agent_temperature: float = 0.0,
        answerer_model_name: str = "gpt-4-turbo",
        answerer_system_template: str = _ANSWERER_SYSTEM_TEMPLATE,
        answerer_temperature: float = 0.0,
        collection_index:int = 0,
        use_doctrines:bool = True,
        search_type:str = "similarity",
        similarity_threshold:float = 0.0,
        k:int = 15,
    ):

    agent = Agent(
        model_name = agent_model_name,
        system_template = agent_system_template,
        temperature = agent_temperature,
    )

    agent_runnable = agent.runnable

    answerer = Answerer(
        model_name = answerer_model_name,
        system_template = answerer_system_template,
        temperature = answerer_temperature,
        collection_index = collection_index,
        use_doctrines = use_doctrines,
        search_type = search_type,
        similarity_threshold = similarity_threshold,
        k = k,
    )
    answerer_runnable = answerer.runnable

    # GRAPH
    class GraphState(TypedDict):
        query: str

        agent_outcome: NotRequired[
            Union[AgentAction, AgentFinish]
        ]

        chat_history: List[BaseMessage]

        response: NotRequired[
            Dict[
                str,
                Union[
                    str, 
                    List[int],
                    List[Dict[str, Union[int, str]]]
                ]
            ]
        ]
        
    ## Nodes' functions
    async def execute_agent(
            state: GraphState,
            config: Dict,
        ) -> Union[AgentAction, AgentFinish, None]:
        """
        Invokes the agent model to generate a response based on the current state.

        This function calls the agent model to generate a response to the current conversation state.

        Args:
            state (messages): The current state of the agent.

        Returns:
            dict: The new agent outcome.
        """
        
        inputs = state.copy()
        
        agent_outcome = await agent_runnable \
            .with_config({"run_name": "agent_node"}) \
            .ainvoke(inputs, config=config)

        return {"agent_outcome": agent_outcome}

    def execute_tool(
            state: GraphState,
            config: Dict,
        ) -> List[Tuple[AgentAction, str]]:

        """
        Executes the Retrieve tool.

        Args:
            state (messages): The current state of the agent.

        Returns:
            dict: The final response.
        """
            
        inputs = state["agent_outcome"][0].tool_input

        tool_output = answerer_runnable.invoke(
            {"query": inputs["standalone_question"]}, 
            config=config
        )
        
        return { 
            OUTPUT_KEY: tool_output
        }

    def finish(
            state: GraphState
        ) -> None:

        if state[OUTPUT_KEY] is not None:
            response = state[OUTPUT_KEY]
        else:
            response = {
                "answer": AIMessage(state['agent_outcome'].return_values['output']),
                "docs": [],
                "standalone_question": None
            }
        
        return {OUTPUT_KEY: response}

    ## Edges' functions
    def parse(
            state: GraphState
        ) -> str:
        """
        Router based on the previous agent outcome.

        This function checks the agent outcome to determine if the agent decided to finish the conversation.
        In that case it ends the process, otherwise it calls a tool.

        Args:
            state (messages): The current state of the agent.
        Returns:
            str: A decision to either "end", "use_tool".
        """

        agent_outcome = state["agent_outcome"]

        if isinstance(agent_outcome, AgentFinish):
            return "end"
        elif isinstance(agent_outcome, List):
            agent_outcome = agent_outcome[0]
            if agent_outcome.tool is not None:
                    return "use_tool"
        
    ## Graph
    graph = StateGraph(GraphState)

    # Define the nodes
    graph.add_node("agent", execute_agent)
    graph.add_node("tools", execute_tool)
    graph.add_node("finish", finish)

    # Start from the agent node
    graph.set_entry_point("agent")

    # Decide whether to use a tool or end the process
    graph.add_conditional_edges(
        "agent",
        parse,
        {
            "use_tool": "tools",
            "end": "finish",
        },
    )

    # It ends after generating a response with context
    graph.add_edge("tools", "finish")
    graph.add_edge("finish", END)

    # Compile
    compiled_graph = graph.compile()

    return compiled_graph