pendar02 commited on
Commit
7262cba
·
verified ·
1 Parent(s): 0d0c8c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -143
app.py CHANGED
@@ -117,173 +117,122 @@ def preprocess_text(text):
117
 
118
  return formatted_text
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def improve_summary_generation(text, model, tokenizer):
122
- """Generate improved summary with better prompt engineering and validation"""
123
  if not isinstance(text, str) or not text.strip():
124
  return "No abstract available to summarize."
125
 
126
- # Create a more structured prompt that enforces accurate reporting
127
  formatted_text = (
128
- "Summarize this medical research paper accurately and concisely. "
129
- "Include only factual information from the text. "
130
- "Structure the summary as follows:\n"
131
- "1. OBJECTIVE: State the main purpose and study population\n"
132
- "2. METHODS: Describe key methodological elements\n"
133
- "3. RESULTS: Report specific findings with exact numbers/percentages\n"
134
- "4. CONCLUSION: State main implications\n\n"
135
  "Original text: " + preprocess_text(text)
136
  )
137
 
138
- # First attempt with conservative parameters
139
- summary = generate_summary_attempt(formatted_text, model, tokenizer,
140
- conservative_params=True)
141
-
142
- # Validate the generated summary
143
- if not validate_summary(summary, text):
144
- # If validation fails, try again with different parameters
145
- summary = generate_summary_attempt(formatted_text, model, tokenizer,
146
- conservative_params=False)
147
-
148
- return post_process_summary(summary)
149
-
150
- def generate_summary_attempt(formatted_text, model, tokenizer, conservative_params=True):
151
- """Generate a summary with specified parameters"""
152
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
153
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
154
 
155
- params = {
156
- "input_ids": inputs["input_ids"],
157
- "attention_mask": inputs["attention_mask"],
158
- "max_length": 250, # Increased for better coverage
159
- "min_length": 100, # Increased to ensure comprehensive summary
160
- "early_stopping": True,
161
- "no_repeat_ngram_size": 3,
162
- }
163
-
164
- if conservative_params:
165
- params.update({
166
- "num_beams": 5,
167
- "length_penalty": 1.5,
168
- "temperature": 0.7,
169
- "top_p": 0.9,
170
- "repetition_penalty": 1.5
171
- })
172
- else:
173
- params.update({
174
- "num_beams": 4,
175
- "length_penalty": 2.0,
176
- "temperature": 0.8,
177
- "top_p": 0.95,
178
- "repetition_penalty": 2.0
179
- })
180
-
181
  with torch.no_grad():
182
- summary_ids = model.generate(**params)
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def validate_summary(summary, original_text):
187
- """Enhanced validation of summary content"""
188
- if not summary or not original_text:
 
 
189
  return False
190
-
191
- # Extract numerical values from both texts
192
- original_numbers = set(re.findall(r'(\d+(?:\.\d+)?)\s*%', original_text))
193
- summary_numbers = set(re.findall(r'(\d+(?:\.\d+)?)\s*%', summary))
194
 
195
- # Check if key percentages are preserved
196
- if not summary_numbers.issubset(original_numbers):
 
 
197
  return False
198
 
199
- # Check for contradictions in methodology statements
200
- methods_original = extract_methods(original_text)
201
- methods_summary = extract_methods(summary)
202
- if methods_summary and not any(m in original_text.lower() for m in methods_summary):
203
  return False
204
 
205
- # Verify no hallucinated content
206
- sentences = summary.split('.')
207
- for sentence in sentences:
208
- # Check if key claims in summary are supported by original
209
- if sentence.strip() and not is_supported_by_original(sentence, original_text):
210
- return False
211
-
212
- return True
213
-
214
- def extract_methods(text):
215
- """Extract methodology-related terms"""
216
- method_keywords = ['study', 'survey', 'analysis', 'trial', 'experiment']
217
- methods = []
218
- for keyword in method_keywords:
219
- pattern = fr'{keyword}\s+\w+'
220
- matches = re.findall(pattern, text.lower())
221
- methods.extend(matches)
222
- return methods
223
-
224
- def is_supported_by_original(claim, original):
225
- """Check if a claim from summary is supported by original text"""
226
- # Remove common filler phrases
227
- claim = re.sub(r'(this study|the study|results show|we found that)', '', claim.lower()).strip()
228
-
229
- # Split into key phrases
230
- key_phrases = [p.strip() for p in claim.split(' and ')]
231
-
232
- # Check if each key phrase has supporting evidence
233
- for phrase in key_phrases:
234
- if phrase and not has_supporting_evidence(phrase, original.lower()):
235
- return False
236
  return True
237
 
