Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
b4be37b
1
Parent(s):
84e38c0
hi
Browse files
app.py
CHANGED
@@ -124,7 +124,8 @@ class RagPipeline:
|
|
124 |
self.embedding_model = self._load_embedding_model()
|
125 |
self.reranker = self._load_reranker_model()
|
126 |
self.csv_path = self._ensure_csv_path(CSV_PATH)
|
127 |
-
self.drug_name_to_ids = {}
|
|
|
128 |
|
129 |
def _load_embedding_model(self):
|
130 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -161,7 +162,7 @@ class RagPipeline:
|
|
161 |
raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
|
162 |
|
163 |
self.df_csv = pd.read_csv(self.csv_path, dtype=str).fillna('')
|
164 |
-
required_cols = {"drug_id", "
|
165 |
missing_cols = required_cols - set(self.df_csv.columns)
|
166 |
if missing_cols:
|
167 |
raise ValueError(f"CSV 缺少必要欄位: {missing_cols}")
|
@@ -172,6 +173,8 @@ class RagPipeline:
|
|
172 |
self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
|
173 |
log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
|
174 |
|
|
|
|
|
175 |
self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
|
176 |
self.state.bm25 = self._ensure_bm25_index()
|
177 |
|
@@ -187,6 +190,31 @@ class RagPipeline:
|
|
187 |
|
188 |
log.info("所有模型與資料載入完成。")
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
def _load_or_build_sentence_index(self):
|
191 |
if os.path.exists(FAISS_INDEX) and os.path.exists(SENTENCES_PKL):
|
192 |
log.info("載入已存在的索引...")
|
@@ -246,7 +274,6 @@ class RagPipeline:
|
|
246 |
|
247 |
try:
|
248 |
log.info("步驟 1/5: 辨識藥品名稱...")
|
249 |
-
# 修正:移除不必要的 self.df_csv 參數
|
250 |
drug_ids = self._find_drug_ids_from_name(q_orig)
|
251 |
if not drug_ids:
|
252 |
log.warning("未找到對應藥品,直接回覆。")
|
@@ -272,7 +299,7 @@ class RagPipeline:
|
|
272 |
return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
|
273 |
|
274 |
for sub_q in sub_queries:
|
275 |
-
expanded_q = self._expand_query_with_llm(sub_q, intents)
|
276 |
log.info(f"擴展後的查詢: '{expanded_q}'")
|
277 |
|
278 |
weights = self._adjust_section_weights(intents)
|
@@ -363,7 +390,7 @@ class RagPipeline:
|
|
363 |
|
364 |
@lru_cache(maxsize=128)
|
365 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
366 |
-
candidates = extract_drug_candidates_from_query(query)
|
367 |
expanded = expand_aliases(candidates)
|
368 |
|
369 |
drug_ids = set()
|
@@ -374,9 +401,41 @@ class RagPipeline:
|
|
374 |
except Exception as e:
|
375 |
log.warning(f"Failed to match '{alias}': {e}. Skipping this alias.")
|
376 |
return list(drug_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
|
|
|
|
|
|
|
|
|
378 |
@lru_cache(maxsize=128)
|
379 |
-
def _expand_query_with_llm(self, query: str, intents:
|
380 |
prompt = f"""請根據以下意圖:{list(intents)},擴展原始查詢,加入相關同義詞、相關術語和不同的說法。
|
381 |
原始查詢:{query}
|
382 |
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。"""
|
@@ -582,22 +641,28 @@ def line_reply(reply_token: str, text: str):
|
|
582 |
log.error(f"LINE API 回覆失敗: {e}")
|
583 |
|
584 |
# ---- 額外工具函式 ----
|
585 |
-
def extract_drug_candidates_from_query(query: str) -> list:
|
586 |
-
query =
|
587 |
candidates = set()
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
595 |
candidates.add(clean_token)
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
candidates.add(dataset_name)
|
600 |
-
return [c for c in candidates if len(c) > 1]
|
601 |
|
602 |
def expand_aliases(candidates: list) -> list:
|
603 |
out = set()
|
|
|
124 |
self.embedding_model = self._load_embedding_model()
|
125 |
self.reranker = self._load_reranker_model()
|
126 |
self.csv_path = self._ensure_csv_path(CSV_PATH)
|
127 |
+
self.drug_name_to_ids = {}
|
128 |
+
self.drug_vocab = {"zh": set(), "en": set()}
|
129 |
|
130 |
def _load_embedding_model(self):
|
131 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
162 |
raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
|
163 |
|
164 |
self.df_csv = pd.read_csv(self.csv_path, dtype=str).fillna('')
|
165 |
+
required_cols = {"drug_id", "drug_name_zh", "drug_name_en", "section"}
|
166 |
missing_cols = required_cols - set(self.df_csv.columns)
|
167 |
if missing_cols:
|
168 |
raise ValueError(f"CSV 缺少必要欄位: {missing_cols}")
|
|
|
173 |
self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
|
174 |
log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
|
175 |
|
176 |
+
self._load_drug_name_vocabulary()
|
177 |
+
|
178 |
self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
|
179 |
self.state.bm25 = self._ensure_bm25_index()
|
180 |
|
|
|
190 |
|
191 |
log.info("所有模型與資料載入完成。")
|
192 |
|
193 |
+
def _load_drug_name_vocabulary(self):
|
194 |
+
"""從 CSV 載入所有中英文藥名,建立詞庫"""
|
195 |
+
log.info("建立藥名詞庫...")
|
196 |
+
zh_names = self.df_csv['drug_name_zh'].dropna().unique()
|
197 |
+
en_names = self.df_csv['drug_name_en'].dropna().unique()
|
198 |
+
|
199 |
+
for name in zh_names:
|
200 |
+
# 去除標點符號和空格
|
201 |
+
clean_name = re.sub(r'[^\u4e00-\u9fff]', '', str(name)).strip()
|
202 |
+
if clean_name:
|
203 |
+
self.drug_vocab["zh"].add(clean_name)
|
204 |
+
|
205 |
+
for name in en_names:
|
206 |
+
clean_name = str(name).lower().replace(' ', '').strip()
|
207 |
+
if clean_name:
|
208 |
+
self.drug_vocab["en"].add(clean_name)
|
209 |
+
|
210 |
+
# 加入別名
|
211 |
+
for _, dataset_name in DRUG_NAME_MAPPING.items():
|
212 |
+
clean_name = dataset_name.lower().replace(' ', '').strip()
|
213 |
+
if clean_name:
|
214 |
+
self.drug_vocab["en"].add(clean_name)
|
215 |
+
|
216 |
+
log.info(f"藥名詞庫建立完成。中文詞彙數: {len(self.drug_vocab['zh'])}, 英文詞彙數: {len(self.drug_vocab['en'])}")
|
217 |
+
|
218 |
def _load_or_build_sentence_index(self):
|
219 |
if os.path.exists(FAISS_INDEX) and os.path.exists(SENTENCES_PKL):
|
220 |
log.info("載入已存在的索引...")
|
|
|
274 |
|
275 |
try:
|
276 |
log.info("步驟 1/5: 辨識藥品名稱...")
|
|
|
277 |
drug_ids = self._find_drug_ids_from_name(q_orig)
|
278 |
if not drug_ids:
|
279 |
log.warning("未找到對應藥品,直接回覆。")
|
|
|
299 |
return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
|
300 |
|
301 |
for sub_q in sub_queries:
|
302 |
+
expanded_q = self._expand_query_with_llm(sub_q, tuple(intents))
|
303 |
log.info(f"擴展後的查詢: '{expanded_q}'")
|
304 |
|
305 |
weights = self._adjust_section_weights(intents)
|
|
|
390 |
|
391 |
@lru_cache(maxsize=128)
|
392 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
393 |
+
candidates = extract_drug_candidates_from_query(query, self.drug_vocab)
|
394 |
expanded = expand_aliases(candidates)
|
395 |
|
396 |
drug_ids = set()
|
|
|
401 |
except Exception as e:
|
402 |
log.warning(f"Failed to match '{alias}': {e}. Skipping this alias.")
|
403 |
return list(drug_ids)
|
404 |
+
|
405 |
+
def _analyze_query(self, query: str) -> Dict[str, Any]:
|
406 |
+
"""一次性呼叫 LLM,同時獲取子問題和意圖。"""
|
407 |
+
options = "\n".join(f"- {c}" for c in INTENT_CATEGORIES)
|
408 |
+
prompt = f"""
|
409 |
+
請分析以下使用者問題,並完成以下兩個任務:
|
410 |
+
1. 將問題分解為1-3個子問題。
|
411 |
+
2. 判斷問題的意圖,從清單中選擇最貼近的分類。
|
412 |
+
|
413 |
+
請以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intent' (字串) 兩個鍵。
|
414 |
+
範例: {{"sub_queries": ["子問題一", "子問題二"], "intent": "分類名稱"}}
|
415 |
+
清單:
|
416 |
+
{options}
|
417 |
+
使用者問題:{query}
|
418 |
+
"""
|
419 |
+
messages = [{"role": "user", "content": prompt}]
|
420 |
+
response = ""
|
421 |
+
try:
|
422 |
+
response = self._llm_call(messages, temperature=0.2)
|
423 |
+
result = json.loads(response)
|
424 |
+
|
425 |
+
sub_queries = result.get("sub_queries", [])
|
426 |
+
intent = result.get("intent", None)
|
427 |
+
|
428 |
+
if not sub_queries:
|
429 |
+
sub_queries = [query]
|
430 |
+
|
431 |
+
return {"sub_queries": sub_queries, "intents": [intent] if intent else []}
|
432 |
|
433 |
+
except Exception as e:
|
434 |
+
log.error(f"分析查詢時發生錯誤,LLM回覆: '{response}',錯誤: {e}", exc_info=True)
|
435 |
+
return {"sub_queries": [query], "intents": []}
|
436 |
+
|
437 |
@lru_cache(maxsize=128)
|
438 |
+
def _expand_query_with_llm(self, query: str, intents: tuple) -> str:
|
439 |
prompt = f"""請根據以下意圖:{list(intents)},擴展原始查詢,加入相關同義詞、相關術語和不同的說法。
|
440 |
原始查詢:{query}
|
441 |
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。"""
|
|
|
641 |
log.error(f"LINE API 回覆失敗: {e}")
|
642 |
|
643 |
# ---- 額外工具函式 ----
|
644 |
+
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
|
645 |
+
query = query.lower()
|
646 |
candidates = set()
|
647 |
+
|
648 |
+
# 步驟 1: 處理冒號,僅對冒號前的部分進行藥名提取
|
649 |
+
drug_part = query.split(':', 1)[0]
|
650 |
+
|
651 |
+
# 步驟 2: 獨立處理中英文,並與詞庫比對
|
652 |
+
# 處理英文
|
653 |
+
words = re.findall(r'[a-z0-9]+', drug_part)
|
654 |
+
for word in words:
|
655 |
+
if word in drug_vocab["en"]:
|
656 |
+
candidates.add(word)
|
657 |
+
|
658 |
+
# 處理中文
|
659 |
+
for token in jieba.cut(drug_part):
|
660 |
+
clean_token = re.sub(r'[^\u4e00-\u9fff]', '', token).strip()
|
661 |
+
if clean_token and clean_token in drug_vocab["zh"] and clean_token not in DRUG_STOPWORDS:
|
662 |
candidates.add(clean_token)
|
663 |
+
|
664 |
+
return list(candidates)
|
665 |
+
|
|
|
|
|
666 |
|
667 |
def expand_aliases(candidates: list) -> list:
|
668 |
out = set()
|