Spaces:
Running
Running
Song
commited on
Commit
·
92ee3c2
1
Parent(s):
7b2e5cd
hi
Browse files- app.py +105 -143
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,14 +1,3 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""
|
5 |
-
DrugQA (ZH) — 優化版 FastAPI LINE Webhook (最終版)
|
6 |
-
整合 RAG 邏輯,包含 LLM 意圖偵測、子查詢分解、Intent-aware 檢索與 Rerank。
|
7 |
-
新增動態字數調整、多次互動邏輯與對話狀態管理,提升使用者體驗。
|
8 |
-
僅支援十種藥物。
|
9 |
-
"""
|
10 |
-
|
11 |
-
# ---------- 環境與快取設定 ----------
|
12 |
import os
|
13 |
import pathlib
|
14 |
import re
|
@@ -28,9 +17,8 @@ from contextlib import asynccontextmanager
|
|
28 |
import unicodedata
|
29 |
from collections import defaultdict
|
30 |
import asyncio
|
31 |
-
import aiohttp # 新增:導入 aiohttp 用於異步 HTTP 請求
|
32 |
|
33 |
-
#
|
34 |
import numpy as np
|
35 |
import pandas as pd
|
36 |
import jieba
|
@@ -44,7 +32,7 @@ import requests
|
|
44 |
import uvicorn
|
45 |
from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks
|
46 |
|
47 |
-
#
|
48 |
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1")))
|
49 |
|
50 |
# ===== CONFIG =====
|
@@ -55,30 +43,24 @@ def _require_env(var: str) -> str:
|
|
55 |
raise RuntimeError(f"FATAL: Missing required environment variable: {var}")
|
56 |
return v
|
57 |
|
58 |
-
|
59 |
def _require_llm_config():
|
60 |
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
|
61 |
_require_env(k)
|
62 |
|
63 |
-
|
64 |
# --------- 路徑設定 ------------
|
65 |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
|
66 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
67 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
68 |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
|
69 |
-
|
70 |
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 20))
|
71 |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
|
72 |
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
|
73 |
-
|
74 |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
|
75 |
-
|
76 |
LLM_API_CONFIG = {
|
77 |
"base_url": _require_env("LITELLM_BASE_URL"),
|
78 |
"api_key": _require_env("LITELLM_API_KEY"),
|
79 |
"model": _require_env("LM_MODEL"),
|
80 |
}
|
81 |
-
|
82 |
LLM_MODEL_CONFIG = {
|
83 |
"max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)),
|
84 |
"max_tokens_simple": int(os.getenv("MAX_TOKENS_SIMPLE", 256)),
|
@@ -95,7 +77,6 @@ INTENT_CATEGORIES = [
|
|
95 |
"劑量調整 (Dosage Adjustment)",
|
96 |
"禁忌症/適應症 (Contraindications/Indications)",
|
97 |
]
|
98 |
-
|
99 |
INTENT_TO_SECTION = {
|
100 |
"操作 (Administration)": ["用法用量", "病人使用須知"],
|
101 |
"保存/攜帶 (Storage & Handling)": ["包裝及儲存"],
|
@@ -105,7 +86,6 @@ INTENT_TO_SECTION = {
|
|
105 |
"劑量調整 (Dosage Adjustment)": ["用法用量"],
|
106 |
"禁忌症/適應症 (Contraindications/Indications)": ["適應症", "禁忌", "警語與注意事項"],
|
107 |
}
|
108 |
-
|
109 |
DRUG_NAME_MAPPING = {
|
110 |
"fentanyl patch": "fentanyl",
|
111 |
"spiriva respimat": "spiriva",
|
@@ -122,7 +102,6 @@ SUPPORTED_DRUGS = list(DRUG_NAME_MAPPING.keys())
|
|
122 |
DISCLAIMER = (
|
123 |
"本資訊僅供參考,若您對藥物使用有任何疑問,請務必諮詢您的醫師或藥師。"
|
124 |
)
|
125 |
-
|
126 |
REFERENCE_MAPPING = {
|
127 |
"如何用藥?": "病人使用須知、用法用量",
|
128 |
"如何保存與攜帶?": "包裝及儲存",
|
@@ -130,7 +109,6 @@ REFERENCE_MAPPING = {
|
|
130 |
"每次劑量多少?": "用法用量、藥袋上的醫囑",
|
131 |
"用藥時間?": "用法用量、藥袋上的醫囑",
|
132 |
}
|
133 |
-
|
134 |
REFERENCE_TO_INTENT = {
|
135 |
"如何用藥?": ["操作 (Administration)"],
|
136 |
"如何保存與攜帶?": ["保存/攜帶 (Storage & Handling)"],
|
@@ -138,20 +116,16 @@ REFERENCE_TO_INTENT = {
|
|
138 |
"每次劑量多少?": ["劑量調整 (Dosage Adjustment)"],
|
139 |
"用藥時間?": ["時間/併用 (Timing & Interaction)"],
|
140 |
}
|
141 |
-
|
142 |
PROMPT_TEMPLATES = {
|
143 |
"analyze_query": """
|
144 |
請分析以下使用者問題,並完成以下三個任務:
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列)、'intents' (字串陣列) 和 'complexity' (字串) 三個鍵。
|
150 |
範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"], "complexity": "simple"}}
|
151 |
-
|
152 |
意圖分類清單:
|
153 |
{options}。
|
154 |
-
|
155 |
使用者問題:{query}
|
156 |
""",
|
157 |
"expand_query": """
|
@@ -161,54 +135,41 @@ PROMPT_TEMPLATES = {
|
|
161 |
""",
|
162 |
"final_answer": """
|
163 |
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆:
|
164 |
-
|
165 |
一、 回覆規範:
|
166 |
-
- 回覆語言:使用繁體中文,口語化且易懂,避免專業術語或解釋之。
|
167 |
-
- 結構:先以「簡答:」標記提供簡短總結答案(50-100字),然後以「詳答:」標記提供詳細解釋,最後提醒使用者諮詢醫師。
|
168 |
-
- 長度:簡答控制在50-100字,詳答根據問題複雜度調整,簡單問題約100-200字,複雜問題(如多步驟的裝置安裝或藥品使用)可達300-500字。
|
169 |
-
- 態度:親切、專業、關懷,避免驚嚇使用者。
|
170 |
-
{additional_instruction}
|
171 |
-
|
172 |
-
---
|
173 |
-
參考資料:
|
174 |
-
{context}
|
175 |
-
---
|
176 |
|
|
|
|
|
|
|
|
|
|
|
177 |
使用者問題:{query}
|
178 |
-
|
179 |
請直接輸出最終的答案:
|
180 |
""",
|
181 |
"analyze_reference": """
|
182 |
從以下清單選擇最匹配的使用者問題分類,如果沒有匹配,返回 'none'。
|
183 |
-
|
184 |
分類清單:
|
185 |
{options}
|
186 |
-
|
187 |
使用者問題:{query}
|
188 |
-
|
189 |
請僅輸出分類名稱或 'none',不需任何額外的解釋或格式。
|
190 |
""",
|
191 |
"clarification": """
|
192 |
請根據以下使用者問題,生成一個簡潔、禮貌的澄清性提問,以幫助我更精準地回答。問題應引導使用者提供更多細節,例如具體藥名、使用情境,並附上範例問題。請在回覆中明確告知使用者,目前僅支援以下藥物詢問:
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
範例:
|
205 |
使用者問題:這個藥會怎麼樣?
|
206 |
澄清提問:您好,請問您指的藥物是下列哪一種?目前僅支援以下藥物詢問:Fentanyl patch、Spiriva Respimat...等。例如,您可以問:「Fentanyl patch 的副作用有哪些?」請確認藥名或提供更多細節。
|
207 |
-
|
208 |
使用者問題:{query}
|
209 |
"""
|
210 |
}
|
211 |
-
|
212 |
# ---------- 日誌設定 ----------
|
213 |
logging.basicConfig(
|
214 |
level=logging.INFO,
|
@@ -222,7 +183,6 @@ def _norm(s: str) -> str:
|
|
222 |
s = unicodedata.normalize("NFKC", s)
|
223 |
return re.sub(r"[^\w\s]", "", s.lower()).strip()
|
224 |
|
225 |
-
|
226 |
@dataclass
|
227 |
class FusedCandidate:
|
228 |
idx: int
|
@@ -230,7 +190,6 @@ class FusedCandidate:
|
|
230 |
sem_score: float
|
231 |
bm_score: float
|
232 |
|
233 |
-
|
234 |
@dataclass
|
235 |
class RerankResult:
|
236 |
idx: int
|
@@ -238,7 +197,6 @@ class RerankResult:
|
|
238 |
text: str
|
239 |
meta: Dict[str, Any] = field(default_factory=dict)
|
240 |
|
241 |
-
|
242 |
@dataclass
|
243 |
class ConversationState:
|
244 |
query_history: List[str] = field(default_factory=list)
|
@@ -248,7 +206,6 @@ class ConversationState:
|
|
248 |
last_answer: Optional[str] = None
|
249 |
clarification_count: int = 0
|
250 |
|
251 |
-
|
252 |
# ---------- 核心 RAG 邏輯 ----------
|
253 |
class RagPipeline:
|
254 |
def __init__(self):
|
@@ -314,8 +271,8 @@ class RagPipeline:
|
|
314 |
with open(BM25_PKL, "rb") as f:
|
315 |
bm25_data = pickle.load(f)
|
316 |
self.state.bm25 = bm25_data["bm25"]
|
317 |
-
|
318 |
-
|
319 |
|
320 |
log.info("所有模型與資料載入完成。")
|
321 |
|
@@ -334,9 +291,11 @@ class RagPipeline:
|
|
334 |
for part in q_norm_parts:
|
335 |
if part in self.drug_name_to_ids:
|
336 |
drug_ids.update(self.drug_name_to_ids[part])
|
|
|
337 |
for drug_name, ids in self.drug_name_to_ids.items():
|
338 |
if drug_name in _norm(query):
|
339 |
drug_ids.update(ids)
|
|
|
340 |
return sorted(drug_ids)
|
341 |
|
342 |
def _build_drug_name_to_ids(self) -> Dict[str, List[str]]:
|
@@ -355,11 +314,14 @@ class RagPipeline:
|
|
355 |
part = part.strip()
|
356 |
if part and len(part) > 1:
|
357 |
self.drug_name_to_ids.setdefault(part, []).append(drug_id)
|
|
|
358 |
for alias, canonical_name in DRUG_NAME_MAPPING.items():
|
359 |
if _norm(canonical_name) in _norm(row["drug_name_norm"]):
|
360 |
self.drug_name_to_ids.setdefault(_norm(alias), []).append(drug_id)
|
|
|
361 |
for key in self.drug_name_to_ids:
|
362 |
self.drug_name_to_ids[key] = sorted(set(self.drug_name_to_ids[key]))
|
|
|
363 |
return self.drug_name_to_ids
|
364 |
|
365 |
def _load_drug_name_vocabulary(self):
|
@@ -372,17 +334,19 @@ class RagPipeline:
|
|
372 |
self.drug_vocab["zh"].add(word)
|
373 |
else:
|
374 |
self.drug_vocab["en"].add(word)
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
|
|
|
|
386 |
|
387 |
@tenacity.retry(
|
388 |
wait=tenacity.wait_fixed(2),
|
@@ -421,6 +385,7 @@ class RagPipeline:
|
|
421 |
conv_state.clarification_count += 1
|
422 |
if conv_state.clarification_count > 3:
|
423 |
return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, []
|
|
|
424 |
clarification = self._generate_clarification_query(q_orig)
|
425 |
conv_state.last_answer = clarification
|
426 |
return f"{clarification}\n\n{DISCLAIMER}", []
|
@@ -436,31 +401,37 @@ class RagPipeline:
|
|
436 |
sections = [s.strip() for s in sections_str.split('、') if s.strip() and s != '藥袋上的醫囑']
|
437 |
intents = REFERENCE_TO_INTENT.get(ref_key, [])
|
438 |
context = self._build_context_from_csv(drug_ids, sections)
|
|
|
439 |
# 根據參考資料判斷複雜度
|
440 |
if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
|
441 |
complexity = "complex" # 多步驟的裝置安裝或藥品使用
|
442 |
elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
|
443 |
complexity = "simple" # 副作用問題
|
|
|
|
|
444 |
else:
|
445 |
-
|
|
|
446 |
|
447 |
conv_state.intents = intents
|
448 |
conv_state.complexity = complexity
|
449 |
-
|
450 |
max_tokens = LLM_MODEL_CONFIG["max_tokens_complex"] if complexity == "complex" else LLM_MODEL_CONFIG["max_tokens_simple"]
|
451 |
prompt = self._make_final_prompt(q_orig, context, intents)
|
452 |
answer = self._llm_call(
|
453 |
[{"role": "user", "content": prompt}],
|
454 |
max_tokens=max_tokens
|
455 |
)
|
|
|
456 |
if not answer:
|
457 |
return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
|
458 |
|
459 |
answer = answer.replace("*", "")
|
460 |
conv_state.last_answer = answer
|
461 |
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
|
|
462 |
log.info(f"查詢處理完成,耗時: {time.time() - start_time:.2f}秒")
|
463 |
return final_answer, drug_ids
|
|
|
464 |
except Exception as e:
|
465 |
log.error(f"處理查詢時發生錯誤: {e}", exc_info=True)
|
466 |
return f"處理時發生內部錯誤,請稍後再試。\n{DISCLAIMER}", []
|
@@ -471,13 +442,16 @@ class RagPipeline:
|
|
471 |
sub_queries = analysis.get("sub_queries", [q_orig])
|
472 |
intents = analysis.get("intents", [])
|
473 |
complexity = "simple" # 預設為簡單
|
|
|
474 |
sections = []
|
475 |
for intent in intents:
|
476 |
sections.extend(INTENT_TO_SECTION.get(intent, []))
|
|
|
477 |
if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
|
478 |
complexity = "complex"
|
479 |
elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
|
480 |
complexity = "simple"
|
|
|
481 |
conv_state.intents = intents
|
482 |
conv_state.complexity = complexity
|
483 |
|
@@ -486,6 +460,7 @@ class RagPipeline:
|
|
486 |
conv_state.clarification_count += 1
|
487 |
if conv_state.clarification_count > 3:
|
488 |
return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, drug_ids
|
|
|
489 |
clarification = self._generate_clarification_query(q_orig)
|
490 |
conv_state.last_answer = clarification
|
491 |
return f"{clarification}\n\n{DISCLAIMER}", drug_ids
|
@@ -494,7 +469,6 @@ class RagPipeline:
|
|
494 |
drug_ids, sub_queries, intents
|
495 |
)
|
496 |
final_candidates = all_candidates[:TOP_K_SENTENCES]
|
497 |
-
|
498 |
reranked_results = [
|
499 |
RerankResult(
|
500 |
idx=c.idx,
|
@@ -504,6 +478,7 @@ class RagPipeline:
|
|
504 |
)
|
505 |
for c in final_candidates
|
506 |
]
|
|
|
507 |
prioritized = self._prioritize_context(reranked_results, intents)
|
508 |
context = self._build_context(prioritized)
|
509 |
|
@@ -516,6 +491,7 @@ class RagPipeline:
|
|
516 |
[{"role": "user", "content": prompt}],
|
517 |
max_tokens=max_tokens
|
518 |
)
|
|
|
519 |
if not answer:
|
520 |
return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
|
521 |
|
@@ -540,9 +516,9 @@ class RagPipeline:
|
|
540 |
for drug_id in drug_ids:
|
541 |
drug_df = self.df_csv[self.df_csv['drug_id'] == drug_id]
|
542 |
for sec in sections:
|
543 |
-
|
544 |
-
|
545 |
-
content =
|
546 |
if len(context) + len(content) > LLM_MODEL_CONFIG["max_context_chars"]:
|
547 |
return context.strip()
|
548 |
context += content + "\n\n"
|
@@ -572,32 +548,42 @@ class RagPipeline:
|
|
572 |
return []
|
573 |
|
574 |
all_fused_candidates: Dict[int, FusedCandidate] = {}
|
|
|
575 |
for sub_q in sub_queries:
|
576 |
expanded_q = self._expand_query_with_llm(sub_q, intents)
|
577 |
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
|
|
|
578 |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
579 |
faiss.normalize_L2(q_emb)
|
|
|
580 |
distances, sem_indices = self.state.index.search(q_emb, PRE_RERANK_K)
|
581 |
|
582 |
tokenized_query = list(jieba.cut(expanded_q))
|
583 |
bm25_scores = self.state.bm25.get_scores(tokenized_query)
|
|
|
584 |
rel_idx = np.fromiter(relevant_indices, dtype=np.int64)
|
585 |
rel_scores = bm25_scores[rel_idx]
|
586 |
top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
|
587 |
doc_to_bm25_score: Dict[int, float] = {
|
588 |
int(i): float(bm25_scores[i]) for i in top_rel
|
589 |
}
|
|
|
590 |
candidate_scores: Dict[int, Dict[str, float]] = {}
|
|
|
591 |
def to_similarity(d: float) -> float:
|
592 |
return float(d) if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT else 1.0 / (1.0 + float(d))
|
|
|
593 |
for i, dist in zip(sem_indices[0], distances[0]):
|
594 |
if i in relevant_indices:
|
595 |
candidate_scores[i] = {"sem": to_similarity(dist), "bm": 0.0}
|
|
|
596 |
for i, score in doc_to_bm25_score.items():
|
597 |
if i in relevant_indices:
|
598 |
candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score
|
|
|
599 |
if not candidate_scores:
|
600 |
continue
|
|
|
601 |
keys = list(candidate_scores.keys())
|
602 |
sem_scores = np.array([candidate_scores[k]["sem"] for k in keys])
|
603 |
bm_scores = np.array([candidate_scores[k]["bm"] for k in keys])
|
@@ -606,12 +592,14 @@ class RagPipeline:
|
|
606 |
return (x - x.min()) / (x.max() - x.min() + 1e-8) if x.max() - x.min() > 0 else np.zeros_like(x)
|
607 |
|
608 |
sem_n, bm_n = norm(sem_scores), norm(bm_scores)
|
|
|
609 |
for idx, k in enumerate(keys):
|
610 |
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
|
611 |
if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score:
|
612 |
all_fused_candidates[k] = FusedCandidate(
|
613 |
idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]
|
614 |
)
|
|
|
615 |
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
|
616 |
|
617 |
def _expand_query_with_llm(self, query: str, intents: List[str]) -> str:
|
@@ -626,11 +614,14 @@ class RagPipeline:
|
|
626 |
def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]:
|
627 |
if not intents:
|
628 |
return results
|
|
|
629 |
prioritized_sections = set()
|
630 |
for intent in intents:
|
631 |
prioritized_sections.update(INTENT_TO_SECTION.get(intent, []))
|
|
|
632 |
if not prioritized_sections:
|
633 |
return results
|
|
|
634 |
prioritized, other = [], []
|
635 |
for res in results:
|
636 |
if res.meta.get("section") in prioritized_sections:
|
@@ -665,6 +656,7 @@ class RagPipeline:
|
|
665 |
add_instr += "\n請根據以下問題與參考資料對應回答:"
|
666 |
for q, refs in REFERENCE_MAPPING.items():
|
667 |
add_instr += f"\n- {q}: {refs}"
|
|
|
668 |
return PROMPT_TEMPLATES["final_answer"].format(
|
669 |
additional_instruction=add_instr, context=context, query=query
|
670 |
)
|
@@ -674,22 +666,18 @@ class RagPipeline:
|
|
674 |
return json.loads(s)
|
675 |
except json.JSONDecodeError:
|
676 |
try:
|
677 |
-
m = re.search(r"
|
678 |
if m:
|
679 |
return json.loads(m.group(0))
|
680 |
except json.JSONDecodeError:
|
681 |
pass
|
682 |
-
|
683 |
-
|
684 |
|
685 |
# ---------- FastAPI 事件與路由 ----------
|
686 |
class AppConfig:
|
687 |
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
|
688 |
CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
|
689 |
-
|
690 |
-
|
691 |
-
rag_pipeline: Optional[RagPipeline] = None
|
692 |
-
|
693 |
|
694 |
@asynccontextmanager
|
695 |
async def lifespan(app: FastAPI):
|
@@ -701,10 +689,8 @@ async def lifespan(app: FastAPI):
|
|
701 |
yield
|
702 |
log.info("服務關閉中。")
|
703 |
|
704 |
-
|
705 |
app = FastAPI(lifespan=lifespan)
|
706 |
|
707 |
-
|
708 |
@app.post("/webhook")
|
709 |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
710 |
signature = request.headers.get("X-Line-Signature")
|
@@ -712,6 +698,7 @@ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
|
712 |
raise HTTPException(status_code=400, detail="Missing LINE X-Line-Signature header")
|
713 |
|
714 |
body = await request.body()
|
|
|
715 |
try:
|
716 |
hash_obj = hmac.new(AppConfig.CHANNEL_SECRET.encode("utf-8"), body, hashlib.sha256)
|
717 |
expected_signature = base64.b64encode(hash_obj.digest()).decode("utf-8")
|
@@ -728,65 +715,55 @@ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
|
728 |
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
729 |
|
730 |
for event in data.get("events", []):
|
731 |
-
if (
|
732 |
-
event.get("
|
733 |
-
and event.get("message", {}).get("type") == "text"
|
734 |
-
):
|
735 |
-
user_text = event.get("message", {}).get("text", "").strip()
|
736 |
source = event.get("source", {})
|
737 |
stype = source.get("type")
|
738 |
target_id = (
|
739 |
source.get("userId") or source.get("groupId") or source.get("roomId")
|
740 |
)
|
741 |
-
if user_text and target_id:
|
742 |
-
background_tasks.add_task(
|
743 |
-
process_user_query, stype, target_id, user_text
|
744 |
-
)
|
745 |
-
return Response(status_code=status.HTTP_200_OK)
|
746 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
747 |
|
748 |
-
async def process_user_query(source_type: str, target_id: str,
|
749 |
try:
|
750 |
if not rag_pipeline:
|
751 |
-
|
752 |
"系統正在啟動中,請稍後再試。")
|
753 |
return
|
754 |
-
|
755 |
-
await
|
|
|
|
|
756 |
except Exception as e:
|
757 |
log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True)
|
758 |
-
|
759 |
source_type,
|
760 |
target_id,
|
761 |
f"抱歉,處理時發生未預期的錯誤。\n{DISCLAIMER}",
|
762 |
)
|
763 |
|
764 |
-
|
765 |
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
766 |
-
|
767 |
headers = {
|
768 |
"Content-Type": "application/json",
|
769 |
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}",
|
770 |
}
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
response.raise_for_status()
|
779 |
-
|
780 |
-
|
781 |
-
async def line_reply(reply_token: str, text: str):
|
782 |
-
messages = [
|
783 |
-
{"type": "text", "text": chunk}
|
784 |
-
for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]
|
785 |
-
]
|
786 |
-
await line_api_call("reply", {"replyToken": reply_token, "messages": messages})
|
787 |
-
|
788 |
|
789 |
-
|
790 |
messages = [
|
791 |
{"type": "text", "text": chunk}
|
792 |
for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]
|
@@ -794,24 +771,9 @@ async def line_push_generic(source_type: str, target_id: str, text: str):
|
|
794 |
if "目前僅支援以下藥物詢問" in text:
|
795 |
drug_list = "\n".join(f"- {drug}" for drug in SUPPORTED_DRUGS)
|
796 |
messages.append({"type": "text", "text": f"支援的藥物清單:\n{drug_list}"})
|
797 |
-
data = {"to": target_id, "messages": messages}
|
798 |
-
await line_api_call("push", data)
|
799 |
-
|
800 |
-
|
801 |
-
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> List[str]:
|
802 |
-
candidates = set()
|
803 |
-
q_norm = _norm(query)
|
804 |
-
for word in re.findall(r"[a-z0-9]+", q_norm):
|
805 |
-
if word in drug_vocab["en"]:
|
806 |
-
candidates.add(word)
|
807 |
-
for token in jieba.cut(q_norm):
|
808 |
-
if token in drug_vocab["zh"]:
|
809 |
-
candidates.add(token)
|
810 |
-
supported_drugs = set(DRUG_NAME_MAPPING.keys()).union(DRUG_NAME_MAPPING.values())
|
811 |
-
if not candidates.issubset(supported_drugs):
|
812 |
-
candidates = set()
|
813 |
-
return list(candidates)
|
814 |
|
|
|
|
|
815 |
|
816 |
# ---------- 執行 ----------
|
817 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import pathlib
|
3 |
import re
|
|
|
17 |
import unicodedata
|
18 |
from collections import defaultdict
|
19 |
import asyncio
|
|
|
20 |
|
21 |
+
# 第三方函式庫
|
22 |
import numpy as np
|
23 |
import pandas as pd
|
24 |
import jieba
|
|
|
32 |
import uvicorn
|
33 |
from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks
|
34 |
|
35 |
+
# 限制 PyTorch 執行緒數量,避免 CPU 環境下過度佔用資源
|
36 |
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1")))
|
37 |
|
38 |
# ===== CONFIG =====
|
|
|
43 |
raise RuntimeError(f"FATAL: Missing required environment variable: {var}")
|
44 |
return v
|
45 |
|
|
|
46 |
def _require_llm_config():
|
47 |
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
|
48 |
_require_env(k)
|
49 |
|
|
|
50 |
# --------- 路徑設定 ------------
|
51 |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
|
52 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
53 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
54 |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
|
|
|
55 |
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 20))
|
56 |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
|
57 |
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
|
|
|
58 |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
|
|
|
59 |
LLM_API_CONFIG = {
|
60 |
"base_url": _require_env("LITELLM_BASE_URL"),
|
61 |
"api_key": _require_env("LITELLM_API_KEY"),
|
62 |
"model": _require_env("LM_MODEL"),
|
63 |
}
|
|
|
64 |
LLM_MODEL_CONFIG = {
|
65 |
"max_context_chars": int(os.getenv("MAX_CONTEXT_CHARS", 10000)),
|
66 |
"max_tokens_simple": int(os.getenv("MAX_TOKENS_SIMPLE", 256)),
|
|
|
77 |
"劑量調整 (Dosage Adjustment)",
|
78 |
"禁忌症/適應症 (Contraindications/Indications)",
|
79 |
]
|
|
|
80 |
INTENT_TO_SECTION = {
|
81 |
"操作 (Administration)": ["用法用量", "病人使用須知"],
|
82 |
"保存/攜帶 (Storage & Handling)": ["包裝及儲存"],
|
|
|
86 |
"劑量調整 (Dosage Adjustment)": ["用法用量"],
|
87 |
"禁忌症/適應症 (Contraindications/Indications)": ["適應症", "禁忌", "警語與注意事項"],
|
88 |
}
|
|
|
89 |
DRUG_NAME_MAPPING = {
|
90 |
"fentanyl patch": "fentanyl",
|
91 |
"spiriva respimat": "spiriva",
|
|
|
102 |
DISCLAIMER = (
|
103 |
"本資訊僅供參考,若您對藥物使用有任何疑問,請務必諮詢您的醫師或藥師。"
|
104 |
)
|
|
|
105 |
REFERENCE_MAPPING = {
|
106 |
"如何用藥?": "病人使用須知、用法用量",
|
107 |
"如何保存與攜帶?": "包裝及儲存",
|
|
|
109 |
"每次劑量多少?": "用法用量、藥袋上的醫囑",
|
110 |
"用藥時間?": "用法用量、藥袋上的醫囑",
|
111 |
}
|
|
|
112 |
REFERENCE_TO_INTENT = {
|
113 |
"如何用藥?": ["操作 (Administration)"],
|
114 |
"如何保存與攜帶?": ["保存/攜帶 (Storage & Handling)"],
|
|
|
116 |
"每次劑量多少?": ["劑量調整 (Dosage Adjustment)"],
|
117 |
"用藥時間?": ["時間/併用 (Timing & Interaction)"],
|
118 |
}
|
|
|
119 |
PROMPT_TEMPLATES = {
|
120 |
"analyze_query": """
|
121 |
請分析以下使用者問題,並完成以下三個任務:
|
122 |
+
將問題分解為 1-3 個核心子問題。
|
123 |
+
從清單中選擇所有相關的意圖分類。
|
124 |
+
評估問題複雜度,返回 'simple'(單一問題或簡單意圖)或 'complex'(多子問題或複雜意圖,如副作用、劑量調整)。
|
|
|
125 |
請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列)、'intents' (字串陣列) 和 'complexity' (字串) 三個鍵。
|
126 |
範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"], "complexity": "simple"}}
|
|
|
127 |
意圖分類清單:
|
128 |
{options}。
|
|
|
129 |
使用者問題:{query}
|
130 |
""",
|
131 |
"expand_query": """
|
|
|
135 |
""",
|
136 |
"final_answer": """
|
137 |
您是一位專業、親切的台灣藥師,將在LINE上為使用者解答疑問。請依循以下規範,嚴謹地根據提供的「參考資料」給予回覆:
|
|
|
138 |
一、 回覆規範:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
+
回覆語言:使用繁體中文,口語化且易懂,避免專業術語或解釋之。
|
141 |
+
結構:先以「簡答:」標記提供簡短總結答案(50-100字),然後以「詳答:」標記提供詳細解釋,最後提醒使用者諮詢醫師。
|
142 |
+
長度:簡答控制在50-100字,詳答根據問題複雜度調整,簡單問題約100-200字,複雜問題(如多步驟的裝置安裝或藥品使用)可達300-500字。
|
143 |
+
態度:親切、專業、關懷,避免驚嚇使用者。 {additional_instruction}
|
144 |
+
參考資料: {context}
|
145 |
使用者問題:{query}
|
|
|
146 |
請直接輸出最終的答案:
|
147 |
""",
|
148 |
"analyze_reference": """
|
149 |
從以下清單選擇最匹配的使用者問題分類,如果沒有匹配,返回 'none'。
|
|
|
150 |
分類清單:
|
151 |
{options}
|
|
|
152 |
使用者問題:{query}
|
|
|
153 |
請僅輸出分類名稱或 'none',不需任何額外的解釋或格式。
|
154 |
""",
|
155 |
"clarification": """
|
156 |
請根據以下使用者問題,生成一個簡潔、禮貌的澄清性提問,以幫助我更精準地回答。問題應引導使用者提供更多細節,例如具體藥名、使用情境,並附上範例問題。請在回覆中明確告知使用者,目前僅支援以下藥物詢問:
|
157 |
+
Fentanyl patch
|
158 |
+
Spiriva Respimat
|
159 |
+
NITROSTAT
|
160 |
+
AUGMENTIN FOR SYRUP
|
161 |
+
Ozempic
|
162 |
+
NIFLEC
|
163 |
+
Fosamax
|
164 |
+
Humira
|
165 |
+
PREMARIN
|
166 |
+
SMECTA
|
|
|
167 |
範例:
|
168 |
使用者問題:這個藥會怎麼樣?
|
169 |
澄清提問:您好,請問您指的藥物是下列哪一種?目前僅支援以下藥物詢問:Fentanyl patch、Spiriva Respimat...等。例如,您可以問:「Fentanyl patch 的副作用有哪些?」請確認藥名或提供更多細節。
|
|
|
170 |
使用者問題:{query}
|
171 |
"""
|
172 |
}
|
|
|
173 |
# ---------- 日誌設定 ----------
|
174 |
logging.basicConfig(
|
175 |
level=logging.INFO,
|
|
|
183 |
s = unicodedata.normalize("NFKC", s)
|
184 |
return re.sub(r"[^\w\s]", "", s.lower()).strip()
|
185 |
|
|
|
186 |
@dataclass
|
187 |
class FusedCandidate:
|
188 |
idx: int
|
|
|
190 |
sem_score: float
|
191 |
bm_score: float
|
192 |
|
|
|
193 |
@dataclass
|
194 |
class RerankResult:
|
195 |
idx: int
|
|
|
197 |
text: str
|
198 |
meta: Dict[str, Any] = field(default_factory=dict)
|
199 |
|
|
|
200 |
@dataclass
|
201 |
class ConversationState:
|
202 |
query_history: List[str] = field(default_factory=list)
|
|
|
206 |
last_answer: Optional[str] = None
|
207 |
clarification_count: int = 0
|
208 |
|
|
|
209 |
# ---------- 核心 RAG 邏輯 ----------
|
210 |
class RagPipeline:
|
211 |
def __init__(self):
|
|
|
271 |
with open(BM25_PKL, "rb") as f:
|
272 |
bm25_data = pickle.load(f)
|
273 |
self.state.bm25 = bm25_data["bm25"]
|
274 |
+
if not isinstance(self.state.bm25, BM25Okapi):
|
275 |
+
raise ValueError("Loaded BM25 is not a BM25Okapi instance.")
|
276 |
|
277 |
log.info("所有模型與資料載入完成。")
|
278 |
|
|
|
291 |
for part in q_norm_parts:
|
292 |
if part in self.drug_name_to_ids:
|
293 |
drug_ids.update(self.drug_name_to_ids[part])
|
294 |
+
|
295 |
for drug_name, ids in self.drug_name_to_ids.items():
|
296 |
if drug_name in _norm(query):
|
297 |
drug_ids.update(ids)
|
298 |
+
|
299 |
return sorted(drug_ids)
|
300 |
|
301 |
def _build_drug_name_to_ids(self) -> Dict[str, List[str]]:
|
|
|
314 |
part = part.strip()
|
315 |
if part and len(part) > 1:
|
316 |
self.drug_name_to_ids.setdefault(part, []).append(drug_id)
|
317 |
+
|
318 |
for alias, canonical_name in DRUG_NAME_MAPPING.items():
|
319 |
if _norm(canonical_name) in _norm(row["drug_name_norm"]):
|
320 |
self.drug_name_to_ids.setdefault(_norm(alias), []).append(drug_id)
|
321 |
+
|
322 |
for key in self.drug_name_to_ids:
|
323 |
self.drug_name_to_ids[key] = sorted(set(self.drug_name_to_ids[key]))
|
324 |
+
|
325 |
return self.drug_name_to_ids
|
326 |
|
327 |
def _load_drug_name_vocabulary(self):
|
|
|
334 |
self.drug_vocab["zh"].add(word)
|
335 |
else:
|
336 |
self.drug_vocab["en"].add(word)
|
337 |
+
|
338 |
+
for alias in DRUG_NAME_MAPPING:
|
339 |
+
if re.search(r"[\u4e00-\u9fff]", alias):
|
340 |
+
self.drug_vocab["zh"].add(alias)
|
341 |
+
else:
|
342 |
+
self.drug_vocab["en"].add(alias)
|
343 |
+
|
344 |
+
for word in self.drug_vocab["zh"]:
|
345 |
+
try:
|
346 |
+
if word not in jieba.dt.FREQ:
|
347 |
+
jieba.add_word(word, freq=2_000_000)
|
348 |
+
except Exception:
|
349 |
+
pass
|
350 |
|
351 |
@tenacity.retry(
|
352 |
wait=tenacity.wait_fixed(2),
|
|
|
385 |
conv_state.clarification_count += 1
|
386 |
if conv_state.clarification_count > 3:
|
387 |
return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, []
|
388 |
+
|
389 |
clarification = self._generate_clarification_query(q_orig)
|
390 |
conv_state.last_answer = clarification
|
391 |
return f"{clarification}\n\n{DISCLAIMER}", []
|
|
|
401 |
sections = [s.strip() for s in sections_str.split('、') if s.strip() and s != '藥袋上的醫囑']
|
402 |
intents = REFERENCE_TO_INTENT.get(ref_key, [])
|
403 |
context = self._build_context_from_csv(drug_ids, sections)
|
404 |
+
|
405 |
# 根據參考資料判斷複雜度
|
406 |
if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
|
407 |
complexity = "complex" # 多步驟的裝置安裝或藥品使用
|
408 |
elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
|
409 |
complexity = "simple" # 副作用問題
|
410 |
+
else:
|
411 |
+
return await self._fallback_rag(target_id, q_orig, drug_ids)
|
412 |
else:
|
413 |
+
# If no direct reference mapping, use fallback RAG
|
414 |
+
return await self._fallback_rag(target_id, q_orig, drug_ids)
|
415 |
|
416 |
conv_state.intents = intents
|
417 |
conv_state.complexity = complexity
|
|
|
418 |
max_tokens = LLM_MODEL_CONFIG["max_tokens_complex"] if complexity == "complex" else LLM_MODEL_CONFIG["max_tokens_simple"]
|
419 |
prompt = self._make_final_prompt(q_orig, context, intents)
|
420 |
answer = self._llm_call(
|
421 |
[{"role": "user", "content": prompt}],
|
422 |
max_tokens=max_tokens
|
423 |
)
|
424 |
+
|
425 |
if not answer:
|
426 |
return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
|
427 |
|
428 |
answer = answer.replace("*", "")
|
429 |
conv_state.last_answer = answer
|
430 |
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
431 |
+
|
432 |
log.info(f"查詢處理完成,耗時: {time.time() - start_time:.2f}秒")
|
433 |
return final_answer, drug_ids
|
434 |
+
|
435 |
except Exception as e:
|
436 |
log.error(f"處理查詢時發生錯誤: {e}", exc_info=True)
|
437 |
return f"處理時發生內部錯誤,請稍後再試。\n{DISCLAIMER}", []
|
|
|
442 |
sub_queries = analysis.get("sub_queries", [q_orig])
|
443 |
intents = analysis.get("intents", [])
|
444 |
complexity = "simple" # 預設為簡單
|
445 |
+
|
446 |
sections = []
|
447 |
for intent in intents:
|
448 |
sections.extend(INTENT_TO_SECTION.get(intent, []))
|
449 |
+
|
450 |
if any(sec in ["用法用量", "病人使用須知", "劑型相關"] for sec in sections):
|
451 |
complexity = "complex"
|
452 |
elif any(sec in ["不良反應", "警語與注意事項"] for sec in sections):
|
453 |
complexity = "simple"
|
454 |
+
|
455 |
conv_state.intents = intents
|
456 |
conv_state.complexity = complexity
|
457 |
|
|
|
460 |
conv_state.clarification_count += 1
|
461 |
if conv_state.clarification_count > 3:
|
462 |
return "抱歉,多次無法識別您的問題,請確認藥物名稱或聯繫醫師。\n" + DISCLAIMER, drug_ids
|
463 |
+
|
464 |
clarification = self._generate_clarification_query(q_orig)
|
465 |
conv_state.last_answer = clarification
|
466 |
return f"{clarification}\n\n{DISCLAIMER}", drug_ids
|
|
|
469 |
drug_ids, sub_queries, intents
|
470 |
)
|
471 |
final_candidates = all_candidates[:TOP_K_SENTENCES]
|
|
|
472 |
reranked_results = [
|
473 |
RerankResult(
|
474 |
idx=c.idx,
|
|
|
478 |
)
|
479 |
for c in final_candidates
|
480 |
]
|
481 |
+
|
482 |
prioritized = self._prioritize_context(reranked_results, intents)
|
483 |
context = self._build_context(prioritized)
|
484 |
|
|
|
491 |
[{"role": "user", "content": prompt}],
|
492 |
max_tokens=max_tokens
|
493 |
)
|
494 |
+
|
495 |
if not answer:
|
496 |
return f"無法回答您的問題。\n{DISCLAIMER}", drug_ids
|
497 |
|
|
|
516 |
for drug_id in drug_ids:
|
517 |
drug_df = self.df_csv[self.df_csv['drug_id'] == drug_id]
|
518 |
for sec in sections:
|
519 |
+
sec_rows = drug_df[drug_df['section'].str.contains(sec, na=False)]
|
520 |
+
for _, row in sec_rows.iterrows():
|
521 |
+
content = row['content']
|
522 |
if len(context) + len(content) > LLM_MODEL_CONFIG["max_context_chars"]:
|
523 |
return context.strip()
|
524 |
context += content + "\n\n"
|
|
|
548 |
return []
|
549 |
|
550 |
all_fused_candidates: Dict[int, FusedCandidate] = {}
|
551 |
+
|
552 |
for sub_q in sub_queries:
|
553 |
expanded_q = self._expand_query_with_llm(sub_q, intents)
|
554 |
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
|
555 |
+
|
556 |
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
557 |
faiss.normalize_L2(q_emb)
|
558 |
+
|
559 |
distances, sem_indices = self.state.index.search(q_emb, PRE_RERANK_K)
|
560 |
|
561 |
tokenized_query = list(jieba.cut(expanded_q))
|
562 |
bm25_scores = self.state.bm25.get_scores(tokenized_query)
|
563 |
+
|
564 |
rel_idx = np.fromiter(relevant_indices, dtype=np.int64)
|
565 |
rel_scores = bm25_scores[rel_idx]
|
566 |
top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
|
567 |
doc_to_bm25_score: Dict[int, float] = {
|
568 |
int(i): float(bm25_scores[i]) for i in top_rel
|
569 |
}
|
570 |
+
|
571 |
candidate_scores: Dict[int, Dict[str, float]] = {}
|
572 |
+
|
573 |
def to_similarity(d: float) -> float:
|
574 |
return float(d) if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT else 1.0 / (1.0 + float(d))
|
575 |
+
|
576 |
for i, dist in zip(sem_indices[0], distances[0]):
|
577 |
if i in relevant_indices:
|
578 |
candidate_scores[i] = {"sem": to_similarity(dist), "bm": 0.0}
|
579 |
+
|
580 |
for i, score in doc_to_bm25_score.items():
|
581 |
if i in relevant_indices:
|
582 |
candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score
|
583 |
+
|
584 |
if not candidate_scores:
|
585 |
continue
|
586 |
+
|
587 |
keys = list(candidate_scores.keys())
|
588 |
sem_scores = np.array([candidate_scores[k]["sem"] for k in keys])
|
589 |
bm_scores = np.array([candidate_scores[k]["bm"] for k in keys])
|
|
|
592 |
return (x - x.min()) / (x.max() - x.min() + 1e-8) if x.max() - x.min() > 0 else np.zeros_like(x)
|
593 |
|
594 |
sem_n, bm_n = norm(sem_scores), norm(bm_scores)
|
595 |
+
|
596 |
for idx, k in enumerate(keys):
|
597 |
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
|
598 |
if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score:
|
599 |
all_fused_candidates[k] = FusedCandidate(
|
600 |
idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]
|
601 |
)
|
602 |
+
|
603 |
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
|
604 |
|
605 |
def _expand_query_with_llm(self, query: str, intents: List[str]) -> str:
|
|
|
614 |
def _prioritize_context(self, results: List[RerankResult], intents: List[str]) -> List[RerankResult]:
|
615 |
if not intents:
|
616 |
return results
|
617 |
+
|
618 |
prioritized_sections = set()
|
619 |
for intent in intents:
|
620 |
prioritized_sections.update(INTENT_TO_SECTION.get(intent, []))
|
621 |
+
|
622 |
if not prioritized_sections:
|
623 |
return results
|
624 |
+
|
625 |
prioritized, other = [], []
|
626 |
for res in results:
|
627 |
if res.meta.get("section") in prioritized_sections:
|
|
|
656 |
add_instr += "\n請根據以下問題與參考資料對應回答:"
|
657 |
for q, refs in REFERENCE_MAPPING.items():
|
658 |
add_instr += f"\n- {q}: {refs}"
|
659 |
+
|
660 |
return PROMPT_TEMPLATES["final_answer"].format(
|
661 |
additional_instruction=add_instr, context=context, query=query
|
662 |
)
|
|
|
666 |
return json.loads(s)
|
667 |
except json.JSONDecodeError:
|
668 |
try:
|
669 |
+
m = re.search(r"{.*?}", s, re.DOTALL)
|
670 |
if m:
|
671 |
return json.loads(m.group(0))
|
672 |
except json.JSONDecodeError:
|
673 |
pass
|
674 |
+
return default
|
|
|
675 |
|
676 |
# ---------- FastAPI 事件與路由 ----------
|
677 |
class AppConfig:
|
678 |
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
|
679 |
CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
|
680 |
+
rag_pipeline: Optional[RagPipeline] = None
|
|
|
|
|
|
|
681 |
|
682 |
@asynccontextmanager
|
683 |
async def lifespan(app: FastAPI):
|
|
|
689 |
yield
|
690 |
log.info("服務關閉中。")
|
691 |
|
|
|
692 |
app = FastAPI(lifespan=lifespan)
|
693 |
|
|
|
694 |
@app.post("/webhook")
|
695 |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
696 |
signature = request.headers.get("X-Line-Signature")
|
|
|
698 |
raise HTTPException(status_code=400, detail="Missing LINE X-Line-Signature header")
|
699 |
|
700 |
body = await request.body()
|
701 |
+
|
702 |
try:
|
703 |
hash_obj = hmac.new(AppConfig.CHANNEL_SECRET.encode("utf-8"), body, hashlib.sha256)
|
704 |
expected_signature = base64.b64encode(hash_obj.digest()).decode("utf-8")
|
|
|
715 |
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
716 |
|
717 |
for event in data.get("events", []):
|
718 |
+
if event.get("type") == "message":
|
719 |
+
msg = event.get("message", {})
|
|
|
|
|
|
|
720 |
source = event.get("source", {})
|
721 |
stype = source.get("type")
|
722 |
target_id = (
|
723 |
source.get("userId") or source.get("groupId") or source.get("roomId")
|
724 |
)
|
|
|
|
|
|
|
|
|
|
|
725 |
|
726 |
+
if msg.get("type") == "text" and target_id:
|
727 |
+
user_text = msg.get("text", "").strip()
|
728 |
+
if user_text:
|
729 |
+
background_tasks.add_task(
|
730 |
+
process_user_query, stype, target_id, user_text
|
731 |
+
)
|
732 |
+
return Response(status_code=status.HTTP_200_OK)
|
733 |
|
734 |
+
async def process_user_query(source_type: str, target_id: str, input_data: str):
|
735 |
try:
|
736 |
if not rag_pipeline:
|
737 |
+
line_push_generic(source_type, target_id,
|
738 |
"系統正在啟動中,請稍後再試。")
|
739 |
return
|
740 |
+
|
741 |
+
answer, drug_ids = await rag_pipeline.answer_question(target_id, input_data)
|
742 |
+
line_push_generic(source_type, target_id, answer)
|
743 |
+
|
744 |
except Exception as e:
|
745 |
log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True)
|
746 |
+
line_push_generic(
|
747 |
source_type,
|
748 |
target_id,
|
749 |
f"抱歉,處理時發生未預期的錯誤。\n{DISCLAIMER}",
|
750 |
)
|
751 |
|
|
|
752 |
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
753 |
+
def line_api_call(endpoint: str, data: Dict):
|
754 |
headers = {
|
755 |
"Content-Type": "application/json",
|
756 |
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}",
|
757 |
}
|
758 |
+
response = requests.post(
|
759 |
+
f"https://api.line.me/v2/bot/message/{endpoint}",
|
760 |
+
headers=headers,
|
761 |
+
json=data,
|
762 |
+
timeout=10,
|
763 |
+
)
|
764 |
+
response.raise_for_status()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
765 |
|
766 |
+
def line_push_generic(source_type: str, target_id: str, text: str):
|
767 |
messages = [
|
768 |
{"type": "text", "text": chunk}
|
769 |
for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]
|
|
|
771 |
if "目前僅支援以下藥物詢問" in text:
|
772 |
drug_list = "\n".join(f"- {drug}" for drug in SUPPORTED_DRUGS)
|
773 |
messages.append({"type": "text", "text": f"支援的藥物清單:\n{drug_list}"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
774 |
|
775 |
+
data = {"to": target_id, "messages": messages}
|
776 |
+
line_api_call("push", data)
|
777 |
|
778 |
# ---------- 執行 ----------
|
779 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
@@ -12,4 +12,5 @@ torch
|
|
12 |
# LLM 呼叫相關
|
13 |
openai
|
14 |
tenacity
|
15 |
-
requests
|
|
|
|
12 |
# LLM 呼叫相關
|
13 |
openai
|
14 |
tenacity
|
15 |
+
requests
|
16 |
+
aiohttp
|