Spaces:
Running
Running
# app.py | |
import os | |
import pickle | |
import gzip | |
import json | |
import re | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
import google.generativeai as genai | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from huggingface_hub import hf_hub_download # Thư viện để tải file từ Hub | |
print("--- KHỞI ĐỘNG MÁY CHỦ CHATBOT ---") | |
# --- 1. THIẾT LẬP VÀ TẢI TÀI NGUYÊN TỪ HUGGING FACE HUB --- | |
try: | |
print("Đang tải các tài nguyên cần thiết từ Hugging Face Hub...") | |
# !!! THAY THẾ BẰNG USERNAME VÀ TÊN DATASET CỦA BẠN !!! | |
HF_REPO_ID = "TEN_USERNAME_HF/egov-bot-data" | |
# Tự động tải các file từ "Kho Dữ liệu" về môi trường của Space | |
RAW_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="toan_bo_du_lieu_final.json") | |
FAISS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="index.faiss") | |
METAS_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="metas.pkl.gz") | |
BM25_PATH = hf_hub_download(repo_id=HF_REPO_ID, filename="bm25.pkl.gz") | |
print("✅ Tải file dữ liệu thành công!") | |
# Lấy API Key từ Secret của Hugging Face | |
API_KEY = os.environ.get("GOOGLE_API_KEY") | |
genai.configure(api_key=API_KEY) | |
# Tải các mô hình và dữ liệu | |
generation_model = genai.GenerativeModel('gemini-2.5-flash') | |
embedding_model = SentenceTransformer("AITeamVN/Vietnamese_Embedding") | |
faiss_index = faiss.read_index(FAISS_PATH) | |
with gzip.open(METAS_PATH, "rb") as f: | |
metadatas = pickle.load(f) | |
with gzip.open(BM25_PATH, "rb") as f: | |
bm25 = pickle.load(f) | |
with open(RAW_PATH, "r", encoding="utf-8") as f: | |
raw_data = json.load(f) | |
procedure_map = {item['nguon']: item for item in raw_data} | |
print(f"✅ Tải tài nguyên thành công! Sẵn có {faiss_index.ntotal} chunks kiến thức.") | |
except Exception as e: | |
print(f"❌ LỖI KHI TẢI TÀI NGUYÊN: {e}") | |
# --- 2. CÁC HÀM XỬ LÝ CỦA BỘ NÃO (LOGIC TỪ COLAB CỦA BẠN) --- | |
# (Toàn bộ các hàm classify_followup, minmax_scale, retrieve, get_full_procedure_text của bạn được giữ nguyên ở đây) | |
def classify_followup(text: str): | |
text = text.lower().strip() | |
score = 0 | |
strong_followup_keywords = [r"\b(nó|cái (này|đó|ấy)|thủ tục (này|đó|ấy))\b", r"\b(vừa (nói|hỏi)|trước đó|ở trên|phía trên)\b", r"\b(tiếp theo|tiếp|còn nữa|ngoài ra)\b", r"\b(thế (thì|à)|vậy (thì|à)|như vậy)\b"] | |
detail_questions = [r"\b(mất bao lâu|thời gian|bao nhiêu tiền|chi phí|phí)\b", r"\b(ở đâu|tại đâu|chỗ nào|địa chỉ)\b", r"\b(cần (gì|những gì)|yêu cầu|điều kiện)\b"] | |
specific_services = [r"\b(làm|cấp|gia hạn|đổi|đăng ký)\s+(căn cước|cmnd|cccd)\b", r"\b(làm|cấp|gia hạn|đổi)\s+hộ chiếu\b", r"\b(đăng ký)\s+(kết hôn|sinh|tử|hộ khẩu)\b"] | |
if any(re.search(p, text) for p in strong_followup_keywords): score -= 3 | |
if any(re.search(p, text) for p in detail_questions): score -= 2 | |
if any(re.search(p, text) for p in specific_services): score += 3 | |
if len(text.split()) <= 4: score -=1 | |
return 0 if score < 0 else 1 | |
def minmax_scale(arr): | |
arr = np.array(arr, dtype="float32") | |
if len(arr) == 0 or np.max(arr) == np.min(arr): return np.zeros_like(arr) | |
return (arr - np.min(arr)) / (np.max(arr) - np.min(arr)) | |
def retrieve(query: str, top_k=3): | |
qv = embedding_model.encode([query], normalize_embeddings=True).astype("float32") | |
D, I = faiss_index.search(qv, top_k * 5) | |
vec_scores = (1 - D[0]).tolist() | |
vec_idx = I[0].tolist() | |
tokenized_query = query.split() | |
bm25_scores_all = bm25.get_scores(tokenized_query) | |
bm25_top_idx = np.argsort(-bm25_scores_all)[:top_k * 5].tolist() | |
union_idx = list(dict.fromkeys(vec_idx + bm25_top_idx)) | |
vec_map = {i: s for i, s in zip(vec_idx, vec_scores)} | |
vec_list = [vec_map.get(i, 0.0) for i in union_idx] | |
bm25_list = [bm25_scores_all[i] for i in union_idx] | |
vec_scaled = minmax_scale(vec_list) | |
bm25_scaled = minmax_scale(bm25_list) | |
fused = 0.7 * vec_scaled + 0.3 * bm25_scaled | |
order = np.argsort(-fused) | |
return [union_idx[i] for i in order[:top_k]] | |
def get_full_procedure_text(parent_id): | |
procedure = procedure_map.get(parent_id) | |
if not procedure: return "Không tìm thấy thủ tục." | |
parts = [] | |
field_map = {"ten_thu_tuc": "Tên thủ tục", "cach_thuc_thuc_hien": "Cách thức thực hiện", "thanh_phan_ho_so": "Thành phần hồ sơ", "trinh_tu_thuc_hien": "Trình tự thực hiện", "co_quan_thuc_hien": "Cơ quan thực hiện", "yeu_cau_dieu_kien": "Yêu cầu, điều kiện", "thu_tuc_lien_quan": "Thủ tục liên quan", "nguon": "Nguồn"} | |
for k, v in procedure.items(): | |
if v and k in field_map: | |
parts.append(f"{field_map[k]}:\n{str(v).strip()}") | |
return "\n\n".join(parts) | |
# --- 3. KHỞI TẠO MÁY CHỦ FLASK VÀ API --- | |
app = Flask(__name__) | |
CORS(app) | |
chat_histories = {} | |
def chat(): | |
data = request.json | |
user_query = data.get('question') | |
session_id = data.get('session_id', 'default') | |
if not user_query: | |
return jsonify({"error": "Không có câu hỏi nào được cung cấp"}), 400 | |
if session_id not in chat_histories: | |
chat_histories[session_id] = [] | |
current_history = chat_histories[session_id] | |
context = "" | |
if classify_followup(user_query) == 0 and current_history: | |
context = current_history[-1].get('context', '') | |
print(f"[{session_id}] Dùng lại ngữ cảnh cũ cho câu hỏi followup.") | |
else: | |
retrieved_indices = retrieve(user_query) | |
if retrieved_indices: | |
parent_id = metadatas[retrieved_indices[0]]["parent_id"] | |
context = get_full_procedure_text(parent_id) | |
print(f"[{session_id}] Đã tìm được ngữ cảnh mới.") | |
history_str = "\n".join([f"{item['role']}: {item['content']}" for item in current_history]) | |
prompt = f"""Bạn là trợ lý eGov-Bot... (Nội dung prompt của bạn ở đây) | |
Lịch sử trò chuyện: {history_str} | |
DỮ LIỆU: --- {context} --- | |
CÂU HỎI: {user_query} | |
""" | |
response = generation_model.generate_content(prompt) | |
final_answer = response.text | |
current_history.append({'role': 'user', 'content': user_query}) | |
current_history.append({'role': 'model', 'content': final_answer, 'context': context}) | |
return jsonify({"answer": final_answer}) | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) | |