Spaces:
Runtime error
Runtime error
"""GPT Tree Index inserter.""" | |
from typing import Optional | |
from gpt_index.data_structs.data_structs import IndexGraph, Node | |
from gpt_index.indices.prompt_helper import PromptHelper | |
from gpt_index.indices.utils import extract_numbers_given_response, get_sorted_node_list | |
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor | |
from gpt_index.langchain_helpers.text_splitter import TextSplitter | |
from gpt_index.prompts.base import Prompt | |
from gpt_index.prompts.default_prompts import ( | |
DEFAULT_INSERT_PROMPT, | |
DEFAULT_SUMMARY_PROMPT, | |
) | |
from gpt_index.schema import BaseDocument | |
class GPTIndexInserter: | |
"""LlamaIndex inserter.""" | |
def __init__( | |
self, | |
index_graph: IndexGraph, | |
llm_predictor: LLMPredictor, | |
prompt_helper: PromptHelper, | |
text_splitter: TextSplitter, | |
num_children: int = 10, | |
insert_prompt: Prompt = DEFAULT_INSERT_PROMPT, | |
summary_prompt: Prompt = DEFAULT_SUMMARY_PROMPT, | |
) -> None: | |
"""Initialize with params.""" | |
if num_children < 2: | |
raise ValueError("Invalid number of children.") | |
self.num_children = num_children | |
self.summary_prompt = summary_prompt | |
self.insert_prompt = insert_prompt | |
self.index_graph = index_graph | |
self._llm_predictor = llm_predictor | |
self._prompt_helper = prompt_helper | |
self._text_splitter = text_splitter | |
def _insert_under_parent_and_consolidate( | |
self, text_chunk: str, doc: BaseDocument, parent_node: Optional[Node] | |
) -> None: | |
"""Insert node under parent and consolidate. | |
Consolidation will happen by dividing up child nodes, and creating a new | |
intermediate layer of nodes. | |
""" | |
# perform insertion | |
text_node = Node( | |
text=text_chunk, | |
index=self.index_graph.size, | |
ref_doc_id=doc.get_doc_id(), | |
embedding=doc.embedding, | |
extra_info=doc.extra_info, | |
) | |
self.index_graph.insert_under_parent(text_node, parent_node) | |
# if under num_children limit, then we're fine | |
if len(self.index_graph.get_children(parent_node)) <= self.num_children: | |
return | |
else: | |
# perform consolidation | |
cur_graph_nodes = self.index_graph.get_children(parent_node) | |
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes) | |
# this layer is all leaf nodes, consolidate and split leaf nodes | |
cur_node_index = self.index_graph.size | |
# consolidate and split leaf nodes in half | |
# TODO: do better splitting (with a GPT prompt etc.) | |
half1 = cur_graph_node_list[: len(cur_graph_nodes) // 2] | |
half2 = cur_graph_node_list[len(cur_graph_nodes) // 2 :] | |
text_chunk1 = self._prompt_helper.get_text_from_nodes( | |
half1, prompt=self.summary_prompt | |
) | |
summary1, _ = self._llm_predictor.predict( | |
self.summary_prompt, context_str=text_chunk1 | |
) | |
node1 = Node( | |
text=summary1, | |
index=cur_node_index, | |
child_indices={n.index for n in half1}, | |
) | |
text_chunk2 = self._prompt_helper.get_text_from_nodes( | |
half2, prompt=self.summary_prompt | |
) | |
summary2, _ = self._llm_predictor.predict( | |
self.summary_prompt, context_str=text_chunk2 | |
) | |
node2 = Node( | |
text=summary2, | |
index=cur_node_index + 1, | |
child_indices={n.index for n in half2}, | |
) | |
# insert half1 and half2 as new children of parent_node | |
# first remove child indices from parent node | |
if parent_node is not None: | |
parent_node.child_indices = set() | |
else: | |
self.index_graph.root_nodes = {} | |
self.index_graph.insert_under_parent(node1, parent_node) | |
self.index_graph.insert_under_parent(node2, parent_node) | |
def _insert_node( | |
self, text_chunk: str, doc: BaseDocument, parent_node: Optional[Node] | |
) -> None: | |
"""Insert node.""" | |
cur_graph_nodes = self.index_graph.get_children(parent_node) | |
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes) | |
# if cur_graph_nodes is empty (start with empty graph), then insert under | |
# parent (insert new root node) | |
if len(cur_graph_nodes) == 0: | |
self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node) | |
# check if leaf nodes, then just insert under parent | |
elif len(cur_graph_node_list[0].child_indices) == 0: | |
self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node) | |
# else try to find the right summary node to insert under | |
else: | |
numbered_text = self._prompt_helper.get_numbered_text_from_nodes( | |
cur_graph_node_list, prompt=self.insert_prompt | |
) | |
response, _ = self._llm_predictor.predict( | |
self.insert_prompt, | |
new_chunk_text=text_chunk, | |
num_chunks=len(cur_graph_node_list), | |
context_list=numbered_text, | |
) | |
numbers = extract_numbers_given_response(response) | |
if numbers is None or len(numbers) == 0: | |
# NOTE: if we can't extract a number, then we just insert under parent | |
self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node) | |
elif int(numbers[0]) > len(cur_graph_node_list): | |
# NOTE: if number is out of range, then we just insert under parent | |
self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node) | |
else: | |
selected_node = cur_graph_node_list[int(numbers[0]) - 1] | |
self._insert_node(text_chunk, doc, selected_node) | |
# now we need to update summary for parent node, since we | |
# need to bubble updated summaries up the tree | |
if parent_node is not None: | |
# refetch children | |
cur_graph_nodes = self.index_graph.get_children(parent_node) | |
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes) | |
text_chunk = self._prompt_helper.get_text_from_nodes( | |
cur_graph_node_list, prompt=self.summary_prompt | |
) | |
new_summary, _ = self._llm_predictor.predict( | |
self.summary_prompt, context_str=text_chunk | |
) | |
parent_node.text = new_summary | |
def insert(self, doc: BaseDocument) -> None: | |
"""Insert into index_graph.""" | |
text_chunks = self._text_splitter.split_text(doc.get_text()) | |
for text_chunk in text_chunks: | |
self._insert_node(text_chunk, doc, None) | |