"""SQL Container builder.""" from typing import Any, Dict, List, Optional, Type, Union from gpt_index.indices.base import BaseGPTIndex from gpt_index.indices.common.struct_store.base import SQLDocumentContextBuilder from gpt_index.indices.common.struct_store.schema import SQLContextContainer from gpt_index.indices.query.schema import QueryBundle from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase from gpt_index.readers.base import Document from gpt_index.schema import BaseDocument DEFAULT_CONTEXT_QUERY_TMPL = ( "Please return the relevant tables (including the full schema) " "for the following query: {orig_query_str}" ) class SQLContextContainerBuilder: """SQLContextContainerBuilder. Build a SQLContextContainer that can be passed to the SQL index during index construction or during queryt-time. NOTE: if context_str is specified, that will be used as context instead of context_dict Args: sql_database (SQLDatabase): SQL database context_dict (Optional[Dict[str, str]]): context dict """ def __init__( self, sql_database: SQLDatabase, context_dict: Optional[Dict[str, str]] = None, context_str: Optional[str] = None, ): """Initialize params.""" self.sql_database = sql_database # if context_dict provided, validate that all keys are valid table names if context_dict is not None: # validate context_dict keys are valid table names context_keys = set(context_dict.keys()) if not context_keys.issubset(set(self.sql_database.get_table_names())): raise ValueError( "Invalid context table names: " f"{context_keys - set(self.sql_database.get_table_names())}" ) self.context_dict = context_dict or {} # build full context from sql_database self.full_context_dict = self._build_context_from_sql_database( self.sql_database, current_context=self.context_dict ) self.context_str = context_str @classmethod def from_documents( cls, documents_dict: Dict[str, List[BaseDocument]], sql_database: SQLDatabase, **context_builder_kwargs: Any, ) -> "SQLContextContainerBuilder": """Build context from documents.""" context_builder = SQLDocumentContextBuilder( sql_database, **context_builder_kwargs ) context_dict = context_builder.build_all_context_from_documents(documents_dict) return SQLContextContainerBuilder(sql_database, context_dict=context_dict) def _build_context_from_sql_database( self, sql_database: SQLDatabase, current_context: Optional[Dict[str, str]] = None, ) -> Dict[str, str]: """Get tables schema + optional context as a single string.""" current_context = current_context or {} result_context = {} for table_name in sql_database.get_table_names(): table_desc = sql_database.get_single_table_info(table_name) table_text = f"Schema of table {table_name}:\n" f"{table_desc}\n" if table_name in current_context: table_text += f"Context of table {table_name}:\n" table_text += current_context[table_name] result_context[table_name] = table_text return result_context def _get_context_dict(self, ignore_db_schema: bool) -> Dict[str, str]: """Get full context dict.""" if ignore_db_schema: return self.context_dict else: return self.full_context_dict def derive_index_from_context( self, index_cls: Type[BaseGPTIndex], ignore_db_schema: bool = False, **index_kwargs: Any, ) -> BaseGPTIndex: """Derive index from context.""" full_context_dict = self._get_context_dict(ignore_db_schema) context_docs = [] for table_name, context_str in full_context_dict.items(): doc = Document(context_str, extra_info={"table_name": table_name}) context_docs.append(doc) index = index_cls( documents=context_docs, **index_kwargs, ) return index def query_index_for_context( self, index: BaseGPTIndex, query_str: Union[str, QueryBundle], query_tmpl: Optional[str] = DEFAULT_CONTEXT_QUERY_TMPL, store_context_str: bool = True, **index_kwargs: Any, ) -> str: """Query index for context. A simple wrapper around the index.query call which injects a query template to specifically fetch table information, and can store a context_str. Args: index (BaseGPTIndex): index data structure query_str (Union[str, QueryBundle]): query string query_tmpl (Optional[str]): query template store_context_str (bool): store context_str """ if query_tmpl is None: context_query_str = query_str else: context_query_str = query_tmpl.format(orig_query_str=query_str) response = index.query(context_query_str, **index_kwargs) context_str = str(response.response) if store_context_str: self.context_str = context_str return context_str def build_context_container( self, ignore_db_schema: bool = False ) -> SQLContextContainer: """Build index structure.""" full_context_dict = self._get_context_dict(ignore_db_schema) return SQLContextContainer( context_str=self.context_str, context_dict=full_context_dict, )