Song commited on
Commit
9cf2751
·
1 Parent(s): aa4568e
Files changed (2) hide show
  1. app.py +43 -43
  2. requirements.txt +6 -4
app.py CHANGED
@@ -30,6 +30,7 @@ from typing import List, Dict, Any, Optional, Tuple, Union
30
  from functools import lru_cache
31
  from dataclasses import dataclass, field
32
  from contextlib import asynccontextmanager
 
33
 
34
  # ---------- 第三方函式庫 ----------
35
  import numpy as np
@@ -149,6 +150,12 @@ PROMPT_TEMPLATES = {
149
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
150
  log = logging.getLogger(__name__)
151
 
 
 
 
 
 
 
152
  @dataclass
153
  class FusedCandidate:
154
  idx: int
@@ -170,6 +177,8 @@ class RagPipeline:
170
  if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
171
  raise ValueError("LLM API Key or Base URL is not configured.")
172
  self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
 
 
173
  self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding")
174
  self.reranker = self._load_model(CrossEncoder, RERANKER_MODEL, "reranker")
175
 
@@ -204,14 +213,13 @@ class RagPipeline:
204
  if col not in self.df_csv.columns:
205
  raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}")
206
 
207
- self.df_csv['drug_name_norm_normalized'] = (
208
- self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
209
- )
210
  self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
211
- # [MODIFIED] 把別名也變成可查鍵
212
  for alias, canonical in DRUG_NAME_MAPPING.items():
213
- alias_key = re.sub(r'[^\w\s]', '', alias.lower()).strip()
214
- canonical_key = re.sub(r'[^\w\s]', '', canonical.lower()).strip()
215
  if canonical_key in self.drug_name_to_ids:
216
  self.drug_name_to_ids[alias_key] = self.drug_name_to_ids[canonical_key]
217
  self._load_drug_name_vocabulary()
@@ -221,6 +229,10 @@ class RagPipeline:
221
  self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
222
  if hasattr(self.state.index, "nprobe"):
223
  self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
 
 
 
 
224
  with open(SENTENCES_PKL, "rb") as f:
225
  data = pickle.load(f)
226
  self.state.sentences = data["sentences"]
@@ -256,7 +268,7 @@ class RagPipeline:
256
  else:
257
  self.drug_vocab["en"].add(part)
258
  for alias in DRUG_NAME_MAPPING:
259
- self.drug_vocab["en"].add(alias.lower())
260
  if re.search(r'[\u4e00-\u9fff]', alias):
261
  if alias not in jieba.dt.FREQ:
262
  try:
@@ -273,6 +285,7 @@ class RagPipeline:
273
  )
274
  def _llm_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None) -> str:
275
  """安全地呼叫 LLM API,並處理可能的回應內容為空錯誤。"""
 
276
  log.info(f"LLM 呼叫開始. 模型: {self.model_name}, max_tokens: {max_tokens}, temperature: {temperature}")
277
 
278
  # [DEBUG] 記錄完整的 LLM 提示內容,以便除錯
@@ -280,7 +293,8 @@ class RagPipeline:
280
 
281
  start_time = time.time()
282
  try:
283
- response = self.client.chat.completions.create(
 
284
  model=self.model_name,
285
  messages=messages,
286
  max_tokens=max_tokens,
@@ -312,9 +326,7 @@ class RagPipeline:
312
  log.info(f"===== 處理新查詢: '{q_orig}' =====")
313
  try:
314
  drug_ids = self._find_drug_ids_from_name(q_orig)
315
- if not drug_ids:
316
- log.info("找不到藥品 ID,無法回答。")
317
- return f"抱歉,資料庫中找不到該藥品。請確認藥品名稱,或直接諮詢醫師/藥師。{DISCLAIMER}"
318
  log.info(f"步驟 1/5: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒")
319
  step_start = time.time()
320
 
@@ -345,9 +357,10 @@ class RagPipeline:
345
  context = self._build_context(reranked_results)
346
  if not context:
347
  log.info("沒有足夠的上下文來回答問題。")
348
- return f"根據您的問題,找不到相關的具體說明。建議您直接諮詢醫師或藥師以獲得最準確的資訊。{DISCLAIMER}"
349
 
350
  prompt = self._make_final_prompt(q_orig, context, intents)
 
351
  answer = self._llm_call([{"role": "user", "content": prompt}])
352
 
353
  final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
@@ -365,33 +378,15 @@ class RagPipeline:
365
 
366
  @lru_cache(maxsize=128)
367
  def _find_drug_ids_from_name(self, query: str) -> List[str]:
368
- q = query.lower()
369
- candidates = extract_drug_candidates_from_query(q, self.drug_vocab)
370
  drug_ids = set()
371
 
372
- # 英文:詞邊界;中文:也做子字串掃描
373
  for k, ids in self.drug_name_to_ids.items():
374
- if re.search(r'[\u4e00-\u9fff]', k):
375
- if k in q:
376
- drug_ids.update(ids)
377
- else:
378
- if re.search(rf"\b{re.escape(k)}\b", q):
379
- drug_ids.update(ids)
380
-
381
- # 仍保留舊的候選詞路徑(補強)
382
- for alias in candidates:
383
- # [MODIFIED] 英文藥名比對使用詞邊界,避免子字串誤判
384
- is_english = not re.search(r'[\u4e00-\u9fff]', alias)
385
- for drug_name_norm, ids in self.drug_name_to_ids.items():
386
- match = False
387
- if is_english:
388
- if re.search(rf"\b{re.escape(alias)}\b", drug_name_norm):
389
- match = True
390
- elif alias in drug_name_norm:
391
- match = True
392
-
393
- if match:
394
- drug_ids.update(ids)
395
  return list(drug_ids)
396
 
397
  def _analyze_query(self, query: str) -> Dict[str, Any]:
@@ -399,12 +394,18 @@ class RagPipeline:
399
  options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
400
  query=query
401
  )
 
