Spaces:
Running
Running
from utils_app import _update_session, supabase_client, _get_session_messages, _add_footnote_description | |
from supabase_memory import SupabaseChatMessageHistory | |
from graph import _get_graph | |
from langchain_core.messages import AIMessage, HumanMessage | |
import os | |
import pandas as pd | |
from prompts import _AGENT_SYSTEM_TEMPLATE, _ANSWERER_SYSTEM_TEMPLATE | |
async def _run_graph( | |
session_id:str, | |
input:str, | |
agent_model_name:str = "gpt-4o", | |
agent_temperature:float = 0.0, | |
answerer_model_name:str = "claude-3-5-sonnet-20240620", | |
answerer_temperature:float = 0.0, | |
collection_index:int = 0, | |
use_doctrines:bool = True, | |
search_type:str = "similarity", | |
k:int = 10, | |
similarity_threshold:float = 0.65, | |
agent_system_prompt_template:str = _AGENT_SYSTEM_TEMPLATE, | |
answerer_system_prompt_template:str = _ANSWERER_SYSTEM_TEMPLATE, | |
) : | |
memory = SupabaseChatMessageHistory( | |
session_id = session_id, | |
table_name = os.environ["MESSAGES_TABLE_NAME"], | |
session_name = "chat", | |
client = supabase_client, | |
) | |
_update_session( | |
session_id, | |
metadata = { | |
"agent_model_name": agent_model_name, | |
"agent_temperature": agent_temperature, | |
"answerer_model_name": answerer_model_name, | |
"answerer_temperature": answerer_temperature, | |
"collection_index": collection_index, | |
"use_doctrines": use_doctrines, | |
"search_type": search_type, | |
"k": k, | |
"similarity_threshold": similarity_threshold, | |
"agent_system_prompt_template": agent_system_prompt_template, | |
"answerer_system_prompt_template": answerer_system_prompt_template, | |
} | |
) | |
graph = _get_graph( | |
agent_model_name = agent_model_name, | |
agent_system_template = agent_system_prompt_template, | |
agent_temperature = agent_temperature, | |
answerer_model_name = answerer_model_name, | |
answerer_system_template = answerer_system_prompt_template, | |
answerer_temperature = answerer_temperature, | |
collection_index = collection_index, | |
use_doctrines = use_doctrines, | |
search_type = search_type, | |
similarity_threshold = similarity_threshold, | |
k = k, | |
) | |
chat_history = memory.messages | |
input_message_id = memory.add_message( | |
message = HumanMessage(input) | |
) | |
output_message_id = memory.add_message( | |
message = AIMessage(""), | |
query_id = input_message_id | |
) | |
try: | |
final_state = await graph.ainvoke( | |
input = { | |
"query": input, | |
"chat_history": chat_history, | |
} | |
) | |
response_message = final_state["response"]["answer"] | |
response_message.response_metadata["docs"] = [doc[0].metadata for doc in final_state["response"]["docs"]] | |
response_message.response_metadata["standalone_question"] = final_state["response"]["standalone_question"] | |
response_message.content = _add_footnote_description(response_message.content, response_message.response_metadata["docs"]) | |
memory.update_message( | |
message = response_message, | |
message_id = output_message_id | |
) | |
return _get_session_messages(session_id) | |
except Exception as e: | |
memory.update_message( | |
message_id = output_message_id, | |
error_log = str(e) | |
) | |
return _get_session_messages(session_id) + [(input, f"Oops! An error occurred: {str(e)}")] |