|
|
""" |
|
|
Semantic Chunker Module for RAG Systems |
|
|
====================================== |
|
|
|
|
|
A drop-in replacement for RecursiveCharacterTextSplitter that uses semantic similarity |
|
|
to create more coherent chunks. Designed to work seamlessly with existing LangChain |
|
|
and Streamlit RAG systems. |
|
|
|
|
|
Author: AI Assistant |
|
|
Compatible with: LangChain, BGE embeddings, OpenAI embeddings, Streamlit |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import re |
|
|
from typing import List, Dict, Any, Optional, Union |
|
|
from langchain.schema import Document |
|
|
import streamlit as st |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SemanticChunker: |
|
|
""" |
|
|
Advanced semantic document chunker that creates coherent chunks based on |
|
|
semantic similarity rather than fixed character counts. |
|
|
|
|
|
Perfect for university documents, research papers, and policy documents |
|
|
where maintaining semantic coherence is crucial. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
embeddings_model, |
|
|
chunk_size: int = 4, |
|
|
overlap: int = 1, |
|
|
similarity_threshold: float = 0.75, |
|
|
min_chunk_size: int = 150, |
|
|
max_chunk_size: int = 1500, |
|
|
sentence_split_pattern: Optional[str] = None, |
|
|
debug: bool = False): |
|
|
""" |
|
|
Initialize the semantic chunker. |
|
|
|
|
|
Args: |
|
|
embeddings_model: Your existing embeddings model (BGE, OpenAI, etc.) |
|
|
chunk_size: Base number of sentences per chunk (default: 4) |
|
|
overlap: Number of sentences to overlap between chunks (default: 1) |
|
|
similarity_threshold: Cosine similarity threshold for extending chunks (0.0-1.0) |
|
|
min_chunk_size: Minimum characters per chunk (skip smaller chunks) |
|
|
max_chunk_size: Maximum characters per chunk (prevent overly large chunks) |
|
|
sentence_split_pattern: Custom regex pattern for sentence splitting |
|
|
debug: Enable debug logging and statistics |
|
|
""" |
|
|
self.embeddings_model = embeddings_model |
|
|
self.chunk_size = chunk_size |
|
|
self.overlap = overlap |
|
|
self.similarity_threshold = similarity_threshold |
|
|
self.min_chunk_size = min_chunk_size |
|
|
self.max_chunk_size = max_chunk_size |
|
|
self.debug = debug |
|
|
|
|
|
|
|
|
self.sentence_pattern = sentence_split_pattern or r'[.!?]+\s+' |
|
|
|
|
|
|
|
|
self.stats = { |
|
|
"total_documents": 0, |
|
|
"total_chunks": 0, |
|
|
"avg_chunk_size": 0, |
|
|
"chunking_methods": {}, |
|
|
"embedding_errors": 0 |
|
|
} |
|
|
|
|
|
if self.debug: |
|
|
logger.info(f"Initialized SemanticChunker with threshold={similarity_threshold}") |
|
|
|
|
|
def _detect_embedding_model_type(self) -> str: |
|
|
"""Detect the type of embedding model being used.""" |
|
|
if hasattr(self.embeddings_model, 'model'): |
|
|
|
|
|
model_name = getattr(self.embeddings_model.model, 'model_name', 'sentence-transformers') |
|
|
return f"sentence-transformers ({model_name})" |
|
|
elif hasattr(self.embeddings_model, 'client'): |
|
|
|
|
|
return "OpenAI" |
|
|
else: |
|
|
return "Unknown" |
|
|
|
|
|
def _preprocess_text_for_splitting(self, text: str) -> str: |
|
|
""" |
|
|
Preprocess text to handle common formatting issues in university documents. |
|
|
""" |
|
|
|
|
|
fixes = [ |
|
|
|
|
|
(r'([a-z])\.([A-Z])', r'\1. \2'), |
|
|
|
|
|
(r'([0-9]+)\.([A-Z])', r'\1. \2'), |
|
|
|
|
|
(r'([a-z])\?([A-Z])', r'\1? \2'), |
|
|
(r'([a-z])\!([A-Z])', r'\1! \2'), |
|
|
|
|
|
(r'\s+', ' '), |
|
|
|
|
|
(r'β’\s*([A-Z])', r'β’ \1'), |
|
|
(r'-\s*([A-Z])', r'- \1'), |
|
|
] |
|
|
|
|
|
processed_text = text |
|
|
for pattern, replacement in fixes: |
|
|
processed_text = re.sub(pattern, replacement, processed_text) |
|
|
|
|
|
return processed_text.strip() |
|
|
|
|
|
def _split_into_sentences(self, text: str) -> List[str]: |
|
|
""" |
|
|
Advanced sentence splitting optimized for academic documents. |
|
|
""" |
|
|
|
|
|
text = self._preprocess_text_for_splitting(text) |
|
|
|
|
|
|
|
|
raw_sentences = re.split(self.sentence_pattern, text) |
|
|
|
|
|
|
|
|
sentences = [] |
|
|
for sentence in raw_sentences: |
|
|
sentence = sentence.strip() |
|
|
|
|
|
|
|
|
if len(sentence) >= 10 and not sentence.isdigit() and not re.match(r'^[^\w]*$', sentence): |
|
|
sentences.append(sentence) |
|
|
|
|
|
if self.debug: |
|
|
logger.info(f"Split text into {len(sentences)} sentences") |
|
|
|
|
|
return sentences |
|
|
|
|
|
def _get_embeddings(self, texts: List[str]) -> Optional[np.ndarray]: |
|
|
""" |
|
|
Get embeddings from the provided model with error handling. |
|
|
""" |
|
|
try: |
|
|
if hasattr(self.embeddings_model, 'model'): |
|
|
|
|
|
embeddings = self.embeddings_model.model.encode(texts) |
|
|
return np.array(embeddings) |
|
|
elif hasattr(self.embeddings_model, 'embed_documents'): |
|
|
|
|
|
embeddings = self.embeddings_model.embed_documents(texts) |
|
|
return np.array(embeddings) |
|
|
else: |
|
|
|
|
|
embeddings = self.embeddings_model(texts) |
|
|
return np.array(embeddings) |
|
|
|
|
|
except Exception as e: |
|
|
self.stats["embedding_errors"] += 1 |
|
|
if self.debug: |
|
|
logger.error(f"Error generating embeddings: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
st.warning(f"β οΈ Embedding error, falling back to simple chunking: {str(e)[:100]}...") |
|
|
except: |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
def _calculate_semantic_boundaries(self, embeddings: np.ndarray, sentences: List[str]) -> List[int]: |
|
|
""" |
|
|
Find natural semantic boundaries in the text based on embedding similarities. |
|
|
""" |
|
|
boundaries = [0] |
|
|
|
|
|
|
|
|
similarities = [] |
|
|
for i in range(len(embeddings) - 1): |
|
|
sim = cosine_similarity( |
|
|
embeddings[i:i+1], |
|
|
embeddings[i+1:i+2] |
|
|
)[0][0] |
|
|
similarities.append(sim) |
|
|
|
|
|
|
|
|
if len(similarities) > 1: |
|
|
mean_sim = np.mean(similarities) |
|
|
std_sim = np.std(similarities) |
|
|
threshold = mean_sim - (0.5 * std_sim) |
|
|
|
|
|
for i, sim in enumerate(similarities): |
|
|
if sim < threshold: |
|
|
boundaries.append(i + 1) |
|
|
|
|
|
boundaries.append(len(sentences)) |
|
|
|
|
|
return sorted(list(set(boundaries))) |
|
|
|
|
|
def _create_chunks_from_boundaries(self, sentences: List[str], boundaries: List[int], |
|
|
embeddings: Optional[np.ndarray], metadata: Dict[str, Any]) -> List[Document]: |
|
|
""" |
|
|
Create document chunks based on semantic boundaries. |
|
|
""" |
|
|
chunks = [] |
|
|
|
|
|
for i in range(len(boundaries) - 1): |
|
|
start_idx = boundaries[i] |
|
|
end_idx = boundaries[i + 1] |
|
|
|
|
|
|
|
|
chunk_sentences = sentences[start_idx:end_idx] |
|
|
|
|
|
|
|
|
if embeddings is not None and end_idx < len(sentences): |
|
|
current_embedding = np.mean(embeddings[start_idx:end_idx], axis=0, keepdims=True) |
|
|
|
|
|
|
|
|
extended_end = end_idx |
|
|
while extended_end < len(sentences): |
|
|
next_sentence_embedding = embeddings[extended_end:extended_end+1] |
|
|
similarity = cosine_similarity(current_embedding, next_sentence_embedding)[0][0] |
|
|
|
|
|
if similarity > self.similarity_threshold: |
|
|
|
|
|
test_chunk = ' '.join(sentences[start_idx:extended_end+1]) |
|
|
if len(test_chunk) <= self.max_chunk_size: |
|
|
extended_end += 1 |
|
|
|
|
|
current_embedding = np.mean(embeddings[start_idx:extended_end], axis=0, keepdims=True) |
|
|
else: |
|
|
break |
|
|
else: |
|
|
break |
|
|
|
|
|
|
|
|
if extended_end > end_idx: |
|
|
chunk_sentences = sentences[start_idx:extended_end] |
|
|
|
|
|
|
|
|
chunk_text = ' '.join(chunk_sentences) |
|
|
|
|
|
|
|
|
if len(chunk_text) >= self.min_chunk_size: |
|
|
chunk_metadata = metadata.copy() |
|
|
chunk_metadata.update({ |
|
|
"chunk_index": len(chunks), |
|
|
"sentence_count": len(chunk_sentences), |
|
|
"start_sentence": start_idx, |
|
|
"end_sentence": start_idx + len(chunk_sentences) - 1, |
|
|
"chunking_method": "semantic_boundary", |
|
|
"similarity_threshold": self.similarity_threshold, |
|
|
"chunk_size_chars": len(chunk_text) |
|
|
}) |
|
|
|
|
|
chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata)) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _create_simple_chunks(self, sentences: List[str], metadata: Dict[str, Any]) -> List[Document]: |
|
|
""" |
|
|
Fallback to simple sentence-based chunking when embeddings are unavailable. |
|
|
""" |
|
|
chunks = [] |
|
|
|
|
|
for i in range(0, len(sentences), max(1, self.chunk_size - self.overlap)): |
|
|
chunk_sentences = sentences[i:i + self.chunk_size] |
|
|
chunk_text = ' '.join(chunk_sentences) |
|
|
|
|
|
if len(chunk_text) >= self.min_chunk_size: |
|
|
chunk_metadata = metadata.copy() |
|
|
chunk_metadata.update({ |
|
|
"chunk_index": len(chunks), |
|
|
"sentence_count": len(chunk_sentences), |
|
|
"start_sentence": i, |
|
|
"end_sentence": i + len(chunk_sentences) - 1, |
|
|
"chunking_method": "simple_fallback", |
|
|
"chunk_size_chars": len(chunk_text) |
|
|
}) |
|
|
|
|
|
chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata)) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def split_documents(self, documents: List[Document]) -> List[Document]: |
|
|
""" |
|
|
Main method: Split documents into semantically coherent chunks. |
|
|
|
|
|
Args: |
|
|
documents: List of LangChain Document objects |
|
|
|
|
|
Returns: |
|
|
List of Document objects with semantic chunks |
|
|
""" |
|
|
all_chunks = [] |
|
|
self.stats["total_documents"] = len(documents) |
|
|
|
|
|
for doc_idx, doc in enumerate(documents): |
|
|
try: |
|
|
|
|
|
sentences = self._split_into_sentences(doc.page_content) |
|
|
|
|
|
if not sentences: |
|
|
if self.debug: |
|
|
logger.warning(f"No sentences found in document {doc_idx}") |
|
|
continue |
|
|
|
|
|
|
|
|
if len(sentences) < self.chunk_size: |
|
|
chunk_text = ' '.join(sentences) |
|
|
if len(chunk_text) >= self.min_chunk_size: |
|
|
chunk_metadata = doc.metadata.copy() |
|
|
chunk_metadata.update({ |
|
|
"chunk_index": 0, |
|
|
"total_chunks": 1, |
|
|
"sentence_count": len(sentences), |
|
|
"chunking_method": "single_chunk", |
|
|
"chunk_size_chars": len(chunk_text) |
|
|
}) |
|
|
all_chunks.append(Document(page_content=chunk_text, metadata=chunk_metadata)) |
|
|
continue |
|
|
|
|
|
|
|
|
embeddings = self._get_embeddings(sentences) |
|
|
|
|
|
if embeddings is not None: |
|
|
|
|
|
chunks = self._create_chunks_from_boundaries(sentences, [0, len(sentences)], embeddings, doc.metadata) |
|
|
method = "semantic" |
|
|
else: |
|
|
|
|
|
chunks = self._create_simple_chunks(sentences, doc.metadata) |
|
|
method = "simple_fallback" |
|
|
|
|
|
|
|
|
self.stats["chunking_methods"][method] = self.stats["chunking_methods"].get(method, 0) + 1 |
|
|
|
|
|
|
|
|
for chunk in chunks: |
|
|
chunk.metadata["total_chunks"] = len(chunks) |
|
|
chunk.metadata["source_document_index"] = doc_idx |
|
|
|
|
|
all_chunks.extend(chunks) |
|
|
|
|
|
if self.debug: |
|
|
logger.info(f"Document {doc_idx}: {len(sentences)} sentences β {len(chunks)} chunks ({method})") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing document {doc_idx}: {e}") |
|
|
if self.debug: |
|
|
st.error(f"Error processing document {doc_idx}: {e}") |
|
|
|
|
|
|
|
|
self.stats["total_chunks"] = len(all_chunks) |
|
|
if all_chunks: |
|
|
chunk_sizes = [len(chunk.page_content) for chunk in all_chunks] |
|
|
self.stats["avg_chunk_size"] = sum(chunk_sizes) / len(chunk_sizes) |
|
|
|
|
|
if self.debug: |
|
|
logger.info(f"Created {len(all_chunks)} total chunks from {len(documents)} documents") |
|
|
|
|
|
return all_chunks |
|
|
|
|
|
def get_statistics(self) -> Dict[str, Any]: |
|
|
"""Get chunking statistics for analysis.""" |
|
|
return self.stats.copy() |
|
|
|
|
|
def display_statistics(self): |
|
|
"""Display chunking statistics in Streamlit (if available).""" |
|
|
try: |
|
|
with st.expander("π Semantic Chunking Statistics"): |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.metric("Total Documents", self.stats["total_documents"]) |
|
|
st.metric("Total Chunks", self.stats["total_chunks"]) |
|
|
|
|
|
with col2: |
|
|
st.metric("Avg Chunk Size", f"{self.stats['avg_chunk_size']:.0f} chars") |
|
|
st.metric("Embedding Errors", self.stats["embedding_errors"]) |
|
|
|
|
|
if self.stats["chunking_methods"]: |
|
|
st.write("**Chunking Methods Used:**") |
|
|
for method, count in self.stats["chunking_methods"].items(): |
|
|
percentage = (count / self.stats["total_documents"]) * 100 if self.stats["total_documents"] > 0 else 0 |
|
|
st.write(f" - {method}: {count} documents ({percentage:.1f}%)") |
|
|
|
|
|
st.write("**Configuration:**") |
|
|
st.json({ |
|
|
"chunk_size": self.chunk_size, |
|
|
"overlap": self.overlap, |
|
|
"similarity_threshold": self.similarity_threshold, |
|
|
"min_chunk_size": self.min_chunk_size, |
|
|
"max_chunk_size": self.max_chunk_size, |
|
|
"embedding_model": self._detect_embedding_model_type() |
|
|
}) |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
print("\n=== Semantic Chunking Statistics ===") |
|
|
print(f"Documents processed: {self.stats['total_documents']}") |
|
|
print(f"Chunks created: {self.stats['total_chunks']}") |
|
|
print(f"Average chunk size: {self.stats['avg_chunk_size']:.0f} characters") |
|
|
print(f"Embedding errors: {self.stats['embedding_errors']}") |
|
|
print(f"Chunking methods: {self.stats['chunking_methods']}") |
|
|
|
|
|
|
|
|
def create_semantic_chunker(embeddings_model, **kwargs) -> SemanticChunker: |
|
|
""" |
|
|
Convenience function to create a semantic chunker with sensible defaults. |
|
|
|
|
|
Args: |
|
|
embeddings_model: Your existing embeddings model |
|
|
**kwargs: Additional parameters to pass to SemanticChunker |
|
|
|
|
|
Returns: |
|
|
SemanticChunker instance ready to use |
|
|
""" |
|
|
return SemanticChunker(embeddings_model=embeddings_model, **kwargs) |