Spaces:
Runtime error
Runtime error
File size: 5,745 Bytes
35b22df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""SQL Container builder."""
from typing import Any, Dict, List, Optional, Type, Union
from gpt_index.indices.base import BaseGPTIndex
from gpt_index.indices.common.struct_store.base import SQLDocumentContextBuilder
from gpt_index.indices.common.struct_store.schema import SQLContextContainer
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.readers.base import Document
from gpt_index.schema import BaseDocument
DEFAULT_CONTEXT_QUERY_TMPL = (
"Please return the relevant tables (including the full schema) "
"for the following query: {orig_query_str}"
)
class SQLContextContainerBuilder:
"""SQLContextContainerBuilder.
Build a SQLContextContainer that can be passed to the SQL index
during index construction or during queryt-time.
NOTE: if context_str is specified, that will be used as context
instead of context_dict
Args:
sql_database (SQLDatabase): SQL database
context_dict (Optional[Dict[str, str]]): context dict
"""
def __init__(
self,
sql_database: SQLDatabase,
context_dict: Optional[Dict[str, str]] = None,
context_str: Optional[str] = None,
):
"""Initialize params."""
self.sql_database = sql_database
# if context_dict provided, validate that all keys are valid table names
if context_dict is not None:
# validate context_dict keys are valid table names
context_keys = set(context_dict.keys())
if not context_keys.issubset(set(self.sql_database.get_table_names())):
raise ValueError(
"Invalid context table names: "
f"{context_keys - set(self.sql_database.get_table_names())}"
)
self.context_dict = context_dict or {}
# build full context from sql_database
self.full_context_dict = self._build_context_from_sql_database(
self.sql_database, current_context=self.context_dict
)
self.context_str = context_str
@classmethod
def from_documents(
cls,
documents_dict: Dict[str, List[BaseDocument]],
sql_database: SQLDatabase,
**context_builder_kwargs: Any,
) -> "SQLContextContainerBuilder":
"""Build context from documents."""
context_builder = SQLDocumentContextBuilder(
sql_database, **context_builder_kwargs
)
context_dict = context_builder.build_all_context_from_documents(documents_dict)
return SQLContextContainerBuilder(sql_database, context_dict=context_dict)
def _build_context_from_sql_database(
self,
sql_database: SQLDatabase,
current_context: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""Get tables schema + optional context as a single string."""
current_context = current_context or {}
result_context = {}
for table_name in sql_database.get_table_names():
table_desc = sql_database.get_single_table_info(table_name)
table_text = f"Schema of table {table_name}:\n" f"{table_desc}\n"
if table_name in current_context:
table_text += f"Context of table {table_name}:\n"
table_text += current_context[table_name]
result_context[table_name] = table_text
return result_context
def _get_context_dict(self, ignore_db_schema: bool) -> Dict[str, str]:
"""Get full context dict."""
if ignore_db_schema:
return self.context_dict
else:
return self.full_context_dict
def derive_index_from_context(
self,
index_cls: Type[BaseGPTIndex],
ignore_db_schema: bool = False,
**index_kwargs: Any,
) -> BaseGPTIndex:
"""Derive index from context."""
full_context_dict = self._get_context_dict(ignore_db_schema)
context_docs = []
for table_name, context_str in full_context_dict.items():
doc = Document(context_str, extra_info={"table_name": table_name})
context_docs.append(doc)
index = index_cls(
documents=context_docs,
**index_kwargs,
)
return index
def query_index_for_context(
self,
index: BaseGPTIndex,
query_str: Union[str, QueryBundle],
query_tmpl: Optional[str] = DEFAULT_CONTEXT_QUERY_TMPL,
store_context_str: bool = True,
**index_kwargs: Any,
) -> str:
"""Query index for context.
A simple wrapper around the index.query call which
injects a query template to specifically fetch table information,
and can store a context_str.
Args:
index (BaseGPTIndex): index data structure
query_str (Union[str, QueryBundle]): query string
query_tmpl (Optional[str]): query template
store_context_str (bool): store context_str
"""
if query_tmpl is None:
context_query_str = query_str
else:
context_query_str = query_tmpl.format(orig_query_str=query_str)
response = index.query(context_query_str, **index_kwargs)
context_str = str(response.response)
if store_context_str:
self.context_str = context_str
return context_str
def build_context_container(
self, ignore_db_schema: bool = False
) -> SQLContextContainer:
"""Build index structure."""
full_context_dict = self._get_context_dict(ignore_db_schema)
return SQLContextContainer(
context_str=self.context_str,
context_dict=full_context_dict,
)
|