Spaces:
Sleeping
Sleeping
Commit
·
9382e01
1
Parent(s):
8bc48fc
Upd RAG for sym-diagnosis
Browse files
app.py
CHANGED
|
@@ -94,6 +94,9 @@ except Exception as e:
|
|
| 94 |
logger.error(f"❌ Model Loading Failed: {e}")
|
| 95 |
exit(1)
|
| 96 |
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
# ✅ Setup MongoDB Connection
|
| 99 |
# QA data
|
|
@@ -104,6 +107,9 @@ qa_collection = db["qa_data"]
|
|
| 104 |
iclient = MongoClient(index_uri)
|
| 105 |
idb = iclient["MedicalChatbotDB"]
|
| 106 |
index_collection = idb["faiss_index_files"]
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# ✅ Load FAISS Index (Lazy Load)
|
| 109 |
import gridfs
|
|
@@ -142,6 +148,26 @@ def retrieve_medical_info(query, k=5, min_sim=0.6): # Min similarity between que
|
|
| 142 |
results.append(doc.get("Doctor", "No answer available."))
|
| 143 |
return results if results else ["No relevant medical entries found."]
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
# ✅ Gemini Flash API Call
|
| 147 |
def gemini_flash_completion(prompt, model, temperature=0.7):
|
|
@@ -161,8 +187,11 @@ class RAGMedicalChatbot:
|
|
| 161 |
|
| 162 |
def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
|
| 163 |
# 1. Fetch knowledge
|
|
|
|
| 164 |
retrieved_info = self.retrieve(user_query)
|
| 165 |
knowledge_base = "\n".join(retrieved_info)
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# 2. Use relevant chunks from short-term memory FAISS index (nearest 3 chunks)
|
| 168 |
context = memory.get_relevant_chunks(user_id, user_query, top_k=3)
|
|
@@ -177,6 +206,9 @@ class RAGMedicalChatbot:
|
|
| 177 |
# Load up guideline
|
| 178 |
if knowledge_base:
|
| 179 |
parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
|
|
|
|
|
|
|
|
|
|
| 180 |
parts.append(f"Question: {user_query}")
|
| 181 |
parts.append(f"Language: {lang}")
|
| 182 |
prompt = "\n\n".join(parts)
|
|
|
|
| 94 |
logger.error(f"❌ Model Loading Failed: {e}")
|
| 95 |
exit(1)
|
| 96 |
|
| 97 |
+
# Cache in-memory vectors (optional — useful for <10k rows)
|
| 98 |
+
SYMPTOM_VECTORS = None
|
| 99 |
+
SYMPTOM_DOCS = None
|
| 100 |
|
| 101 |
# ✅ Setup MongoDB Connection
|
| 102 |
# QA data
|
|
|
|
| 107 |
iclient = MongoClient(index_uri)
|
| 108 |
idb = iclient["MedicalChatbotDB"]
|
| 109 |
index_collection = idb["faiss_index_files"]
|
| 110 |
+
# Symptom Diagnosis data
|
| 111 |
+
symptom_client = MongoClient(mongo_uri)
|
| 112 |
+
symptom_col = symptom_client["MedicalChatbotDB"]["symptom_diagnosis"]
|
| 113 |
|
| 114 |
# ✅ Load FAISS Index (Lazy Load)
|
| 115 |
import gridfs
|
|
|
|
| 148 |
results.append(doc.get("Doctor", "No answer available."))
|
| 149 |
return results if results else ["No relevant medical entries found."]
|
| 150 |
|
| 151 |
+
# ✅ Retrieve Sym-Dia Info (4962 scenario)
|
| 152 |
+
def retrieve_diagnosis_from_symptoms(symptom_text, top_k=5, min_sim=0.4):
|
| 153 |
+
global SYMPTOM_VECTORS, SYMPTOM_DOCS
|
| 154 |
+
# Lazy load
|
| 155 |
+
if SYMPTOM_VECTORS is None:
|
| 156 |
+
all_docs = list(symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1}))
|
| 157 |
+
SYMPTOM_DOCS = all_docs
|
| 158 |
+
SYMPTOM_VECTORS = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
|
| 159 |
+
# Embed input
|
| 160 |
+
qvec = embedding_model.encode(symptom_text, convert_to_numpy=True)
|
| 161 |
+
qvec = qvec / (np.linalg.norm(qvec) + 1e-9)
|
| 162 |
+
# Similarity compute
|
| 163 |
+
sims = SYMPTOM_VECTORS @ qvec # cosine
|
| 164 |
+
sorted_idx = np.argsort(sims)[-top_k:][::-1]
|
| 165 |
+
# Final
|
| 166 |
+
return [
|
| 167 |
+
SYMPTOM_DOCS[i]["answer"]
|
| 168 |
+
for i in sorted_idx
|
| 169 |
+
if sims[i] >= min_sim
|
| 170 |
+
]
|
| 171 |
|
| 172 |
# ✅ Gemini Flash API Call
|
| 173 |
def gemini_flash_completion(prompt, model, temperature=0.7):
|
|
|
|
| 187 |
|
| 188 |
def chat(self, user_id: str, user_query: str, lang: str = "EN") -> str:
|
| 189 |
# 1. Fetch knowledge
|
| 190 |
+
## a. KB for generic QA retrieval
|
| 191 |
retrieved_info = self.retrieve(user_query)
|
| 192 |
knowledge_base = "\n".join(retrieved_info)
|
| 193 |
+
## b. Diagnosis RAG from symptom query
|
| 194 |
+
diagnosis_guides = retrieve_diagnosis_from_symptoms(user_query) # smart matcher
|
| 195 |
|
| 196 |
# 2. Use relevant chunks from short-term memory FAISS index (nearest 3 chunks)
|
| 197 |
context = memory.get_relevant_chunks(user_id, user_query, top_k=3)
|
|
|
|
| 206 |
# Load up guideline
|
| 207 |
if knowledge_base:
|
| 208 |
parts.append(f"Medical knowledge (256,916 medical scenario): {knowledge_base}")
|
| 209 |
+
# Symptom-Diagnosis prediction RAG
|
| 210 |
+
if diagnosis_guides:
|
| 211 |
+
parts.append("Symptom-based diagnosis guidance:\n" + "\n".join(diagnosis_guides))
|
| 212 |
parts.append(f"Question: {user_query}")
|
| 213 |
parts.append(f"Language: {lang}")
|
| 214 |
prompt = "\n\n".join(parts)
|