pendar02 commited on
Commit
b7bd5a2
·
verified ·
1 Parent(s): 86deaaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -173
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  from peft import PeftModel
6
  from text_processing import TextProcessor
7
  import gc
8
- import time
9
  from pathlib import Path
10
 
11
  # Configure page
@@ -22,40 +21,31 @@ if 'summaries' not in st.session_state:
22
  st.session_state.summaries = None
23
  if 'text_processor' not in st.session_state:
24
  st.session_state.text_processor = None
25
-
26
- def manage_resources():
27
- """Clear memory and ensure resources are available"""
28
- # Force garbage collection
29
- gc.collect()
30
-
31
- # Clear CUDA cache if available
32
- if torch.cuda.is_available():
33
- torch.cuda.empty_cache()
34
-
35
- # Set torch to use CPU
36
- torch.set_num_threads(8) # Use half of available CPU threads for each model
37
 
38
  def load_model(model_type):
39
- """Load appropriate model based on type with resource management"""
40
- manage_resources()
41
-
42
  try:
 
 
 
 
43
  if model_type == "summarize":
44
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
45
  "facebook/bart-large-cnn",
46
  cache_dir="./models",
47
- device_map=None, # Explicitly set to None for CPU
48
  torch_dtype=torch.float32
49
- ).to("cpu") # Force CPU
50
-
51
  model = PeftModel.from_pretrained(
52
  base_model,
53
  "pendar02/results",
54
- device_map=None, # Explicitly set to None for CPU
55
- torch_dtype=torch.float32,
56
- is_trainable=False # Set to inference mode
57
- ).to("cpu") # Force CPU
58
-
59
  tokenizer = AutoTokenizer.from_pretrained(
60
  "facebook/bart-large-cnn",
61
  cache_dir="./models"
@@ -64,36 +54,43 @@ def load_model(model_type):
64
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
65
  "GanjinZero/biobart-base",
66
  cache_dir="./models",
67
- device_map=None, # Explicitly set to None for CPU
68
  torch_dtype=torch.float32
69
- ).to("cpu") # Force CPU
70
-
71
  model = PeftModel.from_pretrained(
72
  base_model,
73
  "pendar02/biobart-finetune",
74
- device_map=None, # Explicitly set to None for CPU
75
- torch_dtype=torch.float32,
76
- is_trainable=False # Set to inference mode
77
- ).to("cpu") # Force CPU
78
-
79
  tokenizer = AutoTokenizer.from_pretrained(
80
  "GanjinZero/biobart-base",
81
  cache_dir="./models"
82
  )
83
 
84
- model.eval() # Set to evaluation mode
85
  return model, tokenizer
86
  except Exception as e:
87
  st.error(f"Error loading model: {str(e)}")
88
  raise
89
 
 
 
 
 
 
 
 
 
 
 
90
  @st.cache_data
91
  def process_excel(uploaded_file):
92
  """Process uploaded Excel file"""
93
  try:
94
  df = pd.read_excel(uploaded_file)
95
  required_columns = ['Abstract', 'Article Title', 'Authors',
96
- 'Source Title', 'Publication Year', 'DOI']
97
 
98
  # Check required columns
99
  missing_columns = [col for col in required_columns if col not in df.columns]
@@ -127,9 +124,18 @@ def generate_summary(text, model, tokenizer):
127
  if not isinstance(text, str) or not text.strip():
128
  return "No abstract available to summarize."
129
 
 
 
 
 
 
130
  # Preprocess the text first
131
  formatted_text = preprocess_text(text)
132
 
 
 
 
 
133
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
134
 
135
  with torch.no_grad():
@@ -137,15 +143,22 @@ def generate_summary(text, model, tokenizer):
137
  **{
138
  "input_ids": inputs["input_ids"],
139
  "attention_mask": inputs["attention_mask"],
140
- "max_length": 150,
141
- "min_length": 50,
142
  "num_beams": 4,
143
  "length_penalty": 2.0,
144
- "early_stopping": True
 
145
  }
146
  )
147
 
148
- return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
149
 
150
  def generate_focused_summary(question, abstracts, model, tokenizer):
