Spaces:
Runtime error
Runtime error
File size: 2,306 Bytes
35b22df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
"""Query transform."""
from typing import Optional
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.prompts.base import Prompt
from gpt_index.prompts.default_prompts import DEFAULT_HYDE_PROMPT
class BaseQueryTransform:
"""Base class for query transform.
A query transform augments a raw query string with associated transformations
to improve index querying.
"""
def __call__(self, query_str: str) -> QueryBundle:
"""Run query processor."""
return QueryBundle(query_str=query_str, custom_embedding_strs=[query_str])
class HyDEQueryTransform(BaseQueryTransform):
"""Hypothetical Document Embeddings (HyDE) query transform.
It uses an LLM to generate hypothetical answer(s) to a given query,
and use the resulting documents as embedding strings.
As described in `[Precise Zero-Shot Dense Retrieval without Relevance Labels]
(https://arxiv.org/abs/2212.10496)`
"""
def __init__(
self,
llm_predictor: Optional[LLMPredictor] = None,
hyde_prompt: Optional[Prompt] = None,
include_original: bool = True,
) -> None:
"""Initialize HyDEQueryTransform.
Args:
llm_predictor (Optional[LLMPredictor]): LLM for generating
hypothetical documents
hyde_prompt (Optional[Prompt]): Custom prompt for HyDE
include_original (bool): Whether to include original query
string as one of the embedding strings
"""
super().__init__()
self._llm_predictor = llm_predictor or LLMPredictor()
self._hyde_prompt = hyde_prompt or DEFAULT_HYDE_PROMPT
self._include_original = include_original
def __call__(self, query_str: str) -> QueryBundle:
"""Run query transform."""
# TODO: support generating multiple hypothetical docs
hypothetical_doc, _ = self._llm_predictor.predict(
self._hyde_prompt, context_str=query_str
)
embedding_strs = [hypothetical_doc]
if self._include_original:
embedding_strs.append(query_str)
return QueryBundle(
query_str=query_str,
custom_embedding_strs=embedding_strs,
)
|