File size: 5,496 Bytes
35b22df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Query runner."""

from typing import Any, Dict, List, Optional, Union, cast

from gpt_index.data_structs.data_structs import IndexStruct
from gpt_index.docstore import DocumentStore
from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.query.base import BaseGPTIndexQuery, BaseQueryRunner
from gpt_index.indices.query.query_transform import BaseQueryTransform
from gpt_index.indices.query.schema import QueryBundle, QueryConfig, QueryMode
from gpt_index.indices.registry import IndexRegistry
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.response.schema import Response

# TMP: refactor query config type
QUERY_CONFIG_TYPE = Union[Dict, QueryConfig]


class QueryRunner(BaseQueryRunner):
    """Tool to take in a query request and perform a query with the right classes.

    Higher-level wrapper over a given query.

    """

    def __init__(
        self,
        llm_predictor: LLMPredictor,
        prompt_helper: PromptHelper,
        embed_model: BaseEmbedding,
        docstore: DocumentStore,
        index_registry: IndexRegistry,
        query_configs: Optional[List[QUERY_CONFIG_TYPE]] = None,
        query_transform: Optional[BaseQueryTransform] = None,
        recursive: bool = False,
        use_async: bool = False,
    ) -> None:
        """Init params."""
        type_to_config_dict: Dict[str, QueryConfig] = {}
        id_to_config_dict: Dict[str, QueryConfig] = {}
        if query_configs is None or len(query_configs) == 0:
            query_config_objs: List[QueryConfig] = []
        elif isinstance(query_configs[0], Dict):
            query_config_objs = [
                QueryConfig.from_dict(cast(Dict, qc)) for qc in query_configs
            ]
        else:
            query_config_objs = [cast(QueryConfig, q) for q in query_configs]

        for qc in query_config_objs:
            type_to_config_dict[qc.index_struct_type] = qc
            if qc.index_struct_id is not None:
                id_to_config_dict[qc.index_struct_id] = qc

        self._type_to_config_dict = type_to_config_dict
        self._id_to_config_dict = id_to_config_dict
        self._llm_predictor = llm_predictor
        self._prompt_helper = prompt_helper
        self._embed_model = embed_model
        self._docstore = docstore
        self._index_registry = index_registry
        self._query_transform = query_transform or BaseQueryTransform()
        self._recursive = recursive
        self._use_async = use_async

    def _get_query_kwargs(self, config: QueryConfig) -> Dict[str, Any]:
        """Get query kwargs.

        Also update with default arguments if not present.

        """
        query_kwargs = {k: v for k, v in config.query_kwargs.items()}
        if "prompt_helper" not in query_kwargs:
            query_kwargs["prompt_helper"] = self._prompt_helper
        if "llm_predictor" not in query_kwargs:
            query_kwargs["llm_predictor"] = self._llm_predictor
        if "embed_model" not in query_kwargs:
            query_kwargs["embed_model"] = self._embed_model
        return query_kwargs

    def _get_query_obj(
        self,
        index_struct: IndexStruct,
    ) -> BaseGPTIndexQuery:
        """Get query object."""
        index_struct_id = index_struct.get_doc_id()
        index_struct_type = index_struct.get_type()
        if index_struct_id in self._id_to_config_dict:
            config = self._id_to_config_dict[index_struct_id]
        elif index_struct_type in self._type_to_config_dict:
            config = self._type_to_config_dict[index_struct_type]
        else:
            config = QueryConfig(
                index_struct_type=index_struct_type, query_mode=QueryMode.DEFAULT
            )
        mode = config.query_mode

        query_cls = self._index_registry.type_to_query[index_struct_type][mode]
        # if recursive, pass self as query_runner to each individual query
        query_runner = self
        query_kwargs = self._get_query_kwargs(config)
        query_obj = query_cls(
            index_struct,
            **query_kwargs,
            query_runner=query_runner,
            docstore=self._docstore,
            recursive=self._recursive,
            use_async=self._use_async,
        )

        return query_obj

    def query(
        self,
        query_str_or_bundle: Union[str, QueryBundle],
        index_struct: IndexStruct,
    ) -> Response:
        """Run query."""
        # NOTE: Currently, query transform is only run once
        # TODO: Consider refactor to support index-specific query transform
        if isinstance(query_str_or_bundle, str):
            query_bundle = self._query_transform(query_str_or_bundle)
        else:
            query_bundle = query_str_or_bundle
        query_obj = self._get_query_obj(index_struct)
        return query_obj.query(query_bundle)

    async def aquery(
        self,
        query_str_or_bundle: Union[str, QueryBundle],
        index_struct: IndexStruct,
    ) -> Response:
        """Run query."""
        # NOTE: Currently, query transform is only run once
        # TODO: Consider refactor to support index-specific query transform
        if isinstance(query_str_or_bundle, str):
            query_bundle = self._query_transform(query_str_or_bundle)
        else:
            query_bundle = query_str_or_bundle
        query_obj = self._get_query_obj(index_struct)
        return await query_obj.aquery(query_bundle)