Spaces:
Sleeping
Sleeping
""" | |
Semantic Chunker Module for RAG Systems | |
====================================== | |
A drop-in replacement for RecursiveCharacterTextSplitter that uses semantic similarity | |
to create more coherent chunks. Designed to work seamlessly with existing LangChain | |
and Streamlit RAG systems. | |
Author: AI Assistant | |
Compatible with: LangChain, BGE embeddings, OpenAI embeddings, Streamlit | |
""" | |
import numpy as np | |
import re | |
from typing import List, Dict, Any, Optional, Union | |
from langchain.schema import Document | |
import streamlit as st | |
from sklearn.metrics.pairwise import cosine_similarity | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class SemanticChunker: | |
""" | |
Advanced semantic document chunker that creates coherent chunks based on | |
semantic similarity rather than fixed character counts. | |
Perfect for university documents, research papers, and policy documents | |
where maintaining semantic coherence is crucial. | |
""" | |
def __init__(self, | |
embeddings_model, | |
chunk_size: int = 4, | |
overlap: int = 1, | |
similarity_threshold: float = 0.75, | |
min_chunk_size: int = 150, | |
max_chunk_size: int = 1500, | |
sentence_split_pattern: Optional[str] = None, | |
debug: bool = False): | |
""" | |
Initialize the semantic chunker. | |
Args: | |
embeddings_model: Your existing embeddings model (BGE, OpenAI, etc.) | |
chunk_size: Base number of sentences per chunk (default: 4) | |
overlap: Number of sentences to overlap between chunks (default: 1) | |
similarity_threshold: Cosine similarity threshold for extending chunks (0.0-1.0) | |
min_chunk_size: Minimum characters per chunk (skip smaller chunks) | |
max_chunk_size: Maximum characters per chunk (prevent overly large chunks) | |
sentence_split_pattern: Custom regex pattern for sentence splitting | |
debug: Enable debug logging and statistics | |
""" | |
self.embeddings_model = embeddings_model | |
self.chunk_size = chunk_size | |
self.overlap = overlap | |
self.similarity_threshold = similarity_threshold | |
self.min_chunk_size = min_chunk_size | |
self.max_chunk_size = max_chunk_size | |
self.debug = debug | |
# Default sentence splitting pattern optimized for academic/university documents | |
self.sentence_pattern = sentence_split_pattern or r'[.!?]+\s+' | |
# Statistics tracking | |
self.stats = { | |
"total_documents": 0, | |
"total_chunks": 0, | |
"avg_chunk_size": 0, | |
"chunking_methods": {}, | |
"embedding_errors": 0 | |
} | |
if self.debug: | |
logger.info(f"Initialized SemanticChunker with threshold={similarity_threshold}") | |
def _detect_embedding_model_type(self) -> str: | |
"""Detect the type of embedding model being used.""" | |
if hasattr(self.embeddings_model, 'model'): | |
# Likely sentence-transformers model (BGE, etc.) | |
model_name = getattr(self.embeddings_model.model, 'model_name', 'sentence-transformers') | |
return f"sentence-transformers ({model_name})" | |
elif hasattr(self.embeddings_model, 'client'): | |
# Likely OpenAI | |
return "OpenAI" | |
else: | |
return "Unknown" | |
def _preprocess_text_for_splitting(self, text: str) -> str: | |
""" | |
Preprocess text to handle common formatting issues in university documents. | |
""" | |
# Fix common formatting issues | |
fixes = [ | |
# Add space after periods before capital letters | |
(r'([a-z])\.([A-Z])', r'\1. \2'), | |
# Add space after numbers with periods | |
(r'([0-9]+)\.([A-Z])', r'\1. \2'), | |
# Fix missing spaces after question/exclamation marks | |
(r'([a-z])\?([A-Z])', r'\1? \2'), | |
(r'([a-z])\!([A-Z])', r'\1! \2'), | |
# Clean up multiple spaces | |
(r'\s+', ' '), | |
# Fix bullet points | |
(r'β’\s*([A-Z])', r'β’ \1'), | |
(r'-\s*([A-Z])', r'- \1'), | |
] | |
processed_text = text | |
for pattern, replacement in fixes: | |
processed_text = re.sub(pattern, replacement, processed_text) | |
return processed_text.strip() | |
def _split_into_sentences(self, text: str) -> List[str]: | |
""" | |
Advanced sentence splitting optimized for academic documents. | |
""" | |
# Preprocess text | |
text = self._preprocess_text_for_splitting(text) | |
# Split on sentence boundaries | |
raw_sentences = re.split(self.sentence_pattern, text) | |
# Clean and filter sentences | |
sentences = [] | |
for sentence in raw_sentences: | |
sentence = sentence.strip() | |
# Filter out very short sentences, pure numbers, or empty strings | |
if len(sentence) >= 10 and not sentence.isdigit() and not re.match(r'^[^\w]*$', sentence): | |
sentences.append(sentence) | |
if self.debug: | |
logger.info(f"Split text into {len(sentences)} sentences") | |
return sentences | |
def _get_embeddings(self, texts: List[str]) -> Optional[np.ndarray]: | |
""" | |
Get embeddings from the provided model with error handling. | |
""" | |
try: | |
if hasattr(self.embeddings_model, 'model'): | |
# sentence-transformers model (BGE, etc.) | |
embeddings = self.embeddings_model.model.encode(texts) | |
return np.array(embeddings) | |
elif hasattr(self.embeddings_model, 'embed_documents'): | |
# OpenAI or similar API-based embeddings | |
embeddings = self.embeddings_model.embed_documents(texts) | |
return np.array(embeddings) | |
else: | |
# Try direct call | |
embeddings = self.embeddings_model(texts) | |
return np.array(embeddings) | |
except Exception as e: | |
self.stats["embedding_errors"] += 1 | |
if self.debug: | |
logger.error(f"Error generating embeddings: {e}") | |
# Show warning in Streamlit if available | |
try: | |
st.warning(f"β οΈ Embedding error, falling back to simple chunking: {str(e)[:100]}...") | |
except: | |
pass # Streamlit not available | |
return None | |
def _calculate_semantic_boundaries(self, embeddings: np.ndarray, sentences: List[str]) -> List[int]: | |
""" | |
Find natural semantic boundaries in the text based on embedding similarities. | |
""" | |
boundaries = [0] # Always start with first sentence | |
# Calculate similarities between consecutive sentences | |
similarities = [] | |
for i in range(len(embeddings) - 1): | |
sim = cosine_similarity( | |
embeddings[i:i+1], | |
embeddings[i+1:i+2] | |
)[0][0] | |
similarities.append(sim) | |
# Find significant drops in similarity (topic boundaries) | |
if len(similarities) > 1: | |
mean_sim = np.mean(similarities) | |
std_sim = np.std(similarities) | |
threshold = mean_sim - (0.5 * std_sim) # Adaptive threshold | |
for i, sim in enumerate(similarities): | |
if sim < threshold: | |
boundaries.append(i + 1) | |
boundaries.append(len(sentences)) # Always end with last sentence | |
return sorted(list(set(boundaries))) # Remove duplicates and sort | |
def _create_chunks_from_boundaries(self, sentences: List[str], boundaries: List[int], | |
embeddings: Optional[np.ndarray], metadata: Dict[str, Any]) -> List[Document]: | |
""" | |
Create document chunks based on semantic boundaries. | |
""" | |
chunks = [] | |
for i in range(len(boundaries) - 1): | |
start_idx = boundaries[i] | |
end_idx = boundaries[i + 1] | |
# Create base chunk | |
chunk_sentences = sentences[start_idx:end_idx] | |
# Try to extend chunk if semantically similar | |
if embeddings is not None and end_idx < len(sentences): | |
current_embedding = np.mean(embeddings[start_idx:end_idx], axis=0, keepdims=True) | |
# Check if we can extend the chunk | |
extended_end = end_idx | |
while extended_end < len(sentences): | |
next_sentence_embedding = embeddings[extended_end:extended_end+1] | |
similarity = cosine_similarity(current_embedding, next_sentence_embedding)[0][0] | |
if similarity > self.similarity_threshold: | |
# Check size limit | |
test_chunk = ' '.join(sentences[start_idx:extended_end+1]) | |
if len(test_chunk) <= self.max_chunk_size: | |
extended_end += 1 | |
# Update current embedding | |
current_embedding = np.mean(embeddings[start_idx:extended_end], axis=0, keepdims=True) | |
else: | |
break | |
else: | |
break | |
# Use extended chunk if we found extensions | |
if extended_end > end_idx: | |
chunk_sentences = sentences[start_idx:extended_end] | |
# Create chunk text | |
chunk_text = ' '.join(chunk_sentences) | |
# Only add chunks that meet minimum size requirement | |
if len(chunk_text) >= self.min_chunk_size: | |
chunk_metadata = metadata.copy() | |
chunk_metadata.update({ | |
"chunk_index": len(chunks), | |
"sentence_count": len(chunk_sentences), | |
"start_sentence": start_idx, | |
"end_sentence": start_idx + len(chunk_sentences) - 1, | |
"chunking_method": "semantic_boundary", | |
"similarity_threshold": self.similarity_threshold, | |
"chunk_size_chars": len(chunk_text) | |
}) | |
chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata)) | |
return chunks | |
def _create_simple_chunks(self, sentences: List[str], metadata: Dict[str, Any]) -> List[Document]: | |
""" | |
Fallback to simple sentence-based chunking when embeddings are unavailable. | |
""" | |
chunks = [] | |
for i in range(0, len(sentences), max(1, self.chunk_size - self.overlap)): | |
chunk_sentences = sentences[i:i + self.chunk_size] | |
chunk_text = ' '.join(chunk_sentences) | |
if len(chunk_text) >= self.min_chunk_size: | |
chunk_metadata = metadata.copy() | |
chunk_metadata.update({ | |
"chunk_index": len(chunks), | |
"sentence_count": len(chunk_sentences), | |
"start_sentence": i, | |
"end_sentence": i + len(chunk_sentences) - 1, | |
"chunking_method": "simple_fallback", | |
"chunk_size_chars": len(chunk_text) | |
}) | |
chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata)) | |
return chunks | |
def split_documents(self, documents: List[Document]) -> List[Document]: | |
""" | |
Main method: Split documents into semantically coherent chunks. | |
Args: | |
documents: List of LangChain Document objects | |
Returns: | |
List of Document objects with semantic chunks | |
""" | |
all_chunks = [] | |
self.stats["total_documents"] = len(documents) | |
for doc_idx, doc in enumerate(documents): | |
try: | |
# Split document into sentences | |
sentences = self._split_into_sentences(doc.page_content) | |
if not sentences: | |
if self.debug: | |
logger.warning(f"No sentences found in document {doc_idx}") | |
continue | |
# Handle very short documents | |
if len(sentences) < self.chunk_size: | |
chunk_text = ' '.join(sentences) | |
if len(chunk_text) >= self.min_chunk_size: | |
chunk_metadata = doc.metadata.copy() | |
chunk_metadata.update({ | |
"chunk_index": 0, | |
"total_chunks": 1, | |
"sentence_count": len(sentences), | |
"chunking_method": "single_chunk", | |
"chunk_size_chars": len(chunk_text) | |
}) | |
all_chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata)) | |
continue | |
# Generate embeddings | |
embeddings = self._get_embeddings(sentences) | |
if embeddings is not None: | |
# Create semantic chunks | |
chunks = self._create_chunks_from_boundaries(sentences, [0, len(sentences)], embeddings, doc.metadata) | |
method = "semantic" | |
else: | |
# Fallback to simple chunking | |
chunks = self._create_simple_chunks(sentences, doc.metadata) | |
method = "simple_fallback" | |
# Update statistics | |
self.stats["chunking_methods"][method] = self.stats["chunking_methods"].get(method, 0) + 1 | |
# Update total chunks count in each chunk's metadata | |
for chunk in chunks: | |
chunk.metadata["total_chunks"] = len(chunks) | |
chunk.metadata["source_document_index"] = doc_idx | |
all_chunks.extend(chunks) | |
if self.debug: | |
logger.info(f"Document {doc_idx}: {len(sentences)} sentences β {len(chunks)} chunks ({method})") | |
except Exception as e: | |
logger.error(f"Error processing document {doc_idx}: {e}") | |
if self.debug: | |
st.error(f"Error processing document {doc_idx}: {e}") | |
# Update final statistics | |
self.stats["total_chunks"] = len(all_chunks) | |
if all_chunks: | |
chunk_sizes = [len(chunk.page_content) for chunk in all_chunks] | |
self.stats["avg_chunk_size"] = sum(chunk_sizes) / len(chunk_sizes) | |
if self.debug: | |
logger.info(f"Created {len(all_chunks)} total chunks from {len(documents)} documents") | |
return all_chunks | |
def get_statistics(self) -> Dict[str, Any]: | |
"""Get chunking statistics for analysis.""" | |
return self.stats.copy() | |
def display_statistics(self): | |
"""Display chunking statistics in Streamlit (if available).""" | |
try: | |
with st.expander("π Semantic Chunking Statistics"): | |
col1, col2 = st.columns(2) | |
with col1: | |
st.metric("Total Documents", self.stats["total_documents"]) | |
st.metric("Total Chunks", self.stats["total_chunks"]) | |
with col2: | |
st.metric("Avg Chunk Size", f"{self.stats['avg_chunk_size']:.0f} chars") | |
st.metric("Embedding Errors", self.stats["embedding_errors"]) | |
if self.stats["chunking_methods"]: | |
st.write("**Chunking Methods Used:**") | |
for method, count in self.stats["chunking_methods"].items(): | |
percentage = (count / self.stats["total_documents"]) * 100 if self.stats["total_documents"] > 0 else 0 | |
st.write(f" - {method}: {count} documents ({percentage:.1f}%)") | |
st.write("**Configuration:**") | |
st.json({ | |
"chunk_size": self.chunk_size, | |
"overlap": self.overlap, | |
"similarity_threshold": self.similarity_threshold, | |
"min_chunk_size": self.min_chunk_size, | |
"max_chunk_size": self.max_chunk_size, | |
"embedding_model": self._detect_embedding_model_type() | |
}) | |
except ImportError: | |
# Streamlit not available, print to console | |
print("\n=== Semantic Chunking Statistics ===") | |
print(f"Documents processed: {self.stats['total_documents']}") | |
print(f"Chunks created: {self.stats['total_chunks']}") | |
print(f"Average chunk size: {self.stats['avg_chunk_size']:.0f} characters") | |
print(f"Embedding errors: {self.stats['embedding_errors']}") | |
print(f"Chunking methods: {self.stats['chunking_methods']}") | |
def create_semantic_chunker(embeddings_model, **kwargs) -> SemanticChunker: | |
""" | |
Convenience function to create a semantic chunker with sensible defaults. | |
Args: | |
embeddings_model: Your existing embeddings model | |
**kwargs: Additional parameters to pass to SemanticChunker | |
Returns: | |
SemanticChunker instance ready to use | |
""" | |
return SemanticChunker(embeddings_model=embeddings_model, **kwargs) |