238
- def has_supporting_evidence(phrase, original):
239
- """Check if there's supporting evidence for a phrase"""
240
- # Convert to word sets for flexible matching
241
- phrase_words = set(phrase.split())
242
- original_sentences = [set(s.split()) for s in original.split('.')]
243
-
244
- # Check if any sentence contains most of the phrase words
245
- return any(len(phrase_words.intersection(sent)) >= len(phrase_words) * 0.7
246
- for sent in original_sentences)
247
-
248
- def post_process_summary(summary):
249
- """Enhanced post-processing of generated summary"""
250
- if not summary:
251
- return summary
252
-
253
- # Split into sections based on the structured format
254
- sections = []
255
- current_section = []
256
-
257
- for line in summary.split('\n'):
258
- line = line.strip()
259
- if any(marker in line.upper() for marker in ['OBJECTIVE:', 'METHODS:', 'RESULTS:', 'CONCLUSION:']):
260
- if current_section:
261
- sections.append(' '.join(current_section))
262
- current_section = [line]
263
- elif line:
264
- current_section.append(line)
265
-
266
- if current_section:
267
- sections.append(' '.join(current_section))
268
-
269
- # Clean up each section
270
- cleaned_sections = []
271
- for section in sections:
272
- # Fix common issues
273
- section = re.sub(r'\s+', ' ', section) # Remove multiple spaces
274
- section = re.sub(r'(\d+)\s*%', r'\1%', section) # Fix percentage formatting
275
- section = re.sub(r'(\.|,)\s*(\d)', r'\1 \2', section) # Fix number spacing
276
- cleaned_sections.append(section)
277
-
278
- # Join sections with proper spacing
279
- final_summary = '\n'.join(cleaned_sections)
280
-
281
- # Ensure proper ending
282
- if final_summary and not final_summary.endswith('.'):
283
- final_summary += '.'
284
-
285
- return final_summary
286
-
287
  def generate_focused_summary(question, abstracts, model, tokenizer):
288
  """Generate focused summary based on question"""
289
  # Preprocess each abstract
 
117
 
118
  return formatted_text
119
 
120
+ def post_process_summary(summary):
121
+ """Clean up and improve summary coherence"""
122
+ if not summary:
123
+ return summary
124
+
125
+ # Split into sentences
126
+ sentences = [s.strip() for s in summary.split('.')]
127
+ sentences = [s for s in sentences if s] # Remove empty sentences
128
+
129
+ # Fix common issues
130
+ processed_sentences = []
131
+ for i, sentence in enumerate(sentences):
132
+ # Remove redundant words/phrases
133
+ sentence = sentence.replace(" and and ", " and ")
134
+ sentence = sentence.replace("appointment and appointment", "appointment")
135
+
136
+ # Fix common grammatical issues
137
+ sentence = sentence.replace("Cancers distress", "Cancer distress")
138
+ sentence = sentence.replace(" ", " ") # Remove double spaces
139
+
140
+ # Capitalize first letter of each sentence
141
+ sentence = sentence.capitalize()
142
+
143
+ # Add to processed sentences if not empty
144
+ if sentence.strip():
145
+ processed_sentences.append(sentence)
146
+
147
+ # Join sentences with proper spacing and punctuation
148
+ cleaned_summary = '. '.join(processed_sentences)
149
+ if cleaned_summary and not cleaned_summary.endswith('.'):
150
+ cleaned_summary += '.'
151
+
152
+ return cleaned_summary
153
 
154
  def improve_summary_generation(text, model, tokenizer):
155
+ """Generate improved summary with better prompt and validation"""
156
  if not isinstance(text, str) or not text.strip():
157
  return "No abstract available to summarize."
158
 
159
+ # Add a more specific prompt
160
  formatted_text = (
161
+ "Summarize this medical research paper following this structure exactly:\n"
162
+ "1. Background and objectives\n"
163
+ "2. Methods\n"
164
+ "3. Key findings with specific numbers/percentages\n"
165
+ "4. Main conclusions\n"
 
 
166
  "Original text: " + preprocess_text(text)
167
  )
168
 
169
+ # Adjust generation parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
171
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  with torch.no_grad():
174
+ summary_ids = model.generate(
175
+ **{
176
+ "input_ids": inputs["input_ids"],
177
+ "attention_mask": inputs["attention_mask"],
178
+ "max_length": 200,
179
+ "min_length": 50,
180
+ "num_beams": 5,
181
+ "length_penalty": 1.5,
182
+ "no_repeat_ngram_size": 3,
183
+ "temperature": 0.7,
184
+ "repetition_penalty": 1.5
185
+ }
186
+ )
187
 
188
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
189
+
190
+ # Post-process the summary
191
+ processed_summary = post_process_summary(summary)
192
+
193
+ # Validate the summary
194
+ if not validate_summary(processed_summary, text):
195
+ # If validation fails, try one more time with different parameters
196
+ with torch.no_grad():
197
+ summary_ids = model.generate(
198
+ **{
199
+ "input_ids": inputs["input_ids"],
200
+ "attention_mask": inputs["attention_mask"],
201
+ "max_length": 200,
202
+ "min_length": 50,
203
+ "num_beams": 4,
204
+ "length_penalty": 2.0,
205
+ "no_repeat_ngram_size": 4,
206
+ "temperature": 0.8,
207
+ "repetition_penalty": 2.0
208
+ }
209
+ )
210
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
211
+ processed_summary = post_process_summary(summary)
212
+
213
+ return processed_summary
214
 
215
  def validate_summary(summary, original_text):
216
+ """Validate summary content against original text"""
217
+ # Check for age inconsistencies
218
+ age_mentions = re.findall(r'(\d+\.?\d*)\s*years?', summary.lower())
219
+ if len(age_mentions) > 1: # Multiple age mentions
220
  return False
 
 
 
 
221
 
222
+ # Check for repetitive sentences
223
+ sentences = summary.split('.')
224
+ unique_sentences = set(s.strip().lower() for s in sentences if s.strip())
225
+ if len(sentences) - len(unique_sentences) > 1: # More than one duplicate
226
  return False
227
 
228
+ # Check summary isn't too long or too short compared to original
229
+ summary_words = len(summary.split())
230
+ original_words = len(original_text.split())
231
+ if summary_words < 20 or summary_words > original_words * 0.8:
232
  return False
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return True
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def generate_focused_summary(question, abstracts, model, tokenizer):
237
  """Generate focused summary based on question"""
238
  # Preprocess each abstract