pendar02 commited on
Commit
7bd75d7
·
verified ·
1 Parent(s): 06d0182

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -167
app.py CHANGED
@@ -27,61 +27,17 @@ if 'processing_started' not in st.session_state:
27
  if 'focused_summary_generated' not in st.session_state:
28
  st.session_state.focused_summary_generated = False
29
 
30
- def preprocess_text(text):
31
- """Preprocess text for summarization"""
32
- if not isinstance(text, str) or not text.strip():
33
- return text
34
-
35
- # Clean up whitespace
36
- text = re.sub(r'\s+', ' ', text)
37
- text = text.strip()
38
-
39
- # Fix common formatting issues
40
- text = re.sub(r'(\d+)\s*%', r'\1%', text) # Fix percentage format
41
- text = re.sub(r'\(\s*([Nn])\s*=\s*(\d+)\s*\)', r'(n=\2)', text) # Fix sample size format
42
- text = re.sub(r'([Pp])\s*([<>])\s*(\d)', r'\1\2\3', text) # Fix p-value format
43
-
44
- return text
45
-
46
- def verify_facts(summary, original_text):
47
- """Verify key facts between summary and original text"""
48
- # Extract numbers and percentages
49
- def extract_numbers(text):
50
- return set(re.findall(r'(\d+\.?\d*)%?', text))
51
-
52
- # Extract relationships
53
- def extract_relationships(text):
54
- patterns = [
55
- r'associated with', r'predicted', r'correlated',
56
- r'increased', r'decreased', r'significant'
57
- ]
58
- found = []
59
- for pattern in patterns:
60
- if re.search(pattern, text.lower()):
61
- found.append(pattern)
62
- return set(found)
63
-
64
- # Get facts from both texts
65
- original_numbers = extract_numbers(original_text)
66
- summary_numbers = extract_numbers(summary)
67
- original_relations = extract_relationships(original_text)
68
- summary_relations = extract_relationships(summary)
69
-
70
- return {
71
- 'is_valid': summary_numbers.issubset(original_numbers) and
72
- summary_relations.issubset(original_relations),
73
- 'missing_numbers': original_numbers - summary_numbers,
74
- 'missing_relations': original_relations - summary_relations
75
- }
76
-
77
  def load_model(model_type):
78
  """Load appropriate model based on type with proper memory management"""
79
  try:
 
80
  gc.collect()
81
  torch.cuda.empty_cache()
82
- device = "cpu"
 
83
 
84
  if model_type == "summarize":
 
85
  model = AutoModelForSeq2SeqLM.from_pretrained(
86
  "pendar02/bart-large-pubmedd",
87
  cache_dir="./models",
@@ -92,7 +48,7 @@ def load_model(model_type):
92
  "pendar02/bart-large-pubmedd",
93
  cache_dir="./models"
94
  )
95
- else:
96
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
97
  "GanjinZero/biobart-base",
98
  cache_dir="./models",
@@ -117,6 +73,7 @@ def load_model(model_type):
117
  raise
118
 
119
  def cleanup_model(model, tokenizer):
 
120
  try:
121
  del model
122
  del tokenizer
@@ -125,12 +82,15 @@ def cleanup_model(model, tokenizer):
125
  except Exception:
126
  pass
127
 
 
128
  def process_excel(uploaded_file):
 
129
  try:
130
  df = pd.read_excel(uploaded_file)
131
  required_columns = ['Abstract', 'Article Title', 'Authors',
132
  'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases']
133
 
 
134
  missing_columns = [col for col in required_columns if col not in df.columns]
135
  if missing_columns:
136
  st.error(f"Missing required columns: {', '.join(missing_columns)}")
@@ -141,111 +101,119 @@ def process_excel(uploaded_file):
141
  st.error(f"Error processing file: {str(e)}")
142
  return None
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def improve_summary_generation(text, model, tokenizer):
145
  """Generate improved summary with better prompt and validation"""
146
  if not isinstance(text, str) or not text.strip():
147
  return "No abstract available to summarize."
148
 
