Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	File size: 5,699 Bytes
			
			| a19a241 77d39e7 a19a241 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | # file: retrieval.py
import time
import asyncio
import numpy as np
import torch
from groq import AsyncGroq
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
from sklearn.preprocessing import MinMaxScaler
from torch.nn.functional import cosine_similarity
from typing import List, Dict, Tuple
from embedding import EmbeddingClient
from langchain_core.documents import Document
# --- Configuration ---
HYDE_MODEL = "llama-3.1-8b-instant"
RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
INITIAL_K_CANDIDATES = 20
TOP_K_CHUNKS = 10 
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
    """Generates a hypothetical document (HyDE) to enhance search."""
    if not groq_api_key:
        print("Groq API key not set. Skipping HyDE generation.")
        return ""
    print(f"Starting HyDE generation for query: '{query}'...")
    client = AsyncGroq(api_key=groq_api_key)
    prompt = (
        f"Write a brief, formal passage that answers the following question. "
        f"Use specific terminology as if it were from a larger document. "
        f"Do not include the question or conversational text.\n\n"
        f"Question: {query}\n\n"
        f"Hypothetical Passage:"
    )
    try:
        chat_completion = await client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model=HYDE_MODEL,
            temperature=0.7,
            max_tokens=500,
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        print(f"An error occurred during HyDE generation: {e}")
        return ""
class Retriever:
    """Manages hybrid search, combining BM25, dense search, and a reranker."""
    def __init__(self, embedding_client: EmbeddingClient):
        self.embedding_client = embedding_client
        self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
        self.bm25 = None
        self.document_chunks = []
        self.chunk_embeddings = None
        print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
    def index(self, documents: List[Document]):
        """Builds the search index from document chunks."""
        self.document_chunks = documents
        corpus = [doc.page_content for doc in documents]
        if not corpus:
            print("No documents to index.")
            return
        print("Indexing documents for retrieval...")
        # 1. Initialize BM25 model
        tokenized_corpus = [doc.split(" ") for doc in corpus]
        self.bm25 = BM25Okapi(tokenized_corpus)
        # 2. Compute and store dense embeddings
        self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
        print("Indexing complete.")
    def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
        """Performs the initial hybrid search to get candidate chunks."""
        if self.bm25 is None or self.chunk_embeddings is None:
            raise ValueError("Retriever has not been indexed. Call index() first.")
        # Enhance query with hypothetical document
        enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
        # BM25 (keyword) search
        tokenized_query = query.split(" ")
        bm25_scores = self.bm25.get_scores(tokenized_query)
        # Dense (semantic) search
        query_embedding = self.embedding_client.create_embeddings([enhanced_query])
        dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
        # Normalize and combine scores
        scaler = MinMaxScaler()
        norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
        norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
        combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
        
        # Get top initial candidates
        top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
        return [(idx, combined_scores[idx]) for idx in top_indices]
    async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
        """Reranks the candidate chunks using a CrossEncoder model."""
        if not candidates:
            return []
        print(f"Reranking {len(candidates)} candidates...")
        rerank_input = [[query, chunk["content"]] for chunk in candidates]
        
        # Run synchronous prediction in a separate thread
        rerank_scores = await asyncio.to_thread(
            self.reranker.predict, rerank_input, show_progress_bar=False
        )
        # Combine candidates with their new scores and sort
        for candidate, score in zip(candidates, rerank_scores):
            candidate['rerank_score'] = score
        
        candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
        return candidates[:TOP_K_CHUNKS]
    async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]:
        """Executes the full retrieval pipeline: hybrid search followed by reranking."""
        print(f"Retrieving documents for query: '{query}'")
        # 1. Get initial candidates from hybrid search
        initial_candidates_info = self._hybrid_search(query, hyde_doc)
        
        retrieved_candidates = [{
            "content": self.document_chunks[idx].page_content,
            "metadata": self.document_chunks[idx].metadata,
            "initial_score": score
        } for idx, score in initial_candidates_info]
        # 2. Rerank the candidates to get the final list
        final_chunks = await self._rerank(query, retrieved_candidates)
        print(f"Retrieved and reranked {len(final_chunks)} final chunks.")
        return final_chunks | 
