Spaces:
Running
Running
Song
commited on
Commit
·
7b2e5cd
1
Parent(s):
8e9c857
0906
Browse files
app.py
CHANGED
@@ -1,21 +1,16 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
# -*- coding: utf-8 -*-
|
|
|
3 |
"""
|
4 |
DrugQA (ZH) — 優化版 FastAPI LINE Webhook (最終版)
|
5 |
整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。
|
6 |
-
|
|
|
7 |
"""
|
8 |
|
9 |
-
# ---------- 環境與快取設定
|
10 |
import os
|
11 |
import pathlib
|
12 |
-
os.environ.setdefault("HF_HOME", "/tmp/hf")
|
13 |
-
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers")
|
14 |
-
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache")
|
15 |
-
for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.getenv("XDG_CACHE_HOME")):
|
16 |
-
pathlib.Path(d).mkdir(parents=True, exist_ok=True)
|
17 |
-
|
18 |
-
# ---------- Python 標準函式庫 ----------
|
19 |
import re
|
20 |
import hmac
|
21 |
import base64
|
@@ -31,85 +26,131 @@ from functools import lru_cache
|
|
31 |
from dataclasses import dataclass, field
|
32 |
from contextlib import asynccontextmanager
|
33 |
import unicodedata
|
|
|
|
|
|
|
34 |
|
35 |
-
#
|
36 |
import numpy as np
|
37 |
import pandas as pd
|
38 |
-
from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks
|
39 |
-
import uvicorn
|
40 |
import jieba
|
41 |
from rank_bm25 import BM25Okapi
|
42 |
-
from sentence_transformers import SentenceTransformer
|
43 |
import faiss
|
44 |
import torch
|
45 |
from openai import OpenAI
|
46 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
47 |
import requests
|
|
|
|
|
48 |
|
49 |
-
#
|
50 |
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1")))
|
51 |
|
52 |
-
#
|
53 |
-
# [MODIFIED] 新增環境變數健檢函式
|
54 |
def _require_env(var: str) -> str:
|
|
|
55 |
v = os.getenv(var)
|
56 |
if not v:
|
57 |
raise RuntimeError(f"FATAL: Missing required environment variable: {var}")
|
58 |
return v
|
59 |
|
60 |
-
|
61 |
def _require_llm_config():
|
62 |
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
|
63 |
_require_env(k)
|
64 |
|
|
|
|
|
65 |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
|
66 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
67 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
68 |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
|
69 |
|
70 |
-
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES",
|
71 |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
|
72 |
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
|
73 |
|
74 |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
|
75 |
-
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
|
76 |
|
77 |
LLM_API_CONFIG = {
|
78 |
-
"base_url":
|
79 |
-
"api_key":
|
80 |
-
"model":
|
81 |
}
|
82 |
|
83 |
LLM_MODEL_CONFIG = {
|
84 |
"max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)),
|
85 |
-
"
|
|
|
86 |
"temperature": float(os.getenv("TEMPERATURE", 0.0)),
|
87 |
}
|
88 |
|
89 |
INTENT_CATEGORIES = [
|
90 |
-
"操作 (Administration)",
|
91 |
-
"
|
92 |
-
"
|
|
|
|
|
|
|
|
|
93 |
]
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
DRUG_NAME_MAPPING = {
|
96 |
-
"fentanyl patch": "fentanyl",
|
97 |
-
"
|
98 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
}
|
100 |
-
DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務務必諮詢您的醫師或藥師。"
|
101 |
|
102 |
PROMPT_TEMPLATES = {
|
103 |
"analyze_query": """
|
104 |
-
|
105 |
-
1. 將問題分解為1-3
|
106 |
2. 從清單中選擇所有相關的意圖分類。
|
|
|
107 |
|
108 |
-
請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 '
|
109 |
-
範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"]}}
|
110 |
|
111 |
意圖分類清單:
|
112 |
-
{options}
|
113 |
|
114 |
使用者問題:{query}
|
115 |
""",
|
@@ -119,20 +160,13 @@ PROMPT_TEMPLATES = {
|
|
119 |
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。
|
120 |
""",
|
121 |
"final_answer": """
|
122 |
-
|
123 |
-
|
124 |
-
規則:
|
125 |
-
|
126 |
-
所有回答內容必須嚴格依據提供的參考資料,禁止任何形式的捏造或引用外部資訊。
|
127 |
-
若資料不足以回答,請回覆:「根據提供的資料,無法回答您的問題。」
|
128 |
-
針對原始查詢,以專業、友善的口吻,提供簡潔但資訊完整的中文繁體回答。
|
129 |
-
回答字數限制在120字以內。
|
130 |
-
|
131 |
-
排版格式:
|
132 |
-
|
133 |
-
使用條列式分行呈現,排版需適合LINE對話框顯示。
|
134 |
-
回覆結尾必須加上指定提醒語句:「如有不適請立即就醫。」
|
135 |
|
|
|
|
|
|
|
|
|
|
|
136 |
{additional_instruction}
|
137 |
|
138 |
---
|
@@ -143,19 +177,52 @@ PROMPT_TEMPLATES = {
|
|
143 |
使用者問題:{query}
|
144 |
|
145 |
請直接輸出最終的答案:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
"""
|
147 |
}
|
148 |
|
149 |
# ---------- 日誌設定 ----------
|
150 |
-
logging.basicConfig(
|
|
|
|
|
|
|
151 |
log = logging.getLogger(__name__)
|
152 |
|
153 |
-
#
|
154 |
def _norm(s: str) -> str:
|
155 |
-
"""統一化字串:NFKC
|
156 |
s = unicodedata.normalize("NFKC", s)
|
157 |
return re.sub(r"[^\w\s]", "", s.lower()).strip()
|
158 |
|
|
|
159 |
@dataclass
|
160 |
class FusedCandidate:
|
161 |
idx: int
|
@@ -163,6 +230,7 @@ class FusedCandidate:
|
|
163 |
sem_score: float
|
164 |
bm_score: float
|
165 |
|
|
|
166 |
@dataclass
|
167 |
class RerankResult:
|
168 |
idx: int
|
@@ -170,406 +238,459 @@ class RerankResult:
|
|
170 |
text: str
|
171 |
meta: Dict[str, Any] = field(default_factory=dict)
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
# ---------- 核心 RAG 邏輯 ----------
|
174 |
class RagPipeline:
|
175 |
def __init__(self):
|
176 |
-
# [MODIFIED] 不再傳入 AppConfig,直接引用
|
177 |
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
|
178 |
-
|
179 |
-
self.llm_client = OpenAI(
|
180 |
-
|
|
|
|
|
181 |
self.model_name = LLM_API_CONFIG["model"]
|
182 |
-
self.embedding_model = self._load_model(
|
183 |
-
|
184 |
-
|
185 |
self.drug_name_to_ids: Dict[str, List[str]] = {}
|
186 |
self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()}
|
187 |
-
self.state = type(
|
|
|
188 |
|
189 |
def _load_model(self, model_class, model_name: str, model_type: str):
|
190 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
191 |
-
log.info(f"載入 {model_type} 模型:{model_name}
|
192 |
try:
|
193 |
return model_class(model_name, device=device)
|
194 |
except Exception as e:
|
195 |
-
log.warning(f"
|
196 |
try:
|
197 |
return model_class(model_name, device="cpu")
|
198 |
except Exception as e_cpu:
|
199 |
-
log.error(f"切換至 CPU 仍無法載入模型: {model_name}
|
200 |
raise RuntimeError(f"模型載入失敗: {model_name}")
|
201 |
|
202 |
def load_data(self):
|
203 |
log.info("開始載入資料與模型...")
|
204 |
-
# [MODIFIED] 增加檔案存在性檢查
|
205 |
for path in [CSV_PATH, FAISS_INDEX, SENTENCES_PKL, BM25_PKL]:
|
206 |
if not pathlib.Path(path).exists():
|
207 |
raise FileNotFoundError(f"必要的資料檔案不存在: {path}")
|
208 |
|
209 |
try:
|
210 |
-
self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna(
|
211 |
-
|
212 |
-
for col in
|
213 |
if col not in self.df_csv.columns:
|
214 |
raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
# [MODIFIED] 新增更強大的藥名詞典建立邏輯
|
217 |
-
self.drug_name_to_ids = self._build_drug_name_to_ids()
|
218 |
-
self._load_drug_name_vocabulary()
|
219 |
-
|
220 |
-
log.info("載入 FAISS 索引與句子資料...")
|
221 |
-
self.state.index = faiss.read_index(FAISS_INDEX)
|
222 |
-
self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
|
223 |
-
if hasattr(self.state.index, "nprobe"):
|
224 |
-
self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
|
225 |
-
# [新增] 檢查 FAISS 指標類型,若為 IP 則提示
|
226 |
-
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
227 |
-
log.info("FAISS 索引使用內積 (IP) 指標,檢索時將自動進行 L2 正規化以實現餘弦相似度。")
|
228 |
-
|
229 |
-
with open(SENTENCES_PKL, "rb") as f:
|
230 |
-
data = pickle.load(f)
|
231 |
-
self.state.sentences = data["sentences"]
|
232 |
-
self.state.meta = data["meta"]
|
233 |
-
|
234 |
-
log.info("載入 BM25 索引...")
|
235 |
-
with open(BM25_PKL, "rb") as f:
|
236 |
-
# 載入整個字典,然後取 'bm25' 這個鍵
|
237 |
-
bm25_data = pickle.load(f)
|
238 |
-
self.state.bm25 = bm25_data["bm25"]
|
239 |
-
if not isinstance(self.state.bm25, BM25Okapi):
|
240 |
-
raise ValueError("Loaded BM25 is not a BM25Okapi instance.")
|
241 |
-
|
242 |
-
except (FileNotFoundError, KeyError) as e:
|
243 |
-
log.exception(f"資料或索引檔案載入失敗: {e}")
|
244 |
-
raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}")
|
245 |
-
|
246 |
log.info("所有模型與資料載入完成。")
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
249 |
-
|
250 |
-
|
|
|
251 |
drug_ids = set()
|
252 |
-
|
253 |
for part in q_norm_parts:
|
254 |
if part in self.drug_name_to_ids:
|
255 |
drug_ids.update(self.drug_name_to_ids[part])
|
256 |
-
|
257 |
-
|
|
|
|
|
258 |
|
259 |
def _build_drug_name_to_ids(self) -> Dict[str, List[str]]:
|
260 |
-
|
261 |
for _, row in self.df_csv.iterrows():
|
262 |
-
drug_id = row[
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
all_parts = set(zh_parts + en_parts + norm_parts)
|
271 |
for part in all_parts:
|
272 |
part = part.strip()
|
273 |
if part and len(part) > 1:
|
274 |
-
|
275 |
-
|
276 |
-
# 將 DRUG_NAME_MAPPING 中的別名也加入
|
277 |
for alias, canonical_name in DRUG_NAME_MAPPING.items():
|
278 |
-
if _norm(canonical_name) in _norm(row[
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
return mapping
|
285 |
|
286 |
def _load_drug_name_vocabulary(self):
|
287 |
log.info("建立藥名詞庫...")
|
288 |
for _, row in self.df_csv.iterrows():
|
289 |
-
norm_name = row[
|
290 |
-
words =
|
291 |
for word in words:
|
292 |
-
if re.search(r
|
293 |
self.drug_vocab["zh"].add(word)
|
294 |
else:
|
295 |
self.drug_vocab["en"].add(word)
|
296 |
-
|
297 |
for alias in DRUG_NAME_MAPPING:
|
298 |
-
if re.search(r
|
299 |
self.drug_vocab["zh"].add(alias)
|
300 |
else:
|
301 |
self.drug_vocab["en"].add(alias)
|
302 |
-
|
303 |
for word in self.drug_vocab["zh"]:
|
304 |
try:
|
305 |
if word not in jieba.dt.FREQ:
|
306 |
jieba.add_word(word, freq=2_000_000)
|
307 |
except Exception:
|
308 |
pass
|
309 |
-
|
310 |
@tenacity.retry(
|
311 |
wait=tenacity.wait_fixed(2),
|
312 |
-
stop=tenacity.stop_after_attempt(
|
313 |
-
retry=tenacity.retry_if_exception_type(ValueError),
|
314 |
-
before_sleep=tenacity.before_sleep_log(log, logging.WARNING),
|
315 |
-
after=tenacity.after_log(log, logging.INFO)
|
316 |
)
|
317 |
-
def _llm_call(
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
try:
|
325 |
response = self.llm_client.chat.completions.create(
|
326 |
model=self.model_name,
|
327 |
messages=messages,
|
328 |
-
max_tokens=max_tokens,
|
329 |
-
temperature=temperature,
|
330 |
)
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
log.info(f"LLM 收到完整回應: {response.model_dump_json(indent=2)}")
|
335 |
-
|
336 |
-
# --- 修正處:當回傳內容為空時,直接回傳空字串,而非拋出 ValueError ---
|
337 |
-
if not response.choices or not response.choices[0].message.content:
|
338 |
-
log.warning("LLM 呼叫成功 (200 OK),但回傳內容為空。將回傳空字串。")
|
339 |
-
return ""
|
340 |
-
# --- 修正結束 ---
|
341 |
-
|
342 |
-
content = response.choices[0].message.content
|
343 |
-
log.info(f"LLM 呼叫完成,耗時: {end_time - start_time:.2f} 秒。內容長度: {len(content)} 字。")
|
344 |
-
return content
|
345 |
-
|
346 |
except Exception as e:
|
347 |
-
log.error(f"LLM API
|
348 |
raise
|
349 |
|
350 |
-
def answer_question(self, q_orig: str) -> str:
|
351 |
start_time = time.time()
|
352 |
-
log.info(f"=====
|
|
|
|
|
|
|
353 |
try:
|
354 |
drug_ids = self._find_drug_ids_from_name(q_orig)
|
355 |
-
|
356 |
if not drug_ids:
|
357 |
-
log.info("
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
]
|
|
|
380 |
else:
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
# [新增] 根據意圖,將內容進行排序優化
|
388 |
-
prioritized_results = self._prioritize_context(reranked_results, intents)
|
389 |
-
|
390 |
-
context = self._build_context(prioritized_results)
|
391 |
-
if not context:
|
392 |
-
log.info("沒有足夠的上下文來回答問題。")
|
393 |
-
return f"根據提供的資料,無法回答您的問題。{DISCLAIMER}"
|
394 |
-
|
395 |
prompt = self._make_final_prompt(q_orig, context, intents)
|
396 |
-
answer = self._llm_call(
|
397 |
-
|
398 |
-
|
|
|
399 |
if not answer:
|
400 |
-
|
401 |
-
return f"根據提供的資料,無法回答您的問題。{DISCLAIMER}"
|
402 |
-
# --- 處理結束 ---
|
403 |
-
|
404 |
-
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
405 |
-
log.info(f"步驟 5/5: 答案生成完成。答案長度: {len(answer.strip())} 字。耗時: {time.time() - step_start:.2f} 秒")
|
406 |
-
log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====")
|
407 |
-
return final_answer
|
408 |
|
|
|
|
|
|
|
|
|
|
|
409 |
except Exception as e:
|
410 |
-
log.error(f"
|
411 |
-
return f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
-
|
414 |
-
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
)
|
422 |
-
response_str = self._llm_call([{"role": "user", "content": prompt}]
|
423 |
-
|
|
|
|
|
|
|
424 |
|
425 |
-
def
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
|
432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
|
434 |
all_fused_candidates: Dict[int, FusedCandidate] = {}
|
435 |
-
|
436 |
for sub_q in sub_queries:
|
437 |
-
expanded_q = self._expand_query_with_llm(sub_q,
|
438 |
-
|
439 |
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
|
440 |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
441 |
faiss.normalize_L2(q_emb)
|
442 |
-
distances,
|
443 |
|
444 |
tokenized_query = list(jieba.cut(expanded_q))
|
445 |
-
|
446 |
bm25_scores = self.state.bm25.get_scores(tokenized_query)
|
447 |
-
rel_idx = np.fromiter(relevant_indices, dtype=
|
448 |
rel_scores = bm25_scores[rel_idx]
|
449 |
top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
|
450 |
-
doc_to_bm25_score
|
451 |
-
|
|
|
452 |
candidate_scores: Dict[int, Dict[str, float]] = {}
|
453 |
-
|
454 |
def to_similarity(d: float) -> float:
|
455 |
-
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT
|
456 |
-
|
457 |
-
else:
|
458 |
-
return 1.0 / (1.0 + float(d))
|
459 |
-
|
460 |
-
for i, dist in zip(sim_indices[0], distances[0]):
|
461 |
if i in relevant_indices:
|
462 |
-
|
463 |
-
candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0}
|
464 |
-
|
465 |
for i, score in doc_to_bm25_score.items():
|
466 |
if i in relevant_indices:
|
467 |
candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
keys = list(candidate_scores.keys())
|
472 |
-
sem_scores = np.array([candidate_scores[k][
|
473 |
-
bm_scores = np.array([candidate_scores[k][
|
474 |
-
|
475 |
-
def norm(x):
|
476 |
-
|
477 |
-
return (x - x.min()) / (rng + 1e-8) if rng > 0 else np.zeros_like(x)
|
478 |
|
479 |
sem_n, bm_n = norm(sem_scores), norm(bm_scores)
|
480 |
-
|
481 |
for idx, k in enumerate(keys):
|
482 |
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
|
483 |
-
|
484 |
if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score:
|
485 |
all_fused_candidates[k] = FusedCandidate(
|
486 |
idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]
|
487 |
)
|
488 |
-
|
489 |
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
|
490 |
|
491 |
-
def _expand_query_with_llm(self, query: str, intents:
|
492 |
if not intents:
|
493 |
return query
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
if expanded_query and expanded_query.strip():
|
500 |
-
log.info(f"查詢擴展成功。原始: '{query}', 擴展後: '{expanded_query}'")
|
501 |
-
return expanded_query
|
502 |
-
else:
|
503 |
-
log.warning(f"查詢擴展回傳空內容。原始查詢: '{query}'。將使用原始查詢。")
|
504 |
-
return query
|
505 |
-
except Exception as e:
|
506 |
-
log.error(f"查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。")
|
507 |
-
return query
|
508 |
-
|
509 |
-
def _rerank_with_crossencoder(self, query: str, candidates: List[FusedCandidate]) -> List[RerankResult]:
|
510 |
-
if not candidates: return []
|
511 |
-
|
512 |
-
top_candidates = candidates[:MAX_RERANK_CANDIDATES]
|
513 |
-
pairs = [(query, self.state.sentences[c.idx]) for c in top_candidates]
|
514 |
-
scores = self.reranker.predict(pairs, show_progress_bar=False)
|
515 |
-
|
516 |
-
results = [
|
517 |
-
RerankResult(idx=c.idx, rerank_score=float(score), text=self.state.sentences[c.idx], meta=self.state.meta[c.idx])
|
518 |
-
for c, score in zip(top_candidates, scores)
|
519 |
-
]
|
520 |
-
|
521 |
-
return sorted(results, key=lambda x: x.rerank_score, reverse=True)[:TOP_K_SENTENCES]
|
522 |
|
523 |
def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]:
|
524 |
-
if
|
525 |
return results
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
prioritized =
|
532 |
-
|
533 |
-
|
|
|
|
|
|
|
|
|
534 |
|
535 |
def _build_context(self, reranked_results: List[RerankResult]) -> str:
|
536 |
context = ""
|
537 |
for res in reranked_results:
|
538 |
-
if len(context) + len(res.text) > LLM_MODEL_CONFIG["max_context_chars"]:
|
|
|
539 |
context += res.text + "\n\n"
|
540 |
return context.strip()
|
541 |
|
542 |
def _make_final_prompt(self, query: str, context: str, intents: List[str]) -> str:
|
543 |
add_instr = ""
|
544 |
-
if any(
|
545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
if "保存/攜帶 (Storage & Handling)" in intents:
|
547 |
-
add_instr +=
|
548 |
-
|
|
|
|
|
|
|
|
|
549 |
return PROMPT_TEMPLATES["final_answer"].format(
|
550 |
additional_instruction=add_instr, context=context, query=query
|
551 |
)
|
552 |
-
|
553 |
def _safe_json_parse(self, s: str, default: Any = None) -> Any:
|
554 |
try:
|
555 |
return json.loads(s)
|
556 |
except json.JSONDecodeError:
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
try:
|
561 |
return json.loads(m.group(0))
|
562 |
-
|
563 |
-
|
564 |
return default
|
565 |
|
|
|
566 |
# ---------- FastAPI 事件與路由 ----------
|
567 |
class AppConfig:
|
568 |
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
|
569 |
CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
|
570 |
|
|
|
571 |
rag_pipeline: Optional[RagPipeline] = None
|
572 |
|
|
|
573 |
@asynccontextmanager
|
574 |
async def lifespan(app: FastAPI):
|
575 |
_require_llm_config()
|
@@ -580,21 +701,20 @@ async def lifespan(app: FastAPI):
|
|
580 |
yield
|
581 |
log.info("服務關閉中。")
|
582 |
|
|
|
583 |
app = FastAPI(lifespan=lifespan)
|
584 |
|
|
|
585 |
@app.post("/webhook")
|
586 |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
587 |
signature = request.headers.get("X-Line-Signature")
|
588 |
if not signature:
|
589 |
-
raise HTTPException(status_code=400, detail="Missing X-Line-Signature")
|
590 |
-
if not AppConfig.CHANNEL_SECRET:
|
591 |
-
log.error("CHANNEL_SECRET is not configured.")
|
592 |
-
raise HTTPException(status_code=500, detail="Server configuration error")
|
593 |
|
594 |
body = await request.body()
|
595 |
try:
|
596 |
-
|
597 |
-
expected_signature = base64.b64encode(
|
598 |
except Exception as e:
|
599 |
log.error(f"Failed to generate signature: {e}")
|
600 |
raise HTTPException(status_code=500, detail="Signature generation error")
|
@@ -603,73 +723,97 @@ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
|
603 |
raise HTTPException(status_code=403, detail="Invalid signature")
|
604 |
|
605 |
try:
|
606 |
-
data = json.loads(body.decode(
|
607 |
except json.JSONDecodeError:
|
608 |
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
609 |
|
610 |
for event in data.get("events", []):
|
611 |
-
if
|
612 |
-
|
|
|
|
|
613 |
user_text = event.get("message", {}).get("text", "").strip()
|
614 |
source = event.get("source", {})
|
615 |
stype = source.get("type")
|
616 |
-
target_id =
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
background_tasks.add_task(
|
621 |
-
|
|
|
622 |
return Response(status_code=status.HTTP_200_OK)
|
623 |
|
624 |
-
|
|
|
625 |
try:
|
626 |
-
if rag_pipeline:
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
|
|
631 |
except Exception as e:
|
632 |
log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True)
|
633 |
-
line_push_generic(
|
|
|
|
|
|
|
|
|
|
|
634 |
|
635 |
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
636 |
-
def line_api_call(endpoint: str, data: Dict):
|
637 |
headers = {
|
638 |
"Content-Type": "application/json",
|
639 |
-
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}"
|
640 |
}
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
data = {"to": target_id, "messages": messages}
|
656 |
-
line_api_call(
|
|
|
657 |
|
658 |
-
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) ->
|
659 |
candidates = set()
|
660 |
q_norm = _norm(query)
|
661 |
-
|
662 |
for word in re.findall(r"[a-z0-9]+", q_norm):
|
663 |
if word in drug_vocab["en"]:
|
664 |
candidates.add(word)
|
665 |
-
|
666 |
for token in jieba.cut(q_norm):
|
667 |
if token in drug_vocab["zh"]:
|
668 |
candidates.add(token)
|
669 |
-
|
|
|
|
|
670 |
return list(candidates)
|
671 |
|
|
|
672 |
# ---------- 執行 ----------
|
673 |
if __name__ == "__main__":
|
674 |
-
port = int(os.getenv("PORT", 7860))
|
675 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
"""
|
5 |
DrugQA (ZH) — 優化版 FastAPI LINE Webhook (最終版)
|
6 |
整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。
|
7 |
+
新增動態字數調整、多次互動邏輯與對話狀態管理,提升使用者體驗。
|
8 |
+
僅支援十種藥物。
|
9 |
"""
|
10 |
|
11 |
+
# ---------- 環境與快取設定 ----------
|
12 |
import os
|
13 |
import pathlib
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
import re
|
15 |
import hmac
|
16 |
import base64
|
|
|
26 |
from dataclasses import dataclass, field
|
27 |
from contextlib import asynccontextmanager
|
28 |
import unicodedata
|
29 |
+
from collections import defaultdict
|
30 |
+
import asyncio
|
31 |
+
import aiohttp # 新增:導入 aiohttp 用於異步 HTTP 請求
|
32 |
|
33 |
+
# ------------ 第三方函式庫 -------------
|
34 |
import numpy as np
|
35 |
import pandas as pd
|
|
|
|
|
36 |
import jieba
|
37 |
from rank_bm25 import BM25Okapi
|
38 |
+
from sentence_transformers import SentenceTransformer
|
39 |
import faiss
|
40 |
import torch
|
41 |
from openai import OpenAI
|
42 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
43 |
import requests
|
44 |
+
import uvicorn
|
45 |
+
from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks
|
46 |
|
47 |
+
# ---- 限制 PyTorch 執行緒數量,避免 CPU 環境下過度佔用資源 ----
|
48 |
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1")))
|
49 |
|
50 |
+
# ===== CONFIG =====
|
|
|
51 |
def _require_env(var: str) -> str:
|
52 |
+
"""Return the value of an environment variable or raise an error."""
|
53 |
v = os.getenv(var)
|
54 |
if not v:
|
55 |
raise RuntimeError(f"FATAL: Missing required environment variable: {var}")
|
56 |
return v
|
57 |
|
58 |
+
|
59 |
def _require_llm_config():
|
60 |
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
|
61 |
_require_env(k)
|
62 |
|
63 |
+
|
64 |
+
# --------- 路徑設定 ------------
|
65 |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
|
66 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
67 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
68 |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
|
69 |
|
70 |
+
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 20))
|
71 |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
|
72 |
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
|
73 |
|
74 |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
|
|
|
75 |
|
76 |
LLM_API_CONFIG = {
|
77 |
+
"base_url": _require_env("LITELLM_BASE_URL"),
|
78 |
+
"api_key": _require_env("LITELLM_API_KEY"),
|
79 |
+
"model": _require_env("LM_MODEL"),
|
80 |
}
|
81 |
|
82 |
LLM_MODEL_CONFIG = {
|
83 |
"max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)),
|
84 |
+
"max_tokens_simple": int(os.getenv("MAX_TOKENS_SIMPLE", 256)),
|
85 |
+
"max_tokens_complex": int(os.getenv("MAX_TOKENS_COMPLEX", 1024)),
|
86 |
"temperature": float(os.getenv("TEMPERATURE", 0.0)),
|
87 |
}
|
88 |
|
89 |
INTENT_CATEGORIES = [
|
90 |
+
"操作 (Administration)",
|
91 |
+
"保存/攜帶 (Storage & Handling)",
|
92 |
+
"副作用/異常 (Side Effects / Issues)",
|
93 |
+
"劑型相關 (Dosage Form Concerns)",
|
94 |
+
"時間/併用 (Timing & Interaction)",
|
95 |
+
"劑量調整 (Dosage Adjustment)",
|
96 |
+
"禁忌症/適應症 (Contraindications/Indications)",
|
97 |
]
|
98 |
|
99 |
+
INTENT_TO_SECTION = {
|
100 |
+
"操作 (Administration)": ["用法用量", "病人使用須知"],
|
101 |
+
"保存/攜帶 (Storage & Handling)": ["包裝及儲存"],
|
102 |
+
"副作用/異常 (Side Effects / Issues)": ["不良反應", "警語與注意事項"],
|
103 |
+
"劑型相關 (Dosage Form Concerns)": ["劑型相關", "藥品外觀"],
|
104 |
+
"時間/併用 (Timing & Interaction)": ["用法用量"],
|
105 |
+
"劑量調整 (Dosage Adjustment)": ["用法用量"],
|
106 |
+
"禁忌症/適應症 (Contraindications/Indications)": ["適應症", "禁忌", "警語與注意事項"],
|
107 |
+
}
|
108 |
+
|
109 |
DRUG_NAME_MAPPING = {
|
110 |
+
"fentanyl patch": "fentanyl",
|
111 |
+
"spiriva respimat": "spiriva",
|
112 |
+
"augmentin for syrup": "augmentin syrup",
|
113 |
+
"nitrostat": "nitroglycerin",
|
114 |
+
"ozempic": "ozempic",
|
115 |
+
"niflec": "niflec",
|
116 |
+
"fosamax": "fosamax",
|
117 |
+
"humira": "humira",
|
118 |
+
"premarin": "premarin",
|
119 |
+
"smecta": "smecta",
|
120 |
+
}
|
121 |
+
SUPPORTED_DRUGS = list(DRUG_NAME_MAPPING.keys())
|
122 |
+
DISCLAIMER = (
|
123 |
+
"本資訊僅供參考,若您對藥物使用有任何疑問,請務必諮詢您的醫師或藥師。"
|
124 |
+
)
|
125 |
+
|
126 |
+
REFERENCE_MAPPING = {
|
127 |
+
"如何用藥?": "病人使用須知、用法用量",
|
128 |
+
"如何保存與攜帶?": "包裝及儲存",
|
129 |
+
"可能的副作用?": "警語與注意事項、不良反應",
|
130 |
+
"每次劑量多少?": "用法用量、藥袋上的醫囑",
|
131 |
+
"用藥時間?": "用法用量、藥袋上的醫囑",
|
132 |
+
}
|
133 |
+
|
134 |
+
REFERENCE_TO_INTENT = {
|
135 |
+
"如何用藥?": ["操作 (Administration)"],
|
136 |
+
"如何保存與攜帶?": ["保存/攜帶 (Storage & Handling)"],
|
137 |
+
"可能的副作用?": ["副作用/異常 (Side Effects / Issues)"],
|
138 |
+
"每次劑量多少?": ["劑量調整 (Dosage Adjustment)"],
|
139 |
+
"用藥時間?": ["時間/併用 (Timing & Interaction)"],
|
140 |
}
|
|
|
141 |
|
142 |
PROMPT_TEMPLATES = {
|
143 |
"analyze_query": """
|
144 |
+
請分析以下使用者問題,並完成以下三個任務:
|
145 |
+
1. 將問題分解為 1-3 個核心子問題。
|
146 |
2. 從清單中選擇所有相關的意圖分類。
|
147 |
+
3. 評估問題複雜度,返回 'simple'(單一問題或簡單意圖)或 'complex'(多子問題或複雜意圖,如副作用、劑量調整)。
|
148 |
|
149 |
+
請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列)、'intents' (字串陣列) 和 'complexity' (字串) 三個鍵。
|
150 |
+
範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"], "complexity": "simple"}}
|
151 |
|
152 |
意圖分類清單:
|
153 |
+
{options}。
|
154 |
|
155 |
使用者問題:{query}
|
156 |
""",
|
|
|
160 |
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。
|
161 |
""",
|
162 |
"final_answer": """
|
163 |
+
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
+
一、 回覆規範:
|
166 |
+
- 回覆語言:使用繁體中文,口語化且易懂,避免專業術語或解釋之。
|
167 |
+
- 結構:先以「簡答:」標記提供簡短總結答案(50-100字),然後以「詳答:」標記提供詳細解釋,最後提醒使用者諮詢醫師。
|
168 |
+
- 長度:簡答控制在50-100字,詳答根據問題複雜度調整,簡單問題約100-200字,複雜問題(如多步驟的裝置安裝或藥品使用)可達300-500字。
|
169 |
+
- 態度:親切、專業、關懷,避免驚嚇使用者。
|
170 |
{additional_instruction}
|
171 |
|
172 |
---
|
|
|
177 |
使用者問題:{query}
|
178 |
|
179 |
請直接輸出最終的答案:
|
180 |
+
""",
|
181 |
+
"analyze_reference": """
|
182 |
+
從以下清單選擇最匹配的使用者問題分類,如果沒有匹配,返回 'none'。
|
183 |
+
|
184 |
+
分類清單:
|
185 |
+
{options}
|
186 |
+
|
187 |
+
使用者問題:{query}
|
188 |
+
|
189 |
+
請僅輸出分類名稱或 'none',不需任何額外的解釋或格式。
|
190 |
+
""",
|
191 |
+
"clarification": """
|
192 |
+
請根據以下使用者問題,生成一個簡潔、禮貌的澄清性提問,以幫助我更精準地回答。問題應引導使用者提供更多細節,例如具體藥名、使用情境,並附上範例問題。請在回覆中明確告知使用者,目前僅支援以下藥物詢問:
|
193 |
+
- Fentanyl patch
|
194 |
+
- Spiriva Respimat
|
195 |
+
- NITROSTAT
|
196 |
+
- AUGMENTIN FOR SYRUP
|
197 |
+
- Ozempic
|
198 |
+
- NIFLEC
|
199 |
+
- Fosamax
|
200 |
+
- Humira
|
201 |
+
- PREMARIN
|
202 |
+
- SMECTA
|
203 |
+
|
204 |
+
範例:
|
205 |
+
使用者問題:這個藥會怎麼樣?
|
206 |
+
澄清提問:您好,請問您指的藥物是下列哪一種?目前僅支援以下藥物詢問:Fentanyl patch、Spiriva Respimat...等。例如,您可以問:「Fentanyl patch 的副作用有哪些?」請確認藥名或提供更多細節。
|
207 |
+
|
208 |
+
使用者問題:{query}
|
209 |
"""
|
210 |
}
|
211 |
|
212 |
# ---------- 日誌設定 ----------
|
213 |
+
logging.basicConfig(
|
214 |
+
level=logging.INFO,
|
215 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
216 |
+
)
|
217 |
log = logging.getLogger(__name__)
|
218 |
|
219 |
+
# ---------- 字串正規化 ----------
|
220 |
def _norm(s: str) -> str:
|
221 |
+
"""統一化字串:NFKC 正規化、轉小寫、移除標點與空格。"""
|
222 |
s = unicodedata.normalize("NFKC", s)
|
223 |
return re.sub(r"[^\w\s]", "", s.lower()).strip()
|
224 |
|
225 |
+
|
226 |
@dataclass
|
227 |
class FusedCandidate:
|
228 |
idx: int
|
|
|
230 |
sem_score: float
|
231 |
bm_score: float
|
232 |
|
233 |
+
|
234 |
@dataclass
|
235 |
class RerankResult:
|
236 |
idx: int
|
|
|
238 |
text: str
|
239 |
meta: Dict[str, Any] = field(default_factory=dict)
|
240 |
|
241 |
+
|
242 |
+
@dataclass
|
243 |
+
class ConversationState:
|
244 |
+
query_history: List[str] = field(default_factory=list)
|
245 |
+
drug_ids: List[str] = field(default_factory=list)
|
246 |
+
intents: List[str] = field(default_factory=list)
|
247 |
+
complexity: str = "simple"
|
248 |
+
last_answer: Optional[str] = None
|
249 |
+
clarification_count: int = 0
|
250 |
+
|
251 |
+
|
252 |
# ---------- 核心 RAG 邏輯 ----------
|
253 |
class RagPipeline:
|
254 |
def __init__(self):
|
|
|
255 |
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
|
256 |
+
raise ValueError("LLM API Key or Base URL is not configured.")
|
257 |
+
self.llm_client = OpenAI(
|
258 |
+
api_key=LLM_API_CONFIG["api_key"],
|
259 |
+
base_url=LLM_API_CONFIG["base_url"],
|
260 |
+
)
|
261 |
self.model_name = LLM_API_CONFIG["model"]
|
262 |
+
self.embedding_model = self._load_model(
|
263 |
+
SentenceTransformer, EMBEDDING_MODEL, "embedding"
|
264 |
+
)
|
265 |
self.drug_name_to_ids: Dict[str, List[str]] = {}
|
266 |
self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()}
|
267 |
+
self.state = type("state", (), {})()
|
268 |
+
self.conversations: Dict[str, ConversationState] = defaultdict(ConversationState)
|
269 |
|
270 |
def _load_model(self, model_class, model_name: str, model_type: str):
|
271 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
272 |
+
log.info(f"載入 {model_type} 模型:{model_name} 於 {device}")
|
273 |
try:
|
274 |
return model_class(model_name, device=device)
|
275 |
except Exception as e:
|
276 |
+
log.warning(f"載入 {model_type} 至 {device} 失敗: {e},嘗試切換至 CPU。")
|
277 |
try:
|
278 |
return model_class(model_name, device="cpu")
|
279 |
except Exception as e_cpu:
|
280 |
+
log.error(f"切換至 CPU 仍無法載入模型: {model_name}。錯誤訊息: {e_cpu}")
|
281 |
raise RuntimeError(f"模型載入失敗: {model_name}")
|
282 |
|
283 |
def load_data(self):
|
284 |
log.info("開始載入資料與模型...")
|
|
|
285 |
for path in [CSV_PATH, FAISS_INDEX, SENTENCES_PKL, BM25_PKL]:
|
286 |
if not pathlib.Path(path).exists():
|
287 |
raise FileNotFoundError(f"必要的資料檔案不存在: {path}")
|
288 |
|
289 |
try:
|
290 |
+
self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna("")
|
291 |
+
required_cols = ["drug_name_norm", "drug_id", "section", "content"]
|
292 |
+
for col in required_cols:
|
293 |
if col not in self.df_csv.columns:
|
294 |
raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}")
|
295 |
+
except Exception as e:
|
296 |
+
log.error(f"讀取 CSV 失敗: {e}")
|
297 |
+
raise
|
298 |
+
|
299 |
+
self.drug_name_to_ids = self._build_drug_name_to_ids()
|
300 |
+
self._load_drug_name_vocabulary()
|
301 |
+
|
302 |
+
log.info("載入 FAISS 索引與句子資料...")
|
303 |
+
self.state.index = faiss.read_index(FAISS_INDEX)
|
304 |
+
self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
|
305 |
+
if hasattr(self.state.index, "nprobe"):
|
306 |
+
self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
|
307 |
+
|
308 |
+
with open(SENTENCES_PKL, "rb") as f:
|
309 |
+
data = pickle.load(f)
|
310 |
+
self.state.sentences = data["sentences"]
|
311 |
+
self.state.meta = data["meta"]
|
312 |
+
|
313 |
+
log.info("載入 BM25 索引...")
|
314 |
+
with open(BM25_PKL, "rb") as f:
|
315 |
+
bm25_data = pickle.load(f)
|
316 |
+
self.state.bm25 = bm25_data["bm25"]
|
317 |
+
if not isinstance(self.state.bm25, BM25Okapi):
|
318 |
+
raise ValueError("Loaded BM25 is not a BM25Okapi instance.")
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
log.info("所有模型與資料載入完成。")
|
321 |
|
322 |
+
@lru_cache(maxsize=128)
|
323 |
+
def _get_drug_name_by_id(self, drug_id: str) -> Optional[str]:
|
324 |
+
row = self.df_csv[self.df_csv["drug_id"] == drug_id]
|
325 |
+
if not row.empty:
|
326 |
+
return row.iloc[0]["drug_name_norm"]
|
327 |
+
return None
|
328 |
+
|
329 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
330 |
+
q_norm_parts = set(
|
331 |
+
re.findall(r"[a-z0-9]+|[\u4e00-\u9fff]+", _norm(query))
|
332 |
+
)
|
333 |
drug_ids = set()
|
|
|
334 |
for part in q_norm_parts:
|
335 |
if part in self.drug_name_to_ids:
|
336 |
drug_ids.update(self.drug_name_to_ids[part])
|
337 |
+
for drug_name, ids in self.drug_name_to_ids.items():
|
338 |
+
if drug_name in _norm(query):
|
339 |
+
drug_ids.update(ids)
|
340 |
+
return sorted(drug_ids)
|
341 |
|
342 |
def _build_drug_name_to_ids(self) -> Dict[str, List[str]]:
|
343 |
+
self.drug_name_to_ids = {}
|
344 |
for _, row in self.df_csv.iterrows():
|
345 |
+
drug_id = row["drug_id"]
|
346 |
+
zh_parts = list(jieba.cut(row["drug_name_zh"]))
|
347 |
+
en_parts = re.findall(
|
348 |
+
r"[a-zA-Z0-9]+", row["drug_name_en"].lower() if row["drug_name_en"] else ""
|
349 |
+
)
|
350 |
+
norm_parts = re.findall(
|
351 |
+
r"[a-z0-9]+|[\u4e00-\u9fff]+", _norm(row["drug_name_norm"])
|
352 |
+
)
|
353 |
all_parts = set(zh_parts + en_parts + norm_parts)
|
354 |
for part in all_parts:
|
355 |
part = part.strip()
|
356 |
if part and len(part) > 1:
|
357 |
+
self.drug_name_to_ids.setdefault(part, []).append(drug_id)
|
|
|
|
|
358 |
for alias, canonical_name in DRUG_NAME_MAPPING.items():
|
359 |
+
if _norm(canonical_name) in _norm(row["drug_name_norm"]):
|
360 |
+
self.drug_name_to_ids.setdefault(_norm(alias), []).append(drug_id)
|
361 |
+
for key in self.drug_name_to_ids:
|
362 |
+
self.drug_name_to_ids[key] = sorted(set(self.drug_name_to_ids[key]))
|
363 |
+
return self.drug_name_to_ids
|
|
|
|
|
364 |
|
365 |
def _load_drug_name_vocabulary(self):
|
366 |
log.info("建立藥名詞庫...")
|
367 |
for _, row in self.df_csv.iterrows():
|
368 |
+
norm_name = row["drug_name_norm"]
|
369 |
+
words = re.findall(r"[a-z0-9]+|[\u4e00-\u9fff]+", norm_name)
|
370 |
for word in words:
|
371 |
+
if re.search(r"[\u4e00-\u9fff]", word):
|
372 |
self.drug_vocab["zh"].add(word)
|
373 |
else:
|
374 |
self.drug_vocab["en"].add(word)
|
|
|
375 |
for alias in DRUG_NAME_MAPPING:
|
376 |
+
if re.search(r"[\u4e00-\u9fff]", alias):
|
377 |
self.drug_vocab["zh"].add(alias)
|
378 |
else:
|
379 |
self.drug_vocab["en"].add(alias)
|
|
|
380 |
for word in self.drug_vocab["zh"]:
|
381 |
try:
|
382 |
if word not in jieba.dt.FREQ:
|
383 |
jieba.add_word(word, freq=2_000_000)
|
384 |
except Exception:
|
385 |
pass
|
386 |
+
|
387 |
@tenacity.retry(
|
388 |
wait=tenacity.wait_fixed(2),
|
389 |
+
stop=tenacity.stop_after_attempt(5)
|
|
|
|
|
|
|
390 |
)
|
391 |
+
def _llm_call(
|
392 |
+
self,
|
393 |
+
messages: List[Dict[str, str]],
|
394 |
+
max_tokens: Optional[int] = None,
|
395 |
+
temperature: Optional[float] = None,
|
396 |
+
) -> str:
|
397 |
+
log.info(f"LLM 呼叫開始. 模型: {self.model_name}")
|
398 |
try:
|
399 |
response = self.llm_client.chat.completions.create(
|
400 |
model=self.model_name,
|
401 |
messages=messages,
|
402 |
+
max_tokens=max_tokens or LLM_MODEL_CONFIG["max_tokens_simple"],
|
403 |
+
temperature=temperature or LLM_MODEL_CONFIG["temperature"],
|
404 |
)
|
405 |
+
content = response.choices[0].message.content or ""
|
406 |
+
return content.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
except Exception as e:
|
408 |
+
log.error(f"LLM API 調用失敗: {e}")
|
409 |
raise
|
410 |
|
411 |
+
async def answer_question(self, target_id: str, q_orig: str) -> Tuple[str, List[str]]:
|
412 |
start_time = time.time()
|
413 |
+
log.info(f"===== 處理查詢: '{q_orig}' (target_id: {target_id}) =====")
|
414 |
+
conv_state = self.conversations[target_id]
|
415 |
+
conv_state.query_history.append(q_orig)
|
416 |
+
|
417 |
try:
|
418 |
drug_ids = self._find_drug_ids_from_name(q_orig)
|
|
|
419 |
if not drug_ids:
|
420 |
+
log.info(f"未找到匹配藥物,查詢:{q_orig}")
|
421 |
+
conv_state.clarification_count += 1
|
422 |
+
if conv_state.clarification_count > 3:
|
423 |
+
return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, []
|
424 |
+
clarification = self._generate_clarification_query(q_orig)
|
425 |
+
conv_state.last_answer = clarification
|
426 |
+
return f"{clarification}\n\n{DISCLAIMER}", []
|
427 |
+
|
428 |
+
conv_state.drug_ids = drug_ids
|
429 |
+
ref_key = self._match_reference_key(q_orig)
|
430 |
+
complexity = "simple" # 預設為簡單
|
431 |
+
context = ""
|
432 |
+
intents = []
|
433 |
+
|
434 |
+
if ref_key != 'none' and ref_key in REFERENCE_MAPPING:
|
435 |
+
sections_str = REFERENCE_MAPPING[ref_key]
|
436 |
+
sections = [s.strip() for s in sections_str.split('、') if s.strip() and s != '藥袋上的醫囑']
|
437 |
+
intents = REFERENCE_TO_INTENT.get(ref_key, [])
|
438 |
+
context = self._build_context_from_csv(drug_ids, sections)
|
439 |
+
# 根據參考資料判斷複雜度
|
440 |
+
if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
|
441 |
+
complexity = "complex" # 多步驟的裝置安裝或藥品使用
|
442 |
+
elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
|
443 |
+
complexity = "simple" # 副作用問題
|
444 |
else:
|
445 |
+
return await self._fallback_rag(target_id, q_orig, drug_ids)
|
446 |
+
|
447 |
+
conv_state.intents = intents
|
448 |
+
conv_state.complexity = complexity
|
449 |
+
|
450 |
+
max_tokens = LLM_MODEL_CONFIG["max_tokens_complex"] if complexity == "complex" else LLM_MODEL_CONFIG["max_tokens_simple"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
prompt = self._make_final_prompt(q_orig, context, intents)
|
452 |
+
answer = self._llm_call(
|
453 |
+
[{"role": "user", "content": prompt}],
|
454 |
+
max_tokens=max_tokens
|
455 |
+
)
|
456 |
if not answer:
|
457 |
+
return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
|
459 |
+
answer = answer.replace("*", "")
|
460 |
+
conv_state.last_answer = answer
|
461 |
+
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
462 |
+
log.info(f"查詢處理完成,耗時: {time.time() - start_time:.2f}秒")
|
463 |
+
return final_answer, drug_ids
|
464 |
except Exception as e:
|
465 |
+
log.error(f"處理查詢時發生錯誤: {e}", exc_info=True)
|
466 |
+
return f"處理時發生內部錯誤,請稍後再試。\n{DISCLAIMER}", []
|
467 |
+
|
468 |
+
async def _fallback_rag(self, target_id: str, q_orig: str, drug_ids: List[str]) -> Tuple[str, List[str]]:
|
469 |
+
conv_state = self.conversations[target_id]
|
470 |
+
analysis = self._analyze_query(q_orig)
|
471 |
+
sub_queries = analysis.get("sub_queries", [q_orig])
|
472 |
+
intents = analysis.get("intents", [])
|
473 |
+
complexity = "simple" # 預設為簡單
|
474 |
+
sections = []
|
475 |
+
for intent in intents:
|
476 |
+
sections.extend(INTENT_TO_SECTION.get(intent, []))
|
477 |
+
if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
|
478 |
+
complexity = "complex"
|
479 |
+
elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
|
480 |
+
complexity = "simple"
|
481 |
+
conv_state.intents = intents
|
482 |
+
conv_state.complexity = complexity
|
483 |
|
484 |
+
if not intents:
|
485 |
+
log.info(f"無明確意圖,查詢:{q_orig}")
|
486 |
+
conv_state.clarification_count += 1
|
487 |
+
if conv_state.clarification_count > 3:
|
488 |
+
return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, drug_ids
|
489 |
+
clarification = self._generate_clarification_query(q_orig)
|
490 |
+
conv_state.last_answer = clarification
|
491 |
+
return f"{clarification}\n\n{DISCLAIMER}", drug_ids
|
492 |
+
|
493 |
+
all_candidates = self._retrieve_candidates_for_all_queries(
|
494 |
+
drug_ids, sub_queries, intents
|
495 |
+
)
|
496 |
+
final_candidates = all_candidates[:TOP_K_SENTENCES]
|
497 |
+
|
498 |
+
reranked_results = [
|
499 |
+
RerankResult(
|
500 |
+
idx=c.idx,
|
501 |
+
rerank_score=c.fused_score,
|
502 |
+
text=self.state.sentences[c.idx],
|
503 |
+
meta=self.state.meta[c.idx],
|
504 |
+
)
|
505 |
+
for c in final_candidates
|
506 |
+
]
|
507 |
+
prioritized = self._prioritize_context(reranked_results, intents)
|
508 |
+
context = self._build_context(prioritized)
|
509 |
|
510 |
+
if not context:
|
511 |
+
return f"無法回答您的問題,請參閱原始內容。\n{DISCLAIMER}", drug_ids
|
512 |
+
|
513 |
+
max_tokens = LLM_MODEL_CONFIG["max_tokens_complex"] if complexity == "complex" else LLM_MODEL_CONFIG["max_tokens_simple"]
|
514 |
+
prompt = self._make_final_prompt(q_orig, context, intents)
|
515 |
+
answer = self._llm_call(
|
516 |
+
[{"role": "user", "content": prompt}],
|
517 |
+
max_tokens=max_tokens
|
518 |
+
)
|
519 |
+
if not answer:
|
520 |
+
return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
|
521 |
+
|
522 |
+
answer = answer.replace("*", "")
|
523 |
+
conv_state.last_answer = answer
|
524 |
+
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
525 |
+
return final_answer, drug_ids
|
526 |
+
|
527 |
+
def _match_reference_key(self, query: str) -> str:
|
528 |
+
options = "\n".join(f"- {k}" for k in REFERENCE_MAPPING.keys())
|
529 |
+
prompt = PROMPT_TEMPLATES["analyze_reference"].format(
|
530 |
+
options=options, query=query
|
531 |
)
|
532 |
+
response_str = self._llm_call([{"role": "user", "content": prompt}])
|
533 |
+
ref_key = response_str.strip().replace('"', '')
|
534 |
+
if ref_key in REFERENCE_MAPPING:
|
535 |
+
return ref_key
|
536 |
+
return 'none'
|
537 |
|
538 |
+
def _build_context_from_csv(self, drug_ids: List[str], sections: List[str]) -> str:
|
539 |
+
context = ""
|
540 |
+
for drug_id in drug_ids:
|
541 |
+
drug_df = self.df_csv[self.df_csv['drug_id'] == drug_id]
|
542 |
+
for sec in sections:
|
543 |
+
sec_row = drug_df[drug_df['section'].str.contains(sec, na=False)]
|
544 |
+
if not sec_row.empty:
|
545 |
+
content = sec_row.iloc[0]['content']
|
546 |
+
if len(context) + len(content) > LLM_MODEL_CONFIG["max_context_chars"]:
|
547 |
+
return context.strip()
|
548 |
+
context += content + "\n\n"
|
549 |
+
return context.strip()
|
550 |
|
551 |
+
def _analyze_query(self, query: str) -> Dict[str, Any]:
|
552 |
+
options = "\n".join(f"- {c}" for c in INTENT_CATEGORIES)
|
553 |
+
prompt = PROMPT_TEMPLATES["analyze_query"].format(
|
554 |
+
options=options, query=query
|
555 |
+
)
|
556 |
+
response_str = self._llm_call([{"role": "user", "content": prompt}])
|
557 |
+
return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": [], "complexity": "simple"})
|
558 |
+
|
559 |
+
def _generate_clarification_query(self, query: str) -> str:
|
560 |
+
prompt = PROMPT_TEMPLATES["clarification"].format(query=query)
|
561 |
+
return self._llm_call([{"role": "user", "content": prompt}])
|
562 |
+
|
563 |
+
def _retrieve_candidates_for_all_queries(
|
564 |
+
self, drug_ids: List[str], sub_queries: List[str], intents: List[str]
|
565 |
+
) -> List[FusedCandidate]:
|
566 |
+
drug_ids_set, relevant_indices = set(map(str, drug_ids)), (
|
567 |
+
{i for i, m in enumerate(self.state.meta) if str(m.get("drug_id")) in drug_ids_set}
|
568 |
+
if drug_ids_set
|
569 |
+
else set(range(len(self.state.meta)))
|
570 |
+
)
|
571 |
+
if not relevant_indices:
|
572 |
+
return []
|
573 |
|
574 |
all_fused_candidates: Dict[int, FusedCandidate] = {}
|
|
|
575 |
for sub_q in sub_queries:
|
576 |
+
expanded_q = self._expand_query_with_llm(sub_q, intents)
|
|
|
577 |
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
|
578 |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
579 |
faiss.normalize_L2(q_emb)
|
580 |
+
distances, sem_indices = self.state.index.search(q_emb, PRE_RERANK_K)
|
581 |
|
582 |
tokenized_query = list(jieba.cut(expanded_q))
|
|
|
583 |
bm25_scores = self.state.bm25.get_scores(tokenized_query)
|
584 |
+
rel_idx = np.fromiter(relevant_indices, dtype=np.int64)
|
585 |
rel_scores = bm25_scores[rel_idx]
|
586 |
top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
|
587 |
+
doc_to_bm25_score: Dict[int, float] = {
|
588 |
+
int(i): float(bm25_scores[i]) for i in top_rel
|
589 |
+
}
|
590 |
candidate_scores: Dict[int, Dict[str, float]] = {}
|
|
|
591 |
def to_similarity(d: float) -> float:
|
592 |
+
return float(d) if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT else 1.0 / (1.0 + float(d))
|
593 |
+
for i, dist in zip(sem_indices[0], distances[0]):
|
|
|
|
|
|
|
|
|
594 |
if i in relevant_indices:
|
595 |
+
candidate_scores[i] = {"sem": to_similarity(dist), "bm": 0.0}
|
|
|
|
|
596 |
for i, score in doc_to_bm25_score.items():
|
597 |
if i in relevant_indices:
|
598 |
candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score
|
599 |
+
if not candidate_scores:
|
600 |
+
continue
|
|
|
601 |
keys = list(candidate_scores.keys())
|
602 |
+
sem_scores = np.array([candidate_scores[k]["sem"] for k in keys])
|
603 |
+
bm_scores = np.array([candidate_scores[k]["bm"] for k in keys])
|
604 |
+
|
605 |
+
def norm(x):
|
606 |
+
return (x - x.min()) / (x.max() - x.min() + 1e-8) if x.max() - x.min() > 0 else np.zeros_like(x)
|
|
|
607 |
|
608 |
sem_n, bm_n = norm(sem_scores), norm(bm_scores)
|
|
|
609 |
for idx, k in enumerate(keys):
|
610 |
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
|
|
|
611 |
if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score:
|
612 |
all_fused_candidates[k] = FusedCandidate(
|
613 |
idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]
|
614 |
)
|
|
|
615 |
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
|
616 |
|
617 |
+
def _expand_query_with_llm(self, query: str, intents: List[str]) -> str:
|
618 |
if not intents:
|
619 |
return query
|
620 |
+
prompt = PROMPT_TEMPLATES["expand_query"].format(intents=intents, query=query)
|
621 |
+
expanded = self._llm_call(
|
622 |
+
[{"role": "user", "content": prompt}]
|
623 |
+
)
|
624 |
+
return expanded.strip() if expanded else query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
625 |
|
626 |
def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]:
|
627 |
+
if not intents:
|
628 |
return results
|
629 |
+
prioritized_sections = set()
|
630 |
+
for intent in intents:
|
631 |
+
prioritized_sections.update(INTENT_TO_SECTION.get(intent, []))
|
632 |
+
if not prioritized_sections:
|
633 |
+
return results
|
634 |
+
prioritized, other = [], []
|
635 |
+
for res in results:
|
636 |
+
if res.meta.get("section") in prioritized_sections:
|
637 |
+
prioritized.append(res)
|
638 |
+
else:
|
639 |
+
other.append(res)
|
640 |
+
return prioritized + other
|
641 |
|
642 |
def _build_context(self, reranked_results: List[RerankResult]) -> str:
|
643 |
context = ""
|
644 |
for res in reranked_results:
|
645 |
+
if len(context) + len(res.text) > LLM_MODEL_CONFIG["max_context_chars"]:
|
646 |
+
break
|
647 |
context += res.text + "\n\n"
|
648 |
return context.strip()
|
649 |
|
650 |
def _make_final_prompt(self, query: str, context: str, intents: List[str]) -> str:
|
651 |
add_instr = ""
|
652 |
+
if any(
|
653 |
+
i
|
654 |
+
in intents
|
655 |
+
for i in
|
656 |
+
("劑量調整 (Dosage Adjustment)", "時間/併用 (Timing & Interaction)")
|
657 |
+
):
|
658 |
+
add_instr = (
|
659 |
+
"在回答用藥劑量和時間時,務必提醒使用者,醫師開立的藥袋醫囑優先於仿單的一般建議。"
|
660 |
+
)
|
661 |
if "保存/攜帶 (Storage & Handling)" in intents:
|
662 |
+
add_instr += (
|
663 |
+
" 在回答保存與攜帶問題時,除了仿單內容,請根據常識加入實際情境的提醒,例如提醒需冷藏藥品要用保冷袋攜帶。"
|
664 |
+
)
|
665 |
+
add_instr += "\n請根據以下問題與參考資料對應回答:"
|
666 |
+
for q, refs in REFERENCE_MAPPING.items():
|
667 |
+
add_instr += f"\n- {q}: {refs}"
|
668 |
return PROMPT_TEMPLATES["final_answer"].format(
|
669 |
additional_instruction=add_instr, context=context, query=query
|
670 |
)
|
671 |
+
|
672 |
def _safe_json_parse(self, s: str, default: Any = None) -> Any:
|
673 |
try:
|
674 |
return json.loads(s)
|
675 |
except json.JSONDecodeError:
|
676 |
+
try:
|
677 |
+
m = re.search(r"\{.*?\}", s, re.DOTALL)
|
678 |
+
if m:
|
|
|
679 |
return json.loads(m.group(0))
|
680 |
+
except json.JSONDecodeError:
|
681 |
+
pass
|
682 |
return default
|
683 |
|
684 |
+
|
685 |
# ---------- FastAPI 事件與路由 ----------
|
686 |
class AppConfig:
|
687 |
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
|
688 |
CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
|
689 |
|
690 |
+
|
691 |
rag_pipeline: Optional[RagPipeline] = None
|
692 |
|
693 |
+
|
694 |
@asynccontextmanager
|
695 |
async def lifespan(app: FastAPI):
|
696 |
_require_llm_config()
|
|
|
701 |
yield
|
702 |
log.info("服務關閉中。")
|
703 |
|
704 |
+
|
705 |
app = FastAPI(lifespan=lifespan)
|
706 |
|
707 |
+
|
708 |
@app.post("/webhook")
|
709 |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
710 |
signature = request.headers.get("X-Line-Signature")
|
711 |
if not signature:
|
712 |
+
raise HTTPException(status_code=400, detail="Missing LINE X-Line-Signature header")
|
|
|
|
|
|
|
713 |
|
714 |
body = await request.body()
|
715 |
try:
|
716 |
+
hash_obj = hmac.new(AppConfig.CHANNEL_SECRET.encode("utf-8"), body, hashlib.sha256)
|
717 |
+
expected_signature = base64.b64encode(hash_obj.digest()).decode("utf-8")
|
718 |
except Exception as e:
|
719 |
log.error(f"Failed to generate signature: {e}")
|
720 |
raise HTTPException(status_code=500, detail="Signature generation error")
|
|
|
723 |
raise HTTPException(status_code=403, detail="Invalid signature")
|
724 |
|
725 |
try:
|
726 |
+
data = json.loads(body.decode("utf-8"))
|
727 |
except json.JSONDecodeError:
|
728 |
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
729 |
|
730 |
for event in data.get("events", []):
|
731 |
+
if (
|
732 |
+
event.get("type") == "message"
|
733 |
+
and event.get("message", {}).get("type") == "text"
|
734 |
+
):
|
735 |
user_text = event.get("message", {}).get("text", "").strip()
|
736 |
source = event.get("source", {})
|
737 |
stype = source.get("type")
|
738 |
+
target_id = (
|
739 |
+
source.get("userId") or source.get("groupId") or source.get("roomId")
|
740 |
+
)
|
741 |
+
if user_text and target_id:
|
742 |
+
background_tasks.add_task(
|
743 |
+
process_user_query, stype, target_id, user_text
|
744 |
+
)
|
745 |
return Response(status_code=status.HTTP_200_OK)
|
746 |
|
747 |
+
|
748 |
+
async def process_user_query(source_type: str, target_id: str, user_text: str):
|
749 |
try:
|
750 |
+
if not rag_pipeline:
|
751 |
+
await line_push_generic(source_type, target_id,
|
752 |
+
"系統正在啟動中,請稍後再試。")
|
753 |
+
return
|
754 |
+
answer, drug_ids = await rag_pipeline.answer_question(target_id, user_text)
|
755 |
+
await line_push_generic(source_type, target_id, answer)
|
756 |
except Exception as e:
|
757 |
log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True)
|
758 |
+
await line_push_generic(
|
759 |
+
source_type,
|
760 |
+
target_id,
|
761 |
+
f"抱歉,處理時發生未預期的錯誤。\n{DISCLAIMER}",
|
762 |
+
)
|
763 |
+
|
764 |
|
765 |
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
766 |
+
async def line_api_call(endpoint: str, data: Dict):
|
767 |
headers = {
|
768 |
"Content-Type": "application/json",
|
769 |
+
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}",
|
770 |
}
|
771 |
+
async with aiohttp.ClientSession() as session:
|
772 |
+
async with session.post(
|
773 |
+
f"https://api.line.me/v2/bot/message/{endpoint}",
|
774 |
+
headers=headers,
|
775 |
+
json=data,
|
776 |
+
timeout=10,
|
777 |
+
) as response:
|
778 |
+
response.raise_for_status()
|
779 |
+
|
780 |
+
|
781 |
+
async def line_reply(reply_token: str, text: str):
|
782 |
+
messages = [
|
783 |
+
{"type": "text", "text": chunk}
|
784 |
+
for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]
|
785 |
+
]
|
786 |
+
await line_api_call("reply", {"replyToken": reply_token, "messages": messages})
|
787 |
+
|
788 |
+
|
789 |
+
async def line_push_generic(source_type: str, target_id: str, text: str):
|
790 |
+
messages = [
|
791 |
+
{"type": "text", "text": chunk}
|
792 |
+
for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]
|
793 |
+
]
|
794 |
+
if "目前僅支援以下藥物詢問" in text:
|
795 |
+
drug_list = "\n".join(f"- {drug}" for drug in SUPPORTED_DRUGS)
|
796 |
+
messages.append({"type": "text", "text": f"支援的藥物清單:\n{drug_list}"})
|
797 |
data = {"to": target_id, "messages": messages}
|
798 |
+
await line_api_call("push", data)
|
799 |
+
|
800 |
|
801 |
+
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> List[str]:
|
802 |
candidates = set()
|
803 |
q_norm = _norm(query)
|
|
|
804 |
for word in re.findall(r"[a-z0-9]+", q_norm):
|
805 |
if word in drug_vocab["en"]:
|
806 |
candidates.add(word)
|
|
|
807 |
for token in jieba.cut(q_norm):
|
808 |
if token in drug_vocab["zh"]:
|
809 |
candidates.add(token)
|
810 |
+
supported_drugs = set(DRUG_NAME_MAPPING.keys()).union(DRUG_NAME_MAPPING.values())
|
811 |
+
if not candidates.issubset(supported_drugs):
|
812 |
+
candidates = set()
|
813 |
return list(candidates)
|
814 |
|
815 |
+
|
816 |
# ---------- 執行 ----------
|
817 |
if __name__ == "__main__":
|
818 |
+
port = int(os.getenv("PORT", "7860"))
|
819 |
uvicorn.run(app, host="0.0.0.0", port=port)
|