File size: 5,745 Bytes
35b22df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""SQL Container builder."""


from typing import Any, Dict, List, Optional, Type, Union

from gpt_index.indices.base import BaseGPTIndex
from gpt_index.indices.common.struct_store.base import SQLDocumentContextBuilder
from gpt_index.indices.common.struct_store.schema import SQLContextContainer
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.readers.base import Document
from gpt_index.schema import BaseDocument

DEFAULT_CONTEXT_QUERY_TMPL = (
    "Please return the relevant tables (including the full schema) "
    "for the following query: {orig_query_str}"
)


class SQLContextContainerBuilder:
    """SQLContextContainerBuilder.

    Build a SQLContextContainer that can be passed to the SQL index
    during index construction or during queryt-time.

    NOTE: if context_str is specified, that will be used as context
    instead of context_dict

    Args:
        sql_database (SQLDatabase): SQL database
        context_dict (Optional[Dict[str, str]]): context dict

    """

    def __init__(
        self,
        sql_database: SQLDatabase,
        context_dict: Optional[Dict[str, str]] = None,
        context_str: Optional[str] = None,
    ):
        """Initialize params."""
        self.sql_database = sql_database

        # if context_dict provided, validate that all keys are valid table names
        if context_dict is not None:
            # validate context_dict keys are valid table names
            context_keys = set(context_dict.keys())
            if not context_keys.issubset(set(self.sql_database.get_table_names())):
                raise ValueError(
                    "Invalid context table names: "
                    f"{context_keys - set(self.sql_database.get_table_names())}"
                )
        self.context_dict = context_dict or {}
        # build full context from sql_database
        self.full_context_dict = self._build_context_from_sql_database(
            self.sql_database, current_context=self.context_dict
        )
        self.context_str = context_str

    @classmethod
    def from_documents(
        cls,
        documents_dict: Dict[str, List[BaseDocument]],
        sql_database: SQLDatabase,
        **context_builder_kwargs: Any,
    ) -> "SQLContextContainerBuilder":
        """Build context from documents."""
        context_builder = SQLDocumentContextBuilder(
            sql_database, **context_builder_kwargs
        )
        context_dict = context_builder.build_all_context_from_documents(documents_dict)
        return SQLContextContainerBuilder(sql_database, context_dict=context_dict)

    def _build_context_from_sql_database(
        self,
        sql_database: SQLDatabase,
        current_context: Optional[Dict[str, str]] = None,
    ) -> Dict[str, str]:
        """Get tables schema + optional context as a single string."""
        current_context = current_context or {}
        result_context = {}
        for table_name in sql_database.get_table_names():
            table_desc = sql_database.get_single_table_info(table_name)
            table_text = f"Schema of table {table_name}:\n" f"{table_desc}\n"
            if table_name in current_context:
                table_text += f"Context of table {table_name}:\n"
                table_text += current_context[table_name]
            result_context[table_name] = table_text
        return result_context

    def _get_context_dict(self, ignore_db_schema: bool) -> Dict[str, str]:
        """Get full context dict."""
        if ignore_db_schema:
            return self.context_dict
        else:
            return self.full_context_dict

    def derive_index_from_context(
        self,
        index_cls: Type[BaseGPTIndex],
        ignore_db_schema: bool = False,
        **index_kwargs: Any,
    ) -> BaseGPTIndex:
        """Derive index from context."""
        full_context_dict = self._get_context_dict(ignore_db_schema)
        context_docs = []
        for table_name, context_str in full_context_dict.items():
            doc = Document(context_str, extra_info={"table_name": table_name})
            context_docs.append(doc)
        index = index_cls(
            documents=context_docs,
            **index_kwargs,
        )
        return index

    def query_index_for_context(
        self,
        index: BaseGPTIndex,
        query_str: Union[str, QueryBundle],
        query_tmpl: Optional[str] = DEFAULT_CONTEXT_QUERY_TMPL,
        store_context_str: bool = True,
        **index_kwargs: Any,
    ) -> str:
        """Query index for context.

        A simple wrapper around the index.query call which
        injects a query template to specifically fetch table information,
        and can store a context_str.

        Args:
            index (BaseGPTIndex): index data structure
            query_str (Union[str, QueryBundle]): query string
            query_tmpl (Optional[str]): query template
            store_context_str (bool): store context_str

        """
        if query_tmpl is None:
            context_query_str = query_str
        else:
            context_query_str = query_tmpl.format(orig_query_str=query_str)
        response = index.query(context_query_str, **index_kwargs)
        context_str = str(response.response)
        if store_context_str:
            self.context_str = context_str
        return context_str

    def build_context_container(
        self, ignore_db_schema: bool = False
    ) -> SQLContextContainer:
        """Build index structure."""
        full_context_dict = self._get_context_dict(ignore_db_schema)
        return SQLContextContainer(
            context_str=self.context_str,
            context_dict=full_context_dict,
        )