Song commited on
Commit
92ee3c2
·
1 Parent(s): 7b2e5cd
Files changed (2) hide show
  1. app.py +105 -143
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,14 +1,3 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """
5
- DrugQA (ZH) — 優化版 FastAPI LINE Webhook (最終版)
6
- 整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。
7
- 新增動態字數調整、多次互動邏輯與對話狀態管理,提升使用者體驗。
8
- 僅支援十種藥物。
9
- """
10
-
11
- # ---------- 環境與快取設定 ----------
12
  import os
13
  import pathlib
14
  import re
@@ -28,9 +17,8 @@ from contextlib import asynccontextmanager
28
  import unicodedata
29
  from collections import defaultdict
30
  import asyncio
31
- import aiohttp # 新增:導入 aiohttp 用於異步 HTTP 請求
32
 
33
- # ------------ 第三方函式庫 -------------
34
  import numpy as np
35
  import pandas as pd
36
  import jieba
@@ -44,7 +32,7 @@ import requests
44
  import uvicorn
45
  from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks
46
 
47
- # ---- 限制 PyTorch 執行緒數量,避免 CPU 環境下過度佔用資源 ----
48
  torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1")))
49
 
50
  # ===== CONFIG =====
@@ -55,30 +43,24 @@ def _require_env(var: str) -> str:
55
  raise RuntimeError(f"FATAL: Missing required environment variable: {var}")
56
  return v
57
 
58
-
59
  def _require_llm_config():
60
  for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
61
  _require_env(k)
62
 
63
-
64
  # --------- 路徑設定 ------------
65
  CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
66
  FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
67
  SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
68
  BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
69
-
70
  TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 20))
71
  PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
72
  MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
73
-
74
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
75
-
76
  LLM_API_CONFIG = {
77
  "base_url": _require_env("LITELLM_BASE_URL"),
78
  "api_key": _require_env("LITELLM_API_KEY"),
79
  "model": _require_env("LM_MODEL"),
80
  }