149
- try:
150
- # Simplified prompt
151
- formatted_text = (
152
- "Summarize this biomedical abstract into four sections:\n"
153
- "1. Background/Objectives: State the main purpose and population\n"
154
- "2. Methods: Describe what was done\n"
155
- "3. Key findings: Include ALL numerical results and statistical relationships\n"
156
- "4. Conclusions: State main implications\n\n"
157
- "Important: Preserve all numbers, measurements, and statistical findings.\n\n"
158
- "Text: " + preprocess_text(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
160
-
161
- inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
162
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
163
-
164
- # Single generation attempt with optimized parameters
 
 
 
 
165
  with torch.no_grad():
166
  summary_ids = model.generate(
167
  **{
168
  "input_ids": inputs["input_ids"],
169
  "attention_mask": inputs["attention_mask"],
170
- "max_length": 300,
171
- "min_length": 100,
172
- "num_beams": 5,
173
  "length_penalty": 2.0,
174
- "no_repeat_ngram_size": 3,
175
- "temperature": 0.3,
176
- "repetition_penalty": 2.5
177
  }
178
  )
179
-
180
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
181
- if not summary:
182
- return "Error: Could not generate summary."
183
-
184
- return post_process_summary(summary)
185
-
186
- except Exception as e:
187
- print(f"Error in summary generation: {str(e)}")
188
- return "Error generating summary."
189
-
190
- def post_process_summary(summary):
191
- """Enhanced post-processing focused on maintaining structure and removing artifacts"""
192
- if not summary:
193
- return summary
194
-
195
- # Clean up section headers
196
- header_mappings = {
197
- r'(?i)background.*objectives?:?': 'Background and objectives:',
198
- r'(?i)(materials?\s*and\s*)?methods?:?': 'Methods:',
199
- r'(?i)(key\s*)?findings?:?|results?:?': 'Key findings:',
200
- r'(?i)conclusions?:?': 'Conclusions:',
201
- r'(?i)(study\s*)?aims?:?|goals?:?|purpose:?': '',
202
- r'(?i)objectives?:?': '',
203
- r'(?i)outcomes?:?': '',
204
- r'(?i)discussion:?': ''
205
- }
206
 
207
- for pattern, replacement in header_mappings.items():
208
- summary = re.sub(pattern, replacement, summary)
209
-
210
- # Split into sections and clean
211
- sections = re.split(r'(?i)(Background and objectives:|Methods:|Key findings:|Conclusions:)', summary)
212
- sections = [s.strip() for s in sections if s.strip()]
213
-
214
- # Reorganize sections
215
- organized_sections = {
216
- 'Background and objectives': '',
217
- 'Methods': '',
218
- 'Key findings': '',
219
- 'Conclusions': ''
220
- }
221
-
222
- current_section = None
223
- for item in sections:
224
- if item in organized_sections:
225
- current_section = item
226
- elif current_section:
227
- # Clean up content
228
- content = re.sub(r'\s+', ' ', item) # Fix spacing
229
- content = re.sub(r'\.+', '.', content) # Fix multiple periods
230
- content = content.strip('.: ') # Remove trailing periods and spaces
231
- organized_sections[current_section] = content
232
-
233
- # Build final summary
234
- final_sections = []
235
- for section, content in organized_sections.items():
236
- if content:
237
- final_sections.append(f"{section} {content}.")
238
-
239
- return '\n\n'.join(final_sections)
240
 
241
  def validate_summary(summary, original_text):
242
  """Validate summary content against original text"""
243
- # Perform fact verification
244
- verification = verify_facts(summary, original_text)
245
-
246
- if not verification.get('is_valid', False):
247
- return False
248
-
249
  # Check for age inconsistencies
250
  age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
251
  if len(age_mentions) > 1: # Multiple age mentions
@@ -267,40 +235,34 @@ def validate_summary(summary, original_text):
267
 
268
  def generate_focused_summary(question, abstracts, model, tokenizer):
269
  """Generate focused summary based on question"""
270
- try:
271
- # Preprocess each abstract
272
- formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts]
273
- combined_input = f"Question: {question}\nSummarize these abstracts to answer the question:\n" + \
274
- "\n---\n".join(formatted_abstracts)
275
-
276
- inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
277
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
278
-
279
- with torch.no_grad():
280
- summary_ids = model.generate(
281
- **{
282
- "input_ids": inputs["input_ids"],
283
- "attention_mask": inputs["attention_mask"],
284
- "max_length": 300,
285
- "min_length": 100,
286
- "num_beams": 5,
287
- "length_penalty": 2.0,
288
- "temperature": 0.3,
289
- "repetition_penalty": 2.5
290
- }
291
- )
292
-
293
- return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
294
-
295
- except Exception as e:
296
- print(f"Error in focused summary generation: {str(e)}")
297
- return "Error generating focused summary."
298
 
299
  def create_filter_controls(df, sort_column):
300
  """Create appropriate filter controls based on the selected column"""
301
  filtered_df = df.copy()
302
 
303
  if sort_column == 'Publication Year':
 
304
  year_min = int(df['Publication Year'].min())
305
  year_max = int(df['Publication Year'].max())
306
  col1, col2 = st.columns(2)
@@ -320,6 +282,7 @@ def create_filter_controls(df, sort_column):
320
  ]
