Song commited on
Commit
ee31585
·
1 Parent(s): 5f83bd3
Files changed (1) hide show
  1. app.py +75 -33
app.py CHANGED
@@ -59,8 +59,6 @@ def _require_llm_config():
59
  for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
60
  _require_env(k)
61
 
62
- _require_llm_config()
63
-
64
  CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
65
  FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
66
  SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
@@ -192,10 +190,19 @@ class RagPipeline:
192
  self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
193
  )
194
  self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
 
 
 
 
 
 
195
  self._load_drug_name_vocabulary()
196
 
197
  log.info("載入 FAISS 索引與句子資料...")
198
  self.state.index = faiss.read_index(FAISS_INDEX)
 
 
 
199
  with open(SENTENCES_PKL, "rb") as f:
200
  data = pickle.load(f)
201
  self.state.sentences = data["sentences"]
@@ -204,6 +211,8 @@ class RagPipeline:
204
  log.info("載入 BM25 索引...")
205
  with open(BM25_PKL, "rb") as f:
206
  self.state.bm25 = pickle.load(f)
 
 
207
  except (FileNotFoundError, KeyError) as e:
208
  log.exception(f"資料或索引檔案載入失敗: {e}")
209
  raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}")
@@ -217,10 +226,19 @@ class RagPipeline:
217
  for part in parts:
218
  if re.search(r'[\u4e00-\u9fff]', part):
219
  self.drug_vocab["zh"].add(part)
 
 
 
 
220
  else:
221
  self.drug_vocab["en"].add(part)
222
  for alias in DRUG_NAME_MAPPING:
223
  self.drug_vocab["en"].add(alias.lower())
 
 
 
 
 
224
 
225
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
226
  def _llm_call(self, messages, **kwargs) -> str:
@@ -273,9 +291,20 @@ class RagPipeline:
273
 
274
  @lru_cache(maxsize=128)
275
  def _find_drug_ids_from_name(self, query: str) -> List[str]:
276
- candidates = extract_drug_candidates_from_query(query.lower(), self.drug_vocab)
277
-
278
  drug_ids = set()
 
 
 
 
 
 
 
 
 
 
 
279
  for alias in candidates:
280
  # [MODIFIED] 英文藥名比對使用詞邊界,避免子字串誤判
281
  is_english = not re.search(r'[\u4e00-\u9fff]', alias)
@@ -291,7 +320,6 @@ class RagPipeline:
291
  drug_ids.update(ids)
292
  return list(drug_ids)
293
 
294
-
295
  def _analyze_query(self, query: str) -> Dict[str, Any]:
296
  prompt = PROMPT_TEMPLATES["analyze_query"].format(
297
  options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
@@ -311,25 +339,31 @@ class RagPipeline:
311
  expanded_q = self._expand_query_with_llm(sub_q, tuple(intents))
312
 
313
  q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
314
- faiss.normalize_L2(q_emb)
 
315
  distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K)
316
 
317
  tokenized_query = list(jieba.cut(expanded_q))
318
 
319
- # [MODIFIED] 改為獲取真實 BM25 分數,而非使用排名
320
  bm25_scores = self.state.bm25.get_scores(tokenized_query)
321
- top_bm25_indices = np.argsort(bm25_scores)[::-1][:PRE_RERANK_K]
322
- doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_bm25_indices}
 
 
323
 
324
  candidate_scores: Dict[int, Dict[str, float]] = {}
325
 
326
- # [MODIFIED] FAISS L2 距離轉為相似度 (分數越高越好)
327
- def dist_to_sim(d: float) -> float:
328
- return 1.0 / (1.0 + d)
 
 
 
329
 
330
  for i, dist in zip(sim_indices[0], distances[0]):
331
  if i in relevant_indices:
332
- similarity = dist_to_sim(dist)
333
  candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0}
334
 
335
  for i, score in doc_to_bm25_score.items():
@@ -408,16 +442,14 @@ class RagPipeline:
408
  )