81
-
82
  LLM_MODEL_CONFIG = {
83
  "max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)),
84
  "max_tokens_simple": int(os.getenv("MAX_TOKENS_SIMPLE", 256)),
@@ -95,7 +77,6 @@ INTENT_CATEGORIES = [
95
  "劑量調整 (Dosage Adjustment)",
96
  "禁忌症/適應症 (Contraindications/Indications)",
97
  ]
98
-
99
  INTENT_TO_SECTION = {
100
  "操作 (Administration)": ["用法用量", "病人使用須知"],
101
  "保存/攜帶 (Storage & Handling)": ["包裝及儲存"],
@@ -105,7 +86,6 @@ INTENT_TO_SECTION = {
105
  "劑量調整 (Dosage Adjustment)": ["用法用量"],
106
  "禁忌症/適應症 (Contraindications/Indications)": ["適應症", "禁忌", "警語與注意事項"],
107
  }
108
-
109
  DRUG_NAME_MAPPING = {
110
  "fentanyl patch": "fentanyl",
111
  "spiriva respimat": "spiriva",
@@ -122,7 +102,6 @@ SUPPORTED_DRUGS = list(DRUG_NAME_MAPPING.keys())
122
  DISCLAIMER = (
123
  "本資訊僅供參考,若您對藥物使用有任何疑問,請務必諮詢您的醫師或藥師。"
124
  )
125
-
126
  REFERENCE_MAPPING = {
127
  "如何用藥?": "病人使用須知、用法用量",
128
  "如何保存與攜帶?": "包裝及儲存",
@@ -130,7 +109,6 @@ REFERENCE_MAPPING = {
130
  "每次劑量多少?": "用法用量、藥袋上的醫囑",
131
  "用藥時間?": "用法用量、藥袋上的醫囑",
132
  }
133
-
134
  REFERENCE_TO_INTENT = {
135
  "如何用藥?": ["操作 (Administration)"],
136
  "如何保存與攜帶?": ["保存/攜帶 (Storage & Handling)"],
@@ -138,20 +116,16 @@ REFERENCE_TO_INTENT = {
138
  "每次劑量多少?": ["劑量調整 (Dosage Adjustment)"],
139
  "用藥時間?": ["時間/併用 (Timing & Interaction)"],
140
  }
141
-
142
  PROMPT_TEMPLATES = {
143
  "analyze_query": """
144
  請分析以下使用者問題,並完成以下三個任務:
145
- 1. 將問題分解為 1-3 個核心子問題。
146
- 2. 從清單中選擇所有相關的意圖分類。
147
- 3. 評估問題複雜度,返回 'simple'(單一問題或簡單意圖)或 'complex'(多子問題或複雜意圖,如副作用、劑量調整)。
148
-
149
  請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列)、'intents' (字串陣列) 和 'complexity' (字串) 三個鍵。
150
  範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"], "complexity": "simple"}}
151
-
152
  意圖分類清單:
153
  {options}。
154
-
155
  使用者問題:{query}
156
  """,
157
  "expand_query": """
@@ -161,54 +135,41 @@ PROMPT_TEMPLATES = {
161
  """,
162
  "final_answer": """
163
  您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆:
164
-
165
  一、 回覆規範:
166
- - 回覆語言:使用繁體中文,口語化且易懂,避免專業術語或解釋之。
167
- - 結構:先以「簡答:」標記提供簡短總結答案(50-100字),然後以「詳答:」標記提供詳細解釋,最後提醒使用者諮詢醫師。
168
- - 長度:簡答控制在50-100字,詳答根據問題複雜度調整,簡單問題約100-200字,複雜問題(如多步驟的裝置安裝或藥品使用)可達300-500字。
169
- - 態度:親切、專業、關懷,避免驚嚇使用者。
170
- {additional_instruction}
171
-
172
- ---
173
- 參考資料:
174
- {context}
175
- ---
176
 
 
 
 
 
 
177
  使用者問題:{query}
178
-
179
  請直接輸出最終的答案:
180
  """,
181
  "analyze_reference": """
182
  從以下清單選擇最匹配的使用者問題分類,如果沒有匹配,返回 'none'。
183
-
184
  分類清單:
185
  {options}
186
-
187
  使用者問題:{query}
188
-
189
  請僅輸出分類名稱或 'none',不需任何額外的解釋或格式。
190
  """,
191
  "clarification": """
192
  請根據以下使用者問題,生成一個簡潔、禮貌的澄清性提問,以幫助我更精準地回答。問題應引導使用者提供更多細節,例如具體藥名、使用情境,並附上範例問題。請在回覆中明確告知使用者,目前僅支援以下藥物詢問:
193
- - Fentanyl patch
194
- - Spiriva Respimat
195
- - NITROSTAT
196
- - AUGMENTIN FOR SYRUP
197
- - Ozempic
198
- - NIFLEC
199
- - Fosamax
200
- - Humira
201
- - PREMARIN
202
- - SMECTA
203
-
204
  範例:
205
  使用者問題:這個藥會怎麼樣?
206
  澄清提問:您好,請問您指的藥物是下列哪一種?目前僅支援以下藥物詢問:Fentanyl patch、Spiriva Respimat...等。例如,您可以問:「Fentanyl patch 的副作用有哪些?」請確認藥名或提供更多細節。
207
-
208
  使用者問題:{query}
209
  """
210
  }
211
-
212
  # ---------- 日誌設定 ----------
213
  logging.basicConfig(
214
  level=logging.INFO,
@@ -222,7 +183,6 @@ def _norm(s: str) -> str:
222
  s = unicodedata.normalize("NFKC", s)
223
  return re.sub(r"[^\w\s]", "", s.lower()).strip()
224
 
225
-
226
  @dataclass
227
  class FusedCandidate:
228
  idx: int
@@ -230,7 +190,6 @@ class FusedCandidate:
230
  sem_score: float
231
  bm_score: float
232
 
233
-
234
  @dataclass
235
  class RerankResult:
236
  idx: int
@@ -238,7 +197,6 @@ class RerankResult:
238
  text: str
239
  meta: Dict[str, Any] = field(default_factory=dict)
240
 
241
-
242
  @dataclass
243
  class ConversationState:
244
  query_history: List[str] = field(default_factory=list)
@@ -248,7 +206,6 @@ class ConversationState:
248
  last_answer: Optional[str] = None
249
  clarification_count: int = 0
250
 
251
-
252
  # ---------- 核心 RAG 邏輯 ----------
253
  class RagPipeline:
254
  def __init__(self):
@@ -314,8 +271,8 @@ class RagPipeline:
314
  with open(BM25_PKL, "rb") as f:
315
  bm25_data = pickle.load(f)
316
  self.state.bm25 = bm25_data["bm25"]
317
- if not isinstance(self.state.bm25, BM25Okapi):
318
- raise ValueError("Loaded BM25 is not a BM25Okapi instance.")
319
 
320
  log.info("所有模型與資料載入完成。")
321
 
@@ -334,9 +291,11 @@ class RagPipeline:
334
  for part in q_norm_parts:
335
  if part in self.drug_name_to_ids:
336
  drug_ids.update(self.drug_name_to_ids[part])
 
337
  for drug_name, ids in self.drug_name_to_ids.items():
338
  if drug_name in _norm(query):
339
  drug_ids.update(ids)
 
340
  return sorted(drug_ids)
341
 
342
  def _build_drug_name_to_ids(self) -> Dict[str, List[str]]:
@@ -355,11 +314,14 @@ class RagPipeline:
355
  part = part.strip()
356
  if part and len(part) > 1:
357
  self.drug_name_to_ids.setdefault(part, []).append(drug_id)
 
358
  for alias, canonical_name in DRUG_NAME_MAPPING.items():
359
  if _norm(canonical_name) in _norm(row["drug_name_norm"]):
360
  self.drug_name_to_ids.setdefault(_norm(alias), []).append(drug_id)
 
361
  for key in self.drug_name_to_ids:
362
  self.drug_name_to_ids[key] = sorted(set(self.drug_name_to_ids[key]))
 
363
  return self.drug_name_to_ids
364
 
365
  def _load_drug_name_vocabulary(self):
@@ -372,17 +334,19 @@ class RagPipeline:
372
  self.drug_vocab["zh"].add(word)
373
  else:
374
  self.drug_vocab["en"].add(word)
375
- for alias in DRUG_NAME_MAPPING:
376
- if re.search(r"[\u4e00-\u9fff]", alias):
377
- self.drug_vocab["zh"].add(alias)
378
- else:
379
- self.drug_vocab["en"].add(alias)
380
- for word in self.drug_vocab["zh"]:
381
- try:
382
- if word not in jieba.dt.FREQ:
383
- jieba.add_word(word, freq=2_000_000)
384
- except Exception:
385
- pass
 
 
386
 
387
  @tenacity.retry(
388
  wait=tenacity.wait_fixed(2),
@@ -421,6 +385,7 @@ class RagPipeline:
421
  conv_state.clarification_count += 1
422
  if conv_state.clarification_count > 3:
423
  return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, []
 
424
  clarification = self._generate_clarification_query(q_orig)
425
  conv_state.last_answer = clarification
426
  return f"{clarification}\n\n{DISCLAIMER}", []
@@ -436,31 +401,37 @@ class RagPipeline:
436
  sections = [s.strip() for s in sections_str.split('、') if s.strip() and s != '藥袋上的醫囑']
437
  intents = REFERENCE_TO_INTENT.get(ref_key, [])
438
  context = self._build_context_from_csv(drug_ids, sections)
 
439
  # 根據參考資料判斷複雜度
440
  if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
441
  complexity = "complex" # 多步驟的裝置安裝或藥品使用
442
  elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
443
  complexity = "simple" # 副作用問題
 
 
444
  else:
445
- return await self._fallback_rag(target_id, q_orig, drug_ids)
 
446
 
447
  conv_state.intents = intents
448
  conv_state.complexity = complexity
449
-
450
  max_tokens = LLM_MODEL_CONFIG["max_tokens_complex"] if complexity == "complex" else LLM_MODEL_CONFIG["max_tokens_simple"]
451
  prompt = self._make_final_prompt(q_orig, context, intents)
452
  answer = self._llm_call(
453
  [{"role": "user", "content": prompt}],
454
  max_tokens=max_tokens
455
  )
 
456
  if not answer:
457
  return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
458
 
459
  answer = answer.replace("*", "")
460
  conv_state.last_answer = answer
461
  final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
 
462
  log.info(f"查詢處理完成,耗時: {time.time() - start_time:.2f}秒")
463
  return final_answer, drug_ids
 
464
  except Exception as e:
465
  log.error(f"處理查詢時發生錯誤: {e}", exc_info=True)
466
  return f"處理時發生內部錯誤,請稍後再試。\n{DISCLAIMER}", []
@@ -471,13 +442,16 @@ class RagPipeline:
471
  sub_queries = analysis.get("sub_queries", [q_orig])
472
  intents = analysis.get("intents", [])
473
  complexity = "simple" # 預設為簡單
 
474
  sections = []
475
  for intent in intents:
476
  sections.extend(INTENT_TO_SECTION.get(intent, []))
 
477
  if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
478
  complexity = "complex"
479
  elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
480
  complexity = "simple"
 
481
  conv_state.intents = intents
482
  conv_state.complexity = complexity
483
 
@@ -486,6 +460,7 @@ class RagPipeline:
486
  conv_state.clarification_count += 1
487
  if conv_state.clarification_count > 3:
488
  return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, drug_ids
 
489
  clarification = self._generate_clarification_query(q_orig)
490
  conv_state.last_answer = clarification
491
  return f"{clarification}\n\n{DISCLAIMER}", drug_ids
@@ -494,7 +469,6 @@ class RagPipeline:
494
  drug_ids, sub_queries, intents
495
  )
496
  final_candidates = all_candidates[:TOP_K_SENTENCES]
497
-
498
  reranked_results = [
499
  RerankResult(
500
  idx=c.idx,
@@ -504,6 +478,7 @@ class RagPipeline:
504
  )
505
  for c in final_candidates
506
  ]
 
507
  prioritized = self._prioritize_context(reranked_results, intents)
508
  context = self._build_context(prioritized)
509
 
@@ -516,6 +491,7 @@ class RagPipeline:
516
  [{"role": "user", "content": prompt}],
517
  max_tokens=max_tokens
518
  )
 
519
  if not answer:
520
  return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
521
 
@@ -540,9 +516,9 @@ class RagPipeline:
540
  for drug_id in drug_ids:
541
  drug_df = self.df_csv[self.df_csv['drug_id'] == drug_id]
542
  for sec in sections:
543
- sec_row = drug_df[drug_df['section'].str.contains(sec, na=False)]
544
- if not sec_row.empty:
545
- content = sec_row.iloc[0]['content']
546
  if len(context) + len(content) > LLM_MODEL_CONFIG["max_context_chars"]:
547
  return context.strip()
548
  context += content + "\n\n"
@@ -572,32 +548,42 @@ class RagPipeline:
572
  return []
573
 
574
  all_fused_candidates: Dict[int, FusedCandidate] = {}
 
575
  for sub_q in sub_queries:
576
  expanded_q = self._expand_query_with_llm(sub_q, intents)
577
  q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
 
578
  if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
579
  faiss.normalize_L2(q_emb)
 
580
  distances, sem_indices = self.state.index.search(q_emb, PRE_RERANK_K)
581
 
582
  tokenized_query = list(jieba.cut(expanded_q))
583
  bm25_scores = self.state.bm25.get_scores(tokenized_query)
 
584
  rel_idx = np.fromiter(relevant_indices, dtype=np.int64)
585
  rel_scores = bm25_scores[rel_idx]
586
  top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
587
  doc_to_bm25_score: Dict[int, float] = {
588
  int(i): float(bm25_scores[i]) for i in top_rel
589
  }
 
590
  candidate_scores: Dict[int, Dict[str, float]] = {}
 
591
  def to_similarity(d: float) -> float:
592
  return float(d) if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT else 1.0 / (1.0 + float(d))
 
593
  for i, dist in zip(sem_indices[0], distances[0]):
594
  if i in relevant_indices:
595
  candidate_scores[i] = {"sem": to_similarity(dist), "bm": 0.0}
 
596
  for i, score in doc_to_bm25_score.items():
597
  if i in relevant_indices:
598
  candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score
 
599
  if not candidate_scores:
600
  continue
 
601
  keys = list(candidate_scores.keys())
602
  sem_scores = np.array([candidate_scores[k]["sem"] for k in keys])
603
  bm_scores = np.array([candidate_scores[k]["bm"] for k in keys])
@@ -606,12 +592,14 @@ class RagPipeline:
606
  return (x - x.min()) / (x.max() - x.min() + 1e-8) if x.max() - x.min() > 0 else np.zeros_like(x)
607
 
608
  sem_n, bm_n = norm(sem_scores), norm(bm_scores)
 
609
  for idx, k in enumerate(keys):
610
  fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
611
  if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score:
612
  all_fused_candidates[k] = FusedCandidate(
613
  idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]
614
  )
 
615
  return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
616
 
617
  def _expand_query_with_llm(self, query: str, intents: List[str]) -> str:
@@ -626,11 +614,14 @@ class RagPipeline:
626
  def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]:
627
  if not intents:
628
  return results
 
629
  prioritized_sections = set()
630
  for intent in intents:
631
  prioritized_sections.update(INTENT_TO_SECTION.get(intent, []))
 
632
  if not prioritized_sections:
633
  return results
 
634
  prioritized, other = [], []
635
  for res in results:
636
  if res.meta.get("section") in prioritized_sections:
@@ -665,6 +656,7 @@ class RagPipeline:
665
  add_instr += "\n請根據以下問題與參考資料對應回答:"
666
  for q, refs in REFERENCE_MAPPING.items():
667
  add_instr += f"\n- {q}: {refs}"
 
668
  return PROMPT_TEMPLATES["final_answer"].format(
669
  additional_instruction=add_instr, context=context, query=query
670
  )
@@ -674,22 +666,18 @@ class RagPipeline:
674
  return json.loads(s)
675
  except json.JSONDecodeError:
676
  try:
677
- m = re.search(r"\{.*?\}", s, re.DOTALL)
678
  if m:
679
  return json.loads(m.group(0))
680
  except json.JSONDecodeError:
681
  pass
682
- return default
683
-
684
 
685
  # ---------- FastAPI 事件與路由 ----------
686
  class AppConfig:
687
  CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
688
  CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
689
-
690
-
691
- rag_pipeline: Optional[RagPipeline] = None
692
-
693
 
694
  @asynccontextmanager
695
  async def lifespan(app: FastAPI):
@@ -701,10 +689,8 @@ async def lifespan(app: FastAPI):
701
  yield
702
  log.info("服務關閉中。")
703
 
704
-
705
  app = FastAPI(lifespan=lifespan)
706
 
707
-
708
  @app.post("/webhook")
709
  async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
710
  signature = request.headers.get("X-Line-Signature")
@@ -712,6 +698,7 @@ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
712
  raise HTTPException(status_code=400, detail="Missing LINE X-Line-Signature header")
713
 
714
  body = await request.body()
 
715
  try:
716
  hash_obj = hmac.new(AppConfig.CHANNEL_SECRET.encode("utf-8"), body, hashlib.sha256)
717
  expected_signature = base64.b64encode(hash_obj.digest()).decode("utf-8")
@@ -728,65 +715,55 @@ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
728
  raise HTTPException(status_code=400, detail="Invalid JSON body")
729
 
730
  for event in data.get("events", []):
731
- if (
732
- event.get("type") == "message"
733
- and event.get("message", {}).get("type") == "text"
734
- ):
735
- user_text = event.get("message", {}).get("text", "").strip()
736
  source = event.get("source", {})
737
  stype = source.get("type")
738
  target_id = (
739
  source.get("userId") or source.get("groupId") or source.get("roomId")
740
  )
741
- if user_text and target_id:
742
- background_tasks.add_task(
743
- process_user_query, stype, target_id, user_text
744
- )
745
- return Response(status_code=status.HTTP_200_OK)
746
 
 
 
 
 
 
 
 
747
 
748
- async def process_user_query(source_type: str, target_id: str, user_text: str):
749
  try:
750
  if not rag_pipeline:
751
- await line_push_generic(source_type, target_id,
752
  "系統正在啟動中,請稍後再試。")
753
  return
754
- answer, drug_ids = await rag_pipeline.answer_question(target_id, user_text)
755
- await line_push_generic(source_type, target_id, answer)
 
 
756
  except Exception as e:
757
  log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True)
758
- await line_push_generic(
759
  source_type,
760
  target_id,
761
  f"抱歉,處理時發生未預期的錯誤。\n{DISCLAIMER}",
762
  )
763
 
764
-
765
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
766
- async def line_api_call(endpoint: str, data: Dict):
767
  headers = {
768
  "Content-Type": "application/json",
769
  "Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}",
770
  }
771
- async with aiohttp.ClientSession() as session:
772
- async with session.post(
773
- f"https://api.line.me/v2/bot/message/{endpoint}",
774
- headers=headers,
775
- json=data,
776
- timeout=10,
777
- ) as response:
778
- response.raise_for_status()
779
-
780
-
781
- async def line_reply(reply_token: str, text: str):
782
- messages = [
783
- {"type": "text", "text": chunk}
784
- for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]
785
- ]
786
- await line_api_call("reply", {"replyToken": reply_token, "messages": messages})
787
-
788
 
789
- async def line_push_generic(source_type: str, target_id: str, text: str):
790
  messages = [
791
  {"type": "text", "text": chunk}
792
  for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]
@@ -794,24 +771,9 @@ async def line_push_generic(source_type: str, target_id: str, text: str):
794
  if "目前僅支援以下藥物詢問" in text:
795
  drug_list = "\n".join(f"- {drug}" for drug in SUPPORTED_DRUGS)
796
  messages.append({"type": "text", "text": f"支援的藥物清單:\n{drug_list}"})
797
- data = {"to": target_id, "messages": messages}
798
- await line_api_call("push", data)
799
-
800
-
801
- def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> List[str]:
802
- candidates = set()
803
- q_norm = _norm(query)
804
- for word in re.findall(r"[a-z0-9]+", q_norm):
805
- if word in drug_vocab["en"]:
806
- candidates.add(word)
807
- for token in jieba.cut(q_norm):
808
- if token in drug_vocab["zh"]:
809
- candidates.add(token)
810
- supported_drugs = set(DRUG_NAME_MAPPING.keys()).union(DRUG_NAME_MAPPING.values())
811
- if not candidates.issubset(supported_drugs):
812
- candidates = set()
813
- return list(candidates)
814
 
 
 
815
 
816
  # ---------- 執行 ----------
817
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import pathlib
3
  import re
 
17
  import unicodedata
18
  from collections import defaultdict
19
  import asyncio
 
20
 
21
+ # 第三方函式庫
22
  import numpy as np
23
  import pandas as pd
24
  import jieba
 
32
  import uvicorn
33
  from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks
34
 
35
+ # 限制 PyTorch 執行緒數量,避免 CPU 環境下過度佔用資源
36
  torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1")))
37
 
38
  # ===== CONFIG =====
 
43
  raise RuntimeError(f"FATAL: Missing required environment variable: {var}")
44
  return v
45
 
 
46
  def _require_llm_config():
47
  for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
48
  _require_env(k)
49
 
 
50
  # --------- 路徑設定 ------------
51
  CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
52
  FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
53
  SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
54
  BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
 
55
  TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 20))
