Spaces:
Running
Running
import os | |
import shutil | |
# --- Cache + Env setup --- | |
os.environ["HF_HOME"] = "/tmp/hf_home" | |
os.environ["HF_HUB_CACHE"] = "/tmp/hf_cache" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" | |
os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_datasets" | |
os.environ["XDG_CACHE_HOME"] = "/tmp/.cache" | |
os.environ["HOME"] = "/tmp" | |
os.makedirs("/tmp/hf_home", exist_ok=True) | |
os.makedirs("/tmp/hf_cache", exist_ok=True) | |
os.makedirs("/tmp/hf_datasets", exist_ok=True) | |
os.makedirs("/tmp/.cache", exist_ok=True) | |
shutil.rmtree("/.cache", ignore_errors=True) | |
# --- Imports --- | |
import time, hashlib, gzip, pickle, json, traceback, re | |
import torch | |
from flask import Flask, request, jsonify, Response | |
from flask_cors import CORS | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
import google.generativeai as genai | |
from cachetools import TTLCache | |
from huggingface_hub import login, hf_hub_download | |
from transformers import pipeline | |
# --- Login --- | |
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
if HF_TOKEN: | |
try: | |
login(HF_TOKEN) | |
print("HF login successful") | |
except Exception as e: | |
print("Warning: HF login failed:", e) | |
else: | |
print("Warning: HF_TOKEN not found - only public repos accessible") | |
# --- Config --- | |
HF_REPO_ID = os.environ.get("HF_REPO_ID", "HungBB/egov-bot-data") | |
REPO_TYPE = os.environ.get("REPO_TYPE", "dataset") | |
EMB_MODEL = os.environ.get("EMB_MODEL", "AITeamVN/Vietnamese_Embedding") | |
GENAI_MODEL = os.environ.get("GENAI_MODEL", "gemini-2.5-flash") | |
TOP_K = int(os.environ.get("TOP_K", "3")) | |
FAISS_CANDIDATES = int(os.environ.get("FAISS_CANDIDATES", str(max(10, TOP_K*5)))) | |
BM25_PREFILTER = int(os.environ.get("BM25_PREFILTER", "200")) | |
CACHE_TTL = int(os.environ.get("CACHE_TTL", "3600")) | |
CACHE_MAX = int(os.environ.get("CACHE_MAX", "2000")) | |
print("--- KHỞI ĐỘNG MÁY CHỦ CHATBOT (optimized & fixed) ---") | |
t0 = time.perf_counter() | |
# --- TẢI VÀ LOAD TÀI NGUYÊN (SẮP XẾP LẠI CHO ĐÚNG LOGIC) --- | |
try: | |
print("Downloading resources from Hugging Face Hub...") | |
FAISS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="index.faiss", repo_type=REPO_TYPE) | |
METAS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="metas.pkl.gz", repo_type=REPO_TYPE) | |
BM25_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="bm25.pkl.gz", repo_type=REPO_TYPE) | |
print("✅ All files downloaded or already available.") | |
# 1. Load data from files | |
print("Loading data into memory...") | |
faiss_index = faiss.read_index(FAISS_PATH) | |
with gzip.open(METAS_PATH, "rb") as f: | |
metas = pickle.load(f) | |
if isinstance(metas, dict) and "corpus" in metas: | |
corpus = metas["corpus"] | |
else: | |
corpus = metas | |
with gzip.open(BM25_PATH, "rb") as f: | |
bm25 = pickle.load(f) | |
metadatas = corpus | |
print("✅ Data loaded. Corpus size:", len(metadatas)) | |
print("FAISS index ntotal =", getattr(faiss_index, "ntotal", "unknown")) | |
# 2. SỬA LỖI: Tạo mapping SAU KHI đã load xong metadatas | |
# --- Đặt khối code này vào đúng vị trí --- | |
# Hãy đảm bảo bạn sao chép chính xác cả các khoảng trắng thụt lề ở đầu dòng. | |
print("Building parent_id -> list of chunks map...") | |
parent_id_to_chunks = {} | |
for chunk in metadatas: | |
key = chunk.get("parent_id") or chunk.get("nguon") | |
if key: | |
if key not in parent_id_to_chunks: | |
parent_id_to_chunks[key] = [] | |
parent_id_to_chunks[key].append(chunk) | |
print("✅ Chunks map created.") | |
# --- Kết thúc khối code cần thay thế --- | |
# 3. Load embedding model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Loading embedding model: {EMB_MODEL} on device: {device}") | |
embedding_model = SentenceTransformer(EMB_MODEL, device=device) | |
print("✅ Embedding model loaded.") | |
except Exception as e: | |
print(f"❌ LỖI NGHIÊM TRỌNG KHI KHỞI ĐỘNG: {e}") | |
traceback.print_exc() | |
# Dừng ứng dụng nếu không tải được tài nguyên | |
raise RuntimeError("Failed to load critical resources.") from e | |
print("Total resources load time: %.2fs" % (time.perf_counter() - t0)) | |
# --- External APIs --- | |
API_KEY = os.environ.get("GOOGLE_API_KEY") | |
if not API_KEY: | |
print("Warning: GOOGLE_API_KEY missing. LLM calls will fail.") | |
generation_model = None | |
else: | |
genai.configure(api_key=API_KEY) | |
generation_model = genai.GenerativeModel(GENAI_MODEL) | |
# --- Cache --- | |
answer_cache = TTLCache(maxsize=CACHE_MAX, ttl=CACHE_TTL) | |
# --- Utility functions --- | |
def _norm_key(s: str) -> str: | |
return " ".join(s.lower().strip().split()) | |
def cache_key_for_query(q: str) -> str: | |
raw = f"{_norm_key(q)}|emb={EMB_MODEL}|k={TOP_K}|p={BM25_PREFILTER}" | |
return hashlib.sha256(raw.encode("utf-8")).hexdigest() | |
def minmax_scale(arr): | |
arr = np.array(arr, dtype="float32") | |
if len(arr) == 0 or np.max(arr) == np.min(arr): | |
return np.zeros_like(arr) | |
return (arr - np.min(arr)) / (np.max(arr) - np.min(arr)) | |
classifier = pipeline( | |
"text-classification", | |
model="Qwen/Qwen2-0.5B-Instruct", | |
device_map="auto" | |
) | |
def classify_followup(text: str): | |
prompt = f""" | |
Xác định xem câu sau có phải là follow-up (câu hỏi tiếp nối từ ngữ cảnh trước đó) hay không. | |
Trả lời duy nhất: 0 (không) hoặc 1 (có). | |
Câu: "{text}" | |
""" | |
result = classifier(prompt, truncation=True)[0]["label"] | |
return 1 if "1" in result else 0 | |
def retrieve(query: str, top_k=TOP_K): | |
print("Retrieving using FAISS -> BM25 Rerank method on CHUNKS...") | |
# --- BƯỚC 1: LẤY CÁC CHUNK ỨNG VIÊN TỪ FAISS --- | |
qv = embedding_model.encode([query], convert_to_numpy=True, normalize_embeddings=True).astype("float32") | |
num_candidates = 50 | |
try: | |
distances, candidate_indices = faiss_index.search(qv, num_candidates) | |
candidate_indices = candidate_indices[0].tolist() | |
except Exception as e: | |
print(f"Error during FAISS search: {e}") | |
return [] | |
# --- BƯỚC 2: DÙNG BM25 ĐỂ XẾP HẠNG LẠI DỰA TRÊN NỘI DUNG CỦA CHUNK --- | |
tokenized_query = query.split() | |
valid_candidate_indices = [idx for idx in candidate_indices if isinstance(metadatas[idx], dict)] | |
# ✅ ĐÃ SỬA: Sử dụng đúng key 'text' từ file metas.pkl của bạn. | |
candidate_docs_text = [metadatas[i].get('text', '') for i in valid_candidate_indices] | |
if not any(candidate_docs_text): | |
print("Warning: All candidate chunks have empty text. Returning FAISS results directly.") | |
return valid_candidate_indices[:top_k] | |
# Tạo chỉ mục BM25 tạm thời | |
try: | |
temp_bm25 = BM25Okapi(candidate_docs_text) | |
doc_scores = temp_bm25.get_scores(tokenized_query) | |
except Exception as e: | |
print(f"Error during BM25 reranking: {e}") | |
# Nếu BM25 lỗi, trả về kết quả của FAISS | |
return valid_candidate_indices[:top_k] | |
# Sắp xếp lại | |
reranked_order = np.argsort(-doc_scores) | |
# --- BƯỚC 3: LẤY KẾT QUẢ CUỐI CÙNG --- | |
final_indices = [valid_candidate_indices[i] for i in reranked_order[:top_k]] | |
print(f"Reranked and retrieved top {len(final_indices)} chunk indices: {final_indices}") | |
return final_indices | |
# Tạo một chỉ mục BM25 tạm thời chỉ với các ứng viên | |
temp_bm25 = BM25Okapi(candidate_docs_text) | |
doc_scores = temp_bm25.get_scores(tokenized_query) | |
# Sắp xếp lại danh sách ứng viên gốc dựa trên điểm số BM25 | |
reranked_order = np.argsort(-doc_scores) | |
# --- BƯỚC 3: LẤY KẾT QUẢ CUỐI CÙNG --- | |
final_indices = [valid_candidate_indices[i] for i in reranked_order[:top_k]] | |
print(f"Reranked and retrieved top {len(final_indices)} indices: {final_indices}") | |
return final_indices | |
def get_full_procedure_text_by_parent(parent_id): | |
chunks = parent_id_to_chunks.get(parent_id) | |
if not chunks: | |
return "Không tìm thấy thông tin chi tiết cho thủ tục này." | |
# Lấy tên thủ tục từ chunk đầu tiên (vì chúng giống nhau) | |
procedure_name = chunks[0].get('ten_thu_tuc', 'Không rõ') | |
# Dùng dictionary để gom nội dung của các chunk lại theo từng mục | |
procedure_details = {} | |
for chunk in chunks: | |
field = chunk.get('field') | |
# Dùng 'raw' để có text gốc, sạch sẽ | |
raw_text = chunk.get('raw') | |
if field and raw_text: | |
if field not in procedure_details: | |
procedure_details[field] = [] | |
procedure_details[field].append(raw_text.strip()) | |
# Định dạng lại output cho đẹp | |
field_map = { | |
"thanh_phan_ho_so": "Thành phần hồ sơ", | |
"trinh_tu_thuc_hien": "Trình tự thực hiện", | |
"cach_thuc_thuc_hien": "Cách thức thực hiện", | |
"yeu_cau_dieu_kien": "Yêu cầu, điều kiện", | |
"co_quan_thuc_hien": "Cơ quan thực hiện", | |
"nguon": "Nguồn" | |
} | |
context_parts = [f"Tên thủ tục:\n{procedure_name}"] | |
for field_key, field_name in field_map.items(): | |
if field_key in procedure_details: | |
# Nối tất cả các phần text của cùng một mục lại | |
content = "\n".join(procedure_details[field_key]) | |
context_parts.append(f"--- \n{field_name}:\n{content.strip()}") | |
return "\n\n".join(context_parts) | |
# --- Flask App --- | |
app = Flask(__name__) | |
CORS(app) | |
chat_histories = {} # Lưu lịch sử chat theo session | |
def health(): | |
return {"status": "ok"} | |
def chat(): | |
try: | |
data = request.get_json(force=True) | |
except Exception as e: | |
return jsonify({"error": "Cannot parse JSON", "detail": str(e)}), 400 | |
user_query = data.get('question') | |
session_id = data.get('session_id', 'default') | |
use_stream = data.get('stream', False) # Thêm tùy chọn stream | |
if not user_query: | |
return jsonify({"error": "No question provided"}), 400 | |
if session_id not in chat_histories: | |
chat_histories[session_id] = [] | |
current_history = chat_histories[session_id] | |
# --- TỐI ƯU HÓA: KÍCH HOẠT CACHING --- | |
cache_key = cache_key_for_query(user_query) | |
if cache_key in answer_cache: | |
print(f"CACHE HIT for query: '{user_query}'") | |
cached_answer = answer_cache[cache_key] | |
current_history.append({'role': 'user', 'content': user_query}) | |
current_history.append({'role': 'model', 'content': cached_answer, 'context': 'FROM_CACHE'}) | |
return jsonify({"answer": cached_answer}) | |
# --- KẾT THÚC TỐI ƯU HÓA CACHING --- | |
# Logic truy xuất ngữ cảnh | |
if classify_followup(user_query) == 0 and current_history: | |
context = current_history[-1].get('context', '') | |
else: | |
try: | |
idxs = retrieve(user_query, top_k=TOP_K) | |
if idxs: | |
parent_id = metadatas[idxs[0]].get("parent_id") or metadatas[idxs[0]].get("nguon") | |
context = get_full_procedure_text_by_parent(parent_id) | |
else: | |
context = "" | |
except Exception as e: | |
print(f"Error during retrieval: {e}") | |
context = "" | |
# Tạo prompt | |
# Tạo prompt | |
history_str = "\n".join([f"{item['role']}: {item['content']}" for item in current_history]) | |
# SỬA LỖI: Prompt mạnh mẽ và dứt khoát hơn | |
prompt = f"""Bạn là một trợ lý ảo chuyên về dịch vụ công của Việt Nam tên là eGov-Bot. | |
Nhiệm vụ của bạn là trả lời câu hỏi của người dùng một cách chính xác, đi thẳng vào vấn đề và chỉ dựa trên DỮ LIỆU được cung cấp. | |
QUY TẮC BẮT BUỘC: | |
1. **Trích xuất và Tóm tắt**: Phải đọc kỹ DỮ LIỆU, trích xuất các thông tin quan trọng như "Thành phần hồ sơ", "Trình tự thực hiện", "Cách thức thực hiện", "Lệ phí" và trình bày lại dưới dạng các gạch đầu dòng hoặc các bước rõ ràng. | |
2. **KHÔNG ĐƯỢC CHỈ ĐƯA LINK**: Tuyệt đối không trả lời bằng cách chỉ đưa ra một đường link. Link trong mục "Nguồn" chỉ để tham khảo thêm. | |
3. **Thiếu thông tin**: Nếu sau khi đọc kỹ DỮ LIỆU mà vẫn không tìm thấy thông tin cho câu hỏi, hãy trả lời: "Mình chưa có thông tin về [chủ đề câu hỏi]. Bạn có thể tham khảo thêm tại nguồn sau:" và sau đó đưa link nguồn có trong DỮ LIỆU. | |
4. **Câu hỏi về nguồn**: Nếu được hỏi về NGUỒN, phải lấy và đưa link trong mục "Nguồn" để người dùng tham khảo. | |
5. **Tạo bảng so sánh**: Nếu có so sánh thì hãy tạo bảng có khung từ DỮ LIỆU, bảng phải đẹp, gọn và dễ nhìn. | |
Lịch sử trò chuyện: | |
{history_str} | |
DỮ LIỆU: | |
--- | |
{context} | |
--- | |
CÂU HỎI: {user_query} | |
TRẢ LỜI:""" | |
# Gọi LLM | |
try: | |
if generation_model is None: | |
raise RuntimeError("generation_model is not available. Check GOOGLE_API_KEY.") | |
if not use_stream: | |
# Chế độ không stream | |
response = generation_model.generate_content(prompt) | |
final_answer = getattr(response, "text", str(response)) | |
# Lưu vào cache | |
answer_cache[cache_key] = final_answer | |
current_history.append({'role': 'user', 'content': user_query}) | |
current_history.append({'role': 'model', 'content': final_answer, 'context': context}) | |
return jsonify({"answer": final_answer}) | |
else: | |
# Chế độ stream | |
def generate(): | |
response_stream = generation_model.generate_content(prompt, stream=True) | |
full_answer = "" | |
for chunk in response_stream: | |
text_part = getattr(chunk, "text", "") | |
full_answer += text_part | |
yield text_part | |
# Lưu toàn bộ câu trả lời vào cache và lịch sử sau khi stream xong | |
answer_cache[cache_key] = full_answer | |
current_history.append({'role': 'user', 'content': user_query}) | |
current_history.append({'role': 'model', 'content': full_answer, 'context': context}) | |
return Response(generate(), mimetype='text/plain') | |
except Exception as e: | |
tb = traceback.format_exc() | |
print(f"LLM call failed: {e}\n{tb}") | |
return jsonify({"error": "LLM call failed", "detail": str(e), "trace": tb}), 500 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) |