Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
DrugQA (ZH) — FastAPI LINE webhook only (/webhook). | |
僅使用這些 HF 環境變數: | |
- CHANNEL_ACCESS_TOKEN | |
- CHANNEL_SECRET | |
- LITELLM_API_KEY | |
- LITELLM_BASE_URL | |
- LM_MODEL | |
優先載入專案根目錄的檔案(drug_sentences.pkl / drug_sentences.index / bm25.pkl), | |
若不存在才退回 /tmp。重建索引時只嘗試寫到 /tmp,避免唯讀權限問題。 | |
所有快取統一 /tmp。 | |
""" | |
# ---------- 先設定快取目錄(import transformers 前) ---------- | |
import os, pathlib, errno | |
os.environ.setdefault("HF_HOME", "/tmp/hf") | |
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/sentence_transformers") | |
os.environ.setdefault("XDG_CACHE_HOME", "/tmp/.cache") | |
os.environ.pop("TRANSFORMERS_CACHE", None) # 已棄用 | |
for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.getenv("XDG_CACHE_HOME")): | |
pathlib.Path(d).mkdir(parents=True, exist_ok=True) | |
# ---------- Imports ---------- | |
import re, hmac, base64, hashlib, pickle, logging, time, json | |
from typing import List, Dict, Any, Optional, Tuple, Union | |
import numpy as np | |
import pandas as pd | |
from scipy import stats | |
try: | |
import torch # 僅用於檢查裝置 | |
except Exception: | |
torch = None | |
try: | |
import faiss # type: ignore | |
except Exception as e: | |
raise RuntimeError(f"faiss not available: {e}") | |
try: | |
from sentence_transformers import SentenceTransformer, CrossEncoder # type: ignore | |
except Exception: | |
SentenceTransformer = None | |
try: | |
from rank_bm25 import BM25Okapi # type: ignore | |
except Exception: | |
BM25Okapi = None | |
try: | |
import jieba # type: ignore | |
except Exception: | |
jieba = None | |
try: | |
from fuzzywuzzy import fuzz # type: ignore | |
except Exception: | |
fuzz = None | |
try: | |
import requests # type: ignore | |
except Exception: | |
requests = None | |
from fastapi import FastAPI, HTTPException, Header, Request | |
# ---------- Logging ---------- | |
LOG_LEVEL = (os.getenv("LOG_LEVEL") or "INFO").upper() | |
logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s - %(levelname)s - %(message)s") | |
log = logging.getLogger("app") | |
# ---------- 只讀取你指定的 HF 環境變數 ---------- | |
CHANNEL_ACCESS_TOKEN = os.getenv("CHANNEL_ACCESS_TOKEN") | |
CHANNEL_SECRET = os.getenv("CHANNEL_SECRET") | |
LITELLM_API_KEY = os.getenv("LITELLM_API_KEY") | |
LITELLM_BASE_URL = os.getenv("LITELLM_BASE_URL") | |
LM_MODEL = os.getenv("LM_MODEL") | |
# ---------- 檢索設定(固定常數) ---------- | |
TOP_K_SENTENCES = 10 | |
BM25_WEIGHT = 0.8 | |
SEM_WEIGHT = 0.2 | |
EMBEDDING_MODEL_ID= "DMetaSoul/Dmeta-embedding-zh" | |
RERANKER_MODEL_ID = "BAAI/bge-reranker-base" | |
USE_CPU = True # HF 預設 CPU | |
RERANK_THRESHOLD = 0.5 | |
MAX_CONTEXT_CHARS = 8000 | |
DISCLAIMER = "此回覆僅供參考,請遵循醫師/藥師指示。" | |
# 藥名映射與停用詞 | |
DRUG_NAME_MAPPING = { | |
"fentanyl patch": "fentanyl", | |
"spiriva respimat": "spiriva", | |
"augmentin for syrup": "augmentin syrup", | |
"nitrostat": "nitroglycerin", | |
"ozempic": "ozempic", | |
"niflec": "niflec", | |
"fosamax": "alendronate", | |
"humira": "adalimumab", | |
"premarin": "premarin", | |
"smecta": "smecta", | |
"duragesic": "fentanyl", | |
"芬太尼貼片": "fentanyl", | |
"透皮止痛貼片": "fentanyl", | |
} | |
DRUG_STOPWORDS = {"藥", "劑", "錠", "膠囊", "糖漿", "乳膏", "貼片", "含錠", "膜衣錠", "緩釋錠", "滴劑", "懸液", "注射液", | |
"吸入劑", "噴霧", "噴霧劑", "吸入器", "注射筆", "藥水", "小袋", "條", "包", "瓶", "外用", "口服"} | |
# 意圖分類(改用字典提升匹配率) | |
INTENT_KEYWORDS = { | |
"操作 (Administration)": ["操作", "使用", "怎麼用", "怎麼吃", "怎麼貼", "怎麼喝", "怎麼注射", "服用", "組裝", "安裝"], | |
"保存/攜帶 (Storage & Handling)": ["保存", "儲存", "攜帶", "冷藏", "室溫", "潮濕", "保冰袋"], | |
"副作用/異常 (Side Effects / Issues)": ["副作用", "異常", "拉肚子", "頭痛", "不良反應", "問題"], | |
"劑型相關 (Dosage Form Concerns)": ["劑型", "錠", "膠囊", "糖漿", "乳膏", "貼片", "脫落", "剪", "泡"], | |
"時間/併用 (Timing & Interaction)": ["時間", "併用", "同用", "間隔", "多久", "上限", "空腹", "咖啡", "酒精"], | |
"劑量調整 (Dosage Adjustment)": ["劑量", "調整", "幾顆", "幾天", "忘記吃"], | |
"禁忌症/適應症 (Contraindications/Indications)": ["禁忌", "適應症", "不能用", "適合"], | |
} | |
# 章節權重 | |
SECTION_NORMALIZE = { | |
"用法用量": "用法及用量", | |
"副作用不良反應": "不良反應", | |
"警語注意事項": "警語及注意事項", | |
"交互作用": "藥物交互作用", | |
"包裝及儲存": "儲存條件", | |
"儲存條件": "儲存條件" | |
} | |
SECTION_WEIGHTS = { | |
"用法及用量": 1.0, | |
"病人使用須知": 1.0, | |
"儲存條件": 1.0, | |
"警語及注意事項": 1.0, | |
"禁忌": 1.0, | |
"副作用": 1.0, | |
"藥物交互作用": 1.0, | |
"其他": 1.0, | |
"包裝及儲存": 1.0, | |
"不良反應": 1.0, | |
} | |
IMPORTANT_SECTIONS = ["用法及用量", "病人使用須知", "包裝及儲存", "不良反應", "警語及注意事項"] | |
# 移除 DOSAGE_FORM_BOOST | |
# ---------- 路徑工具 ---------- | |
def pick_existing_or_tmp(candidates: List[str]) -> str: | |
for p in candidates: | |
if os.path.exists(p): | |
return p | |
base = os.path.basename(candidates[0]) | |
fallback = os.path.join("/tmp", base) | |
pathlib.Path(fallback).parent.mkdir(parents=True, exist_ok=True) | |
return fallback | |
def safe_pickle_dump(obj: Any, preferred_path: str) -> str: | |
try: | |
pathlib.Path(preferred_path).parent.mkdir(parents=True, exist_ok=True) | |
with open(preferred_path, "wb") as f: | |
pickle.dump(obj, f) | |
return preferred_path | |
except OSError as e: | |
if e.errno == errno.EACCES: | |
alt = os.path.join("/tmp", os.path.basename(preferred_path)) | |
try: | |
with open(alt, "wb") as f: | |
pickle.dump(obj, f) | |
log.warning("No write permission for %s, saved to %s instead.", preferred_path, alt) | |
return alt | |
except Exception as ee: | |
log.warning("Failed to save to /tmp as well: %s", ee) | |
else: | |
log.warning("pickle dump failed: %s", e) | |
except Exception as e: | |
log.warning("pickle dump failed: %s", e) | |
return "" | |
def safe_faiss_write(index, preferred_path: str) -> str: | |
try: | |
pathlib.Path(preferred_path).parent.mkdir(parents=True, exist_ok=True) | |
faiss.write_index(index, preferred_path) | |
return preferred_path | |
except OSError as e: | |
if e.errno == errno.EACCES: | |
alt = os.path.join("/tmp", os.path.basename(preferred_path)) | |
try: | |
faiss.write_index(index, alt) | |
log.warning("No write permission for %s, saved FAISS to %s instead.", preferred_path, alt) | |
return alt | |
except Exception as ee: | |
log.warning("Failed to save FAISS to /tmp as well: %s", ee) | |
else: | |
log.warning("faiss write failed: %s", e) | |
except Exception as e: | |
log.warning("faiss write failed: %s", e) | |
return "" | |
# ---------- 檔案路徑(優先專案根目錄,其次 /app,最後 /tmp) ---------- | |
CWD = os.getcwd() | |
SENTENCES_PKL = pick_existing_or_tmp([ | |
os.path.join(CWD, "drug_sentences.pkl"), | |
"/app/drug_sentences.pkl", | |
"/tmp/drug_sentences.pkl", | |
]) | |
FAISS_INDEX = pick_existing_or_tmp([ | |
os.path.join(CWD, "drug_sentences.index"), | |
"/app/drug_sentences.index", | |
"/tmp/drug_sentences.index", | |
]) | |
BM25_PKL = pick_existing_or_tmp([ | |
os.path.join(CWD, "bm25.pkl"), | |
"/app/bm25.pkl", | |
"/tmp/bm25.pkl", | |
]) | |
CSV_PATH = pick_existing_or_tmp([ | |
os.path.join(CWD, "cleaned_combined.csv"), | |
"/app/cleaned_combined.csv", | |
"/tmp/cleaned_combined.csv", | |
]) | |
# ---------- FastAPI ---------- | |
app = FastAPI(title="DrugQA (ZH) — LINE Webhook Only") | |
# ---------- Helpers ---------- | |
_ZH_SPLIT_RE = re.compile(r"[。!?\n]") | |
def split_sentences(text: str) -> List[str]: | |
if not isinstance(text, str): return [] | |
sents = [s.strip() for s in _ZH_SPLIT_RE.split(text) if s.strip()] | |
return [s for s in sents if len(s) > 6] | |
def tokenize_zh(s: str) -> List[str]: | |
if not isinstance(s, str) or not s: return [] | |
if jieba is None: return s.strip().split() | |
return [t for t in jieba.lcut(s) if t.strip() and t not in DRUG_STOPWORDS] | |
class State: | |
sentences: List[str] = [] | |
meta: List[Dict[str, Any]] = [] | |
emb_model: Optional[Any] = None | |
reranker_model: Optional[Any] = None | |
faiss_index: Optional[Any] = None | |
bm25: Optional[Any] = None | |
df_csv: Optional[pd.DataFrame] = None | |
user_sessions: Dict[str, Dict[str, Any]] = {} | |
query_cache: Dict[str, Dict[str, Any]] = {} | |
STATE = State() | |
# ---------- 載入與建立 ---------- | |
def ensure_sentences_meta() -> Tuple[List[str], List[Dict[str, Any]]]: | |
if os.path.exists(SENTENCES_PKL): | |
try: | |
with open(SENTENCES_PKL, "rb") as f: | |
obj = pickle.load(f) | |
sents = obj.get("sentences", []) if isinstance(obj, dict) else [] | |
meta = obj.get("meta", []) if isinstance(obj, dict) else [] | |
log.info("Loaded sentences/meta: %s (n=%d)", SENTENCES_PKL, len(sents)) | |
return sents, meta | |
except Exception as e: | |
log.warning("Failed to load sentences pkl (%s). Corpus will be empty.", e) | |
else: | |
log.info("Sentences pkl not found: %s", SENTENCES_PKL) | |
return [], [] | |
def load_embedding_model(model_id: str): | |
if SentenceTransformer is None: | |
log.warning("sentence-transformers 不可用;僅以 BM25 檢索。") | |
return None | |
device = "cpu" if (USE_CPU or (torch is None)) else ("cuda" if torch.cuda.is_available() else "cpu") | |
log.info("Load SentenceTransformer: %s on %s", model_id, device) | |
try: | |
return SentenceTransformer(model_id, device=device) | |
except Exception as e: | |
log.warning("Failed to load embedding model: %s", e) | |
return None | |
def load_reranker_model(model_id: str): | |
if CrossEncoder is None: | |
log.warning("CrossEncoder 不可用;略過 rerank。") | |
return None | |
device = "cpu" if (USE_CPU or (torch is None)) else ("cuda" if torch.cuda.is_available() else "cpu") | |
log.info("Load CrossEncoder: %s on %s", model_id, device) | |
try: | |
return CrossEncoder(model_id, device=device) | |
except Exception as e: | |
log.warning("Failed to load reranker model: %s", e) | |
return None | |
def ensure_faiss(index_path: str, sentences: List[str]) -> Optional[Any]: | |
if not faiss: return None | |
if os.path.exists(index_path): | |
try: | |
idx = faiss.read_index(index_path) | |
if idx.ntotal == len(sentences): | |
log.info("Loaded FAISS: %s (d=%d n=%d)", index_path, idx.d, idx.ntotal) | |
return idx | |
else: | |
log.warning("FAISS ntotal mismatch (%d != %d). Rebuilding.", idx.ntotal, len(sentences)) | |
except Exception as e: | |
log.warning("Failed to load FAISS (%s): %s", index_path, e) | |
if STATE.emb_model is None: | |
log.warning("No emb_model; skip FAISS build.") | |
return None | |
log.info("Building FAISS (n=%d)...", len(sentences)) | |
embeds = STATE.emb_model.encode(sentences, normalize_embeddings=True, show_progress_bar=True) | |
dim = embeds.shape[1] | |
idx = faiss.IndexFlatIP(dim) | |
idx.add(embeds.astype(np.float32)) | |
safe_faiss_write(idx, index_path) | |
return idx | |
def ensure_bm25(pkl_path: str, sentences: List[str]) -> Optional[Any]: | |
if BM25Okapi is None: return None | |
if os.path.exists(pkl_path): | |
try: | |
with open(pkl_path, "rb") as f: | |
bm = pickle.load(f) | |
# BM25 has corpus, not corpus_size attribute | |
n_bm = len(bm.corpus) if hasattr(bm, 'corpus') else 0 | |
if n_bm == len(sentences): | |
log.info("Loaded BM25: %s (n=%d)", pkl_path, n_bm) | |
return bm | |
else: | |
log.warning("BM25 corpus size mismatch (%d != %d). Rebuilding.", n_bm, len(sentences)) | |
except Exception as e: | |
log.warning("Failed to load BM25 (%s): %s", pkl_path, e) | |
log.info("Building BM25 (n=%d)...", len(sentences)) | |
tokenized_corpus = [tokenize_zh(s) for s in sentences] | |
bm = BM25Okapi(tokenized_corpus) | |
safe_pickle_dump(bm, pkl_path) | |
return bm | |
# ---------- 資訊解析與藥名處理 (簡化) ---------- | |
# 1. parse_user_message: 簡化為只比對藥名 | |
def parse_user_message(query: str, df: pd.DataFrame) -> Dict[str, Any]: | |
""" | |
MODIFIED: 只比對 drug_name_norm,找最佳藥品。 | |
""" | |
best_drug = None | |
best_row = None | |
max_score = 0 | |
if not fuzz: | |
log.warning("fuzzywuzzy not available; skipping fuzzy match.") | |
return { | |
"drug_name": None, | |
"drug_id": None, | |
"question": query, | |
} | |
# Use a pre-tokenized and normalized list for faster fuzzy matching | |
# In a real app, this should be pre-computed and stored for efficiency | |
unique_drugs = df.drop_duplicates(subset=['drug_id']) | |
# Check for direct match first | |
query_lower = query.lower().strip() | |
direct_match = unique_drugs[unique_drugs['drug_name_norm'].str.lower() == query_lower] | |
if not direct_match.empty: | |
best_row = direct_match.iloc[0] | |
best_drug = best_row["drug_name_norm"] | |
log.info(f"Direct match found: {best_drug}") | |
else: | |
for _, row in unique_drugs.iterrows(): | |
drug_norm = (row.get('drug_name_norm') or "").lower() | |
score = fuzz.token_set_ratio(query_lower, drug_norm) | |
if score > max_score: | |
max_score = score | |
best_drug = drug_norm | |
best_row = row | |
if best_drug is None or max_score < 80: # 設定一個閾值來避免不相關的匹配 | |
log.warning(f"No confident drug match found (score: {max_score})") | |
return { | |
"drug_name": None, | |
"drug_id": None, | |
"question": query, | |
} | |
log.info(f"Parsed user message (best match): {best_drug}, score: {max_score}") | |
return { | |
"drug_name": best_drug, | |
"drug_id": best_row["drug_id"], | |
"question": query | |
} | |
# 2. find_drug_candidates: 簡化為單純 fuzzy 比對 | |
def find_drug_candidates(parsed_info: Dict[str, Any], df: pd.DataFrame, top_k: int = 5) -> List[Dict[str, Any]]: | |
""" | |
MODIFIED: 單純對 drug_name_norm 做 fuzzy 比對,並回傳前 top_k 候選。 | |
""" | |
query_text = parsed_info.get("question", "").lower() | |
if df is None or df.empty or not query_text: | |
return [] | |
if not fuzz: | |
return [] | |
candidates_list = [] | |
unique_drugs = df.drop_duplicates(subset=['drug_id']) | |
for _, row in unique_drugs.iterrows(): | |
drug_norm = (row.get('drug_name_norm') or "").lower() | |
score = fuzz.token_set_ratio(query_text, drug_norm) | |
candidates_list.append({ | |
"drug_id": row["drug_id"], | |
"drug_name": drug_norm, | |
"score": score | |
}) | |
# 依 score 排序並回傳前 top_k | |
sorted_candidates = sorted(candidates_list, key=lambda x: x['score'], reverse=True) | |
log.info(f"Found drug candidates: {sorted_candidates[:top_k]}") | |
return sorted_candidates[:top_k] | |
# 3. answer_pipeline: 簡化流程 | |
async def answer_pipeline(query: str, user_id: str) -> str: | |
log.info("Pipeline start for user_id: %s, query: %s", user_id, query[:50]) | |
if not query or not isinstance(query, str): | |
return handle_error("INVALID_QUERY") | |
if not STATE.sentences or not STATE.df_csv: | |
return handle_error("NO_CORPUS") | |
# 1. 解析使用者輸入並找到最佳藥品 | |
best_drug_info = parse_user_message(query, STATE.df_csv) | |
if not best_drug_info.get("drug_id"): | |
log.warning("No confident drug match found.") | |
return make_clarify_message() | |
# 2. 呼叫 find_drug_candidates 產生候選清單 | |
drug_candidates = find_drug_candidates(best_drug_info, STATE.df_csv) | |
# 3. 依 score >= 95 或與次高分差距 > 10 判斷是否選定最佳藥品 | |
top_score = drug_candidates[0]['score'] if drug_candidates else 0 | |
second_score = drug_candidates[1]['score'] if len(drug_candidates) > 1 else 0 | |
if top_score >= 95 or (top_score - second_score) > 10: | |
log.info("Confidently selected drug: %s", best_drug_info['drug_name']) | |
drug_choice = best_drug_info | |
else: | |
log.info("Scores are too close, requesting clarification.") | |
options = [f"「{c.get('drug_name')}」" for c in drug_candidates[:3]] | |
return f"請問您指的是以下哪一種藥物?\n- " + "\n- ".join(options) + f"\n\n{DISCLAIMER}" | |
# 4. 檢索相關內文 (fuse_and_select) | |
idxs = fuse_and_select( | |
query=best_drug_info["question"], | |
sentences=STATE.sentences, | |
meta=STATE.meta, | |
bm25=STATE.bm25, | |
index=STATE.faiss_index, | |
emb_model=STATE.emb_model, | |
reranker=STATE.reranker_model, | |
top_k=TOP_K_SENTENCES, | |
drug_id=drug_choice['drug_id'], | |
# 移除 parsed_info | |
) | |
if not idxs: | |
return handle_error("NO_CONTEXT") | |
# 5. 建立上下文和 Prompt (build_prompt) | |
context = build_context(idxs, STATE.sentences, STATE.meta) | |
prompt = build_prompt(best_drug_info, context, drug_choice) | |
log.info("Generated Prompt:\n%s", prompt) | |
# 6. 呼叫 LLM 生成答案 | |
answer = call_llm(prompt) | |
if not answer: | |
return handle_error("LLM_ERROR") | |
return f"{answer}\n\n{DISCLAIMER}" | |
# 4. build_prompt: 簡化提示詞 | |
def build_prompt(parsed_info: Dict[str, Any], contexts: str, drug_choice: Dict[str, Any]) -> str: | |
""" | |
MODIFIED: 簡化為只包含藥品名稱、使用者問題、參考片段。 | |
""" | |
return ( | |
"你是一位專業、有同理心的藥師。請根據提供的「參考片段」,簡潔地回答使用者的「問題」。\n" | |
"---限制---\n" | |
"- 絕對忠於「參考片段」,不可捏造或過度推論。你的知識僅限於提供的片段。\n" | |
"- 回覆少於 120 字,並使用繁體中文條列式 2-4 點說明。\n" | |
"- 語氣親切、精簡、專業。\n" | |
"- 若片段中無足夠資訊回答,必須回覆:「根據提供的資料,我無法找到關於您問題的明確答案,建議您諮詢醫師或藥師。」\n" | |
"---輸入資訊---\n" | |
f"藥物名稱: {drug_choice.get('drug_name')}\n" | |
f"問題: {parsed_info.get('question')}\n\n" | |
f"參考片段:\n{contexts}\n" | |
"---你的回答---" | |
) | |
def call_llm(prompt: str, max_tokens: int = 2048) -> Optional[str]: | |
try: | |
from openai import OpenAI | |
except Exception as e: | |
log.warning("openai client 不可用:%s", e) | |
return None | |
if not (LITELLM_API_KEY and LM_MODEL and LITELLM_BASE_URL): | |
log.warning("LLM 未完整設定;略過生成。") | |
return None | |
client = OpenAI(base_url=LITELLM_BASE_URL, api_key=LITELLM_API_KEY) | |
try: | |
t0 = time.time() | |
resp = client.chat.completions.create( | |
model=LM_MODEL, | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.1, | |
timeout=15, | |
max_tokens=max_tokens, | |
) | |
used = time.time() - t0 | |
log.info("LLM ok (%.2fs)", used) | |
return (resp.choices[0].message.content or "").strip() | |
except Exception as e: | |
log.warning("LLM 失敗:%s", e) | |
return None | |
def make_clarify_message() -> str: | |
msg = ( | |
"我需要更多資訊才能準確回答,請您提供:\n" | |
"1. 完整的藥物名稱\n" | |
"2. 您的具體問題\n\n" | |
f"{DISCLAIMER}" | |
) | |
return msg | |
def handle_error(code: str) -> str: | |
log.error(f"Pipeline error: {code}") | |
return f"抱歉,系統暫時無法回覆 ({code})。請諮詢醫師或藥師。{DISCLAIMER}" | |
# 5. fuse_and_select: 移除劑型加權 | |
def fuse_and_select(query: str, sentences: List[str], meta: List[Dict[str, Any]], bm25: Optional[Any], index: Optional[Any], emb_model: Optional[Any], reranker: Optional[Any], top_k: int = 10, drug_id: str = None) -> List[int]: | |
""" | |
MODIFIED: 移除劑型加權。只保留 BM25/FAISS 融合 + 章節加權 + 意圖加權。 | |
""" | |
clean_query = query.strip().lower() | |
cache_key = clean_query + str(drug_id) | |
if cache_key in STATE.query_cache and time.time() - STATE.query_cache[cache_key]['time'] < 180: | |
log.info("Cache hit for query: %s", clean_query[:50]) | |
return STATE.query_cache[cache_key]['idxs'] | |
log.info("Searching for drug_id: %s with query: %s", drug_id, clean_query[:50]) | |
if not drug_id: | |
log.warning("No drug_id provided; falling back to full corpus search.") | |
tokenized_query = tokenize_zh(clean_query) | |
scores = {} | |
# BM25 lexical search | |
if bm25: | |
bm_scores = bm25.get_scores(tokenized_query) | |
bm_scores_np = np.array(bm_scores) | |
if np.max(bm_scores_np) > np.min(bm_scores_np): | |
scores_norm = (bm_scores_np - np.min(bm_scores_np)) / (np.max(bm_scores_np) - np.min(bm_scores_np)) | |
else: | |
scores_norm = bm_scores_np | |
for i, s_norm in enumerate(scores_norm): | |
if 0 <= i < len(meta) and (not drug_id or meta[i].get("drug_id") == drug_id): | |
scores[i] = scores.get(i, 0.0) + BM25_WEIGHT * s_norm | |
# FAISS semantic search | |
if emb_model and index: | |
q_emb = emb_model.encode([clean_query], normalize_embeddings=True).astype(np.float32) | |
_, idxs = index.search(q_emb, top_k * 8) | |
for rank, i in enumerate(idxs[0].tolist()): | |
if 0 <= i < len(meta) and (not drug_id or meta[i].get("drug_id") == drug_id): | |
scores[i] = scores.get(i, 0.0) + SEM_WEIGHT * (1.0 / (1 + rank)) | |
# Apply boosts | |
for i in list(scores.keys()): # Iterate over a copy of keys | |
meta_item = meta[i] | |
# Section weight boost | |
sec = meta_item.get("section", "其他") | |
scores[i] *= SECTION_WEIGHTS.get(sec, 1.0) | |
# Boost based on detected intent | |
detected_intents = detect_intent(clean_query) | |
for i in list(scores.keys()): | |
meta_item = meta[i] | |
sec = meta_item.get("section", "其他") | |
for intent in detected_intents: | |
if ("保存" in intent or "儲存" in intent) and sec in ["儲存條件", "包裝及儲存"]: | |
scores[i] *= 1.5 | |
elif ("使用" in intent or "操作" in intent) and sec in ["用法及用量", "病人使用須知"]: | |
scores[i] *= 1.5 | |
elif ("副作用" in intent or "不良反應" in intent) and sec in ["不良反應"]: | |
scores[i] *= 1.5 | |
# Inject important sections if they are missing | |
for sec in IMPORTANT_SECTIONS: | |
sec_idx = next((i for i, m in enumerate(meta) if (m.get("drug_id") == drug_id) and m.get("section") == sec), None) | |
if sec_idx is not None and sec_idx not in scores: | |
scores[sec_idx] = 1.0 # Give it a moderate score to ensure inclusion before reranking | |
# Prepare for reranking | |
candidates = [(i, sc, 0.0, 0.0) for i, sc in scores.items()] | |
reranked = rerank_results(clean_query, candidates, sentences, reranker, top_k, RERANK_THRESHOLD) | |
idxs = [r["idx"] for r in reranked] | |
STATE.query_cache[cache_key] = {'idxs': idxs, 'time': time.time()} | |
return idxs | |
def build_context(idxs: List[int], sentences: List[str], meta: List[Dict[str, Any]]) -> str: | |
ctx_lines, total_len, seen = [], 0, set() | |
for i in idxs: | |
if i < 0: continue | |
text = sentences[i] | |
if text in seen: continue | |
chunk_id = meta[i].get("chunk_id", "None") | |
section = meta[i].get("section", "未知章節") | |
line = f"[{section}]: {text}" | |
if total_len + len(line) > MAX_CONTEXT_CHARS: break | |
ctx_lines.append(line) | |
total_len += len(line) + 1 | |
seen.add(text) | |
return "\n".join(ctx_lines) or "[未知章節]: 沒有找到相關資料,請諮詢醫師或藥師。" | |
# ---------- LINE 驗簽與回覆 ---------- | |
def verify_line_signature(body_bytes: bytes, signature: str) -> bool: | |
if not CHANNEL_SECRET: | |
log.warning("CHANNEL_SECRET 未設定;跳過簽章驗證(僅供測試)。") | |
return True | |
try: | |
mac = hmac.new(CHANNEL_SECRET.encode("utf-8"), body_bytes, hashlib.sha256).digest() | |
expected = base64.b64encode(mac).decode("utf-8") | |
return hmac.compare_digest(expected, signature) | |
except Exception as e: | |
log.warning("簽章驗證錯誤:%s", e) | |
return False | |
def line_reply(reply_token: str, text: str) -> None: | |
if not CHANNEL_ACCESS_TOKEN or requests is None: | |
log.warning("缺少 CHANNEL_ACCESS_TOKEN 或 requests;略過回覆。") | |
return | |
url = "https://api.line.me/v2/bot/message/reply" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {CHANNEL_ACCESS_TOKEN}", | |
} | |
data = {"replyToken": reply_token, "messages": [{"type": "text", "text": text[:4900]}]} | |
try: | |
r = requests.post(url, headers=headers, json=data, timeout=10) | |
if r.status_code != 200: | |
log.warning("LINE 回覆失敗:%s %s", r.status_code, r.text[:200]) | |
except Exception as e: | |
log.warning("LINE 回覆例外:%s", e) | |
# ---------- 只有這一條路由:POST /webhook ---------- | |
async def webhook(request: Request, x_line_signature: str = Header(default="")): | |
body = await request.body() | |
if not verify_line_signature(body, x_line_signature): | |
raise HTTPException(status_code=401, detail="Invalid LINE signature") | |
try: | |
payload = await request.json() | |
except Exception: | |
raise HTTPException(status_code=400, detail="Invalid JSON body") | |
events = payload.get("events", []) | |
for ev in events: | |
if ev.get("type") == "message" and ev.get("message", {}).get("type") == "text": | |
reply_token = ev.get("replyToken") | |
user_id = ev.get("source", {}).get("userId", "unknown") | |
user_text = (ev.get("message", {}).get("text") or "").strip() | |
try: | |
answer = await answer_pipeline(user_text, user_id) | |
except Exception as e: | |
log.warning("Pipeline 失敗:%s", e) | |
answer = "抱歉,系統暫時無法回覆。" | |
if reply_token: | |
line_reply(reply_token, answer) | |
return {"ok": True} | |
# ---------- 啟動 ---------- | |
async def _startup(): | |
log.info("===== Application Startup =====") | |
try: | |
if torch is not None: | |
log.info("PyTorch version %s available.", torch.__version__) | |
except Exception: | |
pass | |
# 載入語料與索引 | |
STATE.sentences, STATE.meta = ensure_sentences_meta() | |
STATE.emb_model = load_embedding_model(EMBEDDING_MODEL_ID) | |
STATE.reranker_model = load_reranker_model(RERANKER_MODEL_ID) | |
STATE.faiss_index = ensure_faiss(FAISS_INDEX, STATE.sentences) | |
STATE.bm25 = ensure_bm25(BM25_PKL, STATE.sentences) | |
for m in STATE.meta: | |
sec = m.get("section", "其他") | |
m["section"] = SECTION_NORMALIZE.get(sec, sec) | |
if os.path.exists(CSV_PATH): | |
STATE.df_csv = pd.read_csv(CSV_PATH, dtype=str) | |
log.info("LLM via LiteLLM: base=%s model=%s", str(LITELLM_BASE_URL), str(LM_MODEL)) | |
log.info("Startup complete.") | |
async def health(): | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.getenv("PORT", "7860")) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level=LOG_LEVEL.lower(), reload=False) | |