Spaces:
Runtime error
Runtime error
"""Common classes for structured operations.""" | |
import logging | |
from abc import abstractmethod | |
from typing import Any, Callable, Dict, List, Optional, Sequence, cast | |
from gpt_index.data_structs.table import StructDatapoint | |
from gpt_index.indices.prompt_helper import PromptHelper | |
from gpt_index.indices.response.builder import ResponseBuilder, TextChunk | |
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor | |
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase | |
from gpt_index.langchain_helpers.text_splitter import TextSplitter | |
from gpt_index.prompts.default_prompt_selectors import ( | |
DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL, | |
) | |
from gpt_index.prompts.default_prompts import ( | |
DEFAULT_TABLE_CONTEXT_PROMPT, | |
DEFAULT_TABLE_CONTEXT_QUERY, | |
) | |
from gpt_index.prompts.prompts import ( | |
QuestionAnswerPrompt, | |
RefinePrompt, | |
RefineTableContextPrompt, | |
SchemaExtractPrompt, | |
TableContextPrompt, | |
) | |
from gpt_index.schema import BaseDocument | |
from gpt_index.utils import truncate_text | |
class SQLDocumentContextBuilder: | |
"""Builder that builds context for a given set of SQL tables. | |
Args: | |
sql_database (Optional[SQLDatabase]): SQL database to use, | |
llm_predictor (Optional[LLMPredictor]): LLM Predictor to use. | |
prompt_helper (Optional[PromptHelper]): Prompt Helper to use. | |
text_splitter (Optional[TextSplitter]): Text Splitter to use. | |
table_context_prompt (Optional[TableContextPrompt]): A | |
Table Context Prompt (see :ref:`Prompt-Templates`). | |
refine_table_context_prompt (Optional[RefineTableContextPrompt]): | |
A Refine Table Context Prompt (see :ref:`Prompt-Templates`). | |
table_context_task (Optional[str]): The query to perform | |
on the table context. A default query string is used | |
if none is provided by the user. | |
""" | |
def __init__( | |
self, | |
sql_database: SQLDatabase, | |
llm_predictor: Optional[LLMPredictor] = None, | |
prompt_helper: Optional[PromptHelper] = None, | |
text_splitter: Optional[TextSplitter] = None, | |
table_context_prompt: Optional[TableContextPrompt] = None, | |
refine_table_context_prompt: Optional[RefineTableContextPrompt] = None, | |
table_context_task: Optional[str] = None, | |
) -> None: | |
"""Initialize params.""" | |
# TODO: take in an entire index instead of forming a response builder | |
if sql_database is None: | |
raise ValueError("sql_database must be provided.") | |
self._sql_database = sql_database | |
self._llm_predictor = llm_predictor or LLMPredictor() | |
self._prompt_helper = prompt_helper or PromptHelper.from_llm_predictor( | |
self._llm_predictor | |
) | |
self._text_splitter = text_splitter | |
self._table_context_prompt = ( | |
table_context_prompt or DEFAULT_TABLE_CONTEXT_PROMPT | |
) | |
self._refine_table_context_prompt = ( | |
refine_table_context_prompt or DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL | |
) | |
self._table_context_task = table_context_task or DEFAULT_TABLE_CONTEXT_QUERY | |
def build_all_context_from_documents( | |
self, | |
documents_dict: Dict[str, List[BaseDocument]], | |
) -> Dict[str, str]: | |
"""Build context for all tables in the database.""" | |
context_dict = {} | |
for table_name in self._sql_database.get_table_names(): | |
context_dict[table_name] = self.build_table_context_from_documents( | |
documents_dict[table_name], table_name | |
) | |
return context_dict | |
def build_table_context_from_documents( | |
self, | |
documents: Sequence[BaseDocument], | |
table_name: str, | |
) -> str: | |
"""Build context from documents for a single table.""" | |
schema = self._sql_database.get_single_table_info(table_name) | |
prompt_with_schema = QuestionAnswerPrompt.from_prompt( | |
self._table_context_prompt.partial_format(schema=schema) | |
) | |
refine_prompt_with_schema = RefinePrompt.from_prompt( | |
self._refine_table_context_prompt.partial_format(schema=schema) | |
) | |
text_splitter = ( | |
self._text_splitter | |
or self._prompt_helper.get_text_splitter_given_prompt(prompt_with_schema, 1) | |
) | |
# we use the ResponseBuilder to iteratively go through all texts | |
response_builder = ResponseBuilder( | |
self._prompt_helper, | |
self._llm_predictor, | |
prompt_with_schema, | |
refine_prompt_with_schema, | |
) | |
for doc in documents: | |
text_chunks = text_splitter.split_text(doc.get_text()) | |
for text_chunk in text_chunks: | |
response_builder.add_text_chunks([TextChunk(text_chunk)]) | |
# feed in the "query_str" or the task | |
table_context = response_builder.get_response(self._table_context_task) | |
return cast(str, table_context) | |
OUTPUT_PARSER_TYPE = Callable[[str], Optional[Dict[str, Any]]] | |
class BaseStructDatapointExtractor: | |
"""Extracts datapoints from a structured document.""" | |
def __init__( | |
self, | |
llm_predictor: LLMPredictor, | |
text_splitter: TextSplitter, | |
schema_extract_prompt: SchemaExtractPrompt, | |
output_parser: OUTPUT_PARSER_TYPE, | |
) -> None: | |
"""Initialize params.""" | |
self._llm_predictor = llm_predictor | |
self._text_splitter = text_splitter | |
self._schema_extract_prompt = schema_extract_prompt | |
self._output_parser = output_parser | |
def _clean_and_validate_fields(self, fields: Dict[str, Any]) -> Dict[str, Any]: | |
"""Validate fields with col_types_map.""" | |
new_fields = {} | |
col_types_map = self._get_col_types_map() | |
for field, value in fields.items(): | |
clean_value = value | |
if field not in col_types_map: | |
continue | |
# if expected type is int or float, try to convert value to int or float | |
expected_type = col_types_map[field] | |
if expected_type == int: | |
try: | |
clean_value = int(value) | |
except ValueError: | |
continue | |
elif expected_type == float: | |
try: | |
clean_value = float(value) | |
except ValueError: | |
continue | |
else: | |
if len(value) == 0: | |
continue | |
if not isinstance(value, col_types_map[field]): | |
continue | |
new_fields[field] = clean_value | |
return new_fields | |
def _insert_datapoint(self, datapoint: StructDatapoint) -> None: | |
"""Insert datapoint into index.""" | |
def _get_col_types_map(self) -> Dict[str, type]: | |
"""Get col types map for schema.""" | |
def _get_schema_text(self) -> str: | |
"""Get schema text for extracting relevant info from unstructured text.""" | |
def insert_datapoint_from_document(self, document: BaseDocument) -> None: | |
"""Extract datapoint from a document and insert it.""" | |
text_chunks = self._text_splitter.split_text(document.get_text()) | |
fields = {} | |
for i, text_chunk in enumerate(text_chunks): | |
fmt_text_chunk = truncate_text(text_chunk, 50) | |
logging.info(f"> Adding chunk {i}: {fmt_text_chunk}") | |
# if embedding specified in document, pass it to the Node | |
schema_text = self._get_schema_text() | |
response_str, _ = self._llm_predictor.predict( | |
self._schema_extract_prompt, | |
text=text_chunk, | |
schema=schema_text, | |
) | |
cur_fields = self._output_parser(response_str) | |
if cur_fields is None: | |
continue | |
# validate fields with col_types_map | |
new_cur_fields = self._clean_and_validate_fields(cur_fields) | |
fields.update(new_cur_fields) | |
struct_datapoint = StructDatapoint(fields) | |
if struct_datapoint is not None: | |
self._insert_datapoint(struct_datapoint) | |
logging.debug(f"> Added datapoint: {fields}") | |