brown-cafe / app.py
Song
hi
0f8b8aa
raw
history blame
27.4 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
DrugQA (ZH) — 優化版 FastAPI LINE Webhook
整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。
"""
# ---------- 環境與快取設定 (應置於最前) ----------
import os
import pathlib
os.environ.setdefault("HF_HOME", "/tmp/hf")
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers")
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache")
for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.getenv("XDG_CACHE_HOME")):
pathlib.Path(d).mkdir(parents=True, exist_ok=True)
# ---------- Python 標準函式庫 ----------
import re
import hmac
import base64
import hashlib
import pickle
import logging
import json
from typing import List, Dict, Any, Optional, Tuple, Union
from functools import lru_cache
import time
import textwrap
# ---------- 第三方函式庫 ----------
import numpy as np
import pandas as pd
from fastapi import FastAPI, Request, Response, HTTPException, status
import uvicorn
import jieba
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
import torch
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_fixed
import requests
from starlette.concurrency import run_in_threadpool
# ==== CONFIG (從環境變數載入,或使用預設值) ====
# 根據提供的檔案清單,將預設路徑設定為當前目錄
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 30))
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 50))
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
_SENT_SPLIT_RE = re.compile(r"[。!?\n]")
DRUG_STOPWORDS = {"藥", "劑", "錠", "膠囊", "糖漿", "乳膏", "貼片"}
SECTION_WEIGHTS = {
"用法及用量": 1.0,
"病人使用須知": 1.0,
"儲存條件": 1.0,
"警語及注意事項": 2.0,
"禁忌": 1.5,
"副作用": 1.0,
"藥物交互作用": 1.0,
"其他": 1.0,
}
RERANK_THRESHOLD = float(os.getenv("RERANK_THRESHOLD", 0.5))
DRUG_NAME_MAPPING = {
"fentanyl patch": "fentanyl",
"spiriva respimat": "spiriva",
"augmentin for syrup": "augmentin syrup",
"nitrostat": "nitroglycerin",
"ozempic": "ozempic",
"niflec": "niflec",
"fosamax": "fosamax",
"humira": "humira",
"premarin": "premarin",
"smecta": "smecta",
}
LLM_API_CONFIG = {
"base_url": os.getenv("LITELLM_BASE_URL"),
"api_key": os.getenv("LITELLM_API_KEY"),
"model": os.getenv("LM_MODEL")
}
LLM_MODEL_CONFIG = {
"max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 12000)),
"max_tokens": int(os.getenv("MAX_TOKENS", 2048)),
"temperature": float(os.getenv("TEMPERATURE", 0.0)),
"top_p": float(os.getenv("TOP_P", 0.95)),
"stop_tokens": ["==="],
}
# --- 意圖分類類別
INTENT_CATEGORIES = [
"操作 (Administration)",
"保存/攜帶 (Storage & Handling)",
"副作用/異常 (Side Effects / Issues)",
"劑型相關 (Dosage Form Concerns)",
"時間/併用 (Timing & Interaction)",
"劑量調整 (Dosage Adjustment)",
"禁忌症/適應症 (Contraindications/Indications)"
]
DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務必諮詢您的醫師或藥師。"
# ---------- 日誌設定 ----------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
# ---------- 核心 RAG 邏輯 (封裝成類別) ----------
class RagPipeline:
def __init__(self, config):
self.config = config
self.state = type('state', (), {})()
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
self.embedding_model = self._load_embedding_model()
self.reranker = self._load_reranker_model()
self.csv_path = self._ensure_csv_path(CSV_PATH)
def _load_embedding_model(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
log.info(f"載入 embedding 模型:{EMBEDDING_MODEL}{device}...")
try:
model = SentenceTransformer(EMBEDDING_MODEL, device=device)
except Exception as e:
log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。")
device = "cpu"
model = SentenceTransformer(EMBEDDING_MODEL, device=device)
return model
def _load_reranker_model(self):
device = "cuda" if torch.cuda.is_available() else "cpu"
log.info(f"載入 reranker 模型:{RERANKER_MODEL}{device}...")
try:
model = CrossEncoder(RERANKER_MODEL, device=device)
except Exception as e:
log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。")
device = "cpu"
model = CrossEncoder(RERANKER_MODEL, device=device)
return model
def _ensure_csv_path(self, path: str) -> str:
if os.path.exists(path):
return path
raise FileNotFoundError(f"未找到 CSV 檔案: {path}")
def _load_data(self):
"""在啟動時載入所有必要的模型與資料"""
log.info("開始載入資料與模型...")
# Load CSV and check for required columns
if not os.path.exists(self.csv_path):
raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
self.df_csv = pd.read_csv(self.csv_path, dtype=str).fillna('')
required_cols = {"drug_id", "drug_name_norm", "section"}
missing_cols = required_cols - set(self.df_csv.columns)
if missing_cols:
raise ValueError(f"CSV 缺少必要欄位: {missing_cols}")
self.df_csv['drug_name_norm_normalized'] = (
self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
)
log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
# Load corpus and index
self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
self.state.bm25 = self._ensure_bm25_index()
# Check for BM25 alignment
bm_n = len(self.state.bm25.corpus)
sent_n = len(self.state.sentences)
if bm_n != sent_n:
raise RuntimeError(f"BM25 文件數 ({bm_n}) 與 sentences ({sent_n}) 不一致,請重新生成索引。")
log.info("所有模型與資料載入完成。")
def _load_or_build_sentence_index(self):
if os.path.exists(FAISS_INDEX) and os.path.exists(SENTENCES_PKL):
log.info("載入已存在的索引...")
index = faiss.read_index(FAISS_INDEX)
with open(SENTENCES_PKL, "rb") as f:
data = pickle.load(f)
sentences = data["sentences"]
meta = data["meta"]
return index, sentences, meta
log.info("索引檔案不存在,正在從 CSV 重新建立...")
# This function should be run by a separate script, not here.
raise RuntimeError("FAISS 和句子 PKL 檔案未找到,請先執行索引生成腳本。")
def _ensure_bm25_index(self):
"""載入已有的 BM25 索引,若無則報錯"""
if os.path.exists(BM25_PKL):
try:
with open(BM25_PKL, "rb") as f:
data = pickle.load(f)
bm25 = data.get("bm25") if isinstance(data, dict) else data
if not hasattr(bm25, 'get_scores'):
raise ValueError("載入的 BM25 索引無效。")
log.info(f"成功載入 BM25 索引,包含 {len(bm25.corpus)} 篇文件。")
return bm25
except Exception as e:
log.error(f"載入 BM25 索引失敗 ({e}),請檢查檔案格式。")
raise RuntimeError("BM25 索引檔案損壞或格式不符。")
else:
raise FileNotFoundError(f"找不到 BM25 索引檔案於 {BM25_PKL},請先執行索引生成腳本。")
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def _llm_call(self, messages, temperature=LLM_MODEL_CONFIG["temperature"], max_tokens=LLM_MODEL_CONFIG["max_tokens"]):
"""帶有重試機制的 LLM API 呼叫"""
try:
response = self.llm_client.chat.completions.create(
model=LLM_API_CONFIG["model"],
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stop=LLM_MODEL_CONFIG.get("stop_tokens") or None,
)
return response.choices[0].message.content
except Exception as e:
log.error(f"LLM API 呼叫失敗: {e}")
raise
def answer_question(self, q_orig: str) -> str:
"""處理使用者問題的完整流程 (同步版本)"""
start_time = time.time()
log.info(f"===== 處理新查詢: '{q_orig}' =====")
try:
log.info("步驟 1/5: 辨識藥品名稱...")
drug_ids = self._find_drug_ids_from_name(q_orig, self.df_csv)
if not drug_ids:
log.warning("未找到對應藥品,直接回覆。")
return f"未在資料庫中找到該藥品,請檢查名稱或諮詢醫師/藥師。{DISCLAIMER}"
log.info(f"找到對應藥品 ID: {drug_ids}")
log.info("步驟 2/5: 透過 LLM 進行查詢分析 (分解與意圖偵測)...")
analyzed_result = self._analyze_query(q_orig)
sub_queries = analyzed_result.get("sub_queries", [q_orig])
intents = analyzed_result.get("intents", [])
log.info(f"分析結果 - 子問題: {sub_queries}")
log.info(f"分析結果 - 意圖: {intents}")
all_reranked_results = []
log.info("步驟 3/5: 檢索與重排序...")
relevant_indices = {i for i, m in enumerate(self.state.meta) if m.get("drug_id") in drug_ids}
if not relevant_indices:
log.error("找不到與藥品相關的語料。")
return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
for sub_q in sub_queries:
expanded_q = self._expand_query_with_llm(sub_q, intents)
log.info(f"擴展後的查詢: '{expanded_q}'")
weights = self._adjust_section_weights(intents)
log.info(f"根據意圖調整章節權重: {weights}")
sim_indices, sim_scores = self._semantic_search(self.state.index, expanded_q, PRE_RERANK_K * 5, self.embedding_model)
log.info(f"語意檢索找到 {len(sim_indices)} 個候選。")
tokenized_query = list(jieba.cut(expanded_q))
bm25_scores = self.state.bm25.get_scores(tokenized_query) if self.state.bm25 else np.zeros(len(self.state.sentences))
log.info(f"關鍵字檢索找到 {len(bm25_scores)} 個分數。")
candidate_dict = {}
for i, sem_score in zip(sim_indices, sim_scores):
if i in relevant_indices:
candidate_dict[i] = {"sem": sem_score, "bm": 0.0}
bm25_top_indices = np.argsort(bm25_scores)[::-1][:PRE_RERANK_K * 5]
for i in bm25_top_indices:
if i in relevant_indices:
bm_score = bm25_scores[i]
if i in candidate_dict:
candidate_dict[i]["bm"] = bm_score
else:
candidate_dict[i] = {"sem": 0.0, "bm": bm_score}
candidates_list = []
for i, scores in candidate_dict.items():
candidates_list.append((i, scores["sem"], scores["bm"]))
if not candidates_list:
continue
# Normalize scores
sem_vals = np.array([s for _, s, _ in candidates_list], dtype=np.float32)
bm_vals = np.array([b for _, _, b in candidates_list], dtype=np.float32)
def norm(x):
rng = x.max() - x.min()
return (x - x.min()) / (rng + 1e-8)
sem_n = norm(sem_vals)
bm_n = norm(bm_vals)
fused_candidates = []
for idx, (i, s_raw, b_raw) in enumerate(candidates_list):
section_name = self.state.meta[i].get("section", "其他")
section_weight = weights.get(section_name, 1.0)
fused_score = (sem_n[idx] * 0.5 + bm_n[idx] * 0.4) * section_weight
fused_candidates.append((i, fused_score, s_raw, b_raw))
fused_candidates.sort(key=lambda x: x[1], reverse=True)
sub_reranked = self._rerank_with_crossencoder(q_orig, fused_candidates, self.state.sentences, self.reranker, TOP_K_SENTENCES, self.state.meta, RERANK_THRESHOLD)
# De-duplicate using index
processed_indices = {res['idx'] for res in all_reranked_results}
for r in sub_reranked:
if r['idx'] in processed_indices:
continue
all_reranked_results.append(r)
all_reranked_results.sort(key=lambda x: x['rerank_score'], reverse=True)
log.info(f"Reranker 最終選出 {len(all_reranked_results)} 個高品質候選。")
log.debug("所有重排序結果:\n" + json.dumps(all_reranked_results, indent=2, ensure_ascii=False))
log.info("步驟 4/5: 建立上下文並生成最終答案...")
context = self._build_context(all_reranked_results, LLM_MODEL_CONFIG["max_context_chars"])
prompt = self._make_prompt(q_orig, context, intents)
messages = [
{"role": "system", "content": "你是嚴謹的台灣藥師。"},
{"role": "user", "content": prompt}
]
log.debug("傳送給 LLM 的最終 Prompt:\n" + prompt)
answer = self._llm_call(messages)
log.info("成功從 LLM 獲得回答。")
log.debug("LLM 原始回答:\n" + answer)
log.info("步驟 5/5: 格式化答案並回覆。")
final_answer_formatted = self._format_final_answer(answer, DISCLAIMER)
log.debug("最終回覆內容:\n" + final_answer_formatted)
end_time = time.time()
log.info(f"===== 查詢處理完成,總耗時: {end_time - start_time:.2f} 秒 =====")
return final_answer_formatted
except Exception as e:
log.error(f"處理查詢 '{q_orig}' 時發生錯誤: {e}", exc_info=True)
return f"處理時發生錯誤,請檢查日志。{DISCLAIMER}"
def _analyze_query(self, query: str) -> Dict[str, Any]:
"""一次性呼叫 LLM,同時獲取子問題和意圖。"""
options = "\n".join(f"- {c}" for c in INTENT_CATEGORIES)
prompt = f"""
請分析以下使用者問題,並完成以下兩個任務:
1. 將問題分解為1-3個子問題。
2. 判斷問題的意圖,從清單中選擇最貼近的分類。
請以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intent' (字串) 兩個鍵。
範例: {{"sub_queries": ["子問題一", "子問題二"], "intent": "分類名稱"}}
清單:
{options}
使用者問題:{query}
"""
messages = [{"role": "user", "content": prompt}]
response = ""
try:
response = self._llm_call(messages, temperature=0.2)
result = json.loads(response)
sub_queries = result.get("sub_queries", [])
intent = result.get("intent", None)
if not sub_queries:
sub_queries = [query]
return {"sub_queries": sub_queries, "intents": [intent] if intent else []}
except Exception as e:
log.error(f"分析查詢時發生錯誤,LLM回覆: '{response}',錯誤: {e}", exc_info=True)
return {"sub_queries": [query], "intents": []}
def _find_drug_ids_from_name(self, query: str, df: pd.DataFrame) -> List[str]:
if df is None:
return []
candidates = extract_drug_candidates_from_query(query)
expanded = expand_aliases(candidates)
drug_ids = set()
for alias in expanded:
try:
# Use regex=False for literal matching, which is safer
mask = df['drug_name_norm_normalized'].str.contains(alias.lower(), case=False, regex=False, na=False)
matches = df.loc[mask, 'drug_id'].dropna().unique().tolist()
drug_ids.update(matches)
except Exception as e:
log.warning(f"Failed to match '{alias}': {e}. Skipping this alias.")
return list(drug_ids)
def _expand_query_with_llm(self, query: str, intents: List[str]) -> str:
prompt = f"""請根據以下意圖:{intents},擴展原始查詢,加入相關同義詞、相關術語和不同的說法。
原始查詢:{query}
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。"""
try:
return self._llm_call([{"role": "user", "content": prompt}])
except Exception as e:
log.error(f"擴展查詢失敗: {e}", exc_info=True)
return query # 回傳原始查詢作為備用
def _adjust_section_weights(self, intents: List[str]) -> Dict[str, float]:
"""根據意圖調整各仿單章節的檢索權重"""
weights = SECTION_WEIGHTS.copy()
if not intents:
return weights
intent = intents[0]
if intent in ["操作 (Administration)", "劑型相關 (Dosage Form Concerns)"]:
weights["用法及用量"] *= 1.5
weights["病人使用須知"] *= 2.0
elif intent == "保存/攜帶 (Storage & Handling)":
weights["儲存條件"] *= 2.0
elif intent == "副作用/異常 (Side Effects / Issues)":
weights["警語及注意事項"] *= 3.0
weights["副作用"] *= 1.5
elif intent == "時間/併用 (Timing & Interaction)":
weights["用法及用量"] *= 1.5
weights["藥物交互作用"] *= 2.0
elif intent == "劑量調整 (Dosage Adjustment)":
weights["用法及用量"] *= 2.0
return weights
def _semantic_search(self, index, query: str, top_k: int, embedding_model) -> Tuple[List[int], List[float]]:
if not query:
return [], []
q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
faiss.normalize_L2(q_emb)
distances, indices = index.search(q_emb, top_k)
# Check for metric type to ensure scores are "higher is better"
metric = getattr(index, "metric_type", None)
try:
import faiss
METRIC_L2 = faiss.METRIC_L2
except Exception:
METRIC_L2 = 1
if metric == METRIC_L2:
scores = (-distances[0]).tolist() # L2 distance is smaller for closer points
else:
scores = distances[0].tolist() # Inner product (cosine) is larger for closer points
return indices[0].tolist(), scores
def _rerank_with_crossencoder(self, query: str, candidates: List[Tuple], sentences: List[str], reranker, top_k: int, meta: List[Dict], threshold: float) -> List[Dict]:
if not candidates:
return []
limited_candidates = candidates[:MAX_RERANK_CANDIDATES]
pairs = [(query, sentences[i]) for i, _, _, _ in limited_candidates]
scores = reranker.predict(pairs)
reranked = []
for (i, fused_score, sem_score, bm_score), rerank_score in zip(limited_candidates, scores):
if rerank_score >= threshold:
reranked.append({
"idx": i,
"rerank_score": rerank_score,
"fused_score": fused_score,
"sem_score": sem_score,
"bm_score": bm_score,
"meta": meta[i],
"text": sentences[i]
})
reranked.sort(key=lambda x: x['rerank_score'], reverse=True)
return reranked[:top_k]
def _build_context(self, reranked_results: List[Dict], max_chars: int) -> str:
context = ""
processed_chunks = set()
for res in reranked_results:
text = res['text']
if text in processed_chunks:
continue
if len(context) + len(text) > max_chars:
break
context += text + "\n\n"
processed_chunks.add(text)
return context.strip()
def _make_prompt(self, query: str, context: str, intents: List[str]) -> str:
additional_instruction = ""
if "劑量調整 (Dosage Adjustment)" in intents or "時間/併用 (Timing & Interaction)" in intents:
additional_instruction = "在回答用藥劑量和時間時,務必提醒使用者,醫師開立的藥袋醫囑優先於仿單的一般建議。"
return f"""
你是一位專業且謹慎的台灣藥師。請嚴格根據「參考資料」回答使用者問題,使用繁體中文。
規則:
1) 完全依據參考資料,不得捏造或引用外部知識。
2) 以清楚的段落或條列回覆,不要使用 Markdown 符號(如 *, -, #)。
3) 如果資料不足,請回覆:「根據提供的資料,無法回答此問題。」
4) {additional_instruction}
參考資料:
---
{context}
---
使用者問題:{query}
請輸出最終答案:
"""
def _format_final_answer(self, answer: str, disclaimer: str) -> str:
return f"{answer}\n\n{disclaimer}"
# ---------- FastAPI 事件與路由 ----------
app = FastAPI()
rag_pipeline = None
class AppConfig:
CHANNEL_ACCESS_TOKEN = os.getenv("LINE_CHANNEL_ACCESS_TOKEN")
CHANNEL_SECRET = os.getenv("LINE_CHANNEL_SECRET")
@app.on_event("startup")
async def startup_event():
"""應用程式啟動時執行的任務"""
log.info("===== Application Startup =====")
missing = []
if not AppConfig.CHANNEL_ACCESS_TOKEN: missing.append("LINE_CHANNEL_ACCESS_TOKEN")
if not AppConfig.CHANNEL_SECRET: missing.append("LINE_CHANNEL_SECRET")
if not LLM_API_CONFIG.get("api_key"): missing.append("LITELLM_API_KEY")
if not LLM_API_CONFIG.get("base_url"): missing.append("LITELLM_BASE_URL")
if not LLM_API_CONFIG.get("model"): missing.append("LM_MODEL")
if missing:
log.error(f"缺少必要環境變數:{missing}")
raise RuntimeError(f"Missing required environment variables: {missing}")
global rag_pipeline
rag_pipeline = RagPipeline(AppConfig)
rag_pipeline._load_data()
log.info("啟動檢查完成。")
@app.get("/health", status_code=status.HTTP_200_OK)
async def health_check():
"""健康檢查端點,用於 Docker HEALTHCHECK"""
return {"status": "ok"}
@app.post("/webhook")
async def handle_webhook(request: Request, response: Response):
"""處理 LINE Message API 的 Webhook 請求"""
signature = request.headers.get("X-Line-Signature")
if not signature:
raise HTTPException(status_code=400, detail="X-Line-Signature header missing")
body = await request.body()
try:
digest = hmac.new(AppConfig.CHANNEL_SECRET.encode("utf-8"), body, hashlib.sha256).digest()
expected = base64.b64encode(digest).decode()
if not hmac.compare_digest(expected, signature):
raise HTTPException(status_code=403, detail="Invalid signature")
except HTTPException:
raise
except Exception as e:
log.error(f"簽名驗證失敗: {e}")
raise HTTPException(status_code=500, detail="Signature verification error")
data = json.loads(body.decode('utf-8'))
for event in data.get("events", []):
if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
reply_token = event.get("replyToken")
user_text = event.get("message", {}).get("text", "").strip()
if not user_text: continue
# Offload heavy work to a thread pool
answer = await run_in_threadpool(rag_pipeline.answer_question, user_text)
if reply_token:
line_reply(reply_token, answer)
return {"status": "ok"}
def line_reply(reply_token: str, text: str):
"""透過 LINE Message API 回覆訊息,並進行分塊以避免長度限制"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}"
}
# LINE 文本長度上限約為 5000 字元
chunks = textwrap.wrap(text, 4900)
messages = [{"type": "text", "text": c} for c in chunks] or [{"type": "text", "text": text[:4900]}]
data = {"replyToken": reply_token, "messages": messages}
try:
r = requests.post("https://api.line.me/v2/bot/message/reply", headers=headers, json=data, timeout=10)
if r.status_code >= 300:
log.error(f"LINE API 回覆失敗: {r.status_code} {r.text}")
except Exception as e:
log.error(f"LINE API 回覆失敗: {e}")
# ---- 額外工具函式 ----
def extract_drug_candidates_from_query(query: str) -> list:
query = re.sub(r"[A-Za-z]+", lambda m: m.group(0).lower(), query)
candidates = set()
parts = query.split(":", 1)
drug_part = parts[0]
for m in re.finditer(r"[a-zA-Z]{3,}", drug_part):
candidates.add(m.group(0))
for token in re.split(r"[\s,/()()]+", drug_part):
clean_token = re.sub(r'[a-zA-Z0-9\s]+', '', token).strip()
if clean_token and clean_token.lower() not in DRUG_STOPWORDS:
candidates.add(clean_token)
# Avoid adding the whole drug_part to prevent regex errors
# if drug_part.strip():
# candidates.add(drug_part.strip())
for query_name, dataset_name in DRUG_NAME_MAPPING.items():
if query_name in query.lower():
candidates.add(dataset_name)
return [c for c in candidates if len(c) > 1]
def expand_aliases(candidates: list) -> list:
out = set()
for c in candidates:
s = c.strip()
if not s:
continue
out.add(s)
out.add(re.sub(r"[^0-9A-Za-z\u4e00-\u9fff]+", "", s))
out.add(s.lower())
out.add(s.upper())
return [x for x in out if x]
# ---------- 執行 (用於本地測試) ----------
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)