Spaces:
Sleeping
Sleeping
File size: 5,693 Bytes
a19a241 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# file: retrieval.py
import time
import asyncio
import numpy as np
import torch
from groq import AsyncGroq
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
from sklearn.preprocessing import MinMaxScaler
from torch.nn.functional import cosine_similarity
from typing import List, Dict, Tuple
from embedding import EmbeddingClient
from langchain_core.documents import Document
# --- Configuration ---
HYDE_MODEL = "llama3-8b-8192"
RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
INITIAL_K_CANDIDATES = 20
TOP_K_CHUNKS = 10
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
"""Generates a hypothetical document (HyDE) to enhance search."""
if not groq_api_key:
print("Groq API key not set. Skipping HyDE generation.")
return ""
print(f"Starting HyDE generation for query: '{query}'...")
client = AsyncGroq(api_key=groq_api_key)
prompt = (
f"Write a brief, formal passage that answers the following question. "
f"Use specific terminology as if it were from a larger document. "
f"Do not include the question or conversational text.\n\n"
f"Question: {query}\n\n"
f"Hypothetical Passage:"
)
try:
chat_completion = await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=HYDE_MODEL,
temperature=0.7,
max_tokens=500,
)
return chat_completion.choices[0].message.content
except Exception as e:
print(f"An error occurred during HyDE generation: {e}")
return ""
class Retriever:
"""Manages hybrid search, combining BM25, dense search, and a reranker."""
def __init__(self, embedding_client: EmbeddingClient):
self.embedding_client = embedding_client
self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
self.bm25 = None
self.document_chunks = []
self.chunk_embeddings = None
print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
def index(self, documents: List[Document]):
"""Builds the search index from document chunks."""
self.document_chunks = documents
corpus = [doc.page_content for doc in documents]
if not corpus:
print("No documents to index.")
return
print("Indexing documents for retrieval...")
# 1. Initialize BM25 model
tokenized_corpus = [doc.split(" ") for doc in corpus]
self.bm25 = BM25Okapi(tokenized_corpus)
# 2. Compute and store dense embeddings
self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
print("Indexing complete.")
def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
"""Performs the initial hybrid search to get candidate chunks."""
if self.bm25 is None or self.chunk_embeddings is None:
raise ValueError("Retriever has not been indexed. Call index() first.")
# Enhance query with hypothetical document
enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
# BM25 (keyword) search
tokenized_query = query.split(" ")
bm25_scores = self.bm25.get_scores(tokenized_query)
# Dense (semantic) search
query_embedding = self.embedding_client.create_embeddings([enhanced_query])
dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
# Normalize and combine scores
scaler = MinMaxScaler()
norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
# Get top initial candidates
top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
return [(idx, combined_scores[idx]) for idx in top_indices]
async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
"""Reranks the candidate chunks using a CrossEncoder model."""
if not candidates:
return []
print(f"Reranking {len(candidates)} candidates...")
rerank_input = [[query, chunk["content"]] for chunk in candidates]
# Run synchronous prediction in a separate thread
rerank_scores = await asyncio.to_thread(
self.reranker.predict, rerank_input, show_progress_bar=False
)
# Combine candidates with their new scores and sort
for candidate, score in zip(candidates, rerank_scores):
candidate['rerank_score'] = score
candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
return candidates[:TOP_K_CHUNKS]
async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]:
"""Executes the full retrieval pipeline: hybrid search followed by reranking."""
print(f"Retrieving documents for query: '{query}'")
# 1. Get initial candidates from hybrid search
initial_candidates_info = self._hybrid_search(query, hyde_doc)
retrieved_candidates = [{
"content": self.document_chunks[idx].page_content,
"metadata": self.document_chunks[idx].metadata,
"initial_score": score
} for idx, score in initial_candidates_info]
# 2. Rerank the candidates to get the final list
final_chunks = await self._rerank(query, retrieved_candidates)
print(f"Retrieved and reranked {len(final_chunks)} final chunks.")
return final_chunks |