AbeerTrial's picture
Upload folder using huggingface_hub
8a58cf3
raw
history blame
6.87 kB
"""GPT Tree Index inserter."""
from typing import Optional
from gpt_index.data_structs.data_structs import IndexGraph, Node
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.utils import extract_numbers_given_response, get_sorted_node_list
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.text_splitter import TextSplitter
from gpt_index.prompts.base import Prompt
from gpt_index.prompts.default_prompts import (
DEFAULT_INSERT_PROMPT,
DEFAULT_SUMMARY_PROMPT,
)
from gpt_index.schema import BaseDocument
class GPTIndexInserter:
"""LlamaIndex inserter."""
def __init__(
self,
index_graph: IndexGraph,
llm_predictor: LLMPredictor,
prompt_helper: PromptHelper,
text_splitter: TextSplitter,
num_children: int = 10,
insert_prompt: Prompt = DEFAULT_INSERT_PROMPT,
summary_prompt: Prompt = DEFAULT_SUMMARY_PROMPT,
) -> None:
"""Initialize with params."""
if num_children < 2:
raise ValueError("Invalid number of children.")
self.num_children = num_children
self.summary_prompt = summary_prompt
self.insert_prompt = insert_prompt
self.index_graph = index_graph
self._llm_predictor = llm_predictor
self._prompt_helper = prompt_helper
self._text_splitter = text_splitter
def _insert_under_parent_and_consolidate(
self, text_chunk: str, doc: BaseDocument, parent_node: Optional[Node]
) -> None:
"""Insert node under parent and consolidate.
Consolidation will happen by dividing up child nodes, and creating a new
intermediate layer of nodes.
"""
# perform insertion
text_node = Node(
text=text_chunk,
index=self.index_graph.size,
ref_doc_id=doc.get_doc_id(),
embedding=doc.embedding,
extra_info=doc.extra_info,
)
self.index_graph.insert_under_parent(text_node, parent_node)
# if under num_children limit, then we're fine
if len(self.index_graph.get_children(parent_node)) <= self.num_children:
return
else:
# perform consolidation
cur_graph_nodes = self.index_graph.get_children(parent_node)
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
# this layer is all leaf nodes, consolidate and split leaf nodes
cur_node_index = self.index_graph.size
# consolidate and split leaf nodes in half
# TODO: do better splitting (with a GPT prompt etc.)
half1 = cur_graph_node_list[: len(cur_graph_nodes) // 2]
half2 = cur_graph_node_list[len(cur_graph_nodes) // 2 :]
text_chunk1 = self._prompt_helper.get_text_from_nodes(
half1, prompt=self.summary_prompt
)
summary1, _ = self._llm_predictor.predict(
self.summary_prompt, context_str=text_chunk1
)
node1 = Node(
text=summary1,
index=cur_node_index,
child_indices={n.index for n in half1},
)
text_chunk2 = self._prompt_helper.get_text_from_nodes(
half2, prompt=self.summary_prompt
)
summary2, _ = self._llm_predictor.predict(
self.summary_prompt, context_str=text_chunk2
)
node2 = Node(
text=summary2,
index=cur_node_index + 1,
child_indices={n.index for n in half2},
)
# insert half1 and half2 as new children of parent_node
# first remove child indices from parent node
if parent_node is not None:
parent_node.child_indices = set()
else:
self.index_graph.root_nodes = {}
self.index_graph.insert_under_parent(node1, parent_node)
self.index_graph.insert_under_parent(node2, parent_node)
def _insert_node(
self, text_chunk: str, doc: BaseDocument, parent_node: Optional[Node]
) -> None:
"""Insert node."""
cur_graph_nodes = self.index_graph.get_children(parent_node)
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
# if cur_graph_nodes is empty (start with empty graph), then insert under
# parent (insert new root node)
if len(cur_graph_nodes) == 0:
self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node)
# check if leaf nodes, then just insert under parent
elif len(cur_graph_node_list[0].child_indices) == 0:
self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node)
# else try to find the right summary node to insert under
else:
numbered_text = self._prompt_helper.get_numbered_text_from_nodes(
cur_graph_node_list, prompt=self.insert_prompt
)
response, _ = self._llm_predictor.predict(
self.insert_prompt,
new_chunk_text=text_chunk,
num_chunks=len(cur_graph_node_list),
context_list=numbered_text,
)
numbers = extract_numbers_given_response(response)
if numbers is None or len(numbers) == 0:
# NOTE: if we can't extract a number, then we just insert under parent
self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node)
elif int(numbers[0]) > len(cur_graph_node_list):
# NOTE: if number is out of range, then we just insert under parent
self._insert_under_parent_and_consolidate(text_chunk, doc, parent_node)
else:
selected_node = cur_graph_node_list[int(numbers[0]) - 1]
self._insert_node(text_chunk, doc, selected_node)
# now we need to update summary for parent node, since we
# need to bubble updated summaries up the tree
if parent_node is not None:
# refetch children
cur_graph_nodes = self.index_graph.get_children(parent_node)
cur_graph_node_list = get_sorted_node_list(cur_graph_nodes)
text_chunk = self._prompt_helper.get_text_from_nodes(
cur_graph_node_list, prompt=self.summary_prompt
)
new_summary, _ = self._llm_predictor.predict(
self.summary_prompt, context_str=text_chunk
)
parent_node.text = new_summary
def insert(self, doc: BaseDocument) -> None:
"""Insert into index_graph."""
text_chunks = self._text_splitter.split_text(doc.get_text())
for text_chunk in text_chunks:
self._insert_node(text_chunk, doc, None)