Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
d49cddf
1
Parent(s):
ee31585
hi
Browse files
app.py
CHANGED
@@ -40,495 +40,360 @@ from sentence_transformers import SentenceTransformer, CrossEncoder
|
|
40 |
import faiss
|
41 |
import torch
|
42 |
from openai import OpenAI
|
43 |
-
from tenacity import retry, stop_after_attempt, wait_fixed
|
44 |
import requests
|
45 |
|
46 |
-
#
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
if not v:
|
54 |
-
raise RuntimeError(f"FATAL: Missing required environment variable: {var}")
|
55 |
-
return v
|
56 |
-
|
57 |
-
# [MODIFIED] 檢查 LLM 相關環境變數
|
58 |
-
def _require_llm_config():
|
59 |
-
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
|
60 |
-
_require_env(k)
|
61 |
|
62 |
-
|
63 |
-
|
64 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
65 |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
|
70 |
|
71 |
-
|
72 |
-
|
|
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
"
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
PROMPT_TEMPLATES = {
|
100 |
-
"analyze_query": """
|
101 |
-
請分析以下使用者問題,並完成以下兩個任務:
|
102 |
-
1. 將問題分解為1-3個核心的子問題。
|
103 |
-
2. 從清單中選擇所有相關的意圖分類。
|
104 |
-
|
105 |
-
請嚴格以 JSON 格式回覆,包含 'sub_queries' (字串陣列) 和 'intents' (字串陣列) 兩個鍵。
|
106 |
-
範例: {{"sub_queries": ["子問題一", "子問題二"], "intents": ["分類名稱一", "分類名稱二"]}}
|
107 |
-
|
108 |
-
意圖分類清單:
|
109 |
-
{options}
|
110 |
-
|
111 |
-
使用者問題:{query}
|
112 |
-
""",
|
113 |
-
"expand_query": """
|
114 |
-
請根據以下意圖:{intents},擴展這個查詢,加入相關同義詞或術語。
|
115 |
-
原始查詢:{query}
|
116 |
-
請僅輸出擴展後的查詢,不需任何額外的解釋或格式。
|
117 |
-
""",
|
118 |
-
"final_answer": """
|
119 |
-
你是一位專業且謹慎的台灣藥師。請嚴格根據「參考資料」回答使用者問題,使用繁體中文。
|
120 |
-
|
121 |
-
規則:
|
122 |
-
1) 完全依據參考資料,不得捏造或引用外部知識。
|
123 |
-
2) 使用清晰的條列式 (例如 1., 2., 3.) 或分段來組織回答,使其易於閱讀。
|
124 |
-
3) 如果資料不足以回答,請直接回覆:「根據提供的資料,無法回答您的問題。」
|
125 |
-
4) {additional_instruction}
|
126 |
-
5) 總結答案,使其簡潔扼要,總長度應在 100 字以內。
|
127 |
-
|
128 |
-
---
|
129 |
-
參考資料:
|
130 |
-
{context}
|
131 |
-
---
|
132 |
|
133 |
-
|
|
|
|
|
|
|
134 |
|
135 |
-
|
|
|
|
|
|
|
|
|
136 |
"""
|
137 |
-
}
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
|
|
143 |
@dataclass
|
144 |
-
class
|
145 |
-
|
146 |
-
|
147 |
-
sem_score: float
|
148 |
-
bm_score: float
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
rerank_score: float
|
154 |
-
text: str
|
155 |
-
meta: Dict[str, Any] = field(default_factory=dict)
|
156 |
|
157 |
-
#
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding")
|
166 |
-
self.reranker = self._load_model(CrossEncoder, RERANKER_MODEL, "reranker")
|
167 |
-
|
168 |
-
self.drug_name_to_ids: Dict[str, List[str]] = {}
|
169 |
-
self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()}
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
log.info(f"載入 {model_type} 模型:{model_name} 至 {device}...")
|
174 |
try:
|
175 |
-
|
176 |
except Exception as e:
|
177 |
-
log.
|
178 |
-
|
179 |
|
180 |
-
|
181 |
-
log.info("
|
182 |
try:
|
183 |
-
self.
|
184 |
-
# [MODIFIED] 增加必要欄位檢查
|
185 |
-
for col in ("drug_name_norm", "drug_id"):
|
186 |
-
if col not in self.df_csv.columns:
|
187 |
-
raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}")
|
188 |
-
|
189 |
-
self.df_csv['drug_name_norm_normalized'] = (
|
190 |
-
self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
|
191 |
-
)
|
192 |
-
self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
|
193 |
-
# [MODIFIED] 把別名也變成可查鍵
|
194 |
-
for alias, canonical in DRUG_NAME_MAPPING.items():
|
195 |
-
alias_key = re.sub(r'[^\w\s]', '', alias.lower()).strip()
|
196 |
-
canonical_key = re.sub(r'[^\w\s]', '', canonical.lower()).strip()
|
197 |
-
if canonical_key in self.drug_name_to_ids:
|
198 |
-
self.drug_name_to_ids[alias_key] = self.drug_name_to_ids[canonical_key]
|
199 |
-
self._load_drug_name_vocabulary()
|
200 |
-
|
201 |
-
log.info("載入 FAISS 索引與句子資料...")
|
202 |
-
self.state.index = faiss.read_index(FAISS_INDEX)
|
203 |
-
self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
|
204 |
-
if hasattr(self.state.index, "nprobe"):
|
205 |
-
self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
|
206 |
with open(SENTENCES_PKL, "rb") as f:
|
207 |
data = pickle.load(f)
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
210 |
|
211 |
-
|
|
|
|
|
|
|
|
|
212 |
with open(BM25_PKL, "rb") as f:
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
221 |
|
222 |
-
|
223 |
log.info("建立藥名詞庫...")
|
224 |
-
for norm_name in self.df_csv['drug_name_norm_normalized'].dropna().unique():
|
225 |
-
parts = norm_name.split()
|
226 |
-
for part in parts:
|
227 |
-
if re.search(r'[\u4e00-\u9fff]', part):
|
228 |
-
self.drug_vocab["zh"].add(part)
|
229 |
-
try:
|
230 |
-
jieba.add_word(part, freq=2_000_000)
|
231 |
-
except Exception:
|
232 |
-
pass
|
233 |
-
else:
|
234 |
-
self.drug_vocab["en"].add(part)
|
235 |
-
for alias in DRUG_NAME_MAPPING:
|
236 |
-
self.drug_vocab["en"].add(alias.lower())
|
237 |
-
if re.search(r'[\u4e00-\u9fff]', alias):
|
238 |
-
try:
|
239 |
-
jieba.add_word(alias, freq=2_000_000)
|
240 |
-
except Exception:
|
241 |
-
pass
|
242 |
-
|
243 |
-
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
244 |
-
def _llm_call(self, messages, **kwargs) -> str:
|
245 |
try:
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
temperature=config["temperature"],
|
251 |
-
max_tokens=config["max_tokens"],
|
252 |
-
)
|
253 |
-
return response.choices[0].message.content
|
254 |
except Exception as e:
|
255 |
-
log.
|
256 |
-
raise
|
257 |
|
258 |
-
|
259 |
-
|
260 |
-
log.info(f"===== 處理新查詢: '{q_orig}' =====")
|
261 |
try:
|
262 |
-
|
263 |
-
if not drug_ids:
|
264 |
-
return f"抱歉,資料庫中找不到該藥品。請確認藥品名稱,或直接諮詢醫師/藥師。{DISCLAIMER}"
|
265 |
-
log.info(f"找到藥品 ID: {drug_ids}")
|
266 |
-
|
267 |
-
analysis = self._analyze_query(q_orig)
|
268 |
-
sub_queries, intents = analysis.get("sub_queries", [q_orig]), analysis.get("intents", [])
|
269 |
-
log.info(f"子問題: {sub_queries}, 意圖: {intents}")
|
270 |
-
|
271 |
-
all_candidates = self._retrieve_candidates_for_all_queries(drug_ids, sub_queries, intents)
|
272 |
-
log.info(f"所有子查詢共找到 {len(all_candidates)} 個不重複候選 chunks。")
|
273 |
-
|
274 |
-
reranked_results = self._rerank_with_crossencoder(q_orig, all_candidates)
|
275 |
-
log.info(f"Reranker 最終選出 {len(reranked_results)} 個高品質候選。")
|
276 |
-
|
277 |
-
context = self._build_context(reranked_results)
|
278 |
-
if not context:
|
279 |
-
return f"根據您的問題,找不到相關的具體說明。建議您直接諮詢醫師或藥師以獲得最準確的資訊。{DISCLAIMER}"
|
280 |
-
|
281 |
-
prompt = self._make_final_prompt(q_orig, context, intents)
|
282 |
-
answer = self._llm_call([{"role": "user", "content": prompt}])
|
283 |
-
|
284 |
-
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
285 |
-
log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====")
|
286 |
-
return final_answer
|
287 |
-
|
288 |
except Exception as e:
|
289 |
-
log.
|
290 |
-
|
291 |
-
|
292 |
-
@lru_cache(maxsize=128)
|
293 |
-
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
294 |
-
q = query.lower()
|
295 |
-
candidates = extract_drug_candidates_from_query(q, self.drug_vocab)
|
296 |
-
drug_ids = set()
|
297 |
-
|
298 |
-
# 英文:詞邊界;中文:也做子字串掃描
|
299 |
-
for k, ids in self.drug_name_to_ids.items():
|
300 |
-
if re.search(r'[\u4e00-\u9fff]', k):
|
301 |
-
if k in q:
|
302 |
-
drug_ids.update(ids)
|
303 |
-
else:
|
304 |
-
if re.search(rf"\b{re.escape(k)}\b", q):
|
305 |
-
drug_ids.update(ids)
|
306 |
-
|
307 |
-
# 仍保留舊的候選詞路徑(補強)
|
308 |
-
for alias in candidates:
|
309 |
-
# [MODIFIED] 英文藥名比對使用詞邊界,避免子字串誤判
|
310 |
-
is_english = not re.search(r'[\u4e00-\u9fff]', alias)
|
311 |
-
for drug_name_norm, ids in self.drug_name_to_ids.items():
|
312 |
-
match = False
|
313 |
-
if is_english:
|
314 |
-
if re.search(rf"\b{re.escape(alias)}\b", drug_name_norm):
|
315 |
-
match = True
|
316 |
-
elif alias in drug_name_norm:
|
317 |
-
match = True
|
318 |
-
|
319 |
-
if match:
|
320 |
-
drug_ids.update(ids)
|
321 |
-
return list(drug_ids)
|
322 |
-
|
323 |
-
def _analyze_query(self, query: str) -> Dict[str, Any]:
|
324 |
-
prompt = PROMPT_TEMPLATES["analyze_query"].format(
|
325 |
-
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
|
326 |
-
query=query
|
327 |
-
)
|
328 |
-
response_str = self._llm_call([{"role": "user", "content": prompt}], temperature=0.1)
|
329 |
-
return self._safe_json_parse(response_str, default={"sub_queries": [query], "intents": []})
|
330 |
-
|
331 |
-
def _retrieve_candidates_for_all_queries(self, drug_ids: List[str], sub_queries: List[str], intents: List[str]) -> List[FusedCandidate]:
|
332 |
-
drug_ids_set = set(map(str, drug_ids))
|
333 |
-
relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id", "")) in drug_ids_set}
|
334 |
-
if not relevant_indices: return []
|
335 |
-
|
336 |
-
all_fused_candidates: Dict[int, FusedCandidate] = {}
|
337 |
-
|
338 |
-
for sub_q in sub_queries:
|
339 |
-
expanded_q = self._expand_query_with_llm(sub_q, tuple(intents))
|
340 |
-
|
341 |
-
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
|
342 |
-
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
343 |
-
faiss.normalize_L2(q_emb)
|
344 |
-
distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K)
|
345 |
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
|
353 |
-
doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_rel}
|
354 |
-
|
355 |
-
candidate_scores: Dict[int, Dict[str, float]] = {}
|
356 |
-
|
357 |
-
# [MODIFIED] 把 distance 轉成「越大越好的相似度」
|
358 |
-
def to_similarity(d: float) -> float:
|
359 |
-
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
360 |
-
return float(d) # IP 越大越好
|
361 |
-
else: # METRIC_L2(多半是平方 L2)
|
362 |
-
return 1.0 / (1.0 + float(d))
|
363 |
-
|
364 |
-
for i, dist in zip(sim_indices[0], distances[0]):
|
365 |
-
if i in relevant_indices:
|
366 |
-
similarity = to_similarity(dist)
|
367 |
-
candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0}
|
368 |
-
|
369 |
-
for i, score in doc_to_bm25_score.items():
|
370 |
-
if i in relevant_indices:
|
371 |
-
candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score
|
372 |
-
|
373 |
-
if not candidate_scores: continue
|
374 |
-
|
375 |
-
# [MODIFIED] 使用固定的鍵順序來確保分數對齊
|
376 |
-
keys = list(candidate_scores.keys())
|
377 |
-
sem_scores = np.array([candidate_scores[k]['sem'] for k in keys])
|
378 |
-
bm_scores = np.array([candidate_scores[k]['bm'] for k in keys])
|
379 |
-
|
380 |
-
def norm(x):
|
381 |
-
rng = x.max() - x.min()
|
382 |
-
return (x - x.min()) / (rng + 1e-8) if rng > 0 else np.zeros_like(x)
|
383 |
|
384 |
-
|
385 |
-
|
386 |
-
for idx, k in enumerate(keys):
|
387 |
-
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
|
388 |
-
|
389 |
-
if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score:
|
390 |
-
all_fused_candidates[k] = FusedCandidate(
|
391 |
-
idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]
|
392 |
-
)
|
393 |
-
|
394 |
-
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
|
395 |
-
|
396 |
-
# [MODIFIED] 移除 lru_cache,因對多變的長查詢效果不佳
|
397 |
-
def _expand_query_with_llm(self, query: str, intents: tuple) -> str:
|
398 |
-
if not intents:
|
399 |
-
return query
|
400 |
-
|
401 |
-
prompt = PROMPT_TEMPLATES["expand_query"].format(intents=list(intents), query=query)
|
402 |
-
|
403 |
-
try:
|
404 |
-
expanded_query = self._llm_call([{"role": "user", "content": prompt}])
|
405 |
-
if expanded_query and expanded_query.strip():
|
406 |
-
return expanded_query
|
407 |
-
else:
|
408 |
-
log.warning(f"Query expansion for '{query}' returned an empty result. Using original query.")
|
409 |
-
return query
|
410 |
-
except Exception as e:
|
411 |
-
log.error(f"Query expansion for '{query}' failed: {e}. Using original query.")
|
412 |
-
return query
|
413 |
|
414 |
-
def
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
|
440 |
-
return PROMPT_TEMPLATES["final_answer"].format(
|
441 |
-
additional_instruction=add_instr, context=context, query=query
|
442 |
-
)
|
443 |
-
|
444 |
-
# [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
|
445 |
-
def _safe_json_parse(self, s: str, default: Any = None) -> Any:
|
446 |
-
m = re.search(r'\{.*?\}', s, re.DOTALL) # 非貪婪
|
447 |
-
if m:
|
448 |
-
s = m.group(0)
|
449 |
try:
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
|
455 |
-
|
456 |
-
app = FastAPI()
|
457 |
-
rag_pipeline: Optional[RagPipeline] = None
|
458 |
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
|
|
489 |
|
490 |
-
|
491 |
-
|
|
|
|
|
|
|
492 |
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
497 |
-
|
498 |
-
for event in data.get("events", []):
|
499 |
-
if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
|
500 |
-
reply_token = event.get("replyToken")
|
501 |
-
user_text = event.get("message", {}).get("text", "").strip()
|
502 |
-
# [MODIFIED] 擷取 target
|
503 |
-
source = event.get("source", {})
|
504 |
-
stype = source.get("type") # "user" | "group" | "room"
|
505 |
-
target_id = source.get("userId") or source.get("groupId") or source.get("roomId")
|
506 |
-
|
507 |
-
if reply_token and user_text and target_id:
|
508 |
-
# [MODIFIED] 更改回覆策略:立即回覆處理中訊息,避免 replyToken 逾時
|
509 |
-
line_reply(reply_token, "收到您的問題,正在查詢資料庫,請稍候...")
|
510 |
-
# 將耗時的任務交給背景處理,使用 push message 回覆最終答案
|
511 |
-
background_tasks.add_task(process_user_query, stype, target_id, user_text)
|
512 |
-
|
513 |
-
return Response(status_code=status.HTTP_200_OK)
|
514 |
-
|
515 |
-
# [MODIFIED] 調整函式簽名,只接收 user_id 和 text,並使用 push message
|
516 |
-
def process_user_query(source_type: str, target_id: str, user_text: str):
|
517 |
-
try:
|
518 |
-
if rag_pipeline:
|
519 |
-
answer = rag_pipeline.answer_question(user_text)
|
520 |
-
else:
|
521 |
-
answer = "系統正在啟動中,請稍後再試。"
|
522 |
-
line_push_generic(source_type, target_id, answer)
|
523 |
-
except Exception as e:
|
524 |
-
log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True)
|
525 |
-
line_push_generic(source_type, target_id, f"抱歉,處理時發生未預期的錯誤。{DISCLAIMER}")
|
526 |
-
|
527 |
-
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
528 |
-
def line_api_call(endpoint: str, data: Dict):
|
529 |
headers = {
|
530 |
"Content-Type": "application/json",
|
531 |
-
"Authorization": f"Bearer {
|
532 |
}
|
533 |
try:
|
534 |
response = requests.post(f"https://api.line.me/v2/bot/message/{endpoint}", headers=headers, json=data, timeout=10)
|
@@ -538,32 +403,78 @@ def line_api_call(endpoint: str, data: Dict):
|
|
538 |
raise
|
539 |
|
540 |
def line_reply(reply_token: str, text: str):
|
541 |
-
|
|
|
542 |
line_api_call("reply", {"replyToken": reply_token, "messages": messages})
|
543 |
|
544 |
def line_push_generic(source_type: str, target_id: str, text: str):
|
545 |
-
|
|
|
546 |
endpoint = "push"
|
547 |
data = {"to": target_id, "messages": messages}
|
548 |
line_api_call(endpoint, data)
|
549 |
|
550 |
-
#
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
|
566 |
-
# ---------- 執行 ----------
|
567 |
if __name__ == "__main__":
|
568 |
-
|
569 |
-
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
40 |
import faiss
|
41 |
import torch
|
42 |
from openai import OpenAI
|
|
|
43 |
import requests
|
44 |
|
45 |
+
# ---------- 應用程式設定與環境變數 ----------
|
46 |
+
# 預設值皆針對 Dockerfile 設定
|
47 |
+
SECRET_TOKEN = os.getenv("LINE_CHANNEL_SECRET", "YOUR_SECRET_TOKEN")
|
48 |
+
ACCESS_TOKEN = os.getenv("LINE_CHANNEL_ACCESS_TOKEN", "YOUR_ACCESS_TOKEN")
|
49 |
+
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
|
50 |
+
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
|
51 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
# 資料來源檔案路徑
|
54 |
+
SENTENCE_FAISS = os.getenv("SENTENCE_FAISS", "drug_sentences.index")
|
55 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
56 |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
|
57 |
+
DRUG_VOCAB_JSON = os.getenv("DRUG_VOCAB_JSON", "drug_vocab.json")
|
58 |
+
PHARMACY_DATA = os.getenv("PHARMACY_DATA", "pharmacy_data.csv")
|
59 |
|
60 |
+
# 針對 LINE API 訊息長度限制
|
61 |
+
MAX_REPLY_LEN = 4800
|
|
|
62 |
|
63 |
+
# 設定日誌
|
64 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
65 |
+
log = logging.getLogger(__name__)
|
66 |
|
67 |
+
# --- 自定義資料結構 ---
|
68 |
+
@dataclass
|
69 |
+
class RagState:
|
70 |
+
"""應用程式狀態與 RAG 模型物件的單一儲存位置。"""
|
71 |
+
faiss_index: Optional[faiss.Index] = None
|
72 |
+
sentences: Optional[List[str]] = None
|
73 |
+
meta: Optional[List[Dict]] = None
|
74 |
+
bm25: Optional[BM25Okapi] = None
|
75 |
+
bm25_tokenized: Optional[List[List[str]]] = None
|
76 |
+
reranker: Optional[CrossEncoder] = None
|
77 |
+
embedding_model: Optional[SentenceTransformer] = None
|
78 |
+
drug_vocab: Optional[Dict[str, str]] = None
|
79 |
+
pharmacy_df: Optional[pd.DataFrame] = None
|
80 |
+
openai_client: Optional[OpenAI] = None
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class IntentClassifier:
|
84 |
+
"""意圖分類器,使用零樣本學習判斷使用者意圖。"""
|
85 |
+
client: OpenAI
|
86 |
+
prompt_template: str = field(init=False)
|
87 |
+
|
88 |
+
def __post_init__(self):
|
89 |
+
self.prompt_template = """你是一個能判斷使用者意圖的 AI 助手。
|
90 |
+
請根據以下提供的意圖清單,判斷使用者查詢的意圖。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
意圖清單:
|
93 |
+
- general_qa: 提問有關藥物或健康資訊的通用問題。
|
94 |
+
- drug_inquiry: 查詢特定藥物的資訊。
|
95 |
+
- pharmacy_search: 詢問藥局的相關資訊。
|
96 |
|
97 |
+
使用者查詢:
|
98 |
+
{query}
|
99 |
+
|
100 |
+
請直接回覆一個意圖,例如:
|
101 |
+
general_qa
|
102 |
"""
|
|
|
103 |
|
104 |
+
def classify_intent(self, query: str) -> str:
|
105 |
+
"""根據查詢內容,使用 LLM 判斷意圖。"""
|
106 |
+
log.info(f"分類意圖:{query}")
|
107 |
+
try:
|
108 |
+
response = self.client.chat.completions.create(
|
109 |
+
model="gpt-3.5-turbo", # 或其他適用模型
|
110 |
+
messages=[
|
111 |
+
{"role": "user", "content": self.prompt_template.format(query=query)}
|
112 |
+
],
|
113 |
+
temperature=0.0
|
114 |
+
)
|
115 |
+
intent = response.choices[0].message.content.strip().lower()
|
116 |
+
if intent not in ["general_qa", "drug_inquiry", "pharmacy_search"]:
|
117 |
+
log.warning(f"偵測到未知意圖:{intent},將視為 general_qa。")
|
118 |
+
return "general_qa"
|
119 |
+
log.info(f"意圖判定:{intent}")
|
120 |
+
return intent
|
121 |
+
except Exception as e:
|
122 |
+
log.error(f"意圖分類失敗:{e},將使用預設意圖 general_qa。")
|
123 |
+
return "general_qa"
|
124 |
|
125 |
+
# --- RAG 流程與核心邏輯 ---
|
126 |
@dataclass
|
127 |
+
class RagPipeline:
|
128 |
+
"""RAG 流程核心處理類別。"""
|
129 |
+
state: RagState = field(default_factory=RagState)
|
|
|
|
|
130 |
|
131 |
+
def load_data(self):
|
132 |
+
"""載入所有必要的 RAG 相關檔案與模型。"""
|
133 |
+
log.info("開始載入資料與模型...")
|
|
|
|
|
|
|
134 |
|
135 |
+
# 載入 embedding 模型
|
136 |
+
device = "cuda" if torch.cuda.is_available() and torch.cuda.is_available() else "cpu"
|
137 |
+
log.info(f"載入 embedding 模型:{EMBEDDING_MODEL} 至 {device}...")
|
138 |
+
try:
|
139 |
+
self.state.embedding_model = SentenceTransformer(EMBEDDING_MODEL, device=device)
|
140 |
+
except Exception as e:
|
141 |
+
log.error(f"載入 embedding 模型失敗:{e}")
|
142 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
+
# 載入 reranker 模型
|
145 |
+
log.info(f"載入 reranker 模型:{RERANKER_MODEL} 至 {device}...")
|
|
|
146 |
try:
|
147 |
+
self.state.reranker = CrossEncoder(RERANKER_MODEL, device=device)
|
148 |
except Exception as e:
|
149 |
+
log.error(f"載入 reranker 模型失敗:{e}")
|
150 |
+
self.state.reranker = None # Reranker 非必要,失敗可繼續
|
151 |
|
152 |
+
# 載入 FAISS 索引與句子資料
|
153 |
+
log.info("載入 FAISS 索引與句子資料...")
|
154 |
try:
|
155 |
+
self.state.faiss_index = faiss.read_index(SENTENCE_FAISS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
with open(SENTENCES_PKL, "rb") as f:
|
157 |
data = pickle.load(f)
|
158 |
+
self.state.sentences = data["sentences"]
|
159 |
+
self.state.meta = data["meta"]
|
160 |
+
except Exception as e:
|
161 |
+
log.error(f"載入 FAISS 或句子 PKL 失敗:{e}")
|
162 |
+
raise
|
163 |
|
164 |
+
# 載入 BM25 索引
|
165 |
+
log.info("載入 BM25 索引...")
|
166 |
+
try:
|
167 |
+
# === [修正] ===
|
168 |
+
# 正確地從 Pickle 檔案中讀取字典,並提取 BM25 物件
|
169 |
with open(BM25_PKL, "rb") as f:
|
170 |
+
data = pickle.load(f)
|
171 |
+
# 檢查 pickle 檔案是否包含預期的字典結構
|
172 |
+
if not isinstance(data, dict) or "bm25" not in data or not isinstance(data["bm25"], BM25Okapi):
|
173 |
+
raise ValueError("Loaded BM25 is not a BM25Okapi instance or the pickle file has an unexpected format.")
|
174 |
+
self.state.bm25 = data["bm25"]
|
175 |
+
self.state.bm25_tokenized = data["tokenized"]
|
176 |
+
self.state.sentences = data["sentences"]
|
177 |
+
self.state.meta = data["meta"]
|
178 |
+
# =============
|
179 |
+
except Exception as e:
|
180 |
+
log.error(f"載入 BM25 索引失敗:{e}")
|
181 |
+
raise
|
182 |
|
183 |
+
# 載入藥名詞庫 (jieba 使用)
|
184 |
log.info("建立藥名詞庫...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
try:
|
186 |
+
with open(DRUG_VOCAB_JSON, "r", encoding="utf-8") as f:
|
187 |
+
self.state.drug_vocab = json.load(f)
|
188 |
+
for drug_name in self.state.drug_vocab.keys():
|
189 |
+
jieba.add_word(drug_name.lower())
|
|
|
|
|
|
|
|
|
190 |
except Exception as e:
|
191 |
+
log.warning(f"載入藥名詞庫失敗:{e}。部分藥名可能無法正確斷詞。")
|
|
|
192 |
|
193 |
+
# 載入藥局資料
|
194 |
+
log.info("載入藥局資料...")
|
|
|
195 |
try:
|
196 |
+
self.state.pharmacy_df = pd.read_csv(PHARMACY_DATA, dtype=str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
except Exception as e:
|
198 |
+
log.warning(f"載入藥局資料失敗:{e}。藥局查詢功能將無法使用。")
|
199 |
+
self.state.pharmacy_df = pd.DataFrame() # 確保為空 DataFrame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
+
# 初始化 OpenAI Client
|
202 |
+
if OPENAI_API_KEY:
|
203 |
+
self.state.openai_client = OpenAI(api_key=OPENAI_API_KEY)
|
204 |
+
log.info("OpenAI 客戶端初始化完成。")
|
205 |
+
else:
|
206 |
+
log.warning("未設定 OPENAI_API_KEY,意圖偵測與 LLM 回覆功能將無法使用。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
|
208 |
+
log.info("所有資料與模型載入完成。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
+
def retrieve_by_faiss(self, query: str, top_k: int = 10) -> Tuple[List[str], List[Dict]]:
|
211 |
+
"""使用 FAISS 進行向量檢索。"""
|
212 |
+
if self.state.embedding_model is None or self.state.faiss_index is None:
|
213 |
+
log.error("FAISS 或 Embedding 模型未載入。")
|
214 |
+
return [], []
|
215 |
+
|
216 |
+
query_emb = self.state.embedding_model.encode(query, convert_to_numpy=True).astype("float32")
|
217 |
+
faiss.normalize_L2(query_emb)
|
218 |
+
query_emb = query_emb.reshape(1, -1)
|
219 |
+
|
220 |
+
distances, indices = self.state.faiss_index.search(query_emb, top_k)
|
221 |
+
sentences = [self.state.sentences[i] for i in indices[0] if i != -1]
|
222 |
+
metas = [self.state.meta[i] for i in indices[0] if i != -1]
|
223 |
+
|
224 |
+
return sentences, metas
|
225 |
+
|
226 |
+
def retrieve_by_bm25(self, query: str, top_k: int = 10) -> Tuple[List[str], List[Dict]]:
|
227 |
+
"""使用 BM25 進行稀疏檢索。"""
|
228 |
+
if self.state.bm25 is None:
|
229 |
+
log.error("BM25 模型未載入。")
|
230 |
+
return [], []
|
231 |
+
|
232 |
+
query_tokenized = jieba.lcut(query)
|
233 |
+
doc_scores = self.state.bm25.get_scores(query_tokenized)
|
234 |
+
top_k_indices = np.argsort(doc_scores)[::-1][:top_k]
|
235 |
|
236 |
+
sentences = [self.state.sentences[i] for i in top_k_indices]
|
237 |
+
metas = [self.state.meta[i] for i in top_k_indices]
|
238 |
+
|
239 |
+
return sentences, metas
|
240 |
+
|
241 |
+
def rerank_results(self, query: str, pairs: List[Tuple[str, str]]) -> List[Dict]:
|
242 |
+
"""使用 Reranker 模型對結果進行重排序。"""
|
243 |
+
if self.state.reranker is None:
|
244 |
+
log.warning("Reranker 模型未載入,將略過重排序。")
|
245 |
+
return [{"text": pair[1], "score": 0.0, "source": ""} for pair in pairs]
|
246 |
+
|
247 |
+
scores = self.state.reranker.predict(pairs)
|
248 |
+
results = [{"text": pair[1], "score": score, "source": ""} for pair, score in zip(pairs, scores)]
|
249 |
+
results.sort(key=lambda x: x["score"], reverse=True)
|
250 |
+
return results
|
251 |
+
|
252 |
+
def handle_rag_query(self, query: str) -> str:
|
253 |
+
"""處理 RAG 查詢的核心邏輯。"""
|
254 |
+
if not self.state.openai_client:
|
255 |
+
return "無法使用 RAG 功能,請檢查 OPENAI_API_KEY 設定。"
|
256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
try:
|
258 |
+
# 1. 意圖偵測
|
259 |
+
intent_classifier = IntentClassifier(self.state.openai_client)
|
260 |
+
intent = intent_classifier.classify_intent(query)
|
261 |
+
|
262 |
+
# 2. 根據意圖進行不同的檢索與生成流程
|
263 |
+
final_context = ""
|
264 |
+
if intent == "drug_inquiry":
|
265 |
+
# 藥物查詢流程
|
266 |
+
drug_candidates = extract_drug_candidates_from_query(query, self.state.drug_vocab)
|
267 |
+
if not drug_candidates:
|
268 |
+
return self.generate_llm_response(query, [], intent)
|
269 |
+
|
270 |
+
contexts = []
|
271 |
+
for drug_name in drug_candidates:
|
272 |
+
contexts.extend(self.retrieve_and_rerank(query, specific_drug=drug_name))
|
273 |
+
final_context = "\n".join([c["text"] for c in contexts])
|
274 |
+
|
275 |
+
elif intent == "pharmacy_search":
|
276 |
+
# 藥局查詢流程
|
277 |
+
pharmacy_candidates = search_pharmacy(query, self.state.pharmacy_df)
|
278 |
+
if pharmacy_candidates:
|
279 |
+
response_text = "為您找到以下藥局資訊:\n" + "\n---\n".join([
|
280 |
+
f"藥局名稱:{p['醫事機構名稱']}\n電話:{p['醫事機構電話']}\n地址:{p['醫事機構地址']}"
|
281 |
+
for p in pharmacy_candidates
|
282 |
+
])
|
283 |
+
return response_text
|
284 |
+
else:
|
285 |
+
return "很抱歉,沒有找到符合條件的藥局。"
|
286 |
+
|
287 |
+
else: # general_qa
|
288 |
+
# 通用查詢流程
|
289 |
+
contexts = self.retrieve_and_rerank(query)
|
290 |
+
final_context = "\n".join([c["text"] for c in contexts])
|
291 |
|
292 |
+
return self.generate_llm_response(query, final_context, intent)
|
|
|
|
|
293 |
|
294 |
+
except Exception as e:
|
295 |
+
log.error(f"RAG 查詢處理失敗:{e}")
|
296 |
+
return f"對不起,處理您的查詢時發生錯誤:{e}。"
|
297 |
+
|
298 |
+
def retrieve_and_rerank(self, query: str, specific_drug: Optional[str] = None) -> List[Dict]:
|
299 |
+
"""執行檢索與重排序步驟。"""
|
300 |
+
# 1. 檢索
|
301 |
+
log.info(f"執行檢索:{query} (藥物:{specific_drug})")
|
302 |
+
faiss_sents, faiss_metas = self.retrieve_by_faiss(query, top_k=20)
|
303 |
+
bm25_sents, bm25_metas = self.retrieve_by_bm25(query, top_k=20)
|
304 |
+
|
305 |
+
# 2. 合併與去重
|
306 |
+
combined_sents = list(dict.fromkeys(faiss_sents + bm25_sents))
|
307 |
+
|
308 |
+
# 3. 根據藥名過濾
|
309 |
+
if specific_drug:
|
310 |
+
initial_sentences = self.get_sentences_by_drug_name(specific_drug)
|
311 |
+
combined_sents = list(dict.fromkeys(initial_sentences + combined_sents))
|
312 |
+
|
313 |
+
if not combined_sents:
|
314 |
+
return []
|
315 |
+
|
316 |
+
# 4. 重排序
|
317 |
+
log.info("執行重排序...")
|
318 |
+
pairs = [(query, s) for s in combined_sents]
|
319 |
+
reranked_results = self.rerank_results(query, pairs)
|
320 |
+
return reranked_results[:5] # 取 Top 5
|
321 |
+
|
322 |
+
def get_sentences_by_drug_name(self, drug_name: str) -> List[str]:
|
323 |
+
"""從元資料中篩選出與特定藥物相關的句子。"""
|
324 |
+
sentences = []
|
325 |
+
for meta, sentence in zip(self.state.meta, self.state.sentences):
|
326 |
+
if meta.get("drug_name_norm") == drug_name.lower():
|
327 |
+
sentences.append(sentence)
|
328 |
+
return sentences
|
329 |
+
|
330 |
+
def generate_llm_response(self, query: str, context: str, intent: str) -> str:
|
331 |
+
"""使用 LLM 生成最終回覆。"""
|
332 |
+
log.info(f"使用 LLM 生成回覆,意圖:{intent}")
|
333 |
+
system_prompt = f"""你是一個專業的藥物與健康資訊問答助理。
|
334 |
+
- 請根據使用者提供的「使用者查詢」與「相關資訊」來回答問題。
|
335 |
+
- 如果「相關資訊」中沒有足夠的資訊來回答,請禮貌地告知使用者。
|
336 |
+
- 你的回答應簡潔、易懂,並使用繁體中文。
|
337 |
+
- 在回答中,應明確指出資訊來源是來自衛福部、藥廠、或是其他相關法規文件。
|
338 |
+
- 如果使用者詢問的是特定藥物,請在回答中提及藥名。
|
339 |
+
- 如果意圖是 `pharmacy_search`,請直接告知使用者這是關於藥局的查詢,並說「很抱歉,我無法提供藥局資訊。」
|
340 |
+
- 如果意圖是 `general_qa`,且相關資訊不足,請回答「對不起,我無法回答您的問題。」
|
341 |
+
- 如果意圖是 `drug_inquiry`,且相關資訊不足,請回答「對不起,我無法找到該藥物的相關資訊。」
|
342 |
+
|
343 |
+
相關資訊:
|
344 |
+
{context}
|
345 |
|
346 |
+
使用者查詢:
|
347 |
+
{query}
|
348 |
+
"""
|
349 |
+
try:
|
350 |
+
response = self.state.openai_client.chat.completions.create(
|
351 |
+
model="gpt-4o-mini", # 或其他適用模型
|
352 |
+
messages=[
|
353 |
+
{"role": "system", "content": system_prompt},
|
354 |
+
{"role": "user", "content": f"問題:{query}"}
|
355 |
+
],
|
356 |
+
temperature=0.0
|
357 |
+
)
|
358 |
+
return response.choices[0].message.content.strip()
|
359 |
+
except Exception as e:
|
360 |
+
log.error(f"LLM 回覆生成失敗:{e}")
|
361 |
+
return "對不起,生成回覆時發生錯誤。"
|
362 |
+
|
363 |
+
# --- 藥物與藥局輔助函式 ---
|
364 |
+
def search_pharmacy(query: str, df: pd.DataFrame) -> List[Dict]:
|
365 |
+
"""根據關鍵字在藥局資料中搜尋。"""
|
366 |
+
if df.empty:
|
367 |
+
return []
|
368 |
+
keywords = jieba.lcut_for_search(query)
|
369 |
+
results = []
|
370 |
+
for _, row in df.iterrows():
|
371 |
+
match_count = sum(1 for k in keywords if k in row.to_string())
|
372 |
+
if match_count > 0:
|
373 |
+
results.append(row.to_dict())
|
374 |
+
return results
|
375 |
|
376 |
+
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
|
377 |
+
"""從查詢中提取藥名。"""
|
378 |
+
candidates = set()
|
379 |
+
q_lower = query.lower()
|
380 |
+
for drug_name, normalized_name in drug_vocab.items():
|
381 |
+
if drug_name in q_lower:
|
382 |
+
candidates.add(normalized_name)
|
383 |
+
return list(candidates)
|
384 |
|
385 |
+
# --- LINE API 相關函式 ---
|
386 |
+
def validate_signature(request_body: bytes, signature: str):
|
387 |
+
"""驗證 LINE 傳來的簽名。"""
|
388 |
+
hash = hmac.new(SECRET_TOKEN.encode('utf-8'), request_body, hashlib.sha256).digest()
|
389 |
+
return hmac.compare_digest(signature.encode('utf-8'), base64.b64encode(hash))
|
390 |
|
391 |
+
@lru_cache(maxsize=128)
|
392 |
+
def line_api_call(endpoint: str, data: dict):
|
393 |
+
"""呼叫 LINE Messaging API。"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
headers = {
|
395 |
"Content-Type": "application/json",
|
396 |
+
"Authorization": f"Bearer {ACCESS_TOKEN}"
|
397 |
}
|
398 |
try:
|
399 |
response = requests.post(f"https://api.line.me/v2/bot/message/{endpoint}", headers=headers, json=data, timeout=10)
|
|
|
403 |
raise
|
404 |
|
405 |
def line_reply(reply_token: str, text: str):
|
406 |
+
"""回覆 LINE 訊息。"""
|
407 |
+
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, MAX_REPLY_LEN, replace_whitespace=False)[:5]]
|
408 |
line_api_call("reply", {"replyToken": reply_token, "messages": messages})
|
409 |
|
410 |
def line_push_generic(source_type: str, target_id: str, text: str):
|
411 |
+
"""推播 LINE 訊息。"""
|
412 |
+
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, MAX_REPLY_LEN, replace_whitespace=False)[:5]]
|
413 |
endpoint = "push"
|
414 |
data = {"to": target_id, "messages": messages}
|
415 |
line_api_call(endpoint, data)
|
416 |
|
417 |
+
# --- FastAPI 應用程式設定 ---
|
418 |
+
app = FastAPI()
|
419 |
+
rag_pipeline = RagPipeline()
|
420 |
+
|
421 |
+
@app.on_event("startup")
|
422 |
+
async def startup_event():
|
423 |
+
"""應用程式啟動時載入所有 RAG 模型。"""
|
424 |
+
try:
|
425 |
+
rag_pipeline.load_data()
|
426 |
+
except Exception as e:
|
427 |
+
log.error(f"應用程式啟動失敗:{e}")
|
428 |
+
# 在 Docker 環境中,啟動失敗會導致容器結束
|
429 |
+
# 這裡的 exit code 3 是為了在 CI/CD 中標記為應用程式層的錯誤
|
430 |
+
# 在本地執行時,這將直接結束程式
|
431 |
+
os._exit(3)
|
432 |
+
|
433 |
+
@app.get("/")
|
434 |
+
async def root():
|
435 |
+
"""根目錄,用於健康檢查。"""
|
436 |
+
return {"message": "Hello, I am a DrugQA bot! Use me with LINE."}
|
437 |
+
|
438 |
+
@app.post("/callback")
|
439 |
+
async def callback(request: Request, background_tasks: BackgroundTasks):
|
440 |
+
"""LINE Webhook 回呼處理。"""
|
441 |
+
signature = request.headers.get("X-Line-Signature")
|
442 |
+
if not signature:
|
443 |
+
raise HTTPException(status_code=400, detail="X-Line-Signature header is missing.")
|
444 |
+
|
445 |
+
body = await request.body()
|
446 |
+
if not validate_signature(body, signature):
|
447 |
+
raise HTTPException(status_code=400, detail="Invalid signature.")
|
448 |
+
|
449 |
+
try:
|
450 |
+
events = json.loads(body)["events"]
|
451 |
+
for event in events:
|
452 |
+
if event["type"] == "message" and event["message"]["type"] == "text":
|
453 |
+
reply_token = event["replyToken"]
|
454 |
+
query_text = event["message"]["text"]
|
455 |
+
# 將耗時的 RAG 任務放入背景執行
|
456 |
+
background_tasks.add_task(process_user_message, reply_token, query_text)
|
457 |
+
return "OK"
|
458 |
+
except json.JSONDecodeError:
|
459 |
+
raise HTTPException(status_code=400, detail="Invalid JSON body.")
|
460 |
+
except Exception as e:
|
461 |
+
log.error(f"處理 LINE 訊息失敗:{e}")
|
462 |
+
raise HTTPException(status_code=500, detail="Internal Server Error.")
|
463 |
+
|
464 |
+
def process_user_message(reply_token: str, query: str):
|
465 |
+
"""在背景處理使用者訊息。"""
|
466 |
+
try:
|
467 |
+
start_time = time.time()
|
468 |
+
# 處理 RAG 查詢
|
469 |
+
response = rag_pipeline.handle_rag_query(query)
|
470 |
+
end_time = time.time()
|
471 |
+
log.info(f"查詢 '{query}' 處理完成,耗時 {end_time - start_time:.2f} 秒。")
|
472 |
+
|
473 |
+
# 回覆使用者
|
474 |
+
line_reply(reply_token, response)
|
475 |
+
except Exception as e:
|
476 |
+
log.error(f"背景任務執行失敗:{e}")
|
477 |
+
line_reply(reply_token, "對不起,服務目前無法使用,請稍後再試。")
|
478 |
|
|
|
479 |
if __name__ == "__main__":
|
480 |
+
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", 8000)), log_level="info")
|
|