Song commited on
Commit
b4be37b
·
1 Parent(s): 84e38c0
Files changed (1) hide show
  1. app.py +85 -20
app.py CHANGED
@@ -124,7 +124,8 @@ class RagPipeline:
124
  self.embedding_model = self._load_embedding_model()
125
  self.reranker = self._load_reranker_model()
126
  self.csv_path = self._ensure_csv_path(CSV_PATH)
127
- self.drug_name_to_ids = {} # 預建 drug_name 到 drug_id 的映射
 
128
 
129
  def _load_embedding_model(self):
130
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -161,7 +162,7 @@ class RagPipeline:
161
  raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
162
 
163
  self.df_csv = pd.read_csv(self.csv_path, dtype=str).fillna('')
164
- required_cols = {"drug_id", "drug_name_norm", "section"}
165
  missing_cols = required_cols - set(self.df_csv.columns)
166
  if missing_cols:
167
  raise ValueError(f"CSV 缺少必要欄位: {missing_cols}")
@@ -172,6 +173,8 @@ class RagPipeline:
172
  self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
173
  log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
174
 
 
 
175
  self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
176
  self.state.bm25 = self._ensure_bm25_index()
177
 
@@ -187,6 +190,31 @@ class RagPipeline:
187
 
188
  log.info("所有模型與資料載入完成。")
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  def _load_or_build_sentence_index(self):
191
  if os.path.exists(FAISS_INDEX) and os.path.exists(SENTENCES_PKL):
192
  log.info("載入已存在的索引...")
@@ -246,7 +274,6 @@ class RagPipeline:
246
 
247
  try:
248
  log.info("步驟 1/5: 辨識藥品名稱...")
249
- # 修正:移除不必要的 self.df_csv 參數
250
  drug_ids = self._find_drug_ids_from_name(q_orig)
251
  if not drug_ids:
252
  log.warning("未找到對應藥品,直接回覆。")
@@ -272,7 +299,7 @@ class RagPipeline:
272
  return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
273
 
274
  for sub_q in sub_queries:
275
- expanded_q = self._expand_query_with_llm(sub_q, intents)
276
  log.info(f"擴展後的查詢: '{expanded_q}'")
277
 
278
  weights = self._adjust_section_weights(intents)
@@ -363,7 +390,7 @@ class RagPipeline:
363
 
364
  @lru_cache(maxsize=128)
365
  def _find_drug_ids_from_name(self, query: str) -> List[str]:
366
- candidates = extract_drug_candidates_from_query(query)
367
  expanded = expand_aliases(candidates)
368
 
369
  drug_ids = set()
@@ -374,9 +401,41 @@ class RagPipeline:
374
  except Exception as e:
375
  log.warning(f"Failed to match '{alias}': {e}. Skipping this alias.")