56
  PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
57
  MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
 
58
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
 
59
  LLM_API_CONFIG = {
60
  "base_url": _require_env("LITELLM_BASE_URL"),
61
  "api_key": _require_env("LITELLM_API_KEY"),
62
  "model": _require_env("LM_MODEL"),
63
  }
 
64
  LLM_MODEL_CONFIG = {
65
  "max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)),
66
  "max_tokens_simple": int(os.getenv("MAX_TOKENS_SIMPLE", 256)),
 
77
  "劑量調整 (Dosage Adjustment)",
78
  "禁忌症/適應症 (Contraindications/Indications)",
79
  ]
 
80
  INTENT_TO_SECTION = {
81
  "操作 (Administration)": ["用法用量", "病人使用須知"],
82
  "保存/攜帶 (Storage & Handling)": ["包裝及儲存"],
 
86
  "劑量調整 (Dosage Adjustment)": ["用法用量"],
87
  "禁忌症/適應症 (Contraindications/Indications)": ["適應症", "禁忌", "警語與注意事項"],
88
  }
 
89
  DRUG_NAME_MAPPING = {
90
  "fentanyl patch": "fentanyl",
91
  "spiriva respimat": "spiriva",
 
102
  DISCLAIMER = (
103
  "本資訊僅供參考,若您對藥物使用有任何疑問,請務必諮詢您的醫師或藥師。"
104
  )
 
105
  REFERENCE_MAPPING = {
106
  "如何用藥?": "病人使用須知、用法用量",
107
  "如何保存與攜帶?": "包裝及儲存",
 
109
  "每次劑量多少?": "用法用量、藥袋上的醫囑",
110
  "用藥時間?": "用法用量、藥袋上的醫囑",
111
  }
 
112
  REFERENCE_TO_INTENT = {
113
  "如何用藥?": ["操作 (Administration)"],
114
  "如何保存與攜帶?": ["保存/攜帶 (Storage & Handling)"],
 
116
  "每次劑量多少?": ["劑量調整 (Dosage Adjustment)"],
117
  "用藥時間?": ["時間/併用 (Timing & Interaction)"],
118
  }
 
119
  PROMPT_TEMPLATES = {
120
  "analyze_query": """
121
  請分析以下使用者問題,並完成以下三個任務:
122
+ 將問題分解為 1-3 個核心子問題。
123
+ 從清單中選擇所有相關的意圖分類。
124
+ 評估問題複雜度,返回 'simple'(單一問題或簡單意圖)或 'complex'(多子問題或複雜意圖,如副作用、劑量調整)。
 
125
  請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列)、'intents' (字串陣列) 和 'complexity' (字串) 三個鍵。
126
  範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"], "complexity": "simple"}}
 
127
  意圖分類清單:
128
  {options}。
 
129
  使用者問題:{query}
130
  """,
131
  "expand_query": """
 
135
  """,
136
  "final_answer": """
137
  您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆:
 
138
  一、 回覆規範:
 
 
 
 
 
 
 
 
 
 
139
 
140
+ 回覆語言:使用繁體中文,口語化且易懂,避免專業術語或解釋之。
141
+ 結構:先以「簡答:」標記提供簡短總結答案(50-100字),然後以「詳答:」標記提供詳細解釋,最後提醒使用者諮詢醫師。
142
+ 長度:簡答控制在50-100字,詳答根據問題複雜度調整,簡單問題約100-200字,複雜問題(如多步驟的裝置安裝或藥品使用)可達300-500字。
143
+ 態度:親切、專業、關懷,避免驚嚇使用者。 {additional_instruction}
144
+ 參考資料: {context}
145
  使用者問題:{query}
 
146
  請直接輸出最終的答案:
147
  """,
148
  "analyze_reference": """
149
  從以下清單選擇最匹配的使用者問題分類,如果沒有匹配,返回 'none'。
 
150
  分類清單:
151
  {options}
 
152
  使用者問題:{query}
 
153
  請僅輸出分類名稱或 'none',不需任何額外的解釋或格式。
154
  """,
155
  "clarification": """
156
  請根據以下使用者問題,生成一個簡潔、禮貌的澄清性提問,以幫助我更精準地回答。問題應引導使用者提供更多細節,例如具體藥名、使用情境,並附上範例問題。請在回覆中明確告知使用者,目前僅支援以下藥物詢問:
157
+ Fentanyl patch
158
+ Spiriva Respimat
159
+ NITROSTAT
160
+ AUGMENTIN FOR SYRUP
161
+ Ozempic
162
+ NIFLEC
163
+ Fosamax
164
+ Humira
165
+ PREMARIN
166
+ SMECTA
 
167
  範例:
168
  使用者問題:這個藥會怎麼樣?
169
  澄清提問:您好,請問您指的藥物是下列哪一種?目前僅支援以下藥物詢問:Fentanyl patch、Spiriva Respimat...等。例如,您可以問:「Fentanyl patch 的副作用有哪些?」請確認藥名或提供更多細節。
 
170
  使用者問題:{query}
171
  """
172
  }
 
173
  # ---------- 日誌設定 ----------
174
  logging.basicConfig(
175
  level=logging.INFO,
 
183
  s = unicodedata.normalize("NFKC", s)
184
  return re.sub(r"[^\w\s]", "", s.lower()).strip()
185
 
 
186
  @dataclass
187
  class FusedCandidate:
188
  idx: int
 
190
  sem_score: float
191
  bm_score: float
192
 
 
193
  @dataclass
194
  class RerankResult:
195
  idx: int
 
197
  text: str
198
  meta: Dict[str, Any] = field(default_factory=dict)
199
 
 
200
  @dataclass
201
  class ConversationState:
202
  query_history: List[str] = field(default_factory=list)
 
206
  last_answer: Optional[str] = None
207
  clarification_count: int = 0
208
 
 
209
  # ---------- 核心 RAG 邏輯 ----------
210
  class RagPipeline:
211
  def __init__(self):
 
271
  with open(BM25_PKL, "rb") as f:
272
  bm25_data = pickle.load(f)
273
  self.state.bm25 = bm25_data["bm25"]
274
+ if not isinstance(self.state.bm25, BM25Okapi):
275
+ raise ValueError("Loaded BM25 is not a BM25Okapi instance.")
276
 
277
  log.info("所有模型與資料載入完成。")
278
 
 
291
  for part in q_norm_parts:
292
  if part in self.drug_name_to_ids:
293
  drug_ids.update(self.drug_name_to_ids[part])
294
+
295
  for drug_name, ids in self.drug_name_to_ids.items():
296
  if drug_name in _norm(query):
297
  drug_ids.update(ids)
298
+
299
  return sorted(drug_ids)
300
 
301
  def _build_drug_name_to_ids(self) -> Dict[str, List[str]]:
 
314
  part = part.strip()
315
  if part and len(part) > 1:
316
  self.drug_name_to_ids.setdefault(part, []).append(drug_id)
317
+
318
  for alias, canonical_name in DRUG_NAME_MAPPING.items():
319
  if _norm(canonical_name) in _norm(row["drug_name_norm"]):
320
  self.drug_name_to_ids.setdefault(_norm(alias), []).append(drug_id)
321
+
322
  for key in self.drug_name_to_ids:
323
  self.drug_name_to_ids[key] = sorted(set(self.drug_name_to_ids[key]))
324
+
325
  return self.drug_name_to_ids
326
 
327
  def _load_drug_name_vocabulary(self):
 
334
  self.drug_vocab["zh"].add(word)
335
  else:
336
  self.drug_vocab["en"].add(word)
337
+
338
+ for alias in DRUG_NAME_MAPPING:
339
+ if re.search(r"[\u4e00-\u9fff]", alias):
340
+ self.drug_vocab["zh"].add(alias)
341
+ else:
342
+ self.drug_vocab["en"].add(alias)
343
+
344
+ for word in self.drug_vocab["zh"]:
345
+ try:
346
+ if word not in jieba.dt.FREQ:
347
+ jieba.add_word(word, freq=2_000_000)
348
+ except Exception:
349
+ pass
350
 
351
  @tenacity.retry(
352
  wait=tenacity.wait_fixed(2),
 
385
  conv_state.clarification_count += 1
386
  if conv_state.clarification_count > 3:
387
  return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, []
388
+
389
  clarification = self._generate_clarification_query(q_orig)
390
  conv_state.last_answer = clarification
391
  return f"{clarification}\n\n{DISCLAIMER}", []
 
401
  sections = [s.strip() for s in sections_str.split('、') if s.strip() and s != '藥袋上的醫囑']
402
  intents = REFERENCE_TO_INTENT.get(ref_key, [])
403
  context = self._build_context_from_csv(drug_ids, sections)
404
+
405
  # 根據參考資料判斷複雜度
406
  if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
407
  complexity = "complex" # 多步驟的裝置安裝或藥品使用
408
  elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
409
  complexity = "simple" # 副作用問題
410
+ else:
411
+ return await self._fallback_rag(target_id, q_orig, drug_ids)
412
  else:
413
+ # If no direct reference mapping, use fallback RAG
414
+ return await self._fallback_rag(target_id, q_orig, drug_ids)
415
 
416
  conv_state.intents = intents
417
  conv_state.complexity = complexity
 
418
  max_tokens = LLM_MODEL_CONFIG["max_tokens_complex"] if complexity == "complex" else LLM_MODEL_CONFIG["max_tokens_simple"]
419
  prompt = self._make_final_prompt(q_orig, context, intents)
420
  answer = self._llm_call(
421
  [{"role": "user", "content": prompt}],
422
  max_tokens=max_tokens
423
  )
424
+
425
  if not answer:
426
  return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
427
 
428
  answer = answer.replace("*", "")
429
  conv_state.last_answer = answer
430
  final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
431
+
432
  log.info(f"查詢處理完成,耗時: {time.time() - start_time:.2f}秒")
433
  return final_answer, drug_ids
434
+
435
  except Exception as e:
436
  log.error(f"處理查詢時發生錯誤: {e}", exc_info=True)
437
  return f"處理時發生內部錯誤,請稍後再試。\n{DISCLAIMER}", []
 
442
  sub_queries = analysis.get("sub_queries", [q_orig])
443
  intents = analysis.get("intents", [])
444
  complexity = "simple" # 預設為簡單
445
+
446
  sections = []
447
  for intent in intents:
448
  sections.extend(INTENT_TO_SECTION.get(intent, []))
449
+
450
  if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
451
  complexity = "complex"
452
  elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
453
  complexity = "simple"
454
+
455
  conv_state.intents = intents
456
  conv_state.complexity = complexity
457
 
 
460
  conv_state.clarification_count += 1
461
  if conv_state.clarification_count > 3:
462
  return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, drug_ids
463
+
464
  clarification = self._generate_clarification_query(q_orig)
465
  conv_state.last_answer = clarification
466
  return f"{clarification}\n\n{DISCLAIMER}", drug_ids
 
469
  drug_ids, sub_queries, intents
470
  )
471
  final_candidates = all_candidates[:TOP_K_SENTENCES]
 
472
  reranked_results = [
473
  RerankResult(
474
  idx=c.idx,
 
478
  )
479
  for c in final_candidates
480
  ]
481
+
482
  prioritized = self._prioritize_context(reranked_results, intents)
483
  context = self._build_context(prioritized)
484
 
 
491
  [{"role": "user", "content": prompt}],
492
  max_tokens=max_tokens
493
  )
494
+
495
  if not answer:
496
  return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
497
 
 
516
  for drug_id in drug_ids:
517
  drug_df = self.df_csv[self.df_csv['drug_id'] == drug_id]
518
  for sec in sections:
519
+ sec_rows = drug_df[drug_df['section'].str.contains(sec, na=False)]
520
+ for _, row in sec_rows.iterrows():
521
+ content = row['content']
522
  if len(context) + len(content) > LLM_MODEL_CONFIG["max_context_chars"]:
523
  return context.strip()
524
  context += content + "\n\n"
 
548
  return []
549
 
550
  all_fused_candidates: Dict[int, FusedCandidate] = {}
551
+
552
  for sub_q in sub_queries:
553
  expanded_q = self._expand_query_with_llm(sub_q, intents)
554
  q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
555
+
556
  if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
557
  faiss.normalize_L2(q_emb)
558
+
559
  distances, sem_indices = self.state.index.search(q_emb, PRE_RERANK_K)
560
 
561
  tokenized_query = list(jieba.cut(expanded_q))
562
  bm25_scores = self.state.bm25.get_scores(tokenized_query)
563
+
564
  rel_idx = np.fromiter(relevant_indices, dtype=np.int64)
565
  rel_scores = bm25_scores[rel_idx]
566
  top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
567
  doc_to_bm25_score: Dict[int, float] = {
568
  int(i): float(bm25_scores[i]) for i in top_rel
569
  }
570
+
571
  candidate_scores: Dict[int, Dict[str, float]] = {}
572
+
573
  def to_similarity(d: float) -> float:
574
  return float(d) if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT else 1.0 / (1.0 + float(d))
575
+
576
  for i, dist in zip(sem_indices[0], distances[0]):
577
  if i in relevant_indices:
578
  candidate_scores[i] = {"sem": to_similarity(dist), "bm": 0.0}
579
+
580
  for i, score in doc_to_bm25_score.items():
581
  if i in relevant_indices:
582
  candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score
583
+
584
  if not candidate_scores:
585
  continue
586
+
587
  keys = list(candidate_scores.keys())
588
  sem_scores = np.array([candidate_scores[k]["sem"] for k in keys])
589
  bm_scores = np.array([candidate_scores[k]["bm"] for k in keys])
 
592
  return (x - x.min()) / (x.max() - x.min() + 1e-8) if x.max() - x.min() > 0 else np.zeros_like(x)
593
 
594
  sem_n, bm_n = norm(sem_scores), norm(bm_scores)
595
+
596
  for idx, k in enumerate(keys):
597
  fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
598
  if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score:
599
  all_fused_candidates[k] = FusedCandidate(
600
  idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]
601
  )
