PercivalFletcher commited on
Commit
5abe5ee
·
verified ·
1 Parent(s): 02728c5

Upload 7 files

Browse files
Files changed (7) hide show
  1. chunking_parent.py +79 -0
  2. embedding.py +40 -0
  3. generation.py +57 -0
  4. ingestion_router.py +129 -0
  5. main.py +158 -0
  6. pdf_parallel_parser.py +78 -0
  7. retrieval_parent.py +193 -0
chunking_parent.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: chunking.py
2
+ import uuid
3
+ from typing import List, Tuple, Dict, Any
4
+ from langchain_core.documents import Document
5
+ from langchain.storage import InMemoryStore
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+
8
+ # --- Configuration for Parent-Child Splitting ---
9
+ # Parent chunks are the larger documents passed to the LLM for context.
10
+ PARENT_CHUNK_SIZE = 2000
11
+ PARENT_CHUNK_OVERLAP = 200
12
+
13
+ # Child chunks are the smaller, more granular documents used for retrieval.
14
+ CHILD_CHUNK_SIZE = 400
15
+ CHILD_CHUNK_OVERLAP = 100
16
+
17
+ def create_parent_child_chunks(
18
+ full_text: str
19
+ ) -> Tuple[List[Document], InMemoryStore, Dict[str, str]]:
20
+ """
21
+ Implements the Parent Document strategy for chunking.
22
+
23
+ 1. Splits the document into larger "parent" chunks.
24
+ 2. Splits the parent chunks into smaller "child" chunks.
25
+ 3. The child chunks are used for retrieval, while the parent chunks
26
+ are used to provide context to the LLM.
27
+
28
+ Args:
29
+ full_text: The entire text content of the document.
30
+
31
+ Returns:
32
+ A tuple containing:
33
+ - A list of the small "child" documents for the vector store.
34
+ - An in-memory store mapping parent document IDs to the parent documents.
35
+ - A dictionary mapping child document IDs to their parent's ID.
36
+ """
37
+ if not full_text:
38
+ print("Warning: Input text for chunking is empty.")
39
+ return [], InMemoryStore(), {}
40
+
41
+ print("Creating parent and child chunks...")
42
+
43
+ # This splitter creates the large documents that will be stored.
44
+ parent_splitter = RecursiveCharacterTextSplitter(
45
+ chunk_size=PARENT_CHUNK_SIZE,
46
+ chunk_overlap=PARENT_CHUNK_OVERLAP,
47
+ )
48
+
49
+ # This splitter creates the small, granular chunks for retrieval.
50
+ child_splitter = RecursiveCharacterTextSplitter(
51
+ chunk_size=CHILD_CHUNK_SIZE,
52
+ chunk_overlap=CHILD_CHUNK_OVERLAP,
53
+ )
54
+
55
+ parent_documents = parent_splitter.create_documents([full_text])
56
+
57
+ docstore = InMemoryStore()
58
+ child_documents = []
59
+ child_to_parent_id_map = {}
60
+
61
+ # Generate unique IDs for each parent document and add them to the store
62
+ parent_ids = [str(uuid.uuid4()) for _ in parent_documents]
63
+ docstore.mset(list(zip(parent_ids, parent_documents)))
64
+
65
+ # Split each parent document into smaller child documents
66
+ for i, p_doc in enumerate(parent_documents):
67
+ parent_id = parent_ids[i]
68
+ _child_docs = child_splitter.split_documents([p_doc])
69
+
70
+ for _child_doc in _child_docs:
71
+ child_id = str(uuid.uuid4())
72
+ _child_doc.metadata["parent_id"] = parent_id
73
+ _child_doc.metadata["child_id"] = child_id
74
+ child_to_parent_id_map[child_id] = parent_id
75
+
76
+ child_documents.extend(_child_docs)
77
+
78
+ print(f"Created {len(parent_documents)} parent chunks and {len(child_documents)} child chunks.")
79
+ return child_documents, docstore, child_to_parent_id_map
embedding.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: embedding.py
2
+
3
+ import torch
4
+ from sentence_transformers import SentenceTransformer
5
+ from typing import List
6
+
7
+ # --- Configuration ---
8
+ EMBEDDING_MODEL_NAME = "sentence-transformers/stsb-xlm-r-multilingual"
9
+
10
+ class EmbeddingClient:
11
+ """A client for generating text embeddings using a local sentence transformer model."""
12
+
13
+ def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ self.model = SentenceTransformer(model_name, device=self.device)
16
+ print(f"EmbeddingClient initialized with model '{model_name}' on device '{self.device}'.")
17
+
18
+ def create_embeddings(self, texts: List[str]) -> torch.Tensor:
19
+ """
20
+ Generates embeddings for a list of text chunks.
21
+
22
+ Args:
23
+ texts: A list of strings to be embedded.
24
+
25
+ Returns:
26
+ A torch.Tensor containing the generated embeddings.
27
+ """
28
+ if not texts:
29
+ return torch.tensor([])
30
+
31
+ print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
32
+ try:
33
+ embeddings = self.model.encode(
34
+ texts, convert_to_tensor=True, show_progress_bar=False
35
+ )
36
+ print("Embeddings generated successfully.")
37
+ return embeddings
38
+ except Exception as e:
39
+ print(f"An error occurred during embedding generation: {e}")
40
+ raise
generation.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: generation.py
2
+ from groq import AsyncGroq
3
+ from typing import List, Dict
4
+
5
+ # --- Configuration ---
6
+ GROQ_MODEL_NAME = "llama3-8b-8192"
7
+
8
+ async def generate_answer(query: str, context_chunks: List[Dict], groq_api_key: str) -> str:
9
+ """
10
+ Generates a final answer using the Groq API based on the query and retrieved context.
11
+
12
+ Args:
13
+ query: The user's original question.
14
+ context_chunks: A list of the most relevant, reranked document chunks.
15
+ groq_api_key: The API key for the Groq service.
16
+
17
+ Returns:
18
+ A string containing the generated answer.
19
+ """
20
+ if not groq_api_key:
21
+ return "Error: Groq API key is not set."
22
+ if not context_chunks:
23
+ return "I do not have enough information to answer this question based on the provided document."
24
+
25
+ print("Generating final answer with Groq...")
26
+ client = AsyncGroq(api_key=groq_api_key)
27
+
28
+ # Format the context for the prompt
29
+ context_str = "\n\n---\n\n".join(
30
+ [f"Context Chunk:\n{chunk['content']}" for chunk in context_chunks]
31
+ )
32
+
33
+ prompt = (
34
+ "You are an expert Q&A system. Your task is to extract information with 100% accuracy from the provided text. Provide a brief and direct answer."
35
+ "Do not mention the context in your response. Answer *only* using the information from the provided document."
36
+ "Do not infer, add, or assume any information that is not explicitly written in the source text. If the answer is not in the document, state that the information is not available."
37
+ "When the question involves numbers, percentages, or monetary values, extract the exact figures from the text."
38
+ "Double-check that the value corresponds to the correct plan or condition mentioned in the question."
39
+ "\n\n"
40
+ f"CONTEXT:\n{context_str}\n\n"
41
+ f"QUESTION:\n{query}\n\n"
42
+ "ANSWER:"
43
+ )
44
+
45
+ try:
46
+ chat_completion = await client.chat.completions.create(
47
+ messages=[{"role": "user", "content": prompt}],
48
+ model=GROQ_MODEL_NAME,
49
+ temperature=0.2, # Lower temperature for more factual answers
50
+ max_tokens=500,
51
+ )
52
+ answer = chat_completion.choices[0].message.content
53
+ print("Answer generated successfully.")
54
+ return answer
55
+ except Exception as e:
56
+ print(f"An error occurred during Groq API call: {e}")
57
+ return "Could not generate an answer due to an API error."
ingestion_router.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: ingestion_router.py
2
+ import os
3
+ import time
4
+ import httpx
5
+ import zipfile
6
+ import io
7
+ import asyncio
8
+ from PIL import Image
9
+ from pathlib import Path
10
+ from urllib.parse import urlparse, unquote
11
+ from pydantic import HttpUrl
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+
14
+ # --- Import our custom parsers ---
15
+ from pdf_parallel_parser import process_pdf_with_hybrid_parallel_sync
16
+ from complex_parser import process_image_element
17
+
18
+ # --- A simple parser for generic files (DOCX, etc.) using unstructured ---
19
+ from unstructured.partition.auto import partition
20
+
21
+ # --- Configuration ---
22
+ LOCAL_STORAGE_DIR = "data/"
23
+
24
+ # --- Synchronous, CPU-Bound Parsing Functions ---
25
+
26
+ def _process_generic_file_sync(file_content: bytes, filename: str) -> str:
27
+ """Fallback parser for standard files like DOCX, PPTX, etc., using unstructured."""
28
+ print(f"Processing '{filename}' with unstructured (standard)...")
29
+ try:
30
+ elements = partition(file=io.BytesIO(file_content), file_filename=filename)
31
+ return "\n\n".join([el.text for el in elements])
32
+ except Exception as e:
33
+ print(f"Unstructured failed for {filename}: {e}")
34
+ return ""
35
+
36
+ def _process_zip_file_in_parallel(zip_content: bytes, temp_dir: Path) -> str:
37
+ """Extracts and processes files from a ZIP archive in parallel."""
38
+ print("Initiating parallel processing of ZIP archive...")
39
+ all_extracted_texts = []
40
+
41
+ def process_single_zipped_file(zf: zipfile.ZipFile, file_info: zipfile.ZipInfo) -> str:
42
+ file_content = zf.read(file_info.filename)
43
+ file_extension = Path(file_info.filename).suffix.lower()
44
+
45
+ # Route to the appropriate synchronous parser
46
+ if file_extension == '.pdf':
47
+ temp_file_path = temp_dir / Path(file_info.filename).name
48
+ temp_file_path.parent.mkdir(parents=True, exist_ok=True)
49
+ temp_file_path.write_bytes(file_content)
50
+ return process_pdf_with_hybrid_parallel_sync(temp_file_path)
51
+ elif file_extension in ['.png', '.jpg', '.jpeg']:
52
+ return process_image_element(Image.open(io.BytesIO(file_content)))
53
+ elif file_extension in ['.docx', '.pptx', '.html']:
54
+ return _process_generic_file_sync(file_content, file_info.filename)
55
+ else:
56
+ print(f"Skipping unsupported file in ZIP: {file_info.filename}")
57
+ return ""
58
+
59
+ with zipfile.ZipFile(io.BytesIO(zip_content)) as zf:
60
+ file_list = [info for info in zf.infolist() if not info.is_dir()]
61
+ with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as executor:
62
+ future_to_file = {executor.submit(process_single_zipped_file, zf, file_info): file_info for file_info in file_list}
63
+ for future in as_completed(future_to_file):
64
+ try:
65
+ text = future.result()
66
+ if text:
67
+ all_extracted_texts.append(f"--- Content from: {future_to_file[future].filename} ---\n{text}")
68
+ except Exception as e:
69
+ print(f"Error processing file '{future_to_file[future].filename}' from ZIP: {e}")
70
+
71
+ return "\n\n".join(all_extracted_texts)
72
+
73
+ # --- Main Asynchronous Ingestion and Routing Function ---
74
+
75
+ async def ingest_and_parse_document(doc_url: HttpUrl) -> str:
76
+ """Asynchronously downloads and parses a document using the optimal strategy."""
77
+ print(f"Initiating processing for URL: {doc_url}")
78
+ os.makedirs(LOCAL_STORAGE_DIR, exist_ok=True)
79
+ start_time = time.perf_counter()
80
+
81
+ try:
82
+ async with httpx.AsyncClient() as client:
83
+ response = await client.get(str(doc_url), timeout=120.0, follow_redirects=True)
84
+ response.raise_for_status()
85
+ doc_bytes = response.content
86
+ print("Download successful.")
87
+
88
+ filename = unquote(os.path.basename(urlparse(str(doc_url)).path)) or "downloaded_file"
89
+ local_file_path = Path(LOCAL_STORAGE_DIR) / filename
90
+ file_extension = local_file_path.suffix.lower()
91
+
92
+ # Run the appropriate CPU-bound parsing function in a separate thread
93
+ if file_extension == '.pdf':
94
+ local_file_path.write_bytes(doc_bytes)
95
+ doc_text = await asyncio.to_thread(process_pdf_with_hybrid_parallel_sync, local_file_path)
96
+ elif file_extension == '.zip':
97
+ doc_text = await asyncio.to_thread(_process_zip_file_in_parallel, doc_bytes, Path(LOCAL_STORAGE_DIR))
98
+ elif file_extension in ['.png', '.jpg', '.jpeg']:
99
+ image = Image.open(io.BytesIO(doc_bytes))
100
+ doc_text = await asyncio.to_thread(process_image_element, image)
101
+ elif file_extension in ['.docx', '.pptx', '.html']:
102
+ doc_text = await asyncio.to_thread(_process_generic_file_sync, doc_bytes, filename)
103
+ else:
104
+ raise ValueError(f"Unsupported file type: {file_extension}")
105
+
106
+ elapsed_time = time.perf_counter() - start_time
107
+ print(f"Total processing time: {elapsed_time:.4f} seconds.")
108
+ if not doc_text.strip():
109
+ raise ValueError("Document parsing yielded no content.")
110
+
111
+ return doc_text
112
+
113
+ except Exception as e:
114
+ print(f"An unexpected error occurred: {e}")
115
+ raise
116
+
117
+ # # Example of how to run the pipeline
118
+ # async def main():
119
+ # # Example URL pointing to a PDF with tables
120
+ # pdf_url = HttpUrl("https://www.w3.org/WAI/WCAG21/working-examples/pdf-table-linearized/table.pdf")
121
+ # try:
122
+ # content = await ingest_and_parse_document(pdf_url)
123
+ # print("\n--- FINAL EXTRACTED CONTENT ---")
124
+ # print(content)
125
+ # except Exception as e:
126
+ # print(f"Pipeline failed: {e}")
127
+
128
+ # if __name__ == "__main__":
129
+ # asyncio.run(main())
main.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: main.py
2
+ import time
3
+ import os
4
+ import asyncio
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel, HttpUrl
7
+ from typing import List, Dict, Any
8
+ from dotenv import load_dotenv
9
+
10
+ # Assuming 'ingestion_router.py' is in the same directory and contains the function
11
+ from ingestion_router import ingest_and_parse_document
12
+ from chunking_parent import create_parent_child_chunks
13
+ from embedding import EmbeddingClient
14
+ from retrieval_parent import Retriever
15
+ from generation import generate_answer
16
+
17
+ load_dotenv()
18
+
19
+ app = FastAPI(
20
+ title="Modular RAG API",
21
+ description="A modular API for Retrieval-Augmented Generation with Parent-Child Retrieval.",
22
+ version="2.3.0", # Updated version
23
+ )
24
+
25
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
26
+ embedding_client = EmbeddingClient()
27
+ retriever = Retriever(embedding_client=embedding_client)
28
+
29
+ # --- Pydantic Models ---
30
+ class RunRequest(BaseModel):
31
+ documents: HttpUrl
32
+ questions: List[str]
33
+
34
+ class RunResponse(BaseModel):
35
+ answers: List[str]
36
+
37
+ class TestRequest(BaseModel):
38
+ documents: HttpUrl
39
+
40
+ # --- NEW: Test Endpoint for Ingestion and Parsing ---
41
+ @app.post("/test/ingestion", response_model=Dict[str, Any], tags=["Testing"])
42
+ async def test_ingestion_endpoint(request: TestRequest):
43
+ """
44
+ Tests the complete ingestion and parsing pipeline.
45
+ Downloads a document from a URL, processes it using the modular
46
+ parsing strategy (e.g., parallel for PDF, standard for DOCX),
47
+ and returns the extracted Markdown content and time taken.
48
+ """
49
+ print("--- Running Document Ingestion & Parsing Test ---")
50
+ start_time = time.perf_counter()
51
+ try:
52
+ # Step 1: Call the main ingestion function from your router
53
+ markdown_content = await ingest_and_parse_document(request.documents)
54
+
55
+ end_time = time.perf_counter()
56
+ duration = end_time - start_time
57
+ print(f"--- Ingestion and Parsing took {duration:.2f} seconds ---")
58
+
59
+ if not markdown_content:
60
+ raise HTTPException(
61
+ status_code=404,
62
+ detail="Document processed, but no content was extracted."
63
+ )
64
+
65
+ return {
66
+ "total_time_seconds": duration,
67
+ "character_count": len(markdown_content),
68
+ "extracted_content": markdown_content,
69
+ }
70
+ except Exception as e:
71
+ # Catch potential download errors, parsing errors, or unsupported file types
72
+ raise HTTPException(status_code=500, detail=f"An error occurred during ingestion test: {str(e)}")
73
+
74
+
75
+ # --- Test Endpoint for Parent-Child Chunking ---
76
+ @app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
77
+ async def test_chunking_endpoint(request: TestRequest):
78
+ """
79
+ Tests the parent-child chunking strategy.
80
+ Returns parent chunks, child chunks, and the time taken.
81
+ """
82
+ print("--- Running Parent-Child Chunking Test ---")
83
+ start_time = time.perf_counter()
84
+
85
+ try:
86
+ # Step 1: Parse the document to get raw text
87
+ markdown_content = await ingest_and_parse_document(request.documents)
88
+
89
+ # Step 2: Create parent and child chunks
90
+ child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
91
+
92
+ end_time = time.perf_counter()
93
+ duration = end_time - start_time
94
+ print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")
95
+
96
+ # Convert Document objects to a JSON-serializable list for the response
97
+ child_chunk_results = [
98
+ {"page_content": doc.page_content, "metadata": doc.metadata}
99
+ for doc in child_documents
100
+ ]
101
+
102
+ # Retrieve parent documents from the in-memory store
103
+ parent_docs = docstore.mget(list(docstore.store.keys()))
104
+ parent_chunk_results = [
105
+ {"page_content": doc.page_content, "metadata": doc.metadata}
106
+ for doc in parent_docs if doc
107
+ ]
108
+
109
+ return {
110
+ "total_time_seconds": duration,
111
+ "parent_chunk_count": len(parent_chunk_results),
112
+ "child_chunk_count": len(child_chunk_results),
113
+ "parent_chunks": parent_chunk_results,
114
+ "child_chunks": child_chunk_results,
115
+ }
116
+ except Exception as e:
117
+ raise HTTPException(status_code=500, detail=f"An error occurred during chunking test: {str(e)}")
118
+
119
+
120
+ @app.post("/hackrx/run", response_model=RunResponse)
121
+ async def run_rag_pipeline(request: RunRequest):
122
+ try:
123
+ print("--- Kicking off RAG Pipeline with Parent-Child Strategy ---")
124
+
125
+ # --- STAGE 1: DOCUMENT INGESTION ---
126
+ markdown_content = await ingest_and_parse_document(request.documents)
127
+
128
+ # --- STAGE 2: PARENT-CHILD CHUNKING ---
129
+ child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
130
+
131
+ if not child_documents:
132
+ raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")
133
+
134
+ # --- STAGE 3: INDEXING ---
135
+ retriever.index(child_documents, docstore)
136
+
137
+ # --- STAGE 4: CONCURRENT RETRIEVAL & GENERATION ---
138
+ print("Starting retrieval for all questions...")
139
+ retrieval_tasks = [
140
+ retriever.retrieve(q, GROQ_API_KEY)
141
+ for q in request.questions
142
+ ]
143
+ all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
144
+ print("Retrieval complete. Starting final answer generation...")
145
+
146
+ answer_tasks = [
147
+ generate_answer(q, chunks, GROQ_API_KEY)
148
+ for q, chunks in zip(request.questions, all_retrieved_chunks)
149
+ ]
150
+ final_answers = await asyncio.gather(*answer_tasks)
151
+
152
+ print("--- RAG Pipeline Completed Successfully ---")
153
+ return RunResponse(answers=final_answers)
154
+
155
+ except Exception as e:
156
+ raise HTTPException(
157
+ status_code=500, detail=f"An internal server error occurred: {str(e)}"
158
+ )
pdf_parallel_parser.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: pdf_parallel_parser.py
2
+
3
+ import fitz # PyMuPDF
4
+ from PIL import Image
5
+ import io
6
+ import os
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from pathlib import Path
9
+
10
+ # Import the specialized parsers from our other module
11
+ from complex_parser import process_table_element, process_image_element
12
+
13
+ def _is_bbox_contained(inner_bbox, outer_bbox):
14
+ """Check if inner_bbox is fully inside outer_bbox."""
15
+ return (inner_bbox[0] >= outer_bbox[0] and
16
+ inner_bbox[1] >= outer_bbox[1] and
17
+ inner_bbox[2] <= outer_bbox[2] and
18
+ inner_bbox[3] <= outer_bbox[3])
19
+
20
+ def _process_page(page: fitz.Page) -> str:
21
+ """
22
+ Processes a single PDF page to extract text, tables, and images.
23
+ - Tables are found and processed with the complex_parser.
24
+ - Plain text is extracted, excluding any text already inside a processed table.
25
+ """
26
+ page_content = []
27
+
28
+ # 1. Find and process tables first
29
+ table_bboxes = []
30
+ try:
31
+ tables = page.find_tables()
32
+ pix = page.get_pixmap(dpi=200)
33
+ page_image = Image.open(io.BytesIO(pix.tobytes("png")))
34
+
35
+ print(f"Page {page.number}: Found {len(tables.tables)} potential tables.")
36
+ for i, table in enumerate(tables):
37
+ table_bboxes.append(table.bbox)
38
+ table_image = page_image.crop(table.bbox)
39
+ markdown_table = process_table_element(table_image)
40
+ page_content.append(markdown_table)
41
+ except Exception as e:
42
+ print(f"Could not process tables on page {page.number}: {e}")
43
+
44
+ # 2. Extract text blocks, excluding those within table bounding boxes
45
+ text_blocks = page.get_text("blocks")
46
+ for block in text_blocks:
47
+ block_bbox = block[:4]
48
+ # Check if this text block is inside any of the tables we just processed
49
+ is_in_table = any(_is_bbox_contained(block_bbox, table_bbox) for table_bbox in table_bboxes)
50
+ if not is_in_table:
51
+ page_content.append(block[4].strip())
52
+
53
+ # Note: Image extraction can be added here if needed, similar to table extraction.
54
+
55
+ return "\n".join(page_content)
56
+
57
+ def process_pdf_with_hybrid_parallel_sync(file_path: Path) -> str:
58
+ """
59
+ Processes a PDF file in parallel using PyMuPDF and the complex_parser.
60
+ """
61
+ print(f"Processing PDF '{file_path.name}' with parallel page-by-page strategy...")
62
+ all_page_texts = []
63
+ doc = fitz.open(file_path)
64
+
65
+ with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as executor:
66
+ futures = {executor.submit(_process_page, page): page.number for page in doc}
67
+
68
+ # Collect results in page order
69
+ results = ["" for _ in range(len(doc))]
70
+ for future in as_completed(futures):
71
+ page_num = futures[future]
72
+ try:
73
+ results[page_num] = future.result()
74
+ except Exception as e:
75
+ print(f"Error processing page {page_num}: {e}")
76
+ all_page_texts = results
77
+
78
+ return f"\n\n--- Page Break ---\n\n".join(all_page_texts)
retrieval_parent.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: retrieval_parent.py
2
+
3
+ import time # <-- ADD THIS IMPORT
4
+ import asyncio
5
+ import numpy as np
6
+ import torch
7
+ import json
8
+ from groq import AsyncGroq
9
+ from rank_bm25 import BM25Okapi
10
+ from sentence_transformers import CrossEncoder
11
+ from sklearn.preprocessing import MinMaxScaler
12
+ from torch.nn.functional import cosine_similarity
13
+ from typing import List, Dict, Tuple
14
+ from langchain.storage import InMemoryStore
15
+
16
+ from embedding import EmbeddingClient
17
+ from langchain_core.documents import Document
18
+
19
+ # --- Configuration ---
20
+ GENERATION_MODEL = "llama3-8b-8192"
21
+ RERANKER_MODEL = 'cross-encoder/stsb-distilroberta-base'
22
+ INITIAL_K_CANDIDATES = 20
23
+ TOP_K_CHUNKS = 10
24
+
25
+ async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
26
+ """Generates a hypothetical document to answer the query (HyDE)."""
27
+ if not groq_api_key:
28
+ print("Groq API key not set. Skipping HyDE generation.")
29
+ return ""
30
+
31
+ print(f"Starting HyDE generation for query: '{query}'...")
32
+ client = AsyncGroq(api_key=groq_api_key)
33
+ prompt = (
34
+ f"Write a brief, formal passage that directly answers the following question. "
35
+ f"This passage will be used to find similar documents. "
36
+ f"Do not include the question or any conversational text.\n\n"
37
+ f"Question: {query}\n\n"
38
+ f"Hypothetical Passage:"
39
+ )
40
+
41
+ start_time = time.perf_counter() # <-- START TIMER
42
+ try:
43
+ chat_completion = await client.chat.completions.create(
44
+ messages=[{"role": "user", "content": prompt}],
45
+ model=GENERATION_MODEL,
46
+ temperature=0.7,
47
+ max_tokens=500,
48
+ )
49
+ end_time = time.perf_counter() # <-- END TIMER
50
+ print(f"--- HyDE generation took {end_time - start_time:.4f} seconds ---") # <-- PRINT DURATION
51
+ return chat_completion.choices[0].message.content
52
+ except Exception as e:
53
+ print(f"An error occurred during HyDE generation: {e}")
54
+ return ""
55
+
56
+ async def generate_expanded_terms(query: str, groq_api_key: str) -> List[str]:
57
+ """Generates semantically related search terms for a query."""
58
+ if not groq_api_key:
59
+ print("Groq API key not set. Skipping Semantic Expansion.")
60
+ return [query]
61
+
62
+ print(f"Starting Semantic Expansion for query: '{query}'...")
63
+ client = AsyncGroq(api_key=groq_api_key)
64
+ prompt = (
65
+ f"You are a search query expansion expert. Based on the following query, generate up to 4 additional, "
66
+ f"semantically related search terms. The terms should be relevant for finding information in technical documents. "
67
+ f"Return the original query plus the new terms as a single JSON list of strings.\n\n"
68
+ f"Query: \"{query}\"\n\n"
69
+ f"JSON List:"
70
+ )
71
+
72
+ start_time = time.perf_counter() # <-- START TIMER
73
+ try:
74
+ chat_completion = await client.chat.completions.create(
75
+ messages=[{"role": "user", "content": prompt}],
76
+ model=GENERATION_MODEL,
77
+ temperature=0.4,
78
+ max_tokens=200,
79
+ response_format={"type": "json_object"},
80
+ )
81
+ end_time = time.perf_counter() # <-- END TIMER
82
+ print(f"--- Semantic Expansion took {end_time - start_time:.4f} seconds ---") # <-- PRINT DURATION
83
+
84
+ result_text = chat_completion.choices[0].message.content
85
+ terms = json.loads(result_text)
86
+
87
+ if isinstance(terms, dict) and 'terms' in terms:
88
+ return terms['terms']
89
+ return terms
90
+ except Exception as e:
91
+ print(f"An error occurred during Semantic Expansion: {e}")
92
+ return [query]
93
+
94
+
95
+ class Retriever:
96
+ """Manages hybrid search with parent-child retrieval."""
97
+
98
+ def __init__(self, embedding_client: EmbeddingClient):
99
+ self.embedding_client = embedding_client
100
+ self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
101
+ self.bm25 = None
102
+ self.document_chunks = []
103
+ self.chunk_embeddings = None
104
+ self.docstore = InMemoryStore()
105
+ print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
106
+
107
+ def index(self, child_documents: List[Document], docstore: InMemoryStore):
108
+ """Builds the search index from child documents and stores parent documents."""
109
+ self.document_chunks = child_documents
110
+ self.docstore = docstore
111
+
112
+ corpus = [doc.page_content for doc in child_documents]
113
+ if not corpus:
114
+ print("No documents to index.")
115
+ return
116
+
117
+ print("Indexing child documents for retrieval...")
118
+ tokenized_corpus = [doc.split(" ") for doc in corpus]
119
+ self.bm25 = BM25Okapi(tokenized_corpus)
120
+ self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
121
+ print("Indexing complete.")
122
+
123
+ def _hybrid_search(self, query: str, hyde_doc: str, expanded_terms: List[str]) -> List[Tuple[int, float]]:
124
+ """Performs a hybrid search using expanded terms for BM25 and a HyDE doc for dense search."""
125
+ if self.bm25 is None or self.chunk_embeddings is None:
126
+ raise ValueError("Retriever has not been indexed. Call index() first.")
127
+
128
+ print(f"Running BM25 with expanded terms: {expanded_terms}")
129
+ bm25_scores = self.bm25.get_scores(expanded_terms)
130
+
131
+ enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
132
+ query_embedding = self.embedding_client.create_embeddings([enhanced_query])
133
+ dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
134
+
135
+ scaler = MinMaxScaler()
136
+ norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
137
+ norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
138
+ combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
139
+
140
+ top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
141
+ return [(idx, combined_scores[idx]) for idx in top_indices]
142
+
143
+ async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
144
+ """Reranks candidates using a CrossEncoder model."""
145
+ if not candidates:
146
+ return []
147
+
148
+ print(f"Reranking {len(candidates)} candidates...")
149
+ rerank_input = [[query, chunk["content"]] for chunk in candidates]
150
+
151
+ rerank_scores = await asyncio.to_thread(
152
+ self.reranker.predict, rerank_input, show_progress_bar=False
153
+ )
154
+
155
+ for candidate, score in zip(candidates, rerank_scores):
156
+ candidate['rerank_score'] = score
157
+
158
+ candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
159
+ return candidates[:TOP_K_CHUNKS]
160
+
161
+ async def retrieve(self, query: str, groq_api_key: str) -> List[Dict]:
162
+ """Executes the full retrieval pipeline: expansion, HyDE, hybrid search, and reranking."""
163
+ print(f"\n--- Retrieving documents for query: '{query}' ---")
164
+
165
+ hyde_task = generate_hypothetical_document(query, groq_api_key)
166
+ expansion_task = generate_expanded_terms(query, groq_api_key)
167
+ hyde_doc, expanded_terms = await asyncio.gather(hyde_task, expansion_task)
168
+
169
+ initial_candidates_info = self._hybrid_search(query, hyde_doc, expanded_terms)
170
+
171
+ retrieved_child_docs = [{
172
+ "content": self.document_chunks[idx].page_content,
173
+ "metadata": self.document_chunks[idx].metadata,
174
+ } for idx, score in initial_candidates_info]
175
+
176
+ reranked_child_docs = await self._rerank(query, retrieved_child_docs)
177
+
178
+ parent_ids = []
179
+ for doc in reranked_child_docs:
180
+ parent_id = doc["metadata"]["parent_id"]
181
+ if parent_id not in parent_ids:
182
+ parent_ids.append(parent_id)
183
+
184
+ retrieved_parents = self.docstore.mget(parent_ids)
185
+ final_parent_docs = [doc for doc in retrieved_parents if doc is not None]
186
+
187
+ final_context = [{
188
+ "content": doc.page_content,
189
+ "metadata": doc.metadata
190
+ } for doc in final_parent_docs]
191
+
192
+ print(f"Retrieved {len(final_context)} final parent chunks for context.")
193
+ return final_context