Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -39,24 +39,34 @@ import spaces
|
|
| 39 |
import torch.cuda.amp
|
| 40 |
|
| 41 |
|
| 42 |
-
@spaces.GPU
|
| 43 |
def get_device():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
print("Initializing device configuration...")
|
| 45 |
|
| 46 |
try:
|
| 47 |
-
|
| 48 |
-
# 使用 mixed precision
|
| 49 |
-
torch.set_float32_matmul_precision('medium')
|
| 50 |
-
|
| 51 |
if torch.cuda.is_available():
|
| 52 |
device = torch.device('cuda')
|
| 53 |
-
torch.cuda.
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
return device
|
| 56 |
-
|
|
|
|
| 57 |
print(f"GPU initialization error: {str(e)}")
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
| 60 |
return torch.device('cpu')
|
| 61 |
|
| 62 |
device = get_device()
|
|
@@ -152,50 +162,41 @@ class BaseModel(nn.Module):
|
|
| 152 |
|
| 153 |
def load_model(model_path, model_instance, device):
|
| 154 |
"""
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
Args:
|
| 158 |
-
model_path: 模型檔案的路徑
|
| 159 |
-
model_instance: BaseModel 的實例
|
| 160 |
-
device: 計算設備(CPU 或 GPU)
|
| 161 |
-
|
| 162 |
-
Returns:
|
| 163 |
-
載入權重後的模型實例
|
| 164 |
"""
|
| 165 |
try:
|
| 166 |
-
print(f"
|
| 167 |
|
| 168 |
-
#
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
-
#
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
# 設置為評估模式
|
| 185 |
model_instance.eval()
|
| 186 |
-
|
| 187 |
-
print("模型載入成功")
|
| 188 |
return model_instance
|
| 189 |
-
|
| 190 |
-
except Exception as e:
|
| 191 |
-
print(f"模型載入出錯: {str(e)}")
|
| 192 |
-
print("嘗試使用基本載入方式...")
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
model_instance.load_state_dict(checkpoint['base_model'], strict=False)
|
| 197 |
-
model_instance.eval()
|
| 198 |
-
return model_instance
|
| 199 |
|
| 200 |
# Initialize model
|
| 201 |
num_classes = len(dog_breeds)
|
|
|
|
| 39 |
import torch.cuda.amp
|
| 40 |
|
| 41 |
|
| 42 |
+
@spaces.GPU(duration=30) # Request smaller GPU time chunk
|
| 43 |
def get_device():
|
| 44 |
+
"""
|
| 45 |
+
Initialize device configuration with automatic CPU fallback.
|
| 46 |
+
Attempts GPU first, falls back to CPU if necessary.
|
| 47 |
+
"""
|
| 48 |
print("Initializing device configuration...")
|
| 49 |
|
| 50 |
try:
|
| 51 |
+
# Attempt GPU initialization with optimizations
|
|
|
|
|
|
|
|
|
|
| 52 |
if torch.cuda.is_available():
|
| 53 |
device = torch.device('cuda')
|
| 54 |
+
torch.cuda.init()
|
| 55 |
+
torch.set_float32_matmul_precision('medium')
|
| 56 |
+
|
| 57 |
+
# Add CUDA optimizations
|
| 58 |
+
torch.backends.cudnn.benchmark = True
|
| 59 |
+
torch.backends.cudnn.deterministic = False
|
| 60 |
+
|
| 61 |
+
print(f"Successfully initialized CUDA device: {torch.cuda.get_device_name(device)}")
|
| 62 |
return device
|
| 63 |
+
|
| 64 |
+
except (spaces.zero.gradio.HTMLError, RuntimeError) as e:
|
| 65 |
print(f"GPU initialization error: {str(e)}")
|
| 66 |
|
| 67 |
+
# CPU fallback with optimizations
|
| 68 |
+
print("Using CPU mode")
|
| 69 |
+
torch.set_num_threads(4) # Optimize CPU performance
|
| 70 |
return torch.device('cpu')
|
| 71 |
|
| 72 |
device = get_device()
|
|
|
|
| 162 |
|
| 163 |
def load_model(model_path, model_instance, device):
|
| 164 |
"""
|
| 165 |
+
Enhanced model loading function with device handling.
|
| 166 |
+
Maintains original function signature for compatibility.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
"""
|
| 168 |
try:
|
| 169 |
+
print(f"Loading model to device: {device}")
|
| 170 |
|
| 171 |
+
# Load checkpoint with optimizations
|
| 172 |
+
checkpoint = torch.load(
|
| 173 |
+
model_path,
|
| 174 |
+
map_location=device,
|
| 175 |
+
weights_only=True
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Load model weights
|
| 179 |
+
model_instance.load_state_dict(checkpoint['base_model'], strict=False)
|
| 180 |
+
model_instance = model_instance.to(device)
|
| 181 |
+
model_instance.eval()
|
| 182 |
+
|
| 183 |
+
print("Model loading successful")
|
| 184 |
+
return model_instance
|
| 185 |
|
| 186 |
+
except RuntimeError as e:
|
| 187 |
+
if "CUDA out of memory" in str(e):
|
| 188 |
+
print("GPU memory exceeded, falling back to CPU")
|
| 189 |
+
device = torch.device('cpu')
|
| 190 |
+
model_instance = model_instance.cpu()
|
| 191 |
|
| 192 |
+
# Retry loading on CPU
|
| 193 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
| 194 |
+
model_instance.load_state_dict(checkpoint['base_model'], strict=False)
|
|
|
|
|
|
|
| 195 |
model_instance.eval()
|
|
|
|
|
|
|
| 196 |
return model_instance
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
print(f"Model loading error: {str(e)}")
|
| 199 |
+
raise
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
# Initialize model
|
| 202 |
num_classes = len(dog_breeds)
|