File size: 8,750 Bytes
35b22df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""Composability graphs."""

import json
from typing import Any, Dict, List, Optional, Type, Union

from gpt_index.data_structs.data_structs import IndexStruct
from gpt_index.data_structs.struct_type import IndexStructType
from gpt_index.docstore import DocumentStore
from gpt_index.embeddings.base import BaseEmbedding
from gpt_index.embeddings.openai import OpenAIEmbedding
from gpt_index.indices.base import BaseGPTIndex
from gpt_index.indices.empty.base import GPTEmptyIndex
from gpt_index.indices.keyword_table.base import GPTKeywordTableIndex
from gpt_index.indices.knowledge_graph.base import GPTKnowledgeGraphIndex
from gpt_index.indices.list.base import GPTListIndex
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.query.query_runner import QueryRunner
from gpt_index.indices.query.schema import QueryBundle, QueryConfig
from gpt_index.indices.registry import IndexRegistry
from gpt_index.indices.struct_store.sql import GPTSQLStructStoreIndex
from gpt_index.indices.tree.base import GPTTreeIndex
from gpt_index.indices.vector_store.base import GPTVectorStoreIndex
from gpt_index.indices.vector_store.vector_indices import (
    GPTChromaIndex,
    GPTFaissIndex,
    GPTPineconeIndex,
    GPTQdrantIndex,
    GPTSimpleVectorIndex,
    GPTWeaviateIndex,
)
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.response.schema import Response

# TMP: refactor query config type
QUERY_CONFIG_TYPE = Union[Dict, QueryConfig]


# this is a map from type to outer index class
# we extract the type_to_struct and type_to_query
# fields from the index class
DEFAULT_INDEX_REGISTRY_MAP: Dict[IndexStructType, Type[BaseGPTIndex]] = {
    IndexStructType.TREE: GPTTreeIndex,
    IndexStructType.LIST: GPTListIndex,
    IndexStructType.KEYWORD_TABLE: GPTKeywordTableIndex,
    IndexStructType.SIMPLE_DICT: GPTSimpleVectorIndex,
    IndexStructType.DICT: GPTFaissIndex,
    IndexStructType.WEAVIATE: GPTWeaviateIndex,
    IndexStructType.PINECONE: GPTPineconeIndex,
    IndexStructType.QDRANT: GPTQdrantIndex,
    IndexStructType.CHROMA: GPTChromaIndex,
    IndexStructType.VECTOR_STORE: GPTVectorStoreIndex,
    IndexStructType.SQL: GPTSQLStructStoreIndex,
    IndexStructType.KG: GPTKnowledgeGraphIndex,
    IndexStructType.EMPTY: GPTEmptyIndex,
}


def _get_default_index_registry() -> IndexRegistry:
    """Get default index registry."""
    index_registry = IndexRegistry()
    for index_type, index_class in DEFAULT_INDEX_REGISTRY_MAP.items():
        index_registry.type_to_struct[index_type] = index_class.index_struct_cls
        index_registry.type_to_query[index_type] = index_class.get_query_map()
    return index_registry


def _safe_get_index_struct(
    docstore: DocumentStore, index_struct_id: str
) -> IndexStruct:
    """Try get index struct."""
    index_struct = docstore.get_document(index_struct_id)
    if not isinstance(index_struct, IndexStruct):
        raise ValueError("Invalid `index_struct_id` - must be an IndexStruct")
    return index_struct


class ComposableGraph:
    """Composable graph."""

    def __init__(
        self,
        docstore: DocumentStore,
        index_registry: IndexRegistry,
        index_struct: IndexStruct,
        llm_predictor: Optional[LLMPredictor] = None,
        prompt_helper: Optional[PromptHelper] = None,
        embed_model: Optional[BaseEmbedding] = None,
        chunk_size_limit: Optional[int] = None,
    ) -> None:
        """Init params."""
        self._docstore = docstore
        self._index_registry = index_registry
        # this represents the "root" index struct
        self._index_struct = index_struct

        self._llm_predictor = llm_predictor or LLMPredictor()
        self._prompt_helper = prompt_helper or PromptHelper.from_llm_predictor(
            self._llm_predictor, chunk_size_limit=chunk_size_limit
        )
        self._embed_model = embed_model or OpenAIEmbedding()

    @classmethod
    def build_from_index(self, index: BaseGPTIndex) -> "ComposableGraph":
        """Build from index."""
        return ComposableGraph(
            index.docstore,
            index.index_registry,
            # this represents the "root" index struct
            index.index_struct,
            llm_predictor=index.llm_predictor,
            prompt_helper=index.prompt_helper,
            embed_model=index.embed_model,
        )

    def query(
        self,
        query_str: Union[str, QueryBundle],
        query_configs: Optional[List[QUERY_CONFIG_TYPE]] = None,
        llm_predictor: Optional[LLMPredictor] = None,
    ) -> Response:
        """Query the index."""
        # go over all the indices and create a registry
        llm_predictor = llm_predictor or self._llm_predictor
        query_runner = QueryRunner(
            llm_predictor,
            self._prompt_helper,
            self._embed_model,
            self._docstore,
            self._index_registry,
            query_configs=query_configs,
            recursive=True,
        )
        return query_runner.query(query_str, self._index_struct)

    async def aquery(
        self,
        query_str: Union[str, QueryBundle],
        query_configs: Optional[List[QUERY_CONFIG_TYPE]] = None,
        llm_predictor: Optional[LLMPredictor] = None,
    ) -> Response:
        """Query the index."""
        # go over all the indices and create a registry
        llm_predictor = llm_predictor or self._llm_predictor
        query_runner = QueryRunner(
            llm_predictor,
            self._prompt_helper,
            self._embed_model,
            self._docstore,
            self._index_registry,
            query_configs=query_configs,
            recursive=True,
        )
        return await query_runner.aquery(query_str, self._index_struct)

    def get_index(
        self, index_struct_id: str, index_cls: Type[BaseGPTIndex], **kwargs: Any
    ) -> BaseGPTIndex:
        """Get index."""
        index_struct = _safe_get_index_struct(self._docstore, index_struct_id)
        return index_cls(
            index_struct=index_struct,
            docstore=self._docstore,
            index_registry=self._index_registry,
            **kwargs
        )

    @classmethod
    def load_from_string(cls, index_string: str, **kwargs: Any) -> "ComposableGraph":
        """Load index from string (in JSON-format).

        This method loads the index from a JSON string. The index data
        structure itself is preserved completely. If the index is defined over
        subindices, those subindices will also be preserved (and subindices of
        those subindices, etc.).

        Args:
            save_path (str): The save_path of the file.

        Returns:
            BaseGPTIndex: The loaded index.

        """
        result_dict = json.loads(index_string)
        # TODO: this is hardcoded for now, allow it to be specified by the user
        index_registry = _get_default_index_registry()
        docstore = DocumentStore.load_from_dict(
            result_dict["docstore"], index_registry.type_to_struct
        )
        index_struct = _safe_get_index_struct(docstore, result_dict["index_struct_id"])
        return cls(docstore, index_registry, index_struct, **kwargs)

    @classmethod
    def load_from_disk(cls, save_path: str, **kwargs: Any) -> "ComposableGraph":
        """Load index from disk.

        This method loads the index from a JSON file stored on disk. The index data
        structure itself is preserved completely. If the index is defined over
        subindices, those subindices will also be preserved (and subindices of
        those subindices, etc.).

        Args:
            save_path (str): The save_path of the file.

        Returns:
            BaseGPTIndex: The loaded index.

        """
        with open(save_path, "r") as f:
            file_contents = f.read()
            return cls.load_from_string(file_contents, **kwargs)

    def save_to_string(self, **save_kwargs: Any) -> str:
        """Save to string.

        This method stores the index into a JSON file stored on disk.

        Args:
            save_path (str): The save_path of the file.

        """
        out_dict: Dict[str, Any] = {
            "index_struct_id": self._index_struct.get_doc_id(),
            "docstore": self._docstore.serialize_to_dict(),
        }
        return json.dumps(out_dict)

    def save_to_disk(self, save_path: str, **save_kwargs: Any) -> None:
        """Save to file.

        This method stores the index into a JSON file stored on disk.

        Args:
            save_path (str): The save_path of the file.

        """
        index_string = self.save_to_string(**save_kwargs)
        with open(save_path, "w") as f:
            f.write(index_string)