File size: 11,855 Bytes
b699122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
"""Node postprocessor."""

import re
from abc import abstractmethod
from typing import Dict, List, Optional, cast

from pydantic import BaseModel, Field, validator

import logging
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.service_context import ServiceContext
from gpt_index.prompts.prompts import QuestionAnswerPrompt, RefinePrompt
from gpt_index.docstore import DocumentStore
from gpt_index.data_structs.node_v2 import Node, DocumentRelationship
from gpt_index.indices.postprocessor.base import BasePostprocessor
from gpt_index.indices.query.embedding_utils import SimilarityTracker
from gpt_index.indices.response.builder import ResponseBuilder, TextChunk

logger = logging.getLogger(__name__)


class BaseNodePostprocessor(BasePostprocessor, BaseModel):
    """Node postprocessor."""

    @abstractmethod
    def postprocess_nodes(
        self, nodes: List[Node], extra_info: Optional[Dict] = None
    ) -> List[Node]:
        """Postprocess nodes."""


class KeywordNodePostprocessor(BaseNodePostprocessor):
    """Keyword-based Node processor."""

    required_keywords: List[str] = Field(default_factory=list)
    exclude_keywords: List[str] = Field(default_factory=list)

    def postprocess_nodes(
        self, nodes: List[Node], extra_info: Optional[Dict] = None
    ) -> List[Node]:
        """Postprocess nodes."""
        new_nodes = []
        for node in nodes:
            words = re.findall(r"\w+", node.get_text())
            should_use_node = True
            if self.required_keywords is not None:
                for w in self.required_keywords:
                    if w not in words:
                        should_use_node = False

            if self.exclude_keywords is not None:
                for w in self.exclude_keywords:
                    if w in words:
                        should_use_node = False

            if should_use_node:
                new_nodes.append(node)

        return new_nodes


class SimilarityPostprocessor(BaseNodePostprocessor):
    """Similarity-based Node processor."""

    similarity_cutoff: float = Field(default=None)

    def postprocess_nodes(
        self, nodes: List[Node], extra_info: Optional[Dict] = None
    ) -> List[Node]:
        """Postprocess nodes."""
        extra_info = extra_info or {}
        similarity_tracker = extra_info.get("similarity_tracker", None)
        if similarity_tracker is None:
            return nodes
        sim_cutoff_exists = (
            similarity_tracker is not None and self.similarity_cutoff is not None
        )

        new_nodes = []
        for node in nodes:
            should_use_node = True
            if sim_cutoff_exists:
                similarity = cast(SimilarityTracker, similarity_tracker).find(node)
                if similarity is None:
                    should_use_node = False
                if cast(float, similarity) < cast(float, self.similarity_cutoff):
                    should_use_node = False

            if should_use_node:
                new_nodes.append(node)

        return new_nodes


def get_forward_nodes(
    node: Node, num_nodes: int, docstore: DocumentStore
) -> Dict[str, Node]:
    """Get forward nodes."""
    nodes: Dict[str, Node] = {node.get_doc_id(): node}
    cur_count = 0
    # get forward nodes in an iterative manner
    while cur_count < num_nodes:
        if DocumentRelationship.NEXT not in node.relationships:
            break
        next_node_id = node.relationships[DocumentRelationship.NEXT]
        next_node = docstore.get_node(next_node_id)
        if next_node is None:
            break
        nodes[next_node.get_doc_id()] = next_node
        node = next_node
        cur_count += 1
    return nodes


def get_backward_nodes(
    node: Node, num_nodes: int, docstore: DocumentStore
) -> Dict[str, Node]:
    """Get backward nodes."""
    # get backward nodes in an iterative manner
    nodes: Dict[str, Node] = {node.get_doc_id(): node}
    cur_count = 0
    while cur_count < num_nodes:
        if DocumentRelationship.PREVIOUS not in node.relationships:
            break
        prev_node_id = node.relationships[DocumentRelationship.PREVIOUS]
        prev_node = docstore.get_node(prev_node_id)
        if prev_node is None:
            break
        nodes[prev_node.get_doc_id()] = prev_node
        node = prev_node
        cur_count += 1
    return nodes


class PrevNextNodePostprocessor(BaseNodePostprocessor):
    """Previous/Next Node post-processor.

    Allows users to fetch additional nodes from the document store,
    based on the relationships of the nodes.

    NOTE: this is a beta feature.

    Args:
        docstore (DocumentStore): The document store.
        num_nodes (int): The number of nodes to return (default: 1)
        mode (str): The mode of the post-processor.
            Can be "previous", "next", or "both.

    """

    docstore: DocumentStore
    num_nodes: int = Field(default=1)
    mode: str = Field(default="next")

    def _get_backward_nodes(self, node: Node) -> Dict[str, Node]:
        """Get backward nodes."""
        # get backward nodes in an iterative manner
        nodes: Dict[str, Node] = {node.get_doc_id(): node}
        cur_count = 0
        while cur_count < self.num_nodes:
            if DocumentRelationship.PREVIOUS not in node.relationships:
                break
            prev_node_id = node.relationships[DocumentRelationship.PREVIOUS]
            prev_node = self.docstore.get_node(prev_node_id)
            if prev_node is None:
                break
            nodes[prev_node.get_doc_id()] = prev_node
            node = prev_node
            cur_count += 1
        return nodes

    @validator("mode")
    def _validate_mode(cls, v: str) -> str:
        """Validate mode."""
        if v not in ["next", "previous", "both"]:
            raise ValueError(f"Invalid mode: {v}")
        return v

    def postprocess_nodes(
        self, nodes: List[Node], extra_info: Optional[Dict] = None
    ) -> List[Node]:
        """Postprocess nodes."""
        all_nodes: Dict[str, Node] = {}
        for node in nodes:
            all_nodes[node.get_doc_id()] = node
            if self.mode == "next":
                all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
            elif self.mode == "previous":
                all_nodes.update(
                    get_backward_nodes(node, self.num_nodes, self.docstore)
                )
            elif self.mode == "both":
                all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
                all_nodes.update(
                    get_backward_nodes(node, self.num_nodes, self.docstore)
                )
            else:
                raise ValueError(f"Invalid mode: {self.mode}")

        sorted_nodes = sorted(all_nodes.values(), key=lambda x: x.get_doc_id())
        return list(sorted_nodes)


