File size: 3,987 Bytes
d46cc41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import logging
from typing import List, TYPE_CHECKING, Optional
from datetime import datetime
import pytz

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
    BaseMessage,
    message_to_dict,
    messages_from_dict,
)
from langchain_core.utils import get_from_env

if TYPE_CHECKING:
    from supabase import Client

logger = logging.getLogger(__name__)

class SupabaseChatMessageHistory(BaseChatMessageHistory):
    """Chat message history stored in a Supabase project database."""

    def __init__(
        self,
        session_id: str,
        table_name: str = "message_store",
        session_name: str = "session",
        client: Optional['Client'] = None,
        supabase_url: Optional[str] = None,
        supabase_key: Optional[str] = None,
    ):
        try:
            from supabase import create_client
        except ImportError:
            raise ImportError(
                "Could not import supabase python package. "
                "Please install it with `pip install supabase`."
            )

        # Make sure session id is not null
        if not session_id:
            raise ValueError("Please ensure that the session_id parameter is provided")

        self.client = client
        if client is None:
            supabase_url = get_from_env("url", "SUPABASE_URL", supabase_url)
            supabase_key = get_from_env("key", "SUPABASE_KEY", supabase_key)

            self.client = create_client(
                supabase_url=supabase_url, 
                supabase_key=supabase_key
            )

        self.session_id = session_id
        self.table_name = table_name
        self.session_name = session_name

    @property
    def messages(self) -> List[BaseMessage]:
        """Retrieve the messages from the Supabase project database"""
        response = self.client.table(self.table_name) \
            .select("id", "query_id", "message", "error_log") \
            .eq(f"{self.session_name}_id", self.session_id) \
            .order('created_at', desc=False) \
            .execute()
        
        failed_messages = [record for record in response.data if record["message"]["data"]["content"] == "" or record["error_log"] is not None]

        failed_ids = []
        for failed_message in failed_messages:
            failed_ids.extend([failed_message["id"], failed_message["query_id"]])

        items = [record["message"] for record in response.data if record["id"] not in failed_ids]
        messages = messages_from_dict(items)
        
        return messages

    def add_message(self, message: BaseMessage, query_id: Optional[str] = None) -> None:
        """Append the message to the record in the Supabase project database"""
        response = self.client.table(self.table_name).insert(
            {
                f"{self.session_name}_id": self.session_id,
                "message": message_to_dict(message),
                "query_id": query_id,
            }
        ).execute()

        return response.data[0]["id"]
    
    def update_message(
            self, 
            message_id:str, 
            message: Optional[BaseMessage] = None, 
            error_log: Optional[dict] = None
        ) -> None:
        """Append the message to the record in the Supabase project database"""

        updated_dict = {
            "updated_at": datetime.now(pytz.utc).isoformat()
        }
        
        if message is not None:
            updated_dict["message"] = message_to_dict(message)
        
        if error_log is not None:
            updated_dict["error_log"] = error_log


        self.client.table(self.table_name).update(updated_dict) \
            .eq('id', message_id) \
            .execute()


    def clear(self) -> None:
        """Clear session memory from the Supabase project database"""
        self.client.table(self.table_name) \
            .delete() \
            .eq(f"{self.session_name}_id", self.session_id) \
            .execute()