multimodalart HF Staff commited on
Commit
cda9eef
·
verified ·
1 Parent(s): b2b6037

Update comfy/float.py

Browse files
Files changed (1) hide show
  1. comfy/float.py +2 -2
comfy/float.py CHANGED
@@ -55,7 +55,7 @@ def stochastic_rounding(value, dtype, seed=0):
55
  if dtype == torch.bfloat16:
56
  return value.to(dtype=torch.bfloat16)
57
  if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
58
- generator = torch.Generator(device='cuda' if torch.cuda.is_available() else 'cpu')
59
  torch.manual_seed(seed)
60
  if(torch.cuda.is_available()):
61
  torch.cuda.manual_seed(seed)
@@ -64,7 +64,7 @@ def stochastic_rounding(value, dtype, seed=0):
64
  slice_size = max(1, round(value.shape[0] / num_slices))
65
  with torch.no_grad():
66
  for i in range(0, value.shape[0], slice_size):
67
- output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
68
  return output
69
 
70
  return value.to(dtype=dtype)
 
55
  if dtype == torch.bfloat16:
56
  return value.to(dtype=torch.bfloat16)
57
  if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
58
+ #generator = torch.Generator(device='cuda' if torch.cuda.is_available() else 'cpu')
59
  torch.manual_seed(seed)
60
  if(torch.cuda.is_available()):
61
  torch.cuda.manual_seed(seed)
 
64
  slice_size = max(1, round(value.shape[0] / num_slices))
65
  with torch.no_grad():
66
  for i in range(0, value.shape[0], slice_size):
67
+ output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype))
68
  return output
69
 
70
  return value.to(dtype=dtype)