602
+
603
  return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
604
 
605
  def _expand_query_with_llm(self, query: str, intents: List[str]) -> str:
 
614
  def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]:
615
  if not intents:
616
  return results
617
+
618
  prioritized_sections = set()
619
  for intent in intents:
620
  prioritized_sections.update(INTENT_TO_SECTION.get(intent, []))
621
+
622
  if not prioritized_sections:
623
  return results
624
+
625
  prioritized, other = [], []
626
  for res in results:
627
  if res.meta.get("section") in prioritized_sections:
 
656
  add_instr += "\n請根據以下問題與參考資料對應回答:"
657
  for q, refs in REFERENCE_MAPPING.items():
658
  add_instr += f"\n- {q}: {refs}"
659
+
660
  return PROMPT_TEMPLATES["final_answer"].format(
661
  additional_instruction=add_instr, context=context, query=query
662
  )
 
666
  return json.loads(s)
667
  except json.JSONDecodeError:
668
  try:
669
+ m = re.search(r"{.*?}", s, re.DOTALL)
670
  if m:
671
  return json.loads(m.group(0))
672
  except json.JSONDecodeError:
673
  pass
674
+ return default
 
675
 
676
  # ---------- FastAPI 事件與路由 ----------
677
  class AppConfig:
678
  CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
