jbilcke-hf HF Staff commited on
Commit
36ad7ca
·
1 Parent(s): d67e4c8

trying a fix for the A100 but I think I'll regret it

Browse files
Files changed (1) hide show
  1. app.py +26 -1
app.py CHANGED
@@ -29,6 +29,31 @@ import argparse
29
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class CropResize:
33
  def __init__(self, size=(704, 1216)):
34
  self.target_h, self.target_w = size
@@ -56,7 +81,7 @@ def create_args():
56
  args.image_start = True
57
  args.seed = None
58
  args.infer_steps = 8
59
- args.use_fp8 = True
60
  args.flow_shift_eval_video = 5.0
61
  args.sample_n_frames = 33
62
  args.num_images = 1
 
29
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
 
32
+ def detect_gpu_supports_fp8():
33
+ """Detect if the current GPU supports FP8 operations."""
34
+ if not torch.cuda.is_available():
35
+ return False
36
+
37
+ try:
38
+ # Get compute capability
39
+ compute_capability = torch.cuda.get_device_capability()
40
+ major, minor = compute_capability
41
+
42
+ # Get GPU name for logging
43
+ gpu_name = torch.cuda.get_device_name()
44
+
45
+ # FP8 with fp8e4m3fn (fp8e4nv) requires compute capability >= 9.0 (H100, H200)
46
+ # A100 has compute capability 8.0 and doesn't support this FP8 variant
47
+ supports_fp8 = major >= 9
48
+
49
+ logger.info(f"GPU detected: {gpu_name} (compute capability {major}.{minor})")
50
+ logger.info(f"FP8 support: {'Enabled' if supports_fp8 else 'Disabled (requires compute capability >= 9.0)'}")
51
+
52
+ return supports_fp8
53
+ except Exception as e:
54
+ logger.warning(f"Could not detect GPU capabilities: {e}. Disabling FP8.")
55
+ return False
56
+
57
  class CropResize:
58
  def __init__(self, size=(704, 1216)):
59
  self.target_h, self.target_w = size
 
81
  args.image_start = True
82
  args.seed = None
83
  args.infer_steps = 8
84
+ args.use_fp8 = detect_gpu_supports_fp8() # Auto-detect FP8 support based on GPU
85
  args.flow_shift_eval_video = 5.0
86
  args.sample_n_frames = 33
87
  args.num_images = 1