Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
import re | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from peft import PeftModel | |
from text_processing import TextProcessor | |
import gc | |
from pathlib import Path | |
# Configure page | |
st.set_page_config( | |
page_title="Biomedical Papers Analysis", | |
page_icon="π¬", | |
layout="wide" | |
) | |
# Initialize session state | |
if 'processed_data' not in st.session_state: | |
st.session_state.processed_data = None | |
if 'summaries' not in st.session_state: | |
st.session_state.summaries = None | |
if 'text_processor' not in st.session_state: | |
st.session_state.text_processor = None | |
if 'processing_started' not in st.session_state: | |
st.session_state.processing_started = False | |
if 'focused_summary_generated' not in st.session_state: | |
st.session_state.focused_summary_generated = False | |
def load_model(model_type): | |
"""Load appropriate model based on type with proper memory management""" | |
try: | |
# Clear any existing cached data | |
gc.collect() | |
torch.cuda.empty_cache() | |
device = "cpu" # Force CPU usage | |
if model_type == "summarize": | |
# Load the new fine-tuned model directly | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
"pendar02/bart-large-pubmedd", | |
cache_dir="./models", | |
torch_dtype=torch.float32 | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"pendar02/bart-large-pubmedd", | |
cache_dir="./models" | |
) | |
else: # question_focused | |
base_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models", | |
torch_dtype=torch.float32 | |
).to(device) | |
model = PeftModel.from_pretrained( | |
base_model, | |
"pendar02/biobart-finetune", | |
is_trainable=False | |
).to(device) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"GanjinZero/biobart-base", | |
cache_dir="./models" | |
) | |
model.eval() | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
raise | |
def cleanup_model(model, tokenizer): | |
"""Properly cleanup model resources""" | |
try: | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
gc.collect() | |
except Exception: | |
pass | |
def process_excel(uploaded_file): | |
"""Process uploaded Excel file""" | |
try: | |
df = pd.read_excel(uploaded_file) | |
required_columns = ['Abstract', 'Article Title', 'Authors', | |
'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases'] | |
# Check required columns | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
st.error(f"Missing required columns: {', '.join(missing_columns)}") | |
return None | |
return df[required_columns] | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
return None | |
def preprocess_text(text): | |
"""Preprocess text to add appropriate formatting before summarization""" | |
if not isinstance(text, str) or not text.strip(): | |
return text | |
# Split text into sentences (basic implementation) | |
sentences = [s.strip() for s in text.replace('. ', '.\n').split('\n')] | |
# Remove empty sentences | |
sentences = [s for s in sentences if s] | |
# Join with proper line breaks | |
formatted_text = '\n'.join(sentences) | |
return formatted_text | |
def post_process_summary(summary): | |
"""Clean up and improve summary coherence""" | |
if not summary: | |
return summary | |
# Split into sentences | |
sentences = [s.strip() for s in summary.split('.')] | |
sentences = [s for s in sentences if s] # Remove empty sentences | |
# Fix common issues | |
processed_sentences = [] | |
for i, sentence in enumerate(sentences): | |
# Remove redundant words/phrases | |
sentence = sentence.replace(" and and ", " and ") | |
sentence = sentence.replace("appointment and appointment", "appointment") | |
# Fix common grammatical issues | |
sentence = sentence.replace("Cancers distress", "Cancer distress") | |
sentence = sentence.replace(" ", " ") # Remove double spaces | |
# Capitalize first letter of each sentence | |
sentence = sentence.capitalize() | |
# Add to processed sentences if not empty | |
if sentence.strip(): | |
processed_sentences.append(sentence) | |
# Join sentences with proper spacing and punctuation | |
cleaned_summary = '. '.join(processed_sentences) | |
if cleaned_summary and not cleaned_summary.endswith('.'): | |
cleaned_summary += '.' | |
return cleaned_summary | |
def improve_summary_generation(text, model, tokenizer): | |
"""Generate improved summary with better prompt and validation""" | |
if not isinstance(text, str) or not text.strip(): | |
return "No abstract available to summarize." | |
# Add a more specific prompt | |
formatted_text = ( | |
"Summarize this medical research paper following this structure exactly:\n" | |
"1. Background and objectives\n" | |
"2. Methods\n" | |
"3. Key findings with specific numbers/percentages\n" | |
"4. Main conclusions\n" | |
"Original text: " + preprocess_text(text) | |
) | |
# Adjust generation parameters | |
inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 200, | |
"min_length": 50, | |
"num_beams": 5, | |
"length_penalty": 1.5, | |
"no_repeat_ngram_size": 3, | |
"temperature": 0.7, | |
"repetition_penalty": 1.5 | |
} | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
# Post-process the summary | |
processed_summary = post_process_summary(summary) | |
# Validate the summary | |
if not validate_summary(processed_summary, text): | |
# If validation fails, try one more time with different parameters | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 200, | |
"min_length": 50, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"no_repeat_ngram_size": 4, | |
"temperature": 0.8, | |
"repetition_penalty": 2.0 | |
} | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
processed_summary = post_process_summary(summary) | |
return processed_summary | |
def validate_summary(summary, original_text): | |
"""Validate summary content against original text""" | |
# Check for age inconsistencies | |
age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower()) | |
if len(age_mentions) > 1: # Multiple age mentions | |
return False | |
# Check for repetitive sentences | |
sentences = summary.split('.') | |
unique_sentences = set(s.strip().lower() for s in sentences if s.strip()) | |
if len(sentences) - len(unique_sentences) > 1: # More than one duplicate | |
return False | |
# Check summary isn't too long or too short compared to original | |
summary_words = len(summary.split()) | |
original_words = len(original_text.split()) | |
if summary_words < 20 or summary_words > original_words * 0.8: | |
return False | |
return True | |
def generate_focused_summary(question, abstracts, model, tokenizer): | |
"""Generate focused summary based on question""" | |
# Preprocess each abstract | |
formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts] | |
combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts) | |
inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 200, | |
"min_length": 50, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"early_stopping": True | |
} | |
) | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
def create_filter_controls(df, sort_column): | |
"""Create appropriate filter controls based on the selected column""" | |
filtered_df = df.copy() | |
if sort_column == 'Publication Year': | |
# Year range slider | |
year_min = int(df['Publication Year'].min()) | |
year_max = int(df['Publication Year'].max()) | |
col1, col2 = st.columns(2) | |
with col1: | |
start_year = st.number_input('From Year', | |
min_value=year_min, | |
max_value=year_max, | |
value=year_min) | |
with col2: | |
end_year = st.number_input('To Year', | |
min_value=year_min, | |
max_value=year_max, | |
value=year_max) | |
filtered_df = filtered_df[ | |
(filtered_df['Publication Year'] >= start_year) & | |
(filtered_df['Publication Year'] <= end_year) | |
] | |
elif sort_column == 'Authors': | |
# Multi-select for authors | |
unique_authors = sorted(set( | |
author.strip() | |
for authors in df['Authors'].dropna() | |
for author in authors.split(';') | |
)) | |
selected_authors = st.multiselect( | |
'Select Authors', | |
unique_authors | |
) | |
if selected_authors: | |
filtered_df = filtered_df[ | |
filtered_df['Authors'].apply( | |
lambda x: any(author in str(x) for author in selected_authors) | |
) | |
] | |
elif sort_column == 'Source Title': | |
# Multi-select for source titles | |
unique_sources = sorted(df['Source Title'].unique()) | |
selected_sources = st.multiselect( | |
'Select Sources', | |
unique_sources | |
) | |
if selected_sources: | |
filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)] | |
elif sort_column == 'Article Title': | |
# Only alphabetical sorting, no filtering | |
pass | |
elif sort_column == 'Times Cited': | |
# Cited count range slider | |
cited_min = int(df['Times Cited'].min()) | |
cited_max = int(df['Times Cited'].max()) | |
col1, col2 = st.columns(2) | |
with col1: | |
start_cited = st.number_input('From Cited Count', | |
min_value=cited_min, | |
max_value=cited_max, | |
value=cited_min) | |
with col2: | |
end_cited = st.number_input('To Cited Count', | |
min_value=cited_min, | |
max_value=cited_max, | |
value=cited_max) | |
filtered_df = filtered_df[ | |
(filtered_df['Times Cited'] >= start_cited) & | |
(filtered_df['Times Cited'] <= end_cited) | |
] | |
return filtered_df | |
def main(): | |
st.title("π¬ Biomedical Papers Analysis") | |
# File upload section | |
uploaded_file = st.file_uploader( | |
"Upload Excel file containing papers", | |
type=['xlsx', 'xls'], | |
help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" | |
) | |
# Question input - moved up but hidden initially | |
question_container = st.empty() | |
question = "" | |
if uploaded_file is not None: | |
# Process Excel file | |
if st.session_state.processed_data is None: | |
with st.spinner("Processing file..."): | |
df = process_excel(uploaded_file) | |
if df is not None: | |
st.session_state.processed_data = df.dropna(subset=["Abstract"]) | |
if st.session_state.processed_data is not None: | |
df = st.session_state.processed_data | |
st.write(f"π Loaded {len(df)} papers with abstracts") | |
# Get question before processing | |
with question_container: | |
question = st.text_input( | |
"Enter your research question (optional):", | |
help="If provided, a question-focused summary will be generated after individual summaries" | |
) | |
# Single button for both processes | |
if not st.session_state.get('processing_started', False): | |
if st.button("Start Analysis"): | |
st.session_state.processing_started = True | |
# Show processing status and results | |
if st.session_state.get('processing_started', False): | |
# Individual Summaries Section | |
st.header("π Individual Paper Summaries") | |
# Generate summaries if not already done | |
if st.session_state.summaries is None: | |
try: | |
with st.spinner("Generating individual paper summaries..."): | |
model, tokenizer = load_model("summarize") | |
summaries = [] | |
progress_bar = st.progress(0) | |
for idx, abstract in enumerate(df['Abstract']): | |
summary = improve_summary_generation(abstract, model, tokenizer) | |
summaries.append(summary) | |
progress_bar.progress((idx + 1) / len(df)) | |
st.session_state.summaries = summaries | |
cleanup_model(model, tokenizer) | |
progress_bar.empty() | |
except Exception as e: | |
st.error(f"Error generating summaries: {str(e)}") | |
st.session_state.processing_started = False | |
# Display summaries with improved sorting and filtering | |
if st.session_state.summaries is not None: | |
col1, col2 = st.columns(2) | |
with col1: | |
sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited'] | |
sort_column = st.selectbox("Sort/Filter by:", sort_options) | |
with col2: | |
# Only show A-Z/Z-A option for Article Title | |
if sort_column == 'Article Title': | |
ascending = st.radio( | |
"Sort order", | |
["A to Z", "Z to A"], | |
horizontal=True | |
) == "A to Z" | |
elif sort_column == 'Times Cited': | |
ascending = st.radio( | |
"Sort order", | |
["Most cited", "Least cited"], | |
horizontal=True | |
) == "Least cited" | |
else: | |
ascending = True # Default for other columns | |
# Create display dataframe | |
display_df = df.copy() | |
display_df['Summary'] = st.session_state.summaries | |
display_df['Publication Year'] = display_df['Publication Year'].astype(int) | |
display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True) | |
display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int) | |
# Apply filters | |
filtered_df = create_filter_controls(display_df, sort_column) | |
if sort_column == 'Article Title': | |
# Sort alphabetically | |
sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) | |
else: | |
# Keep original order for other columns after filtering | |
# Keep original order for other columns after filtering | |
sorted_df = filtered_df | |
# Show number of filtered results | |
if len(sorted_df) != len(display_df): | |
st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers") | |
# Apply custom styling | |
st.markdown(""" | |
<style> | |
.paper-info { | |
border: 1px solid #ddd; | |
padding: 15px; | |
margin-bottom: 20px; | |
border-radius: 5px; | |
} | |
.paper-section { | |
margin-bottom: 10px; | |
} | |
.section-header { | |
font-weight: bold; | |
color: #555; | |
margin-bottom: 8px; | |
} | |
.paper-title { | |
margin-top: 5px; | |
margin-bottom: 10px; | |
} | |
.paper-meta { | |
font-size: 0.9em; | |
color: #666; | |
} | |
.doi-link { | |
color: #0366d6; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Display papers using the filtered and sorted dataframe | |
for _, row in sorted_df.iterrows(): | |
paper_info_cols = st.columns([1, 1]) | |
with paper_info_cols[0]: # PAPER column | |
st.markdown('<div class="paper-section"><div class="section-header">PAPER</div>', unsafe_allow_html=True) | |
st.markdown(f""" | |
<div class="paper-info"> | |
<div class="paper-title">{row['Article Title']}</div> | |
<div class="paper-meta"> | |
<strong>Authors:</strong> {row['Authors']}<br> | |
<strong>Source:</strong> {row['Source Title']}<br> | |
<strong>Publication Year:</strong> {row['Publication Year']}<br> | |
<strong>Times Cited:</strong> {row['Times Cited']}<br> | |
<strong>DOI:</strong> {row['DOI'] if pd.notna(row['DOI']) else 'None'} | |
</div> | |
</div> | |
""", unsafe_allow_html=True) | |
with paper_info_cols[1]: # SUMMARY column | |
st.markdown('<div class="paper-section"><div class="section-header">SUMMARY</div>', unsafe_allow_html=True) | |
st.markdown(f""" | |
<div class="paper-info"> | |
{row['Summary']} | |
</div> | |
""", unsafe_allow_html=True) | |
# Add spacing between papers | |
st.markdown("<div style='margin-bottom: 20px;'></div>", unsafe_allow_html=True) | |
# Question-focused Summary Section (only if question provided) | |
if question.strip(): | |
st.header("β Question-focused Summary") | |
if not st.session_state.get('focused_summary_generated', False): | |
try: | |
with st.spinner("Analyzing relevant papers..."): | |
# Initialize text processor if needed | |
if st.session_state.text_processor is None: | |
st.session_state.text_processor = TextProcessor() | |
# Find relevant abstracts | |
results = st.session_state.text_processor.find_most_relevant_abstracts( | |
question, | |
df['Abstract'].tolist(), | |
top_k=5 | |
) | |
# Load question-focused model | |
model, tokenizer = load_model("question_focused") | |
# Generate focused summary | |
relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() | |
focused_summary = generate_focused_summary( | |
question, | |
relevant_abstracts, | |
model, | |
tokenizer | |
) | |
# Store results | |
st.session_state.focused_summary = focused_summary | |
st.session_state.relevant_papers = df.iloc[results['top_indices']] | |
st.session_state.relevance_scores = results['scores'] | |
st.session_state.focused_summary_generated = True | |
# Cleanup second model | |
cleanup_model(model, tokenizer) | |
except Exception as e: | |
st.error(f"Error generating focused summary: {str(e)}") | |
# Display focused summary results | |
if st.session_state.get('focused_summary_generated', False): | |
st.subheader("Summary") | |
st.write(st.session_state.focused_summary) | |
st.subheader("Most Relevant Papers") | |
relevant_papers = st.session_state.relevant_papers[ | |
['Article Title', 'Authors', 'Publication Year', 'DOI'] | |
].copy() | |
relevant_papers['Relevance Score'] = st.session_state.relevance_scores | |
relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int) | |
st.dataframe(relevant_papers, hide_index=True) | |
if __name__ == "__main__": | |
main() |