376
  return list(drug_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
 
 
 
 
378
  @lru_cache(maxsize=128)
379
- def _expand_query_with_llm(self, query: str, intents: Tuple[str]) -> str:
380
  prompt = f"""請根據以下意圖:{list(intents)},擴展原始查詢,加入相關同義詞、相關術語和不同的說法。
381
  原始查詢:{query}
382
  請僅輸出擴展後的查詢,不需任何額外的解釋或格式。"""
@@ -582,22 +641,28 @@ def line_reply(reply_token: str, text: str):
582
  log.error(f"LINE API 回覆失敗: {e}")
583
 
584
  # ---- 額外工具函式 ----
585
- def extract_drug_candidates_from_query(query: str) -> list:
586
- query = re.sub(r"[A-Za-z]+", lambda m: m.group(0).lower(), query)
587
  candidates = set()
588
- parts = query.split(":", 1)
589
- drug_part = parts[0]
590
- for m in re.finditer(r"[a-zA-Z]{3,}", drug_part):
591
- candidates.add(m.group(0))
592
- for token in re.split(r"[\s,/()()]+", drug_part):
593
- clean_token = re.sub(r'[a-zA-Z0-9\s]+', '', token).strip()
594
- if clean_token and clean_token.lower() not in DRUG_STOPWORDS:
 
 
 
 
 
 
 
 
595
  candidates.add(clean_token)
596
-
597
- for query_name, dataset_name in DRUG_NAME_MAPPING.items():
598
- if query_name in query.lower():
599
- candidates.add(dataset_name)
600
- return [c for c in candidates if len(c) > 1]
601
 
602
  def expand_aliases(candidates: list) -> list:
603
  out = set()
 
124
  self.embedding_model = self._load_embedding_model()
125
  self.reranker = self._load_reranker_model()
126
  self.csv_path = self._ensure_csv_path(CSV_PATH)
127
+ self.drug_name_to_ids = {}
128
+ self.drug_vocab = {"zh": set(), "en": set()}
129
 
130
  def _load_embedding_model(self):
131
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
162
  raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
163
 
164
  self.df_csv = pd.read_csv(self.csv_path, dtype=str).fillna('')
165
+ required_cols = {"drug_id", "drug_name_zh", "drug_name_en", "section"}
166
  missing_cols = required_cols - set(self.df_csv.columns)
167
  if missing_cols:
168
  raise ValueError(f"CSV 缺少必要欄位: {missing_cols}")
 
173
  self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
174
  log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
175
 
176
+ self._load_drug_name_vocabulary()
177
+
178
  self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
179
  self.state.bm25 = self._ensure_bm25_index()
180
 
 
190
 
191
  log.info("所有模型與資料載入完成。")
192
 
193
+ def _load_drug_name_vocabulary(self):
194
+ """從 CSV 載入所有中英文藥名,建立詞庫"""
195
+ log.info("建立藥名詞庫...")
196
+ zh_names = self.df_csv['drug_name_zh'].dropna().unique()
197
+ en_names = self.df_csv['drug_name_en'].dropna().unique()
198
+
199
+ for name in zh_names:
200
+ # 去除標點符號和空格
201
+ clean_name = re.sub(r'[^\u4e00-\u9fff]', '', str(name)).strip()
202
+ if clean_name:
203
+ self.drug_vocab["zh"].add(clean_name)
204
+
205
+ for name in en_names:
206
+ clean_name = str(name).lower().replace(' ', '').strip()
207
+ if clean_name:
208
+ self.drug_vocab["en"].add(clean_name)
209
+
210
+ # 加入別名
211
+ for _, dataset_name in DRUG_NAME_MAPPING.items():
212
+ clean_name = dataset_name.lower().replace(' ', '').strip()
213
+ if clean_name:
214
+ self.drug_vocab["en"].add(clean_name)
215
+
216
+ log.info(f"藥名詞庫建立完成。中文詞彙數: {len(self.drug_vocab['zh'])}, 英文詞彙數: {len(self.drug_vocab['en'])}")
217
+
218
  def _load_or_build_sentence_index(self):
219
  if os.path.exists(FAISS_INDEX) and os.path.exists(SENTENCES_PKL):
220
  log.info("載入已存在的索引...")
 
274
 
275
  try:
276
  log.info("步驟 1/5: 辨識藥品名稱...")
 
277
  drug_ids = self._find_drug_ids_from_name(q_orig)
278
  if not drug_ids:
279
  log.warning("未找到對應藥品,直接回覆。")
 
299
  return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
300
 
301
  for sub_q in sub_queries:
302
+ expanded_q = self._expand_query_with_llm(sub_q, tuple(intents))
303
  log.info(f"擴展後的查詢: '{expanded_q}'")
304
 
305
  weights = self._adjust_section_weights(intents)
 
390
 
391
  @lru_cache(maxsize=128)
392
  def _find_drug_ids_from_name(self, query: str) -> List[str]:
393
+ candidates = extract_drug_candidates_from_query(query, self.drug_vocab)
394
  expanded = expand_aliases(candidates)
395
 
396
  drug_ids = set()
 
401
  except Exception as e:
402
  log.warning(f"Failed to match '{alias}': {e}. Skipping this alias.")
403
  return list(drug_ids)
404
+
405
+ def _analyze_query(self, query: str) -> Dict[str, Any]:
406
+ """一次性呼叫 LLM,同時獲取子問題和意圖。"""
407
+ options = "\n".join(f"- {c}" for c in INTENT_CATEGORIES)
408
+ prompt = f"""
409
+ 請分析以下使用者問題,並完成以下兩個任務:
410
+ 1. 將問題分解為1-3個子問題。
411
+ 2. 判斷問題的意圖,從清單中選擇最貼近的分類。
412
+
413
+ 請以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intent' (字串) 兩個鍵。
414
+ 範例: {{"sub_queries": ["子問題一", "子問題二"], "intent": "分類名稱"}}
415
+ 清單:
416
+ {options}
417
+ 使用者問題:{query}
418
+ """
419
+ messages = [{"role": "user", "content": prompt}]
420
+ response = ""
421
+ try:
422
+ response = self._llm_call(messages, temperature=0.2)
423
+ result = json.loads(response)
424
+
425
+ sub_queries = result.get("sub_queries", [])
426
+ intent = result.get("intent", None)
427
+
428
+ if not sub_queries:
429
+ sub_queries = [query]
430
+
431
+ return {"sub_queries": sub_queries, "intents": [intent] if intent else []}
432
 
433
+ except Exception as e:
434
+ log.error(f"分析查詢時發生錯誤,LLM回覆: '{response}',錯誤: {e}", exc_info=True)
435
+ return {"sub_queries": [query], "intents": []}
436
+
437
  @lru_cache(maxsize=128)
438
+ def _expand_query_with_llm(self, query: str, intents: tuple) -> str:
439
  prompt = f"""請根據以下意圖:{list(intents)},擴展原始查詢,加入相關同義詞、相關術語和不同的說法。
440
  原始查詢:{query}
441
  請僅輸出擴展後的查詢,不需任何額外的解釋或格式。"""
 
641
  log.error(f"LINE API 回覆失敗: {e}")
642
 
643
  # ---- 額外工具函式 ----
644
+ def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
645
+ query = query.lower()
646
  candidates = set()
647
+
648
+ # 步驟 1: 處理冒號,僅對冒號前的部分進行藥名提取
649
+ drug_part = query.split(':', 1)[0]
650
+
651
+ # 步驟 2: 獨立處理中英文,並與詞庫比對
652
+ # 處理英文
653
+ words = re.findall(r'[a-z0-9]+', drug_part)
654
+ for word in words:
655
+ if word in drug_vocab["en"]:
656
+ candidates.add(word)
657
+
658
+ # 處理中文
659
+ for token in jieba.cut(drug_part):
660
+ clean_token = re.sub(r'[^\u4e00-\u9fff]', '', token).strip()
661
+ if clean_token and clean_token in drug_vocab["zh"] and clean_token not in DRUG_STOPWORDS:
662
  candidates.add(clean_token)
663
+
664
+ return list(candidates)
665
+
 
 
666
 
667
  def expand_aliases(candidates: list) -> list:
668
  out = set()