402
  response_str = self._llm_call([{"role": "user", "content": prompt}], temperature=0.1)
403
  return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": []})
404
 
405
  def _retrieve_candidates_for_all_queries(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]:
406
  drug_ids_set = set(map(str, drug_ids))
407
- relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id", "")) in drug_ids_set}
 
 
 
 
 
408
  if not relevant_indices: return []
409
 
410
  all_fused_candidates: Dict[int, FusedCandidate] = {}
@@ -632,17 +633,16 @@ def line_push_generic(source_type: str, target_id: str, text: str):
632
  data = {"to": target_id, "messages": messages}
633
  line_api_call(endpoint, data)
634
 
635
- # [MODIFIED] 改善藥名提取的正則表達式
636
  def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
637
  candidates = set()
638
- q_lower = query.lower()
639
- # 允許藥名中包含 -, /, . 等符號
640
- words = re.findall(r"[a-z0-9][a-z0-9+\-/\.]*", q_lower)
641
- for word in words:
642
  if word in drug_vocab["en"]:
643
  candidates.add(word)
644
 
645
- for token in jieba.cut(q_lower):
646
  if token in drug_vocab["zh"]:
647
  candidates.add(token)
648
 
 
30
  from functools import lru_cache
31
  from dataclasses import dataclass, field
32
  from contextlib import asynccontextmanager
33
+ import unicodedata # [新增]
34
 
35
  # ---------- 第三方函式庫 ----------
36
  import numpy as np
 
150
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
151
  log = logging.getLogger(__name__)
152
 
153
+ # [新增] 統一字串正規化函式
154
+ def _norm(s: str) -> str:
155
+ """統一化字串:NFKC 正規化、轉小寫、移除標點符號與空白。"""
156
+ s = unicodedata.normalize("NFKC", s)
157
+ return re.sub(r"[^\w\s]", "", s.lower()).strip()
158
+
159
  @dataclass
160
  class FusedCandidate:
161
  idx: int
 
177
  if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
178
  raise ValueError("LLM API Key or Base URL is not configured.")
179
  self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
180
+ # [FIXED] 新增 model_name 屬性
181
+ self.model_name = LLM_API_CONFIG["model"]
182
  self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding")
183
  self.reranker = self._load_model(CrossEncoder, RERANKER_MODEL, "reranker")
184
 
 
213
  if col not in self.df_csv.columns:
214
  raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}")
215
 
216
+ # [MODIFIED] 統一使用 _norm 函式進行正規化
217
+ self.df_csv['drug_name_norm_normalized'] = self.df_csv['drug_name_norm'].apply(_norm)
 
218
  self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
219
+ # [MODIFIED] 把別名也變成可查鍵,並使用統一正規化
220
  for alias, canonical in DRUG_NAME_MAPPING.items():
221
+ alias_key = _norm(alias)
222
+ canonical_key = _norm(canonical)
223
  if canonical_key in self.drug_name_to_ids:
224
  self.drug_name_to_ids[alias_key] = self.drug_name_to_ids[canonical_key]
225
  self._load_drug_name_vocabulary()
 
229
  self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
230
  if hasattr(self.state.index, "nprobe"):
231
  self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
232
+ # [新增] 檢查 FAISS 指標類型,若為 IP 則提示
233
+ if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
234
+ log.info("FAISS 索引使用內積 (IP) 指標,檢索時將自動進行 L2 正規化以實現餘弦相似度。")
235
+
236
  with open(SENTENCES_PKL, "rb") as f:
237
  data = pickle.load(f)
238
  self.state.sentences = data["sentences"]
 
268
  else:
269
  self.drug_vocab["en"].add(part)
270
  for alias in DRUG_NAME_MAPPING:
