Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
ee31585
1
Parent(s):
5f83bd3
hi
Browse files
app.py
CHANGED
@@ -59,8 +59,6 @@ def _require_llm_config():
|
|
59 |
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
|
60 |
_require_env(k)
|
61 |
|
62 |
-
_require_llm_config()
|
63 |
-
|
64 |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
|
65 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
66 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
@@ -192,10 +190,19 @@ class RagPipeline:
|
|
192 |
self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
|
193 |
)
|
194 |
self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
self._load_drug_name_vocabulary()
|
196 |
|
197 |
log.info("載入 FAISS 索引與句子資料...")
|
198 |
self.state.index = faiss.read_index(FAISS_INDEX)
|
|
|
|
|
|
|
199 |
with open(SENTENCES_PKL, "rb") as f:
|
200 |
data = pickle.load(f)
|
201 |
self.state.sentences = data["sentences"]
|
@@ -204,6 +211,8 @@ class RagPipeline:
|
|
204 |
log.info("載入 BM25 索引...")
|
205 |
with open(BM25_PKL, "rb") as f:
|
206 |
self.state.bm25 = pickle.load(f)
|
|
|
|
|
207 |
except (FileNotFoundError, KeyError) as e:
|
208 |
log.exception(f"資料或索引檔案載入失敗: {e}")
|
209 |
raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}")
|
@@ -217,10 +226,19 @@ class RagPipeline:
|
|
217 |
for part in parts:
|
218 |
if re.search(r'[\u4e00-\u9fff]', part):
|
219 |
self.drug_vocab["zh"].add(part)
|
|
|
|
|
|
|
|
|
220 |
else:
|
221 |
self.drug_vocab["en"].add(part)
|
222 |
for alias in DRUG_NAME_MAPPING:
|
223 |
self.drug_vocab["en"].add(alias.lower())
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
226 |
def _llm_call(self, messages, **kwargs) -> str:
|
@@ -273,9 +291,20 @@ class RagPipeline:
|
|
273 |
|
274 |
@lru_cache(maxsize=128)
|
275 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
276 |
-
|
277 |
-
|
278 |
drug_ids = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
for alias in candidates:
|
280 |
# [MODIFIED] 英文藥名比對使用詞邊界,避免子字串誤判
|
281 |
is_english = not re.search(r'[\u4e00-\u9fff]', alias)
|
@@ -291,7 +320,6 @@ class RagPipeline:
|
|
291 |
drug_ids.update(ids)
|
292 |
return list(drug_ids)
|
293 |
|
294 |
-
|
295 |
def _analyze_query(self, query: str) -> Dict[str, Any]:
|
296 |
prompt = PROMPT_TEMPLATES["analyze_query"].format(
|
297 |
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
|
@@ -311,25 +339,31 @@ class RagPipeline:
|
|
311 |
expanded_q = self._expand_query_with_llm(sub_q, tuple(intents))
|
312 |
|
313 |
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
|
314 |
-
faiss.
|
|
|
315 |
distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K)
|
316 |
|
317 |
tokenized_query = list(jieba.cut(expanded_q))
|
318 |
|
319 |
-
# [MODIFIED]
|
320 |
bm25_scores = self.state.bm25.get_scores(tokenized_query)
|
321 |
-
|
322 |
-
|
|
|
|
|
323 |
|
324 |
candidate_scores: Dict[int, Dict[str, float]] = {}
|
325 |
|
326 |
-
# [MODIFIED]
|
327 |
-
def
|
328 |
-
|
|
|
|
|
|
|
329 |
|
330 |
for i, dist in zip(sim_indices[0], distances[0]):
|
331 |
if i in relevant_indices:
|
332 |
-
similarity =
|
333 |
candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0}
|
334 |
|
335 |
for i, score in doc_to_bm25_score.items():
|
@@ -408,16 +442,14 @@ class RagPipeline:
|
|
408 |
)
|
409 |
|
410 |
# [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
|
411 |
-
def _safe_json_parse(self,
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
json_str = match.group(0)
|
416 |
-
|
417 |
try:
|
418 |
-
return json.loads(
|
419 |
-
except
|
420 |
-
log.warning(f"無法解析 LLM 回傳的 JSON: {
|
421 |
return default
|
422 |
|
423 |
# ---------- FastAPI 事件與路由 ----------
|
@@ -432,6 +464,7 @@ class AppConfig:
|
|
432 |
@app.on_event("startup")
|
433 |
async def startup_event():
|
434 |
global rag_pipeline
|
|
|
435 |
rag_pipeline = RagPipeline(AppConfig)
|
436 |
rag_pipeline.load_data()
|
437 |
log.info("啟動完成,服務準備就緒。")
|
@@ -457,35 +490,41 @@ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
|
457 |
if not hmac.compare_digest(expected_signature, signature):
|
458 |
raise HTTPException(status_code=403, detail="Invalid signature")
|
459 |
|
460 |
-
|
|
|
|
|
|
|
|
|
461 |
for event in data.get("events", []):
|
462 |
if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
|
463 |
reply_token = event.get("replyToken")
|
464 |
user_text = event.get("message", {}).get("text", "").strip()
|
465 |
-
# [MODIFIED]
|
466 |
source = event.get("source", {})
|
467 |
-
|
|
|
468 |
|
469 |
-
if reply_token and
|
470 |
# [MODIFIED] 更改回覆策略:立即回覆處理中訊息,避免 replyToken 逾時
|
471 |
line_reply(reply_token, "收到您的問題,正在查詢資料庫,請稍候...")
|
472 |
# 將耗時的任務交給背景處理,使用 push message 回覆最終答案
|
473 |
-
background_tasks.add_task(process_user_query,
|
474 |
|
475 |
return Response(status_code=status.HTTP_200_OK)
|
476 |
|
477 |
# [MODIFIED] 調整函式簽名,只接收 user_id 和 text,並使用 push message
|
478 |
-
def process_user_query(
|
479 |
try:
|
480 |
if rag_pipeline:
|
481 |
answer = rag_pipeline.answer_question(user_text)
|
482 |
else:
|
483 |
answer = "系統正在啟動中,請稍後再試。"
|
484 |
-
|
485 |
except Exception as e:
|
486 |
-
log.error(f"背景處理
|
487 |
-
|
488 |
|
|
|
489 |
def line_api_call(endpoint: str, data: Dict):
|
490 |
headers = {
|
491 |
"Content-Type": "application/json",
|
@@ -496,14 +535,17 @@ def line_api_call(endpoint: str, data: Dict):
|
|
496 |
response.raise_for_status()
|
497 |
except requests.exceptions.RequestException as e:
|
498 |
log.error(f"LINE API ({endpoint}) 呼叫失敗: {e} | Response: {e.response.text if e.response else 'N/A'}")
|
|
|
499 |
|
500 |
def line_reply(reply_token: str, text: str):
|
501 |
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
|
502 |
line_api_call("reply", {"replyToken": reply_token, "messages": messages})
|
503 |
|
504 |
-
def
|
505 |
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
|
506 |
-
|
|
|
|
|
507 |
|
508 |
# [MODIFIED] 改善藥名提取的正則表達式
|
509 |
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
|
|
|
59 |
for k in ("LITELLM_BASE_URL", "LITELLM_API_KEY", "LM_MODEL"):
|
60 |
_require_env(k)
|
61 |
|
|
|
|
|
62 |
CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
|
63 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
64 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
|
|
190 |
self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
|
191 |
)
|
192 |
self.drug_name_to_ids = self.df_csv.groupby('drug_name_norm_normalized')['drug_id'].unique().apply(list).to_dict()
|
193 |
+
# [MODIFIED] 把別名也變成可查鍵
|
194 |
+
for alias, canonical in DRUG_NAME_MAPPING.items():
|
195 |
+
alias_key = re.sub(r'[^\w\s]', '', alias.lower()).strip()
|
196 |
+
canonical_key = re.sub(r'[^\w\s]', '', canonical.lower()).strip()
|
197 |
+
if canonical_key in self.drug_name_to_ids:
|
198 |
+
self.drug_name_to_ids[alias_key] = self.drug_name_to_ids[canonical_key]
|
199 |
self._load_drug_name_vocabulary()
|
200 |
|
201 |
log.info("載入 FAISS 索引與句子資料...")
|
202 |
self.state.index = faiss.read_index(FAISS_INDEX)
|
203 |
+
self.state.faiss_metric = getattr(self.state.index, "metric_type", faiss.METRIC_L2)
|
204 |
+
if hasattr(self.state.index, "nprobe"):
|
205 |
+
self.state.index.nprobe = int(os.getenv("FAISS_NPROBE", "16"))
|
206 |
with open(SENTENCES_PKL, "rb") as f:
|
207 |
data = pickle.load(f)
|
208 |
self.state.sentences = data["sentences"]
|
|
|
211 |
log.info("載入 BM25 索引...")
|
212 |
with open(BM25_PKL, "rb") as f:
|
213 |
self.state.bm25 = pickle.load(f)
|
214 |
+
if not isinstance(self.state.bm25, BM25Okapi):
|
215 |
+
raise ValueError("Loaded BM25 is not a BM25Okapi instance.")
|
216 |
except (FileNotFoundError, KeyError) as e:
|
217 |
log.exception(f"資料或索引檔案載入失敗: {e}")
|
218 |
raise RuntimeError(f"資料初始化失敗,請檢查檔案路徑與內容: {e}")
|
|
|
226 |
for part in parts:
|
227 |
if re.search(r'[\u4e00-\u9fff]', part):
|
228 |
self.drug_vocab["zh"].add(part)
|
229 |
+
try:
|
230 |
+
jieba.add_word(part, freq=2_000_000)
|
231 |
+
except Exception:
|
232 |
+
pass
|
233 |
else:
|
234 |
self.drug_vocab["en"].add(part)
|
235 |
for alias in DRUG_NAME_MAPPING:
|
236 |
self.drug_vocab["en"].add(alias.lower())
|
237 |
+
if re.search(r'[\u4e00-\u9fff]', alias):
|
238 |
+
try:
|
239 |
+
jieba.add_word(alias, freq=2_000_000)
|
240 |
+
except Exception:
|
241 |
+
pass
|
242 |
|
243 |
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
244 |
def _llm_call(self, messages, **kwargs) -> str:
|
|
|
291 |
|
292 |
@lru_cache(maxsize=128)
|
293 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
294 |
+
q = query.lower()
|
295 |
+
candidates = extract_drug_candidates_from_query(q, self.drug_vocab)
|
296 |
drug_ids = set()
|
297 |
+
|
298 |
+
# 英文:詞邊界;中文:也做子字串掃描
|
299 |
+
for k, ids in self.drug_name_to_ids.items():
|
300 |
+
if re.search(r'[\u4e00-\u9fff]', k):
|
301 |
+
if k in q:
|
302 |
+
drug_ids.update(ids)
|
303 |
+
else:
|
304 |
+
if re.search(rf"\b{re.escape(k)}\b", q):
|
305 |
+
drug_ids.update(ids)
|
306 |
+
|
307 |
+
# 仍保留舊的候選詞路徑(補強)
|
308 |
for alias in candidates:
|
309 |
# [MODIFIED] 英文藥名比對使用詞邊界,避免子字串誤判
|
310 |
is_english = not re.search(r'[\u4e00-\u9fff]', alias)
|
|
|
320 |
drug_ids.update(ids)
|
321 |
return list(drug_ids)
|
322 |
|
|
|
323 |
def _analyze_query(self, query: str) -> Dict[str, Any]:
|
324 |
prompt = PROMPT_TEMPLATES["analyze_query"].format(
|
325 |
options="\n".join(f"- {c}" for c in INTENT_CATEGORIES),
|
|
|
339 |
expanded_q = self._expand_query_with_llm(sub_q, tuple(intents))
|
340 |
|
341 |
q_emb = self.embedding_model.encode([expanded_q], convert_to_numpy=True).astype("float32")
|
342 |
+
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
343 |
+
faiss.normalize_L2(q_emb)
|
344 |
distances, sim_indices = self.state.index.search(q_emb, PRE_RERANK_K)
|
345 |
|
346 |
tokenized_query = list(jieba.cut(expanded_q))
|
347 |
|
348 |
+
# [MODIFIED] 先過濾 relevant_indices 再取 TopK
|
349 |
bm25_scores = self.state.bm25.get_scores(tokenized_query)
|
350 |
+
rel_idx = np.fromiter(relevant_indices, dtype=int)
|
351 |
+
rel_scores = bm25_scores[rel_idx]
|
352 |
+
top_rel = rel_idx[np.argsort(rel_scores)[::-1][:PRE_RERANK_K]]
|
353 |
+
doc_to_bm25_score = {int(i): float(bm25_scores[i]) for i in top_rel}
|
354 |
|
355 |
candidate_scores: Dict[int, Dict[str, float]] = {}
|
356 |
|
357 |
+
# [MODIFIED] 把 distance 轉成「越大越好的相似度」
|
358 |
+
def to_similarity(d: float) -> float:
|
359 |
+
if self.state.faiss_metric == faiss.METRIC_INNER_PRODUCT:
|
360 |
+
return float(d) # IP 越大越好
|
361 |
+
else: # METRIC_L2(多半是平方 L2)
|
362 |
+
return 1.0 / (1.0 + float(d))
|
363 |
|
364 |
for i, dist in zip(sim_indices[0], distances[0]):
|
365 |
if i in relevant_indices:
|
366 |
+
similarity = to_similarity(dist)
|
367 |
candidate_scores[int(i)] = {"sem": float(similarity), "bm": 0.0}
|
368 |
|
369 |
for i, score in doc_to_bm25_score.items():
|
|
|
442 |
)
|
443 |
|
444 |
# [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
|
445 |
+
def _safe_json_parse(self, s: str, default: Any = None) -> Any:
|
446 |
+
m = re.search(r'\{.*?\}', s, re.DOTALL) # 非貪婪
|
447 |
+
if m:
|
448 |
+
s = m.group(0)
|
|
|
|
|
449 |
try:
|
450 |
+
return json.loads(s)
|
451 |
+
except Exception:
|
452 |
+
log.warning(f"無法解析 LLM 回傳的 JSON: {s[:200]}...")
|
453 |
return default
|
454 |
|
455 |
# ---------- FastAPI 事件與路由 ----------
|
|
|
464 |
@app.on_event("startup")
|
465 |
async def startup_event():
|
466 |
global rag_pipeline
|
467 |
+
_require_llm_config()
|
468 |
rag_pipeline = RagPipeline(AppConfig)
|
469 |
rag_pipeline.load_data()
|
470 |
log.info("啟動完成,服務準備就緒。")
|
|
|
490 |
if not hmac.compare_digest(expected_signature, signature):
|
491 |
raise HTTPException(status_code=403, detail="Invalid signature")
|
492 |
|
493 |
+
try:
|
494 |
+
data = json.loads(body.decode('utf-8'))
|
495 |
+
except json.JSONDecodeError:
|
496 |
+
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
497 |
+
|
498 |
for event in data.get("events", []):
|
499 |
if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
|
500 |
reply_token = event.get("replyToken")
|
501 |
user_text = event.get("message", {}).get("text", "").strip()
|
502 |
+
# [MODIFIED] 擷取 target
|
503 |
source = event.get("source", {})
|
504 |
+
stype = source.get("type") # "user" | "group" | "room"
|
505 |
+
target_id = source.get("userId") or source.get("groupId") or source.get("roomId")
|
506 |
|
507 |
+
if reply_token and user_text and target_id:
|
508 |
# [MODIFIED] 更改回覆策略:立即回覆處理中訊息,避免 replyToken 逾時
|
509 |
line_reply(reply_token, "收到您的問題,正在查詢資料庫,請稍候...")
|
510 |
# 將耗時的任務交給背景處理,使用 push message 回覆最終答案
|
511 |
+
background_tasks.add_task(process_user_query, stype, target_id, user_text)
|
512 |
|
513 |
return Response(status_code=status.HTTP_200_OK)
|
514 |
|
515 |
# [MODIFIED] 調整函式簽名,只接收 user_id 和 text,並使用 push message
|
516 |
+
def process_user_query(source_type: str, target_id: str, user_text: str):
|
517 |
try:
|
518 |
if rag_pipeline:
|
519 |
answer = rag_pipeline.answer_question(user_text)
|
520 |
else:
|
521 |
answer = "系統正在啟動中,請稍後再試。"
|
522 |
+
line_push_generic(source_type, target_id, answer)
|
523 |
except Exception as e:
|
524 |
+
log.error(f"背景處理 target_id={target_id} 發生錯誤: {e}", exc_info=True)
|
525 |
+
line_push_generic(source_type, target_id, f"抱歉,處理時發生未預期的錯誤。{DISCLAIMER}")
|
526 |
|
527 |
+
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
528 |
def line_api_call(endpoint: str, data: Dict):
|
529 |
headers = {
|
530 |
"Content-Type": "application/json",
|
|
|
535 |
response.raise_for_status()
|
536 |
except requests.exceptions.RequestException as e:
|
537 |
log.error(f"LINE API ({endpoint}) 呼叫失敗: {e} | Response: {e.response.text if e.response else 'N/A'}")
|
538 |
+
raise
|
539 |
|
540 |
def line_reply(reply_token: str, text: str):
|
541 |
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
|
542 |
line_api_call("reply", {"replyToken": reply_token, "messages": messages})
|
543 |
|
544 |
+
def line_push_generic(source_type: str, target_id: str, text: str):
|
545 |
messages = [{"type": "text", "text": chunk} for chunk in textwrap.wrap(text, 4800, replace_whitespace=False)[:5]]
|
546 |
+
endpoint = "push"
|
547 |
+
data = {"to": target_id, "messages": messages}
|
548 |
+
line_api_call(endpoint, data)
|
549 |
|
550 |
# [MODIFIED] 改善藥名提取的正則表達式
|
551 |
def extract_drug_candidates_from_query(query: str, drug_vocab: dict) -> list:
|