File size: 2,619 Bytes
b699122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""SQL StructDatapointExtractor."""

from typing import Any, Dict, Optional, cast

from sqlalchemy import Table

from gpt_index.data_structs.table import StructDatapoint
from gpt_index.indices.common.struct_store.base import (
    OUTPUT_PARSER_TYPE,
    BaseStructDatapointExtractor,
)
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.prompts.prompts import SchemaExtractPrompt


class SQLStructDatapointExtractor(BaseStructDatapointExtractor):
    """Extracts datapoints from a structured document for a SQL db."""

    def __init__(
        self,
        llm_predictor: LLMPredictor,
        schema_extract_prompt: SchemaExtractPrompt,
        output_parser: OUTPUT_PARSER_TYPE,
        sql_database: SQLDatabase,
        table_name: Optional[str] = None,
        table: Optional[Table] = None,
        ref_doc_id_column: Optional[str] = None,
    ) -> None:
        """Initialize params."""
        super().__init__(llm_predictor, schema_extract_prompt, output_parser)
        self._sql_database = sql_database
        # currently the user must specify a table info
        if table_name is None and table is None:
            raise ValueError("table_name must be specified")
        self._table_name = table_name or cast(Table, table).name
        if table is None:
            table = self._sql_database.metadata_obj.tables[table_name]
        # if ref_doc_id_column is specified, then we need to check that
        # it is a valid column in the table
        col_names = [c.name for c in table.c]
        if ref_doc_id_column is not None and ref_doc_id_column not in col_names:
            raise ValueError(
                f"ref_doc_id_column {ref_doc_id_column} not in table {table_name}"
            )
        self.ref_doc_id_column = ref_doc_id_column
        # then store python types of each column
        self._col_types_map: Dict[str, type] = {
            c.name: table.c[c.name].type.python_type for c in table.c
        }

    def _get_col_types_map(self) -> Dict[str, type]:
        """Get col types map for schema."""
        return self._col_types_map

    def _get_schema_text(self) -> str:
        """Insert datapoint into index."""
        return self._sql_database.get_single_table_info(self._table_name)

    def _insert_datapoint(self, datapoint: StructDatapoint) -> None:
        """Insert datapoint into index."""
        datapoint_dict = datapoint.to_dict()["fields"]
        self._sql_database.insert_into_table(
            self._table_name, cast(Dict[Any, Any], datapoint_dict)
        )