Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
9e85da1
1
Parent(s):
b20c534
hi
Browse files
app.py
CHANGED
@@ -27,6 +27,7 @@ for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.gete
|
|
27 |
# ---------- Imports ----------
|
28 |
import re, hmac, base64, hashlib, pickle, logging, time, json
|
29 |
from typing import List, Dict, Any, Optional, Tuple, Union
|
|
|
30 |
|
31 |
import numpy as np
|
32 |
import pandas as pd
|
@@ -38,659 +39,323 @@ except Exception:
|
|
38 |
torch = None
|
39 |
|
40 |
try:
|
41 |
-
import faiss
|
42 |
-
except Exception
|
43 |
-
|
44 |
|
45 |
try:
|
46 |
-
from sentence_transformers import SentenceTransformer
|
47 |
except Exception:
|
48 |
SentenceTransformer = None
|
49 |
|
50 |
try:
|
51 |
-
from rank_bm25 import BM25Okapi
|
52 |
except Exception:
|
53 |
BM25Okapi = None
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
BM25_WEIGHT = 0.8
|
87 |
-
SEM_WEIGHT = 0.2
|
88 |
-
EMBEDDING_MODEL_ID= "DMetaSoul/Dmeta-embedding-zh"
|
89 |
-
RERANKER_MODEL_ID = "BAAI/bge-reranker-base"
|
90 |
-
USE_CPU = True # HF 預設 CPU
|
91 |
-
RERANK_THRESHOLD = 0.5
|
92 |
-
MAX_CONTEXT_CHARS = 8000
|
93 |
-
DISCLAIMER = "此回覆僅供參考,請遵循醫師/藥師指示。"
|
94 |
-
|
95 |
-
# 藥名映射與停用詞
|
96 |
-
DRUG_NAME_MAPPING = {
|
97 |
-
"fentanyl patch": "fentanyl",
|
98 |
-
"spiriva respimat": "spiriva",
|
99 |
-
"augmentin for syrup": "augmentin syrup",
|
100 |
-
"nitrostat": "nitroglycerin",
|
101 |
-
"ozempic": "ozempic",
|
102 |
-
"niflec": "niflec",
|
103 |
-
"fosamax": "alendronate",
|
104 |
-
"humira": "adalimumab",
|
105 |
-
"premarin": "premarin",
|
106 |
-
"smecta": "smecta",
|
107 |
-
"duragesic": "fentanyl",
|
108 |
-
"芬太尼貼片": "fentanyl",
|
109 |
-
"透皮止痛貼片": "fentanyl",
|
110 |
-
}
|
111 |
|
112 |
-
DRUG_STOPWORDS = {"藥", "劑", "錠", "膠囊", "糖漿", "乳膏", "貼片", "含錠", "膜衣錠", "緩釋錠", "滴劑", "懸液", "注射液",
|
113 |
-
"吸入劑", "噴霧", "噴霧劑", "吸入器", "注射筆", "藥水", "小袋", "條", "包", "瓶", "外用", "口服"}
|
114 |
-
|
115 |
-
# 意圖分類(改用字典提升匹配率)
|
116 |
-
INTENT_KEYWORDS = {
|
117 |
-
"如何用藥 (Administration)": ["操作", "使用", "怎麼用", "怎麼吃", "怎麼貼", "怎麼喝", "怎麼注射", "服用", "組裝", "安裝", "用藥方式"],
|
118 |
-
"保存/攜帶 (Storage & Handling)": ["保存", "儲存", "攜帶", "冷藏", "室溫", "潮濕", "保冰袋", "旅遊"],
|
119 |
-
"副作用/異常 (Side Effects / Issues)": ["副作用", "異常", "不良反應", "頭暈", "拉肚子", "噁心", "想吐", "過敏", "問題"],
|
120 |
-
"劑量調整 (Dosage Adjustment)": ["劑量", "幾顆", "調整", "忘記吃", "上限", "幾次", "劑量多少"],
|
121 |
-
"用藥時間 (Timing)": ["時間", "多久", "間隔", "飯前", "飯後", "隨餐", "睡前", "什麼時候"],
|
122 |
-
"禁忌症/適應症 (Contraindications/Indications)": ["禁忌", "適應症", "不能用", "不適合", "誰不能吃", "適合"],
|
123 |
-
}
|
124 |
-
|
125 |
-
# 章節權重
|
126 |
SECTION_NORMALIZE = {
|
127 |
-
|
128 |
-
|
129 |
-
"警語注意事項": "警語及注意事項",
|
130 |
-
"交互作用": "藥物交互作用",
|
131 |
-
"包裝及儲存": "儲存條件",
|
132 |
-
"儲存條件": "儲存條件"
|
133 |
-
}
|
134 |
-
|
135 |
-
SECTION_WEIGHTS = {
|
136 |
-
"用法及用量": 1.0,
|
137 |
-
"病人使用須知": 1.0,
|
138 |
-
"儲存條件": 1.0,
|
139 |
-
"警語及注意事項": 1.0,
|
140 |
-
"禁忌": 1.0,
|
141 |
-
"副作用": 1.0,
|
142 |
-
"藥物交互作用": 1.0,
|
143 |
-
"其他": 1.0,
|
144 |
-
"包裝及儲存": 1.0,
|
145 |
-
"不良反應": 1.0,
|
146 |
}
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
pathlib.Path(preferred_path).parent.mkdir(parents=True, exist_ok=True)
|
164 |
-
with open(preferred_path, "wb") as f:
|
165 |
-
pickle.dump(obj, f)
|
166 |
-
return preferred_path
|
167 |
-
except OSError as e:
|
168 |
-
if e.errno == errno.EACCES:
|
169 |
-
alt = os.path.join("/tmp", os.path.basename(preferred_path))
|
170 |
-
try:
|
171 |
-
with open(alt, "wb") as f:
|
172 |
-
pickle.dump(obj, f)
|
173 |
-
log.warning("No write permission for %s, saved to %s instead.", preferred_path, alt)
|
174 |
-
return alt
|
175 |
-
except Exception as ee:
|
176 |
-
log.warning("Failed to save to /tmp as well: %s", ee)
|
177 |
-
else:
|
178 |
-
log.warning("pickle dump failed: %s", e)
|
179 |
-
except Exception as e:
|
180 |
-
log.warning("pickle dump failed: %s", e)
|
181 |
-
return ""
|
182 |
-
|
183 |
-
def safe_faiss_write(index, preferred_path: str) -> str:
|
184 |
-
try:
|
185 |
-
pathlib.Path(preferred_path).parent.mkdir(parents=True, exist_ok=True)
|
186 |
-
faiss.write_index(index, preferred_path)
|
187 |
-
return preferred_path
|
188 |
-
except OSError as e:
|
189 |
-
if e.errno == errno.EACCES:
|
190 |
-
alt = os.path.join("/tmp", os.path.basename(preferred_path))
|
191 |
-
try:
|
192 |
-
faiss.write_index(index, alt)
|
193 |
-
log.warning("No write permission for %s, saved FAISS to %s instead.", preferred_path, alt)
|
194 |
-
return alt
|
195 |
-
except Exception as ee:
|
196 |
-
log.warning("Failed to save FAISS to /tmp as well: %s", ee)
|
197 |
-
else:
|
198 |
-
log.warning("faiss write failed: %s", e)
|
199 |
-
except Exception as e:
|
200 |
-
log.warning("faiss write failed: %s", e)
|
201 |
-
return ""
|
202 |
-
|
203 |
-
# ---------- 檔案路徑(優先專案根目錄,其次 /app,最後 /tmp) ----------
|
204 |
-
CWD = os.getcwd()
|
205 |
-
SENTENCES_PKL = pick_existing_or_tmp([
|
206 |
-
os.path.join(CWD, "drug_sentences.pkl"),
|
207 |
-
"/app/drug_sentences.pkl",
|
208 |
-
"/tmp/drug_sentences.pkl",
|
209 |
-
])
|
210 |
-
FAISS_INDEX = pick_existing_or_tmp([
|
211 |
-
os.path.join(CWD, "drug_sentences.index"),
|
212 |
-
"/app/drug_sentences.index",
|
213 |
-
"/tmp/drug_sentences.index",
|
214 |
-
])
|
215 |
-
BM25_PKL = pick_existing_or_tmp([
|
216 |
-
os.path.join(CWD, "bm25.pkl"),
|
217 |
-
"/app/bm25.pkl",
|
218 |
-
"/tmp/bm25.pkl",
|
219 |
-
])
|
220 |
-
CSV_PATH = pick_existing_or_tmp([
|
221 |
-
os.path.join(CWD, "cleaned_combined.csv"),
|
222 |
-
"/app/cleaned_combined.csv",
|
223 |
-
"/tmp/cleaned_combined.csv",
|
224 |
-
])
|
225 |
-
|
226 |
-
# ---------- FastAPI ----------
|
227 |
-
app = FastAPI(title="DrugQA (ZH) — LINE Webhook Only")
|
228 |
-
|
229 |
-
# ---------- Helpers ----------
|
230 |
-
_ZH_SPLIT_RE = re.compile(r"[。!?\n]")
|
231 |
-
|
232 |
-
def split_sentences(text: str) -> List[str]:
|
233 |
-
if not isinstance(text, str): return []
|
234 |
-
sents = [s.strip() for s in _ZH_SPLIT_RE.split(text) if s.strip()]
|
235 |
-
return [s for s in sents if len(s) > 6]
|
236 |
-
|
237 |
-
def tokenize_zh(s: str) -> List[str]:
|
238 |
-
if not isinstance(s, str) or not s: return []
|
239 |
-
if jieba is None: return s.strip().split()
|
240 |
-
return [t for t in jieba.lcut(s) if t.strip() and t not in DRUG_STOPWORDS]
|
241 |
-
|
242 |
-
def detect_intent(query: str) -> List[str]:
|
243 |
-
"""Detects user intent based on keywords."""
|
244 |
-
detected = []
|
245 |
-
query_lower = query.lower().replace(" ", "")
|
246 |
-
for intent, keywords in INTENT_KEYWORDS.items():
|
247 |
-
if any(k in query_lower for k in keywords):
|
248 |
-
detected.append(intent)
|
249 |
-
return detected
|
250 |
-
|
251 |
-
class State:
|
252 |
-
sentences: List[str] = []
|
253 |
-
meta: List[Dict[str, Any]] = []
|
254 |
-
emb_model: Optional[Any] = None
|
255 |
-
reranker_model: Optional[Any] = None
|
256 |
-
faiss_index: Optional[Any] = None
|
257 |
-
bm25: Optional[Any] = None
|
258 |
-
df_csv: Optional[pd.DataFrame] = None
|
259 |
-
user_sessions: Dict[str, Dict[str, Any]] = {}
|
260 |
-
query_cache: Dict[str, Dict[str, Any]] = {}
|
261 |
-
|
262 |
-
STATE = State()
|
263 |
-
|
264 |
-
# ---------- 載入與建立 ----------
|
265 |
-
def ensure_sentences_meta() -> Tuple[List[str], List[Dict[str, Any]]]:
|
266 |
if os.path.exists(SENTENCES_PKL):
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
|
|
|
280 |
def load_embedding_model(model_id: str):
|
281 |
if SentenceTransformer is None:
|
282 |
-
log.
|
283 |
return None
|
284 |
-
device = "cpu" if (USE_CPU or (torch is None)) else ("cuda" if torch.cuda.is_available() else "cpu")
|
285 |
-
log.info("Load SentenceTransformer: %s on %s", model_id, device)
|
286 |
try:
|
287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
except Exception as e:
|
289 |
-
log.
|
290 |
return None
|
291 |
|
|
|
292 |
def load_reranker_model(model_id: str):
|
293 |
-
if
|
294 |
-
log.
|
295 |
return None
|
296 |
-
device = "cpu" if (USE_CPU or (torch is None)) else ("cuda" if torch.cuda.is_available() else "cpu")
|
297 |
-
log.info("Load CrossEncoder: %s on %s", model_id, device)
|
298 |
try:
|
299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
except Exception as e:
|
301 |
-
log.
|
302 |
return None
|
303 |
|
304 |
-
def ensure_faiss(index_path: str, sentences: List[str])
|
305 |
-
if
|
306 |
-
|
307 |
-
try:
|
308 |
-
idx = faiss.read_index(index_path)
|
309 |
-
if idx.ntotal == len(sentences):
|
310 |
-
log.info("Loaded FAISS: %s (d=%d n=%d)", index_path, idx.d, idx.ntotal)
|
311 |
-
return idx
|
312 |
-
else:
|
313 |
-
log.warning("FAISS ntotal mismatch (%d != %d). Rebuilding.", idx.ntotal, len(sentences))
|
314 |
-
except Exception as e:
|
315 |
-
log.warning("Failed to load FAISS (%s): %s", index_path, e)
|
316 |
-
if STATE.emb_model is None:
|
317 |
-
log.warning("No emb_model; skip FAISS build.")
|
318 |
return None
|
319 |
-
log.info("Building FAISS (n=%d)...", len(sentences))
|
320 |
-
embeds = STATE.emb_model.encode(sentences, normalize_embeddings=True, show_progress_bar=True)
|
321 |
-
dim = embeds.shape[1]
|
322 |
-
idx = faiss.IndexFlatIP(dim)
|
323 |
-
idx.add(embeds.astype(np.float32))
|
324 |
-
safe_faiss_write(idx, index_path)
|
325 |
-
return idx
|
326 |
-
|
327 |
-
def ensure_bm25(pkl_path: str, sentences: List[str]) -> Optional[Any]:
|
328 |
-
if BM25Okapi is None: return None
|
329 |
-
if os.path.exists(pkl_path):
|
330 |
-
try:
|
331 |
-
with open(pkl_path, "rb") as f:
|
332 |
-
bm = pickle.load(f)
|
333 |
-
# BM25 has corpus, not corpus_size attribute
|
334 |
-
n_bm = len(bm.corpus) if hasattr(bm, 'corpus') else 0
|
335 |
-
if n_bm == len(sentences):
|
336 |
-
log.info("Loaded BM25: %s (n=%d)", pkl_path, n_bm)
|
337 |
-
return bm
|
338 |
-
else:
|
339 |
-
log.warning("BM25 corpus size mismatch (%d != %d). Rebuilding.", n_bm, len(sentences))
|
340 |
-
except Exception as e:
|
341 |
-
log.warning("Failed to load BM25 (%s): %s", pkl_path, e)
|
342 |
-
log.info("Building BM25 (n=%d)...", len(sentences))
|
343 |
-
tokenized_corpus = [tokenize_zh(s) for s in sentences]
|
344 |
-
bm = BM25Okapi(tokenized_corpus)
|
345 |
-
safe_pickle_dump(bm, pkl_path)
|
346 |
-
return bm
|
347 |
-
|
348 |
-
# ---------- 資訊解析與藥名處理 (簡化) ----------
|
349 |
-
|
350 |
-
# 1. parse_user_message: 簡化為只比對藥名
|
351 |
-
def parse_user_message(query: str, df: pd.DataFrame) -> Dict[str, Any]:
|
352 |
-
"""
|
353 |
-
MODIFIED: 只比對 drug_name_norm,找最佳藥品。
|
354 |
-
"""
|
355 |
-
best_drug = None
|
356 |
-
best_row = None
|
357 |
-
max_score = 0
|
358 |
-
|
359 |
-
if not fuzz:
|
360 |
-
log.warning("fuzzywuzzy not available; skipping fuzzy match.")
|
361 |
-
return {
|
362 |
-
"drug_name": None,
|
363 |
-
"drug_id": None,
|
364 |
-
"question": query,
|
365 |
-
}
|
366 |
-
|
367 |
-
# Use a pre-tokenized and normalized list for faster fuzzy matching
|
368 |
-
# In a real app, this should be pre-computed and stored for efficiency
|
369 |
-
unique_drugs = df.drop_duplicates(subset=['drug_id'])
|
370 |
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
best_row = row
|
387 |
-
|
388 |
-
if best_drug is None or max_score < 80: # 設定一個閾值來避免不相關的匹配
|
389 |
-
log.warning(f"No confident drug match found (score: {max_score})")
|
390 |
-
return {
|
391 |
-
"drug_name": None,
|
392 |
-
"drug_id": None,
|
393 |
-
"question": query,
|
394 |
-
}
|
395 |
-
|
396 |
-
log.info(f"Parsed user message (best match): {best_drug}, score: {max_score}")
|
397 |
-
|
398 |
-
return {
|
399 |
-
"drug_name": best_drug,
|
400 |
-
"drug_id": best_row["drug_id"],
|
401 |
-
"question": query
|
402 |
-
}
|
403 |
-
|
404 |
-
# 2. find_drug_candidates: 簡化為單純 fuzzy 比對
|
405 |
-
def find_drug_candidates(parsed_info: Dict[str, Any], df: pd.DataFrame, top_k: int = 5) -> List[Dict[str, Any]]:
|
406 |
-
"""
|
407 |
-
MODIFIED: 單純對 drug_name_norm 做 fuzzy 比對,並回傳前 top_k 候選。
|
408 |
-
"""
|
409 |
-
query_text = parsed_info.get("question", "").lower()
|
410 |
-
if df is None or df.empty or not query_text:
|
411 |
-
return []
|
412 |
|
413 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
return []
|
415 |
|
416 |
-
|
417 |
-
|
|
|
|
|
418 |
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
candidates_list.append({
|
424 |
-
"drug_id": row["drug_id"],
|
425 |
-
"drug_name": drug_norm,
|
426 |
-
"score": score
|
427 |
-
})
|
428 |
-
|
429 |
-
# 依 score 排序並回傳前 top_k
|
430 |
-
sorted_candidates = sorted(candidates_list, key=lambda x: x['score'], reverse=True)
|
431 |
-
|
432 |
-
log.info(f"Found drug candidates: {sorted_candidates[:top_k]}")
|
433 |
-
return sorted_candidates[:top_k]
|
434 |
-
|
435 |
-
# 3. answer_pipeline: 簡化流程
|
436 |
-
async def answer_pipeline(query: str, user_id: str) -> str:
|
437 |
-
log.info("Pipeline start for user_id: %s, query: %s", user_id, query[:50])
|
438 |
-
if not query or not isinstance(query, str):
|
439 |
-
return handle_error("INVALID_QUERY")
|
440 |
-
if not STATE.sentences or not STATE.df_csv:
|
441 |
-
return handle_error("NO_CORPUS")
|
442 |
-
|
443 |
-
# 1. 解析使用者輸入並找到最佳藥品
|
444 |
-
best_drug_info = parse_user_message(query, STATE.df_csv)
|
445 |
|
446 |
-
|
447 |
-
|
448 |
-
return make_clarify_message()
|
449 |
|
450 |
-
#
|
451 |
-
|
452 |
|
453 |
-
#
|
454 |
-
|
455 |
-
second_score = drug_candidates[1]['score'] if len(drug_candidates) > 1 else 0
|
456 |
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
# 4. 檢索相關內文 (fuse_and_select)
|
466 |
-
idxs = fuse_and_select(
|
467 |
-
query=best_drug_info["question"],
|
468 |
-
sentences=STATE.sentences,
|
469 |
-
meta=STATE.meta,
|
470 |
-
bm25=STATE.bm25,
|
471 |
-
index=STATE.faiss_index,
|
472 |
-
emb_model=STATE.emb_model,
|
473 |
-
reranker=STATE.reranker_model,
|
474 |
-
top_k=TOP_K_SENTENCES,
|
475 |
-
drug_id=drug_choice['drug_id'],
|
476 |
-
# 移除 parsed_info
|
477 |
-
)
|
478 |
-
|
479 |
-
if not idxs:
|
480 |
-
return handle_error("NO_CONTEXT")
|
481 |
-
|
482 |
-
# 5. 建立上下文和 Prompt (build_prompt)
|
483 |
-
context = build_context(idxs, STATE.sentences, STATE.meta)
|
484 |
-
prompt = build_prompt(best_drug_info, context, drug_choice)
|
485 |
-
log.info("Generated Prompt:\n%s", prompt)
|
486 |
-
|
487 |
-
# 6. 呼叫 LLM 生成答案
|
488 |
-
answer = call_llm(prompt)
|
489 |
-
if not answer:
|
490 |
-
return handle_error("LLM_ERROR")
|
491 |
-
|
492 |
-
return f"{answer}\n\n{DISCLAIMER}"
|
493 |
-
|
494 |
-
# 4. build_prompt: 簡化提示詞
|
495 |
-
def build_prompt(parsed_info: Dict[str, Any], contexts: str, drug_choice: Dict[str, Any]) -> str:
|
496 |
-
"""
|
497 |
-
MODIFIED: 簡化為只包含藥品名稱、使用者問題、參考片段。
|
498 |
-
"""
|
499 |
-
return (
|
500 |
-
"你是一位專業、有同理心的藥師。請根據提供的「參考片段」,簡潔地回答使用者的「問題」。\n"
|
501 |
-
"---限制---\n"
|
502 |
-
"- 絕對忠於「參考片段」,不可捏造或過度推論。你的知識僅限於提供的片段。\n"
|
503 |
-
"- 回覆少於 120 字,並使用繁體中文條列式 2-4 點說明。\n"
|
504 |
-
"- 語氣親切、精簡、專業。\n"
|
505 |
-
"- 若片段中無足夠資訊回答,必須回覆:「根據提供的資料,我無法找到關於您問題的明確答案,建議您諮詢醫師或藥師。」\n"
|
506 |
-
"---輸入資訊---\n"
|
507 |
-
f"藥物名稱: {drug_choice.get('drug_name')}\n"
|
508 |
-
f"問題: {parsed_info.get('question')}\n\n"
|
509 |
-
f"參考片段:\n{contexts}\n"
|
510 |
-
"---你的回答---"
|
511 |
)
|
512 |
-
|
513 |
-
|
514 |
-
try:
|
515 |
-
from openai import OpenAI
|
516 |
-
except Exception as e:
|
517 |
-
log.warning("openai client 不可用:%s", e)
|
518 |
-
return None
|
519 |
-
if not (LITELLM_API_KEY and LM_MODEL and LITELLM_BASE_URL):
|
520 |
-
log.warning("LLM 未完整設定;略過生成。")
|
521 |
-
return None
|
522 |
-
client = OpenAI(base_url=LITELLM_BASE_URL, api_key=LITELLM_API_KEY)
|
523 |
try:
|
524 |
-
|
525 |
-
resp = client.chat.completions.create(
|
526 |
model=LM_MODEL,
|
527 |
-
messages=[
|
528 |
-
|
529 |
-
|
530 |
-
|
|
|
531 |
)
|
532 |
-
|
533 |
-
log.info("LLM ok (%.2fs)", used)
|
534 |
-
return (resp.choices[0].message.content or "").strip()
|
535 |
except Exception as e:
|
536 |
-
log.
|
537 |
-
return
|
538 |
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
tokenized_query = tokenize_zh(clean_query)
|
568 |
-
scores = {}
|
569 |
-
|
570 |
-
# BM25 lexical search
|
571 |
-
if bm25:
|
572 |
-
bm_scores = bm25.get_scores(tokenized_query)
|
573 |
-
bm_scores_np = np.array(bm_scores)
|
574 |
-
if np.max(bm_scores_np) > np.min(bm_scores_np):
|
575 |
-
scores_norm = (bm_scores_np - np.min(bm_scores_np)) / (np.max(bm_scores_np) - np.min(bm_scores_np))
|
576 |
-
else:
|
577 |
-
scores_norm = bm_scores_np
|
578 |
-
|
579 |
-
for i, s_norm in enumerate(scores_norm):
|
580 |
-
if 0 <= i < len(meta) and (not drug_id or meta[i].get("drug_id") == drug_id):
|
581 |
-
scores[i] = scores.get(i, 0.0) + BM25_WEIGHT * s_norm
|
582 |
-
|
583 |
-
# FAISS semantic search
|
584 |
-
if emb_model and index:
|
585 |
-
q_emb = emb_model.encode([clean_query], normalize_embeddings=True).astype(np.float32)
|
586 |
-
_, idxs = index.search(q_emb, top_k * 8)
|
587 |
-
for rank, i in enumerate(idxs[0].tolist()):
|
588 |
-
if 0 <= i < len(meta) and (not drug_id or meta[i].get("drug_id") == drug_id):
|
589 |
-
scores[i] = scores.get(i, 0.0) + SEM_WEIGHT * (1.0 / (1 + rank))
|
590 |
-
|
591 |
-
# Apply boosts
|
592 |
-
for i in list(scores.keys()): # Iterate over a copy of keys
|
593 |
-
meta_item = meta[i]
|
594 |
-
|
595 |
-
# Section weight boost
|
596 |
-
sec = meta_item.get("section", "其他")
|
597 |
-
scores[i] *= SECTION_WEIGHTS.get(sec, 1.0)
|
598 |
-
|
599 |
-
# Boost based on detected intent
|
600 |
-
detected_intents = detect_intent(clean_query)
|
601 |
-
for i in list(scores.keys()):
|
602 |
-
meta_item = meta[i]
|
603 |
-
sec = meta_item.get("section", "其他")
|
604 |
-
|
605 |
-
if any(intent in detected_intents for intent in ["如何用藥 (Administration)", "用藥時間 (Timing)", "劑量調整 (Dosage Adjustment)"]) and sec in ["用法及用量", "病人使用須知"]:
|
606 |
-
scores[i] *= 1.5
|
607 |
-
elif any(intent in detected_intents for intent in ["保存/攜帶 (Storage & Handling)"]) and sec in ["儲存條件", "包裝及儲存"]:
|
608 |
-
scores[i] *= 1.5
|
609 |
-
elif any(intent in detected_intents for intent in ["副作用/異常 (Side Effects / Issues)"]) and sec in ["不良反應", "警語及注意事項"]:
|
610 |
-
scores[i] *= 1.5
|
611 |
-
|
612 |
-
# Inject important sections if they are missing
|
613 |
-
for sec in IMPORTANT_SECTIONS:
|
614 |
-
sec_idx = next((i for i, m in enumerate(meta) if (m.get("drug_id") == drug_id) and m.get("section") == sec), None)
|
615 |
-
if sec_idx is not None and sec_idx not in scores:
|
616 |
-
scores[sec_idx] = 1.0 # Give it a moderate score to ensure inclusion before reranking
|
617 |
-
|
618 |
-
# Prepare for reranking
|
619 |
-
candidates = [(i, sc, 0.0, 0.0) for i, sc in scores.items()]
|
620 |
-
|
621 |
-
reranked = rerank_results(clean_query, candidates, sentences, reranker, top_k, RERANK_THRESHOLD)
|
622 |
-
idxs = [r["idx"] for r in reranked]
|
623 |
-
|
624 |
-
STATE.query_cache[cache_key] = {'idxs': idxs, 'time': time.time()}
|
625 |
-
return idxs
|
626 |
-
|
627 |
-
def build_context(idxs: List[int], sentences: List[str], meta: List[Dict[str, Any]]) -> str:
|
628 |
-
ctx_lines, total_len, seen = [], 0, set()
|
629 |
-
for i in idxs:
|
630 |
-
if i < 0: continue
|
631 |
-
text = sentences[i]
|
632 |
-
if text in seen: continue
|
633 |
-
chunk_id = meta[i].get("chunk_id", "None")
|
634 |
-
section = meta[i].get("section", "未知章節")
|
635 |
-
line = f"[{section}]: {text}"
|
636 |
-
if total_len + len(line) > MAX_CONTEXT_CHARS: break
|
637 |
-
ctx_lines.append(line)
|
638 |
-
total_len += len(line) + 1
|
639 |
-
seen.add(text)
|
640 |
-
return "\n".join(ctx_lines) or "[未知章節]: 沒有找到相關資料,請諮詢醫師或藥師。"
|
641 |
-
|
642 |
-
# ---------- LINE 驗簽與回覆 ----------
|
643 |
-
def verify_line_signature(body_bytes: bytes, signature: str) -> bool:
|
644 |
-
if not CHANNEL_SECRET:
|
645 |
-
log.warning("CHANNEL_SECRET 未設定;跳過簽章驗證(僅供測試)。")
|
646 |
-
return True
|
647 |
-
try:
|
648 |
-
mac = hmac.new(CHANNEL_SECRET.encode("utf-8"), body_bytes, hashlib.sha256).digest()
|
649 |
-
expected = base64.b64encode(mac).decode("utf-8")
|
650 |
-
return hmac.compare_digest(expected, signature)
|
651 |
-
except Exception as e:
|
652 |
-
log.warning("簽章驗證錯誤:%s", e)
|
653 |
-
return False
|
654 |
-
|
655 |
-
def line_reply(reply_token: str, text: str) -> None:
|
656 |
-
if not CHANNEL_ACCESS_TOKEN or requests is None:
|
657 |
-
log.warning("缺少 CHANNEL_ACCESS_TOKEN 或 requests;略過回覆。")
|
658 |
-
return
|
659 |
-
url = "https://api.line.me/v2/bot/message/reply"
|
660 |
headers = {
|
661 |
-
"Content-Type": "application/json",
|
662 |
-
"Authorization": f"Bearer {
|
|
|
|
|
|
|
|
|
663 |
}
|
664 |
-
data = {"replyToken": reply_token, "messages": [{"type": "text", "text": text[:4900]}]}
|
665 |
try:
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
except Exception as e:
|
670 |
-
log.
|
|
|
|
|
|
|
|
|
|
|
671 |
|
672 |
-
# ---------- 只有這一條路由:POST /webhook ----------
|
673 |
@app.post("/webhook")
|
674 |
-
async def webhook
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
try:
|
690 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
except Exception as e:
|
692 |
log.warning("Pipeline 失敗:%s", e)
|
693 |
-
answer = "
|
694 |
if reply_token:
|
695 |
line_reply(reply_token, answer)
|
696 |
return {"ok": True}
|
@@ -704,25 +369,29 @@ async def _startup():
|
|
704 |
log.info("PyTorch version %s available.", torch.__version__)
|
705 |
except Exception:
|
706 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
# 載入語料與索引
|
708 |
STATE.sentences, STATE.meta = ensure_sentences_meta()
|
709 |
STATE.emb_model = load_embedding_model(EMBEDDING_MODEL_ID)
|
710 |
STATE.reranker_model = load_reranker_model(RERANKER_MODEL_ID)
|
711 |
STATE.faiss_index = ensure_faiss(FAISS_INDEX, STATE.sentences)
|
712 |
STATE.bm25 = ensure_bm25(BM25_PKL, STATE.sentences)
|
|
|
713 |
for m in STATE.meta:
|
714 |
sec = m.get("section", "其他")
|
715 |
m["section"] = SECTION_NORMALIZE.get(sec, sec)
|
716 |
-
|
717 |
-
|
718 |
-
log.info("
|
719 |
-
log.info("Startup complete.")
|
720 |
-
|
721 |
-
@app.get("/")
|
722 |
-
async def health():
|
723 |
-
return {"status": "healthy"}
|
724 |
|
|
|
725 |
if __name__ == "__main__":
|
726 |
import uvicorn
|
727 |
-
|
728 |
-
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level=LOG_LEVEL.lower(), reload=False)
|
|
|
27 |
# ---------- Imports ----------
|
28 |
import re, hmac, base64, hashlib, pickle, logging, time, json
|
29 |
from typing import List, Dict, Any, Optional, Tuple, Union
|
30 |
+
from functools import lru_cache
|
31 |
|
32 |
import numpy as np
|
33 |
import pandas as pd
|
|
|
39 |
torch = None
|
40 |
|
41 |
try:
|
42 |
+
import faiss # 僅用於檢查 faiss 是否安裝
|
43 |
+
except Exception:
|
44 |
+
faiss = None
|
45 |
|
46 |
try:
|
47 |
+
from sentence_transformers import SentenceTransformer
|
48 |
except Exception:
|
49 |
SentenceTransformer = None
|
50 |
|
51 |
try:
|
52 |
+
from rank_bm25 import BM25Okapi
|
53 |
except Exception:
|
54 |
BM25Okapi = None
|
55 |
|
56 |
+
from litellm import completion
|
57 |
+
from fastapi import FastAPI, Request
|
58 |
+
from pydantic import BaseModel
|
59 |
+
from pydantic_settings import BaseSettings
|
60 |
+
|
61 |
+
# ---------- Logging Setup ----------
|
62 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
63 |
+
log = logging.getLogger(__name__)
|
64 |
+
|
65 |
+
# ---------- Constants & Paths ----------
|
66 |
+
CONFIG_PATH = os.path.join(os.path.dirname(__file__), ".env")
|
67 |
+
SENTENCES_PKL = os.path.join(os.path.dirname(__file__), "drug_sentences.pkl")
|
68 |
+
FAISS_INDEX = os.path.join(os.path.dirname(__file__), "drug_sentences.index")
|
69 |
+
BM25_PKL = os.path.join(os.path.dirname(__file__), "bm25.pkl")
|
70 |
+
CSV_PATH = os.path.join(os.path.dirname(__file__), "cleaned_combined.csv")
|
71 |
+
|
72 |
+
# fallback to /tmp
|
73 |
+
for p in (SENTENCES_PKL, FAISS_INDEX, BM25_PKL):
|
74 |
+
if not os.path.exists(p):
|
75 |
+
log.warning("File not found: %s, fallback to /tmp", p)
|
76 |
+
# 允許重建,但只寫到 /tmp
|
77 |
+
SENTENCES_PKL = os.path.join("/tmp", os.path.basename(SENTENCES_PKL))
|
78 |
+
FAISS_INDEX = os.path.join("/tmp", os.path.basename(FAISS_INDEX))
|
79 |
+
BM25_PKL = os.path.join("/tmp", os.path.basename(BM25_PKL))
|
80 |
+
break
|
81 |
+
|
82 |
+
# https://docs.huggingface.co/huggingface_hub/package_reference/environment_variables#general-purpose
|
83 |
+
# https://docs.sentence-transformers.com/en/latest/quickstart.html
|
84 |
+
EMBEDDING_MODEL_ID = os.getenv("EMBEDDING_MODEL_ID", "AI-infinity/bge-m3-zh-tw")
|
85 |
+
RERANKER_MODEL_ID = os.getenv("RERANKER_MODEL_ID", "AI-infinity/bge-reranker-base-zh-tw")
|
86 |
+
LM_MODEL = os.getenv("LM_MODEL", "azure/gpt-4o")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
SECTION_NORMALIZE = {
|
89 |
+
'適應症': '適應症', '用法用量': '用法用量', '副作用': '副作用',
|
90 |
+
'禁忌': '禁忌', '警語注意事項': '警語及注意事項', '藥品外觀': '藥品外觀'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
}
|
92 |
|
93 |
+
# ---------- State Management ----------
|
94 |
+
class AppState:
|
95 |
+
def __init__(self):
|
96 |
+
self.sentences: List[str] = []
|
97 |
+
self.meta: List[Dict[str, Any]] = []
|
98 |
+
self.emb_model: Optional[SentenceTransformer] = None
|
99 |
+
self.reranker_model: Optional[SentenceTransformer] = None
|
100 |
+
self.faiss_index: Optional[faiss.Index] = None
|
101 |
+
self.bm25: Optional[BM25Okapi] = None
|
102 |
+
self.df_csv: Optional[pd.DataFrame] = None
|
103 |
+
|
104 |
+
STATE = AppState()
|
105 |
+
|
106 |
+
# ---------- Helper Functions (RAG pipeline) ----------
|
107 |
+
def ensure_sentences_meta():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
if os.path.exists(SENTENCES_PKL):
|
109 |
+
log.info("Loading sentences from %s", SENTENCES_PKL)
|
110 |
+
with open(SENTENCES_PKL, "rb") as f:
|
111 |
+
sentences, meta = pickle.load(f)
|
112 |
+
return sentences, meta
|
113 |
+
|
114 |
+
if not os.path.exists(CSV_PATH):
|
115 |
+
log.error("CSV file not found: %s", CSV_PATH)
|
116 |
+
return [], []
|
117 |
+
|
118 |
+
df = pd.read_csv(CSV_PATH)
|
119 |
+
sentences = (df["section"] + ":" + df["content"]).tolist()
|
120 |
+
meta = df.to_dict(orient="records")
|
121 |
+
|
122 |
+
log.info("Building sentences from CSV (total=%d)", len(sentences))
|
123 |
+
with open(SENTENCES_PKL, "wb") as f:
|
124 |
+
pickle.dump((sentences, meta), f)
|
125 |
+
|
126 |
+
return sentences, meta
|
127 |
|
128 |
+
@lru_cache(maxsize=1)
|
129 |
def load_embedding_model(model_id: str):
|
130 |
if SentenceTransformer is None:
|
131 |
+
log.error("SentenceTransformer not installed. Please install it.")
|
132 |
return None
|
|
|
|
|
133 |
try:
|
134 |
+
if torch is not None and torch.cuda.is_available():
|
135 |
+
log.info("Using GPU for embedding model: %s", model_id)
|
136 |
+
device = "cuda"
|
137 |
+
else:
|
138 |
+
log.info("Using CPU for embedding model: %s", model_id)
|
139 |
+
device = "cpu"
|
140 |
+
model = SentenceTransformer(model_id, device=device)
|
141 |
+
return model
|
142 |
except Exception as e:
|
143 |
+
log.error("Failed to load embedding model: %s, error: %s", model_id, e)
|
144 |
return None
|
145 |
|
146 |
+
@lru_cache(maxsize=1)
|
147 |
def load_reranker_model(model_id: str):
|
148 |
+
if SentenceTransformer is None:
|
149 |
+
log.error("SentenceTransformer not installed. Please install it.")
|
150 |
return None
|
|
|
|
|
151 |
try:
|
152 |
+
if torch is not None and torch.cuda.is_available():
|
153 |
+
log.info("Using GPU for reranker model: %s", model_id)
|
154 |
+
device = "cuda"
|
155 |
+
else:
|
156 |
+
log.info("Using CPU for reranker model: %s", model_id)
|
157 |
+
device = "cpu"
|
158 |
+
model = SentenceTransformer(model_id, device=device)
|
159 |
+
return model
|
160 |
except Exception as e:
|
161 |
+
log.error("Failed to load reranker model: %s, error: %s", model_id, e)
|
162 |
return None
|
163 |
|
164 |
+
def ensure_faiss(index_path: str, sentences: List[str]):
|
165 |
+
if faiss is None:
|
166 |
+
log.error("FAISS not installed. Please install it.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
+
if os.path.exists(index_path):
|
170 |
+
log.info("Loading FAISS index from %s", index_path)
|
171 |
+
return faiss.read_index(index_path)
|
172 |
+
|
173 |
+
log.info("Building FAISS index (total=%d)", len(sentences))
|
174 |
+
embeddings = STATE.emb_model.encode(sentences, show_progress_bar=True)
|
175 |
+
index = faiss.IndexFlatL2(embeddings.shape[1])
|
176 |
+
index.add(embeddings)
|
177 |
+
faiss.write_index(index, index_path)
|
178 |
+
return index
|
179 |
+
|
180 |
+
def ensure_bm25(bm25_path: str, sentences: List[str]):
|
181 |
+
if BM25Okapi is None:
|
182 |
+
log.error("BM25Okapi not installed. Please install it.")
|
183 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
if os.path.exists(bm25_path):
|
186 |
+
log.info("Loading BM25 index from %s", bm25_path)
|
187 |
+
with open(bm25_path, 'rb') as f:
|
188 |
+
bm25 = pickle.load(f)
|
189 |
+
return bm25
|
190 |
+
|
191 |
+
log.info("Building BM25 index (total=%d)", len(sentences))
|
192 |
+
tokenized_corpus = [s.split(" ") for s in sentences]
|
193 |
+
bm25 = BM25Okapi(tokenized_corpus)
|
194 |
+
with open(bm25_path, 'wb') as f:
|
195 |
+
pickle.dump(bm25, f)
|
196 |
+
return bm25
|
197 |
+
|
198 |
+
def retrieve_top_k_passages(query: str, k: int=10):
|
199 |
+
if STATE.emb_model is None or STATE.faiss_index is None or STATE.bm25 is None:
|
200 |
return []
|
201 |
|
202 |
+
# FAISS retrieval
|
203 |
+
query_embedding = STATE.emb_model.encode([query])
|
204 |
+
_, faiss_indices = STATE.faiss_index.search(query_embedding, k)
|
205 |
+
faiss_passages = [STATE.sentences[i] for i in faiss_indices[0]]
|
206 |
|
207 |
+
# BM25 retrieval
|
208 |
+
bm25_scores = STATE.bm25.get_scores(query.split(" "))
|
209 |
+
bm25_indices = np.argsort(bm25_scores)[::-1][:k]
|
210 |
+
bm25_passages = [STATE.sentences[i] for i in bm25_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
+
# Combine and deduplicate
|
213 |
+
combined_passages = list(dict.fromkeys(faiss_passages + bm25_passages))
|
|
|
214 |
|
215 |
+
# Reranking
|
216 |
+
reranker_scores = STATE.reranker_model.rank(query, combined_passages)
|
217 |
|
218 |
+
# Sort and return top k
|
219 |
+
reranked_passages = [p for _, p in sorted(zip(reranker_scores, combined_passages), reverse=True)][:k]
|
|
|
220 |
|
221 |
+
return reranked_passages
|
222 |
+
|
223 |
+
def generate_answer(query: str, passages: List[str]):
|
224 |
+
context = "\n".join(passages)
|
225 |
+
system_prompt = (
|
226 |
+
"你是一位專業藥師,請根據以下提供的藥品資訊內容,使用繁體中文簡潔且準確地回答問題。 "
|
227 |
+
"你的回覆應清楚、易懂,且只引用提供的內容。若提供的內容無法回答問題,請直接說「抱歉,我無法從現有資料中找到相關資訊。」"
|
228 |
+
"回答時請依據不同段落整理成條列式或分段式,並在回覆最後加上免責聲明「本資訊僅供參考,實際用藥請諮詢專業醫師或藥師。」。"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
)
|
230 |
+
user_prompt = f"問題: {query}\n\n藥品資訊:\n{context}\n\n回覆:"
|
231 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
try:
|
233 |
+
response = completion(
|
|
|
234 |
model=LM_MODEL,
|
235 |
+
messages=[
|
236 |
+
{"role": "system", "content": system_prompt},
|
237 |
+
{"role": "user", "content": user_prompt}
|
238 |
+
],
|
239 |
+
temperature=0.0
|
240 |
)
|
241 |
+
return response.choices[0].message.content
|
|
|
|
|
242 |
except Exception as e:
|
243 |
+
log.error("LLM completion failed: %s", e)
|
244 |
+
return "抱歉,系統暫時無法回覆。請稍後再試。"
|
245 |
|
246 |
+
# ---------- FastAPI App & Webhook Logic ----------
|
247 |
+
class Settings(BaseSettings):
|
248 |
+
channel_access_token: str
|
249 |
+
channel_secret: str
|
250 |
+
|
251 |
+
class Config:
|
252 |
+
env_file = ".env"
|
253 |
+
env_file_encoding = "utf-8"
|
254 |
+
|
255 |
+
try:
|
256 |
+
settings = Settings()
|
257 |
+
except Exception as e:
|
258 |
+
log.error("Failed to load environment variables from .env: %s", e)
|
259 |
+
log.warning("Trying to load from system environment variables...")
|
260 |
+
settings = Settings(_env_file=None) # fall back to system env
|
261 |
+
|
262 |
+
app = FastAPI()
|
263 |
+
|
264 |
+
class LineEvent(BaseModel):
|
265 |
+
replyToken: str
|
266 |
+
message: Dict[str, str]
|
267 |
+
|
268 |
+
class LineWebhook(BaseModel):
|
269 |
+
events: List[LineEvent]
|
270 |
+
destination: str
|
271 |
+
|
272 |
+
def line_reply(reply_token: str, message: str):
|
273 |
+
import requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
headers = {
|
275 |
+
"Content-Type": "application/json; charset=UTF-8",
|
276 |
+
"Authorization": f"Bearer {settings.channel_access_token}"
|
277 |
+
}
|
278 |
+
data = {
|
279 |
+
"replyToken": reply_token,
|
280 |
+
"messages": [{"type": "text", "text": message}]
|
281 |
}
|
|
|
282 |
try:
|
283 |
+
response = requests.post("https://api.line.me/v2/bot/message/reply", headers=headers, json=data, timeout=5)
|
284 |
+
response.raise_for_status()
|
285 |
+
log.info("LINE reply success: %s", response.status_code)
|
286 |
except Exception as e:
|
287 |
+
log.error("LINE reply failed: %s", e)
|
288 |
+
|
289 |
+
# 為了快速比對藥名,預先正規化
|
290 |
+
# REVISED: 新增程式碼區塊
|
291 |
+
if STATE.df_csv is not None:
|
292 |
+
STATE.df_csv['drug_name_norm_normalized'] = STATE.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
|
293 |
|
|
|
294 |
@app.post("/webhook")
|
295 |
+
async def handle_webhook(webhook: LineWebhook, request: Request):
|
296 |
+
# 驗證簽章
|
297 |
+
signature = request.headers.get("x-line-signature")
|
298 |
+
if signature is None:
|
299 |
+
log.warning("Signature is missing.")
|
300 |
+
return {"ok": False}
|
301 |
+
|
302 |
+
body = (await request.body()).decode("utf-8")
|
303 |
+
hash_body = hmac.new(settings.channel_secret.encode('utf-8'), body.encode('utf-8'), hashlib.sha256).digest()
|
304 |
+
if not hmac.compare_digest(hash_body, base64.b64decode(signature)):
|
305 |
+
log.warning("Invalid signature.")
|
306 |
+
return {"ok": False}
|
307 |
+
|
308 |
+
for event in webhook.events:
|
309 |
+
reply_token = event.replyToken
|
310 |
+
msg_type = event.message.get("type")
|
311 |
+
if msg_type == "text":
|
312 |
+
user_text = event.message.get("text", "")
|
313 |
+
log.info("Received message: %s", user_text)
|
314 |
+
|
315 |
+
answer = "抱歉,系統暫時無法回覆。"
|
316 |
try:
|
317 |
+
# REVISED: 優先處理單一藥名精確比對
|
318 |
+
if STATE.df_csv is not None:
|
319 |
+
normalized_user_text = user_text.lower().replace(r'[^\w\s]', '').strip()
|
320 |
+
|
321 |
+
# 檢查是否為單一藥名查詢
|
322 |
+
matches = STATE.df_csv[STATE.df_csv['drug_name_norm_normalized'] == normalized_user_text]
|
323 |
+
|
324 |
+
if not matches.empty:
|
325 |
+
unique_drug_ids = matches['drug_id'].nunique()
|
326 |
+
if unique_drug_ids == 1:
|
327 |
+
log.info("Exact match found for single drug name: %s", user_text)
|
328 |
+
# 獲取該藥品的所有資訊
|
329 |
+
target_drug_id = matches['drug_id'].iloc[0]
|
330 |
+
relevant_rows = STATE.df_csv[STATE.df_csv['drug_id'] == target_drug_id]
|
331 |
+
|
332 |
+
# 組合所有相關段落內容
|
333 |
+
combined_content = []
|
334 |
+
for section in relevant_rows['section'].unique():
|
335 |
+
section_content = "\n".join(relevant_rows[relevant_rows['section'] == section]['content'].tolist())
|
336 |
+
combined_content.append(f"{section}:{section_content}")
|
337 |
+
|
338 |
+
# 傳給 LLM 進行摘要與回覆生成
|
339 |
+
passages_for_llm = combined_content
|
340 |
+
answer = generate_answer(query=user_text, passages=passages_for_llm)
|
341 |
+
|
342 |
+
# 不符合單一藥名比對條件,或有多個 drug_id,則走原本的 RAG 流程
|
343 |
+
else:
|
344 |
+
log.info("Multiple drugs or no exact match, falling back to RAG pipeline.")
|
345 |
+
passages = retrieve_top_k_passages(user_text)
|
346 |
+
answer = generate_answer(query=user_text, passages=passages)
|
347 |
+
else:
|
348 |
+
# 找不到精確匹配,走原本的 RAG 流程
|
349 |
+
passages = retrieve_top_k_passages(user_text)
|
350 |
+
answer = generate_answer(query=user_text, passages=passages)
|
351 |
+
else:
|
352 |
+
# 如果 CSV 沒有載入,也走 RAG 流程
|
353 |
+
passages = retrieve_top_k_passages(user_text)
|
354 |
+
answer = generate_answer(query=user_text, passages=passages)
|
355 |
+
|
356 |
except Exception as e:
|
357 |
log.warning("Pipeline 失敗:%s", e)
|
358 |
+
answer = "抱歉,系統暫時無法回覆。請稍後再試。"
|
359 |
if reply_token:
|
360 |
line_reply(reply_token, answer)
|
361 |
return {"ok": True}
|
|
|
369 |
log.info("PyTorch version %s available.", torch.__version__)
|
370 |
except Exception:
|
371 |
pass
|
372 |
+
|
373 |
+
# 載入 CSV
|
374 |
+
if os.path.exists(CSV_PATH):
|
375 |
+
STATE.df_csv = pd.read_csv(CSV_PATH, dtype=str)
|
376 |
+
log.info("Loaded CSV: %s (rows=%d)", CSV_PATH, len(STATE.df_csv))
|
377 |
+
else:
|
378 |
+
log.warning("CSV not found: %s", CSV_PATH)
|
379 |
+
|
380 |
# 載入語料與索引
|
381 |
STATE.sentences, STATE.meta = ensure_sentences_meta()
|
382 |
STATE.emb_model = load_embedding_model(EMBEDDING_MODEL_ID)
|
383 |
STATE.reranker_model = load_reranker_model(RERANKER_MODEL_ID)
|
384 |
STATE.faiss_index = ensure_faiss(FAISS_INDEX, STATE.sentences)
|
385 |
STATE.bm25 = ensure_bm25(BM25_PKL, STATE.sentences)
|
386 |
+
|
387 |
for m in STATE.meta:
|
388 |
sec = m.get("section", "其他")
|
389 |
m["section"] = SECTION_NORMALIZE.get(sec, sec)
|
390 |
+
|
391 |
+
log.info("LLM Model: %s", LM_MODEL)
|
392 |
+
log.info("===== Application Ready =====")
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
+
# This is a sample code to test the API locally.
|
395 |
if __name__ == "__main__":
|
396 |
import uvicorn
|
397 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|