pendar02 commited on
Commit
054584c
·
verified ·
1 Parent(s): bb88c28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -81
app.py CHANGED
@@ -101,6 +101,69 @@ def process_excel(uploaded_file):
101
  st.error(f"Error processing file: {str(e)}")
102
  return None
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def preprocess_text(text):
105
  """Preprocess text to add appropriate formatting before summarization"""
106
  if not isinstance(text, str) or not text.strip():
@@ -109,8 +172,8 @@ def preprocess_text(text):
109
  # Split text into sentences (basic implementation)
110
  sentences = [s.strip() for s in text.replace('. ', '.\n').split('\n')]
111
 
112
- # Remove empty sentences
113
- sentences = [s for s in sentences if s]
114
 
115
  # Join with proper line breaks
116
  formatted_text = '\n'.join(sentences)
@@ -118,115 +181,179 @@ def preprocess_text(text):
118
  return formatted_text
119
 
120
  def post_process_summary(summary):
121
- """Clean up and improve summary coherence."""
122
  if not summary:
123
  return summary
124
-
125
- # Split into sentences
126
- sentences = [s.strip() for s in summary.split('.')]
127
- sentences = [s for s in sentences if s] # Remove empty sentences
128
-
129
- # Correct common issues
130
- processed_sentences = []
131
- for sentence in sentences:
132
- # Remove redundant phrases
133
- sentence = re.sub(r"\b(and and|appointment and appointment)\b", "and", sentence)
 
 
 
 
 
134
 
135
- # Ensure first letter capitalization
136
- sentence = sentence.capitalize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- # Avoid duplicates
139
- if sentence not in processed_sentences:
140
- processed_sentences.append(sentence)
 
 
141
 
142
- # Join sentences with proper punctuation
143
- cleaned_summary = '. '.join(processed_sentences)
144
- return cleaned_summary if cleaned_summary.endswith('.') else cleaned_summary + '.'
145
-
 
 
 
 
 
 
 
146
 
147
  def improve_summary_generation(text, model, tokenizer):
148
- """Generate improved summary with better prompt and validation."""
149
  if not isinstance(text, str) or not text.strip():
150
  return "No abstract available to summarize."
151
 
152
- # Add a structured prompt for summarization
153
  formatted_text = (
154
- "Summarize this biomedical research abstract into the following structure:\n"
155
- "1. Background and Objectives\n"
156
- "2. Methods\n"
157
- "3. Key Findings (include any percentages or numbers)\n"
158
- "4. Conclusions\n"
159
- f"Abstract:\n{text.strip()}"
 
 
 
 
 
 
160
  )
161
 
162
- # Prepare input tokens
163
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
164
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
165
 
166
- # Generate summary with adjusted parameters
167
- try:
168
  with torch.no_grad():
169
- summary_ids = model.generate(
170
- input_ids=inputs["input_ids"],
171
- attention_mask=inputs["attention_mask"],
172
- max_length=300, # Increased for more detailed summaries
173
- min_length=100, # Ensure summaries are not too short
174
- num_beams=5,
175
- length_penalty=1.5,
176
- no_repeat_ngram_size=3,
177
- temperature=0.7,
178
- repetition_penalty=1.3,
 
 
 
179
  )
180
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
181
- except Exception as e:
182
- return f"Error in generation: {str(e)}"
183
 
184
- # Post-process the summary
185
- return post_process_summary(summary)
186
-
 
 
 
187
 
188
- # Validate the summary
189
- if not validate_summary(processed_summary, text):
190
- # Retry with alternate generation parameters
191
- with torch.no_grad():
192
- summary_ids = model.generate(
193
- input_ids=inputs["input_ids"],
194
- attention_mask=inputs["attention_mask"],
195
- max_length=250,
196
- min_length=50,
197
- num_beams=4,
198
- length_penalty=2.0,
199
- no_repeat_ngram_size=4,
200
- temperature=0.8,
201
- repetition_penalty=1.5,
202
- )
203
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
204
  processed_summary = post_process_summary(summary)