409
 
410
  # [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
411
- def _safe_json_parse(self, json_str: str, default: Any = None) -> Any:
412
- # Find the JSON object within the string
413
- match = re.search(r'\{.*\}', json_str, re.DOTALL)
414
- if match:
415
- json_str = match.group(0)
416
-
417
  try:
418
- return json.loads(json_str)
419
- except json.JSONDecodeError:
420
- log.warning(f"無法解析 LLM 回傳的 JSON: {json_str}")
421
  return default
422
 
423
  # ---------- FastAPI 事件與路由 ----------
@@ -432,6 +464,7 @@ class AppConfig:
432
  @app.on_event("startup")
433
  async def startup_event():
434
  global rag_pipeline
 
435
  rag_pipeline = RagPipeline(AppConfig)
436
  rag_pipeline.load_data()
437
  log.info("啟動完成,服務準備就緒。")
@@ -457,35 +490,41 @@ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
457
  if not hmac.compare_digest(expected_signature, signature):
458
  raise HTTPException(status_code=403, detail="Invalid signature")
459
 
460
- data = json.loads(body.decode('utf-8'))
 
 
 
 
461
  for event in data.get("events", []):
462
  if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
463
  reply_token = event.get("replyToken")
464
  user_text = event.get("message", {}).get("text", "").strip()
465
- # [MODIFIED] 安全地獲取 userId,應對群組/聊天室中可能不存在的情況
466
  source = event.get("source", {})
467
- user_id = source.get("userId")
 
468
 
469
- if reply_token and user_id and user_text:
470
  # [MODIFIED] 更改回覆策略:立即回覆處理中訊息,避免 replyToken 逾時
471
  line_reply(reply_token, "收到您的問題,正在查詢資料庫,請稍候...")
472
  # 將耗時的任務交給背景處理,使用 push message 回覆最終答案
473
- background_tasks.add_task(process_user_query, user_id, user_text)
474
 
475
  return Response(status_code=status.HTTP_200_OK)
476
 
477
  # [MODIFIED] 調整函式簽名,只接收 user_id 和 text,並使用 push message
478
- def process_user_query(user_id: str, user_text: str):
479
  try:
480
  if rag_pipeline:
481
  answer = rag_pipeline.answer_question(user_text)
482
  else:
483
  answer = "系統正在啟動中,請稍後再試。"
484
- line_push(user_id, answer)
485
  except Exception as e:
486
- log.error(f"背景處理 user_id={user_id} 發生錯誤: {e}", exc_info=True)
487
- line_push(user_id, f"抱歉,處理時發生未預期的錯誤。{DISCLAIMER}")
488
 
 
489
  def line_api_call(endpoint: str, data: Dict):
490
  headers = {
491
  "Content-Type": "application/json",
@@ -496,14 +535,17 @@ def line_api_call(endpoint: str, data: Dict):
496
  response.raise_for_status()
497
  except requests.exceptions.RequestException as e:
498
  log.error(f"LINE API ({endpoint}) 呼叫失敗: {e} | Response: {e.response.text if e.response else 'N/A'}")
 
499
 
500
  def line_reply(reply_token: str, text: str):
501
  messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
502
  line_api_call("reply", {"replyToken": reply_token, "messages": messages})
503
 
504
- def line_push(user_id: str, text: str):
505
  messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
506
- line_api_call("push", {"to": user_id, "messages": messages})
 
 
507
 
508
  # [MODIFIED] 改善藥名提取的正則表達式
509
  def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
 
59
  for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
60
  _require_env(k)
61
 
 
 
62
  CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
63
  FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
64
  SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
 
190
  self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
191
  )
192
  self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
193
+ # [MODIFIED] 把別名也變成可查鍵
194
+ for alias, canonical in DRUG_NAME_MAPPING.items():
195
+ alias_key = re.sub(r'[^\w\s]', '', alias.lower()).strip()
196
+ canonical_key = re.sub(r'[^\w\s]', '', canonical.lower()).strip()
197
+ if canonical_key in self.drug_name_to_ids:
198
+ self.drug_name_to_ids[alias_key] = self.drug_name_to_ids[canonical_key]
199
  self._load_drug_name_vocabulary()
