pendar02 commited on
Commit
1229bf2
Β·
verified Β·
1 Parent(s): 2c2de78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -77
app.py CHANGED
@@ -23,20 +23,43 @@ if 'summaries' not in st.session_state:
23
  if 'text_processor' not in st.session_state:
24
  st.session_state.text_processor = None
25
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def load_model(model_type):
27
- """Load appropriate model based on type"""
 
 
28
  try:
 
 
 
 
 
 
 
 
29
  if model_type == "summarize":
30
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
31
  "facebook/bart-large-cnn",
32
- cache_dir="./models"
 
 
33
  )
34
  model = PeftModel.from_pretrained(
35
  base_model,
36
  "pendar02/results",
37
- load_in_8bit=False,
38
- device_map="auto",
39
- torch_dtype=torch.float32
40
  )
41
  tokenizer = AutoTokenizer.from_pretrained(
42
  "facebook/bart-large-cnn",
@@ -45,14 +68,15 @@ def load_model(model_type):
45
  else: # question_focused
46
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
47
  "GanjinZero/biobart-base",
48
- cache_dir="./models"
 
 
49
  )
50
  model = PeftModel.from_pretrained(
51
  base_model,
52
  "pendar02/biobart-finetune",
53
- load_in_8bit=False,
54
- device_map="auto",
55
- torch_dtype=torch.float32
56
  )
