"""Node postprocessor.""" import re from abc import abstractmethod from typing import Dict, List, Optional, cast from pydantic import BaseModel, Field, validator import logging from gpt_index.indices.query.schema import QueryBundle from gpt_index.indices.service_context import ServiceContext from gpt_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt from gpt_index.docstore import DocumentStore from gpt_index.data_structs.node_v2 import Node, DocumentRelationship from gpt_index.indices.postprocessor.base import BasePostprocessor from gpt_index.indices.query.embedding_utils import SimilarityTracker from gpt_index.indices.response.builder import ResponseBuilder, TextChunk logger = logging.getLogger(__name__) class BaseNodePostprocessor(BasePostprocessor, BaseModel): """Node postprocessor.""" @abstractmethod def postprocess_nodes( self, nodes: List[Node], extra_info: Optional[Dict] = None ) -> List[Node]: """Postprocess nodes.""" class KeywordNodePostprocessor(BaseNodePostprocessor): """Keyword-based Node processor.""" required_keywords: List[str] = Field(default_factory=list) exclude_keywords: List[str] = Field(default_factory=list) def postprocess_nodes( self, nodes: List[Node], extra_info: Optional[Dict] = None ) -> List[Node]: """Postprocess nodes.""" new_nodes = [] for node in nodes: words = re.findall(r"\w+", node.get_text()) should_use_node = True if self.required_keywords is not None: for w in self.required_keywords: if w not in words: should_use_node = False if self.exclude_keywords is not None: for w in self.exclude_keywords: if w in words: should_use_node = False if should_use_node: new_nodes.append(node) return new_nodes class SimilarityPostprocessor(BaseNodePostprocessor): """Similarity-based Node processor.""" similarity_cutoff: float = Field(default=None) def postprocess_nodes( self, nodes: List[Node], extra_info: Optional[Dict] = None ) -> List[Node]: """Postprocess nodes.""" extra_info = extra_info or {} similarity_tracker = extra_info.get("similarity_tracker", None) if similarity_tracker is None: return nodes sim_cutoff_exists = ( similarity_tracker is not None and self.similarity_cutoff is not None ) new_nodes = [] for node in nodes: should_use_node = True if sim_cutoff_exists: similarity = cast(SimilarityTracker, similarity_tracker).find(node) if similarity is None: should_use_node = False if cast(float, similarity) < cast(float, self.similarity_cutoff): should_use_node = False if should_use_node: new_nodes.append(node) return new_nodes def get_forward_nodes( node: Node, num_nodes: int, docstore: DocumentStore ) -> Dict[str, Node]: """Get forward nodes.""" nodes: Dict[str, Node] = {node.get_doc_id(): node} cur_count = 0 # get forward nodes in an iterative manner while cur_count < num_nodes: if DocumentRelationship.NEXT not in node.relationships: break next_node_id = node.relationships[DocumentRelationship.NEXT] next_node = docstore.get_node(next_node_id) if next_node is None: break nodes[next_node.get_doc_id()] = next_node node = next_node cur_count += 1 return nodes def get_backward_nodes( node: Node, num_nodes: int, docstore: DocumentStore ) -> Dict[str, Node]: """Get backward nodes.""" # get backward nodes in an iterative manner nodes: Dict[str, Node] = {node.get_doc_id(): node} cur_count = 0 while cur_count < num_nodes: if DocumentRelationship.PREVIOUS not in node.relationships: break prev_node_id = node.relationships[DocumentRelationship.PREVIOUS] prev_node = docstore.get_node(prev_node_id) if prev_node is None: break nodes[prev_node.get_doc_id()] = prev_node node = prev_node cur_count += 1 return nodes class PrevNextNodePostprocessor(BaseNodePostprocessor): """Previous/Next Node post-processor. Allows users to fetch additional nodes from the document store, based on the relationships of the nodes. NOTE: this is a beta feature. Args: docstore (DocumentStore): The document store. num_nodes (int): The number of nodes to return (default: 1) mode (str): The mode of the post-processor. Can be "previous", "next", or "both. """ docstore: DocumentStore num_nodes: int = Field(default=1) mode: str = Field(default="next") def _get_backward_nodes(self, node: Node) -> Dict[str, Node]: """Get backward nodes.""" # get backward nodes in an iterative manner nodes: Dict[str, Node] = {node.get_doc_id(): node} cur_count = 0 while cur_count < self.num_nodes: if DocumentRelationship.PREVIOUS not in node.relationships: break prev_node_id = node.relationships[DocumentRelationship.PREVIOUS] prev_node = self.docstore.get_node(prev_node_id) if prev_node is None: break nodes[prev_node.get_doc_id()] = prev_node node = prev_node cur_count += 1 return nodes @validator("mode") def _validate_mode(cls, v: str) -> str: """Validate mode.""" if v not in ["next", "previous", "both"]: raise ValueError(f"Invalid mode: {v}") return v def postprocess_nodes( self, nodes: List[Node], extra_info: Optional[Dict] = None ) -> List[Node]: """Postprocess nodes.""" all_nodes: Dict[str, Node] = {} for node in nodes: all_nodes[node.get_doc_id()] = node if self.mode == "next": all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore)) elif self.mode == "previous": all_nodes.update( get_backward_nodes(node, self.num_nodes, self.docstore) ) elif self.mode == "both": all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore)) all_nodes.update( get_backward_nodes(node, self.num_nodes, self.docstore) ) else: raise ValueError(f"Invalid mode: {self.mode}") sorted_nodes = sorted(all_nodes.values(), key=lambda x: x.get_doc_id()) return list(sorted_nodes) DEFAULT_INFER_PREV_NEXT_TMPL = ( "The current context information is provided. \n" "A question is also provided. \n" "You are a retrieval agent deciding whether to search the " "document store for additional prior context or future context. \n" "Given the context and question, return PREVIOUS or NEXT or NONE. \n" "Examples: \n\n" "Context: Describes the author's experience at Y Combinator." "Question: What did the author do after his time at Y Combinator? \n" "Answer: NEXT \n\n" "Context: Describes the author's experience at Y Combinator." "Question: What did the author do before his time at Y Combinator? \n" "Answer: PREVIOUS \n\n" "Context: Describe the author's experience at Y Combinator." "Question: What did the author do at Y Combinator? \n" "Answer: NONE \n\n" "Context: {context_str}\n" "Question: {query_str}\n" "Answer: " ) DEFAULT_REFINE_INFER_PREV_NEXT_TMPL = ( "The current context information is provided. \n" "A question is also provided. \n" "An existing answer is also provided.\n" "You are a retrieval agent deciding whether to search the " "document store for additional prior context or future context. \n" "Given the context, question, and previous answer, " "return PREVIOUS or NEXT or NONE.\n" "Examples: \n\n" "Context: {context_msg}\n" "Question: {query_str}\n" "Existing Answer: {existing_answer}\n" "Answer: " ) class AutoPrevNextNodePostprocessor(BaseNodePostprocessor): """Previous/Next Node post-processor. Allows users to fetch additional nodes from the document store, based on the prev/next relationships of the nodes. NOTE: difference with PrevNextPostprocessor is that this infers forward/backwards direction. NOTE: this is a beta feature. Args: docstore (DocumentStore): The document store. llm_predictor (LLMPredictor): The LLM predictor. num_nodes (int): The number of nodes to return (default: 1) infer_prev_next_tmpl (str): The template to use for inference. Required fields are {context_str} and {query_str}. """ docstore: DocumentStore service_context: ServiceContext num_nodes: int = Field(default=1) infer_prev_next_tmpl: str = Field(default=DEFAULT_INFER_PREV_NEXT_TMPL) refine_prev_next_tmpl: str = Field(default=DEFAULT_REFINE_INFER_PREV_NEXT_TMPL) verbose: bool = Field(default=False) class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True def _parse_prediction(self, raw_pred: str) -> str: """Parse prediction.""" pred = raw_pred.strip().lower() if "previous" in pred: return "previous" elif "next" in pred: return "next" elif "none" in pred: return "none" raise ValueError(f"Invalid prediction: {raw_pred}") def postprocess_nodes( self, nodes: List[Node], extra_info: Optional[Dict] = None ) -> List[Node]: """Postprocess nodes.""" if extra_info is None or "query_bundle" not in extra_info: raise ValueError("Missing query bundle in extra info.") query_bundle = cast(QueryBundle, extra_info["query_bundle"]) infer_prev_next_prompt = QuestionAnswerPrompt( self.infer_prev_next_tmpl, ) refine_infer_prev_next_prompt = RefinePrompt(self.refine_prev_next_tmpl) all_nodes: Dict[str, Node] = {} for node in nodes: all_nodes[node.get_doc_id()] = node # use response builder instead of llm_predictor directly # to be more robust to handling long context response_builder = ResponseBuilder( self.service_context, infer_prev_next_prompt, refine_infer_prev_next_prompt, ) response_builder.add_text_chunks([TextChunk(node.get_text())]) raw_pred = response_builder.get_response( query_str=query_bundle.query_str, response_mode="tree_summarize", ) raw_pred = cast(str, raw_pred) mode = self._parse_prediction(raw_pred) logger.debug(f"> Postprocessor Predicted mode: {mode}") if self.verbose: print(f"> Postprocessor Predicted mode: {mode}") if mode == "next": all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore)) elif mode == "previous": all_nodes.update( get_backward_nodes(node, self.num_nodes, self.docstore) ) elif mode == "none": pass else: raise ValueError(f"Invalid mode: {mode}") sorted_nodes = sorted(all_nodes.values(), key=lambda x: x.get_doc_id()) return list(sorted_nodes)