rivapereira123 commited on
Commit
a12c950
·
verified ·
1 Parent(s): b4c0f0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -31
app.py CHANGED
@@ -480,52 +480,64 @@ Provide clear, actionable advice while emphasizing the need for professional med
480
  return "\n\n".join(context_parts)
481
 
482
  def _generate_response(self, query: str, context: str) -> str:
483
- """Enhanced response generation using model.generate() to avoid DynamicCache errors"""
484
  if self.llm is None or self.tokenizer is None:
485
  return self._generate_fallback_response(query, context)
486
-
487
- # Build prompt with Gaza-specific context
488
- prompt = f"""{self.system_prompt}
489
-
490
  MEDICAL KNOWLEDGE CONTEXT:
491
  {context}
492
 
493
  PATIENT QUESTION: {query}
494
 
495
  RESPONSE (provide practical, Gaza-appropriate medical guidance):"""
496
-
497
  try:
498
- # Tokenize and move to correct device
499
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
500
- if hasattr(self.llm, 'device'):
501
- inputs = inputs.to(self.llm.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
- # Generate the response
504
- with torch.no_grad():
505
- outputs = self.generation_pipeline(prompt, max_new_tokens=300, temperature=0.3, repetition_penalty=1.15, no_repeat_ngram_size=3)
506
- response_text = outputs[0]["generated_text"]
507
 
 
 
 
 
 
 
 
508
 
509
- # Decode and clean up
510
- response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
511
-
512
- # Extract only the generated part
513
- if "RESPONSE (provide practical, Gaza-appropriate medical guidance):" in response_text:
514
- response_text = response_text.split("RESPONSE (provide practical, Gaza-appropriate medical guidance):")[1]
515
-
516
- # Clean up the response
517
- lines = response_text.split('\n')
518
- unique_lines = []
519
- unique_lines.append(line)
520
- final_response = '\n'.join(unique_lines)
521
- logger.info(f"🧪 Final cleaned response:\n{final_response}")
522
 
523
- return final_response
 
 
 
 
524
 
525
 
526
- except Exception as e:
527
- logger.error(f"❌ Error in LLM generate(): {e}")
528
- return self._generate_fallback_response(query, context)
529
 
530
  def _generate_fallback_response(self, query: str, context: str) -> str:
531
  """Enhanced fallback response with Gaza-specific guidance"""
 
480
  return "\n\n".join(context_parts)
481
 
482
  def _generate_response(self, query: str, context: str) -> str:
483
+ """Generate response using T5-style seq2seq model with Gaza-specific context"""
484
  if self.llm is None or self.tokenizer is None:
485
  return self._generate_fallback_response(query, context)
486
+ prompt = f"""{self.system_prompt}
 
 
 
487
  MEDICAL KNOWLEDGE CONTEXT:
488
  {context}
489
 
490
  PATIENT QUESTION: {query}
491
 
492
  RESPONSE (provide practical, Gaza-appropriate medical guidance):"""
 
493
  try:
494
+ inputs = self.tokenizer(
495
+ prompt,
496
+ return_tensors="pt",
497
+ truncation=True,
498
+ max_length=512,
499
+ padding="max_length"
500
+ )
501
+ input_ids = inputs["input_ids"]
502
+ attention_mask = inputs["attention_mask"]
503
+ device = self.llm.device if hasattr(self.llm, "device") else "cpu"
504
+ input_ids = input_ids.to(device)
505
+ attention_mask = attention_mask.to(device)
506
+
507
+ # Generate output
508
+ with torch.no_grad():
509
+ outputs = self.llm.generate(
510
+ input_ids=input_ids,
511
+ attention_mask=attention_mask,
512
+ max_new_tokens=256,
513
+ temperature=0.3,
514
+ pad_token_id=self.tokenizer.eos_token_id,
515
+ do_sample=True,
516
+ repetition_penalty=1.15,
517
+ no_repeat_ngram_size=3
518
+ )
519
 
520
+ # Decode result
521
+ response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
522
 
523
+ # Clean and filter output
524
+ lines = response_text.split('\n')
525
+ unique_lines = []
526
+ for line in lines:
527
+ line = line.strip()
528
+ if line and line not in unique_lines and len(line) > 10:
529
+ unique_lines.append(line)
530
 
531
+ final_response = '\n'.join(unique_lines)
532
+ logger.info(f"🧪 Final cleaned response:\n{final_response}")
 
 
 
 
 
 
 
 
 
 
 
533
 
534
+ return final_response
535
+
536
+ except Exception as e:
537
+ logger.error(f"❌ Error in LLM generate(): {e}")
538
+ return self._generate_fallback_response(query, context)
539
 
540
 
 
 
 
541
 
542
  def _generate_fallback_response(self, query: str, context: str) -> str:
543
  """Enhanced fallback response with Gaza-specific guidance"""