205
-
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- return processed_summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  def validate_summary(summary, original_text):
210
- """Validate summary content against original text."""
211
- # Check for common validation points
212
- if not summary or len(summary.split()) < 20:
213
- return False # Too short
214
- if len(summary.split()) > len(original_text.split()) * 0.8:
215
- return False # Too long
216
-
217
- # Ensure structure is maintained (e.g., headings are present)
218
- required_sections = ["background and objectives", "methods", "key findings", "conclusions"]
219
- if not all(section.lower() in summary.lower() for section in required_sections):
220
  return False
221
-
222
- # Ensure no repetitive sentences
 
 
 
 
 
223
  sentences = summary.split('.')
224
- if len(sentences) != len(set(sentences)):
 
225
  return False
226
-
 
 
 
 
 
 
227
  return True
228
 
229
-
230
  def generate_focused_summary(question, abstracts, model, tokenizer):
231
  """Generate focused summary based on question"""
232
  # Preprocess each abstract
 
101
  st.error(f"Error processing file: {str(e)}")
102
  return None
103
 
104
+ def verify_facts(summary, original_text):
105
+ """Verify that key facts in the summary match the original text"""
106
+ # Extract numbers and percentages
107
+ def extract_numbers(text):
108
+ return set(re.findall(r'(\d+\.?\d*)%?', text))
109
+
110
+ original_numbers = extract_numbers(original_text)
111
+ summary_numbers = extract_numbers(summary)
112
+
113
+ # Check if all numbers from original are in summary
114
+ missing_numbers = original_numbers - summary_numbers
115
+
116
+ # Extract key phrases indicating relationships
117
+ relationship_patterns = [
118
+ r'associated with',
119
+ r'predicted',
120
+ r'correlated with',
121
+ r'relationship between',
122
+ r'linked to'
123
+ ]
124
+
125
+ def extract_relationships(text):
126
+ relationships = []
127
+ for pattern in relationship_patterns:
128
+ matches = re.finditer(pattern, text.lower())
129
+ for match in matches:
130
+ # Get surrounding context
131
+ start = max(0, match.start() - 50)
132
+ end = min(len(text), match.end() + 50)
133
+ relationships.append(text[start:end].strip())
134
+ return set(relationships)
135
+
136
+ original_relationships = extract_relationships(original_text)
137
+ summary_relationships = extract_relationships(summary)
138
+
139
+ # Check for contradictions
140
+ def find_contradictions(summary, original):
141
+ contradictions = []
142
+ # Common contradiction patterns
143
+ neg_patterns = [
144
+ (r'no association', r'associated with'),
145
+ (r'did not predict', r'predicted'),
146
+ (r'was not significant', r'was significant'),
147
+ (r'decreased', r'increased'),
148
+ (r'lower', r'higher')
149
+ ]
150
+
151
+ for pos, neg in neg_patterns:
152
+ if (re.search(pos, summary.lower()) and re.search(neg, original.lower())) or \
153
+ (re.search(neg, summary.lower()) and re.search(pos, original.lower())):
154
+ contradictions.append(f"Contradiction found: {pos} vs {neg}")
155
+
156
+ return contradictions
157
+
158
+ contradictions = find_contradictions(summary, original_text)
159
+
160
+ return {
161
+ 'missing_numbers': missing_numbers,
162
+ 'missing_relationships': original_relationships - summary_relationships,
163
+ 'contradictions': contradictions,
164
+ 'is_valid': len(missing_numbers) == 0 and len(contradictions) == 0
165
+ }
166
+
167
  def preprocess_text(text):
168
  """Preprocess text to add appropriate formatting before summarization"""
169
  if not isinstance(text, str) or not text.strip():
 
172
  # Split text into sentences (basic implementation)
