"""Common classes/functions for tree index operations.""" import asyncio import logging from typing import Dict, List, Optional, Sequence, Tuple from gpt_index.async_utils import run_async_tasks from gpt_index.data_structs.data_structs_v2 import IndexGraph from gpt_index.data_structs.node_v2 import Node from gpt_index.docstore import DocumentStore from gpt_index.indices.service_context import ServiceContext from gpt_index.indices.utils import get_sorted_node_list, truncate_text from gpt_index.prompts.prompts import SummaryPrompt logger = logging.getLogger(__name__) class GPTTreeIndexBuilder: """GPT tree index builder. Helper class to build the tree-structured index, or to synthesize an answer. """ def __init__( self, num_children: int, summary_prompt: SummaryPrompt, service_context: ServiceContext, docstore: Optional[DocumentStore] = None, use_async: bool = False, ) -> None: """Initialize with params.""" if num_children < 2: raise ValueError("Invalid number of children.") self.num_children = num_children self.summary_prompt = summary_prompt self._service_context = service_context self._use_async = use_async self._docstore = docstore or DocumentStore() @property def docstore(self) -> DocumentStore: """Return docstore.""" return self._docstore def build_from_nodes( self, nodes: Sequence[Node], build_tree: bool = True, ) -> IndexGraph: """Build from text. Returns: IndexGraph: graph object consisting of all_nodes, root_nodes """ index_graph = IndexGraph() for node in nodes: index_graph.insert(node) if build_tree: return self.build_index_from_nodes( index_graph, index_graph.all_nodes, index_graph.all_nodes, level=0 ) else: return index_graph def _prepare_node_and_text_chunks( self, cur_node_ids: Dict[int, str] ) -> Tuple[List[int], List[List[Node]], List[str]]: """Prepare node and text chunks.""" cur_nodes = { index: self._docstore.get_node(node_id) for index, node_id in cur_node_ids.items() } cur_node_list = get_sorted_node_list(cur_nodes) logger.info( f"> Building index from nodes: {len(cur_nodes) // self.num_children} chunks" ) indices, cur_nodes_chunks, text_chunks = [], [], [] for i in range(0, len(cur_node_list), self.num_children): cur_nodes_chunk = cur_node_list[i : i + self.num_children] text_chunk = self._service_context.prompt_helper.get_text_from_nodes( cur_nodes_chunk, prompt=self.summary_prompt ) indices.append(i) cur_nodes_chunks.append(cur_nodes_chunk) text_chunks.append(text_chunk) return indices, cur_nodes_chunks, text_chunks def _construct_parent_nodes( self, index_graph: IndexGraph, indices: List[int], cur_nodes_chunks: List[List[Node]], summaries: List[str], ) -> Dict[int, str]: """Construct parent nodes. Save nodes to docstore. """ new_node_dict = {} for i, cur_nodes_chunk, new_summary in zip( indices, cur_nodes_chunks, summaries ): logger.debug( f"> {i}/{len(cur_nodes_chunk)}, " f"summary: {truncate_text(new_summary, 50)}" ) new_node = Node( text=new_summary, ) index_graph.insert(new_node, children_nodes=cur_nodes_chunk) index = index_graph.get_index(new_node) new_node_dict[index] = new_node.get_doc_id() self._docstore.add_documents([new_node], allow_update=False) return new_node_dict def build_index_from_nodes( self, index_graph: IndexGraph, cur_node_ids: Dict[int, str], all_node_ids: Dict[int, str], level: int = 0, ) -> IndexGraph: """Consolidates chunks recursively, in a bottoms-up fashion.""" if len(cur_node_ids) <= self.num_children: index_graph.root_nodes = cur_node_ids return index_graph indices, cur_nodes_chunks, text_chunks = self._prepare_node_and_text_chunks( cur_node_ids ) if self._use_async: tasks = [ self._service_context.llm_predictor.apredict( self.summary_prompt, context_str=text_chunk ) for text_chunk in text_chunks ] outputs: List[Tuple[str, str]] = run_async_tasks(tasks) summaries = [output[0] for output in outputs] else: summaries = [ self._service_context.llm_predictor.predict( self.summary_prompt, context_str=text_chunk )[0] for text_chunk in text_chunks ] self._service_context.llama_logger.add_log( {"summaries": summaries, "level": level} ) new_node_dict = self._construct_parent_nodes( index_graph, indices, cur_nodes_chunks, summaries ) all_node_ids.update(new_node_dict) index_graph.root_nodes = new_node_dict if len(new_node_dict) <= self.num_children: return index_graph else: return self.build_index_from_nodes( index_graph, new_node_dict, all_node_ids, level=level + 1 ) async def abuild_index_from_nodes( self, index_graph: IndexGraph, cur_node_ids: Dict[int, str], all_node_ids: Dict[int, str], level: int = 0, ) -> IndexGraph: """Consolidates chunks recursively, in a bottoms-up fashion.""" if len(cur_node_ids) <= self.num_children: index_graph.root_nodes = cur_node_ids return index_graph indices, cur_nodes_chunks, text_chunks = self._prepare_node_and_text_chunks( cur_node_ids ) tasks = [ self._service_context.llm_predictor.apredict( self.summary_prompt, context_str=text_chunk ) for text_chunk in text_chunks ] outputs: List[Tuple[str, str]] = await asyncio.gather(*tasks) summaries = [output[0] for output in outputs] self._service_context.llama_logger.add_log( {"summaries": summaries, "level": level} ) new_node_dict = self._construct_parent_nodes( index_graph, indices, cur_nodes_chunks, summaries ) all_node_ids.update(new_node_dict) index_graph.root_nodes = new_node_dict if len(new_node_dict) <= self.num_children: return index_graph else: return await self.abuild_index_from_nodes( index_graph, new_node_dict, all_node_ids, level=level + 1 )