Spaces:
Sleeping
Sleeping
Song
commited on
Commit
·
4cc218a
1
Parent(s):
7c4588b
hi
Browse files
app.py
CHANGED
@@ -28,6 +28,7 @@ import time
|
|
28 |
from typing import List, Dict, Any, Optional, Tuple, Union
|
29 |
from functools import lru_cache
|
30 |
from dataclasses import dataclass, field
|
|
|
31 |
|
32 |
# ---------- 第三方函式庫 ----------
|
33 |
import numpy as np
|
@@ -163,9 +164,8 @@ class RerankResult:
|
|
163 |
|
164 |
# ---------- 核心 RAG 邏輯 ----------
|
165 |
class RagPipeline:
|
166 |
-
def __init__(self
|
167 |
-
|
168 |
-
self.state = type('state', (), {})()
|
169 |
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
|
170 |
raise ValueError("LLM API Key or Base URL is not configured.")
|
171 |
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
|
@@ -174,6 +174,7 @@ class RagPipeline:
|
|
174 |
|
175 |
self.drug_name_to_ids: Dict[str, List[str]] = {}
|
176 |
self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()}
|
|
|
177 |
|
178 |
def _load_model(self, model_class, model_name: str, model_type: str):
|
179 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -182,10 +183,19 @@ class RagPipeline:
|
|
182 |
return model_class(model_name, device=device)
|
183 |
except Exception as e:
|
184 |
log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。")
|
185 |
-
|
|
|
|
|
|
|
|
|
186 |
|
187 |
def load_data(self):
|
188 |
log.info("開始載入資料與模型...")
|
|
|
|
|
|
|
|
|
|
|
189 |
try:
|
190 |
self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('')
|
191 |
# [MODIFIED] 增加必要欄位檢查
|
@@ -236,22 +246,28 @@ class RagPipeline:
|
|
236 |
for part in parts:
|
237 |
if re.search(r'[\u4e00-\u9fff]', part):
|
238 |
self.drug_vocab["zh"].add(part)
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
243 |
else:
|
244 |
self.drug_vocab["en"].add(part)
|
245 |
for alias in DRUG_NAME_MAPPING:
|
246 |
self.drug_vocab["en"].add(alias.lower())
|
247 |
if re.search(r'[\u4e00-\u9fff]', alias):
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
252 |
|
253 |
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
254 |
def _llm_call(self, messages, **kwargs) -> str:
|
|
|
|
|
|
|
255 |
try:
|
256 |
config = {**LLM_MODEL_CONFIG, **kwargs}
|
257 |
response = self.llm_client.chat.completions.create(
|
@@ -260,15 +276,25 @@ class RagPipeline:
|
|
260 |
temperature=config["temperature"],
|
261 |
max_tokens=config["max_tokens"],
|
262 |
)
|
263 |
-
|
264 |
-
# [MODIFIED]
|
|
|
|
|
|
|
|
|
|
|
265 |
if not isinstance(content, str) or not content.strip():
|
|
|
266 |
raise ValueError("LLM response content is empty or not a string.")
|
|
|
|
|
|
|
267 |
return content
|
268 |
except Exception as e:
|
269 |
-
log.error(f"LLM API 呼叫失敗: {e}")
|
270 |
raise
|
271 |
|
|
|
272 |
def answer_question(self, q_orig: str) -> str:
|
273 |
start_time = time.time()
|
274 |
log.info(f"===== 處理新查詢: '{q_orig}' =====")
|
@@ -277,17 +303,32 @@ class RagPipeline:
|
|
277 |
if not drug_ids:
|
278 |
log.info("找不到藥品 ID,無法回答。")
|
279 |
return f"抱歉,資料庫中找不到該藥品。請確認藥品名稱,或直接諮詢醫師/藥師。{DISCLAIMER}"
|
280 |
-
log.info(f"步驟 1/5: 找到藥品 ID: {drug_ids}")
|
|
|
281 |
|
282 |
analysis = self._analyze_query(q_orig)
|
283 |
sub_queries, intents = analysis.get("sub_queries", [q_orig]), analysis.get("intents", [])
|
284 |
-
|
285 |
-
|
|
|
|
|
286 |
all_candidates = self._retrieve_candidates_for_all_queries(drug_ids, sub_queries, intents)
|
287 |
-
log.info(f"步驟 3/5: 檢索完成。所有子查詢共找到 {len(all_candidates)} 個不重複候選 chunks
|
288 |
-
|
289 |
-
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
context = self._build_context(reranked_results)
|
293 |
if not context:
|
@@ -298,7 +339,7 @@ class RagPipeline:
|
|
298 |
answer = self._llm_call([{"role": "user", "content": prompt}])
|
299 |
|
300 |
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
301 |
-
log.info(f"步驟 5/5: 答案生成完成。答案長度: {len(answer.strip())}
|
302 |
log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====")
|
303 |
return final_answer
|
304 |
|
@@ -306,6 +347,10 @@ class RagPipeline:
|
|
306 |
log.error(f"處理查詢 '{q_orig}' 時發生嚴重錯誤: {e}", exc_info=True)
|
307 |
return f"處理您的問題時發生內部錯誤,請稍後再試。{DISCLAIMER}"
|
308 |
|
|
|
|
|
|
|
|
|
309 |
@lru_cache(maxsize=128)
|
310 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
311 |
q = query.lower()
|
@@ -420,12 +465,13 @@ class RagPipeline:
|
|
420 |
try:
|
421 |
expanded_query = self._llm_call([{"role": "user", "content": prompt}])
|
422 |
if expanded_query and expanded_query.strip():
|
|
|
423 |
return expanded_query
|
424 |
else:
|
425 |
-
log.warning(f"
|
426 |
return query
|
427 |
except Exception as e:
|
428 |
-
log.error(f"
|
429 |
return query
|
430 |
|
431 |
def _rerank_with_crossencoder(self, query: str, candidates: List[FusedCandidate]) -> List[RerankResult]:
|
@@ -460,31 +506,41 @@ class RagPipeline:
|
|
460 |
|
461 |
# [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
|
462 |
def _safe_json_parse(self, s: str, default: Any = None) -> Any:
|
463 |
-
m = re.search(r'\{.*?\}', s, re.DOTALL) # 非貪婪
|
464 |
-
if m:
|
465 |
-
s = m.group(0)
|
466 |
try:
|
|
|
467 |
return json.loads(s)
|
468 |
-
except
|
469 |
-
log.warning(f"
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
|
472 |
# ---------- FastAPI 事件與路由 ----------
|
473 |
-
app = FastAPI()
|
474 |
-
rag_pipeline: Optional[RagPipeline] = None
|
475 |
-
|
476 |
# [MODIFIED] 將 LINE 配置集中管理並進行啟動時檢查
|
477 |
class AppConfig:
|
478 |
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
|
479 |
CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
|
480 |
|
481 |
-
|
482 |
-
|
483 |
-
|
|
|
|
|
484 |
_require_llm_config()
|
485 |
-
rag_pipeline
|
|
|
486 |
rag_pipeline.load_data()
|
487 |
log.info("啟動完成,服務準備就緒。")
|
|
|
|
|
|
|
|
|
|
|
488 |
|
489 |
@app.post("/webhook")
|
490 |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|
|
|
28 |
from typing import List, Dict, Any, Optional, Tuple, Union
|
29 |
from functools import lru_cache
|
30 |
from dataclasses import dataclass, field
|
31 |
+
from contextlib import asynccontextmanager
|
32 |
|
33 |
# ---------- 第三方函式庫 ----------
|
34 |
import numpy as np
|
|
|
164 |
|
165 |
# ---------- 核心 RAG 邏輯 ----------
|
166 |
class RagPipeline:
|
167 |
+
def __init__(self):
|
168 |
+
# [MODIFIED] 不再傳入 AppConfig,直接引用
|
|
|
169 |
if not LLM_API_CONFIG["api_key"] or not LLM_API_CONFIG["base_url"]:
|
170 |
raise ValueError("LLM API Key or Base URL is not configured.")
|
171 |
self.llm_client = OpenAI(api_key=LLM_API_CONFIG["api_key"], base_url=LLM_API_CONFIG["base_url"])
|
|
|
174 |
|
175 |
self.drug_name_to_ids: Dict[str, List[str]] = {}
|
176 |
self.drug_vocab: Dict[str, set] = {"zh": set(), "en": set()}
|
177 |
+
self.state = type('state', (), {})()
|
178 |
|
179 |
def _load_model(self, model_class, model_name: str, model_type: str):
|
180 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
183 |
return model_class(model_name, device=device)
|
184 |
except Exception as e:
|
185 |
log.warning(f"載入模型至 {device} 失敗: {e}。嘗試切換至 CPU。")
|
186 |
+
try:
|
187 |
+
return model_class(model_name, device="cpu")
|
188 |
+
except Exception as e_cpu:
|
189 |
+
log.error(f"切換至 CPU 仍無法載入模型: {model_name}。請確認模型路徑或網路連線。錯誤訊息: {e_cpu}")
|
190 |
+
raise RuntimeError(f"模型載入失敗: {model_name}")
|
191 |
|
192 |
def load_data(self):
|
193 |
log.info("開始載入資料與模型...")
|
194 |
+
# [MODIFIED] 增加檔案存在性檢查
|
195 |
+
for path in [CSV_PATH, FAISS_INDEX, SENTENCES_PKL, BM25_PKL]:
|
196 |
+
if not pathlib.Path(path).exists():
|
197 |
+
raise FileNotFoundError(f"必要的資料檔案不存在: {path}")
|
198 |
+
|
199 |
try:
|
200 |
self.df_csv = pd.read_csv(CSV_PATH, dtype=str).fillna('')
|
201 |
# [MODIFIED] 增加必要欄位檢查
|
|
|
246 |
for part in parts:
|
247 |
if re.search(r'[\u4e00-\u9fff]', part):
|
248 |
self.drug_vocab["zh"].add(part)
|
249 |
+
# [MODIFIED] 檢查詞彙是否已存在
|
250 |
+
if part not in jieba.dt.FREQ:
|
251 |
+
try:
|
252 |
+
jieba.add_word(part, freq=2_000_000)
|
253 |
+
except Exception:
|
254 |
+
pass
|
255 |
else:
|
256 |
self.drug_vocab["en"].add(part)
|
257 |
for alias in DRUG_NAME_MAPPING:
|
258 |
self.drug_vocab["en"].add(alias.lower())
|
259 |
if re.search(r'[\u4e00-\u9fff]', alias):
|
260 |
+
if alias not in jieba.dt.FREQ:
|
261 |
+
try:
|
262 |
+
jieba.add_word(alias, freq=2_000_000)
|
263 |
+
except Exception:
|
264 |
+
pass
|
265 |
|
266 |
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
|
267 |
def _llm_call(self, messages, **kwargs) -> str:
|
268 |
+
start_time = time.time()
|
269 |
+
log.info(f"LLM 呼叫開始. 模型: {LLM_API_CONFIG['model']}, max_tokens: {kwargs.get('max_tokens', 'N/A')}, temperature: {kwargs.get('temperature', 'N/A')}")
|
270 |
+
|
271 |
try:
|
272 |
config = {**LLM_MODEL_CONFIG, **kwargs}
|
273 |
response = self.llm_client.chat.completions.create(
|
|
|
276 |
temperature=config["temperature"],
|
277 |
max_tokens=config["max_tokens"],
|
278 |
)
|
279 |
+
|
280 |
+
# [MODIFIED] 檢查回應結構並使用 getattr 安全地獲取內容
|
281 |
+
if not response or not response.choices or not response.choices[0].message:
|
282 |
+
log.error(f"LLM 呼叫成功 (200 OK),但回傳的 JSON 結構不完整。回傳: {response.model_dump_json() if response else 'None'}")
|
283 |
+
raise ValueError("LLM response content is empty or not a string.")
|
284 |
+
|
285 |
+
content = getattr(response.choices[0].message, "content", None)
|
286 |
if not isinstance(content, str) or not content.strip():
|
287 |
+
log.error(f"LLM 呼叫成功 (200 OK),但回傳內容為空。Response: {content}")
|
288 |
raise ValueError("LLM response content is empty or not a string.")
|
289 |
+
|
290 |
+
elapsed = time.time() - start_time
|
291 |
+
log.info(f"LLM 呼叫完成,耗時: {elapsed:.2f} 秒。��容長度: {len(content)} 字。")
|
292 |
return content
|
293 |
except Exception as e:
|
294 |
+
log.error(f"LLM API 呼叫失敗: {e}", exc_info=True)
|
295 |
raise
|
296 |
|
297 |
+
# [MODIFIED] 實現動態流程,根據查詢複雜度決定是否使用 Reranker
|
298 |
def answer_question(self, q_orig: str) -> str:
|
299 |
start_time = time.time()
|
300 |
log.info(f"===== 處理新查詢: '{q_orig}' =====")
|
|
|
303 |
if not drug_ids:
|
304 |
log.info("找不到藥品 ID,無法回答。")
|
305 |
return f"抱歉,資料庫中找不到該藥品。請確認藥品名稱,或直接諮詢醫師/藥師。{DISCLAIMER}"
|
306 |
+
log.info(f"步驟 1/5: 找到藥品 ID: {drug_ids},耗時: {time.time() - start_time:.2f} 秒")
|
307 |
+
step_start = time.time()
|
308 |
|
309 |
analysis = self._analyze_query(q_orig)
|
310 |
sub_queries, intents = analysis.get("sub_queries", [q_orig]), analysis.get("intents", [])
|
311 |
+
is_simple_query = self._is_simple_query(sub_queries, intents)
|
312 |
+
log.info(f"步驟 2/5: 意圖分析完成。子問題: {sub_queries}, 意圖: {intents}。判定為簡單查詢: {is_simple_query}。耗時: {time.time() - step_start:.2f} 秒")
|
313 |
+
step_start = time.time()
|
314 |
+
|
315 |
all_candidates = self._retrieve_candidates_for_all_queries(drug_ids, sub_queries, intents)
|
316 |
+
log.info(f"步驟 3/5: 檢索完成。所有子查詢共找到 {len(all_candidates)} 個不重複候選 chunks。耗時: {time.time() - step_start:.2f} 秒")
|
317 |
+
step_start = time.time()
|
318 |
+
|
319 |
+
if is_simple_query:
|
320 |
+
log.info("偵測到簡單查詢,跳過 Reranker 步驟。")
|
321 |
+
final_candidates = all_candidates[:TOP_K_SENTENCES]
|
322 |
+
reranked_results = [
|
323 |
+
RerankResult(idx=c.idx, rerank_score=c.fused_score, text=self.state.sentences[c.idx], meta=self.state.meta[c.idx])
|
324 |
+
for c in final_candidates
|
325 |
+
]
|
326 |
+
else:
|
327 |
+
log.info("偵測到複雜查詢,執行 Reranker。")
|
328 |
+
reranked_results = self._rerank_with_crossencoder(q_orig, all_candidates)
|
329 |
+
|
330 |
+
log.info(f"步驟 4/5: 最終選出 {len(reranked_results)} 個高品質候選。耗時: {time.time() - step_start:.2f} 秒")
|
331 |
+
step_start = time.time()
|
332 |
|
333 |
context = self._build_context(reranked_results)
|
334 |
if not context:
|
|
|
339 |
answer = self._llm_call([{"role": "user", "content": prompt}])
|
340 |
|
341 |
final_answer = f"{answer.strip()}\n\n{DISCLAIMER}"
|
342 |
+
log.info(f"步驟 5/5: 答案生成完成。答案長度: {len(answer.strip())} 字。耗時: {time.time() - step_start:.2f} 秒")
|
343 |
log.info(f"===== 查詢處理完成,總耗時: {time.time() - start_time:.2f} 秒 =====")
|
344 |
return final_answer
|
345 |
|
|
|
347 |
log.error(f"處理查詢 '{q_orig}' 時發生嚴重錯誤: {e}", exc_info=True)
|
348 |
return f"處理您的問題時發生內部錯誤,請稍後再試。{DISCLAIMER}"
|
349 |
|
350 |
+
def _is_simple_query(self, sub_queries: List[str], intents: List[str]) -> bool:
|
351 |
+
# 如果意圖分析回傳的子查詢數量 <= 1,且意圖分類數量也 <= 1,則判定為簡單問題
|
352 |
+
return len(sub_queries) <= 1 and len(intents) <= 1
|
353 |
+
|
354 |
@lru_cache(maxsize=128)
|
355 |
def _find_drug_ids_from_name(self, query: str) -> List[str]:
|
356 |
q = query.lower()
|
|
|
465 |
try:
|
466 |
expanded_query = self._llm_call([{"role": "user", "content": prompt}])
|
467 |
if expanded_query and expanded_query.strip():
|
468 |
+
log.info(f"查詢擴展成功。原始: '{query}', 擴展後: '{expanded_query}'")
|
469 |
return expanded_query
|
470 |
else:
|
471 |
+
log.warning(f"查詢擴展回傳空內容。原始查詢: '{query}'。將使用原始查詢。")
|
472 |
return query
|
473 |
except Exception as e:
|
474 |
+
log.error(f"查詢擴展失敗: {e}。原始查詢: '{query}'。將使用原始查詢。")
|
475 |
return query
|
476 |
|
477 |
def _rerank_with_crossencoder(self, query: str, candidates: List[FusedCandidate]) -> List[RerankResult]:
|
|
|
506 |
|
507 |
# [MODIFIED] 增強 JSON 解析的穩健性,從字串中提取 JSON 物件
|
508 |
def _safe_json_parse(self, s: str, default: Any = None) -> Any:
|
|
|
|
|
|
|
509 |
try:
|
510 |
+
# 嘗試解析完整字串
|
511 |
return json.loads(s)
|
512 |
+
except json.JSONDecodeError:
|
513 |
+
log.warning(f"無法解析完整 JSON。嘗試從字串中提取: {s[:200]}...")
|
514 |
+
# 如果失敗,嘗試用 regex 提取第一個 JSON 物件
|
515 |
+
m = re.search(r'\{.*?\}', s, re.DOTALL)
|
516 |
+
if m:
|
517 |
+
try:
|
518 |
+
return json.loads(m.group(0))
|
519 |
+
except json.JSONDecodeError:
|
520 |
+
log.warning(f"提取的 JSON 仍無法解析: {m.group(0)[:100]}...")
|
521 |
+
return default
|
522 |
|
523 |
# ---------- FastAPI 事件與路由 ----------
|
|
|
|
|
|
|
524 |
# [MODIFIED] 將 LINE 配置集中管理並進行啟動時檢查
|
525 |
class AppConfig:
|
526 |
CHANNEL_ACCESS_TOKEN = _require_env("CHANNEL_ACCESS_TOKEN")
|
527 |
CHANNEL_SECRET = _require_env("CHANNEL_SECRET")
|
528 |
|
529 |
+
rag_pipeline: Optional[RagPipeline] = None
|
530 |
+
|
531 |
+
# [MODIFIED] 使用 lifespan context manager
|
532 |
+
@asynccontextmanager
|
533 |
+
async def lifespan(app: FastAPI):
|
534 |
_require_llm_config()
|
535 |
+
global rag_pipeline
|
536 |
+
rag_pipeline = RagPipeline()
|
537 |
rag_pipeline.load_data()
|
538 |
log.info("啟動完成,服務準備就緒。")
|
539 |
+
yield
|
540 |
+
# 若有資源需要關閉可在這裡實作
|
541 |
+
log.info("服務關閉中。")
|
542 |
+
|
543 |
+
app = FastAPI(lifespan=lifespan)
|
544 |
|
545 |
@app.post("/webhook")
|
546 |
async def handle_webhook(request: Request, background_tasks: BackgroundTasks):
|