smartpdf_Highlighter / src /functions.py
rosa0003's picture
Upload 10 files
5d4f783 verified
"""
This module provides functions for generating a highlighted PDF with important sentences.
The main function, `generate_highlighted_pdf`, takes an input PDF file and a pre-trained
sentence embedding model as input.
It splits the text of the PDF into sentences, computes sentence embeddings, and builds a
graph based on the cosine similarity between embeddings and at the same time split the
sentences to different clusters using clustering.
The sentences are then ranked using PageRank scores and a the middle of the cluster,
and important sentences are selected based on a threshold and clustering.
Finally, the selected sentences are highlighted in the PDF and the highlighted PDF content
is returned.
Other utility functions in this module include functions for loading a sentence embedding
model, encoding sentences, computing similarity matrices,building graphs, ranking sentences,
clustering sentence embeddings, and splitting text into sentences.
Note: This module requires the PyMuPDF, networkx, numpy, torch, sentence_transformers, and
sklearn libraries to be installed.
"""
import logging
from typing import BinaryIO, List, Tuple
import fitz # PyMuPDF
import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
# Constants
MAX_PAGE = 40
MAX_SENTENCES = 2000
PAGERANK_THRESHOLD_RATIO = 0.15
NUM_CLUSTERS_RATIO = 0.05
MIN_WORDS = 10
# Logger configuration
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)
def load_sentence_model(revision: str = None) -> SentenceTransformer:
"""
Load a pre-trained sentence embedding model.
Args:
revision (str): Optional parameter to specify the model revision.
Returns:
SentenceTransformer: A pre-trained sentence embedding model.
"""
return SentenceTransformer("avsolatorio/GIST-Embedding-v0", revision=revision)
def encode_sentence(model: SentenceTransformer, sentence: str) -> torch.Tensor:
"""
Encode a sentence into a fixed-dimensional vector representation.
Args:
model (SentenceTransformer): A pre-trained sentence embedding model.
sentence (str): Input sentence.
Returns:
torch.Tensor: Encoded sentence vector.
"""
model.eval() # Set the model to evaluation mode
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad(): # Disable gradient tracking
return model.encode(sentence, convert_to_tensor=True).to(device)
def compute_similarity_matrix(embeddings: torch.Tensor) -> np.ndarray:
"""
Compute the cosine similarity matrix between sentence embeddings.
Args:
embeddings (torch.Tensor): Sentence embeddings.
Returns:
np.ndarray: Cosine similarity matrix.
"""
scores = F.cosine_similarity(
embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=-1
)
similarity_matrix = scores.cpu().numpy()
normalized_adjacency_matrix = similarity_matrix / similarity_matrix.sum(
axis=1, keepdims=True
)
return normalized_adjacency_matrix
def build_graph(normalized_adjacency_matrix: np.ndarray) -> nx.DiGraph:
"""
Build a directed graph from a normalized adjacency matrix.
Args:
normalized_adjacency_matrix (np.ndarray): Normalized adjacency matrix.
Returns:
nx.DiGraph: Directed graph.
"""
return nx.DiGraph(normalized_adjacency_matrix)
def rank_sentences(graph: nx.DiGraph, sentences: List[str]) -> List[Tuple[str, float]]:
"""
Rank sentences based on PageRank scores.
Args:
graph (nx.DiGraph): Directed graph.
sentences (List[str]): List of sentences.
Returns:
List[Tuple[str, float]]: Ranked sentences with their PageRank scores.
"""
pagerank_scores = nx.pagerank(graph)
ranked_sentences = sorted(
zip(sentences, pagerank_scores.values()),
key=lambda x: x[1],
reverse=True,
)
return ranked_sentences
def cluster_sentences(
embeddings: torch.Tensor, num_clusters: int
) -> Tuple[np.ndarray, np.ndarray]:
"""
Cluster sentence embeddings using K-means clustering.
Args:
embeddings (torch.Tensor): Sentence embeddings.
num_clusters (int): Number of clusters.
Returns:
Tuple[np.ndarray, np.ndarray]: Cluster assignments and cluster centers.
"""
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
cluster_assignments = kmeans.fit_predict(embeddings.cpu())
cluster_centers = kmeans.cluster_centers_
return cluster_assignments, cluster_centers
def get_middle_sentence(cluster_indices: np.ndarray, sentences: List[str]) -> List[str]:
"""
Get the middle sentence from each cluster.
Args:
cluster_indices (np.ndarray): Cluster assignments.
sentences (List[str]): List of sentences.
Returns:
List[str]: Middle sentences from each cluster.
"""
middle_indices = [
int(np.median(np.where(cluster_indices == i)[0]))
for i in range(max(cluster_indices) + 1)
]
middle_sentences = [sentences[i] for i in middle_indices]
return middle_sentences
def split_text_into_sentences(text: str, min_words: int = MIN_WORDS) -> List[str]:
"""
Split text into sentences.
Args:
text (str): Input text.
min_words (int): Minimum number of words for a valid sentence.
Returns:
List[str]: List of sentences.
"""
sentences = []
for s in text.split("."):
s = s.strip()
# filtering out short sentences and sentences that contain more than 40% digits
if (
s
and len(s.split()) >= min_words
and (sum(c.isdigit() for c in s) / len(s)) < 0.4
):
sentences.append(s)
return sentences
def extract_text_from_pages(doc):
"""Generator to yield text per page from the PDF, for memory efficiency for large PDFs."""
for page_num in range(len(doc)):
yield doc[page_num].get_text()
def generate_highlighted_pdf(
input_pdf_file: BinaryIO, model=load_sentence_model()
) -> bytes:
"""
Generate a highlighted PDF with important sentences.
Args:
input_pdf_file: Input PDF file object.
model (SentenceTransformer): Pre-trained sentence embedding model.
Returns:
bytes: Highlighted PDF content.
"""
with fitz.open(stream=input_pdf_file.read(), filetype="pdf") as doc:
num_pages = doc.page_count
if num_pages > MAX_PAGE:
# It will show the error message for the user.
return f"The PDF file exceeds the maximum limit of {MAX_PAGE} pages."
sentences = []
for page_text in extract_text_from_pages(doc): # Memory efficient
sentences.extend(split_text_into_sentences(page_text))
len_sentences = len(sentences)
print(len_sentences)
if len_sentences > MAX_SENTENCES:
# It will show the error message for the user.
return (
f"The PDF file exceeds the maximum limit of {MAX_SENTENCES} sentences."
)
embeddings = encode_sentence(model, sentences)
similarity_matrix = compute_similarity_matrix(embeddings)
graph = build_graph(similarity_matrix)
ranked_sentences = rank_sentences(graph, sentences)
pagerank_threshold = int(len(ranked_sentences) * PAGERANK_THRESHOLD_RATIO) + 1
top_pagerank_sentences = [
sentence[0] for sentence in ranked_sentences[:pagerank_threshold]
]
num_clusters = int(len_sentences * NUM_CLUSTERS_RATIO) + 1
cluster_assignments, _ = cluster_sentences(embeddings, num_clusters)
center_sentences = get_middle_sentence(cluster_assignments, sentences)
important_sentences = list(set(top_pagerank_sentences + center_sentences))
for i in range(num_pages):
try:
page = doc[i]
for sentence in important_sentences:
rects = page.search_for(sentence)
colors = (fitz.pdfcolor["yellow"], fitz.pdfcolor["green"])
for i, rect in enumerate(rects):
color = colors[i % 2]
annot = page.add_highlight_annot(rect)
annot.set_colors(stroke=color)
annot.update()
except Exception as e:
logger.error(f"Error processing page {i}: {e}")
output_pdf = doc.write()
return output_pdf