multimodalart HF Staff commited on
Commit
7ced7f8
·
verified ·
1 Parent(s): c760e50

Update comfy/float.py

Browse files
Files changed (1) hide show
  1. comfy/float.py +3 -1
comfy/float.py CHANGED
@@ -55,8 +55,10 @@ def stochastic_rounding(value, dtype, seed=0):
55
  if dtype == torch.bfloat16:
56
  return value.to(dtype=torch.bfloat16)
57
  if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
58
- #generator = torch.Generator()
59
  torch.manual_seed(seed)
 
 
60
  output = torch.empty_like(value, dtype=dtype)
61
  num_slices = max(1, (value.numel() / (4096 * 4096)))
62
  slice_size = max(1, round(value.shape[0] / num_slices))
 
55
  if dtype == torch.bfloat16:
56
  return value.to(dtype=torch.bfloat16)
57
  if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
58
+ generator = torch.Generator()
59
  torch.manual_seed(seed)
60
+ if(torch.cuda.is_available()):
61
+ torch.cuda.manual_seed(seed)
62
  output = torch.empty_like(value, dtype=dtype)
63
  num_slices = max(1, (value.numel() / (4096 * 4096)))
64
  slice_size = max(1, round(value.shape[0] / num_slices))