271
+ self.drug_vocab["en"].add(_norm(alias)) # [修改]
272
  if re.search(r'[\u4e00-\u9fff]', alias):
273
  if alias not in jieba.dt.FREQ:
274
  try:
 
285
  )
286
  def _llm_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None) -> str:
287
  """安全地呼叫 LLM API,並處理可能的回應內容為空錯誤。"""
288
+ # [FIXED] 修正 self.client 為 self.llm_client
289
  log.info(f"LLM 呼叫開始. 模型: {self.model_name}, max_tokens: {max_tokens}, temperature: {temperature}")
290
 
291
  # [DEBUG] 記錄完整的 LLM 提示內容,以便除錯
 
293
 
294
  start_time = time.time()
295
  try:
296
+ # [FIXED] 修正 self.client 為 self.llm_client
297
+ response = self.llm_client.chat.completions.create(
298
  model=self.model_name,
299
  messages=messages,
300
  max_tokens=max_tokens,
 
326
  log.info(f"===== 處理新查詢: '{q_orig}' =====")
327
  try:
328
  drug_ids = self._find_drug_ids_from_name(q_orig)
329
+ # [MODIFIED] 移除找不到藥品 ID 的直接返回邏輯,讓 RAG 流程繼續,以處理無明確藥名的通用問題。
 
 
330
  log.info(f"步驟 1/5: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒")
331
  step_start = time.time()
332
 
 
357
  context = self._build_context(reranked_results)
358
  if not context:
359
  log.info("沒有足夠的上下文來回答問題。")
360
+ return f"根據提供的資料,無法回答您的問題。{DISCLAIMER}"
361
 
362
  prompt = self._make_final_prompt(q_orig, context, intents)
363
+ # [FIXED] 修正 self.client 為 self.llm_client
364
  answer = self._llm_call([{"role": "user", "content": prompt}])
365
 
366
  final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
 
378
 
379
  @lru_cache(maxsize=128)
380
  def _find_drug_ids_from_name(self, query: str) -> List[str]:
381
+ # [修改] 使用統一正規化函式
382
+ q_norm = _norm(query)
383
  drug_ids = set()
384
 
385
+ # 藉由查詢正規化後的字串,直接與正規化後的藥名鍵進行比對
386
  for k, ids in self.drug_name_to_ids.items():
387
+ if k in q_norm:
388
+ drug_ids.update(ids)
389
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  return list(drug_ids)
391
 
392
  def _analyze_query(self, query: str) -> Dict[str, Any]:
 
394
  options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
395
  query=query
396
  )
397
+ # [FIXED] 修正 self.client 為 self.llm_client
398
  response_str = self._llm_call([{"role": "user", "content": prompt}], temperature=0.1)
399
  return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": []})
400
 
401
  def _retrieve_candidates_for_all_queries(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]:
402
  drug_ids_set = set(map(str, drug_ids))
403
+ # [MODIFIED] 如果 drug_ids 為空,則 relevant_indices 應包含所有索引
404
+ if drug_ids_set:
405
+ relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id", "")) in drug_ids_set}
406
+ else:
407
+ relevant_indices = set(range(len(self.state.meta)))
408
+
409
  if not relevant_indices: return []
410
 
411
  all_fused_candidates: Dict[int, FusedCandidate] = {}
 
633
  data = {"to": target_id, "messages": messages}
634
  line_api_call(endpoint, data)
635
 
636
+ # [MODIFIED] 改善藥名提取的正則表達式,並使用統一正規化函式
637
  def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
638
  candidates = set()
639
+ q_norm = _norm(query) # [修改]
640
+
641
+ for word in re.findall(r"[a-z0-9]+", q_norm): # [修改] 允許數字
 
642
  if word in drug_vocab["en"]:
643
  candidates.add(word)
644
 
645
+ for token in jieba.cut(q_norm): # [修改]
646
  if token in drug_vocab["zh"]:
647
  candidates.add(token)
648
 
requirements.txt CHANGED
@@ -1,13 +1,15 @@
1
- numpy
2
- pandas
3
  fastapi
4
  uvicorn
 
 
5
  jieba
6
  rank-bm25
 
7
  sentence-transformers
8
- # 根據您的硬體選擇一個,若有 NVIDIA GPU 請使用 faiss-gpu,否則使用 faiss-cpu
9
- faiss-cpu
10
  torch
 
11
  openai
12
  tenacity
13
  requests
 
1
+ # 主要套件
 
2
  fastapi
3
  uvicorn
4
+ pandas
5
+ numpy
6
  jieba
7
  rank-bm25
8
+ faiss-cpu # 向量搜尋引擎
9
  sentence-transformers
10
+ # 確保 torch 版本與 faiss 相容
 
11
  torch
12
+ # LLM 呼叫相關
13
  openai
14
  tenacity
15
  requests