Song commited on
Commit
ab6561e
·
1 Parent(s): f5c1888
Files changed (1) hide show
  1. app.py +48 -30
app.py CHANGED
@@ -30,7 +30,7 @@ import textwrap
30
  # ---------- 第三方函式庫 ----------
31
  import numpy as np
32
  import pandas as pd
33
- from fastapi import FastAPI, Request, Response, HTTPException, status
34
  import uvicorn
35
  import jieba
36
  from rank_bm25 import BM25Okapi
@@ -40,7 +40,6 @@ import torch
40
  from openai import OpenAI
41
  from tenacity import retry, stop_after_attempt, wait_fixed
42
  import requests
43
- from starlette.concurrency import run_in_threadpool
44
 
45
  # ==== CONFIG (從環境變數載入,或使用預設值) ====
46
  # 根據提供的檔案清單,將預設路徑設定為當前目錄
@@ -157,7 +156,6 @@ class RagPipeline:
157
  """在啟動時載入所有必要的模型與資料"""
158
  log.info("開始載入資料與模型...")
159
 
160
- # Load CSV and check for required columns
161
  if not os.path.exists(self.csv_path):
162
  raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
163
 
@@ -172,15 +170,19 @@ class RagPipeline:
172
  )
173
  log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
174
 
175
- # Load corpus and index
176
  self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
177
  self.state.bm25 = self._ensure_bm25_index()
178
 
179
- # Check for BM25 alignment
180
- bm_n = len(self.state.bm25.corpus)
181
  sent_n = len(self.state.sentences)
182
- if bm_n != sent_n:
 
 
 
183
  raise RuntimeError(f"BM25 文件數 ({bm_n}) 與 sentences ({sent_n}) 不一致,請重新生成索引。")
 
 
 
184
  log.info("所有模型與資料載入完成。")
185
 
186
  def _load_or_build_sentence_index(self):
@@ -194,7 +196,6 @@ class RagPipeline:
194
  return index, sentences, meta
195
 
196
  log.info("索引檔案不存在,正在從 CSV 重新建立...")
197
- # This function should be run by a separate script, not here.
198
  raise RuntimeError("FAISS 和句子 PKL 檔案未找到,請先執行索引生成腳本。")
199
 
200
  def _ensure_bm25_index(self):
@@ -203,11 +204,15 @@ class RagPipeline:
203
  try:
204
  with open(BM25_PKL, "rb") as f:
205
  data = pickle.load(f)
206
- bm25 = data.get("bm25") if isinstance(data, dict) else data
207
- if not hasattr(bm25, 'get_scores'):
208
- raise ValueError("載入的 BM25 索引無效。")
209
- log.info(f"成功載入 BM25 索引,包含 {len(bm25.corpus)} 篇文件。")
210
- return bm25
 
 
 
 
211
  except Exception as e:
212
  log.error(f"載入 BM25 索引失敗 ({e}),請檢查檔案格式。")
213
  raise RuntimeError("BM25 索引檔案損壞或格式不符。")
@@ -254,7 +259,10 @@ class RagPipeline:
254
  all_reranked_results = []
255
 
256
  log.info("步驟 3/5: 檢索與重排序...")
257
- relevant_indices = {i for i, m in enumerate(self.state.meta) if m.get("drug_id") in drug_ids}
 
 
 
258
  if not relevant_indices:
259
  log.error("找不到與藥品相關的語料。")
260
  return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
@@ -299,6 +307,8 @@ class RagPipeline:
299
  bm_vals = np.array([b for _, _, b in candidates_list], dtype=np.float32)
300
 
301
  def norm(x):
 
 
302
  rng = x.max() - x.min()
303
  return (x - x.min()) / (rng + 1e-8)
304
 
@@ -441,6 +451,9 @@ class RagPipeline:
441
  def _semantic_search(self, index, query: str, top_k: int, embedding_model) -> Tuple[List[int], List[float]]:
442
  if not query:
443
  return [], []
 
 
 
444
  q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
445
  faiss.normalize_L2(q_emb)
446
 
@@ -528,16 +541,16 @@ app = FastAPI()
528
  rag_pipeline = None
529
 
530
  class AppConfig:
531
- CHANNEL_ACCESS_TOKEN = os.getenv("CHANNEL_ACCESS_TOKEN")
532
- CHANNEL_SECRET = os.getenv("CHANNEL_SECRET")
533
 
534
  @app.on_event("startup")
535
  async def startup_event():
536
  """應用程式啟動時執行的任務"""
537
  log.info("===== Application Startup =====")
538
  missing = []
