Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
0f8b8aa
1
Parent(s):
da1563e
hi
Browse files
app.py
CHANGED
@@ -25,6 +25,7 @@ import json
|
|
25 |
from typing import List, Dict, Any, Optional, Tuple, Union
|
26 |
from functools import lru_cache
|
27 |
import time
|
|
|
28 |
|
29 |
# ---------- 第三方函式庫 ----------
|
30 |
import numpy as np
|
@@ -39,6 +40,7 @@ import torch
|
|
39 |
from openai import OpenAI
|
40 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
41 |
import requests
|
|
|
42 |
|
43 |
# ==== CONFIG (從環境變數載入,或使用預設值) ====
|
44 |
# 根據提供的檔案清單,將預設路徑設定為當前目錄
|
@@ -46,7 +48,6 @@ CSV_PATH = os.getenv("CSV_PATH", "cleaned_combined.csv")
|
|
46 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
47 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
48 |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
|
49 |
-
META_PKL = "/tmp/drug_meta.pkl" # 這個檔案不再需要,但保留以避免載入時出錯
|
50 |
|
51 |
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 30))
|
52 |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
|
@@ -124,7 +125,6 @@ class RagPipeline:
|
|
124 |
self.embedding_model = self._load_embedding_model()
|
125 |
self.reranker = self._load_reranker_model()
|
126 |
self.csv_path = self._ensure_csv_path(CSV_PATH)
|
127 |
-
self.app = FastAPI()
|
128 |
|
129 |
def _load_embedding_model(self):
|
130 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -156,18 +156,31 @@ class RagPipeline:
|
|
156 |
def _load_data(self):
|
157 |
"""在啟動時載入所有必要的模型與資料"""
|
158 |
log.info("開始載入資料與模型...")
|
159 |
-
# 載入 CSV
|
160 |
-
if os.path.exists(self.csv_path):
|
161 |
-
self.df_csv = pd.read_csv(self.csv_path, dtype=str).fillna('')
|
162 |
-
self.df_csv['drug_name_norm_normalized'] = self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
|
163 |
-
log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
|
164 |
-
else:
|
165 |
-
log.error(f"錯誤: 找不到 CSV 檔案於 {self.csv_path}")
|
166 |
-
self.df_csv = None
|
167 |
|
168 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
|
170 |
self.state.bm25 = self._ensure_bm25_index()
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
log.info("所有模型與資料載入完成。")
|
172 |
|
173 |
def _load_or_build_sentence_index(self):
|
@@ -181,7 +194,7 @@ class RagPipeline:
|
|
181 |
return index, sentences, meta
|
182 |
|
183 |
log.info("索引檔案不存在,正在從 CSV 重新建立...")
|
184 |
-
#
|
185 |
raise RuntimeError("FAISS 和句子 PKL 檔案未找到,請先執行索引生成腳本。")
|
186 |
|
187 |
def _ensure_bm25_index(self):
|
@@ -190,11 +203,10 @@ class RagPipeline:
|
|
190 |
try:
|
191 |
with open(BM25_PKL, "rb") as f:
|
192 |
data = pickle.load(f)
|
193 |
-
|
194 |
-
bm25 = data.get("bm25")
|
195 |
if not hasattr(bm25, 'get_scores'):
|
196 |
raise ValueError("載入的 BM25 索引無效。")
|
197 |
-
log.info(f"成功載入 BM25 索引,包含 {len(
|
198 |
return bm25
|
199 |
except Exception as e:
|
200 |
log.error(f"載入 BM25 索引失敗 ({e}),請檢查檔案格式。")
|
@@ -211,14 +223,15 @@ class RagPipeline:
|
|
211 |
messages=messages,
|
212 |
temperature=temperature,
|
213 |
max_tokens=max_tokens,
|
|
|
214 |
)
|
215 |
return response.choices[0].message.content
|
216 |
except Exception as e:
|
217 |
log.error(f"LLM API 呼叫失敗: {e}")
|
218 |
raise
|
219 |
|
220 |
-
|
221 |
-
"""處理使用者問題的完整流程"""
|
222 |
start_time = time.time()
|
223 |
log.info(f"===== 處理新查詢: '{q_orig}' =====")
|
224 |
|
@@ -239,8 +252,7 @@ class RagPipeline:
|
|
239 |
log.info(f"分析結果 - 意圖: {intents}")
|
240 |
|
241 |
all_reranked_results = []
|
242 |
-
|
243 |
-
|
244 |
log.info("步驟 3/5: 檢索與重排序...")
|
245 |
relevant_indices = {i for i, m in enumerate(self.state.meta) if m.get("drug_id") in drug_ids}
|
246 |
if not relevant_indices:
|
@@ -275,26 +287,41 @@ class RagPipeline:
|
|
275 |
else:
|
276 |
candidate_dict[i] = {"sem": 0.0, "bm": bm_score}
|
277 |
|
278 |
-
|
279 |
for i, scores in candidate_dict.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
section_name = self.state.meta[i].get("section", "其他")
|
281 |
section_weight = weights.get(section_name, 1.0)
|
282 |
-
fused_score = (
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
|
287 |
-
sub_reranked = self._rerank_with_crossencoder(q_orig,
|
288 |
|
|
|
|
|
289 |
for r in sub_reranked:
|
290 |
-
if r
|
291 |
-
continue
|
292 |
-
elif not r.get('chunk_id') and r['idx'] in {res['idx'] for res in all_reranked_results}:
|
293 |
continue
|
294 |
-
|
295 |
all_reranked_results.append(r)
|
296 |
-
if r.get('chunk_id'):
|
297 |
-
processed_chunk_ids.add(r['chunk_id'])
|
298 |
|
299 |
all_reranked_results.sort(key=lambda x: x['rerank_score'], reverse=True)
|
300 |
log.info(f"Reranker 最終選出 {len(all_reranked_results)} 個高品質候選。")
|
@@ -358,12 +385,21 @@ class RagPipeline:
|
|
358 |
return {"sub_queries": [query], "intents": []}
|
359 |
|
360 |
def _find_drug_ids_from_name(self, query: str, df: pd.DataFrame) -> List[str]:
|
|
|
|
|
|
|
361 |
candidates = extract_drug_candidates_from_query(query)
|
362 |
expanded = expand_aliases(candidates)
|
|
|
363 |
drug_ids = set()
|
364 |
for alias in expanded:
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
|
|
367 |
return list(drug_ids)
|
368 |
|
369 |
def _expand_query_with_llm(self, query: str, intents: List[str]) -> str:
|
@@ -407,8 +443,23 @@ class RagPipeline:
|
|
407 |
return [], []
|
408 |
q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
|
409 |
faiss.normalize_L2(q_emb)
|
|
|
410 |
distances, indices = index.search(q_emb, top_k)
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
def _rerank_with_crossencoder(self, query: str, candidates: List[Tuple], sentences: List[str], reranker, top_k: int, meta: List[Dict], threshold: float) -> List[Dict]:
|
414 |
if not candidates:
|
@@ -420,13 +471,14 @@ class RagPipeline:
|
|
420 |
scores = reranker.predict(pairs)
|
421 |
|
422 |
reranked = []
|
423 |
-
for (i,
|
424 |
-
if
|
425 |
reranked.append({
|
426 |
"idx": i,
|
427 |
-
"rerank_score":
|
428 |
-
"
|
429 |
-
"
|
|
|
430 |
"meta": meta[i],
|
431 |
"text": sentences[i]
|
432 |
})
|
@@ -474,6 +526,7 @@ class RagPipeline:
|
|
474 |
# ---------- FastAPI 事件與路由 ----------
|
475 |
app = FastAPI()
|
476 |
rag_pipeline = None
|
|
|
477 |
class AppConfig:
|
478 |
CHANNEL_ACCESS_TOKEN = os.getenv("LINE_CHANNEL_ACCESS_TOKEN")
|
479 |
CHANNEL_SECRET = os.getenv("LINE_CHANNEL_SECRET")
|
@@ -482,12 +535,21 @@ class AppConfig:
|
|
482 |
async def startup_event():
|
483 |
"""應用程式啟動時執行的任務"""
|
484 |
log.info("===== Application Startup =====")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
global rag_pipeline
|
486 |
rag_pipeline = RagPipeline(AppConfig)
|
487 |
rag_pipeline._load_data()
|
488 |
-
|
489 |
-
|
490 |
-
log.error("錯誤: LINE_CHANNEL_ACCESS_TOKEN 或 LINE_CHANNEL_SECRET 未設定!")
|
491 |
|
492 |
@app.get("/health", status_code=status.HTTP_200_OK)
|
493 |
async def health_check():
|
@@ -503,12 +565,15 @@ async def handle_webhook(request: Request, response: Response):
|
|
503 |
|
504 |
body = await request.body()
|
505 |
try:
|
506 |
-
|
507 |
-
|
|
|
508 |
raise HTTPException(status_code=403, detail="Invalid signature")
|
|
|
|
|
509 |
except Exception as e:
|
510 |
log.error(f"簽名驗證失敗: {e}")
|
511 |
-
raise HTTPException(status_code=
|
512 |
|
513 |
data = json.loads(body.decode('utf-8'))
|
514 |
for event in data.get("events", []):
|
@@ -518,7 +583,8 @@ async def handle_webhook(request: Request, response: Response):
|
|
518 |
|
519 |
if not user_text: continue
|
520 |
|
521 |
-
|
|
|
522 |
|
523 |
if reply_token:
|
524 |
line_reply(reply_token, answer)
|
@@ -526,21 +592,23 @@ async def handle_webhook(request: Request, response: Response):
|
|
526 |
return {"status": "ok"}
|
527 |
|
528 |
def line_reply(reply_token: str, text: str):
|
529 |
-
"""透過 LINE Message API
|
530 |
headers = {
|
531 |
"Content-Type": "application/json",
|
532 |
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}"
|
533 |
}
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
}
|
538 |
try:
|
539 |
-
requests.post("https://api.line.me/v2/bot/message/reply", headers=headers, json=data)
|
|
|
|
|
540 |
except Exception as e:
|
541 |
log.error(f"LINE API 回覆失敗: {e}")
|
542 |
|
543 |
-
# ----
|
544 |
def extract_drug_candidates_from_query(query: str) -> list:
|
545 |
query = re.sub(r"[A-Za-z]+", lambda m: m.group(0).lower(), query)
|
546 |
candidates = set()
|
@@ -552,8 +620,11 @@ def extract_drug_candidates_from_query(query: str) -> list:
|
|
552 |
clean_token = re.sub(r'[a-zA-Z0-9\s]+', '', token).strip()
|
553 |
if clean_token and clean_token.lower() not in DRUG_STOPWORDS:
|
554 |
candidates.add(clean_token)
|
555 |
-
|
556 |
-
|
|
|
|
|
|
|
557 |
for query_name, dataset_name in DRUG_NAME_MAPPING.items():
|
558 |
if query_name in query.lower():
|
559 |
candidates.add(dataset_name)
|
@@ -575,5 +646,3 @@ def expand_aliases(candidates: list) -> list:
|
|
575 |
if __name__ == "__main__":
|
576 |
port = int(os.getenv("PORT", 7860))
|
577 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
578 |
-
|
579 |
-
|
|
|
25 |
from typing import List, Dict, Any, Optional, Tuple, Union
|
26 |
from functools import lru_cache
|
27 |
import time
|
28 |
+
import textwrap
|
29 |
|
30 |
# ---------- 第三方函式庫 ----------
|
31 |
import numpy as np
|
|
|
40 |
from openai import OpenAI
|
41 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
42 |
import requests
|
43 |
+
from starlette.concurrency import run_in_threadpool
|
44 |
|
45 |
# ==== CONFIG (從環境變數載入,或使用預設值) ====
|
46 |
# 根據提供的檔案清單,將預設路徑設定為當前目錄
|
|
|
48 |
FAISS_INDEX = os.getenv("FAISS_INDEX", "drug_sentences.index")
|
49 |
SENTENCES_PKL = os.getenv("SENTENCES_PKL", "drug_sentences.pkl")
|
50 |
BM25_PKL = os.getenv("BM25_PKL", "bm25.pkl")
|
|
|
51 |
|
52 |
TOP_K_SENTENCES = int(os.getenv("TOP_K_SENTENCES", 30))
|
53 |
PRE_RERANK_K = int(os.getenv("PRE_RERANK_K", 30))
|
|
|
125 |
self.embedding_model = self._load_embedding_model()
|
126 |
self.reranker = self._load_reranker_model()
|
127 |
self.csv_path = self._ensure_csv_path(CSV_PATH)
|
|
|
128 |
|
129 |
def _load_embedding_model(self):
|
130 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
156 |
def _load_data(self):
|
157 |
"""在啟動時載入所有必要的模型與資料"""
|
158 |
log.info("開始載入資料與模型...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
# Load CSV and check for required columns
|
161 |
+
if not os.path.exists(self.csv_path):
|
162 |
+
raise FileNotFoundError(f"找不到 CSV 檔案於 {self.csv_path}")
|
163 |
+
|
164 |
+
self.df_csv = pd.read_csv(self.csv_path, dtype=str).fillna('')
|
165 |
+
required_cols = {"drug_id", "drug_name_norm", "section"}
|
166 |
+
missing_cols = required_cols - set(self.df_csv.columns)
|
167 |
+
if missing_cols:
|
168 |
+
raise ValueError(f"CSV 缺少必要欄位: {missing_cols}")
|
169 |
+
|
170 |
+
self.df_csv['drug_name_norm_normalized'] = (
|
171 |
+
self.df_csv['drug_name_norm'].str.lower().str.replace(r'[^\w\s]', '', regex=True).str.strip()
|
172 |
+
)
|
173 |
+
log.info(f"成功載入 CSV: {self.csv_path} (rows={len(self.df_csv)})")
|
174 |
+
|
175 |
+
# Load corpus and index
|
176 |
self.state.index, self.state.sentences, self.state.meta = self._load_or_build_sentence_index()
|
177 |
self.state.bm25 = self._ensure_bm25_index()
|
178 |
+
|
179 |
+
# Check for BM25 alignment
|
180 |
+
bm_n = len(self.state.bm25.corpus)
|
181 |
+
sent_n = len(self.state.sentences)
|
182 |
+
if bm_n != sent_n:
|
183 |
+
raise RuntimeError(f"BM25 文件數 ({bm_n}) 與 sentences ({sent_n}) 不一致,請重新生成索引。")
|
184 |
log.info("所有模型與資料載入完成。")
|
185 |
|
186 |
def _load_or_build_sentence_index(self):
|
|
|
194 |
return index, sentences, meta
|
195 |
|
196 |
log.info("索引檔案不存在,正在從 CSV 重新建立...")
|
197 |
+
# This function should be run by a separate script, not here.
|
198 |
raise RuntimeError("FAISS 和句子 PKL 檔案未找到,請先執行索引生成腳本。")
|
199 |
|
200 |
def _ensure_bm25_index(self):
|
|
|
203 |
try:
|
204 |
with open(BM25_PKL, "rb") as f:
|
205 |
data = pickle.load(f)
|
206 |
+
bm25 = data.get("bm25") if isinstance(data, dict) else data
|
|
|
207 |
if not hasattr(bm25, 'get_scores'):
|
208 |
raise ValueError("載入的 BM25 索引無效。")
|
209 |
+
log.info(f"成功載入 BM25 索引,包含 {len(bm25.corpus)} 篇文件。")
|
210 |
return bm25
|
211 |
except Exception as e:
|
212 |
log.error(f"載入 BM25 索引失敗 ({e}),請檢查檔案格式。")
|
|
|
223 |
messages=messages,
|
224 |
temperature=temperature,
|
225 |
max_tokens=max_tokens,
|
226 |
+
stop=LLM_MODEL_CONFIG.get("stop_tokens") or None,
|
227 |
)
|
228 |
return response.choices[0].message.content
|
229 |
except Exception as e:
|
230 |
log.error(f"LLM API 呼叫失敗: {e}")
|
231 |
raise
|
232 |
|
233 |
+
def answer_question(self, q_orig: str) -> str:
|
234 |
+
"""處理使用者問題的完整流程 (同步版本)"""
|
235 |
start_time = time.time()
|
236 |
log.info(f"===== 處理新查詢: '{q_orig}' =====")
|
237 |
|
|
|
252 |
log.info(f"分析結果 - 意圖: {intents}")
|
253 |
|
254 |
all_reranked_results = []
|
255 |
+
|
|
|
256 |
log.info("步驟 3/5: 檢索與重排序...")
|
257 |
relevant_indices = {i for i, m in enumerate(self.state.meta) if m.get("drug_id") in drug_ids}
|
258 |
if not relevant_indices:
|
|
|
287 |
else:
|
288 |
candidate_dict[i] = {"sem": 0.0, "bm": bm_score}
|
289 |
|
290 |
+
candidates_list = []
|
291 |
for i, scores in candidate_dict.items():
|
292 |
+
candidates_list.append((i, scores["sem"], scores["bm"]))
|
293 |
+
|
294 |
+
if not candidates_list:
|
295 |
+
continue
|
296 |
+
|
297 |
+
# Normalize scores
|
298 |
+
sem_vals = np.array([s for _, s, _ in candidates_list], dtype=np.float32)
|
299 |
+
bm_vals = np.array([b for _, _, b in candidates_list], dtype=np.float32)
|
300 |
+
|
301 |
+
def norm(x):
|
302 |
+
rng = x.max() - x.min()
|
303 |
+
return (x - x.min()) / (rng + 1e-8)
|
304 |
+
|
305 |
+
sem_n = norm(sem_vals)
|
306 |
+
bm_n = norm(bm_vals)
|
307 |
+
|
308 |
+
fused_candidates = []
|
309 |
+
for idx, (i, s_raw, b_raw) in enumerate(candidates_list):
|
310 |
section_name = self.state.meta[i].get("section", "其他")
|
311 |
section_weight = weights.get(section_name, 1.0)
|
312 |
+
fused_score = (sem_n[idx] * 0.5 + bm_n[idx] * 0.4) * section_weight
|
313 |
+
fused_candidates.append((i, fused_score, s_raw, b_raw))
|
314 |
+
|
315 |
+
fused_candidates.sort(key=lambda x: x[1], reverse=True)
|
316 |
|
317 |
+
sub_reranked = self._rerank_with_crossencoder(q_orig, fused_candidates, self.state.sentences, self.reranker, TOP_K_SENTENCES, self.state.meta, RERANK_THRESHOLD)
|
318 |
|
319 |
+
# De-duplicate using index
|
320 |
+
processed_indices = {res['idx'] for res in all_reranked_results}
|
321 |
for r in sub_reranked:
|
322 |
+
if r['idx'] in processed_indices:
|
|
|
|
|
323 |
continue
|
|
|
324 |
all_reranked_results.append(r)
|
|
|
|
|
325 |
|
326 |
all_reranked_results.sort(key=lambda x: x['rerank_score'], reverse=True)
|
327 |
log.info(f"Reranker 最終選出 {len(all_reranked_results)} 個高品質候選。")
|
|
|
385 |
return {"sub_queries": [query], "intents": []}
|
386 |
|
387 |
def _find_drug_ids_from_name(self, query: str, df: pd.DataFrame) -> List[str]:
|
388 |
+
if df is None:
|
389 |
+
return []
|
390 |
+
|
391 |
candidates = extract_drug_candidates_from_query(query)
|
392 |
expanded = expand_aliases(candidates)
|
393 |
+
|
394 |
drug_ids = set()
|
395 |
for alias in expanded:
|
396 |
+
try:
|
397 |
+
# Use regex=False for literal matching, which is safer
|
398 |
+
mask = df['drug_name_norm_normalized'].str.contains(alias.lower(), case=False, regex=False, na=False)
|
399 |
+
matches = df.loc[mask, 'drug_id'].dropna().unique().tolist()
|
400 |
+
drug_ids.update(matches)
|
401 |
+
except Exception as e:
|
402 |
+
log.warning(f"Failed to match '{alias}': {e}. Skipping this alias.")
|
403 |
return list(drug_ids)
|
404 |
|
405 |
def _expand_query_with_llm(self, query: str, intents: List[str]) -> str:
|
|
|
443 |
return [], []
|
444 |
q_emb = embedding_model.encode([query], convert_to_numpy=True).astype("float32")
|
445 |
faiss.normalize_L2(q_emb)
|
446 |
+
|
447 |
distances, indices = index.search(q_emb, top_k)
|
448 |
+
|
449 |
+
# Check for metric type to ensure scores are "higher is better"
|
450 |
+
metric = getattr(index, "metric_type", None)
|
451 |
+
try:
|
452 |
+
import faiss
|
453 |
+
METRIC_L2 = faiss.METRIC_L2
|
454 |
+
except Exception:
|
455 |
+
METRIC_L2 = 1
|
456 |
+
|
457 |
+
if metric == METRIC_L2:
|
458 |
+
scores = (-distances[0]).tolist() # L2 distance is smaller for closer points
|
459 |
+
else:
|
460 |
+
scores = distances[0].tolist() # Inner product (cosine) is larger for closer points
|
461 |
+
|
462 |
+
return indices[0].tolist(), scores
|
463 |
|
464 |
def _rerank_with_crossencoder(self, query: str, candidates: List[Tuple], sentences: List[str], reranker, top_k: int, meta: List[Dict], threshold: float) -> List[Dict]:
|
465 |
if not candidates:
|
|
|
471 |
scores = reranker.predict(pairs)
|
472 |
|
473 |
reranked = []
|
474 |
+
for (i, fused_score, sem_score, bm_score), rerank_score in zip(limited_candidates, scores):
|
475 |
+
if rerank_score >= threshold:
|
476 |
reranked.append({
|
477 |
"idx": i,
|
478 |
+
"rerank_score": rerank_score,
|
479 |
+
"fused_score": fused_score,
|
480 |
+
"sem_score": sem_score,
|
481 |
+
"bm_score": bm_score,
|
482 |
"meta": meta[i],
|
483 |
"text": sentences[i]
|
484 |
})
|
|
|
526 |
# ---------- FastAPI 事件與路由 ----------
|
527 |
app = FastAPI()
|
528 |
rag_pipeline = None
|
529 |
+
|
530 |
class AppConfig:
|
531 |
CHANNEL_ACCESS_TOKEN = os.getenv("LINE_CHANNEL_ACCESS_TOKEN")
|
532 |
CHANNEL_SECRET = os.getenv("LINE_CHANNEL_SECRET")
|
|
|
535 |
async def startup_event():
|
536 |
"""應用程式啟動時執行的任務"""
|
537 |
log.info("===== Application Startup =====")
|
538 |
+
missing = []
|
539 |
+
if not AppConfig.CHANNEL_ACCESS_TOKEN: missing.append("LINE_CHANNEL_ACCESS_TOKEN")
|
540 |
+
if not AppConfig.CHANNEL_SECRET: missing.append("LINE_CHANNEL_SECRET")
|
541 |
+
if not LLM_API_CONFIG.get("api_key"): missing.append("LITELLM_API_KEY")
|
542 |
+
if not LLM_API_CONFIG.get("base_url"): missing.append("LITELLM_BASE_URL")
|
543 |
+
if not LLM_API_CONFIG.get("model"): missing.append("LM_MODEL")
|
544 |
+
if missing:
|
545 |
+
log.error(f"缺少必要環境變數:{missing}")
|
546 |
+
raise RuntimeError(f"Missing required environment variables: {missing}")
|
547 |
+
|
548 |
global rag_pipeline
|
549 |
rag_pipeline = RagPipeline(AppConfig)
|
550 |
rag_pipeline._load_data()
|
551 |
+
log.info("啟動檢查完成。")
|
552 |
+
|
|
|
553 |
|
554 |
@app.get("/health", status_code=status.HTTP_200_OK)
|
555 |
async def health_check():
|
|
|
565 |
|
566 |
body = await request.body()
|
567 |
try:
|
568 |
+
digest = hmac.new(AppConfig.CHANNEL_SECRET.encode("utf-8"), body, hashlib.sha256).digest()
|
569 |
+
expected = base64.b64encode(digest).decode()
|
570 |
+
if not hmac.compare_digest(expected, signature):
|
571 |
raise HTTPException(status_code=403, detail="Invalid signature")
|
572 |
+
except HTTPException:
|
573 |
+
raise
|
574 |
except Exception as e:
|
575 |
log.error(f"簽名驗證失敗: {e}")
|
576 |
+
raise HTTPException(status_code=500, detail="Signature verification error")
|
577 |
|
578 |
data = json.loads(body.decode('utf-8'))
|
579 |
for event in data.get("events", []):
|
|
|
583 |
|
584 |
if not user_text: continue
|
585 |
|
586 |
+
# Offload heavy work to a thread pool
|
587 |
+
answer = await run_in_threadpool(rag_pipeline.answer_question, user_text)
|
588 |
|
589 |
if reply_token:
|
590 |
line_reply(reply_token, answer)
|
|
|
592 |
return {"status": "ok"}
|
593 |
|
594 |
def line_reply(reply_token: str, text: str):
|
595 |
+
"""透過 LINE Message API 回覆訊息,並進行分塊以避免長度限制"""
|
596 |
headers = {
|
597 |
"Content-Type": "application/json",
|
598 |
"Authorization": f"Bearer {AppConfig.CHANNEL_ACCESS_TOKEN}"
|
599 |
}
|
600 |
+
# LINE 文本長度上限約為 5000 字元
|
601 |
+
chunks = textwrap.wrap(text, 4900)
|
602 |
+
messages = [{"type": "text", "text": c} for c in chunks] or [{"type": "text", "text": text[:4900]}]
|
603 |
+
data = {"replyToken": reply_token, "messages": messages}
|
604 |
try:
|
605 |
+
r = requests.post("https://api.line.me/v2/bot/message/reply", headers=headers, json=data, timeout=10)
|
606 |
+
if r.status_code >= 300:
|
607 |
+
log.error(f"LINE API 回覆失敗: {r.status_code} {r.text}")
|
608 |
except Exception as e:
|
609 |
log.error(f"LINE API 回覆失敗: {e}")
|
610 |
|
611 |
+
# ---- 額外工具函式 ----
|
612 |
def extract_drug_candidates_from_query(query: str) -> list:
|
613 |
query = re.sub(r"[A-Za-z]+", lambda m: m.group(0).lower(), query)
|
614 |
candidates = set()
|
|
|
620 |
clean_token = re.sub(r'[a-zA-Z0-9\s]+', '', token).strip()
|
621 |
if clean_token and clean_token.lower() not in DRUG_STOPWORDS:
|
622 |
candidates.add(clean_token)
|
623 |
+
|
624 |
+
# Avoid adding the whole drug_part to prevent regex errors
|
625 |
+
# if drug_part.strip():
|
626 |
+
# candidates.add(drug_part.strip())
|
627 |
+
|
628 |
for query_name, dataset_name in DRUG_NAME_MAPPING.items():
|
629 |
if query_name in query.lower():
|
630 |
candidates.add(dataset_name)
|
|
|
646 |
if __name__ == "__main__":
|
647 |
port = int(os.getenv("PORT", 7860))
|
648 |
uvicorn.run(app, host="0.0.0.0", port=port)
|
|
|
|