AbeerTrial's picture
Duplicate from AbeerTrial/SOAPAssist
35b22df
"""Common classes for structured operations."""
import logging
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Optional, Sequence, cast
from gpt_index.data_structs.table import StructDatapoint
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.response.builder import ResponseBuilder, TextChunk
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.langchain_helpers.text_splitter import TextSplitter
from gpt_index.prompts.default_prompt_selectors import (
DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL,
)
from gpt_index.prompts.default_prompts import (
DEFAULT_TABLE_CONTEXT_PROMPT,
DEFAULT_TABLE_CONTEXT_QUERY,
)
from gpt_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
RefineTableContextPrompt,
SchemaExtractPrompt,
TableContextPrompt,
)
from gpt_index.schema import BaseDocument
from gpt_index.utils import truncate_text
class SQLDocumentContextBuilder:
"""Builder that builds context for a given set of SQL tables.
Args:
sql_database (Optional[SQLDatabase]): SQL database to use,
llm_predictor (Optional[LLMPredictor]): LLM Predictor to use.
prompt_helper (Optional[PromptHelper]): Prompt Helper to use.
text_splitter (Optional[TextSplitter]): Text Splitter to use.
table_context_prompt (Optional[TableContextPrompt]): A
Table Context Prompt (see :ref:`Prompt-Templates`).
refine_table_context_prompt (Optional[RefineTableContextPrompt]):
A Refine Table Context Prompt (see :ref:`Prompt-Templates`).
table_context_task (Optional[str]): The query to perform
on the table context. A default query string is used
if none is provided by the user.
"""
def __init__(
self,
sql_database: SQLDatabase,
llm_predictor: Optional[LLMPredictor] = None,
prompt_helper: Optional[PromptHelper] = None,
text_splitter: Optional[TextSplitter] = None,
table_context_prompt: Optional[TableContextPrompt] = None,
refine_table_context_prompt: Optional[RefineTableContextPrompt] = None,
table_context_task: Optional[str] = None,
) -> None:
"""Initialize params."""
# TODO: take in an entire index instead of forming a response builder
if sql_database is None:
raise ValueError("sql_database must be provided.")
self._sql_database = sql_database
self._llm_predictor = llm_predictor or LLMPredictor()
self._prompt_helper = prompt_helper or PromptHelper.from_llm_predictor(
self._llm_predictor
)
self._text_splitter = text_splitter
self._table_context_prompt = (
table_context_prompt or DEFAULT_TABLE_CONTEXT_PROMPT
)
self._refine_table_context_prompt = (
refine_table_context_prompt or DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL
)
self._table_context_task = table_context_task or DEFAULT_TABLE_CONTEXT_QUERY
def build_all_context_from_documents(
self,
documents_dict: Dict[str, List[BaseDocument]],
) -> Dict[str, str]:
"""Build context for all tables in the database."""
context_dict = {}
for table_name in self._sql_database.get_table_names():
context_dict[table_name] = self.build_table_context_from_documents(
documents_dict[table_name], table_name
)
return context_dict
def build_table_context_from_documents(
self,
documents: Sequence[BaseDocument],
table_name: str,
) -> str:
"""Build context from documents for a single table."""
schema = self._sql_database.get_single_table_info(table_name)
prompt_with_schema = QuestionAnswerPrompt.from_prompt(
self._table_context_prompt.partial_format(schema=schema)
)
refine_prompt_with_schema = RefinePrompt.from_prompt(
self._refine_table_context_prompt.partial_format(schema=schema)
)
text_splitter = (
self._text_splitter
or self._prompt_helper.get_text_splitter_given_prompt(prompt_with_schema, 1)
)
# we use the ResponseBuilder to iteratively go through all texts
response_builder = ResponseBuilder(
self._prompt_helper,
self._llm_predictor,
prompt_with_schema,
refine_prompt_with_schema,
)
for doc in documents:
text_chunks = text_splitter.split_text(doc.get_text())
for text_chunk in text_chunks:
response_builder.add_text_chunks([TextChunk(text_chunk)])
# feed in the "query_str" or the task
table_context = response_builder.get_response(self._table_context_task)
return cast(str, table_context)
OUTPUT_PARSER_TYPE = Callable[[str], Optional[Dict[str, Any]]]
class BaseStructDatapointExtractor:
"""Extracts datapoints from a structured document."""
def __init__(
self,
llm_predictor: LLMPredictor,
text_splitter: TextSplitter,
schema_extract_prompt: SchemaExtractPrompt,
output_parser: OUTPUT_PARSER_TYPE,
) -> None:
"""Initialize params."""
self._llm_predictor = llm_predictor
self._text_splitter = text_splitter
self._schema_extract_prompt = schema_extract_prompt
self._output_parser = output_parser
def _clean_and_validate_fields(self, fields: Dict[str, Any]) -> Dict[str, Any]:
"""Validate fields with col_types_map."""
new_fields = {}
col_types_map = self._get_col_types_map()
for field, value in fields.items():
clean_value = value
if field not in col_types_map:
continue
# if expected type is int or float, try to convert value to int or float
expected_type = col_types_map[field]
if expected_type == int:
try:
clean_value = int(value)
except ValueError:
continue
elif expected_type == float:
try:
clean_value = float(value)
except ValueError:
continue
else:
if len(value) == 0:
continue
if not isinstance(value, col_types_map[field]):
continue
new_fields[field] = clean_value
return new_fields
@abstractmethod
def _insert_datapoint(self, datapoint: StructDatapoint) -> None:
"""Insert datapoint into index."""
@abstractmethod
def _get_col_types_map(self) -> Dict[str, type]:
"""Get col types map for schema."""
@abstractmethod
def _get_schema_text(self) -> str:
"""Get schema text for extracting relevant info from unstructured text."""
def insert_datapoint_from_document(self, document: BaseDocument) -> None:
"""Extract datapoint from a document and insert it."""
text_chunks = self._text_splitter.split_text(document.get_text())
fields = {}
for i, text_chunk in enumerate(text_chunks):
fmt_text_chunk = truncate_text(text_chunk, 50)
logging.info(f"> Adding chunk {i}: {fmt_text_chunk}")
# if embedding specified in document, pass it to the Node
schema_text = self._get_schema_text()
response_str, _ = self._llm_predictor.predict(
self._schema_extract_prompt,
text=text_chunk,
schema=schema_text,
)
cur_fields = self._output_parser(response_str)
if cur_fields is None:
continue
# validate fields with col_types_map
new_cur_fields = self._clean_and_validate_fields(cur_fields)
fields.update(new_cur_fields)
struct_datapoint = StructDatapoint(fields)
if struct_datapoint is not None:
self._insert_datapoint(struct_datapoint)
logging.debug(f"> Added datapoint: {fields}")