539
- if not AppConfig.CHANNEL_ACCESS_TOKEN: missing.append("CHANNEL_ACCESS_TOKEN")
540
- if not AppConfig.CHANNEL_SECRET: missing.append("CHANNEL_SECRET")
541
  if not LLM_API_CONFIG.get("api_key"): missing.append("LITELLM_API_KEY")
542
  if not LLM_API_CONFIG.get("base_url"): missing.append("LITELLM_BASE_URL")
543
  if not LLM_API_CONFIG.get("model"): missing.append("LM_MODEL")
@@ -556,8 +569,18 @@ async def health_check():
556
  """健康檢查端點,用於 Docker HEALTHCHECK"""
557
  return {"status": "ok"}
558
 
 
 
 
 
 
 
 
 
 
 
559
  @app.post("/webhook")
560
- async def handle_webhook(request: Request, response: Response):
561
  """處理 LINE Message API 的 Webhook 請求"""
562
  signature = request.headers.get("X-Line-Signature")
563
  if not signature:
@@ -580,16 +603,11 @@ async def handle_webhook(request: Request, response: Response):
580
  if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
581
  reply_token = event.get("replyToken")
582
  user_text = event.get("message", {}).get("text", "").strip()
583
-
584
- if not user_text: continue
585
 
586
- # Offload heavy work to a thread pool
587
- answer = await run_in_threadpool(rag_pipeline.answer_question, user_text)
588
-
589
- if reply_token:
590
- line_reply(reply_token, answer)
591
 
592
- return {"status": "ok"}
593
 
594
  def line_reply(reply_token: str, text: str):
595
  """透過 LINE Message API 回覆訊息,並進行分塊以避免長度限制"""
@@ -597,9 +615,9 @@ def line_reply(reply_token: str, text: str):
597
  "Content-Type": "application/json",
598
  "Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}"
599
  }
600
- # LINE 文本長度上限約為 5000 字元
601
- chunks = textwrap.wrap(text, 4900)
602
- messages = [{"type": "text", "text": c} for c in chunks] or [{"type": "text", "text": text[:4900]}]
603
  data = {"replyToken": reply_token, "messages": messages}
604
  try:
605
  r = requests.post("https://api.line.me/v2/bot/message/reply", headers=headers, json=data, timeout=10)
 
30
  # ---------- 第三方函式庫 ----------
31
  import numpy as np
32
  import pandas as pd
33
+ from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks
34
  import uvicorn
35
  import jieba
36
  from rank_bm25 import BM25Okapi
 
40
  from openai import OpenAI
41
  from tenacity import retry, stop_after_attempt, wait_fixed
42
  import requests
 
43
 
44
  # ==== CONFIG (從環境變數載入,或使用預設值) ====
45
  # 根據提供的檔案清單,將預設路徑設定為當前目錄
 
156
  """在啟動時載入所有必要的模型與資料"""
157
  log.info("開始載入資料與模型...")
158
 
 
159
  if not os.path.exists(self.csv_path):
160
  raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
161
 
 
170
  )
171
  log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
172
 
 
173
  self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
174
  self.state.bm25 = self._ensure_bm25_index()
175
 
176
+ # Check for BM25 and meta alignment
 
177
  sent_n = len(self.state.sentences)
178
+ meta_n = len(self.state.meta)
179
+ bm_n = getattr(self.state.bm25, 'corpus_size', len(getattr(self.state.bm25, 'doc_len', [])))
180
+
181
+ if sent_n != bm_n:
182
  raise RuntimeError(f"BM25 文件數 ({bm_n}) 與 sentences ({sent_n}) 不一致,請重新生成索引。")
183
+ if sent_n != meta_n:
184
+ raise RuntimeError(f"sentences ({sent_n}) 與 meta ({meta_n}) 長度不一致,請重新生成索引。")
185
+
186
  log.info("所有模型與資料載入完成。")
187
 
188
  def _load_or_build_sentence_index(self):
 
196
  return index, sentences, meta
197
 
198
  log.info("索引檔案不存在,正在從 CSV 重新建立...")
 
199
  raise RuntimeError("FAISS 和句子 PKL 檔案未找到,請先執行索引生成腳本。")
200
 
201
  def _ensure_bm25_index(self):
 
204
  try:
205
  with open(BM25_PKL, "rb") as f:
206
  data = pickle.load(f)
