""" Semantic Chunker Module for RAG Systems ====================================== A drop-in replacement for RecursiveCharacterTextSplitter that uses semantic similarity to create more coherent chunks. Designed to work seamlessly with existing LangChain and Streamlit RAG systems. Author: AI Assistant Compatible with: LangChain, BGE embeddings, OpenAI embeddings, Streamlit """ import numpy as np import re from typing import List, Dict, Any, Optional, Union from langchain.schema import Document import streamlit as st from sklearn.metrics.pairwise import cosine_similarity import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class SemanticChunker: """ Advanced semantic document chunker that creates coherent chunks based on semantic similarity rather than fixed character counts. Perfect for university documents, research papers, and policy documents where maintaining semantic coherence is crucial. """ def __init__(self, embeddings_model, chunk_size: int = 4, overlap: int = 1, similarity_threshold: float = 0.75, min_chunk_size: int = 150, max_chunk_size: int = 1500, sentence_split_pattern: Optional[str] = None, debug: bool = False): """ Initialize the semantic chunker. Args: embeddings_model: Your existing embeddings model (BGE, OpenAI, etc.) chunk_size: Base number of sentences per chunk (default: 4) overlap: Number of sentences to overlap between chunks (default: 1) similarity_threshold: Cosine similarity threshold for extending chunks (0.0-1.0) min_chunk_size: Minimum characters per chunk (skip smaller chunks) max_chunk_size: Maximum characters per chunk (prevent overly large chunks) sentence_split_pattern: Custom regex pattern for sentence splitting debug: Enable debug logging and statistics """ self.embeddings_model = embeddings_model self.chunk_size = chunk_size self.overlap = overlap self.similarity_threshold = similarity_threshold self.min_chunk_size = min_chunk_size self.max_chunk_size = max_chunk_size self.debug = debug # Default sentence splitting pattern optimized for academic/university documents self.sentence_pattern = sentence_split_pattern or r'[.!?]+\s+' # Statistics tracking self.stats = { "total_documents": 0, "total_chunks": 0, "avg_chunk_size": 0, "chunking_methods": {}, "embedding_errors": 0 } if self.debug: logger.info(f"Initialized SemanticChunker with threshold={similarity_threshold}") def _detect_embedding_model_type(self) -> str: """Detect the type of embedding model being used.""" if hasattr(self.embeddings_model, 'model'): # Likely sentence-transformers model (BGE, etc.) model_name = getattr(self.embeddings_model.model, 'model_name', 'sentence-transformers') return f"sentence-transformers ({model_name})" elif hasattr(self.embeddings_model, 'client'): # Likely OpenAI return "OpenAI" else: return "Unknown" def _preprocess_text_for_splitting(self, text: str) -> str: """ Preprocess text to handle common formatting issues in university documents. """ # Fix common formatting issues fixes = [ # Add space after periods before capital letters (r'([a-z])\.([A-Z])', r'\1. \2'), # Add space after numbers with periods (r'([0-9]+)\.([A-Z])', r'\1. \2'), # Fix missing spaces after question/exclamation marks (r'([a-z])\?([A-Z])', r'\1? \2'), (r'([a-z])\!([A-Z])', r'\1! \2'), # Clean up multiple spaces (r'\s+', ' '), # Fix bullet points (r'•\s*([A-Z])', r'• \1'), (r'-\s*([A-Z])', r'- \1'), ] processed_text = text for pattern, replacement in fixes: processed_text = re.sub(pattern, replacement, processed_text) return processed_text.strip() def _split_into_sentences(self, text: str) -> List[str]: """ Advanced sentence splitting optimized for academic documents. """ # Preprocess text text = self._preprocess_text_for_splitting(text) # Split on sentence boundaries raw_sentences = re.split(self.sentence_pattern, text) # Clean and filter sentences sentences = [] for sentence in raw_sentences: sentence = sentence.strip() # Filter out very short sentences, pure numbers, or empty strings if len(sentence) >= 10 and not sentence.isdigit() and not re.match(r'^[^\w]*$', sentence): sentences.append(sentence) if self.debug: logger.info(f"Split text into {len(sentences)} sentences") return sentences def _get_embeddings(self, texts: List[str]) -> Optional[np.ndarray]: """ Get embeddings from the provided model with error handling. """ try: if hasattr(self.embeddings_model, 'model'): # sentence-transformers model (BGE, etc.) embeddings = self.embeddings_model.model.encode(texts) return np.array(embeddings) elif hasattr(self.embeddings_model, 'embed_documents'): # OpenAI or similar API-based embeddings embeddings = self.embeddings_model.embed_documents(texts) return np.array(embeddings) else: # Try direct call embeddings = self.embeddings_model(texts) return np.array(embeddings) except Exception as e: self.stats["embedding_errors"] += 1 if self.debug: logger.error(f"Error generating embeddings: {e}") # Show warning in Streamlit if available try: st.warning(f"⚠️ Embedding error, falling back to simple chunking: {str(e)[:100]}...") except: pass # Streamlit not available return None def _calculate_semantic_boundaries(self, embeddings: np.ndarray, sentences: List[str]) -> List[int]: """ Find natural semantic boundaries in the text based on embedding similarities. """ boundaries = [0] # Always start with first sentence # Calculate similarities between consecutive sentences similarities = [] for i in range(len(embeddings) - 1): sim = cosine_similarity( embeddings[i:i+1], embeddings[i+1:i+2] )[0][0] similarities.append(sim) # Find significant drops in similarity (topic boundaries) if len(similarities) > 1: mean_sim = np.mean(similarities) std_sim = np.std(similarities) threshold = mean_sim - (0.5 * std_sim) # Adaptive threshold for i, sim in enumerate(similarities): if sim < threshold: boundaries.append(i + 1) boundaries.append(len(sentences)) # Always end with last sentence return sorted(list(set(boundaries))) # Remove duplicates and sort def _create_chunks_from_boundaries(self, sentences: List[str], boundaries: List[int], embeddings: Optional[np.ndarray], metadata: Dict[str, Any]) -> List[Document]: """ Create document chunks based on semantic boundaries. """ chunks = [] for i in range(len(boundaries) - 1): start_idx = boundaries[i] end_idx = boundaries[i + 1] # Create base chunk chunk_sentences = sentences[start_idx:end_idx] # Try to extend chunk if semantically similar if embeddings is not None and end_idx < len(sentences): current_embedding = np.mean(embeddings[start_idx:end_idx], axis=0, keepdims=True) # Check if we can extend the chunk extended_end = end_idx while extended_end < len(sentences): next_sentence_embedding = embeddings[extended_end:extended_end+1] similarity = cosine_similarity(current_embedding, next_sentence_embedding)[0][0] if similarity > self.similarity_threshold: # Check size limit test_chunk = ' '.join(sentences[start_idx:extended_end+1]) if len(test_chunk) <= self.max_chunk_size: extended_end += 1 # Update current embedding current_embedding = np.mean(embeddings[start_idx:extended_end], axis=0, keepdims=True) else: break else: break # Use extended chunk if we found extensions if extended_end > end_idx: chunk_sentences = sentences[start_idx:extended_end] # Create chunk text chunk_text = ' '.join(chunk_sentences) # Only add chunks that meet minimum size requirement if len(chunk_text) >= self.min_chunk_size: chunk_metadata = metadata.copy() chunk_metadata.update({ "chunk_index": len(chunks), "sentence_count": len(chunk_sentences), "start_sentence": start_idx, "end_sentence": start_idx + len(chunk_sentences) - 1, "chunking_method": "semantic_boundary", "similarity_threshold": self.similarity_threshold, "chunk_size_chars": len(chunk_text) }) chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata)) return chunks def _create_simple_chunks(self, sentences: List[str], metadata: Dict[str, Any]) -> List[Document]: """ Fallback to simple sentence-based chunking when embeddings are unavailable. """ chunks = [] for i in range(0, len(sentences), max(1, self.chunk_size - self.overlap)): chunk_sentences = sentences[i:i + self.chunk_size] chunk_text = ' '.join(chunk_sentences) if len(chunk_text) >= self.min_chunk_size: chunk_metadata = metadata.copy() chunk_metadata.update({ "chunk_index": len(chunks), "sentence_count": len(chunk_sentences), "start_sentence": i, "end_sentence": i + len(chunk_sentences) - 1, "chunking_method": "simple_fallback", "chunk_size_chars": len(chunk_text) }) chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata)) return chunks def split_documents(self, documents: List[Document]) -> List[Document]: """ Main method: Split documents into semantically coherent chunks. Args: documents: List of LangChain Document objects Returns: List of Document objects with semantic chunks """ all_chunks = [] self.stats["total_documents"] = len(documents) for doc_idx, doc in enumerate(documents): try: # Split document into sentences sentences = self._split_into_sentences(doc.page_content) if not sentences: if self.debug: logger.warning(f"No sentences found in document {doc_idx}") continue # Handle very short documents if len(sentences) < self.chunk_size: chunk_text = ' '.join(sentences) if len(chunk_text) >= self.min_chunk_size: chunk_metadata = doc.metadata.copy() chunk_metadata.update({ "chunk_index": 0, "total_chunks": 1, "sentence_count": len(sentences), "chunking_method": "single_chunk", "chunk_size_chars": len(chunk_text) }) all_chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata)) continue # Generate embeddings embeddings = self._get_embeddings(sentences) if embeddings is not None: # Create semantic chunks chunks = self._create_chunks_from_boundaries(sentences, [0, len(sentences)], embeddings, doc.metadata) method = "semantic" else: # Fallback to simple chunking chunks = self._create_simple_chunks(sentences, doc.metadata) method = "simple_fallback" # Update statistics self.stats["chunking_methods"][method] = self.stats["chunking_methods"].get(method, 0) + 1 # Update total chunks count in each chunk's metadata for chunk in chunks: chunk.metadata["total_chunks"] = len(chunks) chunk.metadata["source_document_index"] = doc_idx all_chunks.extend(chunks) if self.debug: logger.info(f"Document {doc_idx}: {len(sentences)} sentences → {len(chunks)} chunks ({method})") except Exception as e: logger.error(f"Error processing document {doc_idx}: {e}") if self.debug: st.error(f"Error processing document {doc_idx}: {e}") # Update final statistics self.stats["total_chunks"] = len(all_chunks) if all_chunks: chunk_sizes = [len(chunk.page_content) for chunk in all_chunks] self.stats["avg_chunk_size"] = sum(chunk_sizes) / len(chunk_sizes) if self.debug: logger.info(f"Created {len(all_chunks)} total chunks from {len(documents)} documents") return all_chunks def get_statistics(self) -> Dict[str, Any]: """Get chunking statistics for analysis.""" return self.stats.copy() def display_statistics(self): """Display chunking statistics in Streamlit (if available).""" try: with st.expander("📊 Semantic Chunking Statistics"): col1, col2 = st.columns(2) with col1: st.metric("Total Documents", self.stats["total_documents"]) st.metric("Total Chunks", self.stats["total_chunks"]) with col2: st.metric("Avg Chunk Size", f"{self.stats['avg_chunk_size']:.0f} chars") st.metric("Embedding Errors", self.stats["embedding_errors"]) if self.stats["chunking_methods"]: st.write("**Chunking Methods Used:**") for method, count in self.stats["chunking_methods"].items(): percentage = (count / self.stats["total_documents"]) * 100 if self.stats["total_documents"] > 0 else 0 st.write(f" - {method}: {count} documents ({percentage:.1f}%)") st.write("**Configuration:**") st.json({ "chunk_size": self.chunk_size, "overlap": self.overlap, "similarity_threshold": self.similarity_threshold, "min_chunk_size": self.min_chunk_size, "max_chunk_size": self.max_chunk_size, "embedding_model": self._detect_embedding_model_type() }) except ImportError: # Streamlit not available, print to console print("\n=== Semantic Chunking Statistics ===") print(f"Documents processed: {self.stats['total_documents']}") print(f"Chunks created: {self.stats['total_chunks']}") print(f"Average chunk size: {self.stats['avg_chunk_size']:.0f} characters") print(f"Embedding errors: {self.stats['embedding_errors']}") print(f"Chunking methods: {self.stats['chunking_methods']}") def create_semantic_chunker(embeddings_model, **kwargs) -> SemanticChunker: """ Convenience function to create a semantic chunker with sensible defaults. Args: embeddings_model: Your existing embeddings model **kwargs: Additional parameters to pass to SemanticChunker Returns: SemanticChunker instance ready to use """ return SemanticChunker(embeddings_model=embeddings_model, **kwargs)