LiamKhoaLe commited on
Commit
115b95d
·
1 Parent(s): 2afe3f5

Enhance agentic memory handler with dynamic LTM/STM

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. memory.py +172 -45
app.py CHANGED
@@ -246,7 +246,7 @@ class RAGMedicalChatbot:
246
  parts.append(
247
  "A user medical image is diagnosed by our VLM agent:\n"
248
  f"{image_diagnosis}\n\n"
249
- "➡️ Please incorporate the above findings in your response if medically relevant.\n\n"
250
  )
251
 
252
  # Append contextual chunks from hybrid approach
 
246
  parts.append(
247
  "A user medical image is diagnosed by our VLM agent:\n"
248
  f"{image_diagnosis}\n\n"
249
+ "Please incorporate the above findings in your response if medically relevant.\n\n"
250
  )
251
 
252
  # Append contextual chunks from hybrid approach
memory.py CHANGED
@@ -18,10 +18,14 @@ api_key = os.getenv("FlashAPI")
18
  client = genai.Client(api_key=api_key)
19
 
20
  class MemoryManager:
21
- def __init__(self, max_users=1000, history_per_user=10, max_chunks=30):
 
 
 
22
  self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
 
23
  self.chunk_index = defaultdict(self._new_index) # user_id -> faiss index
24
- self.chunk_meta = defaultdict(list) # '' -> list[{text,tag}]
25
  self.user_queue = deque(maxlen=max_users) # LRU of users
26
  self.max_chunks = max_chunks # hard cap per user
27
  self.chunk_cache = {} # hash(query+resp) -> [chunks]
@@ -29,6 +33,7 @@ class MemoryManager:
29
  # ---------- Public API ----------
30
  def add_exchange(self, user_id: str, query: str, response: str, lang: str = "EN"):
31
  self._touch_user(user_id)
 
32
  self.text_cache[user_id].append(((query or "").strip(), (response or "").strip()))
33
  if not response: return []
34
  # Avoid re-chunking identical response
@@ -36,26 +41,14 @@ class MemoryManager:
36
  if cache_key in self.chunk_cache:
37
  chunks = self.chunk_cache[cache_key]
38
  else:
39
- chunks = self.chunk_response(response, lang)
40
  self.chunk_cache[cache_key] = chunks
41
- text_set = set(c["text"] for c in self.chunk_meta[user_id]) # Set list of metadata for deduplication
42
- # Store chunks → faiss
43
  for chunk in chunks:
44
- if chunk["text"] in text_set:
45
- continue # skip duplicate
46
- vec = self._embed(chunk["text"])
47
- self.chunk_index[user_id].add(np.array([vec]))
48
- # Store each chunk's vector once and reuse it
49
- chunk_with_vec = {
50
- **chunk,
51
- "vec": vec,
52
- "timestamp": time.time(), # store creation time
53
- "used": 0 # track usage
54
- }
55
- self.chunk_meta[user_id].append(chunk_with_vec)
56
- # Trim to max_chunks to keep latency O(1)
57
- if len(self.chunk_meta[user_id]) > self.max_chunks:
58
- self._rebuild_index(user_id, keep_last=self.max_chunks)
59
 
60
  def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 3, min_sim: float = 0.30) -> List[str]:
61
  """Return texts of chunks whose cosine similarity ≥ min_sim."""
@@ -70,7 +63,7 @@ class MemoryManager:
70
  if idx < len(self.chunk_meta[user_id]) and sim >= min_sim:
71
  chunk = self.chunk_meta[user_id][idx]
72
  chunk["used"] += 1 # increment usage
73
- # Decay function (you can tweak)
74
  age_sec = time.time() - chunk["timestamp"]
75
  decay = 1.0 / (1.0 + age_sec / 300) # 5-min half-life
76
  score = sim * decay * (1 + 0.1 * chunk["used"])
@@ -81,29 +74,27 @@ class MemoryManager:
81
  # logger.info(f"[Memory] RAG Retrieved Topic: {results}") # Inspect vector data
82
  return [f"### Topic: {c['tag']}\n{c['text']}" for _, c in results]
83
 
84
- def get_recent_chat_history(self, user_id: str, num_turns: int = 3) -> List[Dict]:
85
  """
86
- Get the most recent chat history with both user questions and bot responses.
87
- Returns: [{"user": "question", "bot": "response", "timestamp": time}, ...]
88
  """
89
- if user_id not in self.text_cache:
90
  return []
91
- # Get the most recent chat history
92
- recent_history = list(self.text_cache[user_id])[-num_turns:]
93
- formatted_history = []
94
- # Format the history
95
- for query, response in recent_history:
96
- formatted_history.append({
97
- "user": query,
98
- "bot": response,
99
- "timestamp": time.time() # We could store actual timestamps if needed
100
  })
