danieldk HF staff commited on
Commit
9fc83e6
·
1 Parent(s): 15057d1
build/torch-universal/triton_layer_norm/__init__.py CHANGED
@@ -1,3 +1,5 @@
1
- from .layer_norm import RMSNorm, layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
2
 
3
- __all__ = ["RMSNorm", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
 
 
 
1
+ from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
2
 
3
+ from . import layers
4
+
5
+ __all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
build/torch-universal/triton_layer_norm/layer_norm.py CHANGED
@@ -10,7 +10,7 @@ import math
10
 
11
  import torch
12
  import torch.nn.functional as F
13
- from torch.cuda.amp import custom_fwd, custom_bwd
14
 
15
  import triton
16
  import triton.language as tl
@@ -59,9 +59,9 @@ def layer_norm_ref(
59
  x = x + x1
60
  if residual is not None:
61
  x = (x + residual).to(x.dtype)
62
- out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
63
- dtype
64
- )
65
  if weight1 is None:
66
  return out if not prenorm else (out, x)
67
  else:
@@ -115,13 +115,15 @@ def rms_norm_ref(
115
  if residual is not None:
116
  x = (x + residual).to(x.dtype)
117
  rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
118
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
 
 
119
  if weight1 is None:
120
  return out if not prenorm else (out, x)
121
  else:
122
- out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
123
- dtype
124
- )
125
  return (out, out1) if not prenorm else (out, out1, x)
126
 
127
 
@@ -201,7 +203,9 @@ def _layer_norm_fwd_1pass_kernel(
201
  if HAS_DROPOUT:
202
  # Compute dropout mask
203
  # 7 rounds is good enough, and reduces register pressure
204
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
 
205
  x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
206
  if STORE_DROPOUT_MASK:
207
  tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
@@ -214,7 +218,8 @@ def _layer_norm_fwd_1pass_kernel(
214
  # Compute dropout mask
215
  # 7 rounds is good enough, and reduces register pressure
216
  keep_mask = (
217
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
218
  )
219
  x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
220
  if STORE_DROPOUT_MASK:
@@ -268,7 +273,7 @@ def _layer_norm_fwd(
268
  is_rms_norm=False,
269
  return_dropout_mask=False,
270
  out=None,
271
- residual_out=None
272
  ):
273
  if residual is not None:
274
  residual_dtype = residual.dtype
@@ -315,14 +320,21 @@ def _layer_norm_fwd(
315
  ):
316
  if residual_out is None:
317
  residual_out = torch.empty(
318
- M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
 
 
 
319
  )
320
  else:
321
  assert residual_out.shape == x.shape
322
  assert residual_out.stride(-1) == 1
323
  else:
324
  residual_out = None
325
- mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
 
 
 
 
326
  rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
327
  if dropout_p > 0.0:
328
  seeds = torch.randint(
@@ -331,7 +343,9 @@ def _layer_norm_fwd(
331
  else:
332
  seeds = None
333
  if return_dropout_mask and dropout_p > 0.0:
334
- dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
 
 
335
  else:
336
  dropout_mask = None
337
  # Less than 64KB per feature: enqueue fused kernel
@@ -401,7 +415,14 @@ def _layer_norm_fwd(
401
  triton.Config({}, num_warps=16),
402
  triton.Config({}, num_warps=32),
403
  ],
404
- key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
 
 
 
 
 
 
 
405
  )
406
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
407
  # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
@@ -529,14 +550,18 @@ def _layer_norm_bwd_kernel(
529
  if HAS_DX1:
530
  if HAS_DROPOUT:
531
  keep_mask = (
532
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
533
  )
534
  dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
535
  else:
536
  dx1 = dx
537
  tl.store(DX1 + cols, dx1, mask=mask)
538
  if HAS_DROPOUT:
539
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
 
 
540
  dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
541
  if HAS_ROWSCALE:
542
  rowscale = tl.load(ROWSCALE + row).to(tl.float32)
@@ -627,9 +652,15 @@ def _layer_norm_bwd(
627
  else None
628
  )
629
  dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
630
- y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
 
 
 
 
631
  if recompute_output:
632
- assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
 
 
633
 
634
  # Less than 64KB per feature: enqueue fused kernel
635
  MAX_FUSED_SIZE = 65536 // x.element_size()
@@ -723,7 +754,7 @@ class LayerNormFn(torch.autograd.Function):
723
  is_rms_norm=False,
724
  return_dropout_mask=False,
725
  out=None,
726
- residual_out=None
727
  ):
728
  x_shape_og = x.shape
729
  # reshape input data into 2D tensor
@@ -759,22 +790,24 @@ class LayerNormFn(torch.autograd.Function):
759
  out = out.reshape(-1, out.shape[-1])
760
  if residual_out is not None:
761
  residual_out = residual_out.reshape(-1, residual_out.shape[-1])
762
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
763
- x,
764
- weight,
765
- bias,
766
- eps,
767
- residual,
768
- x1,
769
- weight1,
770
- bias1,
771
- dropout_p=dropout_p,
772
- rowscale=rowscale,
773
- residual_dtype=residual_dtype,
774
- is_rms_norm=is_rms_norm,
775
- return_dropout_mask=return_dropout_mask,
776
- out=out,
777
- residual_out=residual_out
 
 
778
  )
779
  ctx.save_for_backward(
780
  residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
@@ -789,9 +822,15 @@ class LayerNormFn(torch.autograd.Function):
789
  ctx.x_dtype = x.dtype
790
  y = y.reshape(x_shape_og)
791
  y1 = y1.reshape(x_shape_og) if y1 is not None else None
792
- residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
793
- dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
794
- dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
 
 
 
 
 
 
795
  if not return_dropout_mask:
796
  if weight1 is None:
797
  return y if not prenorm else (y, residual_out)
@@ -890,7 +929,7 @@ def layer_norm_fn(
890
  is_rms_norm=False,
891
  return_dropout_mask=False,
892
  out=None,
893
- residual_out=None
894
  ):
895
  return LayerNormFn.apply(
896
  x,
@@ -908,7 +947,7 @@ def layer_norm_fn(
908
  is_rms_norm,
909
  return_dropout_mask,
910
  out,
911
- residual_out
912
  )
913
 
914
 
@@ -927,7 +966,7 @@ def rms_norm_fn(
927
  residual_in_fp32=False,
928
  return_dropout_mask=False,
929
  out=None,
930
- residual_out=None
931
  ):
932
  return LayerNormFn.apply(
933
  x,
@@ -945,7 +984,7 @@ def rms_norm_fn(
945
  True,
946
  return_dropout_mask,
947
  out,
948
- residual_out
949
  )
950
 
951
 
@@ -981,7 +1020,7 @@ class RMSNorm(torch.nn.Module):
981
 
982
  class LayerNormLinearFn(torch.autograd.Function):
983
  @staticmethod
984
- @custom_fwd
985
  def forward(
986
  ctx,
987
  x,
@@ -1019,17 +1058,25 @@ class LayerNormLinearFn(torch.autograd.Function):
1019
  norm_bias,
1020
  eps,
1021
  residual,
1022
- out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
 
 
 
 
1023
  residual_dtype=residual_dtype,
1024
  is_rms_norm=is_rms_norm,
1025
  )
1026
  y = y.reshape(x_shape_og)
1027
- dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
 
 
1028
  linear_weight = linear_weight.to(dtype)
1029
  linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1030
  out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1031
  # We don't store y, will be recomputed in the backward pass to save memory
1032
- ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
 
 
1033
  ctx.x_shape_og = x_shape_og
1034
  ctx.eps = eps
1035
  ctx.is_rms_norm = is_rms_norm
@@ -1040,7 +1087,7 @@ class LayerNormLinearFn(torch.autograd.Function):
1040
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1041
 
1042
  @staticmethod
1043
- @custom_bwd
1044
  def backward(ctx, dout, *args):
1045
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1046
  dout = dout.reshape(-1, dout.shape[-1])
 
10
 
11
  import torch
12
  import torch.nn.functional as F
13
+ from torch.amp import custom_fwd, custom_bwd
14
 
15
  import triton
16
  import triton.language as tl
 
59
  x = x + x1
60
  if residual is not None:
61
  x = (x + residual).to(x.dtype)
62
+ out = F.layer_norm(
63
+ x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
64
+ ).to(dtype)
65
  if weight1 is None:
66
  return out if not prenorm else (out, x)
67
  else:
 
115
  if residual is not None:
116
  x = (x + residual).to(x.dtype)
117
  rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
118
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
119
+ dtype
120
+ )
121
  if weight1 is None:
122
  return out if not prenorm else (out, x)
123
  else:
124
+ out1 = (
125
+ (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
126
+ ).to(dtype)
127
  return (out, out1) if not prenorm else (out, out1, x)
128
 
129
 
 
203
  if HAS_DROPOUT:
204
  # Compute dropout mask
205
  # 7 rounds is good enough, and reduces register pressure
206
+ keep_mask = (
207
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
208
+ )
209
  x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
210
  if STORE_DROPOUT_MASK:
211
  tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
 
218
  # Compute dropout mask
219
  # 7 rounds is good enough, and reduces register pressure
220
  keep_mask = (
221
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
222
+ > dropout_p
223
  )
224
  x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
225
  if STORE_DROPOUT_MASK:
 
273
  is_rms_norm=False,
274
  return_dropout_mask=False,
275
  out=None,
276
+ residual_out=None,
277
  ):
278
  if residual is not None:
279
  residual_dtype = residual.dtype
 
320
  ):
321
  if residual_out is None:
322
  residual_out = torch.empty(
323
+ M,
324
+ N,
325
+ device=x.device,
326
+ dtype=residual_dtype if residual_dtype is not None else x.dtype,
327
  )
328
  else:
329
  assert residual_out.shape == x.shape
330
  assert residual_out.stride(-1) == 1
331
  else:
332
  residual_out = None
333
+ mean = (
334
+ torch.empty((M,), dtype=torch.float32, device=x.device)
335
+ if not is_rms_norm
336
+ else None
337
+ )
338
  rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
339
  if dropout_p > 0.0:
340
  seeds = torch.randint(
 
343
  else:
344
  seeds = None
345
  if return_dropout_mask and dropout_p > 0.0:
346
+ dropout_mask = torch.empty(
347
+ M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
348
+ )
349
  else:
350
  dropout_mask = None
351
  # Less than 64KB per feature: enqueue fused kernel
 
415
  triton.Config({}, num_warps=16),
416
  triton.Config({}, num_warps=32),
417
  ],
418
+ key=[
419
+ "N",
420
+ "HAS_DRESIDUAL",
421
+ "STORE_DRESIDUAL",
422
+ "IS_RMS_NORM",
423
+ "HAS_BIAS",
424
+ "HAS_DROPOUT",
425
+ ],
426
  )
427
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
428
  # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
 
550
  if HAS_DX1:
551
  if HAS_DROPOUT:
552
  keep_mask = (
553
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
554
+ > dropout_p
555
  )
556
  dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
557
  else:
558
  dx1 = dx
559
  tl.store(DX1 + cols, dx1, mask=mask)
560
  if HAS_DROPOUT:
561
+ keep_mask = (
562
+ tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
563
+ > dropout_p
564
+ )
565
  dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
566
  if HAS_ROWSCALE:
567
  rowscale = tl.load(ROWSCALE + row).to(tl.float32)
 
652
  else None
653
  )
654
  dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
655
+ y = (
656
+ torch.empty(M, N, dtype=dy.dtype, device=dy.device)
657
+ if recompute_output
658
+ else None
659
+ )
660
  if recompute_output:
661
+ assert (
662
+ weight1 is None
663
+ ), "recompute_output is not supported with parallel LayerNorm"
664
 
665
  # Less than 64KB per feature: enqueue fused kernel
666
  MAX_FUSED_SIZE = 65536 // x.element_size()
 
754
  is_rms_norm=False,
755
  return_dropout_mask=False,
756
  out=None,
757
+ residual_out=None,
758
  ):
759
  x_shape_og = x.shape
760
  # reshape input data into 2D tensor
 
790
  out = out.reshape(-1, out.shape[-1])
791
  if residual_out is not None:
792
  residual_out = residual_out.reshape(-1, residual_out.shape[-1])
793
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
794
+ _layer_norm_fwd(
795
+ x,
796
+ weight,
797
+ bias,
798
+ eps,
799
+ residual,
800
+ x1,
801
+ weight1,
802
+ bias1,
803
+ dropout_p=dropout_p,
804
+ rowscale=rowscale,
805
+ residual_dtype=residual_dtype,
806
+ is_rms_norm=is_rms_norm,
807
+ return_dropout_mask=return_dropout_mask,
808
+ out=out,
809
+ residual_out=residual_out,
810
+ )
811
  )