173
  sentences = [s.strip() for s in text.replace('. ', '.\n').split('\n')]
174
 
175
+ # Remove empty sentences and extra whitespace
176
+ sentences = [re.sub(r'\s+', ' ', s).strip() for s in sentences if s.strip()]
177
 
178
  # Join with proper line breaks
179
  formatted_text = '\n'.join(sentences)
 
181
  return formatted_text
182
 
183
  def post_process_summary(summary):
184
+ """Enhanced post-processing for better structure and completeness"""
185
  if not summary:
186
  return summary
187
+
188
+ # Split into sections
189
+ sections = summary.split('\n')
190
+ processed_sections = []
191
+
192
+ for section in sections:
193
+ if not section.strip():
194
+ continue
195
+
196
+ # Remove redundant section headers
197
+ section = re.sub(r'^(Background and objectives|Methods|Results|Conclusions):\s*', '', section)
198
+
199
+ # Split into sentences
200
+ sentences = [s.strip() for s in section.split('.')]
201
+ sentences = [s for s in sentences if s]
202
 
203
+ processed_sentences = []
204
+ for i, sentence in enumerate(sentences):
205
+ # Fix common issues
206
+ sentence = re.sub(r'\s+', ' ', sentence) # Fix spacing
207
+ sentence = re.sub(r'(\d+)\s*%', r'\1%', sentence) # Fix percentage formatting
208
+ sentence = re.sub(r'\(\s*([Nn])\s*=\s*(\d+)\s*\)', r'(n=\2)', sentence) # Fix sample size formatting
209
+
210
+ # Fix common phrase issues
211
+ sentence = sentence.replace(" and and ", " and ")
212
+ sentence = sentence.replace("appointment and appointment", "appointment")
213
+ sentence = sentence.replace("Cancers distress", "Cancer distress")
214
+
215
+ # Remove redundant phrases
216
+ sentence = re.sub(r'(?i)the aim of (the|this) study was to', '', sentence)
217
+ sentence = re.sub(r'(?i)this study aimed to', '', sentence)
218
+
219
+ # Capitalize first letter
220
+ sentence = sentence.capitalize()
221
+
222
+ if sentence.strip():
223
+ processed_sentences.append(sentence)
224
 
225
+ if processed_sentences:
226
+ section = '. '.join(processed_sentences)
227
+ if not section.endswith('.'):
228
+ section += '.'
229
+ processed_sections.append(section)
230
 
231
+ # Ensure key sections are present
232
+ required_sections = ['Background and objectives', 'Methods', 'Key findings', 'Conclusions']
233
+ final_sections = []
234
+
235
+ for i, section in enumerate(processed_sections):
236
+ if i < len(required_sections):
237
+ final_sections.append(f"{required_sections[i]}: {section}")
238
+ else:
239
+ final_sections.append(section)
240
+
241
+ return '\n\n'.join(final_sections)
242
 
243
  def improve_summary_generation(text, model, tokenizer):
244
+ """Generate improved summary with better prompt and validation"""
245
  if not isinstance(text, str) or not text.strip():
246
  return "No abstract available to summarize."
247
 
248
+ # Add a more specific prompt with strict guidelines
249
  formatted_text = (
250
+ "Generate a precise summary of this medical research paper following these strict guidelines:\n"
251
+ "1. Background and objectives: State ONLY the actual study purpose and population - no assumptions\n"
252
+ "2. Methods: Include ONLY methods explicitly mentioned in the text\n"
253
+ "3. Key findings: Report ALL numerical results and statistical relationships\n"
254
+ "4. Conclusions: State ONLY conclusions directly supported by the reported results\n\n"
255
+ "Requirements:\n"
256
+ "- Include ALL percentages and numbers from the original text\n"
257
+ "- Do not repeat section headers\n"
258
+ "- Do not make claims beyond what's explicitly stated\n"
259
+ "- Maintain the original meaning without contradiction\n"
260
+ "- Do not introduce new information\n\n"
261
+ "Original text: " + preprocess_text(text)
262
  )
