jbilcke-hf HF Staff commited on
Commit
60d2ea4
·
2 Parent(s): 5cd8eab 36ad7ca

Merge branch 'main' of hf.co:spaces/jbilcke-hf/Hunyuan-GameCraft into zerogpu

Browse files
Files changed (1) hide show
  1. app.py +26 -1
app.py CHANGED
@@ -30,6 +30,31 @@ import argparse
30
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  class CropResize:
34
  def __init__(self, size=(704, 1216)):
35
  self.target_h, self.target_w = size
@@ -57,7 +82,7 @@ def create_args():
57
  args.image_start = True
58
  args.seed = None
59
  args.infer_steps = 8
60
- args.use_fp8 = True
61
  args.flow_shift_eval_video = 5.0
62
  args.sample_n_frames = 33
63
  args.num_images = 1
 
30
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
33
+ def detect_gpu_supports_fp8():
34
+ """Detect if the current GPU supports FP8 operations."""
35
+ if not torch.cuda.is_available():
36
+ return False
37
+
38
+ try:
39
+ # Get compute capability
40
+ compute_capability = torch.cuda.get_device_capability()
41
+ major, minor = compute_capability
42
+
43
+ # Get GPU name for logging
44
+ gpu_name = torch.cuda.get_device_name()
45
+
46
+ # FP8 with fp8e4m3fn (fp8e4nv) requires compute capability >= 9.0 (H100, H200)
47
+ # A100 has compute capability 8.0 and doesn't support this FP8 variant
48
+ supports_fp8 = major >= 9
49
+
50
+ logger.info(f"GPU detected: {gpu_name} (compute capability {major}.{minor})")
51
+ logger.info(f"FP8 support: {'Enabled' if supports_fp8 else 'Disabled (requires compute capability >= 9.0)'}")
52
+
53
+ return supports_fp8
54
+ except Exception as e:
55
+ logger.warning(f"Could not detect GPU capabilities: {e}. Disabling FP8.")
56
+ return False
57
+
58
  class CropResize:
59
  def __init__(self, size=(704, 1216)):
60
  self.target_h, self.target_w = size
 
82
  args.image_start = True
83
  args.seed = None
84
  args.infer_steps = 8
85
+ args.use_fp8 = detect_gpu_supports_fp8() # Auto-detect FP8 support based on GPU
86
  args.flow_shift_eval_video = 5.0
87
  args.sample_n_frames = 33
88
  args.num_images = 1