57
  tokenizer = AutoTokenizer.from_pretrained(
58
  "GanjinZero/biobart-base",
@@ -148,23 +172,10 @@ def generate_focused_summary(question, abstracts, model, tokenizer):
148
 
149
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
150
 
 
151
  def main():
152
  st.title("πŸ”¬ Biomedical Papers Analysis")
153
 
154
- # Sidebar
155
- st.sidebar.header("About")
156
- st.sidebar.info(
157
- "This app analyzes biomedical research papers. Upload an Excel file "
158
- "containing paper details and abstracts to:"
159
- "\n- Generate individual summaries"
160
- "\n- Get question-focused insights"
161
- )
162
-
163
- # Initialize text processor if not already done
164
- if st.session_state.text_processor is None:
165
- with st.spinner("Loading NLP models..."):
166
- st.session_state.text_processor = TextProcessor()
167
-
168
  # File upload section
169
  uploaded_file = st.file_uploader(
170
  "Upload Excel file containing papers",
@@ -179,74 +190,66 @@ def main():
179
  df = process_excel(uploaded_file)
180
  if df is not None:
181
  st.session_state.processed_data = df.dropna(subset=["Abstract"])
 
 
 
 
182
 
183
- if st.session_state.processed_data is not None:
184
- df = st.session_state.processed_data
185
- st.write(f"πŸ“Š Loaded {len(df)} papers with abstracts")
186
-
187
- # Individual Summaries Section
188
- st.header("πŸ“ Individual Paper Summaries")
189
-
190
- if st.session_state.summaries is None and st.button("Generate Individual Summaries"):
191
- try:
192
- with st.spinner("Generating summaries..."):
193
- # Load summarization model
 
 
194
  model, tokenizer = load_model("summarize")
195
 
196
- # Process abstracts
197
  progress_bar = st.progress(0)
198
- summaries = []
199
 
200
- for i, abstract in enumerate(df['Abstract']):
201
- summary = generate_summary(abstract, model, tokenizer)
202
- summaries.append(summary)
203
  progress_bar.progress((i + 1) / len(df))
 
 
 
 
 
 
 
 
204
 
205
  st.session_state.summaries = summaries
206
 
207
- # Clear GPU memory
208
  del model
209
  del tokenizer
210
  torch.cuda.empty_cache()
211
  gc.collect()
212
 
213
- except Exception as e:
214
- st.error(f"Error generating summaries: {str(e)}")
215
-
216
- if st.session_state.summaries is not None:
217
- # Display summaries with sorting options
218
- col1, col2 = st.columns(2)
219
- with col1:
220
- sort_column = st.selectbox("Sort by:", df.columns)
221
- with col2:
222
- ascending = st.checkbox("Ascending order", True)
223
-
224
- # Create display dataframe
225
- display_df = df.copy()
226
- display_df['Summary'] = st.session_state.summaries
227
- sorted_df = display_df.sort_values(by=sort_column, ascending=ascending)
228
-
229
- # Show interactive table
230
- st.dataframe(sorted_df, hide_index=True)
231
-
232
- # Question-focused Summary Section
233
- st.header("❓ Question-focused Summary")
234
- question = st.text_input("Enter your research question:")
235
-
236
- if question and st.button("Generate Focused Summary"):
237
- try:
238
- with st.spinner("Analyzing relevant papers..."):
239
- # Find relevant abstracts
240
  results = st.session_state.text_processor.find_most_relevant_abstracts(
241
  question,
242
  df['Abstract'].tolist(),
243
  top_k=5
244
  )
245
 
246
- # Load question-focused model
247
  model, tokenizer = load_model("question_focused")
248
 
249
- # Get relevant abstracts and generate summary
250
  relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist()
251
  focused_summary = generate_focused_summary(
252
  question,
@@ -255,26 +258,68 @@ def main():
255
  tokenizer
256
  )
257
 
258
- # Display results
259
- st.subheader("Summary")
260
  st.write(focused_summary)
261
 
262
- # Show relevant papers
263
  st.subheader("Most Relevant Papers")
264
  relevant_papers = df.iloc[results['top_indices']][
265
  ['Article Title', 'Authors', 'Publication Year', 'DOI']
266
  ]
267
  relevant_papers['Relevance Score'] = results['scores']
268
- st.dataframe(relevant_papers, hide_index=True)
269
 
270
- # Clear GPU memory
 
 
 
 
 
 
 
 
 
271
  del model
272
  del tokenizer
273
  torch.cuda.empty_cache()
274
  gc.collect()
275
-
276
- except Exception as e:
277
- st.error(f"Error generating focused summary: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  if __name__ == "__main__":
280
  main()
 
23
  if 'text_processor' not in st.session_state:
24
  st.session_state.text_processor = None
25
 
26
+ def manage_resources():
27
+ """Clear memory and ensure resources are available"""
28
+ # Force garbage collection
29
+ gc.collect()
30
+
31
+ # Clear CUDA cache if available
32
+ if torch.cuda.is_available():
33
+ torch.cuda.empty_cache()
34
+
35
+ # Set torch to use CPU
36
+ torch.set_num_threads(8) # Use half of available CPU threads for each model
37
+
38
  def load_model(model_type):
39
+ """Load appropriate model based on type with resource management"""
40
+ manage_resources()
41
+
42
  try:
43
+ # Set lower precision to reduce memory usage
44
+ torch_dtype = torch.float32
45
+ if torch.cuda.is_available():
46
+ device = "cuda"
47
+ else:
48
+ device = "cpu"
49
+ torch_dtype = torch.float32 # Use float32 for CPU
50
+
51
  if model_type == "summarize":
52
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
53
  "facebook/bart-large-cnn",
54
+ cache_dir="./models",
55
+ torch_dtype=torch_dtype,
56
+ low_cpu_mem_usage=True
57
  )
58
  model = PeftModel.from_pretrained(
59
  base_model,
60
  "pendar02/results",
61
+ device_map=device,
62
+ torch_dtype=torch_dtype
 
63
  )
64
  tokenizer = AutoTokenizer.from_pretrained(
65
  "facebook/bart-large-cnn",
 
68
  else: # question_focused
69
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
70
  "GanjinZero/biobart-base",
71
+ cache_dir="./models",
72
+ torch_dtype=torch_dtype,
73
+ low_cpu_mem_usage=True
74
  )
75
  model = PeftModel.from_pretrained(
76
  base_model,
77
  "pendar02/biobart-finetune",
78
+ device_map=device,
79
+ torch_dtype=torch_dtype
 
80
  )
81
  tokenizer = AutoTokenizer.from_pretrained(
82
  "GanjinZero/biobart-base",
 
172
 
173
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
174
 
175
+
176
  def main():
177
  st.title("πŸ”¬ Biomedical Papers Analysis")
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  # File upload section
180
  uploaded_file = st.file_uploader(
181
  "Upload Excel file containing papers",
 
190
  df = process_excel(uploaded_file)
191
  if df is not None:
192
  st.session_state.processed_data = df.dropna(subset=["Abstract"])
193
+
194
+ if st.session_state.processed_data is not None:
195
+ df = st.session_state.processed_data
196
+ st.write(f"πŸ“Š Loaded {len(df)} papers")
197
 
198
+ # Individual Summaries Section
199
+ st.header("πŸ“ Individual Paper Summaries")
200
+
201
+ # Question input before the unified generate button
202
+ st.header("❓ Question-focused Summary (Optional)")
203
+ question = st.text_input("Enter your research question (optional):")
204
+
205
+ # Unified generate button
206
+ if st.button("Generate Analysis"):
207
+ try:
208
+ # Step 1: Generate Individual Summaries
209
+ if st.session_state.summaries is None:
210
+ with st.spinner("Generating individual summaries..."):
211
  model, tokenizer = load_model("summarize")
212
 
213
+ progress_text = st.empty()
214
  progress_bar = st.progress(0)
215
+ summary_display = st.container()
216
 
217
+ summaries = []
218
+ for i, (_, row) in enumerate(df.iterrows()):
219
+ progress_text.text(f"Processing paper {i+1} of {len(df)}")
220
  progress_bar.progress((i + 1) / len(df))
221
+
222
+ summary = generate_summary(row['Abstract'], model, tokenizer)
223
+ summaries.append(summary)
224
+
225
+ with summary_display:
226
+ st.write(f"**Paper {i+1}:** {row['Article Title']}")
227
+ st.write(summary)
228
+ st.divider()
229
 
230
  st.session_state.summaries = summaries
231
 
232
+ # Clear memory after individual summaries
233
  del model
234
  del tokenizer
235
  torch.cuda.empty_cache()
236
  gc.collect()
237
 
238
+ # Step 2: Generate Question-Focused Summary (only if question is provided)
239
+ if question.strip():
240
+ with st.spinner("Generating question-focused summary..."):
241
+ # Clear memory before question processing
242
+ torch.cuda.empty_cache()
243
+ gc.collect()
244
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  results = st.session_state.text_processor.find_most_relevant_abstracts(
246
  question,
247
  df['Abstract'].tolist(),
248
  top_k=5
249
  )
250
 
 
251
  model, tokenizer = load_model("question_focused")
252
 
 
253
  relevant_abstracts = df['Abstract'].iloc[results['top_indices']].tolist()
254
  focused_summary = generate_focused_summary(
255
  question,
 
258
  tokenizer
259
  )
260
 
261
+ st.subheader("Question-Focused Summary")
 
262
  st.write(focused_summary)
263
 
 
264
  st.subheader("Most Relevant Papers")
265
  relevant_papers = df.iloc[results['top_indices']][
266
  ['Article Title', 'Authors', 'Publication Year', 'DOI']
267
  ]
268
  relevant_papers['Relevance Score'] = results['scores']
269
+ relevant_papers['Publication Year'] = relevant_papers['Publication Year'].astype(int)
270
 
271
+ st.dataframe(
272
+ relevant_papers,
273
+ column_config={
274
+ 'Publication Year': st.column_config.NumberColumn('Year', format="%d"),
275
+ 'Relevance Score': st.column_config.NumberColumn('Relevance', format="%.3f")
276
+ },
277
+ hide_index=True
278
+ )
279
+
280
+ # Clear memory after question processing
281
  del model
282
  del tokenizer
283
  torch.cuda.empty_cache()
284
  gc.collect()
285
+
286
+ except Exception as e:
287
+ st.error(f"Error in analysis: {str(e)}")
288
+
289
+ # Display sorted summaries if they exist
290
+ if st.session_state.summaries is not None:
291
+ st.subheader("All Paper Summaries")
292
+ sort_options = {
293
+ 'Article Title': 'Article Title',
294
+ 'Authors': 'Authors',
295
+ 'Publication Year': 'Publication Year',
296
+ 'Source Title': 'Source Title'
297
+ }
298
+
299
+ col1, col2 = st.columns(2)
300
+ with col1:
301
+ sort_column = st.selectbox("Sort by:", list(sort_options.keys()))
302
+ with col2:
303
+ ascending = st.checkbox("Ascending order", True)
304
+
305
+ display_df = df.copy()
306
+ display_df['Summary'] = st.session_state.summaries
307
+ display_df['Publication Year'] = display_df['Publication Year'].astype(int)
308
+ sorted_df = display_df.sort_values(by=sort_options[sort_column], ascending=ascending)
309
+
310
+ st.dataframe(
311
+ sorted_df[['Article Title', 'Authors', 'Source Title',
312
+ 'Publication Year', 'DOI', 'Summary']],
313
+ column_config={
314
+ 'Article Title': st.column_config.TextColumn('Article Title', width='medium'),
315
+ 'Authors': st.column_config.TextColumn('Authors', width='medium'),
316
+ 'Source Title': st.column_config.TextColumn('Source Title', width='medium'),
317
+ 'Publication Year': st.column_config.NumberColumn('Year', format="%d"),
318
+ 'DOI': st.column_config.TextColumn('DOI', width='small'),
319
+ 'Summary': st.column_config.TextColumn('Summary', width='large'),
320
+ },
321
+ hide_index=True
322
+ )
323
 
324
  if __name__ == "__main__":
325
  main()