151
  """Generate focused summary based on question"""
@@ -173,11 +186,6 @@ def generate_focused_summary(question, abstracts, model, tokenizer):
173
  def main():
174
  st.title("🔬 Biomedical Papers Analysis")
175
 
176
- # Initialize text processor if not already done
177
- if st.session_state.text_processor is None:
178
- with st.spinner("Loading NLP models..."):
179
- st.session_state.text_processor = TextProcessor()
180
-
181
  # File upload section
182
  uploaded_file = st.file_uploader(
183
  "Upload Excel file containing papers",
@@ -185,6 +193,10 @@ def main():
185
  help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI"
186
  )
187
 
 
 
 
 
188
  if uploaded_file is not None:
189
  # Process Excel file
190
  if st.session_state.processed_data is None:
@@ -192,146 +204,144 @@ def main():
192
  df = process_excel(uploaded_file)
193
  if df is not None:
194
  st.session_state.processed_data = df.dropna(subset=["Abstract"])
195
-
196
- if st.session_state.processed_data is not None:
197
- df = st.session_state.processed_data
198
- st.write(f"📊 Loaded {len(df)} papers")
199
 
200
- # Question input before the unified generate button
201
- st.header("❓ Question-focused Summary (Optional)")
202
- question = st.text_input("Enter your research question (optional):")
203
-
204
- # Unified generate button
205
- if st.button("Generate Analysis"):
206
- try:
207
- # Step 1: Generate Individual Summaries
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  if st.session_state.summaries is None:
209
- with st.spinner("Generating individual summaries..."):
210
- model, tokenizer = load_model("summarize")
211
-
212
- progress_text = st.empty()
213
- progress_bar = st.progress(0)
214
-
215
- # Create a table for live updates
216
- summary_table = st.empty()
217
- summaries = []
218
- table_data = []
219
-
220
- for i, (_, row) in enumerate(df.iterrows()):
221
- progress_text.text(f"Processing paper {i+1} of {len(df)}")
222
- progress_bar.progress((i + 1) / len(df))
223
 
224
- summary = generate_summary(row['Abstract'], model, tokenizer)
225
- summaries.append(summary)
 
 
226
 
227
- # Update table data
228
- table_data.append({
229
- "PAPER": f"{row['Article Title']}\n{row['Authors']}\nDOI: {row['DOI']}",
230
- "SUMMARY": summary
231
- })
232
- summary_table.dataframe(
233
- pd.DataFrame(table_data),
234
- column_config={
235
- "PAPER": st.column_config.TextColumn("PAPER", width=300),
236
- "SUMMARY": st.column_config.TextColumn("SUMMARY", width="medium")
237
- },
238
- hide_index=True
239
- )
240
-
241
- st.session_state.summaries = summaries
242
-
243
- # Clear memory after individual summaries
244
- del model
245
- del tokenizer
246
- torch.cuda.empty_cache()
247
- gc.collect()
248
 
249
- # Step 2: Generate Question-Focused Summary (only if question is provided)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  if question.strip():
251
- with st.spinner("Generating question-focused summary..."):
252
- # Clear memory before question processing
253
- torch.cuda.empty_cache()
254
- gc.collect()
255
-
256
- # Find relevant abstracts
257
- results = st.session_state.text_processor.find_most_relevant_abstracts(
258
- question,
259
- df['Abstract'].tolist(),
260
- top_k=5
261
- )
262
-
263
- # Load question model
264
- model, tokenizer = load_model("question_focused")
265
-
266
- relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist()
267
- focused_summary = generate_focused_summary(
268
- question,
269
- relevant_abstracts,
270
- model,
271
- tokenizer
272
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
- st.subheader("Question-Focused Summary")
275
- st.write(focused_summary)
 
 
 
 
 
276
 
277
  st.subheader("Most Relevant Papers")
278
- relevant_papers = df.iloc[results['top_indices']][
279
  ['Article Title', 'Authors', 'Publication Year', 'DOI']
280
- ]
281
- relevant_papers['Relevance Score'] = results['scores']
282
  relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int)