679
  CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
680
+ rag_pipeline: Optional[RagPipeline] = None
 
 
 
681
 
682
  @asynccontextmanager
683
  async def lifespan(app: FastAPI):
 
689
  yield
690
  log.info("服務關閉中。")
691
 
 
692
  app = FastAPI(lifespan=lifespan)
693
 
 
694
  @app.post("/webhook")
695
  async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
696
  signature = request.headers.get("X-Line-Signature")
 
698
  raise HTTPException(status_code=400, detail="Missing LINE X-Line-Signature header")
699
 
700
  body = await request.body()
701
+
702
  try:
703
  hash_obj = hmac.new(AppConfig.CHANNEL_SECRET.encode("utf-8"), body, hashlib.sha256)
704
  expected_signature = base64.b64encode(hash_obj.digest()).decode("utf-8")
 
715
  raise HTTPException(status_code=400, detail="Invalid JSON body")
716
 
717
  for event in data.get("events", []):
718
+ if event.get("type") == "message":
719
+ msg = event.get("message", {})
 
 
 
720
  source = event.get("source", {})
721
  stype = source.get("type")
722
  target_id = (
723
  source.get("userId") or source.get("groupId") or source.get("roomId")
724
  )
 
 
 
 
 
725
 
726
+ if msg.get("type") == "text" and target_id:
727
+ user_text = msg.get("text", "").strip()
728
+ if user_text:
729
+ background_tasks.add_task(
730
+ process_user_query, stype, target_id, user_text
731
+ )
732
+ return Response(status_code=status.HTTP_200_OK)
733
 
