Spaces:
Runtime error
Runtime error
File size: 5,500 Bytes
35b22df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
"""Default query for GPTFaissIndex."""
import logging
from typing import Any, Optional
from gpt_index.data_structs.table import SQLStructTable
from gpt_index.indices.common.struct_store.schema import SQLContextContainer
from gpt_index.indices.query.base import BaseGPTIndexQuery
from gpt_index.indices.query.schema import QueryBundle, QueryMode
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from gpt_index.prompts.prompts import TextToSQLPrompt
from gpt_index.response.schema import Response
from gpt_index.token_counter.token_counter import llm_token_counter
class GPTSQLStructStoreIndexQuery(BaseGPTIndexQuery[SQLStructTable]):
"""GPT SQL query over a structured database.
Runs raw SQL over a GPTSQLStructStoreIndex. No LLM calls are made here.
NOTE: this query cannot work with composed indices - if the index
contains subindices, those subindices will not be queried.
.. code-block:: python
response = index.query("<query_str>", mode="sql")
"""
def __init__(
self,
index_struct: SQLStructTable,
sql_database: Optional[SQLDatabase] = None,
sql_context_container: Optional[SQLContextContainer] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
super().__init__(index_struct=index_struct, **kwargs)
if sql_database is None:
raise ValueError("sql_database must be provided.")
self._sql_database = sql_database
@llm_token_counter("query")
def query(self, query_bundle: QueryBundle) -> Response:
"""Answer a query."""
# NOTE: override query method in order to fetch the right results.
# NOTE: since the query_str is a SQL query, it doesn't make sense
# to use ResponseBuilder anywhere.
response_str, extra_info = self._sql_database.run_sql(query_bundle.query_str)
response = Response(response=response_str, extra_info=extra_info)
return response
class GPTNLStructStoreIndexQuery(BaseGPTIndexQuery[SQLStructTable]):
"""GPT natural language query over a structured database.
Given a natural language query, we will extract the query to SQL.
Runs raw SQL over a GPTSQLStructStoreIndex. No LLM calls are made here.
NOTE: this query cannot work with composed indices - if the index
contains subindices, those subindices will not be queried.
.. code-block:: python
response = index.query("<query_str>", mode="sql")
"""
def __init__(
self,
index_struct: SQLStructTable,
sql_database: Optional[SQLDatabase] = None,
sql_context_container: Optional[SQLContextContainer] = None,
ref_doc_id_column: Optional[str] = None,
text_to_sql_prompt: Optional[TextToSQLPrompt] = None,
context_query_mode: QueryMode = QueryMode.DEFAULT,
context_query_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Initialize params."""
super().__init__(index_struct=index_struct, **kwargs)
if sql_database is None:
raise ValueError("sql_database must be provided.")
self._sql_database = sql_database
if sql_context_container is None:
raise ValueError("sql_context_container must be provided.")
self._sql_context_container = sql_context_container
self._ref_doc_id_column = ref_doc_id_column
self._text_to_sql_prompt = text_to_sql_prompt or DEFAULT_TEXT_TO_SQL_PROMPT
self._context_query_mode = context_query_mode
self._context_query_kwargs = context_query_kwargs or {}
def _parse_response_to_sql(self, response: str) -> str:
"""Parse response to SQL."""
result_response = response.strip()
return result_response
def _get_table_context(self, query_bundle: QueryBundle) -> str:
"""Get table context.
Get tables schema + optional context as a single string. Taken from
SQLContextContainer.
"""
if self._sql_context_container.context_str is not None:
tables_desc_str = self._sql_context_container.context_str
else:
table_desc_list = []
context_dict = self._sql_context_container.context_dict
if context_dict is None:
raise ValueError(
"context_dict must be provided. There is currently no "
"table context."
)
for table_desc in context_dict.values():
table_desc_list.append(table_desc)
tables_desc_str = "\n\n".join(table_desc_list)
return tables_desc_str
def _query(self, query_bundle: QueryBundle) -> Response:
"""Answer a query."""
table_desc_str = self._get_table_context(query_bundle)
logging.info(f"> Table desc str: {table_desc_str}")
response_str, _ = self._llm_predictor.predict(
self._text_to_sql_prompt,
query_str=query_bundle.query_str,
schema=table_desc_str,
)
sql_query_str = self._parse_response_to_sql(response_str)
# assume that it's a valid SQL query
logging.debug(f"> Predicted SQL query: {sql_query_str}")
response_str, extra_info = self._sql_database.run_sql(sql_query_str)
extra_info["sql_query"] = sql_query_str
response = Response(response=response_str, extra_info=extra_info)
return response
|