File size: 16,934 Bytes
35b22df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
"""Base query classes."""

import logging
import re
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generator, Generic, List, Optional, Tuple, TypeVar, cast

from gpt_index.data_structs.data_structs import IndexStruct, Node
from gpt_index.docstore import DocumentStore
from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.embeddings.openai import OpenAIEmbedding
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.query.embedding_utils import SimilarityTracker
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.response.builder import (
    RESPONSE_TEXT_TYPE,
    ResponseBuilder,
    ResponseMode,
    TextChunk,
)
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.prompts.default_prompt_selectors import DEFAULT_REFINE_PROMPT_SEL
from gpt_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT
from gpt_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt
from gpt_index.response.schema import RESPONSE_TYPE, Response, StreamingResponse
from gpt_index.token_counter.token_counter import llm_token_counter
from gpt_index.utils import truncate_text

IS = TypeVar("IS", bound=IndexStruct)


@dataclass
class BaseQueryRunner:
    """Base query runner."""

    @abstractmethod
    def query(self, query_bundle: QueryBundle, index_struct: IndexStruct) -> Response:
        """Schedule a query."""
        raise NotImplementedError("Not implemented yet.")


class BaseGPTIndexQuery(Generic[IS]):
    """Base LlamaIndex Query.

    Helper class that is used to query an index. Can be called within `query`
    method of a BaseGPTIndex object, or instantiated independently.

    Args:
        llm_predictor (LLMPredictor): Optional LLMPredictor object. If not provided,
            will use the default LLMPredictor (text-davinci-003)
        prompt_helper (PromptHelper): Optional PromptHelper object. If not provided,
            will use the default PromptHelper.
        required_keywords (List[str]): Optional list of keywords that must be present
            in nodes. Can be used to query most indices (tree index is an exception).
        exclude_keywords (List[str]): Optional list of keywords that must not be
            present in nodes. Can be used to query most indices (tree index is an
            exception).
        response_mode (ResponseMode): Optional ResponseMode. If not provided, will
            use the default ResponseMode.
        text_qa_template (QuestionAnswerPrompt): Optional QuestionAnswerPrompt object.
            If not provided, will use the default QuestionAnswerPrompt.
        refine_template (RefinePrompt): Optional RefinePrompt object. If not provided,
            will use the default RefinePrompt.
        include_summary (bool): Optional bool. If True, will also use the summary
            text of the index when generating a response (the summary text can be set
            through `index.set_text("<text>")`).
        similarity_cutoff (float): Optional float. If set, will filter out nodes with
            similarity below this cutoff threshold when computing the response
        streaming (bool): Optional bool. If True, will return a StreamingResponse
            object. If False, will return a Response object.

    """

    def __init__(
        self,
        index_struct: IS,
        # TODO: pass from superclass
        llm_predictor: Optional[LLMPredictor] = None,
        prompt_helper: Optional[PromptHelper] = None,
        embed_model: Optional[BaseEmbedding] = None,
        docstore: Optional[DocumentStore] = None,
        query_runner: Optional[BaseQueryRunner] = None,
        required_keywords: Optional[List[str]] = None,
        exclude_keywords: Optional[List[str]] = None,
        response_mode: ResponseMode = ResponseMode.DEFAULT,
        text_qa_template: Optional[QuestionAnswerPrompt] = None,
        refine_template: Optional[RefinePrompt] = None,
        include_summary: bool = False,
        response_kwargs: Optional[Dict] = None,
        similarity_cutoff: Optional[float] = None,
        use_async: bool = True,
        recursive: bool = False,
        streaming: bool = False,
        doc_ids: Optional[List[str]] = None,
    ) -> None:
        """Initialize with parameters."""
        if index_struct is None:
            raise ValueError("index_struct must be provided.")
        self._validate_index_struct(index_struct)
        self._index_struct = index_struct
        self._llm_predictor = llm_predictor or LLMPredictor()
        # NOTE: the embed_model isn't used in all indices
        self._embed_model = embed_model or OpenAIEmbedding()
        self._docstore = docstore
        self._query_runner = query_runner
        # TODO: make this a required param
        if prompt_helper is None:
            raise ValueError("prompt_helper must be provided.")
        self._prompt_helper = cast(PromptHelper, prompt_helper)

        self._required_keywords = required_keywords
        self._exclude_keywords = exclude_keywords
        self._response_mode = ResponseMode(response_mode)

        self.text_qa_template = text_qa_template or DEFAULT_TEXT_QA_PROMPT
        self.refine_template = refine_template or DEFAULT_REFINE_PROMPT_SEL
        self._include_summary = include_summary

        self._response_kwargs = response_kwargs or {}
        self._use_async = use_async
        self.response_builder = ResponseBuilder(
            self._prompt_helper,
            self._llm_predictor,
            self.text_qa_template,
            self.refine_template,
            use_async=use_async,
            streaming=streaming,
        )

        self.similarity_cutoff = similarity_cutoff
        self._recursive = recursive
        self._streaming = streaming
        self._doc_ids = doc_ids

    def _should_use_node(
        self, node: Node, similarity_tracker: Optional[SimilarityTracker] = None
    ) -> bool:
        """Run node through filters to determine if it should be used."""
        words = re.findall(r"\w+", node.get_text())
        if self._required_keywords is not None:
            for w in self._required_keywords:
                if w not in words:
                    return False

        if self._exclude_keywords is not None:
            for w in self._exclude_keywords:
                if w in words:
                    return False

        sim_cutoff_exists = (
            similarity_tracker is not None and self.similarity_cutoff is not None
        )

        if sim_cutoff_exists:
            similarity = cast(SimilarityTracker, similarity_tracker).find(node)
            if similarity is None:
                return False
            if cast(float, similarity) < cast(float, self.similarity_cutoff):
                return False

        return True

    def _get_text_from_node(
        self,
        query_bundle: QueryBundle,
        node: Node,
        level: Optional[int] = None,
    ) -> Tuple[TextChunk, Optional[Response]]:
        """Query a given node.

        If node references a given document, then return the document.
        If node references a given index, then query the index.

        """
        level_str = "" if level is None else f"[Level {level}]"
        fmt_text_chunk = truncate_text(node.get_text(), 50)
        logging.debug(f">{level_str} Searching in chunk: {fmt_text_chunk}")

        is_index_struct = False
        # if recursive and self._query_runner is not None,
        # assume we want to do a recursive
        # query. In order to not perform a recursive query, make sure
        # _query_runner is None.
        if (
            self._recursive
            and self._query_runner is not None
            and node.ref_doc_id is not None
            and self._docstore is not None
        ):
            doc = self._docstore.get_document(node.ref_doc_id, raise_error=False)
            # NOTE: old version of the docstore contain both documents and index_struct,
            # whereas new versions of the docstore only contain the index struct
            if doc is not None and isinstance(doc, IndexStruct):
                is_index_struct = True

        if is_index_struct:
            query_runner = cast(BaseQueryRunner, self._query_runner)
            response = query_runner.query(query_bundle, cast(IndexStruct, doc))
            return TextChunk(str(response), is_answer=True), response
        else:
            text = node.get_text()
            return TextChunk(text), None

    @property
    def index_struct(self) -> IS:
        """Get the index struct."""
        return self._index_struct

    def _validate_index_struct(self, index_struct: IS) -> None:
        """Validate the index struct."""
        pass

    def _give_response_for_nodes(self, query_str: str) -> RESPONSE_TEXT_TYPE:
        """Give response for nodes."""
        response = self.response_builder.get_response(
            query_str,
            mode=self._response_mode,
            **self._response_kwargs,
        )
        return response

    async def _agive_response_for_nodes(self, query_str: str) -> RESPONSE_TEXT_TYPE:
        """Give response for nodes."""
        response = await self.response_builder.aget_response(
            query_str,
            mode=self._response_mode,
            **self._response_kwargs,
        )
        return response

    def get_nodes_and_similarities_for_response(
        self, query_bundle: QueryBundle
    ) -> List[Tuple[Node, Optional[float]]]:
        """Get list of tuples of node and similarity for response.

        First part of the tuple is the node.
        Second part of tuple is the distance from query to the node.
        If not applicable, it's None.
        """
        similarity_tracker = SimilarityTracker()
        nodes = self._get_nodes_for_response(
            query_bundle, similarity_tracker=similarity_tracker
        )
        nodes = [
            node for node in nodes if self._should_use_node(node, similarity_tracker)
        ]

        # TODO: create a `display` method to allow subclasses to print the Node
        return similarity_tracker.get_zipped_nodes(nodes)

    @abstractmethod
    def _get_nodes_for_response(
        self,
        query_bundle: QueryBundle,
        similarity_tracker: Optional[SimilarityTracker] = None,
    ) -> List[Node]:
        """Get nodes for response."""

    def _get_extra_info_for_response(
        self,
        nodes: List[Node],
    ) -> Optional[Dict[str, Any]]:
        """Get extra info for response."""
        return None

    def _prepare_response_builder(
        self,
        response_builder: ResponseBuilder,
        query_bundle: QueryBundle,
        tuples: List[Tuple[Node, Optional[float]]],
    ) -> None:
        """Prepare response builder and return values for query time."""
        response_builder.reset()
        for node, similarity in tuples:
            text, response = self._get_text_from_node(query_bundle, node)
            response_builder.add_node_as_source(node, similarity=similarity)
            if response is not None:
                # these are source nodes from within this node (when it's an index)
                for source_node in response.source_nodes:
                    response_builder.add_source_node(source_node)
            response_builder.add_text_chunks([text])

    def _prepare_response_output(
        self,
        response_str: Optional[RESPONSE_TEXT_TYPE],
        tuples: List[Tuple[Node, Optional[float]]],
    ) -> RESPONSE_TYPE:
        """Prepare response object from response string."""
        response_extra_info = self._get_extra_info_for_response(
            [node for node, _ in tuples]
        )

        if response_str is None or isinstance(response_str, str):
            return Response(
                response_str,
                source_nodes=self.response_builder.get_sources(),
                extra_info=response_extra_info,
            )
        elif response_str is None or isinstance(response_str, Generator):
            return StreamingResponse(
                response_str,
                source_nodes=self.response_builder.get_sources(),
                extra_info=response_extra_info,
            )
        else:
            raise ValueError("Response must be a string or a generator.")

    def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """Answer a query."""
        # TODO: remove _query and just use query
        tuples = self.get_nodes_and_similarities_for_response(query_bundle)

        # prepare response builder
        self._prepare_response_builder(self.response_builder, query_bundle, tuples)

        if self._response_mode != ResponseMode.NO_TEXT:
            response_str = self._give_response_for_nodes(query_bundle.query_str)
        else:
            response_str = None

        return self._prepare_response_output(response_str, tuples)

    async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """Answer a query asynchronously."""
        # TODO: remove _query and just use query
        tuples = self.get_nodes_and_similarities_for_response(query_bundle)

        # prepare response builder
        self._prepare_response_builder(self.response_builder, query_bundle, tuples)

        if self._response_mode != ResponseMode.NO_TEXT:
            response_str = await self._agive_response_for_nodes(query_bundle.query_str)
        else:
            response_str = None

        return self._prepare_response_output(response_str, tuples)

    @llm_token_counter("query")
    def query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """Answer a query."""
        response = self._query(query_bundle)
        # if include_summary is True, then include summary text in answer
        # summary text is set through `set_text` on the underlying index.
        # TODO: refactor response builder to be in the __init__
        if self._response_mode != ResponseMode.NO_TEXT and self._include_summary:
            response_builder = ResponseBuilder(
                self._prompt_helper,
                self._llm_predictor,
                self.text_qa_template,
                self.refine_template,
                texts=[TextChunk(self._index_struct.get_text())],
                streaming=self._streaming,
            )
            if isinstance(response, Response):
                # NOTE: use create and refine for now (default response mode)
                response_str = response_builder.get_response(
                    query_bundle.query_str,
                    mode=self._response_mode,
                    prev_response=response.response,
                )
                response.response = cast(str, response_str)
            elif isinstance(response, StreamingResponse):
                response_gen = response_builder.get_response(
                    query_bundle.query_str,
                    mode=self._response_mode,
                    prev_response=str(response.response_gen),
                )
                response.response_gen = cast(Generator, response_gen)
            else:
                raise ValueError("Response must be a string or a generator.")

        return response

    @llm_token_counter("query")
    async def aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
        """Answer a query."""
        response = await self._aquery(query_bundle)
        # if include_summary is True, then include summary text in answer
        # summary text is set through `set_text` on the underlying index.
        # TODO: refactor response builder to be in the __init__
        if self._response_mode != ResponseMode.NO_TEXT and self._include_summary:
            response_builder = ResponseBuilder(
                self._prompt_helper,
                self._llm_predictor,
                self.text_qa_template,
                self.refine_template,
                texts=[TextChunk(self._index_struct.get_text())],
                streaming=self._streaming,
            )
            if isinstance(response, Response):
                # NOTE: use create and refine for now (default response mode)
                response_str = await response_builder.aget_response(
                    query_bundle.query_str,
                    mode=self._response_mode,
                    prev_response=response.response,
                )
                response.response = cast(str, response_str)
            elif isinstance(response, StreamingResponse):
                response_gen = await response_builder.aget_response(
                    query_bundle.query_str,
                    mode=self._response_mode,
                    prev_response=str(response.response_gen),
                )
                response.response_gen = cast(Generator, response_gen)
            else:
                raise ValueError("Response must be a string or a generator.")

        return response