Spaces:
Sleeping
Sleeping
Commit
Β·
44a2e1d
0
Parent(s):
Initial commit - Pregnancy RAG Chatbot
Browse files- .gitattributes +35 -0
- README.md +14 -0
- app.py +477 -0
- rag_functions.py +246 -0
- requirements.txt +0 -0
- utils.py +184 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Pregnancy RAG Chatbot
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.35.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
short_description: Pregnancy Risk Assessment AI Chatbot
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import traceback
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from rag_functions import get_direct_answer, get_answer_with_query_engine
|
| 12 |
+
from utils import get_index
|
| 13 |
+
print("β
Successfully imported RAG functions")
|
| 14 |
+
|
| 15 |
+
class PregnancyRiskAgent:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.conversation_history = []
|
| 18 |
+
self.current_symptoms = {}
|
| 19 |
+
self.risk_assessment_done = False
|
| 20 |
+
self.user_context = {}
|
| 21 |
+
self.last_user_query = ""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
self.symptom_questions = [
|
| 25 |
+
"Are you currently experiencing any unusual bleeding or discharge?",
|
| 26 |
+
"How would you describe your baby's movements today compared to yesterday?",
|
| 27 |
+
"Have you had any headaches that won't go away or that affect your vision?",
|
| 28 |
+
"Do you feel any pressure or pain in your pelvis or lower back?",
|
| 29 |
+
"Are you experiencing any other symptoms? (If yes, please describe briefly)"
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
self.current_question_index = 0
|
| 33 |
+
self.waiting_for_first_response = True
|
| 34 |
+
|
| 35 |
+
def add_to_conversation_history(self, role, message):
|
| 36 |
+
self.conversation_history.append({
|
| 37 |
+
"role": role,
|
| 38 |
+
"message": message,
|
| 39 |
+
"timestamp": datetime.now().isoformat()
|
| 40 |
+
})
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if len(self.conversation_history) > 20:
|
| 44 |
+
self.conversation_history = self.conversation_history[-20:]
|
| 45 |
+
|
| 46 |
+
def get_conversation_context(self):
|
| 47 |
+
context_parts = []
|
| 48 |
+
|
| 49 |
+
recent_history = self.conversation_history[-10:]
|
| 50 |
+
|
| 51 |
+
for entry in recent_history:
|
| 52 |
+
if entry["role"] == "user":
|
| 53 |
+
context_parts.append(f"User: {entry['message']}")
|
| 54 |
+
else:
|
| 55 |
+
context_parts.append(f"Assistant: {entry['message'][:200]}...")
|
| 56 |
+
|
| 57 |
+
return "\n".join(context_parts)
|
| 58 |
+
|
| 59 |
+
def is_follow_up_question(self, user_input):
|
| 60 |
+
follow_up_indicators = [
|
| 61 |
+
"what about", "can you explain", "what does", "why", "how",
|
| 62 |
+
"tell me more", "what should i", "is it normal", "should i be worried",
|
| 63 |
+
"what if", "when should", "how long", "what causes", "is this"
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
user_lower = user_input.lower()
|
| 67 |
+
return any(indicator in user_lower for indicator in follow_up_indicators)
|
| 68 |
+
|
| 69 |
+
def process_user_input(self, user_input, chat_history):
|
| 70 |
+
try:
|
| 71 |
+
self.last_user_query = user_input
|
| 72 |
+
self.add_to_conversation_history("user", user_input)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if self.waiting_for_first_response:
|
| 76 |
+
self.current_symptoms[f"question_0"] = user_input
|
| 77 |
+
self.waiting_for_first_response = False
|
| 78 |
+
self.current_question_index = 1
|
| 79 |
+
|
| 80 |
+
if self.current_question_index < len(self.symptom_questions):
|
| 81 |
+
bot_response = f"{self.symptom_questions[self.current_question_index]}"
|
| 82 |
+
else:
|
| 83 |
+
bot_response = self.provide_risk_assessment()
|
| 84 |
+
self.risk_assessment_done = True
|
| 85 |
+
|
| 86 |
+
self.add_to_conversation_history("assistant", bot_response)
|
| 87 |
+
return bot_response
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
elif self.current_question_index < len(self.symptom_questions) and not self.risk_assessment_done:
|
| 91 |
+
self.current_symptoms[f"question_{self.current_question_index}"] = user_input
|
| 92 |
+
self.current_question_index += 1
|
| 93 |
+
|
| 94 |
+
if self.current_question_index < len(self.symptom_questions):
|
| 95 |
+
bot_response = f"{self.symptom_questions[self.current_question_index]}"
|
| 96 |
+
else:
|
| 97 |
+
bot_response = self.provide_risk_assessment()
|
| 98 |
+
self.risk_assessment_done = True
|
| 99 |
+
|
| 100 |
+
self.add_to_conversation_history("assistant", bot_response)
|
| 101 |
+
return bot_response
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
else:
|
| 105 |
+
bot_response = self.handle_follow_up_conversation(user_input)
|
| 106 |
+
self.add_to_conversation_history("assistant", bot_response)
|
| 107 |
+
return bot_response
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"β Error in process_user_input: {e}")
|
| 111 |
+
traceback.print_exc()
|
| 112 |
+
error_response = "I encountered an error. Please try again or consult your healthcare provider."
|
| 113 |
+
self.add_to_conversation_history("assistant", error_response)
|
| 114 |
+
return error_response
|
| 115 |
+
|
| 116 |
+
def handle_follow_up_conversation(self, user_input):
|
| 117 |
+
try:
|
| 118 |
+
print(f"π Processing follow-up question: {user_input}")
|
| 119 |
+
|
| 120 |
+
symptom_summary = self.create_symptom_summary()
|
| 121 |
+
conversation_context = self.get_conversation_context()
|
| 122 |
+
|
| 123 |
+
if any(word in user_input.lower() for word in ["last", "previous", "what did i ask", "my question"]):
|
| 124 |
+
if self.last_user_query:
|
| 125 |
+
return f"Your last question was: \"{self.last_user_query}\"\n\nWould you like me to elaborate on that topic or do you have a different question?"
|
| 126 |
+
else:
|
| 127 |
+
return "I don't have a record of your previous question. Could you please rephrase what you'd like to know?"
|
| 128 |
+
|
| 129 |
+
rag_response = get_direct_answer(user_input, symptom_summary, conversation_context=conversation_context, is_risk_assessment=False)
|
| 130 |
+
|
| 131 |
+
if "Error" in rag_response or len(rag_response) < 50:
|
| 132 |
+
print("π Trying alternative method...")
|
| 133 |
+
rag_response = get_answer_with_query_engine(user_input)
|
| 134 |
+
|
| 135 |
+
bot_response = f"""Based on your symptoms and medical literature:
|
| 136 |
+
|
| 137 |
+
{rag_response}"""
|
| 138 |
+
|
| 139 |
+
return bot_response
|
| 140 |
+
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"β Error in follow-up conversation: {e}")
|
| 143 |
+
return "I encountered an error processing your question. Could you please rephrase it or consult your healthcare provider?"
|
| 144 |
+
|
| 145 |
+
def create_symptom_summary(self):
|
| 146 |
+
if not self.current_symptoms:
|
| 147 |
+
return "No specific symptoms reported yet"
|
| 148 |
+
|
| 149 |
+
summary_parts = []
|
| 150 |
+
for i, (key, response) in enumerate(self.current_symptoms.items()):
|
| 151 |
+
if i < len(self.symptom_questions):
|
| 152 |
+
question = self.symptom_questions[i]
|
| 153 |
+
summary_parts.append(f"{question}: {response}")
|
| 154 |
+
return "\n".join(summary_parts)
|
| 155 |
+
|
| 156 |
+
def parse_risk_level(self, text):
|
| 157 |
+
import re
|
| 158 |
+
|
| 159 |
+
patterns = [
|
| 160 |
+
r'\*\*Risk Level:\*\*\s*(Low|Medium|High)',
|
| 161 |
+
r'Risk Level:\s*\*\*(Low|Medium|High)\*\*',
|
| 162 |
+
r'Risk Level:\s*(Low|Medium|High)',
|
| 163 |
+
r'\*\*Risk Level:\*\*\s*<(Low|Medium|High)>',
|
| 164 |
+
r'Risk Level.*?<(Low|Medium|High)>',
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
for pattern in patterns:
|
| 168 |
+
match = re.search(pattern, text, re.IGNORECASE)
|
| 169 |
+
if match:
|
| 170 |
+
risk_level = match.group(1).capitalize()
|
| 171 |
+
print(f"β
Successfully parsed risk level: {risk_level}")
|
| 172 |
+
return risk_level
|
| 173 |
+
|
| 174 |
+
print(f"β Could not parse risk level from: {text[:200]}...")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
def provide_risk_assessment(self):
|
| 178 |
+
all_symptoms = self.create_symptom_summary()
|
| 179 |
+
|
| 180 |
+
rag_query = f"Analyze these pregnancy symptoms for risk assessment:\n{all_symptoms}\n\nProvide risk level and medical recommendations."
|
| 181 |
+
detailed_analysis = get_direct_answer(rag_query, all_symptoms, is_risk_assessment=True)
|
| 182 |
+
|
| 183 |
+
print(f"π RAG Response: {detailed_analysis[:300]}...")
|
| 184 |
+
|
| 185 |
+
llm_risk_level = self.parse_risk_level(detailed_analysis)
|
| 186 |
+
|
| 187 |
+
if llm_risk_level:
|
| 188 |
+
risk_level = llm_risk_level
|
| 189 |
+
|
| 190 |
+
if risk_level == "Low":
|
| 191 |
+
action = "β
Continue routine prenatal care and self-monitoring"
|
| 192 |
+
elif risk_level == "Medium":
|
| 193 |
+
action = "β οΈ Contact your doctor within 24 hours"
|
| 194 |
+
elif risk_level == "High":
|
| 195 |
+
action = "π¨ Immediate visit to ER or OB emergency care required"
|
| 196 |
+
else:
|
| 197 |
+
print("β οΈ RAG assessment failed, using fallback")
|
| 198 |
+
risk_level = "Medium"
|
| 199 |
+
action = "β οΈ Contact your doctor within 24 hours"
|
| 200 |
+
|
| 201 |
+
symptom_list = []
|
| 202 |
+
for i, (key, symptom) in enumerate(self.current_symptoms.items()):
|
| 203 |
+
question = self.symptom_questions[i] if i < len(self.symptom_questions) else f"Question {i+1}"
|
| 204 |
+
symptom_list.append(f"β’ **{question}**: {symptom}")
|
| 205 |
+
|
| 206 |
+
assessment = f"""
|
| 207 |
+
## π₯ **Risk Assessment Complete**
|
| 208 |
+
|
| 209 |
+
**Risk Level: {risk_level}**
|
| 210 |
+
**Recommended Action: {action}**
|
| 211 |
+
|
| 212 |
+
### π **Your Reported Symptoms:**
|
| 213 |
+
{chr(10).join(symptom_list)}
|
| 214 |
+
|
| 215 |
+
### π¬ **Medical Analysis:**
|
| 216 |
+
{detailed_analysis}
|
| 217 |
+
|
| 218 |
+
### π‘ **Next Steps:**
|
| 219 |
+
- Follow the recommended action above
|
| 220 |
+
- Keep monitoring your symptoms
|
| 221 |
+
- Contact your healthcare provider if symptoms worsen
|
| 222 |
+
- Feel free to ask me any follow-up questions about pregnancy health
|
| 223 |
+
|
| 224 |
+
"""
|
| 225 |
+
return assessment
|
| 226 |
+
|
| 227 |
+
def reset_conversation(self):
|
| 228 |
+
self.conversation_history = []
|
| 229 |
+
self.current_symptoms = {}
|
| 230 |
+
self.current_question_index = 0
|
| 231 |
+
self.risk_assessment_done = False
|
| 232 |
+
self.waiting_for_first_response = True
|
| 233 |
+
self.user_context = {}
|
| 234 |
+
self.last_user_query = ""
|
| 235 |
+
return get_welcome_message()
|
| 236 |
+
|
| 237 |
+
def get_welcome_message():
|
| 238 |
+
return """Hello! I'm here to help assess pregnancy-related symptoms and provide risk insights based on medical literature.
|
| 239 |
+
|
| 240 |
+
I'll ask you a few important questions about your current symptoms, then provide a risk assessment and recommendations. After that, feel free to ask any follow-up questions!
|
| 241 |
+
|
| 242 |
+
**To get started, please tell me:**
|
| 243 |
+
Are you currently experiencing any unusual bleeding or discharge?
|
| 244 |
+
|
| 245 |
+
---
|
| 246 |
+
β οΈ **Important**: This tool is for informational purposes only and should not replace professional medical care. In case of emergency, contact your healthcare provider immediately."""
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def create_new_agent():
|
| 250 |
+
|
| 251 |
+
return PregnancyRiskAgent()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
agent = create_new_agent()
|
| 255 |
+
|
| 256 |
+
def chat_interface_with_reset(user_input, history):
|
| 257 |
+
global agent
|
| 258 |
+
|
| 259 |
+
if user_input.lower() in ["reset", "restart", "new assessment"]:
|
| 260 |
+
agent = create_new_agent()
|
| 261 |
+
return get_welcome_message()
|
| 262 |
+
|
| 263 |
+
response = agent.process_user_input(user_input, history)
|
| 264 |
+
return response
|
| 265 |
+
|
| 266 |
+
def reset_chat():
|
| 267 |
+
global agent
|
| 268 |
+
agent = create_new_agent()
|
| 269 |
+
return [{"role": "assistant", "content": get_welcome_message()}], ""
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
custom_css = """
|
| 274 |
+
body, .gradio-container {
|
| 275 |
+
color: yellow !important;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
.header {
|
| 279 |
+
background: linear-gradient(135deg, #ff9a9e 0%, #fecfef 100%);
|
| 280 |
+
padding: 2rem;
|
| 281 |
+
border-radius: 1rem;
|
| 282 |
+
text-align: center;
|
| 283 |
+
margin-bottom: 2rem;
|
| 284 |
+
box-shadow: 0 4px 12px rgba(0,0,0,0.1);
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
.header h1 {
|
| 288 |
+
color: black !important;
|
| 289 |
+
margin-bottom: 0.5rem;
|
| 290 |
+
font-size: 2.5rem;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
.header p {
|
| 294 |
+
color: black !important;
|
| 295 |
+
font-size: 1.1rem;
|
| 296 |
+
margin: 0.5rem 0;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
.warning {
|
| 300 |
+
background-color: #fff4e6;
|
| 301 |
+
border-left: 6px solid #ff7f00;
|
| 302 |
+
padding: 15px;
|
| 303 |
+
border-radius: 5px;
|
| 304 |
+
margin: 10px 0;
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
.warning h3 {
|
| 308 |
+
color: black !important;
|
| 309 |
+
margin-top: 0;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
.warning p {
|
| 313 |
+
color: black !important;
|
| 314 |
+
line-height: 1.6;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
div[style*="background-color: #e8f5e8"] {
|
| 318 |
+
color: black !important;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
div[style*="background-color: #e8f5e8"] h3 {
|
| 322 |
+
color: black !important;
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
div[style*="background-color: #e8f5e8"] li {
|
| 326 |
+
color: black !important;
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
.chatbot {
|
| 330 |
+
color: black !important;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
.message {
|
| 334 |
+
color: black !important;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
/* Hide Gradio footer elements */
|
| 338 |
+
.footer {
|
| 339 |
+
display: none !important;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
.gradio-container .footer {
|
| 343 |
+
display: none !important;
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
footer {
|
| 347 |
+
display: none !important;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
.api-docs {
|
| 351 |
+
display: none !important;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
.built-with {
|
| 355 |
+
display: none !important;
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
.gradio-container > .built-with {
|
| 359 |
+
display: none !important;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
.settings {
|
| 363 |
+
display: none !important;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
div[class*="footer"] {
|
| 367 |
+
display: none !important;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
div[class*="built"] {
|
| 371 |
+
display: none !important;
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
*:contains("Built with Gradio") {
|
| 375 |
+
display: none !important;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
*:contains("Use via API") {
|
| 379 |
+
display: none !important;
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
*:contains("Settings") {
|
| 383 |
+
display: none !important;
|
| 384 |
+
}
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
with gr.Blocks(css=custom_css) as demo:
|
| 389 |
+
gr.HTML("""
|
| 390 |
+
<div class="header">
|
| 391 |
+
<h1>π€± Pregnancy RAG Chatbot</h1>
|
| 392 |
+
<p><strong style="color: black !important;">Proactive RAG-powered pregnancy risk management</strong></p>
|
| 393 |
+
</div>
|
| 394 |
+
""")
|
| 395 |
+
|
| 396 |
+
with gr.Row():
|
| 397 |
+
with gr.Column(scale=1):
|
| 398 |
+
gr.HTML("""
|
| 399 |
+
<div class="warning">
|
| 400 |
+
<h3>β οΈ Medical Disclaimer</h3>
|
| 401 |
+
<p>This AI assistant provides information based on medical literature but is NOT a substitute for professional medical advice, diagnosis, or treatment.</p>
|
| 402 |
+
<p><strong style="color: black !important;">In emergencies, call emergency services immediately.</strong></p>
|
| 403 |
+
</div>
|
| 404 |
+
""")
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
chatbot = gr.ChatInterface(
|
| 408 |
+
fn=chat_interface_with_reset,
|
| 409 |
+
chatbot=gr.Chatbot(
|
| 410 |
+
value=[{"role": "assistant", "content": get_welcome_message()}],
|
| 411 |
+
show_label=False,
|
| 412 |
+
type='messages'
|
| 413 |
+
),
|
| 414 |
+
textbox=gr.Textbox(
|
| 415 |
+
placeholder="Type your response here...",
|
| 416 |
+
show_label=False,
|
| 417 |
+
max_length=1000,
|
| 418 |
+
submit_btn=True
|
| 419 |
+
)
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
with gr.Row():
|
| 423 |
+
reset_btn = gr.Button("π Start New Assessment", variant="secondary")
|
| 424 |
+
|
| 425 |
+
reset_btn.click(
|
| 426 |
+
fn=reset_chat,
|
| 427 |
+
outputs=[chatbot.chatbot, chatbot.textbox],
|
| 428 |
+
show_progress=False
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def check_groq_connection():
|
| 433 |
+
try:
|
| 434 |
+
from backend.utils import llm
|
| 435 |
+
test_response = llm.complete("Hello")
|
| 436 |
+
print("β
Groq connection successful")
|
| 437 |
+
return True
|
| 438 |
+
except Exception as e:
|
| 439 |
+
print(f"β Groq connection failed: {e}")
|
| 440 |
+
return False
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def refresh_page():
|
| 444 |
+
"""Force a complete page refresh"""
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
if __name__ == "__main__":
|
| 450 |
+
print("π Starting GraviLog Pregnancy Risk Assessment Agent...")
|
| 451 |
+
check_groq_connection()
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
is_hf_space = os.getenv('SPACE_ID') is not None
|
| 455 |
+
|
| 456 |
+
if is_hf_space:
|
| 457 |
+
print("π Running on Hugging Face Spaces")
|
| 458 |
+
print("π Each page refresh will start a new conversation")
|
| 459 |
+
demo.queue().launch(
|
| 460 |
+
server_name="0.0.0.0",
|
| 461 |
+
server_port=7860,
|
| 462 |
+
share=False,
|
| 463 |
+
debug=False
|
| 464 |
+
)
|
| 465 |
+
else:
|
| 466 |
+
print("π Running locally")
|
| 467 |
+
print("π Using Groq API for LLM processing")
|
| 468 |
+
print("π Make sure your GROQ_API_KEY is set in environment variables")
|
| 469 |
+
print("π Make sure your Pinecone index is set up and populated")
|
| 470 |
+
|
| 471 |
+
demo.queue().launch(
|
| 472 |
+
server_name="0.0.0.0",
|
| 473 |
+
server_port=7860,
|
| 474 |
+
share=True,
|
| 475 |
+
debug=True,
|
| 476 |
+
show_error=True
|
| 477 |
+
)
|
rag_functions.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import Stemmer
|
| 3 |
+
import requests
|
| 4 |
+
from utils import get_and_chunk_documents, llm, embed_model, get_index
|
| 5 |
+
from utils import Settings
|
| 6 |
+
from llama_index.retrievers.bm25 import BM25Retriever
|
| 7 |
+
from llama_index.core.postprocessor import SentenceTransformerRerank
|
| 8 |
+
from llama_index.core.query_engine import RetrieverQueryEngine
|
| 9 |
+
from llama_index.core.response_synthesizers import get_response_synthesizer
|
| 10 |
+
from llama_index.core.settings import Settings
|
| 11 |
+
from llama_index.core import VectorStoreIndex
|
| 12 |
+
from llama_index.core.llms import ChatMessage
|
| 13 |
+
from llama_index.core.retrievers import QueryFusionRetriever
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
Settings.llm = llm
|
| 18 |
+
Settings.embed_model = embed_model
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
index = get_index()
|
| 22 |
+
hybrid_retriever = None
|
| 23 |
+
vector_retriever = None
|
| 24 |
+
bm25_retriever = None
|
| 25 |
+
|
| 26 |
+
if index:
|
| 27 |
+
try:
|
| 28 |
+
|
| 29 |
+
vector_retriever = index.as_retriever(similarity_top_k=15)
|
| 30 |
+
print("β
Vector retriever initialized successfully")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
all_nodes = index.docstore.docs
|
| 34 |
+
if len(all_nodes) == 0:
|
| 35 |
+
print("β οΈ Warning: No documents found in index, skipping BM25 retriever")
|
| 36 |
+
hybrid_retriever = vector_retriever
|
| 37 |
+
else:
|
| 38 |
+
|
| 39 |
+
has_text_content = False
|
| 40 |
+
for node_id, node in all_nodes.items():
|
| 41 |
+
if hasattr(node, 'text') and node.text and node.text.strip():
|
| 42 |
+
has_text_content = True
|
| 43 |
+
break
|
| 44 |
+
|
| 45 |
+
if not has_text_content:
|
| 46 |
+
print("β οΈ Warning: No text content found in documents, skipping BM25 retriever")
|
| 47 |
+
hybrid_retriever = vector_retriever
|
| 48 |
+
else:
|
| 49 |
+
try:
|
| 50 |
+
|
| 51 |
+
print("π Creating BM25 retriever...")
|
| 52 |
+
bm25_retriever = BM25Retriever.from_defaults(
|
| 53 |
+
docstore=index.docstore,
|
| 54 |
+
similarity_top_k=15,
|
| 55 |
+
verbose=False
|
| 56 |
+
)
|
| 57 |
+
print("β
BM25 retriever initialized successfully")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
hybrid_retriever = QueryFusionRetriever(
|
| 61 |
+
retrievers=[vector_retriever, bm25_retriever],
|
| 62 |
+
similarity_top_k=20,
|
| 63 |
+
num_queries=1,
|
| 64 |
+
mode="reciprocal_rerank",
|
| 65 |
+
use_async=False,
|
| 66 |
+
)
|
| 67 |
+
print("β
Hybrid retriever initialized successfully")
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"β Warning: Could not initialize BM25 retriever: {e}")
|
| 71 |
+
print("π Falling back to vector-only retrieval")
|
| 72 |
+
hybrid_retriever = vector_retriever
|
| 73 |
+
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"β Warning: Could not initialize retrievers: {e}")
|
| 76 |
+
hybrid_retriever = None
|
| 77 |
+
vector_retriever = None
|
| 78 |
+
bm25_retriever = None
|
| 79 |
+
else:
|
| 80 |
+
print("β Warning: Could not initialize retrievers - index is None")
|
| 81 |
+
|
| 82 |
+
def call_groq_api(prompt):
|
| 83 |
+
"""Call Groq API instead of LM Studio"""
|
| 84 |
+
try:
|
| 85 |
+
|
| 86 |
+
response = Settings.llm.complete(prompt)
|
| 87 |
+
return str(response)
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"β Groq API call failed: {e}")
|
| 90 |
+
raise e
|
| 91 |
+
|
| 92 |
+
def get_direct_answer(question, symptom_summary, conversation_context="", max_context_nodes=8, is_risk_assessment=True):
|
| 93 |
+
"""Get answer using hybrid retriever with retrieved context"""
|
| 94 |
+
|
| 95 |
+
print(f"π― Processing question: {question}")
|
| 96 |
+
|
| 97 |
+
if not hybrid_retriever:
|
| 98 |
+
return "Error: Retriever not available. Please check if documents are properly loaded in the index."
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
|
| 102 |
+
print("π Retrieving with available retrieval method...")
|
| 103 |
+
retrieved_nodes = hybrid_retriever.retrieve(question)
|
| 104 |
+
print(f"π Retrieved {len(retrieved_nodes)} nodes")
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"β Retrieval failed: {e}")
|
| 108 |
+
return f"Error during document retrieval: {e}. Please check your document index."
|
| 109 |
+
|
| 110 |
+
if not retrieved_nodes:
|
| 111 |
+
return "No relevant documents found for this question. Please ensure your medical knowledge base is properly loaded and consult your healthcare provider for medical advice."
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
reranker = SentenceTransformerRerank(
|
| 116 |
+
model='cross-encoder/ms-marco-MiniLM-L-2-v2',
|
| 117 |
+
top_n=max_context_nodes,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
reranked_nodes = reranker.postprocess_nodes(retrieved_nodes, query_str=question)
|
| 121 |
+
print(f"π― After reranking: {len(reranked_nodes)} nodes")
|
| 122 |
+
|
| 123 |
+
except Exception as e:
|
| 124 |
+
print(f"β Reranking failed: {e}, using original nodes")
|
| 125 |
+
reranked_nodes = retrieved_nodes[:max_context_nodes]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
filtered_nodes = []
|
| 129 |
+
pregnancy_keywords = ['pregnancy', 'preeclampsia', 'gestational', 'trimester', 'fetal', 'bleeding', 'contractions', 'prenatal']
|
| 130 |
+
|
| 131 |
+
for node in reranked_nodes:
|
| 132 |
+
node_text = node.get_text().lower()
|
| 133 |
+
if any(keyword in node_text for keyword in pregnancy_keywords):
|
| 134 |
+
filtered_nodes.append(node)
|
| 135 |
+
|
| 136 |
+
if filtered_nodes:
|
| 137 |
+
reranked_nodes = filtered_nodes[:max_context_nodes]
|
| 138 |
+
print(f"π After pregnancy keyword filtering: {len(reranked_nodes)} nodes")
|
| 139 |
+
else:
|
| 140 |
+
print("β οΈ No pregnancy-related content found, using original nodes")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
context_chunks = []
|
| 144 |
+
total_chars = 0
|
| 145 |
+
max_context_chars = 6000
|
| 146 |
+
|
| 147 |
+
for node in reranked_nodes:
|
| 148 |
+
node_text = node.get_text()
|
| 149 |
+
if total_chars + len(node_text) <= max_context_chars:
|
| 150 |
+
context_chunks.append(node_text)
|
| 151 |
+
total_chars += len(node_text)
|
| 152 |
+
else:
|
| 153 |
+
remaining_chars = max_context_chars - total_chars
|
| 154 |
+
if remaining_chars > 100:
|
| 155 |
+
context_chunks.append(node_text[:remaining_chars] + "...")
|
| 156 |
+
break
|
| 157 |
+
|
| 158 |
+
context_text = "\n\n---\n\n".join(context_chunks)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if is_risk_assessment:
|
| 162 |
+
prompt = f"""You are the GraviLog Pregnancy Risk Assessment Agent. Use ONLY the context belowβdo not invent or add any new medical facts.
|
| 163 |
+
|
| 164 |
+
SYMPTOM RESPONSES:
|
| 165 |
+
{symptom_summary}
|
| 166 |
+
|
| 167 |
+
MEDICAL KNOWLEDGE:
|
| 168 |
+
{context_text}
|
| 169 |
+
|
| 170 |
+
Respond ONLY in this exact format (no extra text):
|
| 171 |
+
|
| 172 |
+
π₯ Risk Assessment Complete
|
| 173 |
+
**Risk Level:** <Low/Medium/High>
|
| 174 |
+
**Recommended Action:** <from KB's Risk Output Labels>
|
| 175 |
+
|
| 176 |
+
π¬ Rationale:
|
| 177 |
+
<One or two sentences citing which bullet(s) from the KB triggered your risk level.>"""
|
| 178 |
+
|
| 179 |
+
else:
|
| 180 |
+
|
| 181 |
+
prompt = f"""You are a pregnancy health assistant. Based on the medical knowledge below, answer the user's question about pregnancy symptoms and conditions.
|
| 182 |
+
|
| 183 |
+
USER QUESTION: {question}
|
| 184 |
+
|
| 185 |
+
CONVERSATION CONTEXT:
|
| 186 |
+
{conversation_context}
|
| 187 |
+
|
| 188 |
+
CURRENT SYMPTOMS REPORTED:
|
| 189 |
+
{symptom_summary}
|
| 190 |
+
|
| 191 |
+
MEDICAL KNOWLEDGE:
|
| 192 |
+
{context_text}
|
| 193 |
+
|
| 194 |
+
Provide a clear, informative answer based on the medical knowledge. Always mention if symptoms require medical attention and provide risk level (Low/Medium/High) when relevant."""
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
print("π€ Generating response with Groq API...")
|
| 198 |
+
response_text = call_groq_api(prompt)
|
| 199 |
+
return response_text
|
| 200 |
+
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f"β LLM response failed: {e}")
|
| 203 |
+
import traceback
|
| 204 |
+
traceback.print_exc()
|
| 205 |
+
return f"Error generating response: {e}"
|
| 206 |
+
|
| 207 |
+
def get_answer_with_query_engine(question):
|
| 208 |
+
"""Alternative approach using LlamaIndex query engine with hybrid retrieval"""
|
| 209 |
+
try:
|
| 210 |
+
print(f"π― Processing question with query engine: {question}")
|
| 211 |
+
|
| 212 |
+
if index is None:
|
| 213 |
+
return "Error: Could not load index"
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if hybrid_retriever:
|
| 217 |
+
query_engine = RetrieverQueryEngine.from_args(
|
| 218 |
+
retriever=hybrid_retriever,
|
| 219 |
+
response_synthesizer=get_response_synthesizer(
|
| 220 |
+
response_mode="compact",
|
| 221 |
+
use_async=False
|
| 222 |
+
),
|
| 223 |
+
node_postprocessors=[
|
| 224 |
+
SentenceTransformerRerank(
|
| 225 |
+
model='cross-encoder/ms-marco-MiniLM-L-2-v2',
|
| 226 |
+
top_n=5
|
| 227 |
+
)
|
| 228 |
+
]
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
|
| 232 |
+
query_engine = index.as_query_engine(
|
| 233 |
+
similarity_top_k=10,
|
| 234 |
+
response_mode="compact"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
print("π€ Querying with engine...")
|
| 238 |
+
response = query_engine.query(question)
|
| 239 |
+
|
| 240 |
+
return str(response)
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"β Query engine failed: {e}")
|
| 244 |
+
import traceback
|
| 245 |
+
traceback.print_exc()
|
| 246 |
+
return f"Error with query engine: {e}. Please check your setup and try again."
|
requirements.txt
ADDED
|
Binary file (8.69 kB). View file
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from pinecone import Pinecone, ServerlessSpec
|
| 4 |
+
from llama_index.core import (SimpleDirectoryReader,Document, VectorStoreIndex, StorageContext, load_index_from_storage)
|
| 5 |
+
from llama_index.core.node_parser import SemanticSplitterNodeParser
|
| 6 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 7 |
+
from llama_index.readers.file import CSVReader
|
| 8 |
+
from llama_index.vector_stores.pinecone import PineconeVectorStore
|
| 9 |
+
from llama_index.core.settings import Settings
|
| 10 |
+
from llama_index.llms.groq import Groq
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 18 |
+
llm = Groq(
|
| 19 |
+
model="llama-3.1-8b-instant",
|
| 20 |
+
api_key=os.getenv("GROQ_API_KEY"),
|
| 21 |
+
max_tokens=500,
|
| 22 |
+
temperature=0.1
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Settings.embed_model = embed_model
|
| 27 |
+
Settings.llm = llm
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
| 31 |
+
index_name = os.getenv("PINECONE_INDEX")
|
| 32 |
+
|
| 33 |
+
def get_vector_store():
|
| 34 |
+
|
| 35 |
+
pinecone_index = pc.Index(index_name)
|
| 36 |
+
return PineconeVectorStore(pinecone_index=pinecone_index)
|
| 37 |
+
|
| 38 |
+
def get_storage_context(for_rebuild=False):
|
| 39 |
+
|
| 40 |
+
vector_store = get_vector_store()
|
| 41 |
+
persist_dir = "./storage"
|
| 42 |
+
|
| 43 |
+
if for_rebuild or not os.path.exists(persist_dir):
|
| 44 |
+
|
| 45 |
+
return StorageContext.from_defaults(vector_store=vector_store)
|
| 46 |
+
else:
|
| 47 |
+
|
| 48 |
+
return StorageContext.from_defaults(
|
| 49 |
+
vector_store=vector_store,
|
| 50 |
+
persist_dir=persist_dir
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_and_chunk_documents():
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
|
| 61 |
+
file_extractor = {".csv": CSVReader()}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
documents = SimpleDirectoryReader(
|
| 65 |
+
"../knowledge_base",
|
| 66 |
+
file_extractor=file_extractor
|
| 67 |
+
).load_data()
|
| 68 |
+
|
| 69 |
+
print(f"π Loaded {len(documents)} documents")
|
| 70 |
+
|
| 71 |
+
node_parser = SemanticSplitterNodeParser(
|
| 72 |
+
buffer_size=1,
|
| 73 |
+
breakpoint_percentile_threshold=95,
|
| 74 |
+
embed_model=embed_model
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
nodes = node_parser.get_nodes_from_documents(documents)
|
| 78 |
+
print(f"π Created {len(nodes)} document chunks")
|
| 79 |
+
return nodes
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"β Error loading documents: {e}")
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_index():
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
storage_context = get_storage_context()
|
| 90 |
+
|
| 91 |
+
return load_index_from_storage(storage_context)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f"β οΈ Local storage not found, creating index from existing Pinecone data...")
|
| 94 |
+
try:
|
| 95 |
+
|
| 96 |
+
vector_store = get_vector_store()
|
| 97 |
+
storage_context = get_storage_context()
|
| 98 |
+
index = VectorStoreIndex.from_vector_store(
|
| 99 |
+
vector_store=vector_store,
|
| 100 |
+
storage_context=storage_context
|
| 101 |
+
)
|
| 102 |
+
return index
|
| 103 |
+
except Exception as e2:
|
| 104 |
+
print(f"β Error creating index from vector store: {e2}")
|
| 105 |
+
return None
|
| 106 |
+
|
| 107 |
+
def check_index_status():
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
pinecone_index = pc.Index(index_name)
|
| 111 |
+
stats = pinecone_index.describe_index_stats()
|
| 112 |
+
vector_count = stats.get('total_vector_count', 0)
|
| 113 |
+
|
| 114 |
+
if vector_count > 0:
|
| 115 |
+
print(f"β
Index found with {vector_count} vectors")
|
| 116 |
+
return True
|
| 117 |
+
else:
|
| 118 |
+
print("β Index exists but is empty")
|
| 119 |
+
return False
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"β Error checking index: {e}")
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def clear_pinecone_index():
|
| 127 |
+
"""Delete all vectors from Pinecone index"""
|
| 128 |
+
try:
|
| 129 |
+
pinecone_index = pc.Index(index_name)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
stats = pinecone_index.describe_index_stats()
|
| 133 |
+
vector_count = stats.get('total_vector_count', 0)
|
| 134 |
+
print(f"ποΈ Current vectors in index: {vector_count}")
|
| 135 |
+
|
| 136 |
+
if vector_count > 0:
|
| 137 |
+
|
| 138 |
+
pinecone_index.delete(delete_all=True)
|
| 139 |
+
print("β
All vectors deleted from Pinecone index")
|
| 140 |
+
else:
|
| 141 |
+
print("βΉοΈ Index is already empty")
|
| 142 |
+
|
| 143 |
+
return True
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f"β Error clearing index: {e}")
|
| 147 |
+
return False
|
| 148 |
+
|
| 149 |
+
def rebuild_index():
|
| 150 |
+
"""Clear old data and rebuild index with new CSV processing"""
|
| 151 |
+
try:
|
| 152 |
+
print("π Starting index rebuild process...")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if not clear_pinecone_index():
|
| 156 |
+
print("β Failed to clear index, aborting rebuild")
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
import shutil
|
| 161 |
+
if os.path.exists("./storage"):
|
| 162 |
+
shutil.rmtree("./storage")
|
| 163 |
+
print("ποΈ Cleared local storage")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
nodes = get_and_chunk_documents()
|
| 167 |
+
|
| 168 |
+
if not nodes:
|
| 169 |
+
print("β No nodes created, cannot rebuild index")
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
storage_context = get_storage_context(for_rebuild=True)
|
| 174 |
+
index = VectorStoreIndex(nodes, storage_context=storage_context)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
index.storage_context.persist(persist_dir="./storage")
|
| 178 |
+
|
| 179 |
+
print(f"β
Index rebuilt successfully with {len(nodes)} nodes")
|
| 180 |
+
return index
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"β Error rebuilding index: {e}")
|
| 184 |
+
return None
|