283
-
284
- st.dataframe(
285
- relevant_papers,
286
- column_config={
287
- 'Publication Year': st.column_config.NumberColumn('Year', format="%d"),
288
- 'Relevance Score': st.column_config.NumberColumn('Relevance', format="%.3f")
289
- },
290
- hide_index=True
291
- )
292
-
293
- # Clear memory after question processing
294
- del model
295
- del tokenizer
296
- torch.cuda.empty_cache()
297
- gc.collect()
298
-
299
- except Exception as e:
300
- st.error(f"Error in analysis: {str(e)}")
301
-
302
- # Display sorted summaries if they exist
303
- if st.session_state.summaries is not None:
304
- st.header("📝 Individual Paper Summaries")
305
- col1, col2 = st.columns([2, 1])
306
- with col1:
307
- sort_by = st.selectbox(
308
- "Sort By",
309
- ["Article Title", "Publication Year"],
310
- key="sort_summaries"
311
- )
312
- with col2:
313
- ascending = st.checkbox("Ascending order", True, key="sort_order")
314
-
315
- # Create display dataframe
316
- display_df = df.copy()
317
- display_df['PAPER'] = display_df.apply(
318
- lambda x: f"{x['Article Title']}\n{x['Authors']}\nDOI: {x['DOI']}",
319
- axis=1
320
- )
321
- display_df['SUMMARY'] = st.session_state.summaries
322
-
323
- # Sort the dataframe
324
- sorted_df = display_df.sort_values(by=sort_by, ascending=ascending)
325
-
326
- # Display the table
327
- st.dataframe(
328
- sorted_df[['PAPER', 'SUMMARY']],
329
- column_config={
330
- "PAPER": st.column_config.TextColumn("PAPER", width=300),
331
- "SUMMARY": st.column_config.TextColumn("SUMMARY", width="medium")
332
- },
333
- hide_index=True
334
- )
335
 
336
  if __name__ == "__main__":
337
  main()
 
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLm
5
  from peft import PeftModel
6
  from text_processing import TextProcessor
7
  import gc
 
8
  from pathlib import Path
9
 
10
  # Configure page
 
21
  st.session_state.summaries = None
22
  if 'text_processor' not in st.session_state:
23
  st.session_state.text_processor = None
24
+ if 'processing_started' not in st.session_state:
25
+ st.session_state.processing_started = False
26
+ if 'focused_summary_generated' not in st.session_state:
27
+ st.session_state.focused_summary_generated = False
 
 
 
 
 
 
 
 
28
 
29
  def load_model(model_type):
30
+ """Load appropriate model based on type with proper memory management"""
 
 
31
  try:
32
+ # Clear any existing cached data
33
+ torch.cuda.empty_cache()
34
+ gc.collect()
35
+
36
  if model_type == "summarize":
37
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
38
  "facebook/bart-large-cnn",
39
  cache_dir="./models",
40
+ low_cpu_mem_usage=True,
41
  torch_dtype=torch.float32
42
+ )
 
43
  model = PeftModel.from_pretrained(
44
  base_model,
45
  "pendar02/results",
46
+ device_map="auto",
47
+ torch_dtype=torch.float32
48
+ )
 
 
49
  tokenizer = AutoTokenizer.from_pretrained(
50
  "facebook/bart-large-cnn",
51
  cache_dir="./models"
 
54
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
55
  "GanjinZero/biobart-base",
56
  cache_dir="./models",
57
+ low_cpu_mem_usage=True,
58
  torch_dtype=torch.float32
59
+ )
 
60
  model = PeftModel.from_pretrained(
61
  base_model,
62
  "pendar02/biobart-finetune",
63
+ device_map="auto",
64
+ torch_dtype=torch.float32
65
+ )
 
 
66
  tokenizer = AutoTokenizer.from_pretrained(
67
  "GanjinZero/biobart-base",
68
  cache_dir="./models"
69
  )
70
 
71
+ model.eval()
72
  return model, tokenizer
73
  except Exception as e:
74
  st.error(f"Error loading model: {str(e)}")
75
  raise
76
 