812
  ctx.save_for_backward(
813
  residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
 
822
  ctx.x_dtype = x.dtype
823
  y = y.reshape(x_shape_og)
824
  y1 = y1.reshape(x_shape_og) if y1 is not None else None
825
+ residual_out = (
826
+ residual_out.reshape(x_shape_og) if residual_out is not None else None
827
+ )
828
+ dropout_mask = (
829
+ dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
830
+ )
831
+ dropout_mask1 = (
832
+ dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
833
+ )
834
  if not return_dropout_mask:
835
  if weight1 is None:
836
  return y if not prenorm else (y, residual_out)
 
929
  is_rms_norm=False,
930
  return_dropout_mask=False,
931
  out=None,
932
+ residual_out=None,
933
  ):
934
  return LayerNormFn.apply(
935
  x,
 
947
  is_rms_norm,
948
  return_dropout_mask,
949
  out,
950
+ residual_out,
951
  )
952
 
953
 
 
966
  residual_in_fp32=False,
967
  return_dropout_mask=False,
968
  out=None,
969
+ residual_out=None,
970
  ):
971
  return LayerNormFn.apply(
972
  x,
 
984
  True,
985
  return_dropout_mask,
986
  out,
987
+ residual_out,
988
  )
989
 
990
 
 
1020
 
