pendar02 commited on
Commit
d0820e9
·
verified ·
1 Parent(s): 805eb33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -196
app.py CHANGED
@@ -36,6 +36,7 @@ def load_model(model_type):
36
  device = "cpu" # Force CPU usage
37
 
38
  if model_type == "summarize":
 
39
  model = AutoModelForSeq2SeqLM.from_pretrained(
40
  "pendar02/bart-large-pubmedd",
41
  cache_dir="./models",
@@ -150,219 +151,65 @@ def post_process_summary(summary):
150
  return cleaned_summary
151
 
152
  def improve_summary_generation(text, model, tokenizer):
153
- """Enhanced version of summary generation optimized for biomedical papers"""
154
- if not isinstance(text, str) or not text.strip():
155
- return "No abstract available to summarize."
156
-
157
- # Don't summarize if text is too short
158
- word_count = len(text.split())
159
- if word_count < 100: # Increased minimum length for medical texts
160
- return text
161
-
162
- # Preprocess text
163
- formatted_text = preprocess_text(text)
164
-
165
- # Prepare inputs
166
- inputs = tokenizer(
167
- formatted_text,
168
- return_tensors="pt",
169
- max_length=1024,
170
- truncation=True,
171
- padding=True
172
  )
 
 
 
173
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
174
 
175
- # Generate summary with parameters tuned for biomedical text
176
  with torch.no_grad():
177
  summary_ids = model.generate(
178
  **{
179
  "input_ids": inputs["input_ids"],
180
  "attention_mask": inputs["attention_mask"],
181
- "max_length": 300, # Increased for medical summaries
182
- "min_length": 100, # Increased to ensure comprehensive coverage
183
- "num_beams": 4,
184
- "length_penalty": 2.0, # Encourage slightly longer summaries
185
  "no_repeat_ngram_size": 3,
186
- "early_stopping": True,
187
- "do_sample": True, # Enable sampling
188
- "top_p": 0.95, # Nucleus sampling
189
- "temperature": 0.85, # Slightly higher temperature for medical terms
190
- "repetition_penalty": 1.5 # Increased to avoid repeated stats/numbers
191
  }
192
  )
193
 
194
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
195
 
196
- # Enhanced post-processing for medical text
197
- summary = post_process_medical_summary(summary)
198
-
199
- return summary
200
-
201
- def post_process_medical_summary(summary):
202
- """Special post-processing for medical/scientific summaries"""
203
- if not summary:
204
- return summary
205
-
206
- # Fix common medical text issues
207
- summary = (summary
208
- .replace(" p =", " p=") # Fix p-value spacing
209
- .replace(" n =", " n=") # Fix sample size spacing
210
- .replace("( ", "(") # Fix parentheses spacing
211
- .replace(" )", ")")
212
- .replace("vs.", "versus") # Expand abbreviations
213
- .replace("..", ".") # Fix double periods
214
- )
215
-
216
- # Ensure statistical significance symbols are correct
217
- summary = (summary
218
- .replace("p < ", "p<")
219
- .replace("p > ", "p>")
220
- .replace("P < ", "p<")
221
- .replace("P > ", "p>")
222
- )
223
-
224
- # Fix number formatting
225
- summary = (summary
226
- .replace(" +/- ", "±")
227
- .replace(" ± ", "±")
228
- )
229
-
230
- # Split into sentences and process each
231
- sentences = [s.strip() for s in summary.split('.')]
232
- processed_sentences = []
233
-
234
- for sentence in sentences:
235
- if sentence:
236
- # Capitalize first letter
237
- sentence = sentence[0].upper() + sentence[1:] if sentence else sentence
238
-
239
- # Fix common medical abbreviations spacing
240
- sentence = (sentence
241
- .replace(" et al ", " et al. ")
242
- .replace("et al.", "et al.") # Fix double period
243
- )
244
-
245
- processed_sentences.append(sentence)
246
-
247
- # Join sentences
248
- summary = '. '.join(processed_sentences)
249
-
250
- # Ensure proper ending
251
- if summary and not summary.endswith('.'):
252
- summary += '.'
253
-
254
- return summary
255
-
256
- def post_process_medical_summary(summary):
257
- """Special post-processing for medical/scientific summaries"""
258
- if not summary:
259
- return summary
260
-
261
- # Fix common medical text issues
262
- summary = (summary
263
- .replace(" p =", " p=") # Fix p-value spacing
264
- .replace(" n =", " n=") # Fix sample size spacing
265
- .replace("( ", "(") # Fix parentheses spacing
266
- .replace(" )", ")")
267
- .replace("vs.", "versus") # Expand abbreviations
268
- .replace("..", ".") # Fix double periods
269
- )
270
-
271
- # Ensure statistical significance symbols are correct
272
- summary = (summary
273
- .replace("p < ", "p<")
274
- .replace("p > ", "p>")
275
- .replace("P < ", "p<")
276
- .replace("P > ", "p>")
277
- )
278
-
279
- # Fix number formatting
280
- summary = (summary
281
- .replace(" +/- ", "±")
282
- .replace(" ± ", "±")
283
- )
284
-
285
- # Split into sentences and process each
286
- sentences = [s.strip() for s in summary.split('.')]
287
- processed_sentences = []
288
-
289
- for sentence in sentences:
290
- if sentence:
291
- # Capitalize first letter
292
- sentence = sentence[0].upper() + sentence[1:] if sentence else sentence
293
-
294
- # Fix common medical abbreviations spacing
295
- sentence = (sentence
296
- .replace(" et al ", " et al. ")
297
- .replace("et al.", "et al.") # Fix double period
298
- )
299
-
300
- processed_sentences.append(sentence)
301
-
302
- # Join sentences
303
- summary = '. '.join(processed_sentences)
304
-
305
- # Ensure proper ending
306
- if summary and not summary.endswith('.'):
307
- summary += '.'
308
-
309
- return summary
310
-
311
-
312
- def post_process_medical_summary(summary):
313
- """Special post-processing for medical/scientific summaries"""
314
  if not summary:
315
  return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
- # Fix common medical text issues
318
- summary = (summary
319
- .replace(" p =", " p=") # Fix p-value spacing
320
- .replace(" n =", " n=") # Fix sample size spacing
321
- .replace("( ", "(") # Fix parentheses spacing
322
- .replace(" )", ")")
323
- .replace("vs.", "versus") # Expand abbreviations
324
- .replace("..", ".") # Fix double periods
325
- )
326
-
327
- # Ensure statistical significance symbols are correct
328
- summary = (summary
329
- .replace("p < ", "p<")
330
- .replace("p > ", "p>")
331
- .replace("P < ", "p<")
332
- .replace("P > ", "p>")
333
- )
334
-
335
- # Fix number formatting
336
- summary = (summary
337
- .replace(" +/- ", "±")
338
- .replace(" ± ", "±")
339
- )
340
-
341
- # Split into sentences and process each
342
- sentences = [s.strip() for s in summary.split('.')]
343
- processed_sentences = []
344
-
345
- for sentence in sentences:
346
- if sentence:
347
- # Capitalize first letter
348
- sentence = sentence[0].upper() + sentence[1:] if sentence else sentence
349
-
350
- # Fix common medical abbreviations spacing
351
- sentence = (sentence
352
- .replace(" et al ", " et al. ")
353
- .replace("et al.", "et al.") # Fix double period
354
- )
355
-
356
- processed_sentences.append(sentence)
357
-
358
- # Join sentences
359
- summary = '. '.join(processed_sentences)
360
-
361
- # Ensure proper ending
362
- if summary and not summary.endswith('.'):
363
- summary += '.'
364
-
365
- return summary
366
 
367
  def generate_focused_summary(question, abstracts, model, tokenizer):
368
  """Generate focused summary based on question"""
@@ -388,6 +235,23 @@ def generate_focused_summary(question, abstracts, model, tokenizer):
388
 
389
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  def main():
392
  st.title("🔬 Biomedical Papers Analysis")
393
 
 
36
  device = "cpu" # Force CPU usage
37
 
38
  if model_type == "summarize":
39
+ # Load the new fine-tuned model directly
40
  model = AutoModelForSeq2SeqLM.from_pretrained(
41
  "pendar02/bart-large-pubmedd",
42
  cache_dir="./models",
 
151
  return cleaned_summary
152
 
153
  def improve_summary_generation(text, model, tokenizer):
154
+ # Add a more specific prompt
155
+ formatted_text = (
156
+ "Summarize the following medical research paper, focusing on: "
157
+ "1) Study objectives 2) Methods 3) Key findings 4) Main conclusions. "
158
+ "Text: " + preprocess_text(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
160
+
161
+ # Adjust generation parameters
162
+ inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
163
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
164
 
 
165
  with torch.no_grad():
166
  summary_ids = model.generate(
167
  **{
168
  "input_ids": inputs["input_ids"],
169
  "attention_mask": inputs["attention_mask"],
170
+ "max_length": 200,
171
+ "min_length": 50,
172
+ "num_beams": 5,
173
+ "length_penalty": 1.5,
174
  "no_repeat_ngram_size": 3,
175
+ "temperature": 0.7,
176
+ "repetition_penalty": 1.5 # Increased to reduce repetition
 
 
 
177
  }
178
  )
179
 
180
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
181
 
182
+ def post_process_summary(summary):
183
+ """Enhanced post-processing to catch common errors"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  if not summary:
185
  return summary
186
+
187
+ # Remove contradictory age statements
188
+ age_statements = []
189
+ lines = summary.split('.')
190
+ cleaned_lines = []
191
+ for line in lines:
192
+ if "age" not in line.lower():
193
+ cleaned_lines.append(line)
194
+ elif not age_statements: # Only keep first age statement
195
+ age_statements.append(line)
196
+ cleaned_lines.append(line)
197
+
198
+ # Remove redundant statements
199
+ seen_content = set()
200
+ unique_lines = []
201
+ for line in cleaned_lines:
202
+ line_core = ' '.join(sorted(line.lower().split())) # Normalize for comparison
203
+ if line_core not in seen_content:
204
+ seen_content.add(line_core)
205
+ unique_lines.append(line)
206
 
207
+ # Join sentences with proper spacing and punctuation
208
+ cleaned_summary = '. '.join(s.strip() for s in unique_lines if s.strip())
209
+ if cleaned_summary and not cleaned_summary.endswith('.'):
210
+ cleaned_summary += '.'
211
+
212
+ return cleaned_summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  def generate_focused_summary(question, abstracts, model, tokenizer):
215
  """Generate focused summary based on question"""
 
235
 
236
  return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
237
 
238
+
239
+ def validate_summary(summary, original_text):
240
+ """Validate summary content against original text"""
241
+ # Check for age inconsistencies
242
+ age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
243
+ if len(age_mentions) > 1: # Multiple age mentions
244
+ return False
245
+
246
+ # Check for repetitive sentences
247
+ sentences = summary.split('.')
248
+ unique_sentences = set(s.strip().lower() for s in sentences if s.strip())
249
+ if len(sentences) - len(unique_sentences) > 1: # More than one duplicate
250
+ return False
251
+
252
+ return True
253
+
254
+
255
  def main():
256
  st.title("🔬 Biomedical Papers Analysis")
257