77
+ def cleanup_model(model, tokenizer):
78
+ """Properly cleanup model resources"""
79
+ try:
80
+ del model
81
+ del tokenizer
82
+ torch.cuda.empty_cache()
83
+ gc.collect()
84
+ except Exception:
85
+ pass
86
+
87
  @st.cache_data
88
  def process_excel(uploaded_file):
89
  """Process uploaded Excel file"""
90
  try:
91
  df = pd.read_excel(uploaded_file)
92
  required_columns = ['Abstract', 'Article Title', 'Authors',
93
+ 'Source Title', 'Publication Year', 'DOI']
94
 
95
  # Check required columns
96
  missing_columns = [col for col in required_columns if col not in df.columns]
 
124
  if not isinstance(text, str) or not text.strip():
125
  return "No abstract available to summarize."
126
 
127
+ # Check if abstract is too short
128
+ word_count = len(text.split())
129
+ if word_count < 50: # Threshold for "short" abstracts
130
+ return text # Return original text for very short abstracts
131
+
132
  # Preprocess the text first
133
  formatted_text = preprocess_text(text)
134
 
135
+ # Adjust generation parameters based on input length
136
+ max_length = min(150, word_count + 50) # Dynamic max length
137
+ min_length = min(50, word_count) # Dynamic min length
138
+
139
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
140
 
141
  with torch.no_grad():
 
143
  **{
144
  "input_ids": inputs["input_ids"],
145
  "attention_mask": inputs["attention_mask"],
146
+ "max_length": max_length,
147
+ "min_length": min_length,
148
  "num_beams": 4,
149
  "length_penalty": 2.0,
150
+ "early_stopping": True,
151
+ "no_repeat_ngram_size": 3 # Prevent repetition of phrases
152
  }
153
  )
154
 
155
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
156
+
157
+ # Post-process summary
158
+ if summary.lower() == text.lower() or len(summary.split()) / word_count > 0.9:
159
+ return text # Return original if summary is too similar
160
+
161
+ return summary
162
 
163
  def generate_focused_summary(question, abstracts, model, tokenizer):
164
  """Generate focused summary based on question"""
 
186
  def main():
187
  st.title("🔬 Biomedical Papers Analysis")
188
 
 
 
 
 
 
189
  # File upload section
190
  uploaded_file = st.file_uploader(
191
  "Upload Excel file containing papers",
 
193
  help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI"
194
  )
195
 
196
+ # Question input - moved up but hidden initially
197
+ question_container = st.empty()
198
+ question = ""
199
+
200
  if uploaded_file is not None:
201
  # Process Excel file
202
  if st.session_state.processed_data is None:
 
204
  df = process_excel(uploaded_file)
205
  if df is not None:
206
  st.session_state.processed_data = df.dropna(subset=["Abstract"])
 
 
 
 
207
 
208
+ if st.session_state.processed_data is not None:
209
+ df = st.session_state.processed_data
210
+ st.write(f"📊 Loaded {len(df)} papers with abstracts")
211
+
212
+ # Get question before processing
213
+ with question_container:
214
+ question = st.text_input(
215
+ "Enter your research question (optional):",
216
+ help="If provided, a question-focused summary will be generated after individual summaries"
217
+ )
218
+
219
+ # Single button for both processes
220
+ if not st.session_state.get('processing_started', False):
221
+ if st.button("Start Analysis"):
222
+ st.session_state.processing_started = True
223
+
224
+ # Show processing status and results
225
+ if st.session_state.get('processing_started', False):
226
+ # Individual Summaries Section
227
+ st.header("📝 Individual Paper Summaries")
228
+
229
  if st.session_state.summaries is None:
230
+ try:
231
+ with st.spinner("Generating summaries..."):
232
+ # Load summarization model
233
+ model, tokenizer = load_model("summarize")
 
 
 
 
 
 
 
 
 
 
234
 
235
+ # Process abstracts with real-time updates
236
+ summaries = []
237
+ progress_bar = st.progress(0)
238
+ summary_display = st.empty()
239
 
