Extractive_Text_Summarization / utils /extractive_summarization_textrank.py
blacksmithop's picture
Update utils/extractive_summarization_textrank.py
1ee348e verified
import spacy
import pytextrank
import tiktoken
from nltk import download, sent_tokenize
from math import ceil, log
from typing import Literal
"""
Ensure Spacy, pytextrank, nltk, and tiktoken are installed.
Download spacy model:
spacy download en_core_web_sm
"""
# Install nltk tokenizer
download('punkt_tab')
# Load the spaCy model and add the pytextrank pipeline
nlp = spacy.load("en_core_web_sm")
nlp.add_pipe("textrank")
# Constants
MAX_RETRY_COUNT = 5
SENTENCE_BUFFER_INCREMENT = 5
ALLOWED_TOKEN_DEVIATION = 250
# Initialize the token encoder for the specified model
encoder = tiktoken.encoding_for_model("gpt-4o") # Change model name if needed
def get_tokens_with_count(text: str):
"""Encodes the text and returns the tokens along with their count."""
tokens = encoder.encode(text)
return tokens, len(tokens)
def calculate_avg_tokens_per_sentence(text: str):
"""
Calculates the average number of tokens per sentence in the given text.
Args:
text (str): The input text.
Returns:
float: The average number of tokens per sentence.
"""
# Tokenize the text into sentences
sentences = sent_tokenize(text)
# Encode the text to get tokens
_, total_tokens = get_tokens_with_count(text)
# Calculate the average tokens per sentence
if len(sentences) > 0:
return total_tokens / len(sentences)
else:
return 0 # Avoid division by zero for empty or malformed text
def calculate_summary_length(num_sentences: int):
"""Determines the target summary sentence count based on the number of sentences."""
if num_sentences <= 0:
raise ValueError("Number of sentences must be greater than zero.")
if num_sentences <= 12: # Smaller corpus
return max(1, ceil(num_sentences * 2 / 3))
else: # Larger corpus
scaling_factor = 10 + log(num_sentences, 2) # Dynamic scaling
return max(15, ceil(num_sentences / scaling_factor))
def get_sentence_count(text: str, token_count: int):
"""
Dynamically calculates the initial number of sentences for the summary based on the token limit.
Args:
text (str): The input text.
token_count (int): Target token limit for the summary.
Returns:
int: Suggested number of sentences for the summary.
"""
avg_tokens_per_sentence = calculate_avg_tokens_per_sentence(text)
if avg_tokens_per_sentence < 1:
avg_tokens_per_sentence = 15 # Fallback default
sentences = sent_tokenize(text)
total_sentences = len(sentences)
# Calculate the initial number of sentences for the summary
estimated_sentences = min(
total_sentences, # Do not exceed the total number of sentences
max(1, token_count // avg_tokens_per_sentence) # Adjust based on token limit
)
return calculate_summary_length(num_sentences=estimated_sentences)
def summarize(tr, sentence_count: int, level: str):
"""Generates a summary using pytextrank."""
summaries = [
str(sent) for sent in tr.summary(
limit_sentences=sentence_count, preserve_order=True, level=level
)
]
return ". ".join(summaries)
def get_textrank_summary(
text: str,
token_count: int = 100,
level: Literal["sentence", "paragraph"] = "sentence",
verbose: bool = True,
):
"""
Generates a textrank-based summary within the specified token limit.
Args:
text (str): The input text.
token_count (int): Desired token limit for the summary.
level (Literal["sentence", "paragraph"]): Granularity of the summary.
verbose (bool): Whether to print retry information for debugging.
Returns:
str: Generated summary.
"""
# Analyze the text with spaCy and extract textrank data
doc = nlp(text)
tr = doc._.textrank
# Determine initial sentence count
sentence_count = get_sentence_count(text=text, token_count=token_count)
sentence_count += SENTENCE_BUFFER_INCREMENT
retry_count = 0
summary_content = ""
while retry_count <= MAX_RETRY_COUNT:
summary_content = summarize(tr=tr, sentence_count=sentence_count, level=level)
_, summary_token_count = get_tokens_with_count(text=summary_content)
deviation = abs(token_count - summary_token_count)
if deviation <= ALLOWED_TOKEN_DEVIATION:
break
elif summary_token_count > token_count:
sentence_count = max(1, sentence_count - SENTENCE_BUFFER_INCREMENT)
else:
sentence_count += SENTENCE_BUFFER_INCREMENT
retry_count += 1
if retry_count > MAX_RETRY_COUNT:
print("Warning: Max retries reached. Summary may not meet token requirements.")
if verbose:
verbose_message = f"**Token Count:** {token_count} | **Sentence Count:** {sentence_count} | **Summary Token Count:** {summary_token_count} | **Token count deviation:** {deviation}"
return summary_content, verbose_message
return summary_content, ""