Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- chunking_parent.py +79 -0
- embedding.py +40 -0
- generation.py +57 -0
- ingestion_router.py +129 -0
- main.py +158 -0
- pdf_parallel_parser.py +78 -0
- retrieval_parent.py +193 -0
chunking_parent.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: chunking.py
|
2 |
+
import uuid
|
3 |
+
from typing import List, Tuple, Dict, Any
|
4 |
+
from langchain_core.documents import Document
|
5 |
+
from langchain.storage import InMemoryStore
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
|
8 |
+
# --- Configuration for Parent-Child Splitting ---
|
9 |
+
# Parent chunks are the larger documents passed to the LLM for context.
|
10 |
+
PARENT_CHUNK_SIZE = 2000
|
11 |
+
PARENT_CHUNK_OVERLAP = 200
|
12 |
+
|
13 |
+
# Child chunks are the smaller, more granular documents used for retrieval.
|
14 |
+
CHILD_CHUNK_SIZE = 400
|
15 |
+
CHILD_CHUNK_OVERLAP = 100
|
16 |
+
|
17 |
+
def create_parent_child_chunks(
|
18 |
+
full_text: str
|
19 |
+
) -> Tuple[List[Document], InMemoryStore, Dict[str, str]]:
|
20 |
+
"""
|
21 |
+
Implements the Parent Document strategy for chunking.
|
22 |
+
|
23 |
+
1. Splits the document into larger "parent" chunks.
|
24 |
+
2. Splits the parent chunks into smaller "child" chunks.
|
25 |
+
3. The child chunks are used for retrieval, while the parent chunks
|
26 |
+
are used to provide context to the LLM.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
full_text: The entire text content of the document.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
A tuple containing:
|
33 |
+
- A list of the small "child" documents for the vector store.
|
34 |
+
- An in-memory store mapping parent document IDs to the parent documents.
|
35 |
+
- A dictionary mapping child document IDs to their parent's ID.
|
36 |
+
"""
|
37 |
+
if not full_text:
|
38 |
+
print("Warning: Input text for chunking is empty.")
|
39 |
+
return [], InMemoryStore(), {}
|
40 |
+
|
41 |
+
print("Creating parent and child chunks...")
|
42 |
+
|
43 |
+
# This splitter creates the large documents that will be stored.
|
44 |
+
parent_splitter = RecursiveCharacterTextSplitter(
|
45 |
+
chunk_size=PARENT_CHUNK_SIZE,
|
46 |
+
chunk_overlap=PARENT_CHUNK_OVERLAP,
|
47 |
+
)
|
48 |
+
|
49 |
+
# This splitter creates the small, granular chunks for retrieval.
|
50 |
+
child_splitter = RecursiveCharacterTextSplitter(
|
51 |
+
chunk_size=CHILD_CHUNK_SIZE,
|
52 |
+
chunk_overlap=CHILD_CHUNK_OVERLAP,
|
53 |
+
)
|
54 |
+
|
55 |
+
parent_documents = parent_splitter.create_documents([full_text])
|
56 |
+
|
57 |
+
docstore = InMemoryStore()
|
58 |
+
child_documents = []
|
59 |
+
child_to_parent_id_map = {}
|
60 |
+
|
61 |
+
# Generate unique IDs for each parent document and add them to the store
|
62 |
+
parent_ids = [str(uuid.uuid4()) for _ in parent_documents]
|
63 |
+
docstore.mset(list(zip(parent_ids, parent_documents)))
|
64 |
+
|
65 |
+
# Split each parent document into smaller child documents
|
66 |
+
for i, p_doc in enumerate(parent_documents):
|
67 |
+
parent_id = parent_ids[i]
|
68 |
+
_child_docs = child_splitter.split_documents([p_doc])
|
69 |
+
|
70 |
+
for _child_doc in _child_docs:
|
71 |
+
child_id = str(uuid.uuid4())
|
72 |
+
_child_doc.metadata["parent_id"] = parent_id
|
73 |
+
_child_doc.metadata["child_id"] = child_id
|
74 |
+
child_to_parent_id_map[child_id] = parent_id
|
75 |
+
|
76 |
+
child_documents.extend(_child_docs)
|
77 |
+
|
78 |
+
print(f"Created {len(parent_documents)} parent chunks and {len(child_documents)} child chunks.")
|
79 |
+
return child_documents, docstore, child_to_parent_id_map
|
embedding.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: embedding.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
# --- Configuration ---
|
8 |
+
EMBEDDING_MODEL_NAME = "sentence-transformers/stsb-xlm-r-multilingual"
|
9 |
+
|
10 |
+
class EmbeddingClient:
|
11 |
+
"""A client for generating text embeddings using a local sentence transformer model."""
|
12 |
+
|
13 |
+
def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
|
14 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
self.model = SentenceTransformer(model_name, device=self.device)
|
16 |
+
print(f"EmbeddingClient initialized with model '{model_name}' on device '{self.device}'.")
|
17 |
+
|
18 |
+
def create_embeddings(self, texts: List[str]) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
Generates embeddings for a list of text chunks.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
texts: A list of strings to be embedded.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
A torch.Tensor containing the generated embeddings.
|
27 |
+
"""
|
28 |
+
if not texts:
|
29 |
+
return torch.tensor([])
|
30 |
+
|
31 |
+
print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
|
32 |
+
try:
|
33 |
+
embeddings = self.model.encode(
|
34 |
+
texts, convert_to_tensor=True, show_progress_bar=False
|
35 |
+
)
|
36 |
+
print("Embeddings generated successfully.")
|
37 |
+
return embeddings
|
38 |
+
except Exception as e:
|
39 |
+
print(f"An error occurred during embedding generation: {e}")
|
40 |
+
raise
|
generation.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: generation.py
|
2 |
+
from groq import AsyncGroq
|
3 |
+
from typing import List, Dict
|
4 |
+
|
5 |
+
# --- Configuration ---
|
6 |
+
GROQ_MODEL_NAME = "llama3-8b-8192"
|
7 |
+
|
8 |
+
async def generate_answer(query: str, context_chunks: List[Dict], groq_api_key: str) -> str:
|
9 |
+
"""
|
10 |
+
Generates a final answer using the Groq API based on the query and retrieved context.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
query: The user's original question.
|
14 |
+
context_chunks: A list of the most relevant, reranked document chunks.
|
15 |
+
groq_api_key: The API key for the Groq service.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
A string containing the generated answer.
|
19 |
+
"""
|
20 |
+
if not groq_api_key:
|
21 |
+
return "Error: Groq API key is not set."
|
22 |
+
if not context_chunks:
|
23 |
+
return "I do not have enough information to answer this question based on the provided document."
|
24 |
+
|
25 |
+
print("Generating final answer with Groq...")
|
26 |
+
client = AsyncGroq(api_key=groq_api_key)
|
27 |
+
|
28 |
+
# Format the context for the prompt
|
29 |
+
context_str = "\n\n---\n\n".join(
|
30 |
+
[f"Context Chunk:\n{chunk['content']}" for chunk in context_chunks]
|
31 |
+
)
|
32 |
+
|
33 |
+
prompt = (
|
34 |
+
"You are an expert Q&A system. Your task is to extract information with 100% accuracy from the provided text. Provide a brief and direct answer."
|
35 |
+
"Do not mention the context in your response. Answer *only* using the information from the provided document."
|
36 |
+
"Do not infer, add, or assume any information that is not explicitly written in the source text. If the answer is not in the document, state that the information is not available."
|
37 |
+
"When the question involves numbers, percentages, or monetary values, extract the exact figures from the text."
|
38 |
+
"Double-check that the value corresponds to the correct plan or condition mentioned in the question."
|
39 |
+
"\n\n"
|
40 |
+
f"CONTEXT:\n{context_str}\n\n"
|
41 |
+
f"QUESTION:\n{query}\n\n"
|
42 |
+
"ANSWER:"
|
43 |
+
)
|
44 |
+
|
45 |
+
try:
|
46 |
+
chat_completion = await client.chat.completions.create(
|
47 |
+
messages=[{"role": "user", "content": prompt}],
|
48 |
+
model=GROQ_MODEL_NAME,
|
49 |
+
temperature=0.2, # Lower temperature for more factual answers
|
50 |
+
max_tokens=500,
|
51 |
+
)
|
52 |
+
answer = chat_completion.choices[0].message.content
|
53 |
+
print("Answer generated successfully.")
|
54 |
+
return answer
|
55 |
+
except Exception as e:
|
56 |
+
print(f"An error occurred during Groq API call: {e}")
|
57 |
+
return "Could not generate an answer due to an API error."
|
ingestion_router.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: ingestion_router.py
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import httpx
|
5 |
+
import zipfile
|
6 |
+
import io
|
7 |
+
import asyncio
|
8 |
+
from PIL import Image
|
9 |
+
from pathlib import Path
|
10 |
+
from urllib.parse import urlparse, unquote
|
11 |
+
from pydantic import HttpUrl
|
12 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
13 |
+
|
14 |
+
# --- Import our custom parsers ---
|
15 |
+
from pdf_parallel_parser import process_pdf_with_hybrid_parallel_sync
|
16 |
+
from complex_parser import process_image_element
|
17 |
+
|
18 |
+
# --- A simple parser for generic files (DOCX, etc.) using unstructured ---
|
19 |
+
from unstructured.partition.auto import partition
|
20 |
+
|
21 |
+
# --- Configuration ---
|
22 |
+
LOCAL_STORAGE_DIR = "data/"
|
23 |
+
|
24 |
+
# --- Synchronous, CPU-Bound Parsing Functions ---
|
25 |
+
|
26 |
+
def _process_generic_file_sync(file_content: bytes, filename: str) -> str:
|
27 |
+
"""Fallback parser for standard files like DOCX, PPTX, etc., using unstructured."""
|
28 |
+
print(f"Processing '{filename}' with unstructured (standard)...")
|
29 |
+
try:
|
30 |
+
elements = partition(file=io.BytesIO(file_content), file_filename=filename)
|
31 |
+
return "\n\n".join([el.text for el in elements])
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Unstructured failed for {filename}: {e}")
|
34 |
+
return ""
|
35 |
+
|
36 |
+
def _process_zip_file_in_parallel(zip_content: bytes, temp_dir: Path) -> str:
|
37 |
+
"""Extracts and processes files from a ZIP archive in parallel."""
|
38 |
+
print("Initiating parallel processing of ZIP archive...")
|
39 |
+
all_extracted_texts = []
|
40 |
+
|
41 |
+
def process_single_zipped_file(zf: zipfile.ZipFile, file_info: zipfile.ZipInfo) -> str:
|
42 |
+
file_content = zf.read(file_info.filename)
|
43 |
+
file_extension = Path(file_info.filename).suffix.lower()
|
44 |
+
|
45 |
+
# Route to the appropriate synchronous parser
|
46 |
+
if file_extension == '.pdf':
|
47 |
+
temp_file_path = temp_dir / Path(file_info.filename).name
|
48 |
+
temp_file_path.parent.mkdir(parents=True, exist_ok=True)
|
49 |
+
temp_file_path.write_bytes(file_content)
|
50 |
+
return process_pdf_with_hybrid_parallel_sync(temp_file_path)
|
51 |
+
elif file_extension in ['.png', '.jpg', '.jpeg']:
|
52 |
+
return process_image_element(Image.open(io.BytesIO(file_content)))
|
53 |
+
elif file_extension in ['.docx', '.pptx', '.html']:
|
54 |
+
return _process_generic_file_sync(file_content, file_info.filename)
|
55 |
+
else:
|
56 |
+
print(f"Skipping unsupported file in ZIP: {file_info.filename}")
|
57 |
+
return ""
|
58 |
+
|
59 |
+
with zipfile.ZipFile(io.BytesIO(zip_content)) as zf:
|
60 |
+
file_list = [info for info in zf.infolist() if not info.is_dir()]
|
61 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as executor:
|
62 |
+
future_to_file = {executor.submit(process_single_zipped_file, zf, file_info): file_info for file_info in file_list}
|
63 |
+
for future in as_completed(future_to_file):
|
64 |
+
try:
|
65 |
+
text = future.result()
|
66 |
+
if text:
|
67 |
+
all_extracted_texts.append(f"--- Content from: {future_to_file[future].filename} ---\n{text}")
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Error processing file '{future_to_file[future].filename}' from ZIP: {e}")
|
70 |
+
|
71 |
+
return "\n\n".join(all_extracted_texts)
|
72 |
+
|
73 |
+
# --- Main Asynchronous Ingestion and Routing Function ---
|
74 |
+
|
75 |
+
async def ingest_and_parse_document(doc_url: HttpUrl) -> str:
|
76 |
+
"""Asynchronously downloads and parses a document using the optimal strategy."""
|
77 |
+
print(f"Initiating processing for URL: {doc_url}")
|
78 |
+
os.makedirs(LOCAL_STORAGE_DIR, exist_ok=True)
|
79 |
+
start_time = time.perf_counter()
|
80 |
+
|
81 |
+
try:
|
82 |
+
async with httpx.AsyncClient() as client:
|
83 |
+
response = await client.get(str(doc_url), timeout=120.0, follow_redirects=True)
|
84 |
+
response.raise_for_status()
|
85 |
+
doc_bytes = response.content
|
86 |
+
print("Download successful.")
|
87 |
+
|
88 |
+
filename = unquote(os.path.basename(urlparse(str(doc_url)).path)) or "downloaded_file"
|
89 |
+
local_file_path = Path(LOCAL_STORAGE_DIR) / filename
|
90 |
+
file_extension = local_file_path.suffix.lower()
|
91 |
+
|
92 |
+
# Run the appropriate CPU-bound parsing function in a separate thread
|
93 |
+
if file_extension == '.pdf':
|
94 |
+
local_file_path.write_bytes(doc_bytes)
|
95 |
+
doc_text = await asyncio.to_thread(process_pdf_with_hybrid_parallel_sync, local_file_path)
|
96 |
+
elif file_extension == '.zip':
|
97 |
+
doc_text = await asyncio.to_thread(_process_zip_file_in_parallel, doc_bytes, Path(LOCAL_STORAGE_DIR))
|
98 |
+
elif file_extension in ['.png', '.jpg', '.jpeg']:
|
99 |
+
image = Image.open(io.BytesIO(doc_bytes))
|
100 |
+
doc_text = await asyncio.to_thread(process_image_element, image)
|
101 |
+
elif file_extension in ['.docx', '.pptx', '.html']:
|
102 |
+
doc_text = await asyncio.to_thread(_process_generic_file_sync, doc_bytes, filename)
|
103 |
+
else:
|
104 |
+
raise ValueError(f"Unsupported file type: {file_extension}")
|
105 |
+
|
106 |
+
elapsed_time = time.perf_counter() - start_time
|
107 |
+
print(f"Total processing time: {elapsed_time:.4f} seconds.")
|
108 |
+
if not doc_text.strip():
|
109 |
+
raise ValueError("Document parsing yielded no content.")
|
110 |
+
|
111 |
+
return doc_text
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
print(f"An unexpected error occurred: {e}")
|
115 |
+
raise
|
116 |
+
|
117 |
+
# # Example of how to run the pipeline
|
118 |
+
# async def main():
|
119 |
+
# # Example URL pointing to a PDF with tables
|
120 |
+
# pdf_url = HttpUrl("https://www.w3.org/WAI/WCAG21/working-examples/pdf-table-linearized/table.pdf")
|
121 |
+
# try:
|
122 |
+
# content = await ingest_and_parse_document(pdf_url)
|
123 |
+
# print("\n--- FINAL EXTRACTED CONTENT ---")
|
124 |
+
# print(content)
|
125 |
+
# except Exception as e:
|
126 |
+
# print(f"Pipeline failed: {e}")
|
127 |
+
|
128 |
+
# if __name__ == "__main__":
|
129 |
+
# asyncio.run(main())
|
main.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: main.py
|
2 |
+
import time
|
3 |
+
import os
|
4 |
+
import asyncio
|
5 |
+
from fastapi import FastAPI, HTTPException
|
6 |
+
from pydantic import BaseModel, HttpUrl
|
7 |
+
from typing import List, Dict, Any
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
# Assuming 'ingestion_router.py' is in the same directory and contains the function
|
11 |
+
from ingestion_router import ingest_and_parse_document
|
12 |
+
from chunking_parent import create_parent_child_chunks
|
13 |
+
from embedding import EmbeddingClient
|
14 |
+
from retrieval_parent import Retriever
|
15 |
+
from generation import generate_answer
|
16 |
+
|
17 |
+
load_dotenv()
|
18 |
+
|
19 |
+
app = FastAPI(
|
20 |
+
title="Modular RAG API",
|
21 |
+
description="A modular API for Retrieval-Augmented Generation with Parent-Child Retrieval.",
|
22 |
+
version="2.3.0", # Updated version
|
23 |
+
)
|
24 |
+
|
25 |
+
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
|
26 |
+
embedding_client = EmbeddingClient()
|
27 |
+
retriever = Retriever(embedding_client=embedding_client)
|
28 |
+
|
29 |
+
# --- Pydantic Models ---
|
30 |
+
class RunRequest(BaseModel):
|
31 |
+
documents: HttpUrl
|
32 |
+
questions: List[str]
|
33 |
+
|
34 |
+
class RunResponse(BaseModel):
|
35 |
+
answers: List[str]
|
36 |
+
|
37 |
+
class TestRequest(BaseModel):
|
38 |
+
documents: HttpUrl
|
39 |
+
|
40 |
+
# --- NEW: Test Endpoint for Ingestion and Parsing ---
|
41 |
+
@app.post("/test/ingestion", response_model=Dict[str, Any], tags=["Testing"])
|
42 |
+
async def test_ingestion_endpoint(request: TestRequest):
|
43 |
+
"""
|
44 |
+
Tests the complete ingestion and parsing pipeline.
|
45 |
+
Downloads a document from a URL, processes it using the modular
|
46 |
+
parsing strategy (e.g., parallel for PDF, standard for DOCX),
|
47 |
+
and returns the extracted Markdown content and time taken.
|
48 |
+
"""
|
49 |
+
print("--- Running Document Ingestion & Parsing Test ---")
|
50 |
+
start_time = time.perf_counter()
|
51 |
+
try:
|
52 |
+
# Step 1: Call the main ingestion function from your router
|
53 |
+
markdown_content = await ingest_and_parse_document(request.documents)
|
54 |
+
|
55 |
+
end_time = time.perf_counter()
|
56 |
+
duration = end_time - start_time
|
57 |
+
print(f"--- Ingestion and Parsing took {duration:.2f} seconds ---")
|
58 |
+
|
59 |
+
if not markdown_content:
|
60 |
+
raise HTTPException(
|
61 |
+
status_code=404,
|
62 |
+
detail="Document processed, but no content was extracted."
|
63 |
+
)
|
64 |
+
|
65 |
+
return {
|
66 |
+
"total_time_seconds": duration,
|
67 |
+
"character_count": len(markdown_content),
|
68 |
+
"extracted_content": markdown_content,
|
69 |
+
}
|
70 |
+
except Exception as e:
|
71 |
+
# Catch potential download errors, parsing errors, or unsupported file types
|
72 |
+
raise HTTPException(status_code=500, detail=f"An error occurred during ingestion test: {str(e)}")
|
73 |
+
|
74 |
+
|
75 |
+
# --- Test Endpoint for Parent-Child Chunking ---
|
76 |
+
@app.post("/test/chunk", response_model=Dict[str, Any], tags=["Testing"])
|
77 |
+
async def test_chunking_endpoint(request: TestRequest):
|
78 |
+
"""
|
79 |
+
Tests the parent-child chunking strategy.
|
80 |
+
Returns parent chunks, child chunks, and the time taken.
|
81 |
+
"""
|
82 |
+
print("--- Running Parent-Child Chunking Test ---")
|
83 |
+
start_time = time.perf_counter()
|
84 |
+
|
85 |
+
try:
|
86 |
+
# Step 1: Parse the document to get raw text
|
87 |
+
markdown_content = await ingest_and_parse_document(request.documents)
|
88 |
+
|
89 |
+
# Step 2: Create parent and child chunks
|
90 |
+
child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
|
91 |
+
|
92 |
+
end_time = time.perf_counter()
|
93 |
+
duration = end_time - start_time
|
94 |
+
print(f"--- Parsing and Chunking took {duration:.2f} seconds ---")
|
95 |
+
|
96 |
+
# Convert Document objects to a JSON-serializable list for the response
|
97 |
+
child_chunk_results = [
|
98 |
+
{"page_content": doc.page_content, "metadata": doc.metadata}
|
99 |
+
for doc in child_documents
|
100 |
+
]
|
101 |
+
|
102 |
+
# Retrieve parent documents from the in-memory store
|
103 |
+
parent_docs = docstore.mget(list(docstore.store.keys()))
|
104 |
+
parent_chunk_results = [
|
105 |
+
{"page_content": doc.page_content, "metadata": doc.metadata}
|
106 |
+
for doc in parent_docs if doc
|
107 |
+
]
|
108 |
+
|
109 |
+
return {
|
110 |
+
"total_time_seconds": duration,
|
111 |
+
"parent_chunk_count": len(parent_chunk_results),
|
112 |
+
"child_chunk_count": len(child_chunk_results),
|
113 |
+
"parent_chunks": parent_chunk_results,
|
114 |
+
"child_chunks": child_chunk_results,
|
115 |
+
}
|
116 |
+
except Exception as e:
|
117 |
+
raise HTTPException(status_code=500, detail=f"An error occurred during chunking test: {str(e)}")
|
118 |
+
|
119 |
+
|
120 |
+
@app.post("/hackrx/run", response_model=RunResponse)
|
121 |
+
async def run_rag_pipeline(request: RunRequest):
|
122 |
+
try:
|
123 |
+
print("--- Kicking off RAG Pipeline with Parent-Child Strategy ---")
|
124 |
+
|
125 |
+
# --- STAGE 1: DOCUMENT INGESTION ---
|
126 |
+
markdown_content = await ingest_and_parse_document(request.documents)
|
127 |
+
|
128 |
+
# --- STAGE 2: PARENT-CHILD CHUNKING ---
|
129 |
+
child_documents, docstore, _ = create_parent_child_chunks(markdown_content)
|
130 |
+
|
131 |
+
if not child_documents:
|
132 |
+
raise HTTPException(status_code=400, detail="Document could not be processed into chunks.")
|
133 |
+
|
134 |
+
# --- STAGE 3: INDEXING ---
|
135 |
+
retriever.index(child_documents, docstore)
|
136 |
+
|
137 |
+
# --- STAGE 4: CONCURRENT RETRIEVAL & GENERATION ---
|
138 |
+
print("Starting retrieval for all questions...")
|
139 |
+
retrieval_tasks = [
|
140 |
+
retriever.retrieve(q, GROQ_API_KEY)
|
141 |
+
for q in request.questions
|
142 |
+
]
|
143 |
+
all_retrieved_chunks = await asyncio.gather(*retrieval_tasks)
|
144 |
+
print("Retrieval complete. Starting final answer generation...")
|
145 |
+
|
146 |
+
answer_tasks = [
|
147 |
+
generate_answer(q, chunks, GROQ_API_KEY)
|
148 |
+
for q, chunks in zip(request.questions, all_retrieved_chunks)
|
149 |
+
]
|
150 |
+
final_answers = await asyncio.gather(*answer_tasks)
|
151 |
+
|
152 |
+
print("--- RAG Pipeline Completed Successfully ---")
|
153 |
+
return RunResponse(answers=final_answers)
|
154 |
+
|
155 |
+
except Exception as e:
|
156 |
+
raise HTTPException(
|
157 |
+
status_code=500, detail=f"An internal server error occurred: {str(e)}"
|
158 |
+
)
|
pdf_parallel_parser.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: pdf_parallel_parser.py
|
2 |
+
|
3 |
+
import fitz # PyMuPDF
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
# Import the specialized parsers from our other module
|
11 |
+
from complex_parser import process_table_element, process_image_element
|
12 |
+
|
13 |
+
def _is_bbox_contained(inner_bbox, outer_bbox):
|
14 |
+
"""Check if inner_bbox is fully inside outer_bbox."""
|
15 |
+
return (inner_bbox[0] >= outer_bbox[0] and
|
16 |
+
inner_bbox[1] >= outer_bbox[1] and
|
17 |
+
inner_bbox[2] <= outer_bbox[2] and
|
18 |
+
inner_bbox[3] <= outer_bbox[3])
|
19 |
+
|
20 |
+
def _process_page(page: fitz.Page) -> str:
|
21 |
+
"""
|
22 |
+
Processes a single PDF page to extract text, tables, and images.
|
23 |
+
- Tables are found and processed with the complex_parser.
|
24 |
+
- Plain text is extracted, excluding any text already inside a processed table.
|
25 |
+
"""
|
26 |
+
page_content = []
|
27 |
+
|
28 |
+
# 1. Find and process tables first
|
29 |
+
table_bboxes = []
|
30 |
+
try:
|
31 |
+
tables = page.find_tables()
|
32 |
+
pix = page.get_pixmap(dpi=200)
|
33 |
+
page_image = Image.open(io.BytesIO(pix.tobytes("png")))
|
34 |
+
|
35 |
+
print(f"Page {page.number}: Found {len(tables.tables)} potential tables.")
|
36 |
+
for i, table in enumerate(tables):
|
37 |
+
table_bboxes.append(table.bbox)
|
38 |
+
table_image = page_image.crop(table.bbox)
|
39 |
+
markdown_table = process_table_element(table_image)
|
40 |
+
page_content.append(markdown_table)
|
41 |
+
except Exception as e:
|
42 |
+
print(f"Could not process tables on page {page.number}: {e}")
|
43 |
+
|
44 |
+
# 2. Extract text blocks, excluding those within table bounding boxes
|
45 |
+
text_blocks = page.get_text("blocks")
|
46 |
+
for block in text_blocks:
|
47 |
+
block_bbox = block[:4]
|
48 |
+
# Check if this text block is inside any of the tables we just processed
|
49 |
+
is_in_table = any(_is_bbox_contained(block_bbox, table_bbox) for table_bbox in table_bboxes)
|
50 |
+
if not is_in_table:
|
51 |
+
page_content.append(block[4].strip())
|
52 |
+
|
53 |
+
# Note: Image extraction can be added here if needed, similar to table extraction.
|
54 |
+
|
55 |
+
return "\n".join(page_content)
|
56 |
+
|
57 |
+
def process_pdf_with_hybrid_parallel_sync(file_path: Path) -> str:
|
58 |
+
"""
|
59 |
+
Processes a PDF file in parallel using PyMuPDF and the complex_parser.
|
60 |
+
"""
|
61 |
+
print(f"Processing PDF '{file_path.name}' with parallel page-by-page strategy...")
|
62 |
+
all_page_texts = []
|
63 |
+
doc = fitz.open(file_path)
|
64 |
+
|
65 |
+
with ThreadPoolExecutor(max_workers=os.cpu_count() or 4) as executor:
|
66 |
+
futures = {executor.submit(_process_page, page): page.number for page in doc}
|
67 |
+
|
68 |
+
# Collect results in page order
|
69 |
+
results = ["" for _ in range(len(doc))]
|
70 |
+
for future in as_completed(futures):
|
71 |
+
page_num = futures[future]
|
72 |
+
try:
|
73 |
+
results[page_num] = future.result()
|
74 |
+
except Exception as e:
|
75 |
+
print(f"Error processing page {page_num}: {e}")
|
76 |
+
all_page_texts = results
|
77 |
+
|
78 |
+
return f"\n\n--- Page Break ---\n\n".join(all_page_texts)
|
retrieval_parent.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file: retrieval_parent.py
|
2 |
+
|
3 |
+
import time # <-- ADD THIS IMPORT
|
4 |
+
import asyncio
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import json
|
8 |
+
from groq import AsyncGroq
|
9 |
+
from rank_bm25 import BM25Okapi
|
10 |
+
from sentence_transformers import CrossEncoder
|
11 |
+
from sklearn.preprocessing import MinMaxScaler
|
12 |
+
from torch.nn.functional import cosine_similarity
|
13 |
+
from typing import List, Dict, Tuple
|
14 |
+
from langchain.storage import InMemoryStore
|
15 |
+
|
16 |
+
from embedding import EmbeddingClient
|
17 |
+
from langchain_core.documents import Document
|
18 |
+
|
19 |
+
# --- Configuration ---
|
20 |
+
GENERATION_MODEL = "llama3-8b-8192"
|
21 |
+
RERANKER_MODEL = 'cross-encoder/stsb-distilroberta-base'
|
22 |
+
INITIAL_K_CANDIDATES = 20
|
23 |
+
TOP_K_CHUNKS = 10
|
24 |
+
|
25 |
+
async def generate_hypothetical_document(query: str, groq_api_key: str) -> str:
|
26 |
+
"""Generates a hypothetical document to answer the query (HyDE)."""
|
27 |
+
if not groq_api_key:
|
28 |
+
print("Groq API key not set. Skipping HyDE generation.")
|
29 |
+
return ""
|
30 |
+
|
31 |
+
print(f"Starting HyDE generation for query: '{query}'...")
|
32 |
+
client = AsyncGroq(api_key=groq_api_key)
|
33 |
+
prompt = (
|
34 |
+
f"Write a brief, formal passage that directly answers the following question. "
|
35 |
+
f"This passage will be used to find similar documents. "
|
36 |
+
f"Do not include the question or any conversational text.\n\n"
|
37 |
+
f"Question: {query}\n\n"
|
38 |
+
f"Hypothetical Passage:"
|
39 |
+
)
|
40 |
+
|
41 |
+
start_time = time.perf_counter() # <-- START TIMER
|
42 |
+
try:
|
43 |
+
chat_completion = await client.chat.completions.create(
|
44 |
+
messages=[{"role": "user", "content": prompt}],
|
45 |
+
model=GENERATION_MODEL,
|
46 |
+
temperature=0.7,
|
47 |
+
max_tokens=500,
|
48 |
+
)
|
49 |
+
end_time = time.perf_counter() # <-- END TIMER
|
50 |
+
print(f"--- HyDE generation took {end_time - start_time:.4f} seconds ---") # <-- PRINT DURATION
|
51 |
+
return chat_completion.choices[0].message.content
|
52 |
+
except Exception as e:
|
53 |
+
print(f"An error occurred during HyDE generation: {e}")
|
54 |
+
return ""
|
55 |
+
|
56 |
+
async def generate_expanded_terms(query: str, groq_api_key: str) -> List[str]:
|
57 |
+
"""Generates semantically related search terms for a query."""
|
58 |
+
if not groq_api_key:
|
59 |
+
print("Groq API key not set. Skipping Semantic Expansion.")
|
60 |
+
return [query]
|
61 |
+
|
62 |
+
print(f"Starting Semantic Expansion for query: '{query}'...")
|
63 |
+
client = AsyncGroq(api_key=groq_api_key)
|
64 |
+
prompt = (
|
65 |
+
f"You are a search query expansion expert. Based on the following query, generate up to 4 additional, "
|
66 |
+
f"semantically related search terms. The terms should be relevant for finding information in technical documents. "
|
67 |
+
f"Return the original query plus the new terms as a single JSON list of strings.\n\n"
|
68 |
+
f"Query: \"{query}\"\n\n"
|
69 |
+
f"JSON List:"
|
70 |
+
)
|
71 |
+
|
72 |
+
start_time = time.perf_counter() # <-- START TIMER
|
73 |
+
try:
|
74 |
+
chat_completion = await client.chat.completions.create(
|
75 |
+
messages=[{"role": "user", "content": prompt}],
|
76 |
+
model=GENERATION_MODEL,
|
77 |
+
temperature=0.4,
|
78 |
+
max_tokens=200,
|
79 |
+
response_format={"type": "json_object"},
|
80 |
+
)
|
81 |
+
end_time = time.perf_counter() # <-- END TIMER
|
82 |
+
print(f"--- Semantic Expansion took {end_time - start_time:.4f} seconds ---") # <-- PRINT DURATION
|
83 |
+
|
84 |
+
result_text = chat_completion.choices[0].message.content
|
85 |
+
terms = json.loads(result_text)
|
86 |
+
|
87 |
+
if isinstance(terms, dict) and 'terms' in terms:
|
88 |
+
return terms['terms']
|
89 |
+
return terms
|
90 |
+
except Exception as e:
|
91 |
+
print(f"An error occurred during Semantic Expansion: {e}")
|
92 |
+
return [query]
|
93 |
+
|
94 |
+
|
95 |
+
class Retriever:
|
96 |
+
"""Manages hybrid search with parent-child retrieval."""
|
97 |
+
|
98 |
+
def __init__(self, embedding_client: EmbeddingClient):
|
99 |
+
self.embedding_client = embedding_client
|
100 |
+
self.reranker = CrossEncoder(RERANKER_MODEL, device=self.embedding_client.device)
|
101 |
+
self.bm25 = None
|
102 |
+
self.document_chunks = []
|
103 |
+
self.chunk_embeddings = None
|
104 |
+
self.docstore = InMemoryStore()
|
105 |
+
print(f"Retriever initialized with reranker '{RERANKER_MODEL}'.")
|
106 |
+
|
107 |
+
def index(self, child_documents: List[Document], docstore: InMemoryStore):
|
108 |
+
"""Builds the search index from child documents and stores parent documents."""
|
109 |
+
self.document_chunks = child_documents
|
110 |
+
self.docstore = docstore
|
111 |
+
|
112 |
+
corpus = [doc.page_content for doc in child_documents]
|
113 |
+
if not corpus:
|
114 |
+
print("No documents to index.")
|
115 |
+
return
|
116 |
+
|
117 |
+
print("Indexing child documents for retrieval...")
|
118 |
+
tokenized_corpus = [doc.split(" ") for doc in corpus]
|
119 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
120 |
+
self.chunk_embeddings = self.embedding_client.create_embeddings(corpus)
|
121 |
+
print("Indexing complete.")
|
122 |
+
|
123 |
+
def _hybrid_search(self, query: str, hyde_doc: str, expanded_terms: List[str]) -> List[Tuple[int, float]]:
|
124 |
+
"""Performs a hybrid search using expanded terms for BM25 and a HyDE doc for dense search."""
|
125 |
+
if self.bm25 is None or self.chunk_embeddings is None:
|
126 |
+
raise ValueError("Retriever has not been indexed. Call index() first.")
|
127 |
+
|
128 |
+
print(f"Running BM25 with expanded terms: {expanded_terms}")
|
129 |
+
bm25_scores = self.bm25.get_scores(expanded_terms)
|
130 |
+
|
131 |
+
enhanced_query = f"{query}\n\n{hyde_doc}" if hyde_doc else query
|
132 |
+
query_embedding = self.embedding_client.create_embeddings([enhanced_query])
|
133 |
+
dense_scores = cosine_similarity(query_embedding, self.chunk_embeddings).cpu().numpy().flatten()
|
134 |
+
|
135 |
+
scaler = MinMaxScaler()
|
136 |
+
norm_bm25 = scaler.fit_transform(bm25_scores.reshape(-1, 1)).flatten()
|
137 |
+
norm_dense = scaler.fit_transform(dense_scores.reshape(-1, 1)).flatten()
|
138 |
+
combined_scores = 0.5 * norm_bm25 + 0.5 * norm_dense
|
139 |
+
|
140 |
+
top_indices = np.argsort(combined_scores)[::-1][:INITIAL_K_CANDIDATES]
|
141 |
+
return [(idx, combined_scores[idx]) for idx in top_indices]
|
142 |
+
|
143 |
+
async def _rerank(self, query: str, candidates: List[dict]) -> List[dict]:
|
144 |
+
"""Reranks candidates using a CrossEncoder model."""
|
145 |
+
if not candidates:
|
146 |
+
return []
|
147 |
+
|
148 |
+
print(f"Reranking {len(candidates)} candidates...")
|
149 |
+
rerank_input = [[query, chunk["content"]] for chunk in candidates]
|
150 |
+
|
151 |
+
rerank_scores = await asyncio.to_thread(
|
152 |
+
self.reranker.predict, rerank_input, show_progress_bar=False
|
153 |
+
)
|
154 |
+
|
155 |
+
for candidate, score in zip(candidates, rerank_scores):
|
156 |
+
candidate['rerank_score'] = score
|
157 |
+
|
158 |
+
candidates.sort(key=lambda x: x['rerank_score'], reverse=True)
|
159 |
+
return candidates[:TOP_K_CHUNKS]
|
160 |
+
|
161 |
+
async def retrieve(self, query: str, groq_api_key: str) -> List[Dict]:
|
162 |
+
"""Executes the full retrieval pipeline: expansion, HyDE, hybrid search, and reranking."""
|
163 |
+
print(f"\n--- Retrieving documents for query: '{query}' ---")
|
164 |
+
|
165 |
+
hyde_task = generate_hypothetical_document(query, groq_api_key)
|
166 |
+
expansion_task = generate_expanded_terms(query, groq_api_key)
|
167 |
+
hyde_doc, expanded_terms = await asyncio.gather(hyde_task, expansion_task)
|
168 |
+
|
169 |
+
initial_candidates_info = self._hybrid_search(query, hyde_doc, expanded_terms)
|
170 |
+
|
171 |
+
retrieved_child_docs = [{
|
172 |
+
"content": self.document_chunks[idx].page_content,
|
173 |
+
"metadata": self.document_chunks[idx].metadata,
|
174 |
+
} for idx, score in initial_candidates_info]
|
175 |
+
|
176 |
+
reranked_child_docs = await self._rerank(query, retrieved_child_docs)
|
177 |
+
|
178 |
+
parent_ids = []
|
179 |
+
for doc in reranked_child_docs:
|
180 |
+
parent_id = doc["metadata"]["parent_id"]
|
181 |
+
if parent_id not in parent_ids:
|
182 |
+
parent_ids.append(parent_id)
|
183 |
+
|
184 |
+
retrieved_parents = self.docstore.mget(parent_ids)
|
185 |
+
final_parent_docs = [doc for doc in retrieved_parents if doc is not None]
|
186 |
+
|
187 |
+
final_context = [{
|
188 |
+
"content": doc.page_content,
|
189 |
+
"metadata": doc.metadata
|
190 |
+
} for doc in final_parent_docs]
|
191 |
+
|
192 |
+
print(f"Retrieved {len(final_context)} final parent chunks for context.")
|
193 |
+
return final_context
|