240
+ for i, (_, row) in enumerate(df.iterrows()):
241
+ summary = generate_summary(row['Abstract'], model, tokenizer)
242
+ summaries.append(summary)
243
+
244
+ # Update progress and show current summary
245
+ progress = (i + 1) / len(df)
246
+ progress_bar.progress(progress)
247
+ summary_display.write(f"Processing paper {i+1}/{len(df)}:\n{row['Article Title']}")
248
+
249
+ st.session_state.summaries = summaries
250
+
251
+ # Cleanup first model
252
+ cleanup_model(model, tokenizer)
253
+
254
+ except Exception as e:
255
+ st.error(f"Error generating summaries: {str(e)}")
 
 
 
 
 
256
 
257
+ # Display summaries with improved sorting
258
+ if st.session_state.summaries is not None:
259
+ col1, col2 = st.columns(2)
260
+ with col1:
261
+ sort_options = ['Article Title', 'Authors', 'Publication Year', 'Source Title']
262
+ sort_column = st.selectbox("Sort by:", sort_options)
263
+ with col2:
264
+ ascending = st.checkbox("Ascending order", True)
265
+
266
+ # Create display dataframe with formatted year
267
+ display_df = df.copy()
268
+ display_df['Summary'] = st.session_state.summaries
269
+ display_df['Publication Year'] = display_df['Publication Year'].astype(int)
270
+ sorted_df = display_df.sort_values(by=sort_column, ascending=ascending)
271
+
272
+ # Apply custom formatting
273
+ st.markdown("""
274
+ <style>
275
+ .stDataFrame {
276
+ font-size: 16px;
277
+ }
278
+ .stDataFrame td {
279
+ white-space: normal !important;
280
+ padding: 8px !important;
281
+ }
282
+ </style>
283
+ """, unsafe_allow_html=True)
284
+
285
+ st.dataframe(
286
+ sorted_df[['Article Title', 'Authors', 'Source Title',
287
+ 'Publication Year', 'DOI', 'Summary']],
288
+ hide_index=True
289
+ )
290
+
291
+ # Question-focused Summary Section (only if question provided)
292
  if question.strip():
293
+ st.header(" Question-focused Summary")
294
+
295
+ if not st.session_state.get('focused_summary_generated', False):
296
+ try:
297
+ with st.spinner("Analyzing relevant papers..."):
298
+ # Initialize text processor if needed
299
+ if st.session_state.text_processor is None:
300
+ st.session_state.text_processor = TextProcessor()
301
+
302
+ # Find relevant abstracts
303
+ results = st.session_state.text_processor.find_most_relevant_abstracts(
304
+ question,
305
+ df['Abstract'].tolist(),
306
+ top_k=5
307
+ )
308
+
309
+ # Load question-focused model
310
+ model, tokenizer = load_model("question_focused")
311
+
312
+ # Generate focused summary
313
+ relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist()
314
+ focused_summary = generate_focused_summary(
315
+ question,
316
+ relevant_abstracts,
317
+ model,
318
+ tokenizer
319
+ )
320
+
321
+ # Store results
322
+ st.session_state.focused_summary = focused_summary
323
+ st.session_state.relevant_papers = df.iloc[results['top_indices']]
324
+ st.session_state.relevance_scores = results['scores']
325
+ st.session_state.focused_summary_generated = True
326
+
327
+ # Cleanup second model
328
+ cleanup_model(model, tokenizer)
329
 
330
+ except Exception as e:
331
+ st.error(f"Error generating focused summary: {str(e)}")
332
+
333
+ # Display focused summary results
334
+ if st.session_state.get('focused_summary_generated', False):
335
+ st.subheader("Summary")
336
+ st.write(st.session_state.focused_summary)
337
 
338
  st.subheader("Most Relevant Papers")
339
+ relevant_papers = st.session_state.relevant_papers[
340
  ['Article Title', 'Authors', 'Publication Year', 'DOI']
341
+ ].copy()
342
+ relevant_papers['Relevance Score'] = st.session_state.relevance_scores
343
  relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int)
344
+ st.dataframe(relevant_papers, hide_index=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  if __name__ == "__main__":
347
  main()