Spaces:
Sleeping
Sleeping
Commit
·
115b95d
1
Parent(s):
2afe3f5
Enhance agentic memory handler with dynamic LTM/STM
Browse files
app.py
CHANGED
@@ -246,7 +246,7 @@ class RAGMedicalChatbot:
|
|
246 |
parts.append(
|
247 |
"A user medical image is diagnosed by our VLM agent:\n"
|
248 |
f"{image_diagnosis}\n\n"
|
249 |
-
"
|
250 |
)
|
251 |
|
252 |
# Append contextual chunks from hybrid approach
|
|
|
246 |
parts.append(
|
247 |
"A user medical image is diagnosed by our VLM agent:\n"
|
248 |
f"{image_diagnosis}\n\n"
|
249 |
+
"Please incorporate the above findings in your response if medically relevant.\n\n"
|
250 |
)
|
251 |
|
252 |
# Append contextual chunks from hybrid approach
|
memory.py
CHANGED
@@ -18,10 +18,14 @@ api_key = os.getenv("FlashAPI")
|
|
18 |
client = genai.Client(api_key=api_key)
|
19 |
|
20 |
class MemoryManager:
|
21 |
-
def __init__(self, max_users=1000, history_per_user=
|
|
|
|
|
|
|
22 |
self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
|
|
|
23 |
self.chunk_index = defaultdict(self._new_index) # user_id -> faiss index
|
24 |
-
self.chunk_meta = defaultdict(list) #
|
25 |
self.user_queue = deque(maxlen=max_users) # LRU of users
|
26 |
self.max_chunks = max_chunks # hard cap per user
|
27 |
self.chunk_cache = {} # hash(query+resp) -> [chunks]
|
@@ -29,6 +33,7 @@ class MemoryManager:
|
|
29 |
# ---------- Public API ----------
|
30 |
def add_exchange(self, user_id: str, query: str, response: str, lang: str = "EN"):
|
31 |
self._touch_user(user_id)
|
|
|
32 |
self.text_cache[user_id].append(((query or "").strip(), (response or "").strip()))
|
33 |
if not response: return []
|
34 |
# Avoid re-chunking identical response
|
@@ -36,26 +41,14 @@ class MemoryManager:
|
|
36 |
if cache_key in self.chunk_cache:
|
37 |
chunks = self.chunk_cache[cache_key]
|
38 |
else:
|
39 |
-
chunks = self.chunk_response(response, lang)
|
40 |
self.chunk_cache[cache_key] = chunks
|
41 |
-
|
42 |
-
# Store chunks → faiss
|
43 |
for chunk in chunks:
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
# Store each chunk's vector once and reuse it
|
49 |
-
chunk_with_vec = {
|
50 |
-
**chunk,
|
51 |
-
"vec": vec,
|
52 |
-
"timestamp": time.time(), # store creation time
|
53 |
-
"used": 0 # track usage
|
54 |
-
}
|
55 |
-
self.chunk_meta[user_id].append(chunk_with_vec)
|
56 |
-
# Trim to max_chunks to keep latency O(1)
|
57 |
-
if len(self.chunk_meta[user_id]) > self.max_chunks:
|
58 |
-
self._rebuild_index(user_id, keep_last=self.max_chunks)
|
59 |
|
60 |
def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 3, min_sim: float = 0.30) -> List[str]:
|
61 |
"""Return texts of chunks whose cosine similarity ≥ min_sim."""
|
@@ -70,7 +63,7 @@ class MemoryManager:
|
|
70 |
if idx < len(self.chunk_meta[user_id]) and sim >= min_sim:
|
71 |
chunk = self.chunk_meta[user_id][idx]
|
72 |
chunk["used"] += 1 # increment usage
|
73 |
-
# Decay function
|
74 |
age_sec = time.time() - chunk["timestamp"]
|
75 |
decay = 1.0 / (1.0 + age_sec / 300) # 5-min half-life
|
76 |
score = sim * decay * (1 + 0.1 * chunk["used"])
|
@@ -81,29 +74,27 @@ class MemoryManager:
|
|
81 |
# logger.info(f"[Memory] RAG Retrieved Topic: {results}") # Inspect vector data
|
82 |
return [f"### Topic: {c['tag']}\n{c['text']}" for _, c in results]
|
83 |
|
84 |
-
def get_recent_chat_history(self, user_id: str, num_turns: int =
|
85 |
"""
|
86 |
-
Get the most recent
|
87 |
-
Returns:
|
88 |
"""
|
89 |
-
if user_id not in self.
|
90 |
return []
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
"
|
98 |
-
"bot": response,
|
99 |
-
"timestamp": time.time() # We could store actual timestamps if needed
|
100 |
})
|
101 |
-
|
102 |
-
return formatted_history
|
103 |
|
104 |
-
def get_context(self, user_id: str, num_turns: int =
|
105 |
-
|
106 |
-
|
|
|
107 |
|
108 |
def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
|
109 |
"""
|
@@ -111,7 +102,7 @@ class MemoryManager:
|
|
111 |
This ensures conversational continuity while providing a concise summary for the main LLM.
|
112 |
"""
|
113 |
# Get both types of context
|
114 |
-
recent_history = self.get_recent_chat_history(user_id, num_turns=
|
115 |
rag_chunks = self.get_relevant_chunks(user_id, current_query, top_k=3)
|
116 |
|
117 |
logger.info(f"[Contextual] Retrieved {len(recent_history)} recent history items")
|
@@ -133,7 +124,7 @@ class MemoryManager:
|
|
133 |
# Add RAG chunks
|
134 |
if rag_chunks:
|
135 |
rag_text = "\n".join(rag_chunks)
|
136 |
-
context_parts.append(f"Semantically relevant medical information:\n{rag_text}")
|
137 |
|
138 |
# Build summarization prompt
|
139 |
summarization_prompt = f"""
|
@@ -232,7 +223,7 @@ class MemoryManager:
|
|
232 |
# L2 normalise for cosine on IndexFlatIP
|
233 |
return vec / (np.linalg.norm(vec) + 1e-9)
|
234 |
|
235 |
-
def chunk_response(self, response: str, lang: str) -> List[Dict]:
|
236 |
"""
|
237 |
Calls Gemini to:
|
238 |
- Translate (if needed)
|
@@ -245,15 +236,17 @@ class MemoryManager:
|
|
245 |
instructions = []
|
246 |
# if lang.upper() != "EN":
|
247 |
# instructions.append("- Translate the response to English.")
|
248 |
-
instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic or
|
249 |
-
instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions.")
|
250 |
-
instructions.append("- At the start of each part, write `Topic: <
|
251 |
instructions.append("- Separate each part using three dashes `---` on a new line.")
|
252 |
# if lang.upper() != "EN":
|
253 |
# instructions.append(f"Below is the user-provided medical response written in `{lang}`")
|
254 |
# Gemini prompt
|
255 |
prompt = f"""
|
256 |
You are a medical assistant helping organize and condense a clinical response.
|
|
|
|
|
257 |
------------------------
|
258 |
{response}
|
259 |
------------------------
|
@@ -296,4 +289,138 @@ class MemoryManager:
|
|
296 |
return line.strip().rstrip(":")
|
297 |
return " ".join(chunk.split()[:3]).rstrip(":.,")
|
298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
|
|
18 |
client = genai.Client(api_key=api_key)
|
19 |
|
20 |
class MemoryManager:
|
21 |
+
def __init__(self, max_users=1000, history_per_user=20, max_chunks=60):
|
22 |
+
# STM: recent conversation summaries (topic + summary), up to 5 entries
|
23 |
+
self.stm_summaries = defaultdict(lambda: deque(maxlen=history_per_user)) # deque of {topic,text,vec,timestamp,used}
|
24 |
+
# Legacy raw cache (kept for compatibility if needed)
|
25 |
self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
|
26 |
+
# LTM: semantic chunk store (approx 3 chunks x 20 rounds)
|
27 |
self.chunk_index = defaultdict(self._new_index) # user_id -> faiss index
|
28 |
+
self.chunk_meta = defaultdict(list) # '' -> list[{text,tag,vec,timestamp,used}]
|
29 |
self.user_queue = deque(maxlen=max_users) # LRU of users
|
30 |
self.max_chunks = max_chunks # hard cap per user
|
31 |
self.chunk_cache = {} # hash(query+resp) -> [chunks]
|
|
|
33 |
# ---------- Public API ----------
|
34 |
def add_exchange(self, user_id: str, query: str, response: str, lang: str = "EN"):
|
35 |
self._touch_user(user_id)
|
36 |
+
# Keep raw record (optional)
|
37 |
self.text_cache[user_id].append(((query or "").strip(), (response or "").strip()))
|
38 |
if not response: return []
|
39 |
# Avoid re-chunking identical response
|
|
|
41 |
if cache_key in self.chunk_cache:
|
42 |
chunks = self.chunk_cache[cache_key]
|
43 |
else:
|
44 |
+
chunks = self.chunk_response(response, lang, question=query)
|
45 |
self.chunk_cache[cache_key] = chunks
|
46 |
+
# Update STM with merging/deduplication
|
|
|
47 |
for chunk in chunks:
|
48 |
+
self._upsert_stm(user_id, chunk, lang)
|
49 |
+
# Update LTM with merging/deduplication
|
50 |
+
self._upsert_ltm(user_id, chunks, lang)
|
51 |
+
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 3, min_sim: float = 0.30) -> List[str]:
|
54 |
"""Return texts of chunks whose cosine similarity ≥ min_sim."""
|
|
|
63 |
if idx < len(self.chunk_meta[user_id]) and sim >= min_sim:
|
64 |
chunk = self.chunk_meta[user_id][idx]
|
65 |
chunk["used"] += 1 # increment usage
|
66 |
+
# Decay function
|
67 |
age_sec = time.time() - chunk["timestamp"]
|
68 |
decay = 1.0 / (1.0 + age_sec / 300) # 5-min half-life
|
69 |
score = sim * decay * (1 + 0.1 * chunk["used"])
|
|
|
74 |
# logger.info(f"[Memory] RAG Retrieved Topic: {results}") # Inspect vector data
|
75 |
return [f"### Topic: {c['tag']}\n{c['text']}" for _, c in results]
|
76 |
|
77 |
+
def get_recent_chat_history(self, user_id: str, num_turns: int = 5) -> List[Dict]:
|
78 |
"""
|
79 |
+
Get the most recent short-term memory summaries.
|
80 |
+
Returns: a list of entries containing only the summarized bot context.
|
81 |
"""
|
82 |
+
if user_id not in self.stm_summaries:
|
83 |
return []
|
84 |
+
recent = list(self.stm_summaries[user_id])[-num_turns:]
|
85 |
+
formatted = []
|
86 |
+
for entry in recent:
|
87 |
+
formatted.append({
|
88 |
+
"user": "",
|
89 |
+
"bot": f"Topic: {entry['topic']}\n{entry['text']}",
|
90 |
+
"timestamp": entry.get("timestamp", time.time())
|
|
|
|
|
91 |
})
|
92 |
+
return formatted
|
|
|
93 |
|
94 |
+
def get_context(self, user_id: str, num_turns: int = 5) -> str:
|
95 |
+
# Prefer STM summaries
|
96 |
+
history = self.get_recent_chat_history(user_id, num_turns=num_turns)
|
97 |
+
return "\n".join(h["bot"] for h in history)
|
98 |
|
99 |
def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
|
100 |
"""
|
|
|
102 |
This ensures conversational continuity while providing a concise summary for the main LLM.
|
103 |
"""
|
104 |
# Get both types of context
|
105 |
+
recent_history = self.get_recent_chat_history(user_id, num_turns=5)
|
106 |
rag_chunks = self.get_relevant_chunks(user_id, current_query, top_k=3)
|
107 |
|
108 |
logger.info(f"[Contextual] Retrieved {len(recent_history)} recent history items")
|
|
|
124 |
# Add RAG chunks
|
125 |
if rag_chunks:
|
126 |
rag_text = "\n".join(rag_chunks)
|
127 |
+
context_parts.append(f"Semantically relevant historical medical information:\n{rag_text}")
|
128 |
|
129 |
# Build summarization prompt
|
130 |
summarization_prompt = f"""
|
|
|
223 |
# L2 normalise for cosine on IndexFlatIP
|
224 |
return vec / (np.linalg.norm(vec) + 1e-9)
|
225 |
|
226 |
+
def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]:
|
227 |
"""
|
228 |
Calls Gemini to:
|
229 |
- Translate (if needed)
|
|
|
236 |
instructions = []
|
237 |
# if lang.upper() != "EN":
|
238 |
# instructions.append("- Translate the response to English.")
|
239 |
+
instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic, symptom, assessment, plan, or instruction (exclude disclaimer section).")
|
240 |
+
instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions and exact medication names/doses if present.")
|
241 |
+
instructions.append("- At the start of each part, write `Topic: <concise but specific sentence (10-20 words) capturing patient context, condition, and action>`.")
|
242 |
instructions.append("- Separate each part using three dashes `---` on a new line.")
|
243 |
# if lang.upper() != "EN":
|
244 |
# instructions.append(f"Below is the user-provided medical response written in `{lang}`")
|
245 |
# Gemini prompt
|
246 |
prompt = f"""
|
247 |
You are a medical assistant helping organize and condense a clinical response.
|
248 |
+
If helpful, use the user's latest question for context to craft specific topics.
|
249 |
+
User's latest question (context): {question}
|
250 |
------------------------
|
251 |
{response}
|
252 |
------------------------
|
|
|
289 |
return line.strip().rstrip(":")
|
290 |
return " ".join(chunk.split()[:3]).rstrip(":.,")
|
291 |
|
292 |
+
# ---------- New merging/dedup logic ----------
|
293 |
+
def _upsert_stm(self, user_id: str, chunk: Dict, lang: str):
|
294 |
+
"""Insert or merge a summarized chunk into STM with semantic dedup/merge.
|
295 |
+
Identical: replace the older with new. Partially similar: merge extra details from older into newer.
|
296 |
+
"""
|
297 |
+
topic = self._enrich_topic(chunk.get("tag", ""), chunk.get("text", ""))
|
298 |
+
text = chunk.get("text", "").strip()
|
299 |
+
vec = self._embed(text)
|
300 |
+
now = time.time()
|
301 |
+
entry = {"topic": topic, "text": text, "vec": vec, "timestamp": now, "used": 0}
|
302 |
+
stm = self.stm_summaries[user_id]
|
303 |
+
if not stm:
|
304 |
+
stm.append(entry)
|
305 |
+
return
|
306 |
+
# find best match
|
307 |
+
best_idx = -1
|
308 |
+
best_sim = -1.0
|
309 |
+
for i, e in enumerate(stm):
|
310 |
+
sim = float(np.dot(vec, e["vec"]))
|
311 |
+
if sim > best_sim:
|
312 |
+
best_sim = sim
|
313 |
+
best_idx = i
|
314 |
+
if best_sim >= 0.92: # nearly identical
|
315 |
+
# replace older with current
|
316 |
+
stm.rotate(-best_idx)
|
317 |
+
stm.popleft()
|
318 |
+
stm.rotate(best_idx)
|
319 |
+
stm.append(entry)
|
320 |
+
elif best_sim >= 0.75: # partially similar → merge
|
321 |
+
base = stm[best_idx]
|
322 |
+
merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add bits from old not in new
|
323 |
+
merged_topic = base["topic"] if len(base["topic"]) > len(topic) else topic
|
324 |
+
merged_vec = self._embed(merged_text)
|
325 |
+
merged_entry = {"topic": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
|
326 |
+
stm.rotate(-best_idx)
|
327 |
+
stm.popleft()
|
328 |
+
stm.rotate(best_idx)
|
329 |
+
stm.append(merged_entry)
|
330 |
+
else:
|
331 |
+
stm.append(entry)
|
332 |
+
|
333 |
+
def _upsert_ltm(self, user_id: str, chunks: List[Dict], lang: str):
|
334 |
+
"""Insert or merge chunks into LTM with semantic dedup/merge, then rebuild index.
|
335 |
+
Keeps only the most recent self.max_chunks entries.
|
336 |
+
"""
|
337 |
+
current_list = self.chunk_meta[user_id]
|
338 |
+
for chunk in chunks:
|
339 |
+
text = chunk.get("text", "").strip()
|
340 |
+
if not text:
|
341 |
+
continue
|
342 |
+
vec = self._embed(text)
|
343 |
+
topic = self._enrich_topic(chunk.get("tag", ""), text)
|
344 |
+
now = time.time()
|
345 |
+
new_entry = {"tag": topic, "text": text, "vec": vec, "timestamp": now, "used": 0}
|
346 |
+
if not current_list:
|
347 |
+
current_list.append(new_entry)
|
348 |
+
continue
|
349 |
+
# find best similar entry
|
350 |
+
best_idx = -1
|
351 |
+
best_sim = -1.0
|
352 |
+
for i, e in enumerate(current_list):
|
353 |
+
sim = float(np.dot(vec, e["vec"]))
|
354 |
+
if sim > best_sim:
|
355 |
+
best_sim = sim
|
356 |
+
best_idx = i
|
357 |
+
if best_sim >= 0.92:
|
358 |
+
# replace older with new
|
359 |
+
current_list[best_idx] = new_entry
|
360 |
+
elif best_sim >= 0.75:
|
361 |
+
# merge details
|
362 |
+
base = current_list[best_idx]
|
363 |
+
merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add unique sentences from old
|
364 |
+
merged_topic = base["tag"] if len(base["tag"]) > len(topic) else topic
|
365 |
+
merged_vec = self._embed(merged_text)
|
366 |
+
current_list[best_idx] = {"tag": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
|
367 |
+
else:
|
368 |
+
current_list.append(new_entry)
|
369 |
+
# Trim and rebuild index
|
370 |
+
if len(current_list) > self.max_chunks:
|
371 |
+
current_list[:] = current_list[-self.max_chunks:]
|
372 |
+
self._rebuild_index(user_id, keep_last=self.max_chunks)
|
373 |
+
|
374 |
+
@staticmethod
|
375 |
+
def _split_sentences(text: str) -> List[str]:
|
376 |
+
# naive sentence splitter by ., !, ?
|
377 |
+
parts = re.split(r"(?<=[\.!?])\s+", text.strip())
|
378 |
+
return [p.strip() for p in parts if p.strip()]
|
379 |
+
|
380 |
+
def _merge_texts(self, new_text: str, old_text: str) -> str:
|
381 |
+
"""Append sentences from old_text that are not already contained in new_text (by fuzzy match)."""
|
382 |
+
new_sents = self._split_sentences(new_text)
|
383 |
+
old_sents = self._split_sentences(old_text)
|
384 |
+
new_set = set(s.lower() for s in new_sents)
|
385 |
+
merged = list(new_sents)
|
386 |
+
for s in old_sents:
|
387 |
+
s_norm = s.lower()
|
388 |
+
# consider present if significant overlap with any existing sentence
|
389 |
+
if s_norm in new_set:
|
390 |
+
continue
|
391 |
+
# simple containment check
|
392 |
+
if any(self._overlap_ratio(s_norm, t.lower()) > 0.8 for t in merged):
|
393 |
+
continue
|
394 |
+
merged.append(s)
|
395 |
+
return " ".join(merged)
|
396 |
+
|
397 |
+
@staticmethod
|
398 |
+
def _overlap_ratio(a: str, b: str) -> float:
|
399 |
+
"""Compute token overlap ratio between two sentences."""
|
400 |
+
ta = set(re.findall(r"\w+", a))
|
401 |
+
tb = set(re.findall(r"\w+", b))
|
402 |
+
if not ta or not tb:
|
403 |
+
return 0.0
|
404 |
+
inter = len(ta & tb)
|
405 |
+
union = len(ta | tb)
|
406 |
+
return inter / union
|
407 |
+
|
408 |
+
@staticmethod
|
409 |
+
def _enrich_topic(topic: str, text: str) -> str:
|
410 |
+
"""Make topic more descriptive if it's too short by using the first sentence of the text.
|
411 |
+
Does not call LLM to keep latency low.
|
412 |
+
"""
|
413 |
+
topic = (topic or "").strip()
|
414 |
+
if len(topic.split()) < 5 or len(topic) < 20:
|
415 |
+
sents = re.split(r"(?<=[\.!?])\s+", text.strip())
|
416 |
+
if sents:
|
417 |
+
first = sents[0]
|
418 |
+
# cap to ~16 words
|
419 |
+
words = first.split()
|
420 |
+
if len(words) > 16:
|
421 |
+
first = " ".join(words[:16])
|
422 |
+
# ensure capitalized
|
423 |
+
return first.strip().rstrip(':')
|
424 |
+
return topic
|
425 |
+
|
426 |
|