File size: 3,605 Bytes
d46cc41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from utils_app import _update_session, supabase_client, _get_session_messages, _add_footnote_description
from supabase_memory import SupabaseChatMessageHistory
from graph import _get_graph
from langchain_core.messages import AIMessage, HumanMessage
import os
import pandas as pd
from prompts import _AGENT_SYSTEM_TEMPLATE, _ANSWERER_SYSTEM_TEMPLATE

async def _run_graph(
        session_id:str,
        input:str,
        agent_model_name:str = "gpt-4o",
        agent_temperature:float = 0.0, 
        answerer_model_name:str = "claude-3-5-sonnet-20240620", 
        answerer_temperature:float = 0.0, 
        collection_index:int = 0, 
        use_doctrines:bool = True, 
        search_type:str = "similarity", 
        k:int = 10, 
        similarity_threshold:float = 0.65, 
        agent_system_prompt_template:str = _AGENT_SYSTEM_TEMPLATE, 
        answerer_system_prompt_template:str = _ANSWERER_SYSTEM_TEMPLATE,
    ) :

    memory = SupabaseChatMessageHistory(
        session_id = session_id,
        table_name = os.environ["MESSAGES_TABLE_NAME"],
        session_name = "chat",
        client = supabase_client,
    )

    _update_session(
        session_id, 
        metadata = {
            "agent_model_name": agent_model_name,
            "agent_temperature": agent_temperature,
            "answerer_model_name": answerer_model_name,
            "answerer_temperature": answerer_temperature,
            "collection_index": collection_index,
            "use_doctrines": use_doctrines,
            "search_type": search_type,
            "k": k,
            "similarity_threshold": similarity_threshold,
            "agent_system_prompt_template": agent_system_prompt_template,
            "answerer_system_prompt_template": answerer_system_prompt_template,
        }
    )

    graph = _get_graph(
        agent_model_name = agent_model_name,
        agent_system_template = agent_system_prompt_template,
        agent_temperature = agent_temperature,
        answerer_model_name = answerer_model_name,
        answerer_system_template = answerer_system_prompt_template,
        answerer_temperature = answerer_temperature,
        collection_index = collection_index,
        use_doctrines = use_doctrines,
        search_type = search_type,
        similarity_threshold = similarity_threshold,
        k = k,
    )

    chat_history = memory.messages

    input_message_id = memory.add_message(
        message = HumanMessage(input)
    )

    output_message_id = memory.add_message(
        message = AIMessage(""),
        query_id = input_message_id
    )
    
    try:
        final_state = await graph.ainvoke(
            input = {
                "query": input,
                "chat_history": chat_history,
            }
        )

        response_message = final_state["response"]["answer"]
        response_message.response_metadata["docs"] = [doc[0].metadata for doc in final_state["response"]["docs"]]
        response_message.response_metadata["standalone_question"] = final_state["response"]["standalone_question"]

        response_message.content = _add_footnote_description(response_message.content, response_message.response_metadata["docs"])

        memory.update_message(
            message = response_message,
            message_id = output_message_id
        )

        return _get_session_messages(session_id)

    except Exception as e:

        memory.update_message(
            message_id = output_message_id,
            error_log = str(e)
        )

        return _get_session_messages(session_id) + [(input, f"Oops! An error occurred: {str(e)}")]