Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
c40d5cc
1
Parent(s):
f03b053
hi
Browse files
app.py
CHANGED
@@ -26,7 +26,7 @@ for d in (os.getenv("HF_HOME"), os.getenv("SENTENCE_TRANSFORMERS_HOME"), os.gete
|
|
26 |
|
27 |
# ---------- Imports ----------
|
28 |
import re, hmac, base64, hashlib, pickle, logging, time, json
|
29 |
-
from typing import List, Dict, Any, Optional, Tuple
|
30 |
|
31 |
import numpy as np
|
32 |
import pandas as pd
|
@@ -146,6 +146,7 @@ SECTION_WEIGHTS = {
|
|
146 |
}
|
147 |
|
148 |
IMPORTANT_SECTIONS = ["用法及用量", "病人使用須知", "包裝及儲存", "不良反應", "警語及注意事項"]
|
|
|
149 |
|
150 |
# ---------- 路徑工具 ----------
|
151 |
def pick_existing_or_tmp(candidates: List[str]) -> str:
|
@@ -246,7 +247,7 @@ class State:
|
|
246 |
faiss_index: Optional[Any] = None
|
247 |
bm25: Optional[Any] = None
|
248 |
df_csv: Optional[pd.DataFrame] = None
|
249 |
-
user_sessions: Dict[str, Dict[str, Any]] = {}
|
250 |
query_cache: Dict[str, Dict[str, Any]] = {}
|
251 |
|
252 |
STATE = State()
|
@@ -320,7 +321,8 @@ def ensure_bm25(pkl_path: str, sentences: List[str]) -> Optional[Any]:
|
|
320 |
try:
|
321 |
with open(pkl_path, "rb") as f:
|
322 |
bm = pickle.load(f)
|
323 |
-
|
|
|
324 |
if n_bm == len(sentences):
|
325 |
log.info("Loaded BM25: %s (n=%d)", pkl_path, n_bm)
|
326 |
return bm
|
@@ -334,42 +336,141 @@ def ensure_bm25(pkl_path: str, sentences: List[str]) -> Optional[Any]:
|
|
334 |
safe_pickle_dump(bm, pkl_path)
|
335 |
return bm
|
336 |
|
337 |
-
# ----------
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
for cand in candidates:
|
349 |
-
|
350 |
-
if cand in DRUG_NAME_MAPPING:
|
351 |
-
drug_ids.add(DRUG_NAME_MAPPING[cand])
|
352 |
-
continue
|
353 |
-
bucket = []
|
354 |
-
for _, row in df.iterrows():
|
355 |
name_joined = f"{(row.get('drug_name_zh') or '').lower()} {(row.get('drug_name_en') or '').lower()} {(row.get('drug_name_norm') or '').lower()}".strip()
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
358 |
fuzz.token_set_ratio(cand, name_joined),
|
359 |
fuzz.partial_ratio(cand, name_joined)
|
360 |
)
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
score =
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
# ---------- 意圖偵測 ----------
|
375 |
def detect_intent(query: str) -> List[str]:
|
@@ -402,58 +503,71 @@ def rerank_results(query: str, candidates: List[Tuple[int, float, float, float]]
|
|
402 |
log.warning("Rerank failed: %s", e)
|
403 |
return [{"idx": i, "score": fused} for i, fused, _, _ in sorted(candidates, key=lambda x: -x[1])[:top_k]] # fallback
|
404 |
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
|
|
410 |
return STATE.query_cache[cache_key]['idxs']
|
411 |
-
|
412 |
-
log.info("
|
413 |
-
if not
|
414 |
-
log.warning("No
|
415 |
-
|
416 |
-
|
|
|
|
|
|
|
417 |
if bm25:
|
418 |
-
|
419 |
-
|
420 |
-
if np.max(
|
421 |
-
scores_norm = (
|
422 |
else:
|
423 |
-
scores_norm =
|
424 |
-
|
425 |
-
for
|
426 |
-
if 0 <=
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
sem_results = []
|
431 |
if emb_model and index:
|
432 |
-
q_emb = emb_model.encode([
|
433 |
_, idxs = index.search(q_emb, top_k * 8)
|
434 |
-
for rank,
|
435 |
-
if 0 <=
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
for
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
idxs = [r["idx"] for r in reranked]
|
|
|
454 |
STATE.query_cache[cache_key] = {'idxs': idxs, 'time': time.time()}
|
455 |
return idxs
|
456 |
|
|
|
457 |
def build_context(idxs: List[int], sentences: List[str], meta: List[Dict[str, Any]]) -> str:
|
458 |
ctx_lines, total_len, seen = [], 0, set()
|
459 |
for i in idxs:
|
@@ -461,34 +575,40 @@ def build_context(idxs: List[int], sentences: List[str], meta: List[Dict[str, An
|
|
461 |
text = sentences[i]
|
462 |
if text in seen: continue
|
463 |
chunk_id = meta[i].get("chunk_id", "None")
|
464 |
-
|
|
|
465 |
if total_len + len(line) > MAX_CONTEXT_CHARS: break
|
466 |
ctx_lines.append(line)
|
467 |
total_len += len(line) + 1
|
468 |
seen.add(text)
|
469 |
-
return "\n".join(ctx_lines) or "[
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
if
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
elif intent in ["時間/併用 (Timing & Interaction)", "劑量調整 (Dosage Adjustment)"]:
|
485 |
-
ts_parts.append("優先藥袋醫囑(如每日1顆,早餐後)。範圍 [Sxxx]。特殊:病人使用須知。")
|
486 |
-
trouble_shooting = " ".join(ts_parts)
|
487 |
return (
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
492 |
)
|
493 |
|
494 |
def call_llm(prompt: str, max_tokens: int = 2048) -> Optional[str]:
|
@@ -505,12 +625,9 @@ def call_llm(prompt: str, max_tokens: int = 2048) -> Optional[str]:
|
|
505 |
t0 = time.time()
|
506 |
resp = client.chat.completions.create(
|
507 |
model=LM_MODEL,
|
508 |
-
messages=[
|
509 |
-
|
510 |
-
|
511 |
-
],
|
512 |
-
temperature=0.2,
|
513 |
-
timeout=10,
|
514 |
max_tokens=max_tokens,
|
515 |
)
|
516 |
used = time.time() - t0
|
@@ -520,142 +637,77 @@ def call_llm(prompt: str, max_tokens: int = 2048) -> Optional[str]:
|
|
520 |
log.warning("LLM 失敗:%s", e)
|
521 |
return None
|
522 |
|
523 |
-
|
524 |
-
|
525 |
-
parsed = {"drug_name": "", "dosage_form": "", "strength": "", "symptom": "", "usage_record": "", "patient_info": "", "concomitant": "", "question": query}
|
526 |
-
# 改進正則:支援中英、單位如 mcg/hr
|
527 |
-
drug_match = re.search(r"([\w\s]+?)(?:\s*(\d+(?:\.\d+)?\s*(?:mg|mcg|µg|g|mcg/hr|mg/hr))?\s*(\w+))?", query, re.I)
|
528 |
-
if drug_match:
|
529 |
-
parsed["drug_name"] = drug_match.group(1).strip()
|
530 |
-
parsed["strength"] = drug_match.group(2) or ""
|
531 |
-
parsed["dosage_form"] = drug_match.group(3) or ""
|
532 |
-
symptom_match = re.search(r"(頭痛|發燒|拉肚子|症狀|目的)\s*([^,;]*)", query)
|
533 |
-
if symptom_match:
|
534 |
-
parsed["symptom"] = symptom_match.group(2).strip()
|
535 |
-
usage_match = re.search(r"(吃|用|貼|喝|注射)\s*(\d+)\s*(\w+)", query)
|
536 |
-
if usage_match:
|
537 |
-
parsed["usage_record"] = f"{usage_match.group(2)} {usage_match.group(3)}"
|
538 |
-
patient_match = re.search(r"(成人|兒童|孕|哺|過敏|肝|腎)", query)
|
539 |
-
if patient_match:
|
540 |
-
parsed["patient_info"] = patient_match.group(1)
|
541 |
-
concomitant_match = re.search(r"(併用|其他藥|保健|咖啡|酒|空腹)\s*([^,;]*)", query)
|
542 |
-
if concomitant_match:
|
543 |
-
parsed["concomitant"] = concomitant_match.group(2).strip()
|
544 |
-
return parsed
|
545 |
-
|
546 |
-
def make_clarify_message(drug_name_hint: str = "") -> str:
|
547 |
msg = (
|
548 |
-
"
|
549 |
-
"1.
|
550 |
-
"2.
|
551 |
-
"3.
|
|
|
552 |
)
|
553 |
-
|
554 |
-
msg = (
|
555 |
-
"目前無法識別特定藥名,我會先提供一般性建議。"
|
556 |
-
"請補充藥名、劑型與強度。\n" + msg
|
557 |
-
)
|
558 |
-
return msg + DISCLAIMER
|
559 |
-
|
560 |
-
# ---------- 新增: 藥名處理 ----------
|
561 |
-
def process_drug_names(drug_ids: List[str]) -> List[str]:
|
562 |
-
if not drug_ids:
|
563 |
-
log.warning("NO_DRUG_ID")
|
564 |
-
return []
|
565 |
|
566 |
-
if len(drug_ids) == 1:
|
567 |
-
log.info(f"單藥名模式: {drug_ids[0]}")
|
568 |
-
return drug_ids
|
569 |
-
else:
|
570 |
-
log.info(f"多藥名模式: {drug_ids}")
|
571 |
-
return drug_ids
|
572 |
-
|
573 |
-
# ---------- 新增: 回覆壓縮 ----------
|
574 |
-
def compress_reply(reply: str, max_len: int = 100) -> str:
|
575 |
-
if len(reply) <= max_len:
|
576 |
-
return reply
|
577 |
-
|
578 |
-
# 刪除修飾詞、合併句
|
579 |
-
reply = re.sub(r'(例如|比如|像是|可能會|有時候|通常|一般來說|另外|而且|因此|所以)', '', reply)
|
580 |
-
reply = re.sub(r'\s+', ' ', reply).strip()
|
581 |
-
reply = re.sub(r'(\.|\?|\!)\s*', r'\1 ', reply) # 合併句
|
582 |
-
compressed = reply[:max_len] + '...' if len(reply) > max_len else reply
|
583 |
-
log.info("回覆超長,自動壓縮")
|
584 |
-
return compressed
|
585 |
-
|
586 |
-
# ---------- 新增: 異常處理 ----------
|
587 |
def handle_error(code: str) -> str:
|
588 |
log.error(f"Pipeline error: {code}")
|
589 |
-
return "
|
590 |
-
|
591 |
-
# ---------- 新增: 全鏈路log ----------
|
592 |
-
def log_pipeline(user_id: str, query: str, parsed: Dict, candidates: List, drug_choices: List,
|
593 |
-
retrieval: Dict, sections: List, context: str, prompt: str, reply: str, error_code: str = None):
|
594 |
-
log.info(f"1. user_id: {user_id}, query: {query[:50] + '...' if len(query)>50 else query} (desensitized)")
|
595 |
-
log.info(f"2. parsed={parsed}")
|
596 |
-
log.info(f"3. candidates={candidates}")
|
597 |
-
log.info(f"4. drug_id_pick: {drug_choices} (每個候選的分數、淘汰原因 in find_drug_ids)")
|
598 |
-
log.info(f"5. retrieval: BM25_topN={retrieval.get('bm_top', 0)}, SEM_topN={retrieval.get('sem_top', 0)}, fused_top10={retrieval.get('fused_top10', [])}")
|
599 |
-
log.info(f"6. Injected sections: {sections}")
|
600 |
-
log.info(f"7. contexts_chars: {len(context)}, prompt_chars: {len(prompt)}, tokens: ~{len(prompt)//4}")
|
601 |
-
log.info(f"8. LLM: model={LM_MODEL}, time={retrieval.get('llm_time', 0):.2f}s, truncated={len(reply)>200}, output_chars={len(reply)}")
|
602 |
-
if error_code:
|
603 |
-
log.error(f"error_code={error_code}")
|
604 |
|
|
|
605 |
async def answer_pipeline(query: str, user_id: str) -> str:
|
606 |
log.info("Pipeline start for user_id: %s, query: %s", user_id, query[:50])
|
607 |
if not query or not isinstance(query, str):
|
608 |
return handle_error("INVALID_QUERY")
|
609 |
if not STATE.sentences:
|
610 |
return handle_error("NO_CORPUS")
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
|
|
659 |
|
660 |
# ---------- LINE 驗簽與回覆 ----------
|
661 |
def verify_line_signature(body_bytes: bytes, signature: str) -> bool:
|
|
|
26 |
|
27 |
# ---------- Imports ----------
|
28 |
import re, hmac, base64, hashlib, pickle, logging, time, json
|
29 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
30 |
|
31 |
import numpy as np
|
32 |
import pandas as pd
|
|
|
146 |
}
|
147 |
|
148 |
IMPORTANT_SECTIONS = ["用法及用量", "病人使用須知", "包裝及儲存", "不良反應", "警語及注意事項"]
|
149 |
+
DOSAGE_FORM_BOOST = 1.2 # NEW: 劑型匹配的權重提升
|
150 |
|
151 |
# ---------- 路徑工具 ----------
|
152 |
def pick_existing_or_tmp(candidates: List[str]) -> str:
|
|
|
247 |
faiss_index: Optional[Any] = None
|
248 |
bm25: Optional[Any] = None
|
249 |
df_csv: Optional[pd.DataFrame] = None
|
250 |
+
user_sessions: Dict[str, Dict[str, Any]] = {}
|
251 |
query_cache: Dict[str, Dict[str, Any]] = {}
|
252 |
|
253 |
STATE = State()
|
|
|
321 |
try:
|
322 |
with open(pkl_path, "rb") as f:
|
323 |
bm = pickle.load(f)
|
324 |
+
# MODIFIED: BM25 has corpus, not corpus_size attribute
|
325 |
+
n_bm = len(bm.corpus) if hasattr(bm, 'corpus') else 0
|
326 |
if n_bm == len(sentences):
|
327 |
log.info("Loaded BM25: %s (n=%d)", pkl_path, n_bm)
|
328 |
return bm
|
|
|
336 |
safe_pickle_dump(bm, pkl_path)
|
337 |
return bm
|
338 |
|
339 |
+
# ---------- 資訊解析與藥名處理 (MODIFIED & NEW) ----------
|
340 |
+
|
341 |
+
def parse_user_message(query: str) -> Dict[str, Any]:
|
342 |
+
"""
|
343 |
+
NEW: 使用更強的正規表示式從使用者問題中提取結構化資訊
|
344 |
+
"""
|
345 |
+
parsed = {
|
346 |
+
"drug_name": "", "strength": "", "dosage_form": "", "question": query, "raw_query": query
|
347 |
+
}
|
348 |
+
|
349 |
+
# Regex to find drug name, strength (e.g., 500mg, 25 mcg/hr), and dosage form
|
350 |
+
# It tries to find a noun-like part followed by numbers/units and another noun-like part
|
351 |
+
drug_pattern = re.compile(
|
352 |
+
r"([\u4e00-\u9fa5a-zA-Z\s\d\.\-]+?)" # Drug name (non-greedy)
|
353 |
+
r"[::\s]*?" # Optional separator
|
354 |
+
r"(\d+(?:\.\d+)?\s*(?:mg|mcg|µg|g|mcg/hr|mg/hr|iu|國際單位))?" # Strength (optional)
|
355 |
+
r"[\s]*?" # Optional space
|
356 |
+
r"([\u4e00-\u9fa5a-zA-Z]+(?:劑|錠|膠囊|糖漿|乳膏|貼片|噴劑|噴霧|吸入劑))?" # Dosage form (optional)
|
357 |
+
, re.I)
|
358 |
+
|
359 |
+
match = drug_pattern.search(query)
|
360 |
+
question_part = query
|
361 |
+
|
362 |
+
if match:
|
363 |
+
drug_name = (match.group(1) or "").strip()
|
364 |
+
strength = (match.group(2) or "").strip()
|
365 |
+
dosage_form = (match.group(3) or "").strip()
|
366 |
+
|
367 |
+
# Clean up drug name from common question words if they are at the end
|
368 |
+
drug_name = re.sub(r"(怎麼用|怎麼吃|副作用|的用法)$", "", drug_name).strip()
|
369 |
+
|
370 |
+
parsed.update({"drug_name": drug_name, "strength": strength, "dosage_form": dosage_form})
|
371 |
+
|
372 |
+
# The rest of the query is the question
|
373 |
+
question_part = query[match.end():].strip()
|
374 |
+
if not question_part:
|
375 |
+
question_part = query # fallback if parsing consumes whole string
|
376 |
+
|
377 |
+
parsed["question"] = question_part
|
378 |
+
log.info("Parsed user message: %s", parsed)
|
379 |
+
return parsed
|
380 |
+
|
381 |
+
|
382 |
+
def find_drug_candidates(parsed_info: Dict[str, Any], df: pd.DataFrame) -> List[Dict[str, Any]]:
|
383 |
+
"""
|
384 |
+
MODIFIED: Find drug candidates based on parsed info and return a ranked list of dicts.
|
385 |
+
"""
|
386 |
+
if df is None or df.empty or not parsed_info.get("drug_name"):
|
387 |
+
return []
|
388 |
+
|
389 |
+
query_drug_name = parsed_info["drug_name"].lower()
|
390 |
+
query_dosage_form = parsed_info["dosage_form"]
|
391 |
+
|
392 |
+
# Use jieba to get core drug name
|
393 |
+
tokens = tokenize_zh(query_drug_name)
|
394 |
+
candidates = [t.lower() for t in tokens if len(t) > 1 and t not in DRUG_STOPWORDS]
|
395 |
+
candidates.append(query_drug_name) # also search the raw name
|
396 |
+
candidates = list(set([DRUG_NAME_MAPPING.get(c, c) for c in candidates]))
|
397 |
+
|
398 |
+
drug_bucket = []
|
399 |
+
|
400 |
+
# Get all unique drug_ids and their names to avoid iterating the whole dataframe repeatedly
|
401 |
+
unique_drugs = df.drop_duplicates(subset=['drug_id'])
|
402 |
+
|
403 |
for cand in candidates:
|
404 |
+
for _, row in unique_drugs.iterrows():
|
|
|
|
|
|
|
|
|
|
|
405 |
name_joined = f"{(row.get('drug_name_zh') or '').lower()} {(row.get('drug_name_en') or '').lower()} {(row.get('drug_name_norm') or '').lower()}".strip()
|
406 |
+
|
407 |
+
if not fuzz:
|
408 |
+
raw_score = 100 if cand in name_joined else 0
|
409 |
+
else:
|
410 |
+
raw_score = max(
|
411 |
fuzz.token_set_ratio(cand, name_joined),
|
412 |
fuzz.partial_ratio(cand, name_joined)
|
413 |
)
|
414 |
+
|
415 |
+
if raw_score >= 85:
|
416 |
+
# Boost score for English name match, length, and dosage form match
|
417 |
+
score = raw_score
|
418 |
+
if re.search(r'[a-zA-Z]', cand):
|
419 |
+
score *= 1.2
|
420 |
+
if query_dosage_form and query_dosage_form in name_joined:
|
421 |
+
score *= 1.1
|
422 |
+
|
423 |
+
drug_bucket.append({
|
424 |
+
"score": score,
|
425 |
+
"drug_id": row["drug_id"],
|
426 |
+
"drug_name_zh": row.get('drug_name_zh'),
|
427 |
+
"drug_name_en": row.get('drug_name_en'),
|
428 |
+
"matched_term": cand
|
429 |
+
})
|
430 |
+
|
431 |
+
if not drug_bucket:
|
432 |
+
return []
|
433 |
+
|
434 |
+
# Sort and deduplicate results
|
435 |
+
sorted_bucket = sorted(drug_bucket, key=lambda x: x['score'], reverse=True)
|
436 |
+
seen_ids = set()
|
437 |
+
unique_top = []
|
438 |
+
for item in sorted_bucket:
|
439 |
+
if item['drug_id'] not in seen_ids:
|
440 |
+
unique_top.append(item)
|
441 |
+
seen_ids.add(item['drug_id'])
|
442 |
+
|
443 |
+
log.info(f"Found drug candidates: {unique_top[:5]}")
|
444 |
+
return unique_top[:5] # Return top 5 candidates
|
445 |
+
|
446 |
+
def select_best_drug_candidate(candidates: List[Dict[str, Any]]) -> Union[Dict[str, Any], str, None]:
|
447 |
+
"""
|
448 |
+
NEW: Logic to decide if we have a clear winner or need clarification.
|
449 |
+
"""
|
450 |
+
if not candidates:
|
451 |
+
return None
|
452 |
+
|
453 |
+
# Case 1: The top candidate has a very high score and is significantly better than the second
|
454 |
+
if len(candidates) == 1 and candidates[0]['score'] >= 90:
|
455 |
+
return candidates[0]
|
456 |
+
|
457 |
+
if len(candidates) > 1:
|
458 |
+
top_score = candidates[0]['score']
|
459 |
+
second_score = candidates[1]['score']
|
460 |
+
if top_score >= 95 and (top_score - second_score) > 10:
|
461 |
+
return candidates[0]
|
462 |
+
|
463 |
+
# Case 2: Multiple candidates are close in score, ask for clarification
|
464 |
+
if len(candidates) > 1 and (candidates[0]['score'] - candidates[1]['score']) <= 10:
|
465 |
+
options = [f"「{c.get('drug_name_zh') or c.get('drug_name_en')}」" for c in candidates[:3]]
|
466 |
+
return f"請問您指的是以下哪一種藥物?\n- " + "\n- ".join(options)
|
467 |
+
|
468 |
+
# Case 3: One candidate, but score is not high enough to be confident
|
469 |
+
if candidates[0]['score'] < 90:
|
470 |
+
return None
|
471 |
+
|
472 |
+
return candidates[0]
|
473 |
+
|
474 |
|
475 |
# ---------- 意圖偵測 ----------
|
476 |
def detect_intent(query: str) -> List[str]:
|
|
|
503 |
log.warning("Rerank failed: %s", e)
|
504 |
return [{"idx": i, "score": fused} for i, fused, _, _ in sorted(candidates, key=lambda x: -x[1])[:top_k]] # fallback
|
505 |
|
506 |
+
# MODIFIED: fuse_and_select now accepts parsed_info to boost scores based on dosage form
|
507 |
+
def fuse_and_select(query: str, sentences: List[str], meta: List[Dict[str, Any]], bm25: Optional[Any], index: Optional[Any], emb_model: Optional[Any], reranker: Optional[Any], top_k: int = 10, drug_id: str = None, parsed_info: Dict[str, Any] = None) -> List[int]:
|
508 |
+
clean_query = parsed_info.get("question", query).strip().lower()
|
509 |
+
cache_key = clean_query + str(drug_id)
|
510 |
+
if cache_key in STATE.query_cache and time.time() - STATE.query_cache[cache_key]['time'] < 180:
|
511 |
+
log.info("Cache hit for query: %s", clean_query[:50])
|
512 |
return STATE.query_cache[cache_key]['idxs']
|
513 |
+
|
514 |
+
log.info("Searching for drug_id: %s with query: %s", drug_id, clean_query[:50])
|
515 |
+
if not drug_id:
|
516 |
+
log.warning("No drug_id provided; falling back to full corpus search.")
|
517 |
+
|
518 |
+
tokenized_query = tokenize_zh(clean_query)
|
519 |
+
scores = {}
|
520 |
+
|
521 |
+
# BM25 lexical search
|
522 |
if bm25:
|
523 |
+
bm_scores = bm25.get_scores(tokenized_query)
|
524 |
+
bm_scores_np = np.array(bm_scores)
|
525 |
+
if np.max(bm_scores_np) > np.min(bm_scores_np):
|
526 |
+
scores_norm = (bm_scores_np - np.min(bm_scores_np)) / (np.max(bm_scores_np) - np.min(bm_scores_np))
|
527 |
else:
|
528 |
+
scores_norm = bm_scores_np
|
529 |
+
|
530 |
+
for i, s_norm in enumerate(scores_norm):
|
531 |
+
if 0 <= i < len(meta) and (not drug_id or meta[i].get("drug_id") == drug_id):
|
532 |
+
scores[i] = scores.get(i, 0.0) + BM25_WEIGHT * s_norm
|
533 |
+
|
534 |
+
# FAISS semantic search
|
|
|
535 |
if emb_model and index:
|
536 |
+
q_emb = emb_model.encode([clean_query], normalize_embeddings=True).astype(np.float32)
|
537 |
_, idxs = index.search(q_emb, top_k * 8)
|
538 |
+
for rank, i in enumerate(idxs[0].tolist()):
|
539 |
+
if 0 <= i < len(meta) and (not drug_id or meta[i].get("drug_id") == drug_id):
|
540 |
+
scores[i] = scores.get(i, 0.0) + SEM_WEIGHT * (1.0 / (1 + rank))
|
541 |
+
|
542 |
+
# Apply boosts
|
543 |
+
query_dosage_form = parsed_info.get("dosage_form") if parsed_info else ""
|
544 |
+
for i in list(scores.keys()): # Iterate over a copy of keys
|
545 |
+
meta_item = meta[i]
|
546 |
+
|
547 |
+
# Section weight boost
|
548 |
+
sec = meta_item.get("section", "其他")
|
549 |
+
scores[i] *= SECTION_WEIGHTS.get(sec, 1.0)
|
550 |
+
|
551 |
+
# NEW: Dosage form boost
|
552 |
+
if query_dosage_form and query_dosage_form in sentences[i]:
|
553 |
+
scores[i] *= DOSAGE_FORM_BOOST
|
554 |
+
|
555 |
+
# Inject important sections if they are missing
|
556 |
+
for sec in IMPORTANT_SECTIONS:
|
557 |
+
sec_idx = next((i for i, m in enumerate(meta) if (m.get("drug_id") == drug_id) and m.get("section") == sec), None)
|
558 |
+
if sec_idx is not None and sec_idx not in scores:
|
559 |
+
scores[sec_idx] = 1.0 # Give it a moderate score to ensure inclusion before reranking
|
560 |
+
|
561 |
+
# Prepare for reranking
|
562 |
+
candidates = [(i, sc, 0.0, 0.0) for i, sc in scores.items()]
|
563 |
+
|
564 |
+
reranked = rerank_results(clean_query, candidates, sentences, reranker, top_k, RERANK_THRESHOLD)
|
565 |
idxs = [r["idx"] for r in reranked]
|
566 |
+
|
567 |
STATE.query_cache[cache_key] = {'idxs': idxs, 'time': time.time()}
|
568 |
return idxs
|
569 |
|
570 |
+
|
571 |
def build_context(idxs: List[int], sentences: List[str], meta: List[Dict[str, Any]]) -> str:
|
572 |
ctx_lines, total_len, seen = [], 0, set()
|
573 |
for i in idxs:
|
|
|
575 |
text = sentences[i]
|
576 |
if text in seen: continue
|
577 |
chunk_id = meta[i].get("chunk_id", "None")
|
578 |
+
section = meta[i].get("section", "未知章節")
|
579 |
+
line = f"[{section}]: {text}" # MODIFIED: Show section name for better context
|
580 |
if total_len + len(line) > MAX_CONTEXT_CHARS: break
|
581 |
ctx_lines.append(line)
|
582 |
total_len += len(line) + 1
|
583 |
seen.add(text)
|
584 |
+
return "\n".join(ctx_lines) or "[未知章節]: 沒有找到相關資料,請諮詢醫師或藥師。"
|
585 |
+
|
586 |
+
# MODIFIED: Prompt now includes structured patient info
|
587 |
+
def build_prompt(parsed_info: Dict[str, Any], contexts: str, drug_choice: Dict[str, Any]) -> str:
|
588 |
+
|
589 |
+
patient_context_parts = []
|
590 |
+
if parsed_info.get('strength'):
|
591 |
+
patient_context_parts.append(f"劑量: {parsed_info['strength']}")
|
592 |
+
if parsed_info.get('dosage_form'):
|
593 |
+
patient_context_parts.append(f"劑型: {parsed_info['dosage_form']}")
|
594 |
+
|
595 |
+
patient_context_str = " ".join(patient_context_parts)
|
596 |
+
if not patient_context_str:
|
597 |
+
patient_context_str = "未提供"
|
598 |
+
|
|
|
|
|
|
|
599 |
return (
|
600 |
+
"你是一位專業、有同理心的藥師。請根據提供的「參考片段」,並考量「病患已知資訊」,簡潔地回答使用者的「問題」。\n"
|
601 |
+
"---限制---\n"
|
602 |
+
"- 絕對忠於「參考片段」,不可捏造或過度推論。你的知識僅限於提供的片段。\n"
|
603 |
+
"- 回覆少於 120 字,並使用繁體中文條列式 2-4 點說明。\n"
|
604 |
+
"- 語氣親切、精簡、專業。\n"
|
605 |
+
"- 若片段中無足夠資訊回答,必須回覆:「根據提供的資料,我無法找到關於您問題的明確答案,建議您諮詢醫師或藥師。」\n"
|
606 |
+
"---輸入資訊---\n"
|
607 |
+
f"藥物名稱: {drug_choice.get('drug_name_zh') or drug_choice.get('drug_name_en')}\n"
|
608 |
+
f"病患已知資訊: {patient_context_str}\n"
|
609 |
+
f"問題: {parsed_info.get('raw_query')}\n\n"
|
610 |
+
f"參考片段:\n{contexts}\n"
|
611 |
+
"---你的回答---"
|
612 |
)
|
613 |
|
614 |
def call_llm(prompt: str, max_tokens: int = 2048) -> Optional[str]:
|
|
|
625 |
t0 = time.time()
|
626 |
resp = client.chat.completions.create(
|
627 |
model=LM_MODEL,
|
628 |
+
messages=[{"role": "user", "content": prompt}], # MODIFIED: Simplified to user role only, as system prompt is now part of the main prompt
|
629 |
+
temperature=0.1, # MODIFIED: slightly lower temperature for more deterministic answers
|
630 |
+
timeout=15, # MODIFIED: slightly longer timeout
|
|
|
|
|
|
|
631 |
max_tokens=max_tokens,
|
632 |
)
|
633 |
used = time.time() - t0
|
|
|
637 |
log.warning("LLM 失敗:%s", e)
|
638 |
return None
|
639 |
|
640 |
+
def make_clarify_message() -> str:
|
641 |
+
# MODIFIED: More generic clarification message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
642 |
msg = (
|
643 |
+
"我需要更多資訊才能準確回答,請您提供:\n"
|
644 |
+
"1. 完整的藥物名稱\n"
|
645 |
+
"2. 劑量和劑型(例如:普拿疼 500mg 錠劑)\n"
|
646 |
+
"3. 您的具體問題\n\n"
|
647 |
+
f"{DISCLAIMER}"
|
648 |
)
|
649 |
+
return msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
650 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
651 |
def handle_error(code: str) -> str:
|
652 |
log.error(f"Pipeline error: {code}")
|
653 |
+
return f"抱歉,系統暫時無法回覆 ({code})。請諮詢醫師或藥師。{DISCLAIMER}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
|
655 |
+
# ---------- 主流程 (MODIFIED) ----------
|
656 |
async def answer_pipeline(query: str, user_id: str) -> str:
|
657 |
log.info("Pipeline start for user_id: %s, query: %s", user_id, query[:50])
|
658 |
if not query or not isinstance(query, str):
|
659 |
return handle_error("INVALID_QUERY")
|
660 |
if not STATE.sentences:
|
661 |
return handle_error("NO_CORPUS")
|
662 |
+
|
663 |
+
# 1. 解析使用者輸入
|
664 |
+
parsed_info = parse_user_message(query)
|
665 |
+
|
666 |
+
# 2. 尋找藥物候選
|
667 |
+
drug_candidates = find_drug_candidates(parsed_info, STATE.df_csv)
|
668 |
+
|
669 |
+
# 3. 選擇最佳藥物或要求澄清
|
670 |
+
drug_choice_or_clarification = select_best_drug_candidate(drug_candidates)
|
671 |
+
|
672 |
+
if drug_choice_or_clarification is None:
|
673 |
+
log.warning("No confident drug match found.")
|
674 |
+
return make_clarify_message()
|
675 |
+
|
676 |
+
if isinstance(drug_choice_or_clarification, str): # It's a clarification message
|
677 |
+
log.info("Requesting clarification from user.")
|
678 |
+
return drug_choice_or_clarification + f"\n\n{DISCLAIMER}"
|
679 |
+
|
680 |
+
drug_choice = drug_choice_or_clarification
|
681 |
+
log.info("Selected drug: %s", drug_choice)
|
682 |
+
|
683 |
+
# 4. 檢索相關內文
|
684 |
+
idxs = fuse_and_select(
|
685 |
+
query=parsed_info["raw_query"],
|
686 |
+
sentences=STATE.sentences,
|
687 |
+
meta=STATE.meta,
|
688 |
+
bm25=STATE.bm25,
|
689 |
+
index=STATE.faiss_index,
|
690 |
+
emb_model=STATE.emb_model,
|
691 |
+
reranker=STATE.reranker_model,
|
692 |
+
top_k=TOP_K_SENTENCES,
|
693 |
+
drug_id=drug_choice['drug_id'],
|
694 |
+
parsed_info=parsed_info
|
695 |
+
)
|
696 |
+
|
697 |
+
if not idxs:
|
698 |
+
return handle_error("NO_CONTEXT")
|
699 |
+
|
700 |
+
# 5. 建立上下文和 Prompt
|
701 |
+
context = build_context(idxs, STATE.sentences, STATE.meta)
|
702 |
+
prompt = build_prompt(parsed_info, context, drug_choice)
|
703 |
+
log.info("Generated Prompt:\n%s", prompt)
|
704 |
+
|
705 |
+
# 6. 呼叫 LLM 生成答案
|
706 |
+
answer = call_llm(prompt)
|
707 |
+
if not answer:
|
708 |
+
return handle_error("LLM_ERROR")
|
709 |
+
|
710 |
+
return f"{answer}\n\n{DISCLAIMER}"
|
711 |
|
712 |
# ---------- LINE 驗簽與回覆 ----------
|
713 |
def verify_line_signature(body_bytes: bytes, signature: str) -> bool:
|