Build
Browse files
build/torch-universal/triton_layer_norm/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1 |
-
from .layer_norm import
|
2 |
|
3 |
-
|
|
|
|
|
|
1 |
+
from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
|
2 |
|
3 |
+
from . import layers
|
4 |
+
|
5 |
+
__all__ = ["layers", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
|
build/torch-universal/triton_layer_norm/layer_norm.py
CHANGED
@@ -10,7 +10,7 @@ import math
|
|
10 |
|
11 |
import torch
|
12 |
import torch.nn.functional as F
|
13 |
-
from torch.
|
14 |
|
15 |
import triton
|
16 |
import triton.language as tl
|
@@ -59,9 +59,9 @@ def layer_norm_ref(
|
|
59 |
x = x + x1
|
60 |
if residual is not None:
|
61 |
x = (x + residual).to(x.dtype)
|
62 |
-
out = F.layer_norm(
|
63 |
-
dtype
|
64 |
-
)
|
65 |
if weight1 is None:
|
66 |
return out if not prenorm else (out, x)
|
67 |
else:
|
@@ -115,13 +115,15 @@ def rms_norm_ref(
|
|
115 |
if residual is not None:
|
116 |
x = (x + residual).to(x.dtype)
|
117 |
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
118 |
-
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
|
|
|
|
|
119 |
if weight1 is None:
|
120 |
return out if not prenorm else (out, x)
|
121 |
else:
|
122 |
-
out1 = (
|
123 |
-
|
124 |
-
)
|
125 |
return (out, out1) if not prenorm else (out, out1, x)
|
126 |
|
127 |
|
@@ -201,7 +203,9 @@ def _layer_norm_fwd_1pass_kernel(
|
|
201 |
if HAS_DROPOUT:
|
202 |
# Compute dropout mask
|
203 |
# 7 rounds is good enough, and reduces register pressure
|
204 |
-
keep_mask =
|
|
|
|
|
205 |
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
206 |
if STORE_DROPOUT_MASK:
|
207 |
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
@@ -214,7 +218,8 @@ def _layer_norm_fwd_1pass_kernel(
|
|
214 |
# Compute dropout mask
|
215 |
# 7 rounds is good enough, and reduces register pressure
|
216 |
keep_mask = (
|
217 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
|
|
218 |
)
|
219 |
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
220 |
if STORE_DROPOUT_MASK:
|
@@ -268,7 +273,7 @@ def _layer_norm_fwd(
|
|
268 |
is_rms_norm=False,
|
269 |
return_dropout_mask=False,
|
270 |
out=None,
|
271 |
-
residual_out=None
|
272 |
):
|
273 |
if residual is not None:
|
274 |
residual_dtype = residual.dtype
|
@@ -315,14 +320,21 @@ def _layer_norm_fwd(
|
|
315 |
):
|
316 |
if residual_out is None:
|
317 |
residual_out = torch.empty(
|
318 |
-
M,
|
|
|
|
|
|
|
319 |
)
|
320 |
else:
|
321 |
assert residual_out.shape == x.shape
|
322 |
assert residual_out.stride(-1) == 1
|
323 |
else:
|
324 |
residual_out = None
|
325 |
-
mean =
|
|
|
|
|
|
|
|
|
326 |
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
327 |
if dropout_p > 0.0:
|
328 |
seeds = torch.randint(
|
@@ -331,7 +343,9 @@ def _layer_norm_fwd(
|
|
331 |
else:
|
332 |
seeds = None
|
333 |
if return_dropout_mask and dropout_p > 0.0:
|
334 |
-
dropout_mask = torch.empty(
|
|
|
|
|
335 |
else:
|
336 |
dropout_mask = None
|
337 |
# Less than 64KB per feature: enqueue fused kernel
|
@@ -401,7 +415,14 @@ def _layer_norm_fwd(
|
|
401 |
triton.Config({}, num_warps=16),
|
402 |
triton.Config({}, num_warps=32),
|
403 |
],
|
404 |
-
key=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
)
|
406 |
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
407 |
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
@@ -529,14 +550,18 @@ def _layer_norm_bwd_kernel(
|
|
529 |
if HAS_DX1:
|
530 |
if HAS_DROPOUT:
|
531 |
keep_mask = (
|
532 |
-
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
|
|
533 |
)
|
534 |
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
535 |
else:
|
536 |
dx1 = dx
|
537 |
tl.store(DX1 + cols, dx1, mask=mask)
|
538 |
if HAS_DROPOUT:
|
539 |
-
keep_mask =
|
|
|
|
|
|
|
540 |
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
541 |
if HAS_ROWSCALE:
|
542 |
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
@@ -627,9 +652,15 @@ def _layer_norm_bwd(
|
|
627 |
else None
|
628 |
)
|
629 |
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
630 |
-
y =
|
|
|
|
|
|
|
|
|
631 |
if recompute_output:
|
632 |
-
assert
|
|
|
|
|
633 |
|
634 |
# Less than 64KB per feature: enqueue fused kernel
|
635 |
MAX_FUSED_SIZE = 65536 // x.element_size()
|
@@ -723,7 +754,7 @@ class LayerNormFn(torch.autograd.Function):
|
|
723 |
is_rms_norm=False,
|
724 |
return_dropout_mask=False,
|
725 |
out=None,
|
726 |
-
residual_out=None
|
727 |
):
|
728 |
x_shape_og = x.shape
|
729 |
# reshape input data into 2D tensor
|
@@ -759,22 +790,24 @@ class LayerNormFn(torch.autograd.Function):
|
|
759 |
out = out.reshape(-1, out.shape[-1])
|
760 |
if residual_out is not None:
|
761 |
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
|
762 |
-
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 =
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
|
|
|
|
778 |
)
|
779 |
ctx.save_for_backward(
|
780 |
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
@@ -789,9 +822,15 @@ class LayerNormFn(torch.autograd.Function):
|
|
789 |
ctx.x_dtype = x.dtype
|
790 |
y = y.reshape(x_shape_og)
|
791 |
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
792 |
-
residual_out =
|
793 |
-
|
794 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
795 |
if not return_dropout_mask:
|
796 |
if weight1 is None:
|
797 |
return y if not prenorm else (y, residual_out)
|
@@ -890,7 +929,7 @@ def layer_norm_fn(
|
|
890 |
is_rms_norm=False,
|
891 |
return_dropout_mask=False,
|
892 |
out=None,
|
893 |
-
residual_out=None
|
894 |
):
|
895 |
return LayerNormFn.apply(
|
896 |
x,
|
@@ -908,7 +947,7 @@ def layer_norm_fn(
|
|
908 |
is_rms_norm,
|
909 |
return_dropout_mask,
|
910 |
out,
|
911 |
-
residual_out
|
912 |
)
|
913 |
|
914 |
|
@@ -927,7 +966,7 @@ def rms_norm_fn(
|
|
927 |
residual_in_fp32=False,
|
928 |
return_dropout_mask=False,
|
929 |
out=None,
|
930 |
-
residual_out=None
|
931 |
):
|
932 |
return LayerNormFn.apply(
|
933 |
x,
|
@@ -945,7 +984,7 @@ def rms_norm_fn(
|
|
945 |
True,
|
946 |
return_dropout_mask,
|
947 |
out,
|
948 |
-
residual_out
|
949 |
)
|
950 |
|
951 |
|
@@ -981,7 +1020,7 @@ class RMSNorm(torch.nn.Module):
|
|
981 |
|
982 |
class LayerNormLinearFn(torch.autograd.Function):
|
983 |
@staticmethod
|
984 |
-
@custom_fwd
|
985 |
def forward(
|
986 |
ctx,
|
987 |
x,
|
@@ -1019,17 +1058,25 @@ class LayerNormLinearFn(torch.autograd.Function):
|
|
1019 |
norm_bias,
|
1020 |
eps,
|
1021 |
residual,
|
1022 |
-
out_dtype=
|
|
|
|
|
|
|
|
|
1023 |
residual_dtype=residual_dtype,
|
1024 |
is_rms_norm=is_rms_norm,
|
1025 |
)
|
1026 |
y = y.reshape(x_shape_og)
|
1027 |
-
dtype =
|
|
|
|
|
1028 |
linear_weight = linear_weight.to(dtype)
|
1029 |
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
1030 |
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
1031 |
# We don't store y, will be recomputed in the backward pass to save memory
|
1032 |
-
ctx.save_for_backward(
|
|
|
|
|
1033 |
ctx.x_shape_og = x_shape_og
|
1034 |
ctx.eps = eps
|
1035 |
ctx.is_rms_norm = is_rms_norm
|
@@ -1040,7 +1087,7 @@ class LayerNormLinearFn(torch.autograd.Function):
|
|
1040 |
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
1041 |
|
1042 |
@staticmethod
|
1043 |
-
@custom_bwd
|
1044 |
def backward(ctx, dout, *args):
|
1045 |
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
1046 |
dout = dout.reshape(-1, dout.shape[-1])
|
|
|
10 |
|
11 |
import torch
|
12 |
import torch.nn.functional as F
|
13 |
+
from torch.amp import custom_fwd, custom_bwd
|
14 |
|
15 |
import triton
|
16 |
import triton.language as tl
|
|
|
59 |
x = x + x1
|
60 |
if residual is not None:
|
61 |
x = (x + residual).to(x.dtype)
|
62 |
+
out = F.layer_norm(
|
63 |
+
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
|
64 |
+
).to(dtype)
|
65 |
if weight1 is None:
|
66 |
return out if not prenorm else (out, x)
|
67 |
else:
|
|
|
115 |
if residual is not None:
|
116 |
x = (x + residual).to(x.dtype)
|
117 |
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
118 |
+
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
|
119 |
+
dtype
|
120 |
+
)
|
121 |
if weight1 is None:
|
122 |
return out if not prenorm else (out, x)
|
123 |
else:
|
124 |
+
out1 = (
|
125 |
+
(x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
|
126 |
+
).to(dtype)
|
127 |
return (out, out1) if not prenorm else (out, out1, x)
|
128 |
|
129 |
|
|
|
203 |
if HAS_DROPOUT:
|
204 |
# Compute dropout mask
|
205 |
# 7 rounds is good enough, and reduces register pressure
|
206 |
+
keep_mask = (
|
207 |
+
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
208 |
+
)
|
209 |
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
210 |
if STORE_DROPOUT_MASK:
|
211 |
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
|
|
218 |
# Compute dropout mask
|
219 |
# 7 rounds is good enough, and reduces register pressure
|
220 |
keep_mask = (
|
221 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
222 |
+
> dropout_p
|
223 |
)
|
224 |
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
225 |
if STORE_DROPOUT_MASK:
|
|
|
273 |
is_rms_norm=False,
|
274 |
return_dropout_mask=False,
|
275 |
out=None,
|
276 |
+
residual_out=None,
|
277 |
):
|
278 |
if residual is not None:
|
279 |
residual_dtype = residual.dtype
|
|
|
320 |
):
|
321 |
if residual_out is None:
|
322 |
residual_out = torch.empty(
|
323 |
+
M,
|
324 |
+
N,
|
325 |
+
device=x.device,
|
326 |
+
dtype=residual_dtype if residual_dtype is not None else x.dtype,
|
327 |
)
|
328 |
else:
|
329 |
assert residual_out.shape == x.shape
|
330 |
assert residual_out.stride(-1) == 1
|
331 |
else:
|
332 |
residual_out = None
|
333 |
+
mean = (
|
334 |
+
torch.empty((M,), dtype=torch.float32, device=x.device)
|
335 |
+
if not is_rms_norm
|
336 |
+
else None
|
337 |
+
)
|
338 |
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
339 |
if dropout_p > 0.0:
|
340 |
seeds = torch.randint(
|
|
|
343 |
else:
|
344 |
seeds = None
|
345 |
if return_dropout_mask and dropout_p > 0.0:
|
346 |
+
dropout_mask = torch.empty(
|
347 |
+
M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
|
348 |
+
)
|
349 |
else:
|
350 |
dropout_mask = None
|
351 |
# Less than 64KB per feature: enqueue fused kernel
|
|
|
415 |
triton.Config({}, num_warps=16),
|
416 |
triton.Config({}, num_warps=32),
|
417 |
],
|
418 |
+
key=[
|
419 |
+
"N",
|
420 |
+
"HAS_DRESIDUAL",
|
421 |
+
"STORE_DRESIDUAL",
|
422 |
+
"IS_RMS_NORM",
|
423 |
+
"HAS_BIAS",
|
424 |
+
"HAS_DROPOUT",
|
425 |
+
],
|
426 |
)
|
427 |
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
428 |
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
|
|
550 |
if HAS_DX1:
|
551 |
if HAS_DROPOUT:
|
552 |
keep_mask = (
|
553 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
|
554 |
+
> dropout_p
|
555 |
)
|
556 |
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
557 |
else:
|
558 |
dx1 = dx
|
559 |
tl.store(DX1 + cols, dx1, mask=mask)
|
560 |
if HAS_DROPOUT:
|
561 |
+
keep_mask = (
|
562 |
+
tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
|
563 |
+
> dropout_p
|
564 |
+
)
|
565 |
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
566 |
if HAS_ROWSCALE:
|
567 |
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
|
|
652 |
else None
|
653 |
)
|
654 |
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
655 |
+
y = (
|
656 |
+
torch.empty(M, N, dtype=dy.dtype, device=dy.device)
|
657 |
+
if recompute_output
|
658 |
+
else None
|
659 |
+
)
|
660 |
if recompute_output:
|
661 |
+
assert (
|
662 |
+
weight1 is None
|
663 |
+
), "recompute_output is not supported with parallel LayerNorm"
|
664 |
|
665 |
# Less than 64KB per feature: enqueue fused kernel
|
666 |
MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
|
754 |
is_rms_norm=False,
|
755 |
return_dropout_mask=False,
|
756 |
out=None,
|
757 |
+
residual_out=None,
|
758 |
):
|
759 |
x_shape_og = x.shape
|
760 |
# reshape input data into 2D tensor
|
|
|
790 |
out = out.reshape(-1, out.shape[-1])
|
791 |
if residual_out is not None:
|
792 |
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
|
793 |
+
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
|
794 |
+
_layer_norm_fwd(
|
795 |
+
x,
|
796 |
+
weight,
|
797 |
+
bias,
|
798 |
+
eps,
|
799 |
+
residual,
|
800 |
+
x1,
|
801 |
+
weight1,
|
802 |
+
bias1,
|
803 |
+
dropout_p=dropout_p,
|
804 |
+
rowscale=rowscale,
|
805 |
+
residual_dtype=residual_dtype,
|
806 |
+
is_rms_norm=is_rms_norm,
|
807 |
+
return_dropout_mask=return_dropout_mask,
|
808 |
+
out=out,
|
809 |
+
residual_out=residual_out,
|
810 |
+
)
|
811 |
)
|
812 |
ctx.save_for_backward(
|
813 |
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
|
|
822 |
ctx.x_dtype = x.dtype
|
823 |
y = y.reshape(x_shape_og)
|
824 |
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
825 |
+
residual_out = (
|
826 |
+
residual_out.reshape(x_shape_og) if residual_out is not None else None
|
827 |
+
)
|
828 |
+
dropout_mask = (
|
829 |
+
dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
830 |
+
)
|
831 |
+
dropout_mask1 = (
|
832 |
+
dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
833 |
+
)
|
834 |
if not return_dropout_mask:
|
835 |
if weight1 is None:
|
836 |
return y if not prenorm else (y, residual_out)
|
|
|
929 |
is_rms_norm=False,
|
930 |
return_dropout_mask=False,
|
931 |
out=None,
|
932 |
+
residual_out=None,
|
933 |
):
|
934 |
return LayerNormFn.apply(
|
935 |
x,
|
|
|
947 |
is_rms_norm,
|
948 |
return_dropout_mask,
|
949 |
out,
|
950 |
+
residual_out,
|
951 |
)
|
952 |
|
953 |
|
|
|
966 |
residual_in_fp32=False,
|
967 |
return_dropout_mask=False,
|
968 |
out=None,
|
969 |
+
residual_out=None,
|
970 |
):
|
971 |
return LayerNormFn.apply(
|
972 |
x,
|
|
|
984 |
True,
|
985 |
return_dropout_mask,
|
986 |
out,
|
987 |
+
residual_out,
|
988 |
)
|
989 |
|
990 |
|
|
|
1020 |
|
1021 |
class LayerNormLinearFn(torch.autograd.Function):
|
1022 |
@staticmethod
|
1023 |
+
@custom_fwd(device_type="cuda")
|
1024 |
def forward(
|
1025 |
ctx,
|
1026 |
x,
|
|
|
1058 |
norm_bias,
|
1059 |
eps,
|
1060 |
residual,
|
1061 |
+
out_dtype=(
|
1062 |
+
None
|
1063 |
+
if not torch.is_autocast_enabled()
|
1064 |
+
else torch.get_autocast_gpu_dtype()
|
1065 |
+
),
|
1066 |
residual_dtype=residual_dtype,
|
1067 |
is_rms_norm=is_rms_norm,
|
1068 |
)
|
1069 |
y = y.reshape(x_shape_og)
|
1070 |
+
dtype = (
|
1071 |
+
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
1072 |
+
)
|
1073 |
linear_weight = linear_weight.to(dtype)
|
1074 |
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
1075 |
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
1076 |
# We don't store y, will be recomputed in the backward pass to save memory
|
1077 |
+
ctx.save_for_backward(
|
1078 |
+
residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
|
1079 |
+
)
|
1080 |
ctx.x_shape_og = x_shape_og
|
1081 |
ctx.eps = eps
|
1082 |
ctx.is_rms_norm = is_rms_norm
|
|
|
1087 |
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
1088 |
|
1089 |
@staticmethod
|
1090 |
+
@custom_bwd(device_type="cuda")
|
1091 |
def backward(ctx, dout, *args):
|
1092 |
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
1093 |
dout = dout.reshape(-1, dout.shape[-1])
|
build/torch-universal/triton_layer_norm/layers.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .layer_norm import RMSNorm
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = ["RMSNorm"]
|