321
 
322
  elif sort_column == 'Authors':
 
323
  unique_authors = sorted(set(
324
  author.strip()
325
  for authors in df['Authors'].dropna()
@@ -337,6 +300,7 @@ def create_filter_controls(df, sort_column):
337
  ]
338
 
339
  elif sort_column == 'Source Title':
 
340
  unique_sources = sorted(df['Source Title'].unique())
341
  selected_sources = st.multiselect(
342
  'Select Sources',
@@ -345,7 +309,13 @@ def create_filter_controls(df, sort_column):
345
  if selected_sources:
346
  filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)]
347
 
 
 
 
 
 
348
  elif sort_column == 'Times Cited':
 
349
  cited_min = int(df['Times Cited'].min())
350
  cited_max = int(df['Times Cited'].max())
351
  col1, col2 = st.columns(2)
@@ -369,16 +339,19 @@ def create_filter_controls(df, sort_column):
369
  def main():
370
  st.title("🔬 Biomedical Papers Analysis")
371
 
 
372
  uploaded_file = st.file_uploader(
373
  "Upload Excel file containing papers",
374
  type=['xlsx', 'xls'],
375
  help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI"
376
  )
377
 
 
378
  question_container = st.empty()
379
  question = ""
380
 
381
  if uploaded_file is not None:
 
382
  if st.session_state.processed_data is None:
383
  with st.spinner("Processing file..."):
384
  df = process_excel(uploaded_file)
@@ -389,16 +362,17 @@ def main():
389
  df = st.session_state.processed_data
390
  st.write(f"📊 Loaded {len(df)} papers with abstracts")
391
 
 
392
  with question_container:
393
  question = st.text_input(
394
  "Enter your research question (optional):",
395
- help="If provided, a focused summary will be generated after individual summaries"
396
  )
397
 
398
  # Single button for both processes
399
- if not st.session_state.get('processing_started', False):
400
- if st.button("Start Analysis"):
401
- st.session_state.processing_started = True
402
 
403
  # Show processing status and results
404
  if st.session_state.get('processing_started', False):
 
27
  if 'focused_summary_generated' not in st.session_state:
28
  st.session_state.focused_summary_generated = False
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def load_model(model_type):
31
  """Load appropriate model based on type with proper memory management"""
32
  try:
33
+ # Clear any existing cached data
34
  gc.collect()
35
  torch.cuda.empty_cache()
36
+
37
+ device = "cpu" # Force CPU usage
38
 
39
  if model_type == "summarize":
40
+ # Load the new fine-tuned model directly
41
  model = AutoModelForSeq2SeqLM.from_pretrained(
42
  "pendar02/bart-large-pubmedd",
43
  cache_dir="./models",
 
48
  "pendar02/bart-large-pubmedd",
49
  cache_dir="./models"
50
  )
51
+ else: # question_focused
52
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
53
  "GanjinZero/biobart-base",
54
  cache_dir="./models",
 
73
  raise
74
 
75
  def cleanup_model(model, tokenizer):
76
+ """Properly cleanup model resources"""
77
  try:
78
  del model
79
  del tokenizer
 
82
  except Exception:
83
  pass
84
 
