binhnase04854's picture
first deploy
b699122
"""Node recency post-processor."""
from gpt_index.indices.postprocessor.node import BaseNodePostprocessor
from gpt_index.indices.service_context import ServiceContext
from gpt_index.data_structs.node_v2 import Node
from pydantic import Field
from typing import Optional, Dict, List, Set
import pandas as pd
import numpy as np
# NOTE: currently not being used
# DEFAULT_INFER_RECENCY_TMPL = (
# "A question is provided.\n"
# "The goal is to determine whether the question requires finding the most recent "
# "context.\n"
# "Please respond with YES or NO.\n"
# "Question: What is the current status of the patient?\n"
# "Answer: YES\n"
# "Question: What happened in the Battle of Yorktown?\n"
# "Answer: NO\n"
# "Question: What are the most recent changes to the project?\n"
# "Answer: YES\n"
# "Question: How did Harry defeat Voldemort in the Battle of Hogwarts?\n"
# "Answer: NO\n"
# "Question: {query_str}\n"
# "Answer: "
# )
# def parse_recency_pred(pred: str) -> bool:
# """Parse recency prediction."""
# if "YES" in pred:
# return True
# elif "NO" in pred:
# return False
# else:
# raise ValueError(f"Invalid recency prediction: {pred}.")
class FixedRecencyPostprocessor(BaseNodePostprocessor):
"""Recency post-processor.
This post-processor does the following steps:
- Decides if we need to use the post-processor given the query
(is it temporal-related?)
- If yes, sorts nodes by date.
- Take the first k nodes (by default 1), and use that to synthesize an answer.
"""
service_context: ServiceContext
top_k: int = 1
# infer_recency_tmpl: str = Field(default=DEFAULT_INFER_RECENCY_TMPL)
date_key: str = "date"
# if false, then search node info
in_extra_info: bool = True
def postprocess_nodes(
self, nodes: List[Node], extra_info: Optional[Dict] = None
) -> List[Node]:
"""Postprocess nodes."""
if extra_info is None or "query_bundle" not in extra_info:
raise ValueError("Missing query bundle in extra info.")
# query_bundle = cast(QueryBundle, extra_info["query_bundle"])
# infer_recency_prompt = SimpleInputPrompt(self.infer_recency_tmpl)
# raw_pred, _ = self.service_context.llm_predictor.predict(
# prompt=infer_recency_prompt,
# query_str=query_bundle.query_str,
# )
# pred = parse_recency_pred(raw_pred)
# # if no need to use recency post-processor, return nodes as is
# if not pred:
# return nodes
# sort nodes by date
info_dict_attr = "extra_info" if self.in_extra_info else "node_info"
node_dates = pd.to_datetime(
[getattr(node, info_dict_attr)[self.date_key] for node in nodes]
)
sorted_node_idxs = np.flip(node_dates.argsort())
sorted_nodes = [nodes[idx] for idx in sorted_node_idxs]
return sorted_nodes[: self.top_k]
DEFAULT_QUERY_EMBEDDING_TMPL = (
"The current document is provided.\n"
"----------------\n"
"{context_str}\n"
"----------------\n"
"Given the document, we wish to find documents that contain \n"
"similar context. Note that these documents are older "
"than the current document, meaning that certain details may be changed. \n"
"However, the high-level context should be similar.\n"
)
class EmbeddingRecencyPostprocessor(BaseNodePostprocessor):
"""Recency post-processor.
This post-processor does the following steps:
- Decides if we need to use the post-processor given the query
(is it temporal-related?)
- If yes, sorts nodes by date.
- For each node, look at subsequent nodes and filter out nodes
that have high embedding similarity with the current node.
(because this means )
"""
service_context: ServiceContext
# infer_recency_tmpl: str = Field(default=DEFAULT_INFER_RECENCY_TMPL)
date_key: str = "date"
# if false, then search node info
in_extra_info: bool = True
similarity_cutoff: float = Field(default=0.7)
query_embedding_tmpl: str = Field(default=DEFAULT_QUERY_EMBEDDING_TMPL)
def postprocess_nodes(
self, nodes: List[Node], extra_info: Optional[Dict] = None
) -> List[Node]:
"""Postprocess nodes."""
if extra_info is None or "query_bundle" not in extra_info:
raise ValueError("Missing query bundle in extra info.")
# query_bundle = cast(QueryBundle, extra_info["query_bundle"])
# infer_recency_prompt = SimpleInputPrompt(self.infer_recency_tmpl)
# raw_pred, _ = self.service_context.llm_predictor.predict(
# prompt=infer_recency_prompt,
# query_str=query_bundle.query_str,
# )
# pred = parse_recency_pred(raw_pred)
# # if no need to use recency post-processor, return nodes as is
# if not pred:
# return nodes
# sort nodes by date
info_dict_attr = "extra_info" if self.in_extra_info else "node_info"
node_dates = pd.to_datetime(
[getattr(node, info_dict_attr)[self.date_key] for node in nodes]
)
sorted_node_idxs = np.flip(node_dates.argsort())
sorted_nodes: List[Node] = [nodes[idx] for idx in sorted_node_idxs]
# get embeddings for each node
embed_model = self.service_context.embed_model
for node in sorted_nodes:
embed_model.queue_text_for_embeddding(node.get_doc_id(), node.get_text())
_, text_embeddings = embed_model.get_queued_text_embeddings()
node_ids_to_skip: Set[str] = set()
for idx, node in enumerate(sorted_nodes):
if node.get_doc_id() in node_ids_to_skip:
continue
# get query embedding for the "query" node
# NOTE: not the same as the text embedding because
# we want to optimize for retrieval results
query_text = self.query_embedding_tmpl.format(
context_str=node.get_text(),
)
query_embedding = embed_model.get_query_embedding(query_text)
for idx2 in range(idx + 1, len(sorted_nodes)):
if sorted_nodes[idx2].get_doc_id() in node_ids_to_skip:
continue
node2 = sorted_nodes[idx2]
if (
np.dot(query_embedding, text_embeddings[idx2])
> self.similarity_cutoff
):
node_ids_to_skip.add(node2.get_doc_id())
return [
node for node in sorted_nodes if node.get_doc_id() not in node_ids_to_skip
]