Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -711,17 +711,24 @@ Provide clear, actionable advice while emphasizing the need for professional med
|
|
| 711 |
|
| 712 |
def _generate_response(self, query: str, context: str) -> str:
|
| 713 |
"""Enhanced response generation using model.generate() to avoid DynamicCache errors"""
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
|
|
|
|
| 725 |
outputs = self.llm.generate(
|
| 726 |
**inputs,
|
| 727 |
max_new_tokens=800,
|
|
@@ -730,11 +737,22 @@ Provide clear, actionable advice while emphasizing the need for professional med
|
|
| 730 |
do_sample=True,
|
| 731 |
repetition_penalty=1.15,
|
| 732 |
)
|
| 733 |
-
|
| 734 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
logger.error(f"Error in LLM generate(): {e}")
|
| 736 |
return self._generate_fallback_response(query, context)
|
| 737 |
|
|
|
|
| 738 |
# Decode and clean up
|
| 739 |
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 740 |
lines = response_text.split('\n')
|
|
|
|
| 711 |
|
| 712 |
def _generate_response(self, query: str, context: str) -> str:
|
| 713 |
"""Enhanced response generation using model.generate() to avoid DynamicCache errors"""
|
| 714 |
+
if self.llm is None or self.tokenizer is None:
|
| 715 |
+
return self._generate_fallback_response(query, context)
|
| 716 |
+
|
| 717 |
+
# 🧠 Build prompt (this was in the wrong place before)
|
| 718 |
+
prompt = f"""{self.system_prompt}
|
| 719 |
+
|
| 720 |
+
MEDICAL KNOWLEDGE CONTEXT:
|
| 721 |
+
{context}
|
| 722 |
+
|
| 723 |
+
PATIENT QUESTION: {query}
|
| 724 |
+
|
| 725 |
+
RESPONSE (provide practical, Gaza-appropriate medical guidance):"""
|
| 726 |
+
|
| 727 |
+
try:
|
| 728 |
+
# ✅ Tokenize and move to correct device
|
| 729 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.llm.device)
|
| 730 |
|
| 731 |
+
# ✅ Generate the response
|
| 732 |
outputs = self.llm.generate(
|
| 733 |
**inputs,
|
| 734 |
max_new_tokens=800,
|
|
|
|
| 737 |
do_sample=True,
|
| 738 |
repetition_penalty=1.15,
|
| 739 |
)
|
| 740 |
+
|
| 741 |
+
# ✅ Decode and clean up
|
| 742 |
+
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 743 |
+
lines = response_text.split('\n')
|
| 744 |
+
unique_lines = []
|
| 745 |
+
for line in lines:
|
| 746 |
+
line = line.strip()
|
| 747 |
+
if line and line not in unique_lines:
|
| 748 |
+
unique_lines.append(line)
|
| 749 |
+
return '\n'.join(unique_lines)
|
| 750 |
+
|
| 751 |
+
except Exception as e:
|
| 752 |
logger.error(f"Error in LLM generate(): {e}")
|
| 753 |
return self._generate_fallback_response(query, context)
|
| 754 |
|
| 755 |
+
|
| 756 |
# Decode and clean up
|
| 757 |
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 758 |
lines = response_text.split('\n')
|