TopEdu_Demo / semantic_chunking.py
yyzsna's picture
Upload folder using huggingface_hub
102c695 verified
"""
Semantic Chunker Module for RAG Systems
======================================
A drop-in replacement for RecursiveCharacterTextSplitter that uses semantic similarity
to create more coherent chunks. Designed to work seamlessly with existing LangChain
and Streamlit RAG systems.
Author: AI Assistant
Compatible with: LangChain, BGE embeddings, OpenAI embeddings, Streamlit
"""
import numpy as np
import re
from typing import List, Dict, Any, Optional, Union
from langchain.schema import Document
import streamlit as st
from sklearn.metrics.pairwise import cosine_similarity
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SemanticChunker:
"""
Advanced semantic document chunker that creates coherent chunks based on
semantic similarity rather than fixed character counts.
Perfect for university documents, research papers, and policy documents
where maintaining semantic coherence is crucial.
"""
def __init__(self,
embeddings_model,
chunk_size: int = 4,
overlap: int = 1,
similarity_threshold: float = 0.75,
min_chunk_size: int = 150,
max_chunk_size: int = 1500,
sentence_split_pattern: Optional[str] = None,
debug: bool = False):
"""
Initialize the semantic chunker.
Args:
embeddings_model: Your existing embeddings model (BGE, OpenAI, etc.)
chunk_size: Base number of sentences per chunk (default: 4)
overlap: Number of sentences to overlap between chunks (default: 1)
similarity_threshold: Cosine similarity threshold for extending chunks (0.0-1.0)
min_chunk_size: Minimum characters per chunk (skip smaller chunks)
max_chunk_size: Maximum characters per chunk (prevent overly large chunks)
sentence_split_pattern: Custom regex pattern for sentence splitting
debug: Enable debug logging and statistics
"""
self.embeddings_model = embeddings_model
self.chunk_size = chunk_size
self.overlap = overlap
self.similarity_threshold = similarity_threshold
self.min_chunk_size = min_chunk_size
self.max_chunk_size = max_chunk_size
self.debug = debug
# Default sentence splitting pattern optimized for academic/university documents
self.sentence_pattern = sentence_split_pattern or r'[.!?]+\s+'
# Statistics tracking
self.stats = {
"total_documents": 0,
"total_chunks": 0,
"avg_chunk_size": 0,
"chunking_methods": {},
"embedding_errors": 0
}
if self.debug:
logger.info(f"Initialized SemanticChunker with threshold={similarity_threshold}")
def _detect_embedding_model_type(self) -> str:
"""Detect the type of embedding model being used."""
if hasattr(self.embeddings_model, 'model'):
# Likely sentence-transformers model (BGE, etc.)
model_name = getattr(self.embeddings_model.model, 'model_name', 'sentence-transformers')
return f"sentence-transformers ({model_name})"
elif hasattr(self.embeddings_model, 'client'):
# Likely OpenAI
return "OpenAI"
else:
return "Unknown"
def _preprocess_text_for_splitting(self, text: str) -> str:
"""
Preprocess text to handle common formatting issues in university documents.
"""
# Fix common formatting issues
fixes = [
# Add space after periods before capital letters
(r'([a-z])\.([A-Z])', r'\1. \2'),
# Add space after numbers with periods
(r'([0-9]+)\.([A-Z])', r'\1. \2'),
# Fix missing spaces after question/exclamation marks
(r'([a-z])\?([A-Z])', r'\1? \2'),
(r'([a-z])\!([A-Z])', r'\1! \2'),
# Clean up multiple spaces
(r'\s+', ' '),
# Fix bullet points
(r'β€’\s*([A-Z])', r'β€’ \1'),
(r'-\s*([A-Z])', r'- \1'),
]
processed_text = text
for pattern, replacement in fixes:
processed_text = re.sub(pattern, replacement, processed_text)
return processed_text.strip()
def _split_into_sentences(self, text: str) -> List[str]:
"""
Advanced sentence splitting optimized for academic documents.
"""
# Preprocess text
text = self._preprocess_text_for_splitting(text)
# Split on sentence boundaries
raw_sentences = re.split(self.sentence_pattern, text)
# Clean and filter sentences
sentences = []
for sentence in raw_sentences:
sentence = sentence.strip()
# Filter out very short sentences, pure numbers, or empty strings
if len(sentence) >= 10 and not sentence.isdigit() and not re.match(r'^[^\w]*$', sentence):
sentences.append(sentence)
if self.debug:
logger.info(f"Split text into {len(sentences)} sentences")
return sentences
def _get_embeddings(self, texts: List[str]) -> Optional[np.ndarray]:
"""
Get embeddings from the provided model with error handling.
"""
try:
if hasattr(self.embeddings_model, 'model'):
# sentence-transformers model (BGE, etc.)
embeddings = self.embeddings_model.model.encode(texts)
return np.array(embeddings)
elif hasattr(self.embeddings_model, 'embed_documents'):
# OpenAI or similar API-based embeddings
embeddings = self.embeddings_model.embed_documents(texts)
return np.array(embeddings)
else:
# Try direct call
embeddings = self.embeddings_model(texts)
return np.array(embeddings)
except Exception as e:
self.stats["embedding_errors"] += 1
if self.debug:
logger.error(f"Error generating embeddings: {e}")
# Show warning in Streamlit if available
try:
st.warning(f"⚠️ Embedding error, falling back to simple chunking: {str(e)[:100]}...")
except:
pass # Streamlit not available
return None
def _calculate_semantic_boundaries(self, embeddings: np.ndarray, sentences: List[str]) -> List[int]:
"""
Find natural semantic boundaries in the text based on embedding similarities.
"""
boundaries = [0] # Always start with first sentence
# Calculate similarities between consecutive sentences
similarities = []
for i in range(len(embeddings) - 1):
sim = cosine_similarity(
embeddings[i:i+1],
embeddings[i+1:i+2]
)[0][0]
similarities.append(sim)
# Find significant drops in similarity (topic boundaries)
if len(similarities) > 1:
mean_sim = np.mean(similarities)
std_sim = np.std(similarities)
threshold = mean_sim - (0.5 * std_sim) # Adaptive threshold
for i, sim in enumerate(similarities):
if sim < threshold:
boundaries.append(i + 1)
boundaries.append(len(sentences)) # Always end with last sentence
return sorted(list(set(boundaries))) # Remove duplicates and sort
def _create_chunks_from_boundaries(self, sentences: List[str], boundaries: List[int],
embeddings: Optional[np.ndarray], metadata: Dict[str, Any]) -> List[Document]:
"""
Create document chunks based on semantic boundaries.
"""
chunks = []
for i in range(len(boundaries) - 1):
start_idx = boundaries[i]
end_idx = boundaries[i + 1]
# Create base chunk
chunk_sentences = sentences[start_idx:end_idx]
# Try to extend chunk if semantically similar
if embeddings is not None and end_idx < len(sentences):
current_embedding = np.mean(embeddings[start_idx:end_idx], axis=0, keepdims=True)
# Check if we can extend the chunk
extended_end = end_idx
while extended_end < len(sentences):
next_sentence_embedding = embeddings[extended_end:extended_end+1]
similarity = cosine_similarity(current_embedding, next_sentence_embedding)[0][0]
if similarity > self.similarity_threshold:
# Check size limit
test_chunk = ' '.join(sentences[start_idx:extended_end+1])
if len(test_chunk) <= self.max_chunk_size:
extended_end += 1
# Update current embedding
current_embedding = np.mean(embeddings[start_idx:extended_end], axis=0, keepdims=True)
else:
break
else:
break
# Use extended chunk if we found extensions
if extended_end > end_idx:
chunk_sentences = sentences[start_idx:extended_end]
# Create chunk text
chunk_text = ' '.join(chunk_sentences)
# Only add chunks that meet minimum size requirement
if len(chunk_text) >= self.min_chunk_size:
chunk_metadata = metadata.copy()
chunk_metadata.update({
"chunk_index": len(chunks),
"sentence_count": len(chunk_sentences),
"start_sentence": start_idx,
"end_sentence": start_idx + len(chunk_sentences) - 1,
"chunking_method": "semantic_boundary",
"similarity_threshold": self.similarity_threshold,
"chunk_size_chars": len(chunk_text)
})
chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata))
return chunks
def _create_simple_chunks(self, sentences: List[str], metadata: Dict[str, Any]) -> List[Document]:
"""
Fallback to simple sentence-based chunking when embeddings are unavailable.
"""
chunks = []
for i in range(0, len(sentences), max(1, self.chunk_size - self.overlap)):
chunk_sentences = sentences[i:i + self.chunk_size]
chunk_text = ' '.join(chunk_sentences)
if len(chunk_text) >= self.min_chunk_size:
chunk_metadata = metadata.copy()
chunk_metadata.update({
"chunk_index": len(chunks),
"sentence_count": len(chunk_sentences),
"start_sentence": i,
"end_sentence": i + len(chunk_sentences) - 1,
"chunking_method": "simple_fallback",
"chunk_size_chars": len(chunk_text)
})
chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata))
return chunks
def split_documents(self, documents: List[Document]) -> List[Document]:
"""
Main method: Split documents into semantically coherent chunks.
Args:
documents: List of LangChain Document objects
Returns:
List of Document objects with semantic chunks
"""
all_chunks = []
self.stats["total_documents"] = len(documents)
for doc_idx, doc in enumerate(documents):
try:
# Split document into sentences
sentences = self._split_into_sentences(doc.page_content)
if not sentences:
if self.debug:
logger.warning(f"No sentences found in document {doc_idx}")
continue
# Handle very short documents
if len(sentences) < self.chunk_size:
chunk_text = ' '.join(sentences)
if len(chunk_text) >= self.min_chunk_size:
chunk_metadata = doc.metadata.copy()
chunk_metadata.update({
"chunk_index": 0,
"total_chunks": 1,
"sentence_count": len(sentences),
"chunking_method": "single_chunk",
"chunk_size_chars": len(chunk_text)
})
all_chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata))
continue
# Generate embeddings
embeddings = self._get_embeddings(sentences)
if embeddings is not None:
# Create semantic chunks
chunks = self._create_chunks_from_boundaries(sentences, [0, len(sentences)], embeddings, doc.metadata)
method = "semantic"
else:
# Fallback to simple chunking
chunks = self._create_simple_chunks(sentences, doc.metadata)
method = "simple_fallback"
# Update statistics
self.stats["chunking_methods"][method] = self.stats["chunking_methods"].get(method, 0) + 1
# Update total chunks count in each chunk's metadata
for chunk in chunks:
chunk.metadata["total_chunks"] = len(chunks)
chunk.metadata["source_document_index"] = doc_idx
all_chunks.extend(chunks)
if self.debug:
logger.info(f"Document {doc_idx}: {len(sentences)} sentences β†’ {len(chunks)} chunks ({method})")
except Exception as e:
logger.error(f"Error processing document {doc_idx}: {e}")
if self.debug:
st.error(f"Error processing document {doc_idx}: {e}")
# Update final statistics
self.stats["total_chunks"] = len(all_chunks)
if all_chunks:
chunk_sizes = [len(chunk.page_content) for chunk in all_chunks]
self.stats["avg_chunk_size"] = sum(chunk_sizes) / len(chunk_sizes)
if self.debug:
logger.info(f"Created {len(all_chunks)} total chunks from {len(documents)} documents")
return all_chunks
def get_statistics(self) -> Dict[str, Any]:
"""Get chunking statistics for analysis."""
return self.stats.copy()
def display_statistics(self):
"""Display chunking statistics in Streamlit (if available)."""
try:
with st.expander("πŸ“Š Semantic Chunking Statistics"):
col1, col2 = st.columns(2)
with col1:
st.metric("Total Documents", self.stats["total_documents"])
st.metric("Total Chunks", self.stats["total_chunks"])
with col2:
st.metric("Avg Chunk Size", f"{self.stats['avg_chunk_size']:.0f} chars")
st.metric("Embedding Errors", self.stats["embedding_errors"])
if self.stats["chunking_methods"]:
st.write("**Chunking Methods Used:**")
for method, count in self.stats["chunking_methods"].items():
percentage = (count / self.stats["total_documents"]) * 100 if self.stats["total_documents"] > 0 else 0
st.write(f" - {method}: {count} documents ({percentage:.1f}%)")
st.write("**Configuration:**")
st.json({
"chunk_size": self.chunk_size,
"overlap": self.overlap,
"similarity_threshold": self.similarity_threshold,
"min_chunk_size": self.min_chunk_size,
"max_chunk_size": self.max_chunk_size,
"embedding_model": self._detect_embedding_model_type()
})
except ImportError:
# Streamlit not available, print to console
print("\n=== Semantic Chunking Statistics ===")
print(f"Documents processed: {self.stats['total_documents']}")
print(f"Chunks created: {self.stats['total_chunks']}")
print(f"Average chunk size: {self.stats['avg_chunk_size']:.0f} characters")
print(f"Embedding errors: {self.stats['embedding_errors']}")
print(f"Chunking methods: {self.stats['chunking_methods']}")
def create_semantic_chunker(embeddings_model, **kwargs) -> SemanticChunker:
"""
Convenience function to create a semantic chunker with sensible defaults.
Args:
embeddings_model: Your existing embeddings model
**kwargs: Additional parameters to pass to SemanticChunker
Returns:
SemanticChunker instance ready to use
"""
return SemanticChunker(embeddings_model=embeddings_model, **kwargs)