DawnC commited on
Commit
089125f
·
verified ·
1 Parent(s): c4c78dc

Update query_understanding.py

Browse files
Files changed (1) hide show
  1. query_understanding.py +25 -8
query_understanding.py CHANGED
@@ -43,11 +43,12 @@ class QueryUnderstandingEngine:
43
  def __init__(self):
44
  """初始化查詢理解引擎"""
45
  self.sbert_model = None
 
46
  self.breed_list = self._load_breed_list()
47
  self.synonyms = self._initialize_synonyms()
48
  self.semantic_templates = {}
49
- self._initialize_sbert_model()
50
- self._build_semantic_templates()
51
 
52
  def _load_breed_list(self) -> List[str]:
53
  """載入品種清單"""
@@ -66,25 +67,35 @@ class QueryUnderstandingEngine:
66
  'Bulldog', 'Poodle', 'Beagle', 'Border_Collie', 'Yorkshire_Terrier']
67
 
68
  def _initialize_sbert_model(self):
69
- """初始化 SBERT 模型"""
 
 
 
70
  try:
 
71
  model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
72
 
73
  for model_name in model_options:
74
  try:
75
- self.sbert_model = SentenceTransformer(model_name)
76
- print(f"SBERT model {model_name} loaded successfully for query understanding")
77
- return
 
 
78
  except Exception as e:
79
  print(f"Failed to load {model_name}: {str(e)}")
80
  continue
81
 
82
  print("All SBERT models failed to load. Using keyword-only analysis.")
83
  self.sbert_model = None
 
84
 
85
  except Exception as e:
86
  print(f"Failed to initialize SBERT model: {str(e)}")
87
  self.sbert_model = None
 
 
 
88
 
89
  def _initialize_synonyms(self) -> DimensionalSynonyms:
90
  """初始化多維度同義詞字典"""
@@ -143,6 +154,10 @@ class QueryUnderstandingEngine:
143
 
144
  def _build_semantic_templates(self):
145
  """建立語義模板向量(僅在 SBERT 可用時)"""
 
 
 
 
146
  if not self.sbert_model:
147
  return
148
 
@@ -192,6 +207,9 @@ class QueryUnderstandingEngine:
192
  dimensions = self._extract_keyword_dimensions(normalized_input)
193
 
194
  # 如果 SBERT 可用,進行語義分析增強
 
 
 
195
  if self.sbert_model:
196
  semantic_dimensions = self._extract_semantic_dimensions(user_input)
197
  dimensions = self._merge_dimensions(dimensions, semantic_dimensions)
@@ -435,7 +453,6 @@ class QueryUnderstandingEngine:
435
  ])
436
  }
437
 
438
- # 便利函數
439
  def analyze_user_query(user_input: str) -> QueryDimensions:
440
  """
441
  便利函數:分析使用者查詢
@@ -461,4 +478,4 @@ def get_query_summary(user_input: str) -> Dict[str, Any]:
461
  """
462
  engine = QueryUnderstandingEngine()
463
  dimensions = engine.analyze_query(user_input)
464
- return engine.get_dimension_summary(dimensions)
 
43
  def __init__(self):
44
  """初始化查詢理解引擎"""
45
  self.sbert_model = None
46
+ self._sbert_loading_attempted = False
47
  self.breed_list = self._load_breed_list()
48
  self.synonyms = self._initialize_synonyms()
49
  self.semantic_templates = {}
50
+ # 延遲SBERT載入直到需要時才在GPU環境中進行
51
+ print("QueryUnderstandingEngine initialized (SBERT loading deferred)")
52
 
53
  def _load_breed_list(self) -> List[str]:
54
  """載入品種清單"""
 
67
  'Bulldog', 'Poodle', 'Beagle', 'Border_Collie', 'Yorkshire_Terrier']
68
 
69
  def _initialize_sbert_model(self):
70
+ """初始化 SBERT 模型 - 延遲載入以避免ZeroGPU CUDA初始化問題"""
71
+ if self.sbert_model is not None or getattr(self, '_sbert_loading_attempted', False):
72
+ return self.sbert_model
73
+
74
  try:
75
+ print("Loading SBERT model for query understanding in GPU context...")
76
  model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
77
 
78
  for model_name in model_options:
79
  try:
80
+ import torch
81
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
82
+ self.sbert_model = SentenceTransformer(model_name, device=device)
83
+ print(f"SBERT model {model_name} loaded successfully for query understanding on {device}")
84
+ return self.sbert_model
85
  except Exception as e:
86
  print(f"Failed to load {model_name}: {str(e)}")
87
  continue
88
 
89
  print("All SBERT models failed to load. Using keyword-only analysis.")
90
  self.sbert_model = None
91
+ return None
92
 
93
  except Exception as e:
94
  print(f"Failed to initialize SBERT model: {str(e)}")
95
  self.sbert_model = None
96
+ return None
97
+ finally:
98
+ self._sbert_loading_attempted = True
99
 
100
  def _initialize_synonyms(self) -> DimensionalSynonyms:
101
  """初始化多維度同義詞字典"""
 
154
 
155
  def _build_semantic_templates(self):
156
  """建立語義模板向量(僅在 SBERT 可用時)"""
157
+ # Initialize SBERT model if needed
158
+ if self.sbert_model is None:
159
+ self._initialize_sbert_model()
160
+
161
  if not self.sbert_model:
162
  return
163
 
 
207
  dimensions = self._extract_keyword_dimensions(normalized_input)
208
 
209
  # 如果 SBERT 可用,進行語義分析增強
210
+ if self.sbert_model is None:
211
+ self._initialize_sbert_model()
212
+
213
  if self.sbert_model:
214
  semantic_dimensions = self._extract_semantic_dimensions(user_input)
215
  dimensions = self._merge_dimensions(dimensions, semantic_dimensions)
 
453
  ])
454
  }
455
 
 
456
  def analyze_user_query(user_input: str) -> QueryDimensions:
457
  """
458
  便利函數:分析使用者查詢
 
478
  """
479
  engine = QueryUnderstandingEngine()
480
  dimensions = engine.analyze_query(user_input)
481
+ return engine.get_dimension_summary(dimensions)