Spaces:
Running
Running
import logging | |
from typing import List, TYPE_CHECKING, Optional | |
from datetime import datetime | |
import pytz | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.messages import ( | |
BaseMessage, | |
message_to_dict, | |
messages_from_dict, | |
) | |
from langchain_core.utils import get_from_env | |
if TYPE_CHECKING: | |
from supabase import Client | |
logger = logging.getLogger(__name__) | |
class SupabaseChatMessageHistory(BaseChatMessageHistory): | |
"""Chat message history stored in a Supabase project database.""" | |
def __init__( | |
self, | |
session_id: str, | |
table_name: str = "message_store", | |
session_name: str = "session", | |
client: Optional['Client'] = None, | |
supabase_url: Optional[str] = None, | |
supabase_key: Optional[str] = None, | |
): | |
try: | |
from supabase import create_client | |
except ImportError: | |
raise ImportError( | |
"Could not import supabase python package. " | |
"Please install it with `pip install supabase`." | |
) | |
# Make sure session id is not null | |
if not session_id: | |
raise ValueError("Please ensure that the session_id parameter is provided") | |
self.client = client | |
if client is None: | |
supabase_url = get_from_env("url", "SUPABASE_URL", supabase_url) | |
supabase_key = get_from_env("key", "SUPABASE_KEY", supabase_key) | |
self.client = create_client( | |
supabase_url=supabase_url, | |
supabase_key=supabase_key | |
) | |
self.session_id = session_id | |
self.table_name = table_name | |
self.session_name = session_name | |
def messages(self) -> List[BaseMessage]: | |
"""Retrieve the messages from the Supabase project database""" | |
response = self.client.table(self.table_name) \ | |
.select("id", "query_id", "message", "error_log") \ | |
.eq(f"{self.session_name}_id", self.session_id) \ | |
.order('created_at', desc=False) \ | |
.execute() | |
failed_messages = [record for record in response.data if record["message"]["data"]["content"] == "" or record["error_log"] is not None] | |
failed_ids = [] | |
for failed_message in failed_messages: | |
failed_ids.extend([failed_message["id"], failed_message["query_id"]]) | |
items = [record["message"] for record in response.data if record["id"] not in failed_ids] | |
messages = messages_from_dict(items) | |
return messages | |
def add_message(self, message: BaseMessage, query_id: Optional[str] = None) -> None: | |
"""Append the message to the record in the Supabase project database""" | |
response = self.client.table(self.table_name).insert( | |
{ | |
f"{self.session_name}_id": self.session_id, | |
"message": message_to_dict(message), | |
"query_id": query_id, | |
} | |
).execute() | |
return response.data[0]["id"] | |
def update_message( | |
self, | |
message_id:str, | |
message: Optional[BaseMessage] = None, | |
error_log: Optional[dict] = None | |
) -> None: | |
"""Append the message to the record in the Supabase project database""" | |
updated_dict = { | |
"updated_at": datetime.now(pytz.utc).isoformat() | |
} | |
if message is not None: | |
updated_dict["message"] = message_to_dict(message) | |
if error_log is not None: | |
updated_dict["error_log"] = error_log | |
self.client.table(self.table_name).update(updated_dict) \ | |
.eq('id', message_id) \ | |
.execute() | |
def clear(self) -> None: | |
"""Clear session memory from the Supabase project database""" | |
self.client.table(self.table_name) \ | |
.delete() \ | |
.eq(f"{self.session_name}_id", self.session_id) \ | |
.execute() | |