Spaces:
Runtime error
Runtime error
File size: 8,234 Bytes
35b22df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
"""Common classes for structured operations."""
import logging
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Optional, Sequence, cast
from gpt_index.data_structs.table import StructDatapoint
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.response.builder import ResponseBuilder, TextChunk
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.langchain_helpers.text_splitter import TextSplitter
from gpt_index.prompts.default_prompt_selectors import (
DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL,
)
from gpt_index.prompts.default_prompts import (
DEFAULT_TABLE_CONTEXT_PROMPT,
DEFAULT_TABLE_CONTEXT_QUERY,
)
from gpt_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
RefineTableContextPrompt,
SchemaExtractPrompt,
TableContextPrompt,
)
from gpt_index.schema import BaseDocument
from gpt_index.utils import truncate_text
class SQLDocumentContextBuilder:
"""Builder that builds context for a given set of SQL tables.
Args:
sql_database (Optional[SQLDatabase]): SQL database to use,
llm_predictor (Optional[LLMPredictor]): LLM Predictor to use.
prompt_helper (Optional[PromptHelper]): Prompt Helper to use.
text_splitter (Optional[TextSplitter]): Text Splitter to use.
table_context_prompt (Optional[TableContextPrompt]): A
Table Context Prompt (see :ref:`Prompt-Templates`).
refine_table_context_prompt (Optional[RefineTableContextPrompt]):
A Refine Table Context Prompt (see :ref:`Prompt-Templates`).
table_context_task (Optional[str]): The query to perform
on the table context. A default query string is used
if none is provided by the user.
"""
def __init__(
self,
sql_database: SQLDatabase,
llm_predictor: Optional[LLMPredictor] = None,
prompt_helper: Optional[PromptHelper] = None,
text_splitter: Optional[TextSplitter] = None,
table_context_prompt: Optional[TableContextPrompt] = None,
refine_table_context_prompt: Optional[RefineTableContextPrompt] = None,
table_context_task: Optional[str] = None,
) -> None:
"""Initialize params."""
# TODO: take in an entire index instead of forming a response builder
if sql_database is None:
raise ValueError("sql_database must be provided.")
self._sql_database = sql_database
self._llm_predictor = llm_predictor or LLMPredictor()
self._prompt_helper = prompt_helper or PromptHelper.from_llm_predictor(
self._llm_predictor
)
self._text_splitter = text_splitter
self._table_context_prompt = (
table_context_prompt or DEFAULT_TABLE_CONTEXT_PROMPT
)
self._refine_table_context_prompt = (
refine_table_context_prompt or DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL
)
self._table_context_task = table_context_task or DEFAULT_TABLE_CONTEXT_QUERY
def build_all_context_from_documents(
self,
documents_dict: Dict[str, List[BaseDocument]],
) -> Dict[str, str]:
"""Build context for all tables in the database."""
context_dict = {}
for table_name in self._sql_database.get_table_names():
context_dict[table_name] = self.build_table_context_from_documents(
documents_dict[table_name], table_name
)
return context_dict
def build_table_context_from_documents(
self,
documents: Sequence[BaseDocument],
table_name: str,
) -> str:
"""Build context from documents for a single table."""
schema = self._sql_database.get_single_table_info(table_name)
prompt_with_schema = QuestionAnswerPrompt.from_prompt(
self._table_context_prompt.partial_format(schema=schema)
)
refine_prompt_with_schema = RefinePrompt.from_prompt(
self._refine_table_context_prompt.partial_format(schema=schema)
)
text_splitter = (
self._text_splitter
or self._prompt_helper.get_text_splitter_given_prompt(prompt_with_schema, 1)
)
# we use the ResponseBuilder to iteratively go through all texts
response_builder = ResponseBuilder(
self._prompt_helper,
self._llm_predictor,
prompt_with_schema,
refine_prompt_with_schema,
)
for doc in documents:
text_chunks = text_splitter.split_text(doc.get_text())
for text_chunk in text_chunks:
response_builder.add_text_chunks([TextChunk(text_chunk)])
# feed in the "query_str" or the task
table_context = response_builder.get_response(self._table_context_task)
return cast(str, table_context)
OUTPUT_PARSER_TYPE = Callable[[str], Optional[Dict[str, Any]]]
class BaseStructDatapointExtractor:
"""Extracts datapoints from a structured document."""
def __init__(
self,
llm_predictor: LLMPredictor,
text_splitter: TextSplitter,
schema_extract_prompt: SchemaExtractPrompt,
output_parser: OUTPUT_PARSER_TYPE,
) -> None:
"""Initialize params."""
self._llm_predictor = llm_predictor
self._text_splitter = text_splitter
self._schema_extract_prompt = schema_extract_prompt
self._output_parser = output_parser
def _clean_and_validate_fields(self, fields: Dict[str, Any]) -> Dict[str, Any]:
"""Validate fields with col_types_map."""
new_fields = {}
col_types_map = self._get_col_types_map()
for field, value in fields.items():
clean_value = value
if field not in col_types_map:
continue
# if expected type is int or float, try to convert value to int or float
expected_type = col_types_map[field]
if expected_type == int:
try:
clean_value = int(value)
except ValueError:
continue
elif expected_type == float:
try:
clean_value = float(value)
except ValueError:
continue
else:
if len(value) == 0:
continue
if not isinstance(value, col_types_map[field]):
continue
new_fields[field] = clean_value
return new_fields
@abstractmethod
def _insert_datapoint(self, datapoint: StructDatapoint) -> None:
"""Insert datapoint into index."""
@abstractmethod
def _get_col_types_map(self) -> Dict[str, type]:
"""Get col types map for schema."""
@abstractmethod
def _get_schema_text(self) -> str:
"""Get schema text for extracting relevant info from unstructured text."""
def insert_datapoint_from_document(self, document: BaseDocument) -> None:
"""Extract datapoint from a document and insert it."""
text_chunks = self._text_splitter.split_text(document.get_text())
fields = {}
for i, text_chunk in enumerate(text_chunks):
fmt_text_chunk = truncate_text(text_chunk, 50)
logging.info(f"> Adding chunk {i}: {fmt_text_chunk}")
# if embedding specified in document, pass it to the Node
schema_text = self._get_schema_text()
response_str, _ = self._llm_predictor.predict(
self._schema_extract_prompt,
text=text_chunk,
schema=schema_text,
)
cur_fields = self._output_parser(response_str)
if cur_fields is None:
continue
# validate fields with col_types_map
new_cur_fields = self._clean_and_validate_fields(cur_fields)
fields.update(new_cur_fields)
struct_datapoint = StructDatapoint(fields)
if struct_datapoint is not None:
self._insert_datapoint(struct_datapoint)
logging.debug(f"> Added datapoint: {fields}")
|