Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
ab6561e
1
Parent(s):
f5c1888
hi
Browse files
app.py
CHANGED
@@ -30,7 +30,7 @@ import textwrap
|
|
30 |
# ---------- 第三方函式庫 ----------
|
31 |
import numpy as np
|
32 |
import pandas as pd
|
33 |
-
from fastapi import FastAPI, Request, Response, HTTPException, status
|
34 |
import uvicorn
|
35 |
import jieba
|
36 |
from rank_bm25 import BM25Okapi
|
@@ -40,7 +40,6 @@ import torch
|
|
40 |
from openai import OpenAI
|
41 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
42 |
import requests
|
43 |
-
from starlette.concurrency import run_in_threadpool
|
44 |
|
45 |
# ==== CONFIG (從環境變數載入,或使用預設值) ====
|
46 |
# 根據提供的檔案清單,將預設路徑設定為當前目錄
|
@@ -157,7 +156,6 @@ class RagPipeline:
|
|
157 |
"""在啟動時載入所有必要的模型與資料"""
|
158 |
log.info("開始載入資料與模型...")
|
159 |
|
160 |
-
# Load CSV and check for required columns
|
161 |
if not os.path.exists(self.csv_path):
|
162 |
raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
|
163 |
|
@@ -172,15 +170,19 @@ class RagPipeline:
|
|
172 |
)
|
173 |
log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
|
174 |
|
175 |
-
# Load corpus and index
|
176 |
self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
|
177 |
self.state.bm25 = self._ensure_bm25_index()
|
178 |
|
179 |
-
# Check for BM25 alignment
|
180 |
-
bm_n = len(self.state.bm25.corpus)
|
181 |
sent_n = len(self.state.sentences)
|
182 |
-
|
|
|
|
|
|
|
183 |
raise RuntimeError(f"BM25 文件數 ({bm_n}) 與 sentences ({sent_n}) 不一致,請重新生成索引。")
|
|
|
|
|
|
|
184 |
log.info("所有模型與資料載入完成。")
|
185 |
|
186 |
def _load_or_build_sentence_index(self):
|
@@ -194,7 +196,6 @@ class RagPipeline:
|
|
194 |
return index, sentences, meta
|
195 |
|
196 |
log.info("索引檔案不存在,正在從 CSV 重新建立...")
|
197 |
-
# This function should be run by a separate script, not here.
|
198 |
raise RuntimeError("FAISS 和句子 PKL 檔案未找到,請先執行索引生成腳本。")
|
199 |
|
200 |
def _ensure_bm25_index(self):
|
@@ -203,11 +204,15 @@ class RagPipeline:
|
|
203 |
try:
|
204 |
with open(BM25_PKL, "rb") as f:
|
205 |
data = pickle.load(f)
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
211 |
except Exception as e:
|
212 |
log.error(f"載入 BM25 索引失敗 ({e}),請檢查檔案格式。")
|
213 |
raise RuntimeError("BM25 索引檔案損壞或格式不符。")
|
@@ -254,7 +259,10 @@ class RagPipeline:
|
|
254 |
all_reranked_results = []
|
255 |
|
256 |
log.info("步驟 3/5: 檢索與重排序...")
|
257 |
-
|
|
|
|
|
|
|
258 |
if not relevant_indices:
|
259 |
log.error("找不到與藥品相關的語料。")
|
260 |
return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
|
@@ -299,6 +307,8 @@ class RagPipeline:
|
|
299 |
bm_vals = np.array([b for _, _, b in candidates_list], dtype=np.float32)
|
300 |
|
301 |
def norm(x):
|
|
|
|
|
302 |
rng = x.max() - x.min()
|
303 |
return (x - x.min()) / (rng + 1e-8)
|
304 |
|
@@ -441,6 +451,9 @@ class RagPipeline:
|
|
441 |
def _semantic_search(self, index, query: str, top_k: int, embedding_model) -> Tuple[List[int], List[float]]:
|
442 |
if not query:
|
443 |
return [], []
|
|
|
|
|
|
|
444 |
q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
|
445 |
faiss.normalize_L2(q_emb)
|
446 |
|
@@ -528,16 +541,16 @@ app = FastAPI()
|
|
528 |
rag_pipeline = None
|
529 |
|
530 |
class AppConfig:
|
531 |
-
CHANNEL_ACCESS_TOKEN = os.getenv("
|
532 |
-
CHANNEL_SECRET = os.getenv("
|
533 |
|
534 |
@app.on_event("startup")
|
535 |
async def startup_event():
|
536 |
"""應用程式啟動時執行的任務"""
|
537 |
log.info("===== Application Startup =====")
|
538 |
missing = []
|
539 |
-
if not AppConfig.CHANNEL_ACCESS_TOKEN: missing.append("
|
540 |
-
if not AppConfig.CHANNEL_SECRET: missing.append("
|
541 |
if not LLM_API_CONFIG.get("api_key"): missing.append("LITELLM_API_KEY")
|
542 |
if not LLM_API_CONFIG.get("base_url"): missing.append("LITELLM_BASE_URL")
|
543 |
if not LLM_API_CONFIG.get("model"): missing.append("LM_MODEL")
|
@@ -556,8 +569,18 @@ async def health_check():
|
|
556 |
"""健康檢查端點,用於 Docker HEALTHCHECK"""
|
557 |
return {"status": "ok"}
|
558 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
@app.post("/webhook")
|
560 |
-
async def handle_webhook(request: Request,
|
561 |
"""處理 LINE Message API 的 Webhook 請求"""
|
562 |
signature = request.headers.get("X-Line-Signature")
|
563 |
if not signature:
|
@@ -580,16 +603,11 @@ async def handle_webhook(request: Request, response: Response):
|
|
580 |
if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
|
581 |
reply_token = event.get("replyToken")
|
582 |
user_text = event.get("message", {}).get("text", "").strip()
|
583 |
-
|
584 |
-
if not user_text: continue
|
585 |
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
if reply_token:
|
590 |
-
line_reply(reply_token, answer)
|
591 |
|
592 |
-
return {"status": "ok"}
|
593 |
|
594 |
def line_reply(reply_token: str, text: str):
|
595 |
"""透過 LINE Message API 回覆訊息,並進行分塊以避免長度限制"""
|
@@ -597,9 +615,9 @@ def line_reply(reply_token: str, text: str):
|
|
597 |
"Content-Type": "application/json",
|
598 |
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}"
|
599 |
}
|
600 |
-
# LINE 文本長度上限約為 5000
|
601 |
-
chunks = textwrap.wrap(text,
|
602 |
-
messages = [{"type": "text", "text": c} for c in chunks] or [{"type": "text", "text": text[:
|
603 |
data = {"replyToken": reply_token, "messages": messages}
|
604 |
try:
|
605 |
r = requests.post("https://api.line.me/v2/bot/message/reply", headers=headers, json=data, timeout=10)
|
|
|
30 |
# ---------- 第三方函式庫 ----------
|
31 |
import numpy as np
|
32 |
import pandas as pd
|
33 |
+
from fastapi import FastAPI, Request, Response, HTTPException, status, BackgroundTasks
|
34 |
import uvicorn
|
35 |
import jieba
|
36 |
from rank_bm25 import BM25Okapi
|
|
|
40 |
from openai import OpenAI
|
41 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
42 |
import requests
|
|
|
43 |
|
44 |
# ==== CONFIG (從環境變數載入,或使用預設值) ====
|
45 |
# 根據提供的檔案清單,將預設路徑設定為當前目錄
|
|
|
156 |
"""在啟動時載入所有必要的模型與資料"""
|
157 |
log.info("開始載入資料與模型...")
|
158 |
|
|
|
159 |
if not os.path.exists(self.csv_path):
|
160 |
raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
|
161 |
|
|
|
170 |
)
|
171 |
log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
|
172 |
|
|
|
173 |
self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
|
174 |
self.state.bm25 = self._ensure_bm25_index()
|
175 |
|
176 |
+
# Check for BM25 and meta alignment
|
|
|
177 |
sent_n = len(self.state.sentences)
|
178 |
+
meta_n = len(self.state.meta)
|
179 |
+
bm_n = getattr(self.state.bm25, 'corpus_size', len(getattr(self.state.bm25, 'doc_len', [])))
|
180 |
+
|
181 |
+
if sent_n != bm_n:
|
182 |
raise RuntimeError(f"BM25 文件數 ({bm_n}) 與 sentences ({sent_n}) 不一致,請重新生成索引。")
|
183 |
+
if sent_n != meta_n:
|
184 |
+
raise RuntimeError(f"sentences ({sent_n}) 與 meta ({meta_n}) 長度不一致,請重新生成索引。")
|
185 |
+
|
186 |
log.info("所有模型與資料載入完成。")
|
187 |
|
188 |
def _load_or_build_sentence_index(self):
|
|
|
196 |
return index, sentences, meta
|
197 |
|
198 |
log.info("索引檔案不存在,正在從 CSV 重新建立...")
|
|
|
199 |
raise RuntimeError("FAISS 和句子 PKL 檔案未找到,請先執行索引生成腳本。")
|
200 |
|
201 |
def _ensure_bm25_index(self):
|
|
|
204 |
try:
|
205 |
with open(BM25_PKL, "rb") as f:
|
206 |
data = pickle.load(f)
|
207 |
+
bm25 = data.get("bm25") if isinstance(data, dict) else data
|
208 |
+
if not hasattr(bm25, 'get_scores'):
|
209 |
+
raise ValueError("載入的 BM25 索引無效。")
|
210 |
+
|
211 |
+
# Use a more robust way to get corpus size
|
212 |
+
corpus_size = getattr(bm25, 'corpus_size', len(getattr(bm25, 'doc_len', [])))
|
213 |
+
log.info(f"成功載入 BM25 索引,包含 {corpus_size} 篇文件。")
|
214 |
+
setattr(self.state, 'bm25_corpus_len', corpus_size)
|
215 |
+
return bm25
|
216 |
except Exception as e:
|
217 |
log.error(f"載入 BM25 索引失敗 ({e}),請檢查檔案格式。")
|
218 |
raise RuntimeError("BM25 索引檔案損壞或格式不符。")
|
|
|
259 |
all_reranked_results = []
|
260 |
|
261 |
log.info("步驟 3/5: 檢索與重排序...")
|
262 |
+
# Ensure drug_id is always string for robust matching
|
263 |
+
drug_ids_set = {str(did) for did in drug_ids}
|
264 |
+
relevant_indices = {i for i, m in enumerate(self.state.meta) if str(m.get("drug_id")) in drug_ids_set}
|
265 |
+
|
266 |
if not relevant_indices:
|
267 |
log.error("找不到與藥品相關的語料。")
|
268 |
return f"找不到 drug_id {drug_ids} 對應的任何 chunks。{DISCLAIMER}"
|
|
|
307 |
bm_vals = np.array([b for _, _, b in candidates_list], dtype=np.float32)
|
308 |
|
309 |
def norm(x):
|
310 |
+
if len(x) == 0 or (x.max() - x.min()) == 0:
|
311 |
+
return np.zeros_like(x)
|
312 |
rng = x.max() - x.min()
|
313 |
return (x - x.min()) / (rng + 1e-8)
|
314 |
|
|
|
451 |
def _semantic_search(self, index, query: str, top_k: int, embedding_model) -> Tuple[List[int], List[float]]:
|
452 |
if not query:
|
453 |
return [], []
|
454 |
+
|
455 |
+
top_k = min(top_k, index.ntotal)
|
456 |
+
|
457 |
q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
|
458 |
faiss.normalize_L2(q_emb)
|
459 |
|
|
|
541 |
rag_pipeline = None
|
542 |
|
543 |
class AppConfig:
|
544 |
+
CHANNEL_ACCESS_TOKEN = os.getenv("LINE_CHANNEL_ACCESS_TOKEN")
|
545 |
+
CHANNEL_SECRET = os.getenv("LINE_CHANNEL_SECRET")
|
546 |
|
547 |
@app.on_event("startup")
|
548 |
async def startup_event():
|
549 |
"""應用程式啟動時執行的任務"""
|
550 |
log.info("===== Application Startup =====")
|
551 |
missing = []
|
552 |
+
if not AppConfig.CHANNEL_ACCESS_TOKEN: missing.append("LINE_CHANNEL_ACCESS_TOKEN")
|
553 |
+
if not AppConfig.CHANNEL_SECRET: missing.append("LINE_CHANNEL_SECRET")
|
554 |
if not LLM_API_CONFIG.get("api_key"): missing.append("LITELLM_API_KEY")
|
555 |
if not LLM_API_CONFIG.get("base_url"): missing.append("LITELLM_BASE_URL")
|
556 |
if not LLM_API_CONFIG.get("model"): missing.append("LM_MODEL")
|
|
|
569 |
"""健康檢查端點,用於 Docker HEALTHCHECK"""
|
570 |
return {"status": "ok"}
|
571 |
|
572 |
+
def process_and_reply(reply_token: str, user_text: str):
|
573 |
+
"""將耗時的 RAG 處理和回覆任務移至後台執行"""
|
574 |
+
try:
|
575 |
+
answer = rag_pipeline.answer_question(user_text)
|
576 |
+
except Exception as e:
|
577 |
+
log.error(f"後台處理錯誤: {e}", exc_info=True)
|
578 |
+
answer = "處理時發生錯誤,請稍後再試。"
|
579 |
+
line_reply(reply_token, answer)
|
580 |
+
|
581 |
+
|
582 |
@app.post("/webhook")
|
583 |
+
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
584 |
"""處理 LINE Message API 的 Webhook 請求"""
|
585 |
signature = request.headers.get("X-Line-Signature")
|
586 |
if not signature:
|
|
|
603 |
if event.get("type") == "message" and event.get("message", {}).get("type") == "text":
|
604 |
reply_token = event.get("replyToken")
|
605 |
user_text = event.get("message", {}).get("text", "").strip()
|
|
|
|
|
606 |
|
607 |
+
if reply_token and user_text:
|
608 |
+
background_tasks.add_task(process_and_reply, reply_token, user_text)
|
|
|
|
|
|
|
609 |
|
610 |
+
return {"status": "ok"} # Immediately return 200 OK
|
611 |
|
612 |
def line_reply(reply_token: str, text: str):
|
613 |
"""透過 LINE Message API 回覆訊息,並進行分塊以避免長度限制"""
|
|
|
615 |
"Content-Type": "application/json",
|
616 |
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}"
|
617 |
}
|
618 |
+
# LINE 文本長度上限約為 5000 字元,且回覆訊息數上限為 5 則
|
619 |
+
chunks = textwrap.wrap(text, 4000)
|
620 |
+
messages = [{"type": "text", "text": c} for c in chunks[:5]] or [{"type": "text", "text": text[:4000]}]
|
621 |
data = {"replyToken": reply_token, "messages": messages}
|
622 |
try:
|
623 |
r = requests.post("https://api.line.me/v2/bot/message/reply", headers=headers, json=data, timeout=10)
|