263
 
264
+ # Tokenize input
265
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
266
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
267
 
268
+ def generate_attempt(temperature, num_beams, length_penalty):
 
269
  with torch.no_grad():
270
+ return model.generate(
271
+ **{
272
+ "input_ids": inputs["input_ids"],
273
+ "attention_mask": inputs["attention_mask"],
274
+ "max_length": 300, # Increased to ensure all facts are included
275
+ "min_length": 100, # Increased to encourage more complete summaries
276
+ "num_beams": num_beams,
277
+ "length_penalty": length_penalty,
278
+ "no_repeat_ngram_size": 3,
279
+ "temperature": temperature,
280
+ "repetition_penalty": 2.0, # Increased to reduce repetition
281
+ "do_sample": True # Enable sampling for more diverse outputs
282
+ }
283
  )
 
 
 
284
 
285
+ # Try different parameter combinations until we get a valid summary
286
+ parameter_combinations = [
287
+ {"temperature": 0.7, "num_beams": 5, "length_penalty": 1.5},
288
+ {"temperature": 0.5, "num_beams": 8, "length_penalty": 2.0},
289
+ {"temperature": 0.3, "num_beams": 10, "length_penalty": 2.5}
290
+ ]
291
 
292
+ best_summary = None
293
+ best_verification = None
294
+
295
+ for params in parameter_combinations:
296
+ summary_ids = generate_attempt(**params)
 
 
 
 
 
 
 
 
 
 
297
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
298
  processed_summary = post_process_summary(summary)
299
+
300
+ # Verify facts in the summary
301
+ verification = verify_facts(processed_summary, text)
302
+
303
+ if verification['is_valid']:
304
+ return processed_summary
305
+
306
+ # Keep track of best attempt
307
+ if best_verification is None or \
308
+ len(verification['missing_numbers']) < len(best_verification['missing_numbers']):
309
+ best_summary = processed_summary
310
+ best_verification = verification
311
 
312
+ # If no perfect summary was generated, use the best attempt
313
+ # Add missing information if necessary
314
+ if best_verification and best_verification['missing_numbers']:
315
+ # Attempt to add missing numerical information
316
+ additional_info = []
317
+ original_sentences = text.split('.')
318
+ for num in best_verification['missing_numbers']:
319
+ # Find sentences containing the missing number
320
+ for sentence in original_sentences:
321
+ if str(num) in sentence:
322
+ additional_info.append(sentence.strip())
323
+ break
324
+
325
+ if additional_info:
326
+ best_summary += "\n\nAdditional key findings: " + ". ".join(additional_info) + "."
327
+
328
+ return best_summary
329
 
330
  def validate_summary(summary, original_text):
331
+ """Validate summary content against original text"""
332
+ # Perform fact verification
333
+ verification = verify_facts(summary, original_text)
334
+
335
+ if not verification['is_valid']:
 
 
 
 
 
336
  return False
337
+
338
+ # Check for age inconsistencies
339
+ age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
340
+ if len(age_mentions) > 1: # Multiple age mentions
341
+ return False
342
+
343
+ # Check for repetitive sentences
344
  sentences = summary.split('.')
345
+ unique_sentences = set(s.strip().lower() for s in sentences if s.strip())
346
+ if len(sentences) - len(unique_sentences) > 1: # More than one duplicate
347
  return False
348
+
349
+ # Check summary isn't too long or too short compared to original
350
+ summary_words = len(summary.split())
351
+ original_words = len(original_text.split())
352
+ if summary_words < 20 or summary_words > original_words * 0.8:
353
+ return False
354
+
355
  return True
356
 
 
357
  def generate_focused_summary(question, abstracts, model, tokenizer):
358
  """Generate focused summary based on question"""
359
  # Preprocess each abstract