200
 
201
  log.info("載入 FAISS 索引與句子資料...")
202
  self.state.index = faiss.read_index(FAISS_INDEX)
203
+ self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
204
+ if hasattr(self.state.index, "nprobe"):
205
+ self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
206
  with open(SENTENCES_PKL, "rb") as f:
207
  data = pickle.load(f)
208
  self.state.sentences = data["sentences"]
 
211
  log.info("載入 BM25 索引...")
212
  with open(BM25_PKL, "rb") as f:
213
  self.state.bm25 = pickle.load(f)
214
+ if not isinstance(self.state.bm25, BM25Okapi):
215
+ raise ValueError("Loaded BM25 is not a BM25Okapi instance.")
216
  except (FileNotFoundError, KeyError) as e:
217
  log.exception(f"資料或索引檔案載入失敗: {e}")
218
  raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}")
 
226
  for part in parts:
227
  if re.search(r'[\u4e00-\u9fff]', part):
228
  self.drug_vocab["zh"].add(part)
229
+ try:
230
+ jieba.add_word(part, freq=2_000_000)
231
+ except Exception:
232
+ pass
233
  else:
234
  self.drug_vocab["en"].add(part)
235
  for alias in DRUG_NAME_MAPPING:
236
  self.drug_vocab["en"].add(alias.lower())
237
+ if re.search(r'[\u4e00-\u9fff]', alias):
238
+ try:
239
+ jieba.add_word(alias, freq=2_000_000)
240
+ except Exception:
241
+ pass
242
 
243
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
244
  def _llm_call(self, messages, **kwargs) -> str:
 
291
 
292
  @lru_cache(maxsize=128)
293
  def _find_drug_ids_from_name(self, query: str) -> List[str]:
294
+ q = query.lower()
295
+ candidates = extract_drug_candidates_from_query(q, self.drug_vocab)
296
  drug_ids = set()
297
+
298
+ # 英文:詞邊界;中文:也做子字串掃描
299
+ for k, ids in self.drug_name_to_ids.items():
300
+ if re.search(r'[\u4e00-\u9fff]', k):
301
+ if k in q:
302
+ drug_ids.update(ids)
303
+ else:
304
+ if re.search(rf"\b{re.escape(k)}\b", q):
305
+ drug_ids.update(ids)
306
+
307
+ # 仍保留舊的候選詞路徑(補強)
308
  for alias in candidates:
309
  # [MODIFIED] 英文藥名比對使用詞邊界,避免子字串誤判
310
  is_english = not re.search(r'[\u4e00-\u9fff]', alias)
 
320
  drug_ids.update(ids)
321
  return list(drug_ids)
322
 
 
323
  def _analyze_query(self, query: str) -> Dict[str, Any]:
324
  prompt = PROMPT_TEMPLATES["analyze_query"].format(
325
  options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
 
339
  expanded_q = self._expand_query_with_llm(sub_q, tuple(intents))
340
 
341
  q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
342
+ if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
343
+ faiss.normalize_L2(q_emb)
344
  distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K)
345
 
346
  tokenized_query = list(jieba.cut(expanded_q))
347
 
348
+ # [MODIFIED] 先過濾 relevant_indices 再取 TopK
349
  bm25_scores = self.state.bm25.get_scores(tokenized_query)
350
+ rel_idx = np.fromiter(relevant_indices, dtype=int)
351
+ rel_scores = bm25_scores[rel_idx]
352
+ top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
353
+ doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_rel}
354
 
355
  candidate_scores: Dict[int, Dict[str, float]] = {}
356
 
357
+ # [MODIFIED] distance 轉成「越大越好的相似度」
358
+ def to_similarity(d: float) -> float:
359
+ if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
360
+ return float(d) # IP 越大越好
361
+ else: # METRIC_L2(多半是平方 L2)
362
+ return 1.0 / (1.0 + float(d))
363
 