101
-
102
- return formatted_history
103
 
104
- def get_context(self, user_id: str, num_turns: int = 3) -> str:
105
- history = list(self.text_cache.get(user_id, []))[-num_turns:]
106
- return "\n".join(f"User: {q}\nBot: {r}" for q, r in history)
 
107
 
108
  def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
109
  """
@@ -111,7 +102,7 @@ class MemoryManager:
111
  This ensures conversational continuity while providing a concise summary for the main LLM.
112
  """
113
  # Get both types of context
114
- recent_history = self.get_recent_chat_history(user_id, num_turns=3)
115
  rag_chunks = self.get_relevant_chunks(user_id, current_query, top_k=3)
116
 
117
  logger.info(f"[Contextual] Retrieved {len(recent_history)} recent history items")
@@ -133,7 +124,7 @@ class MemoryManager:
133
  # Add RAG chunks
134
  if rag_chunks:
135
  rag_text = "\n".join(rag_chunks)
136
- context_parts.append(f"Semantically relevant medical information:\n{rag_text}")
137
 
138
  # Build summarization prompt
139
  summarization_prompt = f"""
@@ -232,7 +223,7 @@ class MemoryManager:
232
  # L2 normalise for cosine on IndexFlatIP
233
  return vec / (np.linalg.norm(vec) + 1e-9)
234
 
235
- def chunk_response(self, response: str, lang: str) -> List[Dict]:
236
  """
237
  Calls Gemini to:
238
  - Translate (if needed)
@@ -245,15 +236,17 @@ class MemoryManager:
245
  instructions = []
246
  # if lang.upper() != "EN":
247
  # instructions.append("- Translate the response to English.")
248
- instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic or symptom.")
249
- instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions.")
250
- instructions.append("- At the start of each part, write `Topic: <one line description>`.")
251
  instructions.append("- Separate each part using three dashes `---` on a new line.")
252
  # if lang.upper() != "EN":
253
  # instructions.append(f"Below is the user-provided medical response written in `{lang}`")
254
  # Gemini prompt
255
  prompt = f"""
256
  You are a medical assistant helping organize and condense a clinical response.
 
 
257
  ------------------------
258
  {response}
259
  ------------------------
@@ -296,4 +289,138 @@ class MemoryManager:
296
  return line.strip().rstrip(":")
297
  return " ".join(chunk.split()[:3]).rstrip(":.,")
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
 
18
  client = genai.Client(api_key=api_key)
19
 
20
  class MemoryManager:
21
+ def __init__(self, max_users=1000, history_per_user=20, max_chunks=60):
22
+ # STM: recent conversation summaries (topic + summary), up to 5 entries
23
+ self.stm_summaries = defaultdict(lambda: deque(maxlen=history_per_user)) # deque of {topic,text,vec,timestamp,used}
24
+ # Legacy raw cache (kept for compatibility if needed)
25
  self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
26
+ # LTM: semantic chunk store (approx 3 chunks x 20 rounds)
27
  self.chunk_index = defaultdict(self._new_index) # user_id -> faiss index
28
+ self.chunk_meta = defaultdict(list) # '' -> list[{text,tag,vec,timestamp,used}]
29
  self.user_queue = deque(maxlen=max_users) # LRU of users
30
  self.max_chunks = max_chunks # hard cap per user
31
  self.chunk_cache = {} # hash(query+resp) -> [chunks]
 
33
  # ---------- Public API ----------
34
  def add_exchange(self, user_id: str, query: str, response: str, lang: str = "EN"):
35
  self._touch_user(user_id)
36
+ # Keep raw record (optional)
37
  self.text_cache[user_id].append(((query or "").strip(), (response or "").strip()))
38
  if not response: return []
39
  # Avoid re-chunking identical response
 
41
  if cache_key in self.chunk_cache:
42
  chunks = self.chunk_cache[cache_key]
43
  else:
44
+ chunks = self.chunk_response(response, lang, question=query)
45
  self.chunk_cache[cache_key] = chunks
46
+ # Update STM with merging/deduplication
 
47
  for chunk in chunks:
48
+ self._upsert_stm(user_id, chunk, lang)
49
+ # Update LTM with merging/deduplication
50
+ self._upsert_ltm(user_id, chunks, lang)
51
+ return chunks
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 3, min_sim: float = 0.30) -> List[str]:
54
  """Return texts of chunks whose cosine similarity ≥ min_sim."""
 
63
  if idx < len(self.chunk_meta[user_id]) and sim >= min_sim:
64
  chunk = self.chunk_meta[user_id][idx]
65
  chunk["used"] += 1 # increment usage
66
+ # Decay function
67
  age_sec = time.time() - chunk["timestamp"]
68
  decay = 1.0 / (1.0 + age_sec / 300) # 5-min half-life
69
  score = sim * decay * (1 + 0.1 * chunk["used"])
 
74
  # logger.info(f"[Memory] RAG Retrieved Topic: {results}") # Inspect vector data
75
  return [f"### Topic: {c['tag']}\n{c['text']}" for _, c in results]
76
 
77
+ def get_recent_chat_history(self, user_id: str, num_turns: int = 5) -> List[Dict]:
78
  """
