AbeerTrial's picture
Duplicate from AbeerTrial/SOAPAssist
35b22df
"""SQL StructDatapointExtractor."""
from typing import Any, Dict, Optional, cast
from sqlalchemy import Table
from gpt_index.data_structs.table import StructDatapoint
from gpt_index.indices.common.struct_store.base import (
OUTPUT_PARSER_TYPE,
BaseStructDatapointExtractor,
)
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.langchain_helpers.text_splitter import TextSplitter
from gpt_index.prompts.prompts import SchemaExtractPrompt
class SQLStructDatapointExtractor(BaseStructDatapointExtractor):
"""Extracts datapoints from a structured document for a SQL db."""
def __init__(
self,
llm_predictor: LLMPredictor,
text_splitter: TextSplitter,
schema_extract_prompt: SchemaExtractPrompt,
output_parser: OUTPUT_PARSER_TYPE,
sql_database: SQLDatabase,
table_name: Optional[str] = None,
table: Optional[Table] = None,
ref_doc_id_column: Optional[str] = None,
) -> None:
"""Initialize params."""
super().__init__(
llm_predictor, text_splitter, schema_extract_prompt, output_parser
)
self._sql_database = sql_database
# currently the user must specify a table info
if table_name is None and table is None:
raise ValueError("table_name must be specified")
self._table_name = table_name or cast(Table, table).name
if table is None:
table = self._sql_database.metadata_obj.tables[table_name]
# if ref_doc_id_column is specified, then we need to check that
# it is a valid column in the table
col_names = [c.name for c in table.c]
if ref_doc_id_column is not None and ref_doc_id_column not in col_names:
raise ValueError(
f"ref_doc_id_column {ref_doc_id_column} not in table {table_name}"
)
self.ref_doc_id_column = ref_doc_id_column
# then store python types of each column
self._col_types_map: Dict[str, type] = {
c.name: table.c[c.name].type.python_type for c in table.c
}
def _get_col_types_map(self) -> Dict[str, type]:
"""Get col types map for schema."""
return self._col_types_map
def _get_schema_text(self) -> str:
"""Insert datapoint into index."""
return self._sql_database.get_single_table_info(self._table_name)
def _insert_datapoint(self, datapoint: StructDatapoint) -> None:
"""Insert datapoint into index."""
datapoint_dict = datapoint.to_dict()["fields"]
self._sql_database.insert_into_table(
self._table_name, cast(Dict[Any, Any], datapoint_dict)
)