364
  for i, dist in zip(sim_indices[0], distances[0]):
365
  if i in relevant_indices:
366
+ similarity = to_similarity(dist)
367
  candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0}
368
 
369
  for i, score in doc_to_bm25_score.items():
 
442
  )
443
 
444
  # [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
445
+ def _safe_json_parse(self, s: str, default: Any = None) -> Any:
446
+ m = re.search(r'\{.*?\}', s, re.DOTALL) # 非貪婪
447
+ if m:
448
+ s = m.group(0)
 
 
449
  try:
450
+ return json.loads(s)
451
+ except Exception:
452
+ log.warning(f"無法解析 LLM 回傳的 JSON: {s[:200]}...")
453
  return default
454
 
455
  # ---------- FastAPI 事件與路由 ----------
 
464
  @app.on_event("startup")
465
  async def startup_event():
466
  global rag_pipeline
467
+ _require_llm_config()
468
  rag_pipeline = RagPipeline(AppConfig)
469
  rag_pipeline.load_data()
470
  log.info("啟動完成,服務準備就緒。")
 
490
  if not hmac.compare_digest(expected_signature, signature):
491
  raise HTTPException(status_code=403, detail="Invalid signature")
492
 
493
+ try:
494
+ data = json.loads(body.decode('utf-8'))
495
+ except json.JSONDecodeError:
496
+ raise HTTPException(status_code=400, detail="Invalid JSON body")
497
+
498
  for event in data.get("events", []):
499
  if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
500
  reply_token = event.get("replyToken")
501
  user_text = event.get("message", {}).get("text", "").strip()
502
+ # [MODIFIED] 擷取 target
503
  source = event.get("source", {})
504
+ stype = source.get("type") # "user" | "group" | "room"
505
+ target_id = source.get("userId") or source.get("groupId") or source.get("roomId")
506
 
507
+ if reply_token and user_text and target_id:
508
  # [MODIFIED] 更改回覆策略:立即回覆處理中訊息,避免 replyToken 逾時
509
  line_reply(reply_token, "收到您的問題,正在查詢資料庫,請稍候...")
510
  # 將耗時的任務交給背景處理,使用 push message 回覆最終答案
511
+ background_tasks.add_task(process_user_query, stype, target_id, user_text)
512
 
513
  return Response(status_code=status.HTTP_200_OK)
514
 
515
  # [MODIFIED] 調整函式簽名,只接收 user_id 和 text,並使用 push message
516
+ def process_user_query(source_type: str, target_id: str, user_text: str):
517
  try:
518
  if rag_pipeline:
519
  answer = rag_pipeline.answer_question(user_text)
520
  else:
521
  answer = "系統正在啟動中,請稍後再試。"
522
+ line_push_generic(source_type, target_id, answer)
523
  except Exception as e:
524
+ log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True)
525
+ line_push_generic(source_type, target_id, f"抱歉,處理時發生未預期的錯誤。{DISCLAIMER}")
526
 
527
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
528
  def line_api_call(endpoint: str, data: Dict):
529
  headers = {
530
  "Content-Type": "application/json",
 
535
  response.raise_for_status()
536
  except requests.exceptions.RequestException as e:
537
  log.error(f"LINE API ({endpoint}) 呼叫失敗: {e} | Response: {e.response.text if e.response else 'N/A'}")
538
+ raise
539
 
540
  def line_reply(reply_token: str, text: str):
541
  messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
542
  line_api_call("reply", {"replyToken": reply_token, "messages": messages})
543
 
544
+ def line_push_generic(source_type: str, target_id: str, text: str):
545
  messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
546
+ endpoint = "push"
547
+ data = {"to": target_id, "messages": messages}
548
+ line_api_call(endpoint, data)
549
 
550
  # [MODIFIED] 改善藥名提取的正則表達式
551
  def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list: