Song commited on
Commit
4cc218a
·
1 Parent(s): 7c4588b
Files changed (1) hide show
  1. app.py +94 -38
app.py CHANGED
@@ -28,6 +28,7 @@ import time
28
  from typing import List, Dict, Any, Optional, Tuple, Union
29
  from functools import lru_cache
30
  from dataclasses import dataclass, field
 
31
 
32
  # ---------- 第三方函式庫 ----------
33
  import numpy as np
@@ -163,9 +164,8 @@ class RerankResult:
163
 
164
  # ---------- 核心 RAG 邏輯 ----------
165
  class RagPipeline:
166
- def __init__(self, config):
167
- self.config = config
168
- self.state = type('state', (), {})()
169
  if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
170
  raise ValueError("LLM API Key or Base URL is not configured.")
171
  self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
@@ -174,6 +174,7 @@ class RagPipeline:
174
 
175
  self.drug_name_to_ids: Dict[str, List[str]] = {}
176
  self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()}
 
177
 
178
  def _load_model(self, model_class, model_name: str, model_type: str):
179
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -182,10 +183,19 @@ class RagPipeline:
182
  return model_class(model_name, device=device)
183
  except Exception as e:
184
  log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。")
185
- return model_class(model_name, device="cpu")
 
 
 
 
186
 
187
  def load_data(self):
188
  log.info("開始載入資料與模型...")
 
 
 
 
 
189
  try:
190
  self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('')
191
  # [MODIFIED] 增加必要欄位檢查
@@ -236,22 +246,28 @@ class RagPipeline:
236
  for part in parts:
237
  if re.search(r'[\u4e00-\u9fff]', part):
238
  self.drug_vocab["zh"].add(part)
239
- try:
240
- jieba.add_word(part, freq=2_000_000)
241
- except Exception:
242
- pass
 
 
243
  else:
244
  self.drug_vocab["en"].add(part)
245
  for alias in DRUG_NAME_MAPPING:
246
  self.drug_vocab["en"].add(alias.lower())
247
  if re.search(r'[\u4e00-\u9fff]', alias):
248
- try:
249
- jieba.add_word(alias, freq=2_000_000)
250
- except Exception:
251
- pass
 
252
 
253
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
254
  def _llm_call(self, messages, **kwargs) -> str:
 
 
 
255
  try:
256
  config = {**LLM_MODEL_CONFIG, **kwargs}
257
  response = self.llm_client.chat.completions.create(
@@ -260,15 +276,25 @@ class RagPipeline:
260
  temperature=config["temperature"],
261
  max_tokens=config["max_tokens"],
262
  )
263
- content = response.choices[0].message.content
264
- # [MODIFIED] 確保回傳值為非空字串
 
 
 
 
 
265
  if not isinstance(content, str) or not content.strip():
 
266
  raise ValueError("LLM response content is empty or not a string.")
 
 
 
267
  return content
268
  except Exception as e:
269
- log.error(f"LLM API 呼叫失敗: {e}")
270
  raise
271
 
 
272
  def answer_question(self, q_orig: str) -> str:
273
  start_time = time.time()
274
  log.info(f"===== 處理新查詢: '{q_orig}' =====")
@@ -277,17 +303,32 @@ class RagPipeline:
277
  if not drug_ids:
278
  log.info("找不到藥品 ID,無法回答。")
279
  return f"抱歉,資料庫中找不到該藥品。請確認藥品名稱,或直接諮詢醫師/藥師。{DISCLAIMER}"
280
- log.info(f"步驟 1/5: 找到藥品 ID: {drug_ids}")
 
281
 
282
  analysis = self._analyze_query(q_orig)
283
  sub_queries, intents = analysis.get("sub_queries", [q_orig]), analysis.get("intents", [])
284
- log.info(f"步驟 2/5: 意圖分析完成。子問題: {sub_queries}, 意圖: {intents}")
285
-
 
 
286
  all_candidates = self._retrieve_candidates_for_all_queries(drug_ids, sub_queries, intents)
287
- log.info(f"步驟 3/5: 檢索完成。所有子查詢共找到 {len(all_candidates)} 個不重複候選 chunks")
288
-
289
- reranked_results = self._rerank_with_crossencoder(q_orig, all_candidates)
290
- log.info(f"步驟 4/5: Reranker 最終選出 {len(reranked_results)} 個高品質候選。")
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  context = self._build_context(reranked_results)
293
  if not context:
@@ -298,7 +339,7 @@ class RagPipeline:
298
  answer = self._llm_call([{"role": "user", "content": prompt}])
299
 
300
  final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
301
- log.info(f"步驟 5/5: 答案生成完成。答案長度: {len(answer.strip())} 字。")
302
  log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====")
303
  return final_answer
304
 
@@ -306,6 +347,10 @@ class RagPipeline:
306
  log.error(f"處理查詢 '{q_orig}' 時發生嚴重錯誤: {e}", exc_info=True)
307
  return f"處理您的問題時發生內部錯誤,請稍後再試。{DISCLAIMER}"
308
 
 
 
 
 
309
  @lru_cache(maxsize=128)
310
  def _find_drug_ids_from_name(self, query: str) -> List[str]:
311
  q = query.lower()
@@ -420,12 +465,13 @@ class RagPipeline:
420
  try:
421
  expanded_query = self._llm_call([{"role": "user", "content": prompt}])
422
  if expanded_query and expanded_query.strip():
 
423
  return expanded_query
424
  else:
425
- log.warning(f"Query expansion for '{query}' returned an empty result. Using original query.")
426
  return query
427
  except Exception as e:
428
- log.error(f"Query expansion for '{query}' failed: {e}. Using original query.")
429
  return query
430
 
431
  def _rerank_with_crossencoder(self, query: str, candidates: List[FusedCandidate]) -> List[RerankResult]:
@@ -460,31 +506,41 @@ class RagPipeline:
460
 
461
  # [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
462
  def _safe_json_parse(self, s: str, default: Any = None) -> Any:
463
- m = re.search(r'\{.*?\}', s, re.DOTALL) # 非貪婪
464
- if m:
465
- s = m.group(0)
466
  try:
 
467
  return json.loads(s)
468
- except Exception:
469
- log.warning(f"無法解析 LLM 回傳的 JSON: {s[:200]}...")
470
- return default
 
 
 
 
 
 
 
471
 
472
  # ---------- FastAPI 事件與路由 ----------
473
- app = FastAPI()
474
- rag_pipeline: Optional[RagPipeline] = None
475
-
476
  # [MODIFIED] 將 LINE 配置集中管理並進行啟動時檢查
477
  class AppConfig:
478
  CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
479
  CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
480
 
481
- @app.on_event("startup")
482
- async def startup_event():
483
- global rag_pipeline
 
 
484
  _require_llm_config()
485
- rag_pipeline = RagPipeline(AppConfig)
 
486
  rag_pipeline.load_data()
487
  log.info("啟動完成,服務準備就緒。")
 
 
 
 
 
488
 
489
  @app.post("/webhook")
490
  async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
 
28
  from typing import List, Dict, Any, Optional, Tuple, Union
29
  from functools import lru_cache
30
  from dataclasses import dataclass, field
31
+ from contextlib import asynccontextmanager
32
 
33
  # ---------- 第三方函式庫 ----------
34
  import numpy as np
 
164
 
165
  # ---------- 核心 RAG 邏輯 ----------
166
  class RagPipeline:
167
+ def __init__(self):
168
+ # [MODIFIED] 不再傳入 AppConfig,直接引用
 
169
  if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
170
  raise ValueError("LLM API Key or Base URL is not configured.")
171
  self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
 
174
 
175
  self.drug_name_to_ids: Dict[str, List[str]] = {}
176
  self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()}
177
+ self.state = type('state', (), {})()
178
 
179
  def _load_model(self, model_class, model_name: str, model_type: str):
180
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
183
  return model_class(model_name, device=device)
184
  except Exception as e:
185
  log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。")
186
+ try:
187
+ return model_class(model_name, device="cpu")
188
+ except Exception as e_cpu:
189
+ log.error(f"切換至 CPU 仍無法載入模型: {model_name}。請確認模型路徑或網路連線。錯誤訊息: {e_cpu}")
190
+ raise RuntimeError(f"模型載入失敗: {model_name}")
191
 
192
  def load_data(self):
193
  log.info("開始載入資料與模型...")
194
+ # [MODIFIED] 增加檔案存在性檢查
195
+ for path in [CSV_PATH, FAISS_INDEX, SENTENCES_PKL, BM25_PKL]:
196
+ if not pathlib.Path(path).exists():
197
+ raise FileNotFoundError(f"必要的資料檔案不存在: {path}")
198
+
199
  try:
200
  self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('')
201
  # [MODIFIED] 增加必要欄位檢查
 
246
  for part in parts:
247
  if re.search(r'[\u4e00-\u9fff]', part):
248
  self.drug_vocab["zh"].add(part)
249
+ # [MODIFIED] 檢查詞彙是否已存在
250
+ if part not in jieba.dt.FREQ:
251
+ try:
252
+ jieba.add_word(part, freq=2_000_000)
253
+ except Exception:
254
+ pass
255
  else:
256
  self.drug_vocab["en"].add(part)
257
  for alias in DRUG_NAME_MAPPING:
258
  self.drug_vocab["en"].add(alias.lower())
259
  if re.search(r'[\u4e00-\u9fff]', alias):
260
+ if alias not in jieba.dt.FREQ:
261
+ try:
262
+ jieba.add_word(alias, freq=2_000_000)
263
+ except Exception:
264
+ pass
265
 
266
  @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
267
  def _llm_call(self, messages, **kwargs) -> str:
268
+ start_time = time.time()
269
+ log.info(f"LLM 呼叫開始. 模型: {LLM_API_CONFIG['model']}, max_tokens: {kwargs.get('max_tokens', 'N/A')}, temperature: {kwargs.get('temperature', 'N/A')}")
270
+
271
  try:
272
  config = {**LLM_MODEL_CONFIG, **kwargs}
273
  response = self.llm_client.chat.completions.create(
 
276
  temperature=config["temperature"],
277
  max_tokens=config["max_tokens"],
278
  )
279
+
280
+ # [MODIFIED] 檢查回應結構並使用 getattr 安全地獲取內容
281
+ if not response or not response.choices or not response.choices[0].message:
282
+ log.error(f"LLM 呼叫成功 (200 OK),但回傳的 JSON 結構不完整。回傳: {response.model_dump_json() if response else 'None'}")
283
+ raise ValueError("LLM response content is empty or not a string.")
284
+
285
+ content = getattr(response.choices[0].message, "content", None)
286
  if not isinstance(content, str) or not content.strip():
287
+ log.error(f"LLM 呼叫成功 (200 OK),但回傳內容為空。Response: {content}")
288
  raise ValueError("LLM response content is empty or not a string.")
289
+
290
+ elapsed = time.time() - start_time
291
+ log.info(f"LLM 呼叫完成,耗時: {elapsed:.2f} 秒。��容長度: {len(content)} 字。")
292
  return content
293
  except Exception as e:
294
+ log.error(f"LLM API 呼叫失敗: {e}", exc_info=True)
295
  raise
296
 
297
+ # [MODIFIED] 實現動態流程,根據查詢複雜度決定是否使用 Reranker
298
  def answer_question(self, q_orig: str) -> str:
299
  start_time = time.time()
300
  log.info(f"===== 處理新查詢: '{q_orig}' =====")
 
303
  if not drug_ids:
304
  log.info("找不到藥品 ID,無法回答。")
305
  return f"抱歉,資料庫中找不到該藥品。請確認藥品名稱,或直接諮詢醫師/藥師。{DISCLAIMER}"
306
+ log.info(f"步驟 1/5: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒")
307
+ step_start = time.time()
308
 
309
  analysis = self._analyze_query(q_orig)
310
  sub_queries, intents = analysis.get("sub_queries", [q_orig]), analysis.get("intents", [])
311
+ is_simple_query = self._is_simple_query(sub_queries, intents)
312
+ log.info(f"步驟 2/5: 意圖分析完成。子問題: {sub_queries}, 意圖: {intents}。判定為簡單查詢: {is_simple_query}。耗時: {time.time() - step_start:.2f} 秒")
313
+ step_start = time.time()
314
+
315
  all_candidates = self._retrieve_candidates_for_all_queries(drug_ids, sub_queries, intents)
316
+ log.info(f"步驟 3/5: 檢索完成。所有子查詢共找到 {len(all_candidates)} 個不重複候選 chunks。耗時: {time.time() - step_start:.2f} 秒")
317
+ step_start = time.time()
318
+
319
+ if is_simple_query:
320
+ log.info("偵測到簡單查詢,跳過 Reranker 步驟。")
321
+ final_candidates = all_candidates[:TOP_K_SENTENCES]
322
+ reranked_results = [
323
+ RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx])
324
+ for c in final_candidates
325
+ ]
326
+ else:
327
+ log.info("偵測到複雜查詢,執行 Reranker。")
328
+ reranked_results = self._rerank_with_crossencoder(q_orig, all_candidates)
329
+
330
+ log.info(f"步驟 4/5: 最終選出 {len(reranked_results)} 個高品質候選。耗時: {time.time() - step_start:.2f} 秒")
331
+ step_start = time.time()
332
 
333
  context = self._build_context(reranked_results)
334
  if not context:
 
339
  answer = self._llm_call([{"role": "user", "content": prompt}])
340
 
341
  final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
342
+ log.info(f"步驟 5/5: 答案生成完成。答案長度: {len(answer.strip())} 字。耗時: {time.time() - step_start:.2f} 秒")
343
  log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====")
344
  return final_answer
345
 
 
347
  log.error(f"處理查詢 '{q_orig}' 時發生嚴重錯誤: {e}", exc_info=True)
348
  return f"處理您的問題時發生內部錯誤,請稍後再試。{DISCLAIMER}"
349
 
350
+ def _is_simple_query(self, sub_queries: List[str], intents: List[str]) -> bool:
351
+ # 如果意圖分析回傳的子查詢數量 <= 1,且意圖分類數量也 <= 1,則判定為簡單問題
352
+ return len(sub_queries) <= 1 and len(intents) <= 1
353
+
354
  @lru_cache(maxsize=128)
355
  def _find_drug_ids_from_name(self, query: str) -> List[str]:
356
  q = query.lower()
 
465
  try:
466
  expanded_query = self._llm_call([{"role": "user", "content": prompt}])
467
  if expanded_query and expanded_query.strip():
468
+ log.info(f"查詢擴展成功。原始: '{query}', 擴展後: '{expanded_query}'")
469
  return expanded_query
470
  else:
471
+ log.warning(f"查詢擴展回傳空內容。原始查詢: '{query}'。將使用原始查詢。")
472
  return query
473
  except Exception as e:
474
+ log.error(f"查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。")
475
  return query
476
 
477
  def _rerank_with_crossencoder(self, query: str, candidates: List[FusedCandidate]) -> List[RerankResult]:
 
506
 
507
  # [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
508
  def _safe_json_parse(self, s: str, default: Any = None) -> Any:
 
 
 
509
  try:
510
+ # 嘗試解析完整字串
511
  return json.loads(s)
512
+ except json.JSONDecodeError:
513
+ log.warning(f"無法解析完整 JSON。嘗試從字串中提取: {s[:200]}...")
514
+ # 如果失敗,嘗試用 regex 提取第一個 JSON 物件
515
+ m = re.search(r'\{.*?\}', s, re.DOTALL)
516
+ if m:
517
+ try:
518
+ return json.loads(m.group(0))
519
+ except json.JSONDecodeError:
520
+ log.warning(f"提取的 JSON 仍無法解析: {m.group(0)[:100]}...")
521
+ return default
522
 
523
  # ---------- FastAPI 事件與路由 ----------
 
 
 
524
  # [MODIFIED] 將 LINE 配置集中管理並進行啟動時檢查
525
  class AppConfig:
526
  CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
527
  CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
528
 
529
+ rag_pipeline: Optional[RagPipeline] = None
530
+
531
+ # [MODIFIED] 使用 lifespan context manager
532
+ @asynccontextmanager
533
+ async def lifespan(app: FastAPI):
534
  _require_llm_config()
535
+ global rag_pipeline
536
+ rag_pipeline = RagPipeline()
537
  rag_pipeline.load_data()
538
  log.info("啟動完成,服務準備就緒。")
539
+ yield
540
+ # 若有資源需要關閉可在這裡實作
541
+ log.info("服務關閉中。")
542
+
543
+ app = FastAPI(lifespan=lifespan)
544
 
545
  @app.post("/webhook")
546
  async def handle_webhook(request: Request, background_tasks: BackgroundTasks):