1021
  class LayerNormLinearFn(torch.autograd.Function):
1022
  @staticmethod
1023
+ @custom_fwd(device_type="cuda")
1024
  def forward(
1025
  ctx,
1026
  x,
 
1058
  norm_bias,
1059
  eps,
1060
  residual,
1061
+ out_dtype=(
1062
+ None
1063
+ if not torch.is_autocast_enabled()
1064
+ else torch.get_autocast_gpu_dtype()
1065
+ ),
1066
  residual_dtype=residual_dtype,
1067
  is_rms_norm=is_rms_norm,
1068
  )
1069
  y = y.reshape(x_shape_og)
1070
+ dtype = (
1071
+ torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1072
+ )
1073
  linear_weight = linear_weight.to(dtype)
1074
  linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1075
  out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1076
  # We don't store y, will be recomputed in the backward pass to save memory
1077
+ ctx.save_for_backward(
1078
+ residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1079
+ )
1080
  ctx.x_shape_og = x_shape_og
1081
  ctx.eps = eps
1082
  ctx.is_rms_norm = is_rms_norm
 
1087
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1088
 
1089
  @staticmethod
1090
+ @custom_bwd(device_type="cuda")
1091
  def backward(ctx, dout, *args):
1092
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1093
  dout = dout.reshape(-1, dout.shape[-1])
build/torch-universal/triton_layer_norm/layers.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .layer_norm import RMSNorm
2
+
3
+
4
+ __all__ = ["RMSNorm"]