import streamlit as st import pandas as pd import torch import re from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from peft import PeftModel from text_processing import TextProcessor import gc from pathlib import Path import concurrent.futures import time import nltk from nltk.tokenize import sent_tokenize from concurrent.futures import ThreadPoolExecutor # Add this import nltk.download('punkt') # Configure page st.set_page_config( page_title="Biomedical Papers Analysis", page_icon="đŦ", layout="wide" ) # Initialize session state if 'processed_data' not in st.session_state: st.session_state.processed_data = None if 'summaries' not in st.session_state: st.session_state.summaries = None if 'text_processor' not in st.session_state: st.session_state.text_processor = None if 'processing_started' not in st.session_state: st.session_state.processing_started = False if 'focused_summary_generated' not in st.session_state: st.session_state.focused_summary_generated = False if 'current_model' not in st.session_state: st.session_state.current_model = None if 'current_tokenizer' not in st.session_state: st.session_state.current_tokenizer = None if 'model_type' not in st.session_state: st.session_state.model_type = None # TextProcessor class definition try: from text_processing import TextProcessor except ImportError: class TextProcessor: def find_most_relevant_abstracts(self, question, abstracts, top_k=5): return { 'top_indices': list(range(min(top_k, len(abstracts)))), 'scores': [1.0] * min(top_k, len(abstracts)) } def load_model(model_type): """Load appropriate model based on type with proper memory management""" try: # Clear any existing cached data gc.collect() torch.cuda.empty_cache() device = "cpu" # Force CPU usage if model_type == "summarize": # Load the new fine-tuned model directly model = AutoModelForSeq2SeqLM.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models", torch_dtype=torch.float32 ).to(device) tokenizer = AutoTokenizer.from_pretrained( "pendar02/bart-large-pubmedd", cache_dir="./models" ) else: # question_focused base_model = AutoModelForSeq2SeqLM.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models", torch_dtype=torch.float32 ).to(device) model = PeftModel.from_pretrained( base_model, "pendar02/biobart-finetune", is_trainable=False ).to(device) tokenizer = AutoTokenizer.from_pretrained( "GanjinZero/biobart-base", cache_dir="./models" ) model.eval() return model, tokenizer except Exception as e: st.error(f"Error loading model: {str(e)}") raise def get_model(model_type): """Get model from session state or load if needed""" try: if (st.session_state.current_model is None or st.session_state.model_type != model_type): # Clean up existing model if st.session_state.current_model is not None: cleanup_model(st.session_state.current_model, st.session_state.current_tokenizer) # Load new model model, tokenizer = load_model(model_type) st.session_state.current_model = model st.session_state.current_tokenizer = tokenizer st.session_state.model_type = model_type return st.session_state.current_model, st.session_state.current_tokenizer except Exception as e: st.error(f"Error loading model: {str(e)}") st.session_state.processing_started = False return None, None def cleanup_model(model, tokenizer): """Properly cleanup model resources""" try: del model del tokenizer torch.cuda.empty_cache() gc.collect() except Exception: pass @st.cache_data def process_excel(uploaded_file): """Process uploaded Excel file""" try: df = pd.read_excel(uploaded_file) required_columns = ['Abstract', 'Article Title', 'Authors', 'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases'] # Check required columns first missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: st.error("â Missing required columns: " + ", ".join(missing_columns)) st.error("Please ensure your Excel file contains all required columns.") return None # Only proceed with validation if all required columns exist if len(df) > 5: st.error("â Your file contains more than 5 papers. Please upload a file with maximum 5 papers.") return None # Now safe to validate structure as we know columns exist is_valid, messages = validate_excel_structure(df) if not is_valid: for msg in messages: st.error(f"â {msg}") return None return df[required_columns] except Exception as e: st.error(f"â Error reading file: {str(e)}") st.error("Please check if your file is in the correct Excel format (.xlsx or .xls)") return None def validate_excel_structure(df): """Validate the structure and content of the Excel file""" validation_messages = [] # Check for minimum content if len(df) == 0: validation_messages.append("File contains no data") return False, validation_messages try: # Check publication year format - this is useful for sorting/filtering df['Publication Year'] = pd.to_numeric(df['Publication Year'], errors='coerce') if df['Publication Year'].isna().any(): validation_messages.append("Some publication years are invalid. Please ensure all years are in numeric format (e.g., 2024)") else: years = df['Publication Year'].dropna() if len(years) > 0: if years.min() < 1900 or years.max() > 2025: validation_messages.append("Publication years must be between 1900 and 2025") # For short abstracts - just show a warning short_abstracts = df['Abstract'].fillna('').astype(str).str.len() < 50 if short_abstracts.any(): st.warning("âšī¸ Some abstracts are quite short, but will still be processed") except Exception as e: validation_messages.append(f"Error checking data format: {str(e)}") return len(validation_messages) == 0, validation_messages def preprocess_text(text): """Enhanced text preprocessing with improved header and list handling""" if not isinstance(text, str) or not text.strip(): return text # Initial cleanup text = re.sub(r'\s+', ' ', text.strip()) # Standardize case for specific terms (e.g., PRIME -> Prime) text = re.sub(r'\b([A-Z]{2,})\b', lambda m: m.group(1).title(), text) # Fix spacing around punctuation and parentheses text = re.sub(r'\s*:\s*', ': ', text) text = re.sub(r'\s*,\s*', ', ', text) text = re.sub(r'\(\s*([ivx\d]+)\s*\)', r'(\1)', text) # Convert numbered lists to consistent format text = re.sub(r'(?m)^\s*(\d+)\.\s*', r'(\1) ', text) # Normalize section headers (using comprehensive patterns) section_patterns = { r'\b(?:Introduction|Background|Objectives|Purpose|Context)\s*:': 'Background and Objectives: ', r'\b(?:Methods|Materials and Methods|Approach|Study Design|Experimental Design)\s*:': 'Methods: ', r'\b(?:Results|Findings|Observations|Key Findings)\s*:': 'Results: ', r'\b(?:Discussion|Analysis|Implications|Interpretation)\s*:': 'Discussion: ', r'\b(?:Conclusion|Conclusions|Summary|Final Remarks)\s*:': 'Conclusions: ' } # Remove nested headers nested_header_pattern = r'\d+\.\s*(?:Background|Objectives|Methods|Results|Discussion|Conclusions)\s*:' text = re.sub(nested_header_pattern, '', text) # Standardize section headers for pattern, replacement in section_patterns.items(): text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) # Split merged section headers text = re.sub(r'(?i)Results\s+and\s+Conclusions:', 'Results: ', text) # Handle special characters and normalize spacing text = re.sub(r'[ââ]', '"', text) # Correctly handle double quotes text = re.sub(r"[ââ]", "'", text) # Correctly handle single quotes text = re.sub(r'\s*-\s*', '-', text) # Tokenize and capitalize sentences sentences = re.split(r'(?<=\w[.!?])\s+|\n(?=\d+\.|\(\w+\)|-)', text) formatted_sentences = [s.strip().capitalize() for s in sentences if s.strip()] return ' '.join(formatted_sentences) def post_process_summary(summary): """Enhanced summary post-processing with improved formatting.""" if not summary: return summary # Step 1: Remove empty or redundant headers summary = re.sub(r'\b(?:Background|Objectives|Methods|Results|Conclusions)\s*:\s*\.?\s*', '', summary) # Step 2: Fix spacing issues in lists and parentheses summary = re.sub(r'\(\s*([ivx\d]+)\s*\)', r'(\1)', summary) # Fix space inside parentheses summary = re.sub(r'\s*,\s*(\([ivx\d]+\))', r', \1', summary) # Fix spacing before list items # Step 3: Ensure proper punctuation and spacing summary = re.sub(r'(?<=[.!?])\s*([A-Z])', r' \1', summary) # Add space after punctuation summary = re.sub(r'\s*:\s*', ': ', summary) # Fix spacing around colons # Step 4: Remove sections with too little content sections = [s.strip() for s in summary.split('\n') if len(s.split()) > 3] summary = ' '.join(sections) # Step 5: Remove multiple periods summary = re.sub(r'\.\.+', '.', summary) # Step 6: Ensure summary ends with a single period summary = summary.strip() if not summary.endswith('.'): summary += '.' return summary def generate_focused_summary(question, abstracts, model, tokenizer): """Generate a structured summary based on the given question and abstracts.""" # Preprocess and clean abstracts formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts if abstract.strip()] if not formatted_abstracts: raise ValueError("Abstracts list is empty or improperly formatted.") # Join abstracts with separator abstracts_content = " [SEP] ".join(formatted_abstracts) # Create the prompt prompt = f""" Generate a structured summary based on the given abstracts and the question. Follow these rules STRICTLY: **QUESTION:** {question} **SECTION FORMATTING RULES:** 1. Each section MUST start with the section name followed by ": " (e.g., "Background: "). 2. Each section MUST end with a period. 3. Write complete, grammatically correct sentences. 4. Do not use bullet points, lists, or combined section headers. 5. Maintain the exact order of sections: Background, Objectives, Methods, Results, Conclusions. 6. Avoid redundancies, incomplete thoughts, and cutting sentences mid-way. 7. Use transition words (e.g., "Additionally," "Furthermore," "Moreover") to connect ideas naturally. **REQUIRED SECTIONS AND CONTENT:** 1. **Background**: - Provide the context and motivation for the study. - Do not mention objectives, methods, or results in this section. 2. **Objectives**: - Clearly state the aim(s) of the study. - Avoid referencing any methods or findings. 3. **Methods**: - Describe the approach, tools, and procedures used. - Do not include any findings or results in this section. 4. **Results**: - Summarize the key findings, including relevant statistics and outcomes. - Mention implications only if explicitly stated in the abstracts. 5. **Conclusions**: - Highlight the overall interpretation of findings. - Emphasize the significance and implications of the study. **CRITICAL FORMAT RULES:** 1. Each section title must be followed by a colon and a space. 2. All sentences must be grammatically complete and coherent. 3. Avoid bullet points, lists, and repeated sections. 4. End each section with a period. **INPUT ABSTRACTS:** {abstracts_content} """ # Tokenize input (use the correct variable `prompt` here) inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): summary_ids = model.generate( **{ "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "max_length": 280, "min_length": 100, "num_beams": 4, "length_penalty": 2.0, "no_repeat_ngram_size": 2, "temperature": 0.7, "do_sample": False } ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return post_process_summary(summary) def process_papers_in_batches(df, model, tokenizer, batch_size=2): """Process papers in batches for better efficiency""" abstracts = df['Abstract'].tolist() summaries = [] with ThreadPoolExecutor(max_workers=4) as executor: # Parallel processing future_to_batch = {executor.submit(generate_focused_summary, "Focus on key findings and methods.", [abstract], model, tokenizer): abstract for abstract in abstracts} for future in future_to_batch: summaries.append(future.result()) return summaries def create_filter_controls(df, sort_column): """Create appropriate filter controls based on the selected column""" filtered_df = df.copy() if sort_column == 'Publication Year': # Year range slider year_min = int(df['Publication Year'].min()) year_max = int(df['Publication Year'].max()) col1, col2 = st.columns(2) with col1: start_year = st.number_input('From Year', min_value=year_min, max_value=year_max, value=year_min) with col2: end_year = st.number_input('To Year', min_value=year_min, max_value=year_max, value=year_max) filtered_df = filtered_df[ (filtered_df['Publication Year'] >= start_year) & (filtered_df['Publication Year'] <= end_year) ] elif sort_column == 'Authors': # Multi-select for authors unique_authors = sorted(set( author.strip() for authors in df['Authors'].dropna() for author in authors.split(';') )) selected_authors = st.multiselect( 'Select Authors', unique_authors ) if selected_authors: filtered_df = filtered_df[ filtered_df['Authors'].apply( lambda x: any(author in str(x) for author in selected_authors) ) ] elif sort_column == 'Source Title': # Multi-select for source titles unique_sources = sorted(df['Source Title'].unique()) selected_sources = st.multiselect( 'Select Sources', unique_sources ) if selected_sources: filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)] elif sort_column == 'Article Title': # Only alphabetical sorting, no filtering pass return filtered_df def main(): st.title("đŦ Biomedical Papers Analysis") st.info(""" **đ File Upload Requirements:** - Excel file (.xlsx or .xls) with **maximum 5 papers** - Must contain these columns: âĸ Abstract âĸ Article Title âĸ Authors âĸ Source Title âĸ Publication Year âĸ DOI âĸ Times Cited, All Databases """) # File upload section uploaded_file = st.file_uploader( "Upload Excel file containing papers (max 5 papers)", type=['xlsx', 'xls'], help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI" ) # Question input - moved up but hidden initially question_container = st.empty() question = "" if uploaded_file is not None: # Process Excel file if st.session_state.processed_data is None: with st.spinner("Processing file..."): df = process_excel(uploaded_file) if df is not None: df = df.dropna(subset=["Abstract"]) if len(df) > 0: st.session_state.processed_data = df st.success(f"â Successfully loaded {len(df)} papers with abstracts") else: st.error("â No valid papers found after processing. Please check your file.") if st.session_state.processed_data is not None: df = st.session_state.processed_data st.write(f"đ Loaded {len(df)} papers with abstracts") # Get question before processing with question_container: question = st.text_input( "Enter your research question (optional):", help="If provided, a question-focused summary will be generated after individual summaries" ) # Single button for both processes if not st.session_state.get('processing_started', False): if st.button("Start Analysis"): st.session_state.processing_started = True # Show processing status and results if st.session_state.get('processing_started', False): # Individual Summaries Section st.header("đ Individual Paper Summaries") # Generate summaries if not already done if st.session_state.summaries is None: try: with st.spinner("Generating individual paper summaries..."): model, tokenizer = get_model("summarize") if model is None or tokenizer is None: reset_processing_state() return start_time = time.time() st.session_state.summaries = process_papers_in_batches( df, model, tokenizer, batch_size=2 ) end_time = time.time() st.write(f"Processing time: {end_time - start_time:.2f} seconds") except Exception as e: st.error(f"Error generating summaries: {str(e)}") reset_processing_state() # Display summaries with improved sorting and filtering if st.session_state.summaries is not None: col1, col2 = st.columns(2) with col1: sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title', 'Times Cited'] sort_column = st.selectbox("Sort/Filter by:", sort_options) with col2: if sort_column == 'Article Title': ascending = st.radio( "Sort order", ["A to Z", "Z to A"], horizontal=True ) == "A to Z" elif sort_column == 'Times Cited': ascending = st.radio( "Sort order", ["Most cited first", "Least cited first"], horizontal=True ) == "Least cited first" else: ascending = True # Default for other columns # Create display dataframe display_df = df.copy() display_df['Summary'] = st.session_state.summaries display_df['Publication Year'] = display_df['Publication Year'].astype(int) display_df.rename(columns={'Times Cited, All Databases': 'Times Cited'}, inplace=True) display_df['Times Cited'] = display_df['Times Cited'].fillna(0).astype(int) # Apply filters filtered_df = create_filter_controls(display_df, sort_column) # Apply sorting if sort_column == 'Times Cited': sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) elif sort_column == 'Article Title': sorted_df = filtered_df.sort_values(by=sort_column, ascending=ascending) else: sorted_df = filtered_df # Show number of filtered results if len(sorted_df) != len(display_df): st.write(f"Showing {len(sorted_df)} of {len(display_df)} papers") # Apply custom styling st.markdown(""" """, unsafe_allow_html=True) # Display papers using the filtered and sorted dataframe for _, row in sorted_df.iterrows(): paper_info_cols = st.columns([1, 1]) with paper_info_cols[0]: # PAPER column st.markdown('