79
+ Get the most recent short-term memory summaries.
80
+ Returns: a list of entries containing only the summarized bot context.
81
  """
82
+ if user_id not in self.stm_summaries:
83
  return []
84
+ recent = list(self.stm_summaries[user_id])[-num_turns:]
85
+ formatted = []
86
+ for entry in recent:
87
+ formatted.append({
88
+ "user": "",
89
+ "bot": f"Topic: {entry['topic']}\n{entry['text']}",
90
+ "timestamp": entry.get("timestamp", time.time())
 
 
91
  })
92
+ return formatted
 
93
 
94
+ def get_context(self, user_id: str, num_turns: int = 5) -> str:
95
+ # Prefer STM summaries
96
+ history = self.get_recent_chat_history(user_id, num_turns=num_turns)
97
+ return "\n".join(h["bot"] for h in history)
98
 
99
  def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
100
  """
 
102
  This ensures conversational continuity while providing a concise summary for the main LLM.
103
  """
104
  # Get both types of context
105
+ recent_history = self.get_recent_chat_history(user_id, num_turns=5)
106
  rag_chunks = self.get_relevant_chunks(user_id, current_query, top_k=3)
107
 
108
  logger.info(f"[Contextual] Retrieved {len(recent_history)} recent history items")
 
124
  # Add RAG chunks
125
  if rag_chunks:
126
  rag_text = "\n".join(rag_chunks)
127
+ context_parts.append(f"Semantically relevant historical medical information:\n{rag_text}")
128
 
129
  # Build summarization prompt
130
  summarization_prompt = f"""
 
223
  # L2 normalise for cosine on IndexFlatIP
224
  return vec / (np.linalg.norm(vec) + 1e-9)
225
 
226
+ def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]:
227
  """
228
  Calls Gemini to:
229
  - Translate (if needed)
 
236
  instructions = []
237
  # if lang.upper() != "EN":
238
  # instructions.append("- Translate the response to English.")
239
+ instructions.append("- Break the translated (or original) text into semantically distinct parts, grouped by medical topic, symptom, assessment, plan, or instruction (exclude disclaimer section).")
240
+ instructions.append("- For each part, generate a clear, concise summary. The summary may vary in length depending on the complexity of the topic — do not omit key clinical instructions and exact medication names/doses if present.")
241
+ instructions.append("- At the start of each part, write `Topic: <concise but specific sentence (10-20 words) capturing patient context, condition, and action>`.")
242
  instructions.append("- Separate each part using three dashes `---` on a new line.")
243
  # if lang.upper() != "EN":
244
  # instructions.append(f"Below is the user-provided medical response written in `{lang}`")
245
  # Gemini prompt
246
  prompt = f"""
247
  You are a medical assistant helping organize and condense a clinical response.
248
+ If helpful, use the user's latest question for context to craft specific topics.
249
+ User's latest question (context): {question}
250
  ------------------------
251
  {response}
252
  ------------------------
 
289
  return line.strip().rstrip(":")
290
  return " ".join(chunk.split()[:3]).rstrip(":.,")
291
 
