drwlf
Add medical PDF ingestion Gradio app with RAG capabilities
01bc500
#!/usr/bin/env python3
"""
PDF Document Ingestion Script
This script processes complex PDF documents (like medical textbooks), extracts text and images,
chunks them intelligently, generates vector embeddings using state-of-the-art local models,
and stores them in a local ChromaDB vector database.
Author: Expert Python Developer
Python Version: 3.9+
"""
import os
import uuid
import hashlib
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import logging
# Third-party imports
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
from unstructured.partition.pdf import partition_pdf
from PIL import Image
import io
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION SECTION
# =============================================================================
# Input/Output Paths
SOURCE_DIRECTORY = "/home/tony/pdf_docs" # Directory containing PDF files to process
DB_PATH = "/home/tony/chromadb" # Path for persistent ChromaDB database
IMAGE_OUTPUT_DIRECTORY = "/home/tony/extracted_images" # Path for storing extracted images
# Model Configuration
TEXT_EMBEDDING_MODEL = "BAAI/bge-m3" # State-of-the-art text embedding model
IMAGE_EMBEDDING_MODEL = "clip-ViT-B-32" # CLIP model for image embeddings
# Database Configuration
COLLECTION_NAME = "medical_library" # ChromaDB collection name
# Processing Configuration
BATCH_SIZE = 100 # Number of chunks to process in each batch
MAX_CHUNK_SIZE = 1000 # Maximum characters per text chunk
# =============================================================================
# INITIALIZATION FUNCTIONS
# =============================================================================
def initialize_chromadb() -> Tuple[chromadb.Client, chromadb.Collection]:
"""
Initialize and return the ChromaDB client and collection.
Returns:
Tuple[chromadb.Client, chromadb.Collection]: The client and collection objects
"""
try:
# Ensure database directory exists
os.makedirs(DB_PATH, exist_ok=True)
# Initialize ChromaDB client with persistent storage
client = chromadb.PersistentClient(
path=DB_PATH,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True
)
)
# Get or create collection
try:
collection = client.get_collection(name=COLLECTION_NAME)
logger.info(f"Using existing collection: {COLLECTION_NAME}")
except chromadb.errors.NotFoundError:
collection = client.create_collection(
name=COLLECTION_NAME,
metadata={"description": "Medical textbook PDF content with embeddings"}
)
logger.info(f"Created new collection: {COLLECTION_NAME}")
return client, collection
except Exception as e:
logger.error(f"Failed to initialize ChromaDB: {e}")
raise
def initialize_models() -> Tuple[SentenceTransformer, SentenceTransformer]:
"""
Load and return the text and image embedding models.
Returns:
Tuple[SentenceTransformer, SentenceTransformer]: Text and image models
"""
try:
logger.info("Loading text embedding model...")
text_model = SentenceTransformer(TEXT_EMBEDDING_MODEL)
logger.info("Loading image embedding model...")
image_model = SentenceTransformer(IMAGE_EMBEDDING_MODEL)
logger.info("Models loaded successfully!")
return text_model, image_model
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise
def ensure_directories() -> None:
"""
Ensure all required directories exist.
"""
try:
os.makedirs(SOURCE_DIRECTORY, exist_ok=True)
os.makedirs(IMAGE_OUTPUT_DIRECTORY, exist_ok=True)
os.makedirs(DB_PATH, exist_ok=True)
logger.info("All directories verified/created successfully")
except Exception as e:
logger.error(f"Failed to create directories: {e}")
raise
# =============================================================================
# DEDUPLICATION FUNCTIONS
# =============================================================================
def calculate_file_hash(file_path: str) -> str:
"""
Calculate SHA-256 hash of a file for deduplication.
Args:
file_path (str): Path to the file
Returns:
str: SHA-256 hash of the file
"""
hash_sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def is_pdf_already_processed(pdf_path: str, collection: chromadb.Collection) -> bool:
"""
Check if a PDF has already been processed by checking its hash in the database.
Args:
pdf_path (str): Path to the PDF file
collection (chromadb.Collection): ChromaDB collection
Returns:
bool: True if already processed, False otherwise
"""
try:
file_hash = calculate_file_hash(pdf_path)
# Query the collection for any document with this file hash
result = collection.get(where={"file_hash": file_hash}, limit=1)
if len(result['ids']) > 0:
pdf_filename = Path(pdf_path).name
logger.info(f"PDF {pdf_filename} already processed (hash: {file_hash[:12]}...). Skipping.")
return True
return False
except Exception as e:
logger.warning(f"Error checking if PDF is already processed: {e}")
return False
# =============================================================================
# DOCUMENT PROCESSING FUNCTIONS
# =============================================================================
def process_pdf(
pdf_path: str,
text_model: SentenceTransformer,
image_model: SentenceTransformer,
collection: chromadb.Collection
) -> None:
"""
Process a single PDF file and store chunks in ChromaDB.
Args:
pdf_path (str): Path to the PDF file
text_model (SentenceTransformer): Text embedding model
image_model (SentenceTransformer): Image embedding model
collection (chromadb.Collection): ChromaDB collection
"""
try:
pdf_filename = Path(pdf_path).name
logger.info(f"Processing PDF: {pdf_filename}")
# Calculate file hash for deduplication
file_hash = calculate_file_hash(pdf_path)
# Parse PDF with unstructured
elements = partition_pdf(
filename=pdf_path,
strategy="hi_res",
extract_images_in_pdf=True,
infer_table_structure=True
)
if not elements:
logger.warning(f"No elements extracted from {pdf_filename}")
return
# Generate chunks from elements
chunks = create_chunks_from_elements(elements, pdf_filename, file_hash)
if not chunks:
logger.warning(f"No chunks created from {pdf_filename}")
return
# Process chunks in batches
process_chunks_in_batches(chunks, text_model, image_model, collection)
logger.info(f"Successfully processed {pdf_filename}: {len(chunks)} chunks")
except Exception as e:
logger.error(f"Error processing PDF {pdf_path}: {e}")
raise
def create_chunks_from_elements(elements: List, pdf_filename: str, file_hash: str) -> List[Dict[str, Any]]:
"""
Create chunks from unstructured elements (let unstructured handle the intelligent parsing).
Args:
elements (List): List of unstructured elements
pdf_filename (str): Name of the source PDF file
file_hash (str): SHA-256 hash of the PDF file for deduplication
Returns:
List[Dict[str, Any]]: List of chunk dictionaries
"""
chunks = []
for i, element in enumerate(elements):
try:
element_type = element.category
page_number = getattr(element.metadata, 'page_number', 1)
# Handle image elements
if element_type == "Image" and hasattr(element, 'image_bytes'):
# Save image and create image chunk
image_path = save_image(element.image_bytes, pdf_filename, i)
if image_path:
chunks.append({
'id': f"{pdf_filename}_img_{i}",
'content': image_path,
'type': 'image',
'metadata': {
'source_file': pdf_filename,
'page_number': page_number,
'element_type': element_type,
'image_path': image_path,
'file_hash': file_hash
}
})
# Handle all text elements as individual chunks (unstructured already did the intelligent parsing)
else:
text_content = str(element).strip()
if text_content and len(text_content) > 20: # Skip very short fragments
chunks.append({
'id': f"{pdf_filename}_text_{i}",
'content': text_content,
'type': 'text',
'metadata': {
'source_file': pdf_filename,
'page_number': page_number,
'element_type': element_type,
'file_hash': file_hash
}
})
except Exception as e:
logger.warning(f"Error processing element {i}: {e}")
continue
return chunks
def save_image(image_bytes: bytes, pdf_filename: str, chunk_index: int) -> Optional[str]:
"""
Save image bytes to file and return the path.
Args:
image_bytes (bytes): Raw image data
pdf_filename (str): Source PDF filename
chunk_index (int): Index of the chunk
Returns:
Optional[str]: Path to saved image or None if failed
"""
try:
# Create unique filename
image_filename = f"{Path(pdf_filename).stem}_{chunk_index}_{uuid.uuid4().hex[:8]}.png"
image_path = os.path.join(IMAGE_OUTPUT_DIRECTORY, image_filename)
# Convert and save image
image = Image.open(io.BytesIO(image_bytes))
image.save(image_path, format='PNG')
return image_path
except Exception as e:
logger.warning(f"Failed to save image: {e}")
return None
def process_chunks_in_batches(
chunks: List[Dict[str, Any]],
text_model: SentenceTransformer,
image_model: SentenceTransformer,
collection: chromadb.Collection
) -> None:
"""
Process chunks in batches and store in ChromaDB.
Args:
chunks (List[Dict[str, Any]]): List of chunks to process
text_model (SentenceTransformer): Text embedding model
image_model (SentenceTransformer): Image embedding model
collection (chromadb.Collection): ChromaDB collection
"""
for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i:i + BATCH_SIZE]
try:
process_batch(batch, text_model, image_model, collection)
except Exception as e:
logger.error(f"Error processing batch {i//BATCH_SIZE + 1}: {e}")
# Continue with next batch instead of failing completely
continue
def process_batch(
batch: List[Dict[str, Any]],
text_model: SentenceTransformer,
image_model: SentenceTransformer,
collection: chromadb.Collection
) -> None:
"""
Process a single batch of chunks.
Args:
batch (List[Dict[str, Any]]): Batch of chunks to process
text_model (SentenceTransformer): Text embedding model
image_model (SentenceTransformer): Image embedding model
collection (chromadb.Collection): ChromaDB collection
"""
ids = []
embeddings = []
metadatas = []
documents = []
for chunk in batch:
try:
chunk_id = chunk['id']
content = chunk['content']
chunk_type = chunk['type']
metadata = chunk['metadata']
# Generate embedding based on type
if chunk_type == 'text':
embedding = text_model.encode(content).tolist()
document = content
elif chunk_type == 'image':
# For images, encode the image file
if os.path.exists(content):
embedding = image_model.encode(Image.open(content)).tolist()
document = f"Image from {metadata['source_file']} page {metadata['page_number']}"
else:
logger.warning(f"Image file not found: {content}")
continue
else:
logger.warning(f"Unknown chunk type: {chunk_type}")
continue
ids.append(chunk_id)
embeddings.append(embedding)
metadatas.append(metadata)
documents.append(document)
except Exception as e:
logger.warning(f"Error processing chunk {chunk.get('id', 'unknown')}: {e}")
continue
# Add batch to collection
if ids:
try:
collection.add(
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents
)
logger.debug(f"Added batch of {len(ids)} chunks to database")
except Exception as e:
logger.error(f"Error adding batch to database: {e}")
raise
# =============================================================================
# MAIN EXECUTION
# =============================================================================
def main():
"""
Main execution function.
"""
try:
logger.info("Starting PDF ingestion process...")
# Ensure directories exist
ensure_directories()
# Initialize models and database
logger.info("Initializing models and database...")
text_model, image_model = initialize_models()
client, collection = initialize_chromadb()
# Get list of PDF files
pdf_files = []
if os.path.exists(SOURCE_DIRECTORY):
pdf_files = [f for f in os.listdir(SOURCE_DIRECTORY) if f.lower().endswith('.pdf')]
if not pdf_files:
logger.warning(f"No PDF files found in {SOURCE_DIRECTORY}")
logger.info("Please add PDF files to the source directory and run again.")
return
logger.info(f"Found {len(pdf_files)} PDF files to process")
# Process each PDF file with progress bar
with tqdm(pdf_files, desc="Processing PDFs") as pbar:
for pdf_file in pbar:
pdf_path = os.path.join(SOURCE_DIRECTORY, pdf_file)
pbar.set_description(f"Processing {pdf_file}")
# Check if this PDF has already been processed
if is_pdf_already_processed(pdf_path, collection):
continue
try:
process_pdf(pdf_path, text_model, image_model, collection)
except Exception as e:
logger.error(f"Failed to process {pdf_file}: {e}")
continue
# Get final statistics
try:
count = collection.count()
logger.info(f"Ingestion complete! Total chunks in database: {count}")
except Exception as e:
logger.warning(f"Could not get final count: {e}")
logger.info("PDF ingestion process completed successfully!")
except Exception as e:
logger.error(f"Fatal error in main execution: {e}")
raise
if __name__ == "__main__":
main()