Spaces:
Running
on
Zero
Running
on
Zero
Update semantic_breed_recommender.py
Browse files- semantic_breed_recommender.py +45 -14
semantic_breed_recommender.py
CHANGED
|
@@ -37,6 +37,7 @@ class SemanticBreedRecommender:
|
|
| 37 |
"""Initialize the semantic recommender"""
|
| 38 |
self.model_name = 'all-MiniLM-L6-v2' # Efficient SBERT model
|
| 39 |
self.sbert_model = None
|
|
|
|
| 40 |
self.breed_vectors = {}
|
| 41 |
self.breed_list = self._get_breed_list()
|
| 42 |
self.comparative_keywords = {
|
|
@@ -44,13 +45,9 @@ class SemanticBreedRecommender:
|
|
| 44 |
'then': 0.7, 'second': 0.7, 'followed': 0.6,
|
| 45 |
'third': 0.5, 'least': 0.3, 'dislike': 0.2
|
| 46 |
}
|
| 47 |
-
#
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
# self.score_calibrator = ScoreCalibrator()
|
| 51 |
-
# self.config_manager = get_config_manager()
|
| 52 |
-
self._initialize_model()
|
| 53 |
-
self._build_breed_vectors()
|
| 54 |
|
| 55 |
# Initialize multi-head scorer with SBERT model if enhanced mode is enabled
|
| 56 |
# if self.sbert_model:
|
|
@@ -74,18 +71,24 @@ class SemanticBreedRecommender:
|
|
| 74 |
'Bulldog', 'Poodle', 'Beagle', 'Rottweiler', 'Yorkshire_Terrier']
|
| 75 |
|
| 76 |
def _initialize_model(self):
|
| 77 |
-
"""Initialize SBERT model with fallback"""
|
|
|
|
|
|
|
|
|
|
| 78 |
try:
|
| 79 |
-
print("Loading SBERT model...")
|
| 80 |
# Try different model names if the primary one fails
|
| 81 |
model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
|
| 82 |
|
| 83 |
for model_name in model_options:
|
| 84 |
try:
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
| 86 |
self.model_name = model_name
|
| 87 |
-
print(f"SBERT model {model_name} loaded successfully")
|
| 88 |
-
return
|
| 89 |
except Exception as model_e:
|
| 90 |
print(f"Failed to load {model_name}: {str(model_e)}")
|
| 91 |
continue
|
|
@@ -93,12 +96,16 @@ class SemanticBreedRecommender:
|
|
| 93 |
# If all models fail
|
| 94 |
print("All SBERT models failed to load. Using basic text matching fallback.")
|
| 95 |
self.sbert_model = None
|
|
|
|
| 96 |
|
| 97 |
except Exception as e:
|
| 98 |
print(f"Failed to initialize any SBERT model: {str(e)}")
|
| 99 |
print(traceback.format_exc())
|
| 100 |
print("Will provide basic text-based recommendations without embeddings")
|
| 101 |
self.sbert_model = None
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def _create_breed_description(self, breed: str) -> str:
|
| 104 |
"""Create comprehensive natural language description for breed with all key characteristics"""
|
|
@@ -321,10 +328,14 @@ class SemanticBreedRecommender:
|
|
| 321 |
return f"{breed.replace('_', ' ')} is a dog breed with unique characteristics."
|
| 322 |
|
| 323 |
def _build_breed_vectors(self):
|
| 324 |
-
"""Build vector representations for all breeds"""
|
| 325 |
try:
|
| 326 |
print("Building breed vector database...")
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
# Skip if model is not available
|
| 329 |
if self.sbert_model is None:
|
| 330 |
print("SBERT model not available, skipping vector building")
|
|
@@ -959,12 +970,20 @@ class SemanticBreedRecommender:
|
|
| 959 |
try:
|
| 960 |
print(f"Processing user input: {user_input}")
|
| 961 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 962 |
# Check if model is available - if not, raise error
|
| 963 |
if self.sbert_model is None:
|
| 964 |
error_msg = "SBERT model not available. This could be due to:\n• Model download failed\n• Insufficient memory\n• Network connectivity issues\n\nPlease check your environment and try again."
|
| 965 |
print(f"ERROR: {error_msg}")
|
| 966 |
raise RuntimeError(error_msg)
|
| 967 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
# Generate user input embedding
|
| 969 |
user_embedding = self.sbert_model.encode(user_input, convert_to_tensor=False)
|
| 970 |
|
|
@@ -1584,6 +1603,10 @@ def get_breed_recommendations_by_description(user_description: str,
|
|
| 1584 |
try:
|
| 1585 |
print("Initializing Enhanced SemanticBreedRecommender...")
|
| 1586 |
recommender = SemanticBreedRecommender()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1587 |
|
| 1588 |
# 優先使用整合統一評分系統的增強推薦
|
| 1589 |
print("Using enhanced recommendation system with unified scoring")
|
|
@@ -1628,11 +1651,19 @@ def get_enhanced_recommendations_with_unified_scoring(user_description: str, top
|
|
| 1628 |
# 創建基本推薦器實例
|
| 1629 |
recommender = SemanticBreedRecommender()
|
| 1630 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1631 |
if not recommender.sbert_model:
|
| 1632 |
print("SBERT model not available, using basic text matching...")
|
| 1633 |
# 使用基本文字匹配邏輯
|
| 1634 |
return _get_basic_text_matching_recommendations(user_description, top_k)
|
| 1635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1636 |
# 使用語意相似度推薦
|
| 1637 |
recommendations = []
|
| 1638 |
user_embedding = recommender.sbert_model.encode(user_description)
|
|
@@ -2212,4 +2243,4 @@ def _get_basic_text_matching_recommendations(user_description: str, top_k: int =
|
|
| 2212 |
except Exception as e:
|
| 2213 |
error_msg = f"Error in basic text matching: {str(e)}"
|
| 2214 |
print(f"ERROR: {error_msg}")
|
| 2215 |
-
raise RuntimeError(error_msg) from e
|
|
|
|
| 37 |
"""Initialize the semantic recommender"""
|
| 38 |
self.model_name = 'all-MiniLM-L6-v2' # Efficient SBERT model
|
| 39 |
self.sbert_model = None
|
| 40 |
+
self._sbert_loading_attempted = False
|
| 41 |
self.breed_vectors = {}
|
| 42 |
self.breed_list = self._get_breed_list()
|
| 43 |
self.comparative_keywords = {
|
|
|
|
| 45 |
'then': 0.7, 'second': 0.7, 'followed': 0.6,
|
| 46 |
'third': 0.5, 'least': 0.3, 'dislike': 0.2
|
| 47 |
}
|
| 48 |
+
# Defer SBERT model loading until needed in GPU context
|
| 49 |
+
# This prevents CUDA initialization issues in ZeroGPU environment
|
| 50 |
+
print("SemanticBreedRecommender initialized (SBERT loading deferred)")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
# Initialize multi-head scorer with SBERT model if enhanced mode is enabled
|
| 53 |
# if self.sbert_model:
|
|
|
|
| 71 |
'Bulldog', 'Poodle', 'Beagle', 'Rottweiler', 'Yorkshire_Terrier']
|
| 72 |
|
| 73 |
def _initialize_model(self):
|
| 74 |
+
"""Initialize SBERT model with fallback - designed for ZeroGPU compatibility"""
|
| 75 |
+
if self.sbert_model is not None or self._sbert_loading_attempted:
|
| 76 |
+
return self.sbert_model
|
| 77 |
+
|
| 78 |
try:
|
| 79 |
+
print("Loading SBERT model in GPU context...")
|
| 80 |
# Try different model names if the primary one fails
|
| 81 |
model_options = ['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2']
|
| 82 |
|
| 83 |
for model_name in model_options:
|
| 84 |
try:
|
| 85 |
+
# Specify device explicitly to handle ZeroGPU environment
|
| 86 |
+
import torch
|
| 87 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 88 |
+
self.sbert_model = SentenceTransformer(model_name, device=device)
|
| 89 |
self.model_name = model_name
|
| 90 |
+
print(f"SBERT model {model_name} loaded successfully on {device}")
|
| 91 |
+
return self.sbert_model
|
| 92 |
except Exception as model_e:
|
| 93 |
print(f"Failed to load {model_name}: {str(model_e)}")
|
| 94 |
continue
|
|
|
|
| 96 |
# If all models fail
|
| 97 |
print("All SBERT models failed to load. Using basic text matching fallback.")
|
| 98 |
self.sbert_model = None
|
| 99 |
+
return None
|
| 100 |
|
| 101 |
except Exception as e:
|
| 102 |
print(f"Failed to initialize any SBERT model: {str(e)}")
|
| 103 |
print(traceback.format_exc())
|
| 104 |
print("Will provide basic text-based recommendations without embeddings")
|
| 105 |
self.sbert_model = None
|
| 106 |
+
return None
|
| 107 |
+
finally:
|
| 108 |
+
self._sbert_loading_attempted = True
|
| 109 |
|
| 110 |
def _create_breed_description(self, breed: str) -> str:
|
| 111 |
"""Create comprehensive natural language description for breed with all key characteristics"""
|
|
|
|
| 328 |
return f"{breed.replace('_', ' ')} is a dog breed with unique characteristics."
|
| 329 |
|
| 330 |
def _build_breed_vectors(self):
|
| 331 |
+
"""Build vector representations for all breeds - called lazily when needed"""
|
| 332 |
try:
|
| 333 |
print("Building breed vector database...")
|
| 334 |
|
| 335 |
+
# Initialize model if not already done
|
| 336 |
+
if self.sbert_model is None:
|
| 337 |
+
self._initialize_model()
|
| 338 |
+
|
| 339 |
# Skip if model is not available
|
| 340 |
if self.sbert_model is None:
|
| 341 |
print("SBERT model not available, skipping vector building")
|
|
|
|
| 970 |
try:
|
| 971 |
print(f"Processing user input: {user_input}")
|
| 972 |
|
| 973 |
+
# 嘗試載入SBERT模型(如果尚未載入)
|
| 974 |
+
if self.sbert_model is None:
|
| 975 |
+
self._initialize_model()
|
| 976 |
+
|
| 977 |
# Check if model is available - if not, raise error
|
| 978 |
if self.sbert_model is None:
|
| 979 |
error_msg = "SBERT model not available. This could be due to:\n• Model download failed\n• Insufficient memory\n• Network connectivity issues\n\nPlease check your environment and try again."
|
| 980 |
print(f"ERROR: {error_msg}")
|
| 981 |
raise RuntimeError(error_msg)
|
| 982 |
|
| 983 |
+
# 確保breed vectors已建構
|
| 984 |
+
if not self.breed_vectors:
|
| 985 |
+
self._build_breed_vectors()
|
| 986 |
+
|
| 987 |
# Generate user input embedding
|
| 988 |
user_embedding = self.sbert_model.encode(user_input, convert_to_tensor=False)
|
| 989 |
|
|
|
|
| 1603 |
try:
|
| 1604 |
print("Initializing Enhanced SemanticBreedRecommender...")
|
| 1605 |
recommender = SemanticBreedRecommender()
|
| 1606 |
+
|
| 1607 |
+
# 嘗試載入SBERT模型(如果尚未載入)
|
| 1608 |
+
if not recommender.sbert_model:
|
| 1609 |
+
recommender._initialize_model()
|
| 1610 |
|
| 1611 |
# 優先使用整合統一評分系統的增強推薦
|
| 1612 |
print("Using enhanced recommendation system with unified scoring")
|
|
|
|
| 1651 |
# 創建基本推薦器實例
|
| 1652 |
recommender = SemanticBreedRecommender()
|
| 1653 |
|
| 1654 |
+
# 嘗試載入SBERT模型(如果尚未載入)
|
| 1655 |
+
if not recommender.sbert_model:
|
| 1656 |
+
recommender._initialize_model()
|
| 1657 |
+
|
| 1658 |
if not recommender.sbert_model:
|
| 1659 |
print("SBERT model not available, using basic text matching...")
|
| 1660 |
# 使用基本文字匹配邏輯
|
| 1661 |
return _get_basic_text_matching_recommendations(user_description, top_k)
|
| 1662 |
|
| 1663 |
+
# 確保breed vectors已建構
|
| 1664 |
+
if not recommender.breed_vectors:
|
| 1665 |
+
recommender._build_breed_vectors()
|
| 1666 |
+
|
| 1667 |
# 使用語意相似度推薦
|
| 1668 |
recommendations = []
|
| 1669 |
user_embedding = recommender.sbert_model.encode(user_description)
|
|
|
|
| 2243 |
except Exception as e:
|
| 2244 |
error_msg = f"Error in basic text matching: {str(e)}"
|
| 2245 |
print(f"ERROR: {error_msg}")
|
| 2246 |
+
raise RuntimeError(error_msg) from e
|