292
+ # ---------- New merging/dedup logic ----------
293
+ def _upsert_stm(self, user_id: str, chunk: Dict, lang: str):
294
+ """Insert or merge a summarized chunk into STM with semantic dedup/merge.
295
+ Identical: replace the older with new. Partially similar: merge extra details from older into newer.
296
+ """
297
+ topic = self._enrich_topic(chunk.get("tag", ""), chunk.get("text", ""))
298
+ text = chunk.get("text", "").strip()
299
+ vec = self._embed(text)
300
+ now = time.time()
301
+ entry = {"topic": topic, "text": text, "vec": vec, "timestamp": now, "used": 0}
302
+ stm = self.stm_summaries[user_id]
303
+ if not stm:
304
+ stm.append(entry)
305
+ return
306
+ # find best match
307
+ best_idx = -1
308
+ best_sim = -1.0
309
+ for i, e in enumerate(stm):
310
+ sim = float(np.dot(vec, e["vec"]))
311
+ if sim > best_sim:
312
+ best_sim = sim
313
+ best_idx = i
314
+ if best_sim >= 0.92: # nearly identical
315
+ # replace older with current
316
+ stm.rotate(-best_idx)
317
+ stm.popleft()
318
+ stm.rotate(best_idx)
319
+ stm.append(entry)
320
+ elif best_sim >= 0.75: # partially similar → merge
321
+ base = stm[best_idx]
322
+ merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add bits from old not in new
323
+ merged_topic = base["topic"] if len(base["topic"]) > len(topic) else topic
324
+ merged_vec = self._embed(merged_text)
325
+ merged_entry = {"topic": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
326
+ stm.rotate(-best_idx)
327
+ stm.popleft()
328
+ stm.rotate(best_idx)
329
+ stm.append(merged_entry)
330
+ else:
331
+ stm.append(entry)
332
+
333
+ def _upsert_ltm(self, user_id: str, chunks: List[Dict], lang: str):
334
+ """Insert or merge chunks into LTM with semantic dedup/merge, then rebuild index.
335
+ Keeps only the most recent self.max_chunks entries.
336
+ """
337
+ current_list = self.chunk_meta[user_id]
338
+ for chunk in chunks:
339
+ text = chunk.get("text", "").strip()
340
+ if not text:
341
+ continue
342
+ vec = self._embed(text)
343
+ topic = self._enrich_topic(chunk.get("tag", ""), text)
344
+ now = time.time()
345
+ new_entry = {"tag": topic, "text": text, "vec": vec, "timestamp": now, "used": 0}
346
+ if not current_list:
347
+ current_list.append(new_entry)
348
+ continue
349
+ # find best similar entry
350
+ best_idx = -1
351
+ best_sim = -1.0
352
+ for i, e in enumerate(current_list):
353
+ sim = float(np.dot(vec, e["vec"]))
354
+ if sim > best_sim:
355
+ best_sim = sim
356
+ best_idx = i
357
+ if best_sim >= 0.92:
358
+ # replace older with new
359
+ current_list[best_idx] = new_entry
360
+ elif best_sim >= 0.75:
361
+ # merge details
362
+ base = current_list[best_idx]
363
+ merged_text = self._merge_texts(new_text=text, old_text=base["text"]) # add unique sentences from old
364
+ merged_topic = base["tag"] if len(base["tag"]) > len(topic) else topic
365
+ merged_vec = self._embed(merged_text)
366
+ current_list[best_idx] = {"tag": merged_topic, "text": merged_text, "vec": merged_vec, "timestamp": now, "used": base.get("used", 0)}
367
+ else:
368
+ current_list.append(new_entry)
369
+ # Trim and rebuild index
370
+ if len(current_list) > self.max_chunks:
371
+ current_list[:] = current_list[-self.max_chunks:]
372
+ self._rebuild_index(user_id, keep_last=self.max_chunks)
373
+
374
+ @staticmethod
375
+ def _split_sentences(text: str) -> List[str]:
376
+ # naive sentence splitter by ., !, ?
377
+ parts = re.split(r"(?<=[\.!?])\s+", text.strip())
378
+ return [p.strip() for p in parts if p.strip()]
379
+
380
+ def _merge_texts(self, new_text: str, old_text: str) -> str:
381
+ """Append sentences from old_text that are not already contained in new_text (by fuzzy match)."""
382
+ new_sents = self._split_sentences(new_text)
383
+ old_sents = self._split_sentences(old_text)
384
+ new_set = set(s.lower() for s in new_sents)
385
+ merged = list(new_sents)
386
+ for s in old_sents:
387
+ s_norm = s.lower()
388
+ # consider present if significant overlap with any existing sentence
389
+ if s_norm in new_set:
390
+ continue
391
+ # simple containment check
392
+ if any(self._overlap_ratio(s_norm, t.lower()) > 0.8 for t in merged):
393
+ continue
394
+ merged.append(s)
395
+ return " ".join(merged)
396
+
397
+ @staticmethod
398
+ def _overlap_ratio(a: str, b: str) -> float:
399
+ """Compute token overlap ratio between two sentences."""
400
+ ta = set(re.findall(r"\w+", a))
401
+ tb = set(re.findall(r"\w+", b))
402
+ if not ta or not tb:
403
+ return 0.0
404
+ inter = len(ta & tb)
405
+ union = len(ta | tb)
406
+ return inter / union
407
+
408
+ @staticmethod
409
+ def _enrich_topic(topic: str, text: str) -> str:
410
+ """Make topic more descriptive if it's too short by using the first sentence of the text.
411
+ Does not call LLM to keep latency low.
412
+ """
413
+ topic = (topic or "").strip()
414
+ if len(topic.split()) < 5 or len(topic) < 20:
415
+ sents = re.split(r"(?<=[\.!?])\s+", text.strip())
416
+ if sents:
417
+ first = sents[0]
418
+ # cap to ~16 words
419
+ words = first.split()
420
+ if len(words) > 16:
421
+ first = " ".join(words[:16])
422
+ # ensure capitalized
423
+ return first.strip().rstrip(':')
424
+ return topic
425
+
426