ouclxy commited on
Commit
45c12d5
·
verified ·
1 Parent(s): 9df085e

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +35 -24
gradio_app.py CHANGED
@@ -208,10 +208,41 @@ with open("imgs/background.png", "rb") as f:
208
 
209
  @spaces.GPU
210
  def inference(id_image, hair_image):
211
- # Require GPU (HairMapper currently uses CUDA explicitly)
212
- if not torch.cuda.is_available():
213
- raise RuntimeError("This demo requires a GPU Space. Please enable a GPU in this Space.")
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  (
216
  log_validation,
217
  UNet3DConditionModel,
@@ -225,26 +256,6 @@ def inference(id_image, hair_image):
225
  bald_head,
226
  ) = _import_inference_bits()
227
 
228
- # Disable StyleGAN2 custom CUDA ops to avoid JIT compiling (needs ninja/NVCC).
229
- # ZeroGPU 下建议走纯 PyTorch 引用实现,避免扩展编译失败。
230
- try:
231
- from HairMapper.styleGAN2_ada_model.stylegan2_ada.torch_utils.ops import bias_act as _bias_act
232
- _bias_act.USING_CUDA_TO_SPEED_UP = False
233
- try:
234
- from HairMapper.styleGAN2_ada_model.stylegan2_ada.torch_utils.ops import upfirdn2d as _upfirdn2d
235
- if hasattr(_upfirdn2d, 'USING_CUDA_TO_SPEED_UP'):
236
- _upfirdn2d.USING_CUDA_TO_SPEED_UP = False
237
- except Exception:
238
- pass
239
- try:
240
- from HairMapper.styleGAN2_ada_model.stylegan2_ada.torch_utils.ops import filtered_lrelu as _fl
241
- if hasattr(_fl, 'USING_CUDA_TO_SPEED_UP'):
242
- _fl.USING_CUDA_TO_SPEED_UP = False
243
- except Exception:
244
- pass
245
- except Exception:
246
- pass
247
-
248
  os.makedirs("gradio_inputs", exist_ok=True)
249
  os.makedirs("gradio_outputs", exist_ok=True)
250
 
@@ -291,7 +302,7 @@ def inference(id_image, hair_image):
291
 
292
  args = Args()
293
 
294
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
295
 
296
  logging.basicConfig(
297
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
208
 
209
  @spaces.GPU
210
  def inference(id_image, hair_image):
211
+ # ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU torch.cuda.is_available 可能为 False)。
212
+ device = torch.device("cuda")
 
213
 
214
+ # 先禁用 StyleGAN2 自定义 CUDA 算子(导入 HairMapper 前),避免触发 JIT 编译。
215
+ # 1) 禁用模块级开关
216
+ try:
217
+ from torch_utils.ops import bias_act as _bias_act2
218
+ _bias_act2.USING_CUDA_TO_SPEED_UP = False
219
+ except Exception:
220
+ pass
221
+ for _mod_name in ("upfirdn2d", "filtered_lrelu"):
222
+ try:
223
+ _m = __import__(f"torch_utils.ops.{_mod_name}", fromlist=["*"])
224
+ if hasattr(_m, 'USING_CUDA_TO_SPEED_UP'):
225
+ setattr(_m, 'USING_CUDA_TO_SPEED_UP', False)
226
+ except Exception:
227
+ pass
228
+ try:
229
+ from HairMapper.styleGAN2_ada_model.stylegan2_ada.torch_utils.ops import bias_act as _bias_act
230
+ _bias_act.USING_CUDA_TO_SPEED_UP = False
231
+ except Exception:
232
+ pass
233
+
234
+ # 2) 强制 bias_act 走 ref 实现(即便上层传 impl='cuda' 也改为 'ref')。
235
+ try:
236
+ import types
237
+ from torch_utils.ops import bias_act as _ba_mod
238
+ _orig_bias_act = _ba_mod.bias_act
239
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
240
+ return _orig_bias_act(x, b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp, impl='ref')
241
+ _ba_mod.bias_act = types.FunctionType(_bias_act_ref.__code__, globals(), name='bias_act')
242
+ except Exception:
243
+ pass
244
+
245
+ # 再导入依赖(此时已关闭自定义算子与强制 ref 实现)
246
  (
247
  log_validation,
248
  UNet3DConditionModel,
 
256
  bald_head,
257
  ) = _import_inference_bits()
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  os.makedirs("gradio_inputs", exist_ok=True)
260
  os.makedirs("gradio_outputs", exist_ok=True)
261
 
 
302
 
303
  args = Args()
304
 
305
+ device = torch.device("cuda")
306
 
307
  logging.basicConfig(
308
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",