Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
DrugQA (ZH) — 優化版 FastAPI LINE Webhook | |
整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。 | |
""" | |
# ---------- 環境與快取設定 (應置於最前) ---------- | |
import os | |
import pathlib | |
os.environ.setdefault("HF_HOME", "/tmp/hf") | |
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers") | |
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache") | |
for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.getenv("XDG_CACHE_HOME")): | |
pathlib.Path(d).mkdir(parents=True, exist_ok=True) | |
# ---------- Python 標準函式庫 ---------- | |
import re | |
import hmac | |
import base64 | |
import hashlib | |
import pickle | |
import logging | |
import json | |
from typing import List, Dict, Any, Optional, Tuple, Union | |
from functools import lru_cache | |
import time | |
import textwrap | |
# ---------- 第三方函式庫 ---------- | |
import numpy as np | |
import pandas as pd | |
from fastapi import FastAPI, Request, Response, HTTPException, status | |
import uvicorn | |
import jieba | |
from rank_bm25 import BM25Okapi | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
import faiss | |
import torch | |
from openai import OpenAI | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
import requests | |
from starlette.concurrency import run_in_threadpool | |
# ==== CONFIG (從環境變數載入,或使用預設值) ==== | |
# 根據提供的檔案清單,將預設路徑設定為當前目錄 | |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv") | |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index") | |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl") | |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl") | |
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 30)) | |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30)) | |
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 50)) | |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh") | |
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3") | |
_SENT_SPLIT_RE = re.compile(r"[。!?\n]") | |
DRUG_STOPWORDS = {"藥", "劑", "錠", "膠囊", "糖漿", "乳膏", "貼片"} | |
SECTION_WEIGHTS = { | |
"用法及用量": 1.0, | |
"病人使用須知": 1.0, | |
"儲存條件": 1.0, | |
"警語及注意事項": 2.0, | |
"禁忌": 1.5, | |
"副作用": 1.0, | |
"藥物交互作用": 1.0, | |
"其他": 1.0, | |
} | |
RERANK_THRESHOLD = float(os.getenv("RERANK_THRESHOLD", 0.5)) | |
DRUG_NAME_MAPPING = { | |
"fentanyl patch": "fentanyl", | |
"spiriva respimat": "spiriva", | |
"augmentin for syrup": "augmentin syrup", | |
"nitrostat": "nitroglycerin", | |
"ozempic": "ozempic", | |
"niflec": "niflec", | |
"fosamax": "fosamax", | |
"humira": "humira", | |
"premarin": "premarin", | |
"smecta": "smecta", | |
} | |
LLM_API_CONFIG = { | |
"base_url": os.getenv("LITELLM_BASE_URL"), | |
"api_key": os.getenv("LITELLM_API_KEY"), | |
"model": os.getenv("LM_MODEL") | |
} | |
LLM_MODEL_CONFIG = { | |
"max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 12000)), | |
"max_tokens": int(os.getenv("MAX_TOKENS", 2048)), | |
"temperature": float(os.getenv("TEMPERATURE", 0.0)), | |
"top_p": float(os.getenv("TOP_P", 0.95)), | |
"stop_tokens": ["==="], | |
} | |
# --- 意圖分類類別 | |
INTENT_CATEGORIES = [ | |
"操作 (Administration)", | |
"保存/攜帶 (Storage & Handling)", | |
"副作用/異常 (Side Effects / Issues)", | |
"劑型相關 (Dosage Form Concerns)", | |
"時間/併用 (Timing & Interaction)", | |
"劑量調整 (Dosage Adjustment)", | |
"禁忌症/適應症 (Contraindications/Indications)" | |
] | |
DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務必諮詢您的醫師或藥師。" | |
# ---------- 日誌設定 ---------- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
log = logging.getLogger(__name__) | |
# ---------- 核心 RAG 邏輯 (封裝成類別) ---------- | |
class RagPipeline: | |
def __init__(self, config): | |
self.config = config | |
self.state = type('state', (), {})() | |
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"]) | |
self.embedding_model = self._load_embedding_model() | |
self.reranker = self._load_reranker_model() | |
self.csv_path = self._ensure_csv_path(CSV_PATH) | |
def _load_embedding_model(self): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
log.info(f"載入 embedding 模型:{EMBEDDING_MODEL} 至 {device}...") | |
try: | |
model = SentenceTransformer(EMBEDDING_MODEL, device=device) | |
except Exception as e: | |
log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。") | |
device = "cpu" | |
model = SentenceTransformer(EMBEDDING_MODEL, device=device) | |
return model | |
def _load_reranker_model(self): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
log.info(f"載入 reranker 模型:{RERANKER_MODEL} 至 {device}...") | |
try: | |
model = CrossEncoder(RERANKER_MODEL, device=device) | |
except Exception as e: | |
log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。") | |
device = "cpu" | |
model = CrossEncoder(RERANKER_MODEL, device=device) | |
return model | |
def _ensure_csv_path(self, path: str) -> str: | |
if os.path.exists(path): | |
return path | |
raise FileNotFoundError(f"未找到 CSV 檔案: {path}") | |
def _load_data(self): | |
"""在啟動時載入所有必要的模型與資料""" | |
log.info("開始載入資料與模型...") | |
# Load CSV and check for required columns | |
if not os.path.exists(self.csv_path): | |
raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}") | |
self.df_csv = pd.read_csv(self.csv_path, dtype=str).fillna('') | |
required_cols = {"drug_id", "drug_name_norm", "section"} | |
missing_cols = required_cols - set(self.df_csv.columns) | |
if missing_cols: | |
raise ValueError(f"CSV 缺少必要欄位: {missing_cols}") | |
self.df_csv['drug_name_norm_normalized'] = ( | |
self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip() | |
) | |
log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})") | |
# Load corpus and index | |
self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index() | |
self.state.bm25 = self._ensure_bm25_index() | |
# Check for BM25 alignment | |
bm_n = len(self.state.bm25.corpus) | |
sent_n = len(self.state.sentences) | |
if bm_n != sent_n: | |
raise RuntimeError(f"BM25 文件數 ({bm_n}) 與 sentences ({sent_n}) 不一致,請重新生成索引。") | |
log.info("所有模型與資料載入完成。") | |
def _load_or_build_sentence_index(self): | |
if os.path.exists(FAISS_INDEX) and os.path.exists(SENTENCES_PKL): | |
log.info("載入已存在的索引...") | |
index = faiss.read_index(FAISS_INDEX) | |
with open(SENTENCES_PKL, "rb") as f: | |
data = pickle.load(f) | |
sentences = data["sentences"] | |
meta = data["meta"] | |
return index, sentences, meta | |
log.info("索引檔案不存在,正在從 CSV 重新建立...") | |
# This function should be run by a separate script, not here. | |
raise RuntimeError("FAISS 和句子 PKL 檔案未找到,請先執行索引生成腳本。") | |
def _ensure_bm25_index(self): | |
"""載入已有的 BM25 索引,若無則報錯""" | |
if os.path.exists(BM25_PKL): | |
try: | |
with open(BM25_PKL, "rb") as f: | |
data = pickle.load(f) | |
bm25 = data.get("bm25") if isinstance(data, dict) else data | |
if not hasattr(bm25, 'get_scores'): | |
raise ValueError("載入的 BM25 索引無效。") | |
log.info(f"成功載入 BM25 索引,包含 {len(bm25.corpus)} 篇文件。") | |
return bm25 | |
except Exception as e: | |
log.error(f"載入 BM25 索引失敗 ({e}),請檢查檔案格式。") | |
raise RuntimeError("BM25 索引檔案損壞或格式不符。") | |
else: | |
raise FileNotFoundError(f"找不到 BM25 索引檔案於 {BM25_PKL},請先執行索引生成腳本。") | |
def _llm_call(self, messages, temperature=LLM_MODEL_CONFIG["temperature"], max_tokens=LLM_MODEL_CONFIG["max_tokens"]): | |
"""帶有重試機制的 LLM API 呼叫""" | |
try: | |
response = self.llm_client.chat.completions.create( | |
model=LLM_API_CONFIG["model"], | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
stop=LLM_MODEL_CONFIG.get("stop_tokens") or None, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
log.error(f"LLM API 呼叫失敗: {e}") | |
raise | |
def answer_question(self, q_orig: str) -> str: | |
"""處理使用者問題的完整流程 (同步版本)""" | |
start_time = time.time() | |
log.info(f"===== 處理新查詢: '{q_orig}' =====") | |
try: | |
log.info("步驟 1/5: 辨識藥品名稱...") | |
drug_ids = self._find_drug_ids_from_name(q_orig, self.df_csv) | |
if not drug_ids: | |
log.warning("未找到對應藥品,直接回覆。") | |
return f"未在資料庫中找到該藥品,請檢查名稱或諮詢醫師/藥師。{DISCLAIMER}" | |
log.info(f"找到對應藥品 ID: {drug_ids}") | |
log.info("步驟 2/5: 透過 LLM 進行查詢分析 (分解與意圖偵測)...") | |
analyzed_result = self._analyze_query(q_orig) | |
sub_queries = analyzed_result.get("sub_queries", [q_orig]) | |
intents = analyzed_result.get("intents", []) | |
log.info(f"分析結果 - 子問題: {sub_queries}") | |
log.info(f"分析結果 - 意圖: {intents}") | |
all_reranked_results = [] | |
log.info("步驟 3/5: 檢索與重排序...") | |
relevant_indices = {i for i, m in enumerate(self.state.meta) if m.get("drug_id") in drug_ids} | |
if not relevant_indices: | |
log.error("找不到與藥品相關的語料。") | |
return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}" | |
for sub_q in sub_queries: | |
expanded_q = self._expand_query_with_llm(sub_q, intents) | |
log.info(f"擴展後的查詢: '{expanded_q}'") | |
weights = self._adjust_section_weights(intents) | |
log.info(f"根據意圖調整章節權重: {weights}") | |
sim_indices, sim_scores = self._semantic_search(self.state.index, expanded_q, PRE_RERANK_K * 5, self.embedding_model) | |
log.info(f"語意檢索找到 {len(sim_indices)} 個候選。") | |
tokenized_query = list(jieba.cut(expanded_q)) | |
bm25_scores = self.state.bm25.get_scores(tokenized_query) if self.state.bm25 else np.zeros(len(self.state.sentences)) | |
log.info(f"關鍵字檢索找到 {len(bm25_scores)} 個分數。") | |
candidate_dict = {} | |
for i, sem_score in zip(sim_indices, sim_scores): | |
if i in relevant_indices: | |
candidate_dict[i] = {"sem": sem_score, "bm": 0.0} | |
bm25_top_indices = np.argsort(bm25_scores)[::-1][:PRE_RERANK_K * 5] | |
for i in bm25_top_indices: | |
if i in relevant_indices: | |
bm_score = bm25_scores[i] | |
if i in candidate_dict: | |
candidate_dict[i]["bm"] = bm_score | |
else: | |
candidate_dict[i] = {"sem": 0.0, "bm": bm_score} | |
candidates_list = [] | |
for i, scores in candidate_dict.items(): | |
candidates_list.append((i, scores["sem"], scores["bm"])) | |
if not candidates_list: | |
continue | |
# Normalize scores | |
sem_vals = np.array([s for _, s, _ in candidates_list], dtype=np.float32) | |
bm_vals = np.array([b for _, _, b in candidates_list], dtype=np.float32) | |
def norm(x): | |
rng = x.max() - x.min() | |
return (x - x.min()) / (rng + 1e-8) | |
sem_n = norm(sem_vals) | |
bm_n = norm(bm_vals) | |
fused_candidates = [] | |
for idx, (i, s_raw, b_raw) in enumerate(candidates_list): | |
section_name = self.state.meta[i].get("section", "其他") | |
section_weight = weights.get(section_name, 1.0) | |
fused_score = (sem_n[idx] * 0.5 + bm_n[idx] * 0.4) * section_weight | |
fused_candidates.append((i, fused_score, s_raw, b_raw)) | |
fused_candidates.sort(key=lambda x: x[1], reverse=True) | |
sub_reranked = self._rerank_with_crossencoder(q_orig, fused_candidates, self.state.sentences, self.reranker, TOP_K_SENTENCES, self.state.meta, RERANK_THRESHOLD) | |
# De-duplicate using index | |
processed_indices = {res['idx'] for res in all_reranked_results} | |
for r in sub_reranked: | |
if r['idx'] in processed_indices: | |
continue | |
all_reranked_results.append(r) | |
all_reranked_results.sort(key=lambda x: x['rerank_score'], reverse=True) | |
log.info(f"Reranker 最終選出 {len(all_reranked_results)} 個高品質候選。") | |
log.debug("所有重排序結果:\n" + json.dumps(all_reranked_results, indent=2, ensure_ascii=False)) | |
log.info("步驟 4/5: 建立上下文並生成最終答案...") | |
context = self._build_context(all_reranked_results, LLM_MODEL_CONFIG["max_context_chars"]) | |
prompt = self._make_prompt(q_orig, context, intents) | |
messages = [ | |
{"role": "system", "content": "你是嚴謹的台灣藥師。"}, | |
{"role": "user", "content": prompt} | |
] | |
log.debug("傳送給 LLM 的最終 Prompt:\n" + prompt) | |
answer = self._llm_call(messages) | |
log.info("成功從 LLM 獲得回答。") | |
log.debug("LLM 原始回答:\n" + answer) | |
log.info("步驟 5/5: 格式化答案並回覆。") | |
final_answer_formatted = self._format_final_answer(answer, DISCLAIMER) | |
log.debug("最終回覆內容:\n" + final_answer_formatted) | |
end_time = time.time() | |
log.info(f"===== 查詢處理完成,總耗時: {end_time - start_time:.2f} 秒 =====") | |
return final_answer_formatted | |
except Exception as e: | |
log.error(f"處理查詢 '{q_orig}' 時發生錯誤: {e}", exc_info=True) | |
return f"處理時發生錯誤,請檢查日志。{DISCLAIMER}" | |
def _analyze_query(self, query: str) -> Dict[str, Any]: | |
"""一次性呼叫 LLM,同時獲取子問題和意圖。""" | |
options = "\n".join(f"- {c}" for c in INTENT_CATEGORIES) | |
prompt = f""" | |
請分析以下使用者問題,並完成以下兩個任務: | |
1. 將問題分解為1-3個子問題。 | |
2. 判斷問題的意圖,從清單中選擇最貼近的分類。 | |
請以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intent' (字串) 兩個鍵。 | |
範例: {{"sub_queries": ["子問題一", "子問題二"], "intent": "分類名稱"}} | |
清單: | |
{options} | |
使用者問題:{query} | |
""" | |
messages = [{"role": "user", "content": prompt}] | |
response = "" | |
try: | |
response = self._llm_call(messages, temperature=0.2) | |
result = json.loads(response) | |
sub_queries = result.get("sub_queries", []) | |
intent = result.get("intent", None) | |
if not sub_queries: | |
sub_queries = [query] | |
return {"sub_queries": sub_queries, "intents": [intent] if intent else []} | |
except Exception as e: | |
log.error(f"分析查詢時發生錯誤,LLM回覆: '{response}',錯誤: {e}", exc_info=True) | |
return {"sub_queries": [query], "intents": []} | |
def _find_drug_ids_from_name(self, query: str, df: pd.DataFrame) -> List[str]: | |
if df is None: | |
return [] | |
candidates = extract_drug_candidates_from_query(query) | |
expanded = expand_aliases(candidates) | |
drug_ids = set() | |
for alias in expanded: | |
try: | |
# Use regex=False for literal matching, which is safer | |
mask = df['drug_name_norm_normalized'].str.contains(alias.lower(), case=False, regex=False, na=False) | |
matches = df.loc[mask, 'drug_id'].dropna().unique().tolist() | |
drug_ids.update(matches) | |
except Exception as e: | |
log.warning(f"Failed to match '{alias}': {e}. Skipping this alias.") | |
return list(drug_ids) | |
def _expand_query_with_llm(self, query: str, intents: List[str]) -> str: | |
prompt = f"""請根據以下意圖:{intents},擴展原始查詢,加入相關同義詞、相關術語和不同的說法。 | |
原始查詢:{query} | |
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。""" | |
try: | |
return self._llm_call([{"role": "user", "content": prompt}]) | |
except Exception as e: | |
log.error(f"擴展查詢失敗: {e}", exc_info=True) | |
return query # 回傳原始查詢作為備用 | |
def _adjust_section_weights(self, intents: List[str]) -> Dict[str, float]: | |
"""根據意圖調整各仿單章節的檢索權重""" | |
weights = SECTION_WEIGHTS.copy() | |
if not intents: | |
return weights | |
intent = intents[0] | |
if intent in ["操作 (Administration)", "劑型相關 (Dosage Form Concerns)"]: | |
weights["用法及用量"] *= 1.5 | |
weights["病人使用須知"] *= 2.0 | |
elif intent == "保存/攜帶 (Storage & Handling)": | |
weights["儲存條件"] *= 2.0 | |
elif intent == "副作用/異常 (Side Effects / Issues)": | |
weights["警語及注意事項"] *= 3.0 | |
weights["副作用"] *= 1.5 | |
elif intent == "時間/併用 (Timing & Interaction)": | |
weights["用法及用量"] *= 1.5 | |
weights["藥物交互作用"] *= 2.0 | |
elif intent == "劑量調整 (Dosage Adjustment)": | |
weights["用法及用量"] *= 2.0 | |
return weights | |
def _semantic_search(self, index, query: str, top_k: int, embedding_model) -> Tuple[List[int], List[float]]: | |
if not query: | |
return [], [] | |
q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32") | |
faiss.normalize_L2(q_emb) | |
distances, indices = index.search(q_emb, top_k) | |
# Check for metric type to ensure scores are "higher is better" | |
metric = getattr(index, "metric_type", None) | |
try: | |
import faiss | |
METRIC_L2 = faiss.METRIC_L2 | |
except Exception: | |
METRIC_L2 = 1 | |
if metric == METRIC_L2: | |
scores = (-distances[0]).tolist() # L2 distance is smaller for closer points | |
else: | |
scores = distances[0].tolist() # Inner product (cosine) is larger for closer points | |
return indices[0].tolist(), scores | |
def _rerank_with_crossencoder(self, query: str, candidates: List[Tuple], sentences: List[str], reranker, top_k: int, meta: List[Dict], threshold: float) -> List[Dict]: | |
if not candidates: | |
return [] | |
limited_candidates = candidates[:MAX_RERANK_CANDIDATES] | |
pairs = [(query, sentences[i]) for i, _, _, _ in limited_candidates] | |
scores = reranker.predict(pairs) | |
reranked = [] | |
for (i, fused_score, sem_score, bm_score), rerank_score in zip(limited_candidates, scores): | |
if rerank_score >= threshold: | |
reranked.append({ | |
"idx": i, | |
"rerank_score": rerank_score, | |
"fused_score": fused_score, | |
"sem_score": sem_score, | |
"bm_score": bm_score, | |
"meta": meta[i], | |
"text": sentences[i] | |
}) | |
reranked.sort(key=lambda x: x['rerank_score'], reverse=True) | |
return reranked[:top_k] | |
def _build_context(self, reranked_results: List[Dict], max_chars: int) -> str: | |
context = "" | |
processed_chunks = set() | |
for res in reranked_results: | |
text = res['text'] | |
if text in processed_chunks: | |
continue | |
if len(context) + len(text) > max_chars: | |
break | |
context += text + "\n\n" | |
processed_chunks.add(text) | |
return context.strip() | |
def _make_prompt(self, query: str, context: str, intents: List[str]) -> str: | |
additional_instruction = "" | |
if "劑量調整 (Dosage Adjustment)" in intents or "時間/併用 (Timing & Interaction)" in intents: | |
additional_instruction = "在回答用藥劑量和時間時,務必提醒使用者,醫師開立的藥袋醫囑優先於仿單的一般建議。" | |
return f""" | |
你是一位專業且謹慎的台灣藥師。請嚴格根據「參考資料」回答使用者問題,使用繁體中文。 | |
規則: | |
1) 完全依據參考資料,不得捏造或引用外部知識。 | |
2) 以清楚的段落或條列回覆,不要使用 Markdown 符號(如 *, -, #)。 | |
3) 如果資料不足,請回覆:「根據提供的資料,無法回答此問題。」 | |
4) {additional_instruction} | |
參考資料: | |
--- | |
{context} | |
--- | |
使用者問題:{query} | |
請輸出最終答案: | |
""" | |
def _format_final_answer(self, answer: str, disclaimer: str) -> str: | |
return f"{answer}\n\n{disclaimer}" | |
# ---------- FastAPI 事件與路由 ---------- | |
app = FastAPI() | |
rag_pipeline = None | |
class AppConfig: | |
CHANNEL_ACCESS_TOKEN = os.getenv("LINE_CHANNEL_ACCESS_TOKEN") | |
CHANNEL_SECRET = os.getenv("LINE_CHANNEL_SECRET") | |
async def startup_event(): | |
"""應用程式啟動時執行的任務""" | |
log.info("===== Application Startup =====") | |
missing = [] | |
if not AppConfig.CHANNEL_ACCESS_TOKEN: missing.append("LINE_CHANNEL_ACCESS_TOKEN") | |
if not AppConfig.CHANNEL_SECRET: missing.append("LINE_CHANNEL_SECRET") | |
if not LLM_API_CONFIG.get("api_key"): missing.append("LITELLM_API_KEY") | |
if not LLM_API_CONFIG.get("base_url"): missing.append("LITELLM_BASE_URL") | |
if not LLM_API_CONFIG.get("model"): missing.append("LM_MODEL") | |
if missing: | |
log.error(f"缺少必要環境變數:{missing}") | |
raise RuntimeError(f"Missing required environment variables: {missing}") | |
global rag_pipeline | |
rag_pipeline = RagPipeline(AppConfig) | |
rag_pipeline._load_data() | |
log.info("啟動檢查完成。") | |
async def health_check(): | |
"""健康檢查端點,用於 Docker HEALTHCHECK""" | |
return {"status": "ok"} | |
async def handle_webhook(request: Request, response: Response): | |
"""處理 LINE Message API 的 Webhook 請求""" | |
signature = request.headers.get("X-Line-Signature") | |
if not signature: | |
raise HTTPException(status_code=400, detail="X-Line-Signature header missing") | |
body = await request.body() | |
try: | |
digest = hmac.new(AppConfig.CHANNEL_SECRET.encode("utf-8"), body, hashlib.sha256).digest() | |
expected = base64.b64encode(digest).decode() | |
if not hmac.compare_digest(expected, signature): | |
raise HTTPException(status_code=403, detail="Invalid signature") | |
except HTTPException: | |
raise | |
except Exception as e: | |
log.error(f"簽名驗證失敗: {e}") | |
raise HTTPException(status_code=500, detail="Signature verification error") | |
data = json.loads(body.decode('utf-8')) | |
for event in data.get("events", []): | |
if event.get("type") == "message" and event.get("message", {}).get("type") == "text": | |
reply_token = event.get("replyToken") | |
user_text = event.get("message", {}).get("text", "").strip() | |
if not user_text: continue | |
# Offload heavy work to a thread pool | |
answer = await run_in_threadpool(rag_pipeline.answer_question, user_text) | |
if reply_token: | |
line_reply(reply_token, answer) | |
return {"status": "ok"} | |
def line_reply(reply_token: str, text: str): | |
"""透過 LINE Message API 回覆訊息,並進行分塊以避免長度限制""" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}" | |
} | |
# LINE 文本長度上限約為 5000 字元 | |
chunks = textwrap.wrap(text, 4900) | |
messages = [{"type": "text", "text": c} for c in chunks] or [{"type": "text", "text": text[:4900]}] | |
data = {"replyToken": reply_token, "messages": messages} | |
try: | |
r = requests.post("https://api.line.me/v2/bot/message/reply", headers=headers, json=data, timeout=10) | |
if r.status_code >= 300: | |
log.error(f"LINE API 回覆失敗: {r.status_code} {r.text}") | |
except Exception as e: | |
log.error(f"LINE API 回覆失敗: {e}") | |
# ---- 額外工具函式 ---- | |
def extract_drug_candidates_from_query(query: str) -> list: | |
query = re.sub(r"[A-Za-z]+", lambda m: m.group(0).lower(), query) | |
candidates = set() | |
parts = query.split(":", 1) | |
drug_part = parts[0] | |
for m in re.finditer(r"[a-zA-Z]{3,}", drug_part): | |
candidates.add(m.group(0)) | |
for token in re.split(r"[\s,/()()]+", drug_part): | |
clean_token = re.sub(r'[a-zA-Z0-9\s]+', '', token).strip() | |
if clean_token and clean_token.lower() not in DRUG_STOPWORDS: | |
candidates.add(clean_token) | |
# Avoid adding the whole drug_part to prevent regex errors | |
# if drug_part.strip(): | |
# candidates.add(drug_part.strip()) | |
for query_name, dataset_name in DRUG_NAME_MAPPING.items(): | |
if query_name in query.lower(): | |
candidates.add(dataset_name) | |
return [c for c in candidates if len(c) > 1] | |
def expand_aliases(candidates: list) -> list: | |
out = set() | |
for c in candidates: | |
s = c.strip() | |
if not s: | |
continue | |
out.add(s) | |
out.add(re.sub(r"[^0-9A-Za-z\u4e00-\u9fff]+", "", s)) | |
out.add(s.lower()) | |
out.add(s.upper()) | |
return [x for x in out if x] | |
# ---------- 執行 (用於本地測試) ---------- | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) | |