85
+ @st.cache_data
86
  def process_excel(uploaded_file):
87
+ """Process uploaded Excel file"""
88
  try:
89
  df = pd.read_excel(uploaded_file)
90
  required_columns = ['Abstract', 'Article Title', 'Authors',
91
  'Source Title', 'Publication Year', 'DOI', 'Times Cited, All Databases']
92
 
93
+ # Check required columns
94
  missing_columns = [col for col in required_columns if col not in df.columns]
95
  if missing_columns:
96
  st.error(f"Missing required columns: {', '.join(missing_columns)}")
 
101
  st.error(f"Error processing file: {str(e)}")
102
  return None
103
 
104
+ def preprocess_text(text):
105
+ """Preprocess text to add appropriate formatting before summarization"""
106
+ if not isinstance(text, str) or not text.strip():
107
+ return text
108
+
109
+ # Split text into sentences (basic implementation)
110
+ sentences = [s.strip() for s in text.replace('. ', '.\n').split('\n')]
111
+
112
+ # Remove empty sentences
113
+ sentences = [s for s in sentences if s]
114
+
115
+ # Join with proper line breaks
116
+ formatted_text = '\n'.join(sentences)
117
+
118
+ return formatted_text
119
+
120
+ def post_process_summary(summary):
121
+ """Clean up and improve summary coherence"""
122
+ if not summary:
123
+ return summary
124
+
125
+ # Split into sentences
126
+ sentences = [s.strip() for s in summary.split('.')]
127
+ sentences = [s for s in sentences if s] # Remove empty sentences
128
+
129
+ # Fix common issues
130
+ processed_sentences = []
131
+ for i, sentence in enumerate(sentences):
132
+ # Remove redundant words/phrases
133
+ sentence = sentence.replace(" and and ", " and ")
134
+ sentence = sentence.replace("appointment and appointment", "appointment")
135
+
136
+ # Fix common grammatical issues
137
+ sentence = sentence.replace("Cancers distress", "Cancer distress")
138
+ sentence = sentence.replace(" ", " ") # Remove double spaces
139
+
140
+ # Capitalize first letter of each sentence
141
+ sentence = sentence.capitalize()
142
+
143
+ # Add to processed sentences if not empty
144
+ if sentence.strip():
145
+ processed_sentences.append(sentence)
146
+
147
+ # Join sentences with proper spacing and punctuation
148
+ cleaned_summary = '. '.join(processed_sentences)
149
+ if cleaned_summary and not cleaned_summary.endswith('.'):
150
+ cleaned_summary += '.'
151
+
152
+ return cleaned_summary
153
+
154
  def improve_summary_generation(text, model, tokenizer):
155
  """Generate improved summary with better prompt and validation"""
156
  if not isinstance(text, str) or not text.strip():
157
  return "No abstract available to summarize."
158
 
159
+ # Add a more specific prompt
160
+ formatted_text = (
161
+ "Summarize this medical research paper following this structure exactly:\n"
162
+ "1. Background and objectives\n"
163
+ "2. Methods\n"
164
+ "3. Key findings with specific numbers/percentages\n"
165
+ "4. Main conclusions\n"
166
+ "Original text: " + preprocess_text(text)
167
+ )
168
+
169
+ # Adjust generation parameters
170
+ inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
171
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
172
+
173
+ with torch.no_grad():
174
+ summary_ids = model.generate(
175
+ **{
176
+ "input_ids": inputs["input_ids"],
177
+ "attention_mask": inputs["attention_mask"],
178
+ "max_length": 200,
179
+ "min_length": 50,
180
+ "num_beams": 5,
181
+ "length_penalty": 1.5,
182
+ "no_repeat_ngram_size": 3,
183
+ "temperature": 0.7,
184
+ "repetition_penalty": 1.5
185
+ }
186
  )
187
+
188
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
189
+
190
+ # Post-process the summary
191
+ processed_summary = post_process_summary(summary)
192
+
193
+ # Validate the summary
194
+ if not validate_summary(processed_summary, text):
195
+ # If validation fails, try one more time with different parameters
196
  with torch.no_grad():
