Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
5f83bd3
1
Parent(s):
f69df5b
hi
Browse files
app.py
CHANGED
@@ -43,7 +43,24 @@ from openai import OpenAI
|
|
43 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
44 |
import requests
|
45 |
|
|
|
|
|
|
|
46 |
# ==== CONFIG (從環境變數載入,或使用預設值) ====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
|
48 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
49 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
@@ -51,7 +68,7 @@ BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
|
|
51 |
|
52 |
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 15))
|
53 |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
|
54 |
-
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
|
55 |
|
56 |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
|
57 |
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
|
@@ -81,7 +98,6 @@ DRUG_NAME_MAPPING = {
|
|
81 |
}
|
82 |
DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務務必諮詢您的醫師或藥師。"
|
83 |
|
84 |
-
# [NEW] 集中管理 Prompt 模板
|
85 |
PROMPT_TEMPLATES = {
|
86 |
"analyze_query": """
|
87 |
請分析以下使用者問題,並完成以下兩個任務:
|
@@ -126,7 +142,6 @@ PROMPT_TEMPLATES = {
|
|
126 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
127 |
log = logging.getLogger(__name__)
|
128 |
|
129 |
-
# [NEW] 使用 Dataclasses 提升程式碼可讀性
|
130 |
@dataclass
|
131 |
class FusedCandidate:
|
132 |
idx: int
|
@@ -145,7 +160,9 @@ class RerankResult:
|
|
145 |
class RagPipeline:
|
146 |
def __init__(self, config):
|
147 |
self.config = config
|
148 |
-
self.state = type('state', (), {})()
|
|
|
|
|
149 |
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
|
150 |
self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding")
|
151 |
self.reranker = self._load_model(CrossEncoder, RERANKER_MODEL, "reranker")
|
@@ -164,24 +181,33 @@ class RagPipeline:
|
|
164 |
|
165 |
def load_data(self):
|
166 |
log.info("開始載入資料與模型...")
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
-
log.info("載入 FAISS 索引與句子資料...")
|
175 |
-
self.state.index = faiss.read_index(FAISS_INDEX)
|
176 |
-
with open(SENTENCES_PKL, "rb") as f:
|
177 |
-
data = pickle.load(f)
|
178 |
-
self.state.sentences = data["sentences"]
|
179 |
-
self.state.meta = data["meta"]
|
180 |
-
|
181 |
-
log.info("載入 BM25 索引...")
|
182 |
-
with open(BM25_PKL, "rb") as f:
|
183 |
-
self.state.bm25 = pickle.load(f)
|
184 |
-
|
185 |
log.info("所有模型與資料載入完���。")
|
186 |
|
187 |
def _load_drug_name_vocabulary(self):
|
@@ -248,15 +274,24 @@ class RagPipeline:
|
|
248 |
@lru_cache(maxsize=128)
|
249 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
250 |
candidates = extract_drug_candidates_from_query(query.lower(), self.drug_vocab)
|
251 |
-
expanded = {c.lower().replace(" ", "") for c in candidates} | set(candidates)
|
252 |
|
253 |
drug_ids = set()
|
254 |
-
for alias in
|
|
|
|
|
255 |
for drug_name_norm, ids in self.drug_name_to_ids.items():
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
drug_ids.update(ids)
|
258 |
return list(drug_ids)
|
259 |
|
|
|
260 |
def _analyze_query(self, query: str) -> Dict[str, Any]:
|
261 |
prompt = PROMPT_TEMPLATES["analyze_query"].format(
|
262 |
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
|
@@ -280,48 +315,52 @@ class RagPipeline:
|
|
280 |
distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K)
|
281 |
|
282 |
tokenized_query = list(jieba.cut(expanded_q))
|
283 |
-
# [MODIFIED] 使用 get_top_n 提升 BM25 效率
|
284 |
-
bm25_results = self.state.bm25.get_top_n(tokenized_query, self.state.sentences, n=PRE_RERANK_K)
|
285 |
|
286 |
-
|
287 |
-
|
|
|
|
|
|
|
288 |
candidate_scores: Dict[int, Dict[str, float]] = {}
|
|
|
|
|
|
|
|
|
|
|
289 |
for i, dist in zip(sim_indices[0], distances[0]):
|
290 |
if i in relevant_indices:
|
291 |
-
|
|
|
292 |
|
293 |
for i, score in doc_to_bm25_score.items():
|
294 |
if i in relevant_indices:
|
295 |
-
|
296 |
-
candidate_scores[i]["bm"] = score
|
297 |
-
else:
|
298 |
-
candidate_scores[i] = {"sem": 0.0, "bm": score}
|
299 |
|
300 |
if not candidate_scores: continue
|
301 |
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
|
|
|
|
|
|
|
|
|
307 |
sem_n, bm_n = norm(sem_scores), norm(bm_scores)
|
308 |
|
309 |
-
for idx,
|
310 |
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
|
311 |
|
312 |
-
if
|
313 |
-
all_fused_candidates[
|
314 |
-
idx=
|
315 |
)
|
316 |
|
317 |
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
|
318 |
|
319 |
-
|
320 |
def _expand_query_with_llm(self, query: str, intents: tuple) -> str:
|
321 |
-
"""
|
322 |
-
Expands a query using the LLM.
|
323 |
-
[CORRECTED] Now safely handles failed or empty LLM responses.
|
324 |
-
"""
|
325 |
if not intents:
|
326 |
return query
|
327 |
|
@@ -329,14 +368,13 @@ class RagPipeline:
|
|
329 |
|
330 |
try:
|
331 |
expanded_query = self._llm_call([{"role": "user", "content": prompt}])
|
332 |
-
# Ensure the result is a non-empty string before returning
|
333 |
if expanded_query and expanded_query.strip():
|
334 |
return expanded_query
|
335 |
else:
|
336 |
log.warning(f"Query expansion for '{query}' returned an empty result. Using original query.")
|
337 |
return query
|
338 |
except Exception as e:
|
339 |
-
log.error(f"Query expansion for '{query}' failed
|
340 |
return query
|
341 |
|
342 |
def _rerank_with_crossencoder(self, query: str, candidates: List[FusedCandidate]) -> List[RerankResult]:
|
@@ -369,7 +407,13 @@ class RagPipeline:
|
|
369 |
additional_instruction=add_instr, context=context, query=query
|
370 |
)
|
371 |
|
|
|
372 |
def _safe_json_parse(self, json_str: str, default: Any = None) -> Any:
|
|
|
|
|
|
|
|
|
|
|
373 |
try:
|
374 |
return json.loads(json_str)
|
375 |
except json.JSONDecodeError:
|
@@ -380,9 +424,10 @@ class RagPipeline:
|
|
380 |
app = FastAPI()
|
381 |
rag_pipeline: Optional[RagPipeline] = None
|
382 |
|
|
|
383 |
class AppConfig:
|
384 |
-
CHANNEL_ACCESS_TOKEN =
|
385 |
-
CHANNEL_SECRET =
|
386 |
|
387 |
@app.on_event("startup")
|
388 |
async def startup_event():
|
@@ -393,40 +438,53 @@ async def startup_event():
|
|
393 |
|
394 |
@app.post("/webhook")
|
395 |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
|
|
396 |
signature = request.headers.get("X-Line-Signature")
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
body = await request.body()
|
398 |
-
|
399 |
-
|
400 |
-
base64.b64encode(
|
401 |
-
|
402 |
-
|
|
|
|
|
|
|
403 |
raise HTTPException(status_code=403, detail="Invalid signature")
|
404 |
|
405 |
data = json.loads(body.decode('utf-8'))
|
406 |
for event in data.get("events", []):
|
407 |
if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
|
408 |
-
reply_token = event
|
409 |
-
user_text = event
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
return Response(status_code=status.HTTP_200_OK)
|
416 |
|
417 |
-
|
418 |
-
|
419 |
-
line_push(user_id, "收到您的問題,正在查詢資料庫,請稍候...")
|
420 |
-
|
421 |
try:
|
422 |
if rag_pipeline:
|
423 |
answer = rag_pipeline.answer_question(user_text)
|
424 |
-
line_reply(reply_token, answer)
|
425 |
else:
|
426 |
-
|
|
|
427 |
except Exception as e:
|
428 |
log.error(f"背景處理 user_id={user_id} 發生錯誤: {e}", exc_info=True)
|
429 |
-
|
430 |
|
431 |
def line_api_call(endpoint: str, data: Dict):
|
432 |
headers = {
|
@@ -447,15 +505,17 @@ def line_push(user_id: str, text: str):
|
|
447 |
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
|
448 |
line_api_call("push", {"to": user_id, "messages": messages})
|
449 |
|
450 |
-
#
|
451 |
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
|
452 |
candidates = set()
|
453 |
-
|
|
|
|
|
454 |
for word in words:
|
455 |
if word in drug_vocab["en"]:
|
456 |
candidates.add(word)
|
457 |
|
458 |
-
for token in jieba.cut(
|
459 |
if token in drug_vocab["zh"]:
|
460 |
candidates.add(token)
|
461 |
|
|
|
43 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
44 |
import requests
|
45 |
|
46 |
+
# [MODIFIED] 限制 PyTorch 執行緒數量,避免在 CPU 環境下過度佔用資源
|
47 |
+
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "1")))
|
48 |
+
|
49 |
# ==== CONFIG (從環境變數載入,或使用預設值) ====
|
50 |
+
# [MODIFIED] 新增環境變數健檢函式
|
51 |
+
def _require_env(var: str) -> str:
|
52 |
+
v = os.getenv(var)
|
53 |
+
if not v:
|
54 |
+
raise RuntimeError(f"FATAL: Missing required environment variable: {var}")
|
55 |
+
return v
|
56 |
+
|
57 |
+
# [MODIFIED] 檢查 LLM 相關環境變數
|
58 |
+
def _require_llm_config():
|
59 |
+
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
|
60 |
+
_require_env(k)
|
61 |
+
|
62 |
+
_require_llm_config()
|
63 |
+
|
64 |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
|
65 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
66 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
|
|
68 |
|
69 |
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 15))
|
70 |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
|
71 |
+
MAX_RERANK_CANDIDATES = int(os.getenv("MAX_RERANK_CANDIDATES", 30))
|
72 |
|
73 |
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "DMetaSoul/Dmeta-embedding-zh")
|
74 |
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-v2-m3")
|
|
|
98 |
}
|
99 |
DISCLAIMER = "本資訊僅供參考,若您對藥物使用有任何疑問,請務務必諮詢您的醫師或藥師。"
|
100 |
|
|
|
101 |
PROMPT_TEMPLATES = {
|
102 |
"analyze_query": """
|
103 |
請分析以下使用者問題,並完成以下兩個任務:
|
|
|
142 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
143 |
log = logging.getLogger(__name__)
|
144 |
|
|
|
145 |
@dataclass
|
146 |
class FusedCandidate:
|
147 |
idx: int
|
|
|
160 |
class RagPipeline:
|
161 |
def __init__(self, config):
|
162 |
self.config = config
|
163 |
+
self.state = type('state', (), {})()
|
164 |
+
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
|
165 |
+
raise ValueError("LLM API Key or Base URL is not configured.")
|
166 |
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
|
167 |
self.embedding_model = self._load_model(SentenceTransformer, EMBEDDING_MODEL, "embedding")
|
168 |
self.reranker = self._load_model(CrossEncoder, RERANKER_MODEL, "reranker")
|
|
|
181 |
|
182 |
def load_data(self):
|
183 |
log.info("開始載入資料與模型...")
|
184 |
+
try:
|
185 |
+
self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('')
|
186 |
+
# [MODIFIED] 增加必要欄位檢查
|
187 |
+
for col in ("drug_name_norm", "drug_id"):
|
188 |
+
if col not in self.df_csv.columns:
|
189 |
+
raise KeyError(f"CSV 檔案 '{CSV_PATH}' 中缺少必要欄位: {col}")
|
190 |
+
|
191 |
+
self.df_csv['drug_name_norm_normalized'] = (
|
192 |
+
self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
|
193 |
+
)
|
194 |
+
self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
|
195 |
+
self._load_drug_name_vocabulary()
|
196 |
+
|
197 |
+
log.info("載入 FAISS 索引與句子資料...")
|
198 |
+
self.state.index = faiss.read_index(FAISS_INDEX)
|
199 |
+
with open(SENTENCES_PKL, "rb") as f:
|
200 |
+
data = pickle.load(f)
|
201 |
+
self.state.sentences = data["sentences"]
|
202 |
+
self.state.meta = data["meta"]
|
203 |
+
|
204 |
+
log.info("載入 BM25 索引...")
|
205 |
+
with open(BM25_PKL, "rb") as f:
|
206 |
+
self.state.bm25 = pickle.load(f)
|
207 |
+
except (FileNotFoundError, KeyError) as e:
|
208 |
+
log.exception(f"資料或索引檔案載入失敗: {e}")
|
209 |
+
raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}")
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
log.info("所有模型與資料載入完���。")
|
212 |
|
213 |
def _load_drug_name_vocabulary(self):
|
|
|
274 |
@lru_cache(maxsize=128)
|
275 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
276 |
candidates = extract_drug_candidates_from_query(query.lower(), self.drug_vocab)
|
|
|
277 |
|
278 |
drug_ids = set()
|
279 |
+
for alias in candidates:
|
280 |
+
# [MODIFIED] 英文藥名比對使用詞邊界,避免子字串誤判
|
281 |
+
is_english = not re.search(r'[\u4e00-\u9fff]', alias)
|
282 |
for drug_name_norm, ids in self.drug_name_to_ids.items():
|
283 |
+
match = False
|
284 |
+
if is_english:
|
285 |
+
if re.search(rf"\b{re.escape(alias)}\b", drug_name_norm):
|
286 |
+
match = True
|
287 |
+
elif alias in drug_name_norm:
|
288 |
+
match = True
|
289 |
+
|
290 |
+
if match:
|
291 |
drug_ids.update(ids)
|
292 |
return list(drug_ids)
|
293 |
|
294 |
+
|
295 |
def _analyze_query(self, query: str) -> Dict[str, Any]:
|
296 |
prompt = PROMPT_TEMPLATES["analyze_query"].format(
|
297 |
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
|
|
|
315 |
distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K)
|
316 |
|
317 |
tokenized_query = list(jieba.cut(expanded_q))
|
|
|
|
|
318 |
|
319 |
+
# [MODIFIED] 改為獲取真實 BM25 分數,而非使用排名
|
320 |
+
bm25_scores = self.state.bm25.get_scores(tokenized_query)
|
321 |
+
top_bm25_indices = np.argsort(bm25_scores)[::-1][:PRE_RERANK_K]
|
322 |
+
doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_bm25_indices}
|
323 |
+
|
324 |
candidate_scores: Dict[int, Dict[str, float]] = {}
|
325 |
+
|
326 |
+
# [MODIFIED] FAISS L2 距離轉為相似度 (分數越高越好)
|
327 |
+
def dist_to_sim(d: float) -> float:
|
328 |
+
return 1.0 / (1.0 + d)
|
329 |
+
|
330 |
for i, dist in zip(sim_indices[0], distances[0]):
|
331 |
if i in relevant_indices:
|
332 |
+
similarity = dist_to_sim(dist)
|
333 |
+
candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0}
|
334 |
|
335 |
for i, score in doc_to_bm25_score.items():
|
336 |
if i in relevant_indices:
|
337 |
+
candidate_scores.setdefault(i, {"sem": 0.0, "bm": 0.0})["bm"] = score
|
|
|
|
|
|
|
338 |
|
339 |
if not candidate_scores: continue
|
340 |
|
341 |
+
# [MODIFIED] 使用固定的鍵順序來確保分數對齊
|
342 |
+
keys = list(candidate_scores.keys())
|
343 |
+
sem_scores = np.array([candidate_scores[k]['sem'] for k in keys])
|
344 |
+
bm_scores = np.array([candidate_scores[k]['bm'] for k in keys])
|
345 |
|
346 |
+
def norm(x):
|
347 |
+
rng = x.max() - x.min()
|
348 |
+
return (x - x.min()) / (rng + 1e-8) if rng > 0 else np.zeros_like(x)
|
349 |
+
|
350 |
sem_n, bm_n = norm(sem_scores), norm(bm_scores)
|
351 |
|
352 |
+
for idx, k in enumerate(keys):
|
353 |
fused_score = sem_n[idx] * 0.6 + bm_n[idx] * 0.4
|
354 |
|
355 |
+
if k not in all_fused_candidates or fused_score > all_fused_candidates[k].fused_score:
|
356 |
+
all_fused_candidates[k] = FusedCandidate(
|
357 |
+
idx=k, fused_score=fused_score, sem_score=sem_scores[idx], bm_score=bm_scores[idx]
|
358 |
)
|
359 |
|
360 |
return sorted(all_fused_candidates.values(), key=lambda x: x.fused_score, reverse=True)
|
361 |
|
362 |
+
# [MODIFIED] 移除 lru_cache,因對多變的長查詢效果不佳
|
363 |
def _expand_query_with_llm(self, query: str, intents: tuple) -> str:
|
|
|
|
|
|
|
|
|
364 |
if not intents:
|
365 |
return query
|
366 |
|
|
|
368 |
|
369 |
try:
|
370 |
expanded_query = self._llm_call([{"role": "user", "content": prompt}])
|
|
|
371 |
if expanded_query and expanded_query.strip():
|
372 |
return expanded_query
|
373 |
else:
|
374 |
log.warning(f"Query expansion for '{query}' returned an empty result. Using original query.")
|
375 |
return query
|
376 |
except Exception as e:
|
377 |
+
log.error(f"Query expansion for '{query}' failed: {e}. Using original query.")
|
378 |
return query
|
379 |
|
380 |
def _rerank_with_crossencoder(self, query: str, candidates: List[FusedCandidate]) -> List[RerankResult]:
|
|
|
407 |
additional_instruction=add_instr, context=context, query=query
|
408 |
)
|
409 |
|
410 |
+
# [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
|
411 |
def _safe_json_parse(self, json_str: str, default: Any = None) -> Any:
|
412 |
+
# Find the JSON object within the string
|
413 |
+
match = re.search(r'\{.*\}', json_str, re.DOTALL)
|
414 |
+
if match:
|
415 |
+
json_str = match.group(0)
|
416 |
+
|
417 |
try:
|
418 |
return json.loads(json_str)
|
419 |
except json.JSONDecodeError:
|
|
|
424 |
app = FastAPI()
|
425 |
rag_pipeline: Optional[RagPipeline] = None
|
426 |
|
427 |
+
# [MODIFIED] 將 LINE 配置集中管理並進行啟動時檢查
|
428 |
class AppConfig:
|
429 |
+
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
|
430 |
+
CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
|
431 |
|
432 |
@app.on_event("startup")
|
433 |
async def startup_event():
|
|
|
438 |
|
439 |
@app.post("/webhook")
|
440 |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
441 |
+
# [MODIFIED] 增強簽章驗證與環境變數檢查
|
442 |
signature = request.headers.get("X-Line-Signature")
|
443 |
+
if not signature:
|
444 |
+
raise HTTPException(status_code=400, detail="Missing X-Line-Signature")
|
445 |
+
if not AppConfig.CHANNEL_SECRET:
|
446 |
+
log.error("CHANNEL_SECRET is not configured.")
|
447 |
+
raise HTTPException(status_code=500, detail="Server configuration error")
|
448 |
+
|
449 |
body = await request.body()
|
450 |
+
try:
|
451 |
+
hash = hmac.new(AppConfig.CHANNEL_SECRET.encode('utf-8'), body, hashlib.sha256)
|
452 |
+
expected_signature = base64.b64encode(hash.digest()).decode('utf-8')
|
453 |
+
except Exception as e:
|
454 |
+
log.error(f"Failed to generate signature: {e}")
|
455 |
+
raise HTTPException(status_code=500, detail="Signature generation error")
|
456 |
+
|
457 |
+
if not hmac.compare_digest(expected_signature, signature):
|
458 |
raise HTTPException(status_code=403, detail="Invalid signature")
|
459 |
|
460 |
data = json.loads(body.decode('utf-8'))
|
461 |
for event in data.get("events", []):
|
462 |
if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
|
463 |
+
reply_token = event.get("replyToken")
|
464 |
+
user_text = event.get("message", {}).get("text", "").strip()
|
465 |
+
# [MODIFIED] 安全地獲取 userId,應對群組/聊天室中可能不存在的情況
|
466 |
+
source = event.get("source", {})
|
467 |
+
user_id = source.get("userId")
|
468 |
+
|
469 |
+
if reply_token and user_id and user_text:
|
470 |
+
# [MODIFIED] 更改回覆策略:立即回覆處理中訊息,避免 replyToken 逾時
|
471 |
+
line_reply(reply_token, "收到您的問題,正在查詢資料庫,請稍候...")
|
472 |
+
# 將耗時的任務交給背景處理,使用 push message 回覆最終答案
|
473 |
+
background_tasks.add_task(process_user_query, user_id, user_text)
|
474 |
|
475 |
return Response(status_code=status.HTTP_200_OK)
|
476 |
|
477 |
+
# [MODIFIED] 調整函式簽名,只接收 user_id 和 text,並使用 push message
|
478 |
+
def process_user_query(user_id: str, user_text: str):
|
|
|
|
|
479 |
try:
|
480 |
if rag_pipeline:
|
481 |
answer = rag_pipeline.answer_question(user_text)
|
|
|
482 |
else:
|
483 |
+
answer = "系統正在啟動中,請稍後再試。"
|
484 |
+
line_push(user_id, answer)
|
485 |
except Exception as e:
|
486 |
log.error(f"背景處理 user_id={user_id} 發生錯誤: {e}", exc_info=True)
|
487 |
+
line_push(user_id, f"抱歉,處理時發生未預期的錯誤。{DISCLAIMER}")
|
488 |
|
489 |
def line_api_call(endpoint: str, data: Dict):
|
490 |
headers = {
|
|
|
505 |
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
|
506 |
line_api_call("push", {"to": user_id, "messages": messages})
|
507 |
|
508 |
+
# [MODIFIED] 改善藥名提取的正則表達式
|
509 |
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
|
510 |
candidates = set()
|
511 |
+
q_lower = query.lower()
|
512 |
+
# 允許藥名中包含 -, /, . 等符號
|
513 |
+
words = re.findall(r"[a-z0-9][a-z0-9+\-/\.]*", q_lower)
|
514 |
for word in words:
|
515 |
if word in drug_vocab["en"]:
|
516 |
candidates.add(word)
|
517 |
|
518 |
+
for token in jieba.cut(q_lower):
|
519 |
if token in drug_vocab["zh"]:
|
520 |
candidates.add(token)
|
521 |
|