Spaces:
Running
on
Zero
Running
on
Zero
Update query_understanding.py
Browse files- query_understanding.py +25 -8
query_understanding.py
CHANGED
@@ -43,11 +43,12 @@ class QueryUnderstandingEngine:
|
|
43 |
def __init__(self):
|
44 |
"""初始化查詢理解引擎"""
|
45 |
self.sbert_model = None
|
|
|
46 |
self.breed_list = self._load_breed_list()
|
47 |
self.synonyms = self._initialize_synonyms()
|
48 |
self.semantic_templates = {}
|
49 |
-
|
50 |
-
|
51 |
|
52 |
def _load_breed_list(self) -> List[str]:
|
53 |
"""載入品種清單"""
|
@@ -66,25 +67,35 @@ class QueryUnderstandingEngine:
|
|
66 |
'Bulldog', 'Poodle', 'Beagle', 'Border_Collie', 'Yorkshire_Terrier']
|
67 |
|
68 |
def _initialize_sbert_model(self):
|
69 |
-
"""初始化 SBERT 模型"""
|
|
|
|
|
|
|
70 |
try:
|
|
|
71 |
model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
|
72 |
|
73 |
for model_name in model_options:
|
74 |
try:
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
except Exception as e:
|
79 |
print(f"Failed to load {model_name}: {str(e)}")
|
80 |
continue
|
81 |
|
82 |
print("All SBERT models failed to load. Using keyword-only analysis.")
|
83 |
self.sbert_model = None
|
|
|
84 |
|
85 |
except Exception as e:
|
86 |
print(f"Failed to initialize SBERT model: {str(e)}")
|
87 |
self.sbert_model = None
|
|
|
|
|
|
|
88 |
|
89 |
def _initialize_synonyms(self) -> DimensionalSynonyms:
|
90 |
"""初始化多維度同義詞字典"""
|
@@ -143,6 +154,10 @@ class QueryUnderstandingEngine:
|
|
143 |
|
144 |
def _build_semantic_templates(self):
|
145 |
"""建立語義模板向量(僅在 SBERT 可用時)"""
|
|
|
|
|
|
|
|
|
146 |
if not self.sbert_model:
|
147 |
return
|
148 |
|
@@ -192,6 +207,9 @@ class QueryUnderstandingEngine:
|
|
192 |
dimensions = self._extract_keyword_dimensions(normalized_input)
|
193 |
|
194 |
# 如果 SBERT 可用,進行語義分析增強
|
|
|
|
|
|
|
195 |
if self.sbert_model:
|
196 |
semantic_dimensions = self._extract_semantic_dimensions(user_input)
|
197 |
dimensions = self._merge_dimensions(dimensions, semantic_dimensions)
|
@@ -435,7 +453,6 @@ class QueryUnderstandingEngine:
|
|
435 |
])
|
436 |
}
|
437 |
|
438 |
-
# 便利函數
|
439 |
def analyze_user_query(user_input: str) -> QueryDimensions:
|
440 |
"""
|
441 |
便利函數:分析使用者查詢
|
@@ -461,4 +478,4 @@ def get_query_summary(user_input: str) -> Dict[str, Any]:
|
|
461 |
"""
|
462 |
engine = QueryUnderstandingEngine()
|
463 |
dimensions = engine.analyze_query(user_input)
|
464 |
-
return engine.get_dimension_summary(dimensions)
|
|
|
43 |
def __init__(self):
|
44 |
"""初始化查詢理解引擎"""
|
45 |
self.sbert_model = None
|
46 |
+
self._sbert_loading_attempted = False
|
47 |
self.breed_list = self._load_breed_list()
|
48 |
self.synonyms = self._initialize_synonyms()
|
49 |
self.semantic_templates = {}
|
50 |
+
# 延遲SBERT載入直到需要時才在GPU環境中進行
|
51 |
+
print("QueryUnderstandingEngine initialized (SBERT loading deferred)")
|
52 |
|
53 |
def _load_breed_list(self) -> List[str]:
|
54 |
"""載入品種清單"""
|
|
|
67 |
'Bulldog', 'Poodle', 'Beagle', 'Border_Collie', 'Yorkshire_Terrier']
|
68 |
|
69 |
def _initialize_sbert_model(self):
|
70 |
+
"""初始化 SBERT 模型 - 延遲載入以避免ZeroGPU CUDA初始化問題"""
|
71 |
+
if self.sbert_model is not None or getattr(self, '_sbert_loading_attempted', False):
|
72 |
+
return self.sbert_model
|
73 |
+
|
74 |
try:
|
75 |
+
print("Loading SBERT model for query understanding in GPU context...")
|
76 |
model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
|
77 |
|
78 |
for model_name in model_options:
|
79 |
try:
|
80 |
+
import torch
|
81 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
82 |
+
self.sbert_model = SentenceTransformer(model_name, device=device)
|
83 |
+
print(f"SBERT model {model_name} loaded successfully for query understanding on {device}")
|
84 |
+
return self.sbert_model
|
85 |
except Exception as e:
|
86 |
print(f"Failed to load {model_name}: {str(e)}")
|
87 |
continue
|
88 |
|
89 |
print("All SBERT models failed to load. Using keyword-only analysis.")
|
90 |
self.sbert_model = None
|
91 |
+
return None
|
92 |
|
93 |
except Exception as e:
|
94 |
print(f"Failed to initialize SBERT model: {str(e)}")
|
95 |
self.sbert_model = None
|
96 |
+
return None
|
97 |
+
finally:
|
98 |
+
self._sbert_loading_attempted = True
|
99 |
|
100 |
def _initialize_synonyms(self) -> DimensionalSynonyms:
|
101 |
"""初始化多維度同義詞字典"""
|
|
|
154 |
|
155 |
def _build_semantic_templates(self):
|
156 |
"""建立語義模板向量(僅在 SBERT 可用時)"""
|
157 |
+
# Initialize SBERT model if needed
|
158 |
+
if self.sbert_model is None:
|
159 |
+
self._initialize_sbert_model()
|
160 |
+
|
161 |
if not self.sbert_model:
|
162 |
return
|
163 |
|
|
|
207 |
dimensions = self._extract_keyword_dimensions(normalized_input)
|
208 |
|
209 |
# 如果 SBERT 可用,進行語義分析增強
|
210 |
+
if self.sbert_model is None:
|
211 |
+
self._initialize_sbert_model()
|
212 |
+
|
213 |
if self.sbert_model:
|
214 |
semantic_dimensions = self._extract_semantic_dimensions(user_input)
|
215 |
dimensions = self._merge_dimensions(dimensions, semantic_dimensions)
|
|
|
453 |
])
|
454 |
}
|
455 |
|
|
|
456 |
def analyze_user_query(user_input: str) -> QueryDimensions:
|
457 |
"""
|
458 |
便利函數:分析使用者查詢
|
|
|
478 |
"""
|
479 |
engine = QueryUnderstandingEngine()
|
480 |
dimensions = engine.analyze_query(user_input)
|
481 |
+
return engine.get_dimension_summary(dimensions)
|