734
+ async def process_user_query(source_type: str, target_id: str, input_data: str):
735
  try:
736
  if not rag_pipeline:
737
+ line_push_generic(source_type, target_id,
738
  "系統正在啟動中,請稍後再試。")
739
  return
740
+
741
+ answer, drug_ids = await rag_pipeline.answer_question(target_id, input_data)
742
+ line_push_generic(source_type, target_id, answer)
743
+
744
  except Exception as e:
745
  log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True)
746
+ line_push_generic(
747
  source_type,
748
  target_id,
749
  f"抱歉,處理時發生未預期的錯誤。\n{DISCLAIMER}",
750
  )
751
 
 
752
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
753
+ def line_api_call(endpoint: str, data: Dict):
754
  headers = {
755
  "Content-Type": "application/json",
756
  "Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}",
757
  }
758
+ response = requests.post(
759
+ f"https://api.line.me/v2/bot/message/{endpoint}",
760
+ headers=headers,
761
+ json=data,
762
+ timeout=10,
763
+ )
764
+ response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
765
 
766
+ def line_push_generic(source_type: str, target_id: str, text: str):
767
  messages = [
768
  {"type": "text", "text": chunk}
769
  for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]
 
771
  if "目前僅支援以下藥物詢問" in text:
772
  drug_list = "\n".join(f"- {drug}" for drug in SUPPORTED_DRUGS)
773
  messages.append({"type": "text", "text": f"支援的藥物清單:\n{drug_list}"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
 
775
+ data = {"to": target_id, "messages": messages}
776
+ line_api_call("push", data)
777
 
778
  # ---------- 執行 ----------
779
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -12,4 +12,5 @@ torch
12
  # LLM 呼叫相關
13
  openai
14
  tenacity
15
- requests
 
 
12
  # LLM 呼叫相關
13
  openai
14
  tenacity
15
+ requests
16
+ aiohttp