# file: retrieval.py import time import asyncio import numpy as np import torch from groq import AsyncGroq from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder from sklearn.preprocessing import MinMaxScaler from torch.nn.functional import cosine_similarity from typing import List, Dict, Tuple from embedding import EmbeddingClient from langchain_core.documents import Document # --- Configuration --- HYDE_MODEL = "llama-3.1-8b-instant" RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2' INITIAL_K_CANDIDATES = 20 TOP_K_CHUNKS = 10 async def generate_hypothetical_document(query: str, groq_api_key: str) -> str: """Generates a hypothetical document (HyDE) to enhance search.""" if not groq_api_key: print("Groq API key not set. Skipping HyDE generation.") return "" print(f"Starting HyDE generation for query: '{query}'...") client = AsyncGroq(api_key=groq_api_key) prompt = ( f"Write a brief, formal passage that answers the following question. " f"Use specific terminology as if it were from a larger document. " f"Do not include the question or conversational text.\n\n" f"Question: {query}\n\n" f"Hypothetical Passage:" ) try: chat_completion = await client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=HYDE_MODEL, temperature=0.7, max_tokens=500, ) return chat_completion.choices[0].message.content except Exception as e: print(f"An error occurred during HyDE generation: {e}") return "" class Retriever: """Manages hybrid search, combining BM25, dense search, and a reranker.""" def __init__(self, embedding_client: EmbeddingClient): self.embedding_client = embedding_client self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device) self.bm25 = None self.document_chunks = [] self.chunk_embeddings = None print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.") def index(self, documents: List[Document]): """Builds the search index from document chunks.""" self.document_chunks = documents corpus = [doc.page_content for doc in documents] if not corpus: print("No documents to index.") return print("Indexing documents for retrieval...") # 1. Initialize BM25 model tokenized_corpus = [doc.split(" ") for doc in corpus] self.bm25 = BM25Okapi(tokenized_corpus) # 2. Compute and store dense embeddings self.chunk_embeddings = self.embedding_client.create_embeddings(corpus) print("Indexing complete.") def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]: """Performs the initial hybrid search to get candidate chunks.""" if self.bm25 is None or self.chunk_embeddings is None: raise ValueError("Retriever has not been indexed. Call index() first.") # Enhance query with hypothetical document enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query # BM25 (keyword) search tokenized_query = query.split(" ") bm25_scores = self.bm25.get_scores(tokenized_query) # Dense (semantic) search query_embedding = self.embedding_client.create_embeddings([enhanced_query]) dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten() # Normalize and combine scores scaler = MinMaxScaler() norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten() norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten() combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense # Get top initial candidates top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES] return [(idx, combined_scores[idx]) for idx in top_indices] async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]: """Reranks the candidate chunks using a CrossEncoder model.""" if not candidates: return [] print(f"Reranking {len(candidates)} candidates...") rerank_input = [[query, chunk["content"]] for chunk in candidates] # Run synchronous prediction in a separate thread rerank_scores = await asyncio.to_thread( self.reranker.predict, rerank_input, show_progress_bar=False ) # Combine candidates with their new scores and sort for candidate, score in zip(candidates, rerank_scores): candidate['rerank_score'] = score candidates.sort(key=lambda x: x['rerank_score'], reverse=True) return candidates[:TOP_K_CHUNKS] async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]: """Executes the full retrieval pipeline: hybrid search followed by reranking.""" print(f"Retrieving documents for query: '{query}'") # 1. Get initial candidates from hybrid search initial_candidates_info = self._hybrid_search(query, hyde_doc) retrieved_candidates = [{ "content": self.document_chunks[idx].page_content, "metadata": self.document_chunks[idx].metadata, "initial_score": score } for idx, score in initial_candidates_info] # 2. Rerank the candidates to get the final list final_chunks = await self._rerank(query, retrieved_candidates) print(f"Retrieved and reranked {len(final_chunks)} final chunks.") return final_chunks