197
  summary_ids = model.generate(
198
  **{
199
  "input_ids": inputs["input_ids"],
200
  "attention_mask": inputs["attention_mask"],
201
+ "max_length": 200,
202
+ "min_length": 50,
203
+ "num_beams": 4,
204
  "length_penalty": 2.0,
205
+ "no_repeat_ngram_size": 4,
206
+ "temperature": 0.8,
207
+ "repetition_penalty": 2.0
208
  }
209
  )
 
210
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
211
+ processed_summary = post_process_summary(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
+ return processed_summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  def validate_summary(summary, original_text):
216
  """Validate summary content against original text"""
 
 
 
 
 
 
217
  # Check for age inconsistencies
218
  age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
219
  if len(age_mentions) > 1: # Multiple age mentions
 
235
 
236
  def generate_focused_summary(question, abstracts, model, tokenizer):
237
  """Generate focused summary based on question"""
238
+ # Preprocess each abstract
239
+ formatted_abstracts = [preprocess_text(abstract) for abstract in abstracts]
240
+ combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts)
241
+
242
+ inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
243
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
244
+
245
+ with torch.no_grad():
246
+ summary_ids = model.generate(
247
+ **{
248
+ "input_ids": inputs["input_ids"],
249
+ "attention_mask": inputs["attention_mask"],
250
+ "max_length": 200,
251
+ "min_length": 50,
252
+ "num_beams": 4,
253
+ "length_penalty": 2.0,
254
+ "early_stopping": True
255
+ }
256
+ )
257
+
258
+ return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
259
 
260
  def create_filter_controls(df, sort_column):
261
  """Create appropriate filter controls based on the selected column"""
262
  filtered_df = df.copy()
263
 
264
  if sort_column == 'Publication Year':
265
+ # Year range slider
266
  year_min = int(df['Publication Year'].min())
267
  year_max = int(df['Publication Year'].max())
268
  col1, col2 = st.columns(2)
 
282
  ]
283
 
284
  elif sort_column == 'Authors':
285
+ # Multi-select for authors
286
  unique_authors = sorted(set(
287
  author.strip()
288
  for authors in df['Authors'].dropna()
 
300
  ]
301
 
302
  elif sort_column == 'Source Title':
303
+ # Multi-select for source titles
304
  unique_sources = sorted(df['Source Title'].unique())
305
  selected_sources = st.multiselect(
306
  'Select Sources',
 
309
  if selected_sources:
310
  filtered_df = filtered_df[filtered_df['Source Title'].isin(selected_sources)]
311
 
312
+ elif sort_column == 'Article Title':
313
+ # Only alphabetical sorting, no filtering
314
+ pass
315
+
316
+
317
  elif sort_column == 'Times Cited':
318
+ # Cited count range slider
319
  cited_min = int(df['Times Cited'].min())
320
  cited_max = int(df['Times Cited'].max())
321
  col1, col2 = st.columns(2)
 
339
  def main():
340
  st.title("🔬 Biomedical Papers Analysis")
341
 
342
+ # File upload section
343
  uploaded_file = st.file_uploader(
344
  "Upload Excel file containing papers",
345
  type=['xlsx', 'xls'],
346
  help="File must contain: Abstract, Article Title, Authors, Source Title, Publication Year, DOI"
347
  )
348
 
349
+ # Question input - moved up but hidden initially
350
  question_container = st.empty()
351
  question = ""
352
 
353
  if uploaded_file is not None:
354
+ # Process Excel file
355
  if st.session_state.processed_data is None:
356
  with st.spinner("Processing file..."):
357
  df = process_excel(uploaded_file)
 
362
  df = st.session_state.processed_data
363
  st.write(f"📊 Loaded {len(df)} papers with abstracts")
364
 
365
+ # Get question before processing
366
  with question_container:
367
  question = st.text_input(
368
  "Enter your research question (optional):",
369
+ help="If provided, a question-focused summary will be generated after individual summaries"
370
  )
371
 
372
  # Single button for both processes
373
+ if not st.session_state.get('processing_started', False):
374
+ if st.button("Start Analysis"):
375
+ st.session_state.processing_started = True
376
 
377
  # Show processing status and results
378
  if st.session_state.get('processing_started', False):