DEFAULT_INFER_PREV_NEXT_TMPL = (
    "The current context information is provided. \n"
    "A question is also provided. \n"
    "You are a retrieval agent deciding whether to search the "
    "document store for additional prior context or future context. \n"
    "Given the context and question, return PREVIOUS or NEXT or NONE. \n"
    "Examples: \n\n"
    "Context: Describes the author's experience at Y Combinator."
    "Question: What did the author do after his time at Y Combinator? \n"
    "Answer: NEXT \n\n"
    "Context: Describes the author's experience at Y Combinator."
    "Question: What did the author do before his time at Y Combinator? \n"
    "Answer: PREVIOUS \n\n"
    "Context: Describe the author's experience at Y Combinator."
    "Question: What did the author do at Y Combinator? \n"
    "Answer: NONE \n\n"
    "Context: {context_str}\n"
    "Question: {query_str}\n"
    "Answer: "
)


DEFAULT_REFINE_INFER_PREV_NEXT_TMPL = (
    "The current context information is provided. \n"
    "A question is also provided. \n"
    "An existing answer is also provided.\n"
    "You are a retrieval agent deciding whether to search the "
    "document store for additional prior context or future context. \n"
    "Given the context, question, and previous answer, "
    "return PREVIOUS or NEXT or NONE.\n"
    "Examples: \n\n"
    "Context: {context_msg}\n"
    "Question: {query_str}\n"
    "Existing Answer: {existing_answer}\n"
    "Answer: "
)


class AutoPrevNextNodePostprocessor(BaseNodePostprocessor):
    """Previous/Next Node post-processor.

    Allows users to fetch additional nodes from the document store,
    based on the prev/next relationships of the nodes.

    NOTE: difference with PrevNextPostprocessor is that
    this infers forward/backwards direction.

    NOTE: this is a beta feature.

    Args:
        docstore (DocumentStore): The document store.
        llm_predictor (LLMPredictor): The LLM predictor.
        num_nodes (int): The number of nodes to return (default: 1)
        infer_prev_next_tmpl (str): The template to use for inference.
            Required fields are {context_str} and {query_str}.

    """

    docstore: DocumentStore
    service_context: ServiceContext
    num_nodes: int = Field(default=1)
    infer_prev_next_tmpl: str = Field(default=DEFAULT_INFER_PREV_NEXT_TMPL)
    refine_prev_next_tmpl: str = Field(default=DEFAULT_REFINE_INFER_PREV_NEXT_TMPL)
    verbose: bool = Field(default=False)

    class Config:
        """Configuration for this pydantic object."""

        arbitrary_types_allowed = True

    def _parse_prediction(self, raw_pred: str) -> str:
        """Parse prediction."""
        pred = raw_pred.strip().lower()
        if "previous" in pred:
            return "previous"
        elif "next" in pred:
            return "next"
        elif "none" in pred:
            return "none"
        raise ValueError(f"Invalid prediction: {raw_pred}")

    def postprocess_nodes(
        self, nodes: List[Node], extra_info: Optional[Dict] = None
    ) -> List[Node]:
        """Postprocess nodes."""
        if extra_info is None or "query_bundle" not in extra_info:
            raise ValueError("Missing query bundle in extra info.")

        query_bundle = cast(QueryBundle, extra_info["query_bundle"])

        infer_prev_next_prompt = QuestionAnswerPrompt(
            self.infer_prev_next_tmpl,
        )
        refine_infer_prev_next_prompt = RefinePrompt(self.refine_prev_next_tmpl)

        all_nodes: Dict[str, Node] = {}
        for node in nodes:
            all_nodes[node.get_doc_id()] = node
            # use response builder instead of llm_predictor directly
            # to be more robust to handling long context
            response_builder = ResponseBuilder(
                self.service_context,
                infer_prev_next_prompt,
                refine_infer_prev_next_prompt,
            )
            response_builder.add_text_chunks([TextChunk(node.get_text())])
            raw_pred = response_builder.get_response(
                query_str=query_bundle.query_str,
                response_mode="tree_summarize",
            )
            raw_pred = cast(str, raw_pred)
            mode = self._parse_prediction(raw_pred)

            logger.debug(f"> Postprocessor Predicted mode: {mode}")
            if self.verbose:
                print(f"> Postprocessor Predicted mode: {mode}")

            if mode == "next":
                all_nodes.update(get_forward_nodes(node, self.num_nodes, self.docstore))
            elif mode == "previous":
                all_nodes.update(
                    get_backward_nodes(node, self.num_nodes, self.docstore)
                )
            elif mode == "none":
                pass
            else:
                raise ValueError(f"Invalid mode: {mode}")

        sorted_nodes = sorted(all_nodes.values(), key=lambda x: x.get_doc_id())
        return list(sorted_nodes)