khoaaaaa commited on
Commit
4bdaf32
·
verified ·
1 Parent(s): 465670d

update llm_classifier_follow_up

Browse files
Files changed (1) hide show
  1. app.py +14 -30
app.py CHANGED
@@ -27,6 +27,7 @@ from rank_bm25 import BM25Okapi
27
  import google.generativeai as genai
28
  from cachetools import TTLCache
29
  from huggingface_hub import login, hf_hub_download
 
30
 
31
  # --- Login ---
32
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
@@ -130,37 +131,20 @@ def minmax_scale(arr):
130
  return np.zeros_like(arr)
131
  return (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
132
 
 
 
 
 
 
 
133
  def classify_followup(text: str):
134
- text = text.lower().strip()
135
- score = 0
136
- strong_followup_keywords = [
137
- r"\b(nó|cái (này|đó|ấy)|thủ tục (này|đó|ấy))\b",
138
- r"\b(vừa (nói|hỏi)|trước đó|ở trên|phía trên)\b",
139
- r"\b(tiếp theo|tiếp|còn nữa|ngoài ra)\b",
140
- r"\b(thế (thì|à)|vậy (thì|à)|như vậy)\b"
141
- ]
142
- # SỬA LỖI: Thêm "lệ phí" và "chuẩn bị" vào đây
143
- detail_questions = [
144
- r"\b(mất bao lâu|thời gian|bao nhiêu tiền|chi phí|phí|lệ phí)\b",
145
- r"\b(ở đâu|tại đâu|chỗ nào|địa chỉ)\b",
146
- r"\b(cần (gì|những gì)|yêu cầu|điều kiện|chuẩn bị)\b"
147
- ]
148
- specific_services = [
149
- r"\b(làm|cấp|gia hạn|đổi|đăng ký)\s+(căn cước|cmnd|cccd)\b",
150
- r"\b(làm|cấp|gia hạn|đổi)\s+hộ chiếu\b",
151
- r"\b(đăng ký)\s+(kết hôn|sinh|tử|hộ khẩu)\b"
152
- ]
153
-
154
- if any(re.search(p, text) for p in strong_followup_keywords):
155
- score -= 5 # Tăng điểm phạt
156
- if any(re.search(p, text) for p in detail_questions):
157
- score -= 4 # Tăng điểm phạt
158
- if any(re.search(p, text) for p in specific_services):
159
- score += 1 # Giảm điểm cộng
160
- if len(text.split()) <= 3: # Giảm ngưỡng độ dài
161
- score -= 1
162
-
163
- return 0 if score < 0 else 1
164
 
165
  def retrieve(query: str, top_k=TOP_K):
166
  print("Retrieving using FAISS -> BM25 Rerank method on CHUNKS...")
 
27
  import google.generativeai as genai
28
  from cachetools import TTLCache
29
  from huggingface_hub import login, hf_hub_download
30
+ from transformers import pipeline
31
 
32
  # --- Login ---
33
  HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
 
131
  return np.zeros_like(arr)
132
  return (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
133
 
134
+ classifier = pipeline(
135
+ "text-classification",
136
+ model="Qwen/Qwen2-0.5B-Instruct",
137
+ device_map="auto"
138
+ )
139
+
140
  def classify_followup(text: str):
141
+ prompt = f"""
142
+ Xác định xem câu sau có phải là follow-up (câu hỏi tiếp nối từ ngữ cảnh trước đó) hay không.
143
+ Trả lời duy nhất: 0 (không) hoặc 1 (có).
144
+ Câu: "{text}"
145
+ """
146
+ result = classifier(prompt, truncation=True)[0]["label"]
147
+ return 1 if "1" in result else 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  def retrieve(query: str, top_k=TOP_K):
150
  print("Retrieving using FAISS -> BM25 Rerank method on CHUNKS...")