PercivalFletcher commited on
Commit
a19a241
·
verified ·
1 Parent(s): 17aa266

Upload 6 files

Browse files
Files changed (6) hide show
  1. chunking.py +68 -0
  2. document_processor.py +88 -0
  3. embedding.py +40 -0
  4. generation.py +57 -0
  5. main.py +149 -0
  6. retrieval.py +139 -0
chunking.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: chunking.py
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_core.documents import Document
4
+ from typing import List
5
+ from unstructured.partition.md import partition_md
6
+ from unstructured.documents.elements import Header, Footer, PageBreak, Table, NarrativeText
7
+
8
+
9
+
10
+ # --- Configuration ---
11
+ CHUNK_SIZE = 1000
12
+ CHUNK_OVERLAP = 200
13
+
14
+ def process_and_chunk(raw_text: str) -> List[Document]:
15
+ """
16
+ Partitions raw text from a document processor using 'unstructured',
17
+ correctly interpreting it as markdown to preserve table structures,
18
+ and then chunks the remaining text content.
19
+
20
+ Args:
21
+ raw_text: The raw string content of the document (expected to be markdown).
22
+
23
+ Returns:
24
+ A list of Document objects, including structured tables and chunked text.
25
+ """
26
+ if not raw_text:
27
+ print("Warning: Input text for chunking is empty.")
28
+ return []
29
+
30
+ print(f"Processing raw text of length {len(raw_text)} with 'unstructured' markdown parser.")
31
+
32
+ # --- FIX: Change content_type to "text/markdown" ---
33
+ # This tells unstructured to use its specialized markdown parser, which
34
+ # correctly handles tables and other structures from your PyMuPDF output.
35
+ elements = partition_md(text=raw_text)
36
+
37
+ documents = []
38
+ text_splitter = RecursiveCharacterTextSplitter(
39
+ chunk_size=CHUNK_SIZE,
40
+ chunk_overlap=CHUNK_OVERLAP,
41
+ length_function=len,
42
+ is_separator_regex=False,
43
+ )
44
+
45
+ for element in elements:
46
+ if isinstance(element, (Header, Footer, PageBreak)):
47
+ continue
48
+ # Process tables
49
+ if "unstructured.documents.elements.Table" in str(type(element)):
50
+ table_html = element.metadata.text_as_html
51
+ table_metadata = element.metadata.to_dict()
52
+ table_metadata['content_type'] = 'table'
53
+ documents.append(Document(page_content=table_html, metadata=table_metadata))
54
+ # Process and chunk narrative text
55
+ elif "unstructured.documents.elements.NarrativeText" in str(type(element)):
56
+ chunks = text_splitter.split_text(element.text)
57
+ for chunk in chunks:
58
+ chunk_metadata = element.metadata.to_dict()
59
+ chunk_metadata['content_type'] = 'text'
60
+ documents.append(Document(page_content=chunk, metadata=chunk_metadata))
61
+ # Handle other elements directly
62
+ else:
63
+ general_metadata = element.metadata.to_dict()
64
+ general_metadata['content_type'] = 'other'
65
+ documents.append(Document(page_content=element.text, metadata=general_metadata))
66
+
67
+ print(f"Created {len(documents)} documents (chunks and tables).")
68
+ return documents
document_processor.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: document_processing.py
2
+
3
+ import os
4
+ import time
5
+ import httpx
6
+ from pathlib import Path
7
+ from urllib.parse import urlparse, unquote
8
+ from llama_index.readers.file import PyMuPDFReader
9
+ from llama_index.core import Document as LlamaDocument
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from pydantic import HttpUrl
12
+ from typing import List
13
+
14
+ # Define the batch size for parallel processing
15
+ BATCH_SIZE = 25
16
+
17
+ def _process_page_batch(documents_batch: List[LlamaDocument]) -> str:
18
+ """
19
+ Helper function to extract content from a batch of LlamaIndex Document objects
20
+ and join them into a single string.
21
+ """
22
+ return "\n\n".join([d.get_content() for d in documents_batch])
23
+
24
+ async def ingest_and_parse_document(doc_url: HttpUrl) -> str:
25
+ """
26
+ Asynchronously downloads a document, saves it locally, and parses it to
27
+ Markdown text using PyMuPDFReader with parallel processing.
28
+
29
+ Args:
30
+ doc_url: The Pydantic-validated URL of the document to process.
31
+
32
+ Returns:
33
+ A single string containing the document's extracted text.
34
+ """
35
+ print(f"Initiating download from: {doc_url}")
36
+ LOCAL_STORAGE_DIR = "data/"
37
+ os.makedirs(LOCAL_STORAGE_DIR, exist_ok=True)
38
+
39
+ try:
40
+ # Asynchronously download the document
41
+ async with httpx.AsyncClient() as client:
42
+ response = await client.get(str(doc_url), timeout=30.0, follow_redirects=True)
43
+ response.raise_for_status()
44
+ doc_bytes = response.content
45
+ print("Download successful.")
46
+
47
+ # Determine a valid local filename
48
+ parsed_path = urlparse(str(doc_url)).path
49
+ filename = unquote(os.path.basename(parsed_path)) or "downloaded_document.pdf"
50
+ local_file_path = Path(os.path.join(LOCAL_STORAGE_DIR, filename))
51
+
52
+ # Save the document locally
53
+ with open(local_file_path, "wb") as f:
54
+ f.write(doc_bytes)
55
+ print(f"Document saved locally at: {local_file_path}")
56
+
57
+ # Parse the document using LlamaIndex's PyMuPDFReader
58
+ print("Parsing document with PyMuPDFReader...")
59
+ loader = PyMuPDFReader()
60
+ docs_from_loader = loader.load_data(file_path=local_file_path)
61
+
62
+ # Parallelize the extraction of text from loaded pages
63
+ start_time = time.perf_counter()
64
+ all_extracted_texts = []
65
+ with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as executor:
66
+ futures = [
67
+ executor.submit(_process_page_batch, docs_from_loader[i:i + BATCH_SIZE])
68
+ for i in range(0, len(docs_from_loader), BATCH_SIZE)
69
+ ]
70
+ for future in as_completed(futures):
71
+ all_extracted_texts.append(future.result())
72
+
73
+ doc_text = "\n\n".join(all_extracted_texts)
74
+ elapsed_time = time.perf_counter() - start_time
75
+ print(f"Time taken for parallel text extraction: {elapsed_time:.4f} seconds.")
76
+
77
+ if not doc_text:
78
+ raise ValueError("Document parsing yielded no content.")
79
+
80
+ print(f"Parsing complete. Extracted {len(doc_text)} characters.")
81
+ return doc_text
82
+
83
+ except httpx.HTTPStatusError as e:
84
+ print(f"Error downloading document: {e}")
85
+ raise
86
+ except Exception as e:
87
+ print(f"An unexpected error occurred during document processing: {e}")
88
+ raise
embedding.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: embedding.py
2
+
3
+ import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ from typing import List
6
+
7
+ # --- Configuration ---
8
+ EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
9
+
10
+ class EmbeddingClient:
11
+ """A client for generating text embeddings using a local sentence transformer model."""
12
+
13
+ def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ self.model = SentenceTransformer(model_name, device=self.device)
16
+ print(f"EmbeddingClient initialized with model '{model_name}' on device '{self.device}'.")
17
+
18
+ def create_embeddings(self, texts: List[str]) -> torch.Tensor:
19
+ """
20
+ Generates embeddings for a list of text chunks.
21
+
22
+ Args:
23
+ texts: A list of strings to be embedded.
24
+
25
+ Returns:
26
+ A torch.Tensor containing the generated embeddings.
27
+ """
28
+ if not texts:
29
+ return torch.tensor([])
30
+
31
+ print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
32
+ try:
33
+ embeddings = self.model.encode(
34
+ texts, convert_to_tensor=True, show_progress_bar=False
35
+ )
36
+ print("Embeddings generated successfully.")
37
+ return embeddings
38
+ except Exception as e:
39
+ print(f"An error occurred during embedding generation: {e}")
40
+ raise
generation.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: generation.py
2
+ from groq import AsyncGroq
3
+ from typing import List, Dict
4
+
5
+ # --- Configuration ---
6
+ GROQ_MODEL_NAME = "llama3-8b-8192"
7
+
8
+ async def generate_answer(query: str, context_chunks: List[Dict], groq_api_key: str) -> str:
9
+ """
10
+ Generates a final answer using the Groq API based on the query and retrieved context.
11
+
12
+ Args:
13
+ query: The user's original question.
14
+ context_chunks: A list of the most relevant, reranked document chunks.
15
+ groq_api_key: The API key for the Groq service.
16
+
17
+ Returns:
18
+ A string containing the generated answer.
19
+ """
20
+ if not groq_api_key:
21
+ return "Error: Groq API key is not set."
22
+ if not context_chunks:
23
+ return "I do not have enough information to answer this question based on the provided document."
24
+
25
+ print("Generating final answer with Groq...")
26
+ client = AsyncGroq(api_key=groq_api_key)
27
+
28
+ # Format the context for the prompt
29
+ context_str = "\n\n---\n\n".join(
30
+ [f"Context Chunk:\n{chunk['content']}" for chunk in context_chunks]
31
+ )
32
+
33
+ prompt = (
34
+ "You are an expert Q&A system. Your task is to extract information with 100% accuracy from the provided text. Provide a brief and direct answer."
35
+ "Do not mention the context in your response. Answer *only* using the information from the provided document."
36
+ "Do not infer, add, or assume any information that is not explicitly written in the source text. If the answer is not in the document, state that the information is not available."
37
+ "When the question involves numbers, percentages, or monetary values, extract the exact figures from the text."
38
+ "Double-check that the value corresponds to the correct plan or condition mentioned in the question."
39
+ "\n\n"
40
+ f"CONTEXT:\n{context_str}\n\n"
41
+ f"QUESTION:\n{query}\n\n"
42
+ "ANSWER:"
43
+ )
44
+
45
+ try:
46
+ chat_completion = await client.chat.completions.create(
47
+ messages=[{"role": "user", "content": prompt}],
48
+ model=GROQ_MODEL_NAME,
49
+ temperature=0.2, # Lower temperature for more factual answers
50
+ max_tokens=500,
51
+ )
52
+ answer = chat_completion.choices[0].message.content
53
+ print("Answer generated successfully.")
54
+ return answer
55
+ except Exception as e:
56
+ print(f"An error occurred during Groq API call: {e}")
57
+ return "Could not generate an answer due to an API error."
main.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: main.py
2
+ import time
3
+ import os
4
+ import asyncio
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel, HttpUrl
7
+ from typing import List, Dict, Any
8
+ from dotenv import load_dotenv
9
+
10
+ # Import functions and classes from the new modular files
11
+ from document_processor import ingest_and_parse_document
12
+ from chunking import process_and_chunk
13
+ from embedding import EmbeddingClient
14
+ from retrieval import Retriever, generate_hypothetical_document
15
+ from generation import generate_answer
16
+
17
+ load_dotenv()
18
+
19
+ # --- FastAPI App Initialization ---
20
+ app = FastAPI(
21
+ title="Modular RAG API",
22
+ description="A modular API for Retrieval-Augmented Generation from documents.",
23
+ version="2.0.0",
24
+ )
25
+
26
+ # --- Global Clients and API Keys ---
27
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
28
+ embedding_client = EmbeddingClient()
29
+ retriever = Retriever(embedding_client=embedding_client)
30
+
31
+
32
+ # --- Pydantic Models ---
33
+ class RunRequest(BaseModel):
34
+ document_url: HttpUrl
35
+ questions: List[str]
36
+
37
+ class RunResponse(BaseModel):
38
+ answers: List[str]
39
+
40
+ class TestRequest(BaseModel):
41
+ document_url: HttpUrl
42
+ #Endpoints
43
+
44
+ # --- NEW: Test Endpoint for Parsing ---
45
+ @app.post("/test/parse", response_model=Dict[str, Any], tags=["Testing"])
46
+ async def test_parsing_endpoint(request: TestRequest):
47
+ """
48
+ Tests the document ingestion and parsing phase.
49
+ Returns the full markdown content and the time taken.
50
+ """
51
+ print("--- Running Parsing Test ---")
52
+ start_time = time.perf_counter()
53
+
54
+ try:
55
+ markdown_content = await ingest_and_parse_document(request.document_url)
56
+
57
+ end_time = time.perf_counter()
58
+ duration = end_time - start_time
59
+ print(f"--- Parsing took {duration:.2f} seconds ---")
60
+
61
+ return {
62
+ "parsing_time_seconds": duration,
63
+ "character_count": len(markdown_content),
64
+ "content": markdown_content
65
+ }
66
+ except Exception as e:
67
+ raise HTTPException(status_code=500, detail=f"An error occurred during parsing: {str(e)}")
68
+
69
+ @app.post("/hackrx/run", response_model=RunResponse)
70
+ async def run_rag_pipeline(request: RunRequest):
71
+ """
72
+ Runs the full RAG pipeline for a given document URL and a list of questions.
73
+ """
74
+ try:
75
+ # --- STAGE 1 & 2: DOCUMENT INGESTION AND CHUNKING ---
76
+ print("--- Kicking off RAG Pipeline ---")
77
+ markdown_content = await ingest_and_parse_document(request.document_url)
78
+ documents = process_and_chunk(markdown_content)
79
+
80
+ if not documents:
81
+ raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")
82
+
83
+ # --- STAGE 3: INDEXING (Embedding + BM25) ---
84
+ # This step builds the search index for the current document.
85
+ retriever.index(documents)
86
+
87
+ # --- CONCURRENT WORKFLOW FOR ALL QUESTIONS ---
88
+
89
+ # Step A: Concurrently generate hypothetical documents for all questions
90
+ hyde_tasks = [generate_hypothetical_document(q, GROQ_API_KEY) for q in request.questions]
91
+ all_hyde_docs = await asyncio.gather(*hyde_tasks)
92
+
93
+ # Step B: Concurrently retrieve relevant chunks for all questions
94
+ retrieval_tasks = [
95
+ retriever.retrieve(q, hyde_doc)
96
+ for q, hyde_doc in zip(request.questions, all_hyde_docs)
97
+ ]
98
+ all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
99
+
100
+ # Step C: Concurrently generate final answers for all questions
101
+ answer_tasks = [
102
+ generate_answer(q, chunks, GROQ_API_KEY)
103
+ for q, chunks in zip(request.questions, all_retrieved_chunks)
104
+ ]
105
+ final_answers = await asyncio.gather(*answer_tasks)
106
+
107
+ print("--- RAG Pipeline Completed Successfully ---")
108
+ return RunResponse(answers=final_answers)
109
+
110
+ except Exception as e:
111
+ print(f"An unhandled error occurred in the pipeline: {e}")
112
+ # Re-raising as a 500 error for the client
113
+ raise HTTPException(
114
+ status_code=500, detail=f"An internal server error occurred: {str(e)}"
115
+ )
116
+
117
+ @app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
118
+ async def test_chunking_endpoint(request: TestRequest):
119
+ """
120
+ Tests both the parsing and chunking phases together.
121
+ Returns the final list of chunks and the total time taken.
122
+ """
123
+ print("--- Running Parsing and Chunking Test ---")
124
+ start_time = time.perf_counter()
125
+
126
+ try:
127
+ # Step 1: Parse the document
128
+ markdown_content = await ingest_and_parse_document(request.document_url)
129
+
130
+ # Step 2: Chunk the parsed content
131
+ documents = process_and_chunk(markdown_content)
132
+
133
+ end_time = time.perf_counter()
134
+ duration = end_time - start_time
135
+ print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")
136
+
137
+ # Convert Document objects to a JSON-serializable list
138
+ chunk_results = [
139
+ {"page_content": doc.page_content, "metadata": doc.metadata}
140
+ for doc in documents
141
+ ]
142
+
143
+ return {
144
+ "total_time_seconds": duration,
145
+ "chunk_count": len(chunk_results),
146
+ "chunks": chunk_results
147
+ }
148
+ except Exception as e:
149
+ raise HTTPException(status_code=500, detail=f"An error occurred during chunking: {str(e)}")
retrieval.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: retrieval.py
2
+
3
+ import time
4
+ import asyncio
5
+ import numpy as np
6
+ import torch
7
+ from groq import AsyncGroq
8
+ from rank_bm25 import BM25Okapi
9
+ from sentence_transformers import CrossEncoder
10
+ from sklearn.preprocessing import MinMaxScaler
11
+ from torch.nn.functional import cosine_similarity
12
+ from typing import List, Dict, Tuple
13
+
14
+ from embedding import EmbeddingClient
15
+ from langchain_core.documents import Document
16
+
17
+ # --- Configuration ---
18
+ HYDE_MODEL = "llama3-8b-8192"
19
+ RERANKER_MODEL = 'cross-encoder/ms-marco-MiniLM-L6-v2'
20
+ INITIAL_K_CANDIDATES = 20
21
+ TOP_K_CHUNKS = 10
22
+
23
+ async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
24
+ """Generates a hypothetical document (HyDE) to enhance search."""
25
+ if not groq_api_key:
26
+ print("Groq API key not set. Skipping HyDE generation.")
27
+ return ""
28
+
29
+ print(f"Starting HyDE generation for query: '{query}'...")
30
+ client = AsyncGroq(api_key=groq_api_key)
31
+ prompt = (
32
+ f"Write a brief, formal passage that answers the following question. "
33
+ f"Use specific terminology as if it were from a larger document. "
34
+ f"Do not include the question or conversational text.\n\n"
35
+ f"Question: {query}\n\n"
36
+ f"Hypothetical Passage:"
37
+ )
38
+
39
+ try:
40
+ chat_completion = await client.chat.completions.create(
41
+ messages=[{"role": "user", "content": prompt}],
42
+ model=HYDE_MODEL,
43
+ temperature=0.7,
44
+ max_tokens=500,
45
+ )
46
+ return chat_completion.choices[0].message.content
47
+ except Exception as e:
48
+ print(f"An error occurred during HyDE generation: {e}")
49
+ return ""
50
+
51
+ class Retriever:
52
+ """Manages hybrid search, combining BM25, dense search, and a reranker."""
53
+
54
+ def __init__(self, embedding_client: EmbeddingClient):
55
+ self.embedding_client = embedding_client
56
+ self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
57
+ self.bm25 = None
58
+ self.document_chunks = []
59
+ self.chunk_embeddings = None
60
+ print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
61
+
62
+ def index(self, documents: List[Document]):
63
+ """Builds the search index from document chunks."""
64
+ self.document_chunks = documents
65
+ corpus = [doc.page_content for doc in documents]
66
+ if not corpus:
67
+ print("No documents to index.")
68
+ return
69
+
70
+ print("Indexing documents for retrieval...")
71
+ # 1. Initialize BM25 model
72
+ tokenized_corpus = [doc.split(" ") for doc in corpus]
73
+ self.bm25 = BM25Okapi(tokenized_corpus)
74
+ # 2. Compute and store dense embeddings
75
+ self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
76
+ print("Indexing complete.")
77
+
78
+ def _hybrid_search(self, query: str, hyde_doc: str) -> List[Tuple[int, float]]:
79
+ """Performs the initial hybrid search to get candidate chunks."""
80
+ if self.bm25 is None or self.chunk_embeddings is None:
81
+ raise ValueError("Retriever has not been indexed. Call index() first.")
82
+
83
+ # Enhance query with hypothetical document
84
+ enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
85
+
86
+ # BM25 (keyword) search
87
+ tokenized_query = query.split(" ")
88
+ bm25_scores = self.bm25.get_scores(tokenized_query)
89
+
90
+ # Dense (semantic) search
91
+ query_embedding = self.embedding_client.create_embeddings([enhanced_query])
92
+ dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
93
+
94
+ # Normalize and combine scores
95
+ scaler = MinMaxScaler()
96
+ norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
97
+ norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
98
+ combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
99
+
100
+ # Get top initial candidates
101
+ top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
102
+ return [(idx, combined_scores[idx]) for idx in top_indices]
103
+
104
+ async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
105
+ """Reranks the candidate chunks using a CrossEncoder model."""
106
+ if not candidates:
107
+ return []
108
+
109
+ print(f"Reranking {len(candidates)} candidates...")
110
+ rerank_input = [[query, chunk["content"]] for chunk in candidates]
111
+
112
+ # Run synchronous prediction in a separate thread
113
+ rerank_scores = await asyncio.to_thread(
114
+ self.reranker.predict, rerank_input, show_progress_bar=False
115
+ )
116
+
117
+ # Combine candidates with their new scores and sort
118
+ for candidate, score in zip(candidates, rerank_scores):
119
+ candidate['rerank_score'] = score
120
+
121
+ candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
122
+ return candidates[:TOP_K_CHUNKS]
123
+
124
+ async def retrieve(self, query: str, hyde_doc: str) -> List[Dict]:
125
+ """Executes the full retrieval pipeline: hybrid search followed by reranking."""
126
+ print(f"Retrieving documents for query: '{query}'")
127
+ # 1. Get initial candidates from hybrid search
128
+ initial_candidates_info = self._hybrid_search(query, hyde_doc)
129
+
130
+ retrieved_candidates = [{
131
+ "content": self.document_chunks[idx].page_content,
132
+ "metadata": self.document_chunks[idx].metadata,
133
+ "initial_score": score
134
+ } for idx, score in initial_candidates_info]
135
+
136
+ # 2. Rerank the candidates to get the final list
137
+ final_chunks = await self._rerank(query, retrieved_candidates)
138
+ print(f"Retrieved and reranked {len(final_chunks)} final chunks.")
139
+ return final_chunks