Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from peft import PeftModel | |
from text_processing import TextProcessor | |
import gc | |
import time | |
from pathlib import Path | |
# Configure page | |
st.set_page_config( | |
page_title="Biomedical Papers Analysis", | |
page_icon="π¬", | |
layout="wide" | |
) | |
# Initialize session state | |
if 'processed_data' not in st.session_state: | |
st.session_state.processed_data = None | |
if 'summaries' not in st.session_state: | |
st.session_state.summaries = None | |
if 'text_processor' not in st.session_state: | |
st.session_state.text_processor = None | |
def load_model(model_type): | |
"""Load appropriate model based on type""" | |
if model_type == "summarize": | |
base_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") | |
model = PeftModel.from_pretrained(base_model, "pendar02/results") | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
else: # question_focused | |
base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base") | |
model = PeftModel.from_pretrained(base_model, "pendar02/biobart-finetune") | |
tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base") | |
return model, tokenizer | |
def process_excel(uploaded_file): | |
"""Process uploaded Excel file""" | |
try: | |
df = pd.read_excel(uploaded_file) | |
required_columns = ['Abstract', 'Article Title', 'Authors', | |
'Source Title', 'Publication Year', 'DOI'] | |
# Check required columns | |
missing_columns = [col for col in required_columns if col not in df.columns] | |
if missing_columns: | |
st.error(f"Missing required columns: {', '.join(missing_columns)}") | |
return None | |
return df[required_columns] | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
return None | |
def generate_summary(text, model, tokenizer): | |
"""Generate summary for single abstract""" | |
inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True) | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 150, | |
"min_length": 50, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"early_stopping": True | |
} | |
) | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
def generate_focused_summary(question, abstracts, model, tokenizer): | |
"""Generate focused summary based on question""" | |
combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(abstracts) | |
inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True) | |
with torch.no_grad(): | |
summary_ids = model.generate( | |
**{ | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_length": 200, | |
"min_length": 50, | |
"num_beams": 4, | |
"length_penalty": 2.0, | |
"early_stopping": True | |
} | |
) | |
return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
def main(): | |
st.title("π¬ Biomedical Papers Analysis") | |
# Sidebar | |
st.sidebar.header("About") | |
st.sidebar.info( | |
"This app analyzes biomedical research papers. Upload an Excel file " | |
"containing paper details and abstracts to:" | |
"\n- Generate individual summaries" | |
"\n- Get question-focused insights" | |
) | |
# Initialize text processor if not already done | |
if st.session_state.text_processor is None: | |
with st.spinner("Loading NLP models..."): | |
st.session_state.text_processor = TextProcessor() | |
# File upload section | |
uploaded_file = st.file_uploader( | |
"Upload Excel file containing papers", | |
type=['xlsx', 'xls'], | |
help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" | |
) | |
if uploaded_file is not None: | |
# Process Excel file | |
if st.session_state.processed_data is None: | |
with st.spinner("Processing file..."): | |
df = process_excel(uploaded_file) | |
if df is not None: | |
st.session_state.processed_data = df | |
if st.session_state.processed_data is not None: | |
df = st.session_state.processed_data | |
st.write(f"π Loaded {len(df)} papers") | |
# Individual Summaries Section | |
st.header("π Individual Paper Summaries") | |
if st.session_state.summaries is None and st.button("Generate Individual Summaries"): | |
try: | |
with st.spinner("Generating summaries..."): | |
# Load summarization model | |
model, tokenizer = load_model("summarize") | |
# Process abstracts | |
progress_bar = st.progress(0) | |
summaries = [] | |
for i, abstract in enumerate(df['Abstract']): | |
summary = generate_summary(abstract, model, tokenizer) | |
summaries.append(summary) | |
progress_bar.progress((i + 1) / len(df)) | |
st.session_state.summaries = summaries | |
# Clear GPU memory | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
gc.collect() | |
except Exception as e: | |
st.error(f"Error generating summaries: {str(e)}") | |
if st.session_state.summaries is not None: | |
# Display summaries with sorting options | |
col1, col2 = st.columns(2) | |
with col1: | |
sort_column = st.selectbox("Sort by:", df.columns) | |
with col2: | |
ascending = st.checkbox("Ascending order", True) | |
# Create display dataframe | |
display_df = df.copy() | |
display_df['Summary'] = st.session_state.summaries | |
sorted_df = display_df.sort_values(by=sort_column, ascending=ascending) | |
# Show interactive table | |
st.dataframe( | |
sorted_df, | |
column_config={ | |
"Abstract": st.column_config.TextColumn( | |
"Abstract", | |
width="medium", | |
help="Original abstract text" | |
), | |
"Summary": st.column_config.TextColumn( | |
"Summary", | |
width="medium", | |
help="Generated summary" | |
) | |
}, | |
hide_index=True | |
) | |
# Question-focused Summary Section | |
st.header("β Question-focused Summary") | |
question = st.text_input("Enter your research question:") | |
if question and st.button("Generate Focused Summary"): | |
try: | |
with st.spinner("Analyzing relevant papers..."): | |
# Find relevant abstracts | |
results = st.session_state.text_processor.find_most_relevant_abstracts( | |
question, | |
df['Abstract'].tolist(), | |
top_k=5 | |
) | |
# Show spell-check suggestion if needed | |
if results['processed_question']['original'] != results['processed_question']['corrected']: | |
st.info(f"Did you mean: {results['processed_question']['corrected']}?") | |
# Load question-focused model | |
model, tokenizer = load_model("question_focused") | |
# Get relevant abstracts and generate summary | |
relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() | |
focused_summary = generate_focused_summary( | |
results['processed_question']['corrected'], | |
relevant_abstracts, | |
model, | |
tokenizer | |
) | |
# Display results | |
st.subheader("Summary") | |
st.write(focused_summary) | |
# Show relevant papers | |
st.subheader("Most Relevant Papers") | |
relevant_papers = df.iloc[results['top_indices']][ | |
['Article Title', 'Authors', 'Publication Year', 'DOI'] | |
] | |
relevant_papers['Relevance Score'] = results['scores'] | |
st.dataframe(relevant_papers, hide_index=True) | |
# Show identified medical terms | |
st.subheader("Identified Medical Terms") | |
st.write(", ".join(results['processed_question']['medical_entities'])) | |
# Clear GPU memory | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
gc.collect() | |
except Exception as e: | |
st.error(f"Error generating focused summary: {str(e)}") | |
if __name__ == "__main__": | |
main() |