|
import base64 |
|
from langchain.chains.summarize import load_summarize_chain |
|
from langchain.docstore.document import Document |
|
|
|
from langchain.document_loaders import PyMuPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from PyPDF2 import PdfReader |
|
import re |
|
import streamlit as st |
|
from streamlit_tags import st_tags |
|
import sys |
|
import time |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM |
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def file_preprocessing( |
|
file, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words |
|
): |
|
loader = PyMuPDFLoader(file) |
|
pages = loader.load_and_split() |
|
|
|
if (skipfirst == 1) & (skiplast == 0): |
|
del pages[0] |
|
elif (skipfirst == 0) & (skiplast == 1): |
|
del pages[-1] |
|
elif (skipfirst == 1) & (skiplast == 1): |
|
del pages[0] |
|
del pages[-1] |
|
else: |
|
pages = pages |
|
input_text = "" |
|
for page in pages: |
|
input_text = input_text + page.page_content |
|
input_text = re.sub("-\n", "", input_text) |
|
input_text = re.sub(r"\n", " ", input_text) |
|
|
|
valid_sentences = [] |
|
|
|
sentences = re.split(r"(?<=[.!?])\s+", input_text) |
|
|
|
for sentence in sentences: |
|
|
|
if any(word in sentence for word in exclude_words): |
|
continue |
|
valid_sentences.append(sentence) |
|
final_input_text = " ".join(valid_sentences) |
|
print("\n############## New article ##############\n") |
|
print("Cleaned and formatted input text:\n") |
|
print(final_input_text) |
|
print("\nExcluded words: " + str(exclude_words)) |
|
print("\nChunking input text...\n") |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap, |
|
length_function=len, |
|
separators=["\n\n", "\n", " ", ""], |
|
) |
|
text_chunks = text_splitter.split_text(final_input_text) |
|
print("Number of chunks: " + str(len(text_chunks)), end="") |
|
chunks = "" |
|
for text in text_chunks: |
|
chunks = chunks + "\n\n" + text |
|
print(chunks) |
|
return final_input_text, text_chunks |
|
|
|
|
|
|
|
def preprocessing_word_count( |
|
filepath, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words |
|
): |
|
final_input_text, text_chunks = file_preprocessing( |
|
filepath, skipfirst, skiplast, chunk_size, chunk_overlap, exclude_words |
|
) |
|
text_length = len(re.findall(r"\w+", final_input_text)) |
|
print("\nInput word count: " f"{text_length:,}") |
|
print("Chunk size: " f"{chunk_size:,}") |
|
print("Chunk overlap: %s" % chunk_overlap) |
|
return final_input_text, text_chunks, text_length |
|
|
|
|
|
|
|
def llm_pipeline( |
|
tokenizer, base_model, final_input_text, model_source, minimum_token_number |
|
): |
|
summarizer = pipeline( |
|
task="summarization", |
|
model=base_model, |
|
tokenizer=tokenizer, |
|
truncation=True, |
|
) |
|
print("Model source: %s" % (model_source)) |
|
print("Summarizing...\n") |
|
result = summarizer( |
|
final_input_text, |
|
min_length=minimum_token_number, |
|
max_length=tokenizer.model_max_length, |
|
) |
|
summary = result[0]["summary_text"] |
|
print("Summary text:\n") |
|
print(summary) |
|
return summary |
|
|
|
|
|
|
|
def postprocessing_word_count(summary): |
|
text_length = len(re.findall(r"\w+", summary)) |
|
print("\nSummary word count: " f"{text_length:,}") |
|
return text_length |
|
|
|
|
|
|
|
def clean_summary_text(summary): |
|
|
|
summary_cleaned_1 = re.sub(r"\n\s+", "", summary) |
|
|
|
summary_cleaned_2 = summary_cleaned_1.strip() |
|
|
|
summary_cleaned_3 = re.sub(r"\s+([.,;:)!?](?:\s|$))", r"\1", summary_cleaned_2) |
|
|
|
summary_cleaned_4 = re.sub(r"\(\s", r"(", summary_cleaned_3) |
|
|
|
summary_cleaned_5 = re.sub(r"(\))\s+([,.:;?!])", r"\1\2", summary_cleaned_4) |
|
return summary_cleaned_5 |
|
|
|
|
|
|
|
def convert_to_sentence_case(summary): |
|
|
|
sentences = re.split(r"(?<=[.!?])\s+", summary) |
|
|
|
formatted_sentences = [sentence.capitalize() for sentence in sentences] |
|
return " ".join(formatted_sentences) |
|
|
|
|
|
def remove_duplicate_sentences(summary): |
|
|
|
sentences = re.split(r"(?<=[.!?])\s+", summary) |
|
|
|
unique_sentences = set() |
|
|
|
valid_sentences = [] |
|
|
|
for sentence in sentences: |
|
|
|
if sentence not in unique_sentences: |
|
unique_sentences.add(sentence) |
|
valid_sentences.append(sentence) |
|
|
|
final_summary = " ".join(valid_sentences) |
|
return final_summary |
|
|
|
|
|
|
|
def remove_incomplete_last_sentence(summary): |
|
|
|
sentences = re.split(r"(?<=[.!?])\s+", summary) |
|
|
|
if ( |
|
sentences |
|
and sentences[-1].strip() |
|
and not sentences[-1].strip().endswith((".", "!", "?")) |
|
): |
|
|
|
sentences.pop() |
|
|
|
return " ".join(sentences) |
|
|
|
|
|
@st.cache_data(ttl=60 * 60) |
|
|
|
def displayPDF(file): |
|
with open(file, "rb") as f: |
|
base64_pdf = base64.b64encode(f.read()).decode("utf-8") |
|
|
|
pdf_display = f'<iframe src="data:application/pdf;base64,{base64_pdf}" width="100%" height="600" type="application/pdf"></iframe>' |
|
|
|
st.markdown(pdf_display, unsafe_allow_html=True) |
|
|
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
def main(): |
|
st.title("RASA: Research Article Summarization App") |
|
uploaded_file = st.file_uploader("Upload your PDF file", type=["pdf"]) |
|
if uploaded_file is not None: |
|
st.subheader("Options") |
|
col1, col2, col3, col4 = st.columns([1, 1, 1, 2]) |
|
with col1: |
|
model_source_names = ["Cached model", "Download model"] |
|
model_source = st.radio( |
|
"For development:", |
|
model_source_names, |
|
help="Defaults to a cached model; downloading will take longer", |
|
) |
|
with col2: |
|
model_names = [ |
|
"T5-Small", |
|
"BART", |
|
] |
|
selected_model = st.radio( |
|
"Select a model to use:", |
|
model_names, |
|
) |
|
if selected_model == "BART": |
|
chunk_size = 800 |
|
chunk_overlap = 80 |
|
checkpoint = "ccdv/lsg-bart-base-16384-pubmed" |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
checkpoint, |
|
truncation=True, |
|
model_max_length=512, |
|
trust_remote_code=True, |
|
) |
|
if model_source == "Download model": |
|
base_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
checkpoint, |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True, |
|
) |
|
else: |
|
base_model = "model_cache/models--ccdv--lsg-bart-base-16384-pubmed/snapshots/4072bc1a7a94e2b4fd860a5fdf1b71d0487dcf15" |
|
else: |
|
chunk_size = 1000 |
|
chunk_overlap = 100 |
|
checkpoint = "MBZUAI/LaMini-Flan-T5-77M" |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
checkpoint, |
|
truncation=True, |
|
legacy=False, |
|
model_max_length=512, |
|
) |
|
if model_source == "Download model": |
|
base_model = AutoModelForSeq2SeqLM.from_pretrained( |
|
checkpoint, |
|
torch_dtype=torch.float32, |
|
) |
|
else: |
|
base_model = "model_cache/models--MBZUAI--LaMini-Flan-T5-77M/snapshots/c5b12d50a2616b9670a57189be20055d1357b474" |
|
with col3: |
|
st.write("Skip any pages?") |
|
skipfirst = st.checkbox( |
|
"Skip first page", help="Select if your PDF has a cover page" |
|
) |
|
skiplast = st.checkbox("Skip last page") |
|
with col4: |
|
st.write("Background information (links open in a new window)") |
|
st.write( |
|
"Model class: [T5-Small](https://huggingface.co/docs/transformers/main/en/model_doc/t5)" |
|
" | Model: [LaMini-Flan-T5-77M](https://huggingface.co/MBZUAI/LaMini-Flan-T5-77M)" |
|
) |
|
st.write( |
|
"Model class: [BART](https://huggingface.co/docs/transformers/main/en/model_doc/bart)" |
|
" | Model: [lsg-bart-base-16384-pubmed](https://huggingface.co/ccdv/lsg-bart-base-16384-pubmed)" |
|
) |
|
exclude_words = st_tags( |
|
label="Enter word(s) to exclude from the summary:", |
|
text="Press enter to add", |
|
) |
|
col1, col2, col3 = st.columns([1, 1, 5]) |
|
with col1: |
|
minimum_token_number = st.number_input( |
|
"Minimum number of tokens", |
|
value=200, |
|
step=25, |
|
min_value=0, |
|
max_value=512, |
|
help="Use a larger number of tokens to increase summary length", |
|
) |
|
with col3: |
|
st.subheader("Notes") |
|
st.write( |
|
"To remove content from the summary, try copying and pasting the word(s) to exclude in the box above and summarize again." |
|
) |
|
st.write( |
|
"To lengthen or shorten the summary, increase or decrease the minimum number of tokens to the left and summarize again." |
|
) |
|
if st.button("Summarize"): |
|
col1, col2 = st.columns(2) |
|
filepath = "data/" + uploaded_file.name |
|
with open(filepath, "wb") as temp_file: |
|
temp_file.write(uploaded_file.read()) |
|
with col1: |
|
( |
|
final_input_text, |
|
text_chunks, |
|
preprocessing_text_length, |
|
) = preprocessing_word_count( |
|
filepath, |
|
skipfirst, |
|
skiplast, |
|
chunk_size, |
|
chunk_overlap, |
|
exclude_words, |
|
) |
|
st.info( |
|
"Uploaded PDF | Number of words: " |
|
f"{preprocessing_text_length:,}" |
|
) |
|
pdf_viewer = displayPDF(filepath) |
|
with col2: |
|
start = time.time() |
|
with st.spinner("Summarizing..."): |
|
summary = llm_pipeline( |
|
tokenizer, |
|
base_model, |
|
final_input_text, |
|
model_source, |
|
minimum_token_number, |
|
) |
|
|
|
postprocessing_text_length = postprocessing_word_count(summary) |
|
end = time.time() |
|
duration = end - start |
|
print("Duration: " f"{duration:.0f}" + " seconds") |
|
st.info( |
|
"PDF Summary | Number of words: " |
|
f"{postprocessing_text_length:,}" |
|
+ " | Summarization time: " |
|
f"{duration:.0f}" + " seconds" |
|
) |
|
if selected_model == "BART": |
|
|
|
summary_cleaned = clean_summary_text(summary) |
|
|
|
summary_cleaned_sentence_case = convert_to_sentence_case( |
|
summary_cleaned |
|
) |
|
|
|
summary_cleaned_sentence_case_dedup = remove_duplicate_sentences( |
|
summary_cleaned_sentence_case |
|
) |
|
|
|
summary_cleaned_final = remove_incomplete_last_sentence( |
|
summary_cleaned_sentence_case_dedup |
|
) |
|
st.success(summary_cleaned_final) |
|
with st.expander("Unformatted output"): |
|
st.write(summary) |
|
else: |
|
|
|
summary_dedup = remove_duplicate_sentences(summary) |
|
|
|
summary_final = remove_incomplete_last_sentence(summary_dedup) |
|
st.success(summary_final) |
|
with st.expander("Unformatted output"): |
|
st.write(summary) |
|
col1 = st.columns(1) |
|
url = "https://dev.to/eteimz/understanding-langchains-recursivecharactertextsplitter-2846" |
|
st.info("Additional information") |
|
input_ids = tokenizer.encode( |
|
final_input_text, add_special_tokens=True, truncation=True |
|
) |
|
st.write( |
|
"Maximum number of tokens generated for inputs into the model: %s" |
|
% f"{len(input_ids):,}" |
|
) |
|
st.write("First 10 tokens:") |
|
first_10_tokens = input_ids[:10] |
|
first_10_tokens_text = tokenizer.convert_ids_to_tokens(first_10_tokens) |
|
st.write(first_10_tokens_text) |
|
st.write("First 500 tokens:") |
|
first_500_tokens = input_ids[:500] |
|
first_500_tokens_text = tokenizer.convert_ids_to_tokens(first_500_tokens) |
|
st.write(first_500_tokens_text) |
|
st.write("[RecursiveCharacterTextSplitter](%s) parameters used:" % url) |
|
st.write( |
|
" chunk_size=%s" |
|
% chunk_size |
|
) |
|
st.write( |
|
" chunk_overlap=%s" |
|
% chunk_overlap |
|
) |
|
st.write( |
|
" length_function=len" |
|
) |
|
st.write("\n") |
|
st.write("Number of input text chunks: " + str(len(text_chunks))) |
|
st.write("") |
|
st.write("First three chunks:") |
|
st.write("\n") |
|
st.write(text_chunks[0]) |
|
st.write("") |
|
st.write(text_chunks[1]) |
|
st.write("") |
|
st.write(text_chunks[2]) |
|
st.write("\n") |
|
st.write( |
|
"Extracted and cleaned text, less sentences containing excluded words:" |
|
) |
|
st.write("") |
|
st.write(final_input_text) |
|
|
|
|
|
st.markdown( |
|
"""<style> |
|
div[class*="stRadio"] > label > div[data-testid="stMarkdownContainer"] > p { |
|
font-size: 1rem; |
|
font-weight: 400; |
|
} |
|
div[class*="stMarkdown"] > div[data-testid="stMarkdownContainer"] > p { |
|
margin-bottom: -15px; |
|
} |
|
div[class*="stCheckbox"] > label[data-baseweb="checkbox"] { |
|
margin-bottom: -15px; |
|
} |
|
div[class*="stNumberInput"] > label > div[data-testid="stMarkdownContainer"] > p { |
|
font-size: 1rem; |
|
font-weight: 400; |
|
} |
|
body > a { |
|
text-decoration: underline; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|