Spaces:
Runtime error
Runtime error
"""Node recency post-processor.""" | |
from gpt_index.indices.postprocessor.node import BaseNodePostprocessor | |
from gpt_index.indices.service_context import ServiceContext | |
from gpt_index.data_structs.node_v2 import Node | |
from pydantic import Field | |
from typing import Optional, Dict, List, Set | |
import pandas as pd | |
import numpy as np | |
# NOTE: currently not being used | |
# DEFAULT_INFER_RECENCY_TMPL = ( | |
# "A question is provided.\n" | |
# "The goal is to determine whether the question requires finding the most recent " | |
# "context.\n" | |
# "Please respond with YES or NO.\n" | |
# "Question: What is the current status of the patient?\n" | |
# "Answer: YES\n" | |
# "Question: What happened in the Battle of Yorktown?\n" | |
# "Answer: NO\n" | |
# "Question: What are the most recent changes to the project?\n" | |
# "Answer: YES\n" | |
# "Question: How did Harry defeat Voldemort in the Battle of Hogwarts?\n" | |
# "Answer: NO\n" | |
# "Question: {query_str}\n" | |
# "Answer: " | |
# ) | |
# def parse_recency_pred(pred: str) -> bool: | |
# """Parse recency prediction.""" | |
# if "YES" in pred: | |
# return True | |
# elif "NO" in pred: | |
# return False | |
# else: | |
# raise ValueError(f"Invalid recency prediction: {pred}.") | |
class FixedRecencyPostprocessor(BaseNodePostprocessor): | |
"""Recency post-processor. | |
This post-processor does the following steps: | |
- Decides if we need to use the post-processor given the query | |
(is it temporal-related?) | |
- If yes, sorts nodes by date. | |
- Take the first k nodes (by default 1), and use that to synthesize an answer. | |
""" | |
service_context: ServiceContext | |
top_k: int = 1 | |
# infer_recency_tmpl: str = Field(default=DEFAULT_INFER_RECENCY_TMPL) | |
date_key: str = "date" | |
# if false, then search node info | |
in_extra_info: bool = True | |
def postprocess_nodes( | |
self, nodes: List[Node], extra_info: Optional[Dict] = None | |
) -> List[Node]: | |
"""Postprocess nodes.""" | |
if extra_info is None or "query_bundle" not in extra_info: | |
raise ValueError("Missing query bundle in extra info.") | |
# query_bundle = cast(QueryBundle, extra_info["query_bundle"]) | |
# infer_recency_prompt = SimpleInputPrompt(self.infer_recency_tmpl) | |
# raw_pred, _ = self.service_context.llm_predictor.predict( | |
# prompt=infer_recency_prompt, | |
# query_str=query_bundle.query_str, | |
# ) | |
# pred = parse_recency_pred(raw_pred) | |
# # if no need to use recency post-processor, return nodes as is | |
# if not pred: | |
# return nodes | |
# sort nodes by date | |
info_dict_attr = "extra_info" if self.in_extra_info else "node_info" | |
node_dates = pd.to_datetime( | |
[getattr(node, info_dict_attr)[self.date_key] for node in nodes] | |
) | |
sorted_node_idxs = np.flip(node_dates.argsort()) | |
sorted_nodes = [nodes[idx] for idx in sorted_node_idxs] | |
return sorted_nodes[: self.top_k] | |
DEFAULT_QUERY_EMBEDDING_TMPL = ( | |
"The current document is provided.\n" | |
"----------------\n" | |
"{context_str}\n" | |
"----------------\n" | |
"Given the document, we wish to find documents that contain \n" | |
"similar context. Note that these documents are older " | |
"than the current document, meaning that certain details may be changed. \n" | |
"However, the high-level context should be similar.\n" | |
) | |
class EmbeddingRecencyPostprocessor(BaseNodePostprocessor): | |
"""Recency post-processor. | |
This post-processor does the following steps: | |
- Decides if we need to use the post-processor given the query | |
(is it temporal-related?) | |
- If yes, sorts nodes by date. | |
- For each node, look at subsequent nodes and filter out nodes | |
that have high embedding similarity with the current node. | |
(because this means ) | |
""" | |
service_context: ServiceContext | |
# infer_recency_tmpl: str = Field(default=DEFAULT_INFER_RECENCY_TMPL) | |
date_key: str = "date" | |
# if false, then search node info | |
in_extra_info: bool = True | |
similarity_cutoff: float = Field(default=0.7) | |
query_embedding_tmpl: str = Field(default=DEFAULT_QUERY_EMBEDDING_TMPL) | |
def postprocess_nodes( | |
self, nodes: List[Node], extra_info: Optional[Dict] = None | |
) -> List[Node]: | |
"""Postprocess nodes.""" | |
if extra_info is None or "query_bundle" not in extra_info: | |
raise ValueError("Missing query bundle in extra info.") | |
# query_bundle = cast(QueryBundle, extra_info["query_bundle"]) | |
# infer_recency_prompt = SimpleInputPrompt(self.infer_recency_tmpl) | |
# raw_pred, _ = self.service_context.llm_predictor.predict( | |
# prompt=infer_recency_prompt, | |
# query_str=query_bundle.query_str, | |
# ) | |
# pred = parse_recency_pred(raw_pred) | |
# # if no need to use recency post-processor, return nodes as is | |
# if not pred: | |
# return nodes | |
# sort nodes by date | |
info_dict_attr = "extra_info" if self.in_extra_info else "node_info" | |
node_dates = pd.to_datetime( | |
[getattr(node, info_dict_attr)[self.date_key] for node in nodes] | |
) | |
sorted_node_idxs = np.flip(node_dates.argsort()) | |
sorted_nodes: List[Node] = [nodes[idx] for idx in sorted_node_idxs] | |
# get embeddings for each node | |
embed_model = self.service_context.embed_model | |
for node in sorted_nodes: | |
embed_model.queue_text_for_embeddding(node.get_doc_id(), node.get_text()) | |
_, text_embeddings = embed_model.get_queued_text_embeddings() | |
node_ids_to_skip: Set[str] = set() | |
for idx, node in enumerate(sorted_nodes): | |
if node.get_doc_id() in node_ids_to_skip: | |
continue | |
# get query embedding for the "query" node | |
# NOTE: not the same as the text embedding because | |
# we want to optimize for retrieval results | |
query_text = self.query_embedding_tmpl.format( | |
context_str=node.get_text(), | |
) | |
query_embedding = embed_model.get_query_embedding(query_text) | |
for idx2 in range(idx + 1, len(sorted_nodes)): | |
if sorted_nodes[idx2].get_doc_id() in node_ids_to_skip: | |
continue | |
node2 = sorted_nodes[idx2] | |
if ( | |
np.dot(query_embedding, text_embeddings[idx2]) | |
> self.similarity_cutoff | |
): | |
node_ids_to_skip.add(node2.get_doc_id()) | |
return [ | |
node for node in sorted_nodes if node.get_doc_id() not in node_ids_to_skip | |
] | |