import streamlit as st import pandas as pd import torch import re from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel from text_processing import TextProcessor import gc from pathlib import Path # Configure page st.set_page_config( page_title="Biomedical Papers Analysis", page_icon="🔬", layout="wide" ) # Initialize session state if 'processed_data' not in st.session_state: st.session_state.processed_data = None if 'summaries' not in st.session_state: st.session_state.summaries = None if 'text_processor' not in st.session_state: st.session_state.text_processor = None if 'processing_started' not in st.session_state: st.session_state.processing_started = False if 'focused_summary_generated' not in st.session_state: st.session_state.focused_summary_generated = False def load_model(model_type): """Load appropriate model based on type with proper memory management""" try: # Clear any existing cached data gc.collect() torch.cuda.empty_cache() device = "cpu" # Force CPU usage if model_type == "summarize": # Load the new fine-tuned model directly model = AutoModelForSeq2SeqLM.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models", torch_dtype=torch.float32 ).to(device) tokenizer = AutoTokenizer.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models" ) else: # question_focused base_model = AutoModelForSeq2SeqLM.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models", torch_dtype=torch.float32 ).to(device) model = PeftModel.from_pretrained( base_model, "pendar02/biobart-finetune", is_trainable=False ).to(device) tokenizer = AutoTokenizer.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models" ) model.eval() return model, tokenizer except Exception as e: st.error(f"Error loading model: {str(e)}") raise def cleanup_model(model, tokenizer): """Properly cleanup model resources""" try: del model del tokenizer torch.cuda.empty_cache() gc.collect() except Exception: pass @st.cache_data def process_excel(uploaded_file): """Process uploaded Excel file""" try: df = pd.read_excel(uploaded_file) required_columns = ['Abstract', 'Article Title', 'Authors', 'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases'] # Check required columns missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: st.error(f"Missing required columns: {', '.join(missing_columns)}") return None return df[required_columns] except Exception as e: st.error(f"Error processing file: {str(e)}") return None def preprocess_text(text): """Preprocess text to add appropriate formatting before summarization""" if not isinstance(text, str) or not text.strip(): return text # Split text into sentences (basic implementation) sentences = [s.strip() for s in text.replace('. ', '.\n').split('\n')] # Remove empty sentences sentences = [s for s in sentences if s] # Join with proper line breaks formatted_text = '\n'.join(sentences) return formatted_text def post_process_summary(summary): """Clean up and improve summary coherence""" if not summary: return summary # Split into sentences sentences = [s.strip() for s in summary.split('.')] sentences = [s for s in sentences if s] # Remove empty sentences # Fix common issues processed_sentences = [] for i, sentence in enumerate(sentences): # Remove redundant words/phrases sentence = sentence.replace(" and and ", " and ") sentence = sentence.replace("appointment and appointment", "appointment") # Fix common grammatical issues sentence = sentence.replace("Cancers distress", "Cancer distress") sentence = sentence.replace(" ", " ") # Remove double spaces # Capitalize first letter of each sentence sentence = sentence.capitalize() # Add to processed sentences if not empty if sentence.strip(): processed_sentences.append(sentence) # Join sentences with proper spacing and punctuation cleaned_summary = '. '.join(processed_sentences) if cleaned_summary and not cleaned_summary.endswith('.'): cleaned_summary += '.' return cleaned_summary def improve_summary_generation(text, model, tokenizer): """Generate improved summary with better prompt and validation""" if not isinstance(text, str) or not text.strip(): return "No abstract available to summarize." # Add a more specific prompt formatted_text = ( "Summarize this medical research paper following this structure exactly:\n" "1. Background and objectives\n" "2. Methods\n" "3. Key findings with specific numbers/percentages\n" "4. Main conclusions\n" "Original text: " + preprocess_text(text) ) # Adjust generation parameters inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): summary_ids = model.generate( **{ "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_length": 200, "min_length": 50, "num_beams": 5, "length_penalty": 1.5, "no_repeat_ngram_size": 3, "temperature": 0.7, "repetition_penalty": 1.5 } ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) # Post-process the summary processed_summary = post_process_summary(summary) # Validate the summary if not validate_summary(processed_summary, text): # If validation fails, try one more time with different parameters with torch.no_grad(): summary_ids = model.generate( **{ "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_length": 200, "min_length": 50, "num_beams": 4, "length_penalty": 2.0, "no_repeat_ngram_size": 4, "temperature": 0.8, "repetition_penalty": 2.0 } ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) processed_summary = post_process_summary(summary) return processed_summary def validate_summary(summary, original_text): """Validate summary content against original text""" # Check for age inconsistencies age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower()) if len(age_mentions) > 1: # Multiple age mentions return False # Check for repetitive sentences sentences = summary.split('.') unique_sentences = set(s.strip().lower() for s in sentences if s.strip()) if len(sentences) - len(unique_sentences) > 1: # More than one duplicate return False # Check summary isn't too long or too short compared to original summary_words = len(summary.split()) original_words = len(original_text.split()) if summary_words < 20 or summary_words > original_words * 0.8: return False return True def generate_focused_summary(question, abstracts, model, tokenizer): """Generate focused summary based on question""" # Preprocess each abstract formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts] combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts) inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): summary_ids = model.generate( **{ "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_length": 200, "min_length": 50, "num_beams": 4, "length_penalty": 2.0, "early_stopping": True } ) return tokenizer.decode(summary_ids[0], skip_special_tokens=True) def create_filter_controls(df, sort_column): """Create appropriate filter controls based on the selected column""" filtered_df = df.copy() if sort_column == 'Publication Year': # Year range slider year_min = int(df['Publication Year'].min()) year_max = int(df['Publication Year'].max()) col1, col2 = st.columns(2) with col1: start_year = st.number_input('From Year', min_value=year_min, max_value=year_max, value=year_min) with col2: end_year = st.number_input('To Year', min_value=year_min, max_value=year_max, value=year_max) filtered_df = filtered_df[ (filtered_df['Publication Year'] >= start_year) & (filtered_df['Publication Year'] <= end_year) ] elif sort_column == 'Authors': # Multi-select for authors unique_authors = sorted(set( author.strip() for authors in df['Authors'].dropna() for author in authors.split(';') )) selected_authors = st.multiselect( 'Select Authors', unique_authors ) if selected_authors: filtered_df = filtered_df[ filtered_df['Authors'].apply( lambda x: any(author in str(x) for author in selected_authors) ) ] elif sort_column == 'Source Title': # Multi-select for source titles unique_sources = sorted(df['Source Title'].unique()) selected_sources = st.multiselect( 'Select Sources', unique_sources ) if selected_sources: filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)] elif sort_column == 'Article Title': # Only alphabetical sorting, no filtering pass elif sort_column == 'Times Cited': # Cited count range slider cited_min = int(df['Times Cited'].min()) cited_max = int(df['Times Cited'].max()) col1, col2 = st.columns(2) with col1: start_cited = st.number_input('From Cited Count', min_value=cited_min, max_value=cited_max, value=cited_min) with col2: end_cited = st.number_input('To Cited Count', min_value=cited_min, max_value=cited_max, value=cited_max) filtered_df = filtered_df[ (filtered_df['Times Cited'] >= start_cited) & (filtered_df['Times Cited'] <= end_cited) ] return filtered_df def main(): st.title("🔬 Biomedical Papers Analysis") # File upload section uploaded_file = st.file_uploader( "Upload Excel file containing papers", type=['xlsx', 'xls'], help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" ) # Question input - moved up but hidden initially question_container = st.empty() question = "" if uploaded_file is not None: # Process Excel file if st.session_state.processed_data is None: with st.spinner("Processing file..."): df = process_excel(uploaded_file) if df is not None: st.session_state.processed_data = df.dropna(subset=["Abstract"]) if st.session_state.processed_data is not None: df = st.session_state.processed_data st.write(f"📊 Loaded {len(df)} papers with abstracts") # Get question before processing with question_container: question = st.text_input( "Enter your research question (optional):", help="If provided, a question-focused summary will be generated after individual summaries" ) # Single button for both processes if not st.session_state.get('processing_started', False): if st.button("Start Analysis"): st.session_state.processing_started = True # Show processing status and results if st.session_state.get('processing_started', False): # Individual Summaries Section st.header("📝 Individual Paper Summaries") # Generate summaries if not already done if st.session_state.summaries is None: try: with st.spinner("Generating individual paper summaries..."): model, tokenizer = load_model("summarize") summaries = [] progress_bar = st.progress(0) for idx, abstract in enumerate(df['Abstract']): summary = improve_summary_generation(abstract, model, tokenizer) summaries.append(summary) progress_bar.progress((idx + 1) / len(df)) st.session_state.summaries = summaries cleanup_model(model, tokenizer) progress_bar.empty() except Exception as e: st.error(f"Error generating summaries: {str(e)}") st.session_state.processing_started = False # Display summaries with improved sorting and filtering if st.session_state.summaries is not None: col1, col2 = st.columns(2) with col1: sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited'] sort_column = st.selectbox("Sort/Filter by:", sort_options) with col2: # Only show A-Z/Z-A option for Article Title if sort_column == 'Article Title': ascending = st.radio( "Sort order", ["A to Z", "Z to A"], horizontal=True ) == "A to Z" elif sort_column == 'Times Cited': ascending = st.radio( "Sort order", ["Most cited", "Least cited"], horizontal=True ) == "Least cited" else: ascending = True # Default for other columns # Create display dataframe display_df = df.copy() display_df['Summary'] = st.session_state.summaries display_df['Publication Year'] = display_df['Publication Year'].astype(int) display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True) display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int) # Apply filters filtered_df = create_filter_controls(display_df, sort_column) if sort_column == 'Article Title': # Sort alphabetically sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) else: # Keep original order for other columns after filtering # Keep original order for other columns after filtering sorted_df = filtered_df # Show number of filtered results if len(sorted_df) != len(display_df): st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers") # Apply custom styling st.markdown(""" """, unsafe_allow_html=True) # Display papers using the filtered and sorted dataframe for _, row in sorted_df.iterrows(): paper_info_cols = st.columns([1, 1]) with paper_info_cols[0]: # PAPER column st.markdown('
PAPER
', unsafe_allow_html=True) st.markdown(f"""
{row['Article Title']}
Authors: {row['Authors']}
Source: {row['Source Title']}
Publication Year: {row['Publication Year']}
Times Cited: {row['Times Cited']}
DOI: {row['DOI'] if pd.notna(row['DOI']) else 'None'}
""", unsafe_allow_html=True) with paper_info_cols[1]: # SUMMARY column st.markdown('
SUMMARY
', unsafe_allow_html=True) st.markdown(f"""
{row['Summary']}
""", unsafe_allow_html=True) # Add spacing between papers st.markdown("
", unsafe_allow_html=True) # Question-focused Summary Section (only if question provided) if question.strip(): st.header("❓ Question-focused Summary") if not st.session_state.get('focused_summary_generated', False): try: with st.spinner("Analyzing relevant papers..."): # Initialize text processor if needed if st.session_state.text_processor is None: st.session_state.text_processor = TextProcessor() # Find relevant abstracts results = st.session_state.text_processor.find_most_relevant_abstracts( question, df['Abstract'].tolist(), top_k=5 ) # Load question-focused model model, tokenizer = load_model("question_focused") # Generate focused summary relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist() focused_summary = generate_focused_summary( question, relevant_abstracts, model, tokenizer ) # Store results st.session_state.focused_summary = focused_summary st.session_state.relevant_papers = df.iloc[results['top_indices']] st.session_state.relevance_scores = results['scores'] st.session_state.focused_summary_generated = True # Cleanup second model cleanup_model(model, tokenizer) except Exception as e: st.error(f"Error generating focused summary: {str(e)}") # Display focused summary results if st.session_state.get('focused_summary_generated', False): st.subheader("Summary") st.write(st.session_state.focused_summary) st.subheader("Most Relevant Papers") relevant_papers = st.session_state.relevant_papers[ ['Article Title', 'Authors', 'Publication Year', 'DOI'] ].copy() relevant_papers['Relevance Score'] = st.session_state.relevance_scores relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int) st.dataframe(relevant_papers, hide_index=True) if __name__ == "__main__": main()