"""SQL StructDatapointExtractor.""" from typing import Any, Dict, Optional, cast from sqlalchemy import Table from gpt_index.data_structs.table import StructDatapoint from gpt_index.indices.common.struct_store.base import ( OUTPUT_PARSER_TYPE, BaseStructDatapointExtractor, ) from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase from gpt_index.langchain_helpers.text_splitter import TextSplitter from gpt_index.prompts.prompts import SchemaExtractPrompt class SQLStructDatapointExtractor(BaseStructDatapointExtractor): """Extracts datapoints from a structured document for a SQL db.""" def __init__( self, llm_predictor: LLMPredictor, text_splitter: TextSplitter, schema_extract_prompt: SchemaExtractPrompt, output_parser: OUTPUT_PARSER_TYPE, sql_database: SQLDatabase, table_name: Optional[str] = None, table: Optional[Table] = None, ref_doc_id_column: Optional[str] = None, ) -> None: """Initialize params.""" super().__init__( llm_predictor, text_splitter, schema_extract_prompt, output_parser ) self._sql_database = sql_database # currently the user must specify a table info if table_name is None and table is None: raise ValueError("table_name must be specified") self._table_name = table_name or cast(Table, table).name if table is None: table = self._sql_database.metadata_obj.tables[table_name] # if ref_doc_id_column is specified, then we need to check that # it is a valid column in the table col_names = [c.name for c in table.c] if ref_doc_id_column is not None and ref_doc_id_column not in col_names: raise ValueError( f"ref_doc_id_column {ref_doc_id_column} not in table {table_name}" ) self.ref_doc_id_column = ref_doc_id_column # then store python types of each column self._col_types_map: Dict[str, type] = { c.name: table.c[c.name].type.python_type for c in table.c } def _get_col_types_map(self) -> Dict[str, type]: """Get col types map for schema.""" return self._col_types_map def _get_schema_text(self) -> str: """Insert datapoint into index.""" return self._sql_database.get_single_table_info(self._table_name) def _insert_datapoint(self, datapoint: StructDatapoint) -> None: """Insert datapoint into index.""" datapoint_dict = datapoint.to_dict()["fields"] self._sql_database.insert_into_table( self._table_name, cast(Dict[Any, Any], datapoint_dict) )