File size: 6,873 Bytes
8a58cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""GPT Tree Index inserter."""

from typing import Optional

from gpt_index.data_structs.data_structs import IndexGraph, Node
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.utils import extract_numbers_given_response, get_sorted_node_list
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.text_splitter import TextSplitter
from gpt_index.prompts.base import Prompt
from gpt_index.prompts.default_prompts import (
    DEFAULT_INSERT_PROMPT,
    DEFAULT_SUMMARY_PROMPT,
)
from gpt_index.schema import BaseDocument


class GPTIndexInserter:
    """LlamaIndex inserter."""

    def __init__(
        self,
        index_graph: IndexGraph,
        llm_predictor: LLMPredictor,
        prompt_helper: PromptHelper,
        text_splitter: TextSplitter,
        num_children: int = 10,
        insert_prompt: Prompt = DEFAULT_INSERT_PROMPT,
        summary_prompt: Prompt = DEFAULT_SUMMARY_PROMPT,
    ) -> None:
        """Initialize with params."""
        if num_children < 2:
            raise ValueError("Invalid number of children.")
        self.num_children = num_children
        self.summary_prompt = summary_prompt
        self.insert_prompt = insert_prompt
        self.index_graph = index_graph
        self._llm_predictor = llm_predictor
        self._prompt_helper = prompt_helper
        self._text_splitter = text_splitter

    def _insert_under_parent_and_consolidate(
        self, text_chunk: str, doc: BaseDocument, parent_node: Optional[Node]
    ) -> None:
        """Insert node under parent and consolidate.

        Consolidation will happen by dividing up child nodes, and creating a new
        intermediate layer of nodes.

        """
        # perform insertion
        text_node = Node(
            text=text_chunk,
            index=self.index_graph.size,
            ref_doc_id=doc.get_doc_id(),
            embedding=doc.embedding,
            extra_info=doc.extra_info,
        )
        self.index_graph.insert_under_parent(text_node, parent_node)

        # if under num_children limit, then we're fine
        if len(self.index_graph.get_children(parent_node)) <= self.num_children:
            return
        else:
            # perform consolidation
            cur_graph_nodes = self.index_graph.get_children(parent_node)
            cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
            # this layer is all leaf nodes, consolidate and split leaf nodes
            cur_node_index = self.index_graph.size
            # consolidate and split leaf nodes in half
            # TODO: do better splitting (with a GPT prompt etc.)
            half1 = cur_graph_node_list[: len(cur_graph_nodes) // 2]
            half2 = cur_graph_node_list[len(cur_graph_nodes) // 2 :]

            text_chunk1 = self._prompt_helper.get_text_from_nodes(
                half1, prompt=self.summary_prompt
            )
            summary1, _ = self._llm_predictor.predict(
                self.summary_prompt, context_str=text_chunk1
            )
            node1 = Node(
                text=summary1,
                index=cur_node_index,
                child_indices={n.index for n in half1},
            )

            text_chunk2 = self._prompt_helper.get_text_from_nodes(
                half2, prompt=self.summary_prompt
            )
            summary2, _ = self._llm_predictor.predict(
                self.summary_prompt, context_str=text_chunk2
            )
            node2 = Node(
                text=summary2,
                index=cur_node_index + 1,
                child_indices={n.index for n in half2},
            )

            # insert half1 and half2 as new children of parent_node
            # first remove child indices from parent node
            if parent_node is not None:
                parent_node.child_indices = set()
            else:
                self.index_graph.root_nodes = {}
            self.index_graph.insert_under_parent(node1, parent_node)
            self.index_graph.insert_under_parent(node2, parent_node)

    def _insert_node(
        self, text_chunk: str, doc: BaseDocument, parent_node: Optional[Node]
    ) -> None:
        """Insert node."""
        cur_graph_nodes = self.index_graph.get_children(parent_node)
        cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
        # if cur_graph_nodes is empty (start with empty graph), then insert under
        # parent (insert new root node)
        if len(cur_graph_nodes) == 0:
            self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node)
        # check if leaf nodes, then just insert under parent
        elif len(cur_graph_node_list[0].child_indices) == 0:
            self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node)
        # else try to find the right summary node to insert under
        else:
            numbered_text = self._prompt_helper.get_numbered_text_from_nodes(
                cur_graph_node_list, prompt=self.insert_prompt
            )
            response, _ = self._llm_predictor.predict(
                self.insert_prompt,
                new_chunk_text=text_chunk,
                num_chunks=len(cur_graph_node_list),
                context_list=numbered_text,
            )
            numbers = extract_numbers_given_response(response)
            if numbers is None or len(numbers) == 0:
                # NOTE: if we can't extract a number, then we just insert under parent
                self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node)
            elif int(numbers[0]) > len(cur_graph_node_list):
                # NOTE: if number is out of range, then we just insert under parent
                self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node)
            else:
                selected_node = cur_graph_node_list[int(numbers[0]) - 1]
                self._insert_node(text_chunk, doc, selected_node)

        # now we need to update summary for parent node, since we
        # need to bubble updated summaries up the tree
        if parent_node is not None:
            # refetch children
            cur_graph_nodes = self.index_graph.get_children(parent_node)
            cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
            text_chunk = self._prompt_helper.get_text_from_nodes(
                cur_graph_node_list, prompt=self.summary_prompt
            )
            new_summary, _ = self._llm_predictor.predict(
                self.summary_prompt, context_str=text_chunk
            )

            parent_node.text = new_summary

    def insert(self, doc: BaseDocument) -> None:
        """Insert into index_graph."""
        text_chunks = self._text_splitter.split_text(doc.get_text())

        for text_chunk in text_chunks:
            self._insert_node(text_chunk, doc, None)