207
+ bm25 = data.get("bm25") if isinstance(data, dict) else data
208
+ if not hasattr(bm25, 'get_scores'):
209
+ raise ValueError("載入的 BM25 索引無效。")
210
+
211
+ # Use a more robust way to get corpus size
212
+ corpus_size = getattr(bm25, 'corpus_size', len(getattr(bm25, 'doc_len', [])))
213
+ log.info(f"成功載入 BM25 索引,包含 {corpus_size} 篇文件。")
214
+ setattr(self.state, 'bm25_corpus_len', corpus_size)
215
+ return bm25
216
  except Exception as e:
217
  log.error(f"載入 BM25 索引失敗 ({e}),請檢查檔案格式。")
218
  raise RuntimeError("BM25 索引檔案損壞或格式不符。")
 
259
  all_reranked_results = []
260
 
261
  log.info("步驟 3/5: 檢索與重排序...")
262
+ # Ensure drug_id is always string for robust matching
263
+ drug_ids_set = {str(did) for did in drug_ids}
264
+ relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id")) in drug_ids_set}
265
+
266
  if not relevant_indices:
267
  log.error("找不到與藥品相關的語料。")
268
  return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
 
307
  bm_vals = np.array([b for _, _, b in candidates_list], dtype=np.float32)
308
 
309
  def norm(x):
310
+ if len(x) == 0 or (x.max() - x.min()) == 0:
311
+ return np.zeros_like(x)
312
  rng = x.max() - x.min()
313
  return (x - x.min()) / (rng + 1e-8)
314
 
 
451
  def _semantic_search(self, index, query: str, top_k: int, embedding_model) -> Tuple[List[int], List[float]]:
452
  if not query:
453
  return [], []
454
+
455
+ top_k = min(top_k, index.ntotal)
456
+
457
  q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
458
  faiss.normalize_L2(q_emb)
459
 
 
541
  rag_pipeline = None
542
 
543
  class AppConfig:
544
+ CHANNEL_ACCESS_TOKEN = os.getenv("LINE_CHANNEL_ACCESS_TOKEN")
545
+ CHANNEL_SECRET = os.getenv("LINE_CHANNEL_SECRET")
546
 
547
  @app.on_event("startup")
548
  async def startup_event():
549
  """應用程式啟動時執行的任務"""
550
  log.info("===== Application Startup =====")
551
  missing = []
552
+ if not AppConfig.CHANNEL_ACCESS_TOKEN: missing.append("LINE_CHANNEL_ACCESS_TOKEN")
553
+ if not AppConfig.CHANNEL_SECRET: missing.append("LINE_CHANNEL_SECRET")
554
  if not LLM_API_CONFIG.get("api_key"): missing.append("LITELLM_API_KEY")
555
  if not LLM_API_CONFIG.get("base_url"): missing.append("LITELLM_BASE_URL")
556
  if not LLM_API_CONFIG.get("model"): missing.append("LM_MODEL")
 
569
  """健康檢查端點,用於 Docker HEALTHCHECK"""
570
  return {"status": "ok"}
571
 
572
+ def process_and_reply(reply_token: str, user_text: str):
573
+ """將耗時的 RAG 處理和回覆任務移至後台執行"""
574
+ try:
575
+ answer = rag_pipeline.answer_question(user_text)
576
+ except Exception as e:
577
+ log.error(f"後台處理錯誤: {e}", exc_info=True)
578
+ answer = "處理時發生錯誤,請稍後再試。"
579
+ line_reply(reply_token, answer)
580
+
581
+
582
  @app.post("/webhook")
583
+ async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
584
  """處理 LINE Message API 的 Webhook 請求"""
585
  signature = request.headers.get("X-Line-Signature")
586
  if not signature:
 
603
  if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
604
  reply_token = event.get("replyToken")
605
  user_text = event.get("message", {}).get("text", "").strip()
 
 
606
 
607
+ if reply_token and user_text:
608
+ background_tasks.add_task(process_and_reply, reply_token, user_text)
 
 
 
609
 
610
+ return {"status": "ok"} # Immediately return 200 OK
611
 
612
  def line_reply(reply_token: str, text: str):
613
  """透過 LINE Message API 回覆訊息,並進行分塊以避免長度限制"""
 
615
  "Content-Type": "application/json",
616
  "Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}"
617
  }
618
+ # LINE 文本長度上限約為 5000 字元,且回覆訊息數上限為 5 則
619
+ chunks = textwrap.wrap(text, 4000)
620
+ messages = [{"type": "text", "text": c} for c in chunks[:5]] or [{"type": "text", "text": text[:4000]}]
621
  data = {"replyToken": reply_token, "messages": messages}
622
  try:
623
  r = requests.post("https://api.line.me/v2/bot/message/reply", headers=headers, json=data, timeout=10)