# file: retrieval_parent.py import time # <-- ADD THIS IMPORT import asyncio import numpy as np import torch import json from groq import AsyncGroq from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder from sklearn.preprocessing import MinMaxScaler from torch.nn.functional import cosine_similarity from typing import List, Dict, Tuple from langchain.storage import InMemoryStore from embedding import EmbeddingClient from langchain_core.documents import Document # --- Configuration --- GENERATION_MODEL = "llama-3.1-8b-instant" RERANKER_MODEL = 'BAAI/bge-reranker-v2-m3' INITIAL_K_CANDIDATES = 20 TOP_K_CHUNKS = 10 async def generate_hypothetical_document(query: str, groq_api_key: str) -> str: """Generates a hypothetical document to answer the query (HyDE).""" if not groq_api_key: print("Groq API key not set. Skipping HyDE generation.") return "" print(f"Starting HyDE generation for query: '{query}'...") client = AsyncGroq(api_key=groq_api_key) prompt = ( f"Write a brief, formal passage that directly answers the following question. " f"This passage will be used to find similar documents. " f"Do not include the question or any conversational text.\n\n" f"Question: {query}\n\n" f"Hypothetical Passage:" ) start_time = time.perf_counter() # <-- START TIMER try: chat_completion = await client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=GENERATION_MODEL, temperature=0.3, max_tokens=500, ) end_time = time.perf_counter() # <-- END TIMER print(f"--- HyDE generation took {end_time - start_time:.4f} seconds ---") # <-- PRINT DURATION return chat_completion.choices[0].message.content except Exception as e: print(f"An error occurred during HyDE generation: {e}") return "" async def generate_expanded_terms(query: str, groq_api_key: str) -> List[str]: """Generates semantically related search terms for a query.""" if not groq_api_key: print("Groq API key not set. Skipping Semantic Expansion.") return [query] print(f"Starting Semantic Expansion for query: '{query}'...") client = AsyncGroq(api_key=groq_api_key) prompt = ( f"You are a search query expansion expert. Based on the following query, generate up to 4 additional, " f"semantically related search terms. The terms should be relevant for finding information in technical documents. " f"Return the original query plus the new terms as a single JSON list of strings.\n\n" f"Query: \"{query}\"\n\n" f"JSON List:" ) start_time = time.perf_counter() # <-- START TIMER try: chat_completion = await client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model=GENERATION_MODEL, temperature=0.4, max_tokens=200, response_format={"type": "json_object"}, ) end_time = time.perf_counter() # <-- END TIMER print(f"--- Semantic Expansion took {end_time - start_time:.4f} seconds ---") # <-- PRINT DURATION result_text = chat_completion.choices[0].message.content terms = json.loads(result_text) if isinstance(terms, dict) and 'terms' in terms: return terms['terms'] return terms except Exception as e: print(f"An error occurred during Semantic Expansion: {e}") return [query] class Retriever: """Manages hybrid search with parent-child retrieval.""" def __init__(self, embedding_client: EmbeddingClient): self.embedding_client = embedding_client self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device) self.bm25 = None self.document_chunks = [] self.chunk_embeddings = None self.docstore = InMemoryStore() print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.") def index(self, child_documents: List[Document], docstore: InMemoryStore): """Builds the search index from child documents and stores parent documents.""" self.document_chunks = child_documents self.docstore = docstore corpus = [doc.page_content for doc in child_documents] if not corpus: print("No documents to index.") return print("Indexing child documents for retrieval...") tokenized_corpus = [doc.split(" ") for doc in corpus] self.bm25 = BM25Okapi(tokenized_corpus) self.chunk_embeddings = self.embedding_client.create_embeddings(corpus) print("Indexing complete.") def _hybrid_search(self, query: str, hyde_doc: str, expanded_terms: List[str]) -> List[Tuple[int, float]]: """Performs a hybrid search using expanded terms for BM25 and a HyDE doc for dense search.""" if self.bm25 is None or self.chunk_embeddings is None: raise ValueError("Retriever has not been indexed. Call index() first.") print(f"Running BM25 with expanded terms: {expanded_terms}") bm25_scores = self.bm25.get_scores(expanded_terms) enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query query_embedding = self.embedding_client.create_embeddings([enhanced_query]) dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten() scaler = MinMaxScaler() norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten() norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten() combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES] return [(idx, combined_scores[idx]) for idx in top_indices] async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]: """Reranks candidates using a CrossEncoder model.""" if not candidates: return [] print(f"Reranking {len(candidates)} candidates...") rerank_input = [[query, chunk["content"]] for chunk in candidates] rerank_scores = await asyncio.to_thread( self.reranker.predict, rerank_input, show_progress_bar=False ) for candidate, score in zip(candidates, rerank_scores): candidate['rerank_score'] = score candidates.sort(key=lambda x: x['rerank_score'], reverse=True) return candidates[:TOP_K_CHUNKS] async def retrieve(self, query: str, groq_api_key: str) -> List[Dict]: """Executes the full retrieval pipeline: expansion, HyDE, hybrid search, and reranking.""" print(f"\n--- Retrieving documents for query: '{query}' ---") hyde_task = generate_hypothetical_document(query, groq_api_key) expansion_task = generate_expanded_terms(query, groq_api_key) hyde_doc, expanded_terms = await asyncio.gather(hyde_task, expansion_task) initial_candidates_info = self._hybrid_search(query, hyde_doc, expanded_terms) retrieved_child_docs = [{ "content": self.document_chunks[idx].page_content, "metadata": self.document_chunks[idx].metadata, } for idx, score in initial_candidates_info] reranked_child_docs = await self._rerank(query, retrieved_child_docs) parent_ids = [] for doc in reranked_child_docs: parent_id = doc["metadata"]["parent_id"] if parent_id not in parent_ids: parent_ids.append(parent_id) retrieved_parents = self.docstore.mget(parent_ids) final_parent_docs = [doc for doc in retrieved_parents if doc is not None] final_context = [{ "content": doc.page_content, "metadata": doc.metadata } for doc in final_parent_docs] print(f"Retrieved {len(final_context)} final parent chunks for context.") return final_context