Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
9cf2751
1
Parent(s):
aa4568e
hi
Browse files- app.py +43 -43
- requirements.txt +6 -4
app.py
CHANGED
@@ -30,6 +30,7 @@ from typing import List, Dict, Any, Optional, Tuple, Union
|
|
30 |
from functools import lru_cache
|
31 |
from dataclasses import dataclass, field
|
32 |
from contextlib import asynccontextmanager
|
|
|
33 |
|
34 |
# ---------- 第三方函式庫 ----------
|
35 |
import numpy as np
|
@@ -149,6 +150,12 @@ PROMPT_TEMPLATES = {
|
|
149 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
150 |
log = logging.getLogger(__name__)
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
@dataclass
|
153 |
class FusedCandidate:
|
154 |
idx: int
|
@@ -170,6 +177,8 @@ class RagPipeline:
|
|
170 |
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
|
171 |
raise ValueError("LLM API Key or Base URL is not configured.")
|
172 |
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
|
|
|
|
|
173 |
self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding")
|
174 |
self.reranker = self._load_model(CrossEncoder, RERANKER_MODEL, "reranker")
|
175 |
|
@@ -204,14 +213,13 @@ class RagPipeline:
|
|
204 |
if col not in self.df_csv.columns:
|
205 |
raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}")
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
)
|
210 |
self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
|
211 |
-
# [MODIFIED]
|
212 |
for alias, canonical in DRUG_NAME_MAPPING.items():
|
213 |
-
alias_key =
|
214 |
-
canonical_key =
|
215 |
if canonical_key in self.drug_name_to_ids:
|
216 |
self.drug_name_to_ids[alias_key] = self.drug_name_to_ids[canonical_key]
|
217 |
self._load_drug_name_vocabulary()
|
@@ -221,6 +229,10 @@ class RagPipeline:
|
|
221 |
self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
|
222 |
if hasattr(self.state.index, "nprobe"):
|
223 |
self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
|
|
|
|
|
|
|
|
|
224 |
with open(SENTENCES_PKL, "rb") as f:
|
225 |
data = pickle.load(f)
|
226 |
self.state.sentences = data["sentences"]
|
@@ -256,7 +268,7 @@ class RagPipeline:
|
|
256 |
else:
|
257 |
self.drug_vocab["en"].add(part)
|
258 |
for alias in DRUG_NAME_MAPPING:
|
259 |
-
self.drug_vocab["en"].add(alias
|
260 |
if re.search(r'[\u4e00-\u9fff]', alias):
|
261 |
if alias not in jieba.dt.FREQ:
|
262 |
try:
|
@@ -273,6 +285,7 @@ class RagPipeline:
|
|
273 |
)
|
274 |
def _llm_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None) -> str:
|
275 |
"""安全地呼叫 LLM API,並處理可能的回應內容為空錯誤。"""
|
|
|
276 |
log.info(f"LLM 呼叫開始. 模型: {self.model_name}, max_tokens: {max_tokens}, temperature: {temperature}")
|
277 |
|
278 |
# [DEBUG] 記錄完整的 LLM 提示內容,以便除錯
|
@@ -280,7 +293,8 @@ class RagPipeline:
|
|
280 |
|
281 |
start_time = time.time()
|
282 |
try:
|
283 |
-
|
|
|
284 |
model=self.model_name,
|
285 |
messages=messages,
|
286 |
max_tokens=max_tokens,
|
@@ -312,9 +326,7 @@ class RagPipeline:
|
|
312 |
log.info(f"===== 處理新查詢: '{q_orig}' =====")
|
313 |
try:
|
314 |
drug_ids = self._find_drug_ids_from_name(q_orig)
|
315 |
-
|
316 |
-
log.info("找不到藥品 ID,無法回答。")
|
317 |
-
return f"抱歉,資料庫中找不到該藥品。請確認藥品名稱,或直接諮詢醫師/藥師。{DISCLAIMER}"
|
318 |
log.info(f"步驟 1/5: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒")
|
319 |
step_start = time.time()
|
320 |
|
@@ -345,9 +357,10 @@ class RagPipeline:
|
|
345 |
context = self._build_context(reranked_results)
|
346 |
if not context:
|
347 |
log.info("沒有足夠的上下文來回答問題。")
|
348 |
-
return f"
|
349 |
|
350 |
prompt = self._make_final_prompt(q_orig, context, intents)
|
|
|
351 |
answer = self._llm_call([{"role": "user", "content": prompt}])
|
352 |
|
353 |
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
@@ -365,33 +378,15 @@ class RagPipeline:
|
|
365 |
|
366 |
@lru_cache(maxsize=128)
|
367 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
368 |
-
|
369 |
-
|
370 |
drug_ids = set()
|
371 |
|
372 |
-
#
|
373 |
for k, ids in self.drug_name_to_ids.items():
|
374 |
-
if
|
375 |
-
|
376 |
-
|
377 |
-
else:
|
378 |
-
if re.search(rf"\b{re.escape(k)}\b", q):
|
379 |
-
drug_ids.update(ids)
|
380 |
-
|
381 |
-
# 仍保留舊的候選詞路徑(補強)
|
382 |
-
for alias in candidates:
|
383 |
-
# [MODIFIED] 英文藥名比對使用詞邊界,避免子字串誤判
|
384 |
-
is_english = not re.search(r'[\u4e00-\u9fff]', alias)
|
385 |
-
for drug_name_norm, ids in self.drug_name_to_ids.items():
|
386 |
-
match = False
|
387 |
-
if is_english:
|
388 |
-
if re.search(rf"\b{re.escape(alias)}\b", drug_name_norm):
|
389 |
-
match = True
|
390 |
-
elif alias in drug_name_norm:
|
391 |
-
match = True
|
392 |
-
|
393 |
-
if match:
|
394 |
-
drug_ids.update(ids)
|
395 |
return list(drug_ids)
|
396 |
|
397 |
def _analyze_query(self, query: str) -> Dict[str, Any]:
|
@@ -399,12 +394,18 @@ class RagPipeline:
|
|
399 |
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
|
400 |
query=query
|
401 |
)
|
|
|
402 |
response_str = self._llm_call([{"role": "user", "content": prompt}], temperature=0.1)
|
403 |
return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": []})
|
404 |
|
405 |
def _retrieve_candidates_for_all_queries(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]:
|
406 |
drug_ids_set = set(map(str, drug_ids))
|
407 |
-
|
|
|
|
|
|
|
|
|
|
|
408 |
if not relevant_indices: return []
|
409 |
|
410 |
all_fused_candidates: Dict[int, FusedCandidate] = {}
|
@@ -632,17 +633,16 @@ def line_push_generic(source_type: str, target_id: str, text: str):
|
|
632 |
data = {"to": target_id, "messages": messages}
|
633 |
line_api_call(endpoint, data)
|
634 |
|
635 |
-
# [MODIFIED]
|
636 |
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
|
637 |
candidates = set()
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
for word in words:
|
642 |
if word in drug_vocab["en"]:
|
643 |
candidates.add(word)
|
644 |
|
645 |
-
for token in jieba.cut(
|
646 |
if token in drug_vocab["zh"]:
|
647 |
candidates.add(token)
|
648 |
|
|
|
30 |
from functools import lru_cache
|
31 |
from dataclasses import dataclass, field
|
32 |
from contextlib import asynccontextmanager
|
33 |
+
import unicodedata # [新增]
|
34 |
|
35 |
# ---------- 第三方函式庫 ----------
|
36 |
import numpy as np
|
|
|
150 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
151 |
log = logging.getLogger(__name__)
|
152 |
|
153 |
+
# [新增] 統一字串正規化函式
|
154 |
+
def _norm(s: str) -> str:
|
155 |
+
"""統一化字串:NFKC 正規化、轉小寫、移除標點符號與空白。"""
|
156 |
+
s = unicodedata.normalize("NFKC", s)
|
157 |
+
return re.sub(r"[^\w\s]", "", s.lower()).strip()
|
158 |
+
|
159 |
@dataclass
|
160 |
class FusedCandidate:
|
161 |
idx: int
|
|
|
177 |
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
|
178 |
raise ValueError("LLM API Key or Base URL is not configured.")
|
179 |
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
|
180 |
+
# [FIXED] 新增 model_name 屬性
|
181 |
+
self.model_name = LLM_API_CONFIG["model"]
|
182 |
self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding")
|
183 |
self.reranker = self._load_model(CrossEncoder, RERANKER_MODEL, "reranker")
|
184 |
|
|
|
213 |
if col not in self.df_csv.columns:
|
214 |
raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}")
|
215 |
|
216 |
+
# [MODIFIED] 統一使用 _norm 函式進行正規化
|
217 |
+
self.df_csv['drug_name_norm_normalized'] = self.df_csv['drug_name_norm'].apply(_norm)
|
|
|
218 |
self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
|
219 |
+
# [MODIFIED] 把別名也變成可查鍵,並使用統一正規化
|
220 |
for alias, canonical in DRUG_NAME_MAPPING.items():
|
221 |
+
alias_key = _norm(alias)
|
222 |
+
canonical_key = _norm(canonical)
|
223 |
if canonical_key in self.drug_name_to_ids:
|
224 |
self.drug_name_to_ids[alias_key] = self.drug_name_to_ids[canonical_key]
|
225 |
self._load_drug_name_vocabulary()
|
|
|
229 |
self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
|
230 |
if hasattr(self.state.index, "nprobe"):
|
231 |
self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
|
232 |
+
# [新增] 檢查 FAISS 指標類型,若為 IP 則提示
|
233 |
+
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
234 |
+
log.info("FAISS 索引使用內積 (IP) 指標,檢索時將自動進行 L2 正規化以實現餘弦相似度。")
|
235 |
+
|
236 |
with open(SENTENCES_PKL, "rb") as f:
|
237 |
data = pickle.load(f)
|
238 |
self.state.sentences = data["sentences"]
|
|
|
268 |
else:
|
269 |
self.drug_vocab["en"].add(part)
|
270 |
for alias in DRUG_NAME_MAPPING:
|
271 |
+
self.drug_vocab["en"].add(_norm(alias)) # [修改]
|
272 |
if re.search(r'[\u4e00-\u9fff]', alias):
|
273 |
if alias not in jieba.dt.FREQ:
|
274 |
try:
|
|
|
285 |
)
|
286 |
def _llm_call(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None, temperature: Optional[float] = None) -> str:
|
287 |
"""安全地呼叫 LLM API,並處理可能的回應內容為空錯誤。"""
|
288 |
+
# [FIXED] 修正 self.client 為 self.llm_client
|
289 |
log.info(f"LLM 呼叫開始. 模型: {self.model_name}, max_tokens: {max_tokens}, temperature: {temperature}")
|
290 |
|
291 |
# [DEBUG] 記錄完整的 LLM 提示內容,以便除錯
|
|
|
293 |
|
294 |
start_time = time.time()
|
295 |
try:
|
296 |
+
# [FIXED] 修正 self.client 為 self.llm_client
|
297 |
+
response = self.llm_client.chat.completions.create(
|
298 |
model=self.model_name,
|
299 |
messages=messages,
|
300 |
max_tokens=max_tokens,
|
|
|
326 |
log.info(f"===== 處理新查詢: '{q_orig}' =====")
|
327 |
try:
|
328 |
drug_ids = self._find_drug_ids_from_name(q_orig)
|
329 |
+
# [MODIFIED] 移除找不到藥品 ID 的直接返回邏輯,讓 RAG 流程繼續,以處理無明確藥名的通用問題。
|
|
|
|
|
330 |
log.info(f"步驟 1/5: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒")
|
331 |
step_start = time.time()
|
332 |
|
|
|
357 |
context = self._build_context(reranked_results)
|
358 |
if not context:
|
359 |
log.info("沒有足夠的上下文來回答問題。")
|
360 |
+
return f"根據提供的資料,無法回答您的問題。{DISCLAIMER}"
|
361 |
|
362 |
prompt = self._make_final_prompt(q_orig, context, intents)
|
363 |
+
# [FIXED] 修正 self.client 為 self.llm_client
|
364 |
answer = self._llm_call([{"role": "user", "content": prompt}])
|
365 |
|
366 |
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
|
|
378 |
|
379 |
@lru_cache(maxsize=128)
|
380 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
381 |
+
# [修改] 使用統一正規化函式
|
382 |
+
q_norm = _norm(query)
|
383 |
drug_ids = set()
|
384 |
|
385 |
+
# 藉由查詢正規化後的字串,直接與正規化後的藥名鍵進行比對
|
386 |
for k, ids in self.drug_name_to_ids.items():
|
387 |
+
if k in q_norm:
|
388 |
+
drug_ids.update(ids)
|
389 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
return list(drug_ids)
|
391 |
|
392 |
def _analyze_query(self, query: str) -> Dict[str, Any]:
|
|
|
394 |
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
|
395 |
query=query
|
396 |
)
|
397 |
+
# [FIXED] 修正 self.client 為 self.llm_client
|
398 |
response_str = self._llm_call([{"role": "user", "content": prompt}], temperature=0.1)
|
399 |
return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": []})
|
400 |
|
401 |
def _retrieve_candidates_for_all_queries(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]:
|
402 |
drug_ids_set = set(map(str, drug_ids))
|
403 |
+
# [MODIFIED] 如果 drug_ids 為空,則 relevant_indices 應包含所有索引
|
404 |
+
if drug_ids_set:
|
405 |
+
relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id", "")) in drug_ids_set}
|
406 |
+
else:
|
407 |
+
relevant_indices = set(range(len(self.state.meta)))
|
408 |
+
|
409 |
if not relevant_indices: return []
|
410 |
|
411 |
all_fused_candidates: Dict[int, FusedCandidate] = {}
|
|
|
633 |
data = {"to": target_id, "messages": messages}
|
634 |
line_api_call(endpoint, data)
|
635 |
|
636 |
+
# [MODIFIED] 改善藥名提取的正則表達式,並使用統一正規化函式
|
637 |
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
|
638 |
candidates = set()
|
639 |
+
q_norm = _norm(query) # [修改]
|
640 |
+
|
641 |
+
for word in re.findall(r"[a-z0-9]+", q_norm): # [修改] 允許數字
|
|
|
642 |
if word in drug_vocab["en"]:
|
643 |
candidates.add(word)
|
644 |
|
645 |
+
for token in jieba.cut(q_norm): # [修改]
|
646 |
if token in drug_vocab["zh"]:
|
647 |
candidates.add(token)
|
648 |
|
requirements.txt
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
-
|
2 |
-
pandas
|
3 |
fastapi
|
4 |
uvicorn
|
|
|
|
|
5 |
jieba
|
6 |
rank-bm25
|
|
|
7 |
sentence-transformers
|
8 |
-
#
|
9 |
-
faiss-cpu
|
10 |
torch
|
|
|
11 |
openai
|
12 |
tenacity
|
13 |
requests
|
|
|
1 |
+
# 主要套件
|
|
|
2 |
fastapi
|
3 |
uvicorn
|
4 |
+
pandas
|
5 |
+
numpy
|
6 |
jieba
|
7 |
rank-bm25
|
8 |
+
faiss-cpu # 向量搜尋引擎
|
9 |
sentence-transformers
|
10 |
+
# 確保 torch 版本與 faiss 相容
|
|
|
11 |
torch
|
12 |
+
# LLM 呼叫相關
|
13 |
openai
|
14 |
tenacity
|
15 |
requests
|