Commit
·
0f75957
0
Parent(s):
Add Triton-based layer norm from flash-attention
Browse files- README.md +9 -0
- build.toml +5 -0
- flake.nix +14 -0
- torch-ext/triton_layer_norm/__init__.py +3 -0
- torch-ext/triton_layer_norm/layer_norm.py +1112 -0
README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: bsd-3-clause
|
3 |
+
tags:
|
4 |
+
- kernel
|
5 |
+
---
|
6 |
+
|
7 |
+
## triton-layer-norm
|
8 |
+
|
9 |
+
Triton layer norm [from flash-attention](https://github.com/Dao-AILab/flash-attention).
|
build.toml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "triton_layer_norm"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
universal = true
|
flake.nix
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for Triton layer norm kernels";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs ./.;
|
14 |
+
}
|
torch-ext/triton_layer_norm/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .layer_norm import RMSNorm, layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
|
2 |
+
|
3 |
+
__all__ = ["RMSNorm", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
|
torch-ext/triton_layer_norm/layer_norm.py
ADDED
@@ -0,0 +1,1112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, Tri Dao.
|
2 |
+
# Implement dropout + residual + layer_norm / rms_norm.
|
3 |
+
|
4 |
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
5 |
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
6 |
+
# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
7 |
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
8 |
+
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.cuda.amp import custom_fwd, custom_bwd
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
|
18 |
+
|
19 |
+
def layer_norm_ref(
|
20 |
+
x,
|
21 |
+
weight,
|
22 |
+
bias,
|
23 |
+
residual=None,
|
24 |
+
x1=None,
|
25 |
+
weight1=None,
|
26 |
+
bias1=None,
|
27 |
+
eps=1e-6,
|
28 |
+
dropout_p=0.0,
|
29 |
+
rowscale=None,
|
30 |
+
prenorm=False,
|
31 |
+
dropout_mask=None,
|
32 |
+
dropout_mask1=None,
|
33 |
+
upcast=False,
|
34 |
+
):
|
35 |
+
dtype = x.dtype
|
36 |
+
if upcast:
|
37 |
+
x = x.float()
|
38 |
+
weight = weight.float()
|
39 |
+
bias = bias.float() if bias is not None else None
|
40 |
+
residual = residual.float() if residual is not None else residual
|
41 |
+
x1 = x1.float() if x1 is not None else None
|
42 |
+
weight1 = weight1.float() if weight1 is not None else None
|
43 |
+
bias1 = bias1.float() if bias1 is not None else None
|
44 |
+
if x1 is not None:
|
45 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
46 |
+
if rowscale is not None:
|
47 |
+
x = x * rowscale[..., None]
|
48 |
+
if dropout_p > 0.0:
|
49 |
+
if dropout_mask is not None:
|
50 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
51 |
+
else:
|
52 |
+
x = F.dropout(x, p=dropout_p)
|
53 |
+
if x1 is not None:
|
54 |
+
if dropout_mask1 is not None:
|
55 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
56 |
+
else:
|
57 |
+
x1 = F.dropout(x1, p=dropout_p)
|
58 |
+
if x1 is not None:
|
59 |
+
x = x + x1
|
60 |
+
if residual is not None:
|
61 |
+
x = (x + residual).to(x.dtype)
|
62 |
+
out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
|
63 |
+
dtype
|
64 |
+
)
|
65 |
+
if weight1 is None:
|
66 |
+
return out if not prenorm else (out, x)
|
67 |
+
else:
|
68 |
+
out1 = F.layer_norm(
|
69 |
+
x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
|
70 |
+
).to(dtype)
|
71 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
72 |
+
|
73 |
+
|
74 |
+
def rms_norm_ref(
|
75 |
+
x,
|
76 |
+
weight,
|
77 |
+
bias,
|
78 |
+
residual=None,
|
79 |
+
x1=None,
|
80 |
+
weight1=None,
|
81 |
+
bias1=None,
|
82 |
+
eps=1e-6,
|
83 |
+
dropout_p=0.0,
|
84 |
+
rowscale=None,
|
85 |
+
prenorm=False,
|
86 |
+
dropout_mask=None,
|
87 |
+
dropout_mask1=None,
|
88 |
+
upcast=False,
|
89 |
+
):
|
90 |
+
dtype = x.dtype
|
91 |
+
if upcast:
|
92 |
+
x = x.float()
|
93 |
+
weight = weight.float()
|
94 |
+
bias = bias.float() if bias is not None else None
|
95 |
+
residual = residual.float() if residual is not None else residual
|
96 |
+
x1 = x1.float() if x1 is not None else None
|
97 |
+
weight1 = weight1.float() if weight1 is not None else None
|
98 |
+
bias1 = bias1.float() if bias1 is not None else None
|
99 |
+
if x1 is not None:
|
100 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
101 |
+
if rowscale is not None:
|
102 |
+
x = x * rowscale[..., None]
|
103 |
+
if dropout_p > 0.0:
|
104 |
+
if dropout_mask is not None:
|
105 |
+
x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
|
106 |
+
else:
|
107 |
+
x = F.dropout(x, p=dropout_p)
|
108 |
+
if x1 is not None:
|
109 |
+
if dropout_mask1 is not None:
|
110 |
+
x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
|
111 |
+
else:
|
112 |
+
x1 = F.dropout(x1, p=dropout_p)
|
113 |
+
if x1 is not None:
|
114 |
+
x = x + x1
|
115 |
+
if residual is not None:
|
116 |
+
x = (x + residual).to(x.dtype)
|
117 |
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
118 |
+
out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
|
119 |
+
if weight1 is None:
|
120 |
+
return out if not prenorm else (out, x)
|
121 |
+
else:
|
122 |
+
out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
|
123 |
+
dtype
|
124 |
+
)
|
125 |
+
return (out, out1) if not prenorm else (out, out1, x)
|
126 |
+
|
127 |
+
|
128 |
+
@triton.autotune(
|
129 |
+
configs=[
|
130 |
+
triton.Config({}, num_warps=1),
|
131 |
+
triton.Config({}, num_warps=2),
|
132 |
+
triton.Config({}, num_warps=4),
|
133 |
+
triton.Config({}, num_warps=8),
|
134 |
+
triton.Config({}, num_warps=16),
|
135 |
+
triton.Config({}, num_warps=32),
|
136 |
+
],
|
137 |
+
key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
|
138 |
+
)
|
139 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
140 |
+
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
141 |
+
@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
|
142 |
+
@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
|
143 |
+
@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
|
144 |
+
@triton.jit
|
145 |
+
def _layer_norm_fwd_1pass_kernel(
|
146 |
+
X, # pointer to the input
|
147 |
+
Y, # pointer to the output
|
148 |
+
W, # pointer to the weights
|
149 |
+
B, # pointer to the biases
|
150 |
+
RESIDUAL, # pointer to the residual
|
151 |
+
X1,
|
152 |
+
W1,
|
153 |
+
B1,
|
154 |
+
Y1,
|
155 |
+
RESIDUAL_OUT, # pointer to the residual
|
156 |
+
ROWSCALE,
|
157 |
+
SEEDS, # Dropout seeds for each row
|
158 |
+
DROPOUT_MASK,
|
159 |
+
Mean, # pointer to the mean
|
160 |
+
Rstd, # pointer to the 1/std
|
161 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
162 |
+
stride_y_row,
|
163 |
+
stride_res_row,
|
164 |
+
stride_res_out_row,
|
165 |
+
stride_x1_row,
|
166 |
+
stride_y1_row,
|
167 |
+
M, # number of rows in X
|
168 |
+
N, # number of columns in X
|
169 |
+
eps, # epsilon to avoid division by zero
|
170 |
+
dropout_p, # Dropout probability
|
171 |
+
IS_RMS_NORM: tl.constexpr,
|
172 |
+
BLOCK_N: tl.constexpr,
|
173 |
+
HAS_RESIDUAL: tl.constexpr,
|
174 |
+
STORE_RESIDUAL_OUT: tl.constexpr,
|
175 |
+
HAS_BIAS: tl.constexpr,
|
176 |
+
HAS_DROPOUT: tl.constexpr,
|
177 |
+
STORE_DROPOUT_MASK: tl.constexpr,
|
178 |
+
HAS_ROWSCALE: tl.constexpr,
|
179 |
+
HAS_X1: tl.constexpr,
|
180 |
+
HAS_W1: tl.constexpr,
|
181 |
+
HAS_B1: tl.constexpr,
|
182 |
+
):
|
183 |
+
# Map the program id to the row of X and Y it should compute.
|
184 |
+
row = tl.program_id(0)
|
185 |
+
X += row * stride_x_row
|
186 |
+
Y += row * stride_y_row
|
187 |
+
if HAS_RESIDUAL:
|
188 |
+
RESIDUAL += row * stride_res_row
|
189 |
+
if STORE_RESIDUAL_OUT:
|
190 |
+
RESIDUAL_OUT += row * stride_res_out_row
|
191 |
+
if HAS_X1:
|
192 |
+
X1 += row * stride_x1_row
|
193 |
+
if HAS_W1:
|
194 |
+
Y1 += row * stride_y1_row
|
195 |
+
# Compute mean and variance
|
196 |
+
cols = tl.arange(0, BLOCK_N)
|
197 |
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
198 |
+
if HAS_ROWSCALE:
|
199 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
200 |
+
x *= rowscale
|
201 |
+
if HAS_DROPOUT:
|
202 |
+
# Compute dropout mask
|
203 |
+
# 7 rounds is good enough, and reduces register pressure
|
204 |
+
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
205 |
+
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
|
206 |
+
if STORE_DROPOUT_MASK:
|
207 |
+
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
|
208 |
+
if HAS_X1:
|
209 |
+
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
|
210 |
+
if HAS_ROWSCALE:
|
211 |
+
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
|
212 |
+
x1 *= rowscale
|
213 |
+
if HAS_DROPOUT:
|
214 |
+
# Compute dropout mask
|
215 |
+
# 7 rounds is good enough, and reduces register pressure
|
216 |
+
keep_mask = (
|
217 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
218 |
+
)
|
219 |
+
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
|
220 |
+
if STORE_DROPOUT_MASK:
|
221 |
+
tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
|
222 |
+
x += x1
|
223 |
+
if HAS_RESIDUAL:
|
224 |
+
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
|
225 |
+
x += residual
|
226 |
+
if STORE_RESIDUAL_OUT:
|
227 |
+
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
228 |
+
if not IS_RMS_NORM:
|
229 |
+
mean = tl.sum(x, axis=0) / N
|
230 |
+
tl.store(Mean + row, mean)
|
231 |
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
232 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
233 |
+
else:
|
234 |
+
xbar = tl.where(cols < N, x, 0.0)
|
235 |
+
var = tl.sum(xbar * xbar, axis=0) / N
|
236 |
+
rstd = 1 / tl.sqrt(var + eps)
|
237 |
+
tl.store(Rstd + row, rstd)
|
238 |
+
# Normalize and apply linear transformation
|
239 |
+
mask = cols < N
|
240 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
241 |
+
if HAS_BIAS:
|
242 |
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
243 |
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
244 |
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
245 |
+
# Write output
|
246 |
+
tl.store(Y + cols, y, mask=mask)
|
247 |
+
if HAS_W1:
|
248 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
249 |
+
if HAS_B1:
|
250 |
+
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
|
251 |
+
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
|
252 |
+
tl.store(Y1 + cols, y1, mask=mask)
|
253 |
+
|
254 |
+
|
255 |
+
def _layer_norm_fwd(
|
256 |
+
x,
|
257 |
+
weight,
|
258 |
+
bias,
|
259 |
+
eps,
|
260 |
+
residual=None,
|
261 |
+
x1=None,
|
262 |
+
weight1=None,
|
263 |
+
bias1=None,
|
264 |
+
dropout_p=0.0,
|
265 |
+
rowscale=None,
|
266 |
+
out_dtype=None,
|
267 |
+
residual_dtype=None,
|
268 |
+
is_rms_norm=False,
|
269 |
+
return_dropout_mask=False,
|
270 |
+
out=None,
|
271 |
+
residual_out=None
|
272 |
+
):
|
273 |
+
if residual is not None:
|
274 |
+
residual_dtype = residual.dtype
|
275 |
+
M, N = x.shape
|
276 |
+
assert x.stride(-1) == 1
|
277 |
+
if residual is not None:
|
278 |
+
assert residual.stride(-1) == 1
|
279 |
+
assert residual.shape == (M, N)
|
280 |
+
assert weight.shape == (N,)
|
281 |
+
assert weight.stride(-1) == 1
|
282 |
+
if bias is not None:
|
283 |
+
assert bias.stride(-1) == 1
|
284 |
+
assert bias.shape == (N,)
|
285 |
+
if x1 is not None:
|
286 |
+
assert x1.shape == x.shape
|
287 |
+
assert rowscale is None
|
288 |
+
assert x1.stride(-1) == 1
|
289 |
+
if weight1 is not None:
|
290 |
+
assert weight1.shape == (N,)
|
291 |
+
assert weight1.stride(-1) == 1
|
292 |
+
if bias1 is not None:
|
293 |
+
assert bias1.shape == (N,)
|
294 |
+
assert bias1.stride(-1) == 1
|
295 |
+
if rowscale is not None:
|
296 |
+
assert rowscale.is_contiguous()
|
297 |
+
assert rowscale.shape == (M,)
|
298 |
+
# allocate output
|
299 |
+
if out is None:
|
300 |
+
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
|
301 |
+
else:
|
302 |
+
assert out.shape == x.shape
|
303 |
+
assert out.stride(-1) == 1
|
304 |
+
if weight1 is not None:
|
305 |
+
y1 = torch.empty_like(out)
|
306 |
+
assert y1.stride(-1) == 1
|
307 |
+
else:
|
308 |
+
y1 = None
|
309 |
+
if (
|
310 |
+
residual is not None
|
311 |
+
or (residual_dtype is not None and residual_dtype != x.dtype)
|
312 |
+
or dropout_p > 0.0
|
313 |
+
or rowscale is not None
|
314 |
+
or x1 is not None
|
315 |
+
):
|
316 |
+
if residual_out is None:
|
317 |
+
residual_out = torch.empty(
|
318 |
+
M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
assert residual_out.shape == x.shape
|
322 |
+
assert residual_out.stride(-1) == 1
|
323 |
+
else:
|
324 |
+
residual_out = None
|
325 |
+
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
|
326 |
+
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
|
327 |
+
if dropout_p > 0.0:
|
328 |
+
seeds = torch.randint(
|
329 |
+
2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
seeds = None
|
333 |
+
if return_dropout_mask and dropout_p > 0.0:
|
334 |
+
dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
|
335 |
+
else:
|
336 |
+
dropout_mask = None
|
337 |
+
# Less than 64KB per feature: enqueue fused kernel
|
338 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
339 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
340 |
+
if N > BLOCK_N:
|
341 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
342 |
+
with torch.cuda.device(x.device.index):
|
343 |
+
_layer_norm_fwd_1pass_kernel[(M,)](
|
344 |
+
x,
|
345 |
+
out,
|
346 |
+
weight,
|
347 |
+
bias,
|
348 |
+
residual,
|
349 |
+
x1,
|
350 |
+
weight1,
|
351 |
+
bias1,
|
352 |
+
y1,
|
353 |
+
residual_out,
|
354 |
+
rowscale,
|
355 |
+
seeds,
|
356 |
+
dropout_mask,
|
357 |
+
mean,
|
358 |
+
rstd,
|
359 |
+
x.stride(0),
|
360 |
+
out.stride(0),
|
361 |
+
residual.stride(0) if residual is not None else 0,
|
362 |
+
residual_out.stride(0) if residual_out is not None else 0,
|
363 |
+
x1.stride(0) if x1 is not None else 0,
|
364 |
+
y1.stride(0) if y1 is not None else 0,
|
365 |
+
M,
|
366 |
+
N,
|
367 |
+
eps,
|
368 |
+
dropout_p,
|
369 |
+
is_rms_norm,
|
370 |
+
BLOCK_N,
|
371 |
+
residual is not None,
|
372 |
+
residual_out is not None,
|
373 |
+
bias is not None,
|
374 |
+
dropout_p > 0.0,
|
375 |
+
dropout_mask is not None,
|
376 |
+
rowscale is not None,
|
377 |
+
)
|
378 |
+
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
|
379 |
+
if dropout_mask is not None and x1 is not None:
|
380 |
+
dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
|
381 |
+
else:
|
382 |
+
dropout_mask1 = None
|
383 |
+
return (
|
384 |
+
out,
|
385 |
+
y1,
|
386 |
+
mean,
|
387 |
+
rstd,
|
388 |
+
residual_out if residual_out is not None else x,
|
389 |
+
seeds,
|
390 |
+
dropout_mask,
|
391 |
+
dropout_mask1,
|
392 |
+
)
|
393 |
+
|
394 |
+
|
395 |
+
@triton.autotune(
|
396 |
+
configs=[
|
397 |
+
triton.Config({}, num_warps=1),
|
398 |
+
triton.Config({}, num_warps=2),
|
399 |
+
triton.Config({}, num_warps=4),
|
400 |
+
triton.Config({}, num_warps=8),
|
401 |
+
triton.Config({}, num_warps=16),
|
402 |
+
triton.Config({}, num_warps=32),
|
403 |
+
],
|
404 |
+
key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
|
405 |
+
)
|
406 |
+
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
407 |
+
# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
408 |
+
# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
409 |
+
@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
|
410 |
+
@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
|
411 |
+
@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
|
412 |
+
@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
|
413 |
+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
414 |
+
@triton.jit
|
415 |
+
def _layer_norm_bwd_kernel(
|
416 |
+
X, # pointer to the input
|
417 |
+
W, # pointer to the weights
|
418 |
+
B, # pointer to the biases
|
419 |
+
Y, # pointer to the output to be recomputed
|
420 |
+
DY, # pointer to the output gradient
|
421 |
+
DX, # pointer to the input gradient
|
422 |
+
DW, # pointer to the partial sum of weights gradient
|
423 |
+
DB, # pointer to the partial sum of biases gradient
|
424 |
+
DRESIDUAL,
|
425 |
+
W1,
|
426 |
+
DY1,
|
427 |
+
DX1,
|
428 |
+
DW1,
|
429 |
+
DB1,
|
430 |
+
DRESIDUAL_IN,
|
431 |
+
ROWSCALE,
|
432 |
+
SEEDS,
|
433 |
+
Mean, # pointer to the mean
|
434 |
+
Rstd, # pointer to the 1/std
|
435 |
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
436 |
+
stride_y_row,
|
437 |
+
stride_dy_row,
|
438 |
+
stride_dx_row,
|
439 |
+
stride_dres_row,
|
440 |
+
stride_dy1_row,
|
441 |
+
stride_dx1_row,
|
442 |
+
stride_dres_in_row,
|
443 |
+
M, # number of rows in X
|
444 |
+
N, # number of columns in X
|
445 |
+
eps, # epsilon to avoid division by zero
|
446 |
+
dropout_p,
|
447 |
+
rows_per_program,
|
448 |
+
IS_RMS_NORM: tl.constexpr,
|
449 |
+
BLOCK_N: tl.constexpr,
|
450 |
+
HAS_DRESIDUAL: tl.constexpr,
|
451 |
+
STORE_DRESIDUAL: tl.constexpr,
|
452 |
+
HAS_BIAS: tl.constexpr,
|
453 |
+
HAS_DROPOUT: tl.constexpr,
|
454 |
+
HAS_ROWSCALE: tl.constexpr,
|
455 |
+
HAS_DY1: tl.constexpr,
|
456 |
+
HAS_DX1: tl.constexpr,
|
457 |
+
HAS_B1: tl.constexpr,
|
458 |
+
RECOMPUTE_OUTPUT: tl.constexpr,
|
459 |
+
):
|
460 |
+
# Map the program id to the elements of X, DX, and DY it should compute.
|
461 |
+
row_block_id = tl.program_id(0)
|
462 |
+
row_start = row_block_id * rows_per_program
|
463 |
+
# Do not early exit if row_start >= M, because we need to write DW and DB
|
464 |
+
cols = tl.arange(0, BLOCK_N)
|
465 |
+
mask = cols < N
|
466 |
+
X += row_start * stride_x_row
|
467 |
+
if HAS_DRESIDUAL:
|
468 |
+
DRESIDUAL += row_start * stride_dres_row
|
469 |
+
if STORE_DRESIDUAL:
|
470 |
+
DRESIDUAL_IN += row_start * stride_dres_in_row
|
471 |
+
DY += row_start * stride_dy_row
|
472 |
+
DX += row_start * stride_dx_row
|
473 |
+
if HAS_DY1:
|
474 |
+
DY1 += row_start * stride_dy1_row
|
475 |
+
if HAS_DX1:
|
476 |
+
DX1 += row_start * stride_dx1_row
|
477 |
+
if RECOMPUTE_OUTPUT:
|
478 |
+
Y += row_start * stride_y_row
|
479 |
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
480 |
+
if RECOMPUTE_OUTPUT and HAS_BIAS:
|
481 |
+
b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
|
482 |
+
if HAS_DY1:
|
483 |
+
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
|
484 |
+
dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
485 |
+
if HAS_BIAS:
|
486 |
+
db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
487 |
+
if HAS_DY1:
|
488 |
+
dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
489 |
+
if HAS_B1:
|
490 |
+
db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
491 |
+
row_end = min((row_block_id + 1) * rows_per_program, M)
|
492 |
+
for row in range(row_start, row_end):
|
493 |
+
# Load data to SRAM
|
494 |
+
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
495 |
+
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
496 |
+
if HAS_DY1:
|
497 |
+
dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
|
498 |
+
if not IS_RMS_NORM:
|
499 |
+
mean = tl.load(Mean + row)
|
500 |
+
rstd = tl.load(Rstd + row)
|
501 |
+
# Compute dx
|
502 |
+
xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
503 |
+
xhat = tl.where(mask, xhat, 0.0)
|
504 |
+
if RECOMPUTE_OUTPUT:
|
505 |
+
y = xhat * w + b if HAS_BIAS else xhat * w
|
506 |
+
tl.store(Y + cols, y, mask=mask)
|
507 |
+
wdy = w * dy
|
508 |
+
dw += dy * xhat
|
509 |
+
if HAS_BIAS:
|
510 |
+
db += dy
|
511 |
+
if HAS_DY1:
|
512 |
+
wdy += w1 * dy1
|
513 |
+
dw1 += dy1 * xhat
|
514 |
+
if HAS_B1:
|
515 |
+
db1 += dy1
|
516 |
+
if not IS_RMS_NORM:
|
517 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
518 |
+
c2 = tl.sum(wdy, axis=0) / N
|
519 |
+
dx = (wdy - (xhat * c1 + c2)) * rstd
|
520 |
+
else:
|
521 |
+
c1 = tl.sum(xhat * wdy, axis=0) / N
|
522 |
+
dx = (wdy - xhat * c1) * rstd
|
523 |
+
if HAS_DRESIDUAL:
|
524 |
+
dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
525 |
+
dx += dres
|
526 |
+
# Write dx
|
527 |
+
if STORE_DRESIDUAL:
|
528 |
+
tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
529 |
+
if HAS_DX1:
|
530 |
+
if HAS_DROPOUT:
|
531 |
+
keep_mask = (
|
532 |
+
tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
533 |
+
)
|
534 |
+
dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
535 |
+
else:
|
536 |
+
dx1 = dx
|
537 |
+
tl.store(DX1 + cols, dx1, mask=mask)
|
538 |
+
if HAS_DROPOUT:
|
539 |
+
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
|
540 |
+
dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
|
541 |
+
if HAS_ROWSCALE:
|
542 |
+
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
|
543 |
+
dx *= rowscale
|
544 |
+
tl.store(DX + cols, dx, mask=mask)
|
545 |
+
|
546 |
+
X += stride_x_row
|
547 |
+
if HAS_DRESIDUAL:
|
548 |
+
DRESIDUAL += stride_dres_row
|
549 |
+
if STORE_DRESIDUAL:
|
550 |
+
DRESIDUAL_IN += stride_dres_in_row
|
551 |
+
if RECOMPUTE_OUTPUT:
|
552 |
+
Y += stride_y_row
|
553 |
+
DY += stride_dy_row
|
554 |
+
DX += stride_dx_row
|
555 |
+
if HAS_DY1:
|
556 |
+
DY1 += stride_dy1_row
|
557 |
+
if HAS_DX1:
|
558 |
+
DX1 += stride_dx1_row
|
559 |
+
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
560 |
+
if HAS_BIAS:
|
561 |
+
tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
562 |
+
if HAS_DY1:
|
563 |
+
tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
|
564 |
+
if HAS_B1:
|
565 |
+
tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
|
566 |
+
|
567 |
+
|
568 |
+
def _layer_norm_bwd(
|
569 |
+
dy,
|
570 |
+
x,
|
571 |
+
weight,
|
572 |
+
bias,
|
573 |
+
eps,
|
574 |
+
mean,
|
575 |
+
rstd,
|
576 |
+
dresidual=None,
|
577 |
+
dy1=None,
|
578 |
+
weight1=None,
|
579 |
+
bias1=None,
|
580 |
+
seeds=None,
|
581 |
+
dropout_p=0.0,
|
582 |
+
rowscale=None,
|
583 |
+
has_residual=False,
|
584 |
+
has_x1=False,
|
585 |
+
is_rms_norm=False,
|
586 |
+
x_dtype=None,
|
587 |
+
recompute_output=False,
|
588 |
+
):
|
589 |
+
M, N = x.shape
|
590 |
+
assert x.stride(-1) == 1
|
591 |
+
assert dy.stride(-1) == 1
|
592 |
+
assert dy.shape == (M, N)
|
593 |
+
if dresidual is not None:
|
594 |
+
assert dresidual.stride(-1) == 1
|
595 |
+
assert dresidual.shape == (M, N)
|
596 |
+
assert weight.shape == (N,)
|
597 |
+
assert weight.stride(-1) == 1
|
598 |
+
if bias is not None:
|
599 |
+
assert bias.stride(-1) == 1
|
600 |
+
assert bias.shape == (N,)
|
601 |
+
if dy1 is not None:
|
602 |
+
assert weight1 is not None
|
603 |
+
assert dy1.shape == dy.shape
|
604 |
+
assert dy1.stride(-1) == 1
|
605 |
+
if weight1 is not None:
|
606 |
+
assert weight1.shape == (N,)
|
607 |
+
assert weight1.stride(-1) == 1
|
608 |
+
if bias1 is not None:
|
609 |
+
assert bias1.shape == (N,)
|
610 |
+
assert bias1.stride(-1) == 1
|
611 |
+
if seeds is not None:
|
612 |
+
assert seeds.is_contiguous()
|
613 |
+
assert seeds.shape == (M if not has_x1 else M * 2,)
|
614 |
+
if rowscale is not None:
|
615 |
+
assert rowscale.is_contiguous()
|
616 |
+
assert rowscale.shape == (M,)
|
617 |
+
# allocate output
|
618 |
+
dx = (
|
619 |
+
torch.empty_like(x)
|
620 |
+
if x_dtype is None
|
621 |
+
else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
622 |
+
)
|
623 |
+
dresidual_in = (
|
624 |
+
torch.empty_like(x)
|
625 |
+
if has_residual
|
626 |
+
and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
|
627 |
+
else None
|
628 |
+
)
|
629 |
+
dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
|
630 |
+
y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
631 |
+
if recompute_output:
|
632 |
+
assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
|
633 |
+
|
634 |
+
# Less than 64KB per feature: enqueue fused kernel
|
635 |
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
636 |
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
637 |
+
if N > BLOCK_N:
|
638 |
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
639 |
+
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
640 |
+
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
641 |
+
_db = (
|
642 |
+
torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
|
643 |
+
if bias is not None
|
644 |
+
else None
|
645 |
+
)
|
646 |
+
_dw1 = torch.empty_like(_dw) if weight1 is not None else None
|
647 |
+
_db1 = torch.empty_like(_db) if bias1 is not None else None
|
648 |
+
rows_per_program = math.ceil(M / sm_count)
|
649 |
+
grid = (sm_count,)
|
650 |
+
with torch.cuda.device(x.device.index):
|
651 |
+
_layer_norm_bwd_kernel[grid](
|
652 |
+
x,
|
653 |
+
weight,
|
654 |
+
bias,
|
655 |
+
y,
|
656 |
+
dy,
|
657 |
+
dx,
|
658 |
+
_dw,
|
659 |
+
_db,
|
660 |
+
dresidual,
|
661 |
+
weight1,
|
662 |
+
dy1,
|
663 |
+
dx1,
|
664 |
+
_dw1,
|
665 |
+
_db1,
|
666 |
+
dresidual_in,
|
667 |
+
rowscale,
|
668 |
+
seeds,
|
669 |
+
mean,
|
670 |
+
rstd,
|
671 |
+
x.stride(0),
|
672 |
+
0 if not recompute_output else y.stride(0),
|
673 |
+
dy.stride(0),
|
674 |
+
dx.stride(0),
|
675 |
+
dresidual.stride(0) if dresidual is not None else 0,
|
676 |
+
dy1.stride(0) if dy1 is not None else 0,
|
677 |
+
dx1.stride(0) if dx1 is not None else 0,
|
678 |
+
dresidual_in.stride(0) if dresidual_in is not None else 0,
|
679 |
+
M,
|
680 |
+
N,
|
681 |
+
eps,
|
682 |
+
dropout_p,
|
683 |
+
rows_per_program,
|
684 |
+
is_rms_norm,
|
685 |
+
BLOCK_N,
|
686 |
+
dresidual is not None,
|
687 |
+
dresidual_in is not None,
|
688 |
+
bias is not None,
|
689 |
+
dropout_p > 0.0,
|
690 |
+
)
|
691 |
+
dw = _dw.sum(0).to(weight.dtype)
|
692 |
+
db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
693 |
+
dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
|
694 |
+
db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
|
695 |
+
# Don't need to compute dresidual_in separately in this case
|
696 |
+
if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
|
697 |
+
dresidual_in = dx
|
698 |
+
if has_x1 and dropout_p == 0.0:
|
699 |
+
dx1 = dx
|
700 |
+
return (
|
701 |
+
(dx, dw, db, dresidual_in, dx1, dw1, db1)
|
702 |
+
if not recompute_output
|
703 |
+
else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
|
704 |
+
)
|
705 |
+
|
706 |
+
|
707 |
+
class LayerNormFn(torch.autograd.Function):
|
708 |
+
@staticmethod
|
709 |
+
def forward(
|
710 |
+
ctx,
|
711 |
+
x,
|
712 |
+
weight,
|
713 |
+
bias,
|
714 |
+
residual=None,
|
715 |
+
x1=None,
|
716 |
+
weight1=None,
|
717 |
+
bias1=None,
|
718 |
+
eps=1e-6,
|
719 |
+
dropout_p=0.0,
|
720 |
+
rowscale=None,
|
721 |
+
prenorm=False,
|
722 |
+
residual_in_fp32=False,
|
723 |
+
is_rms_norm=False,
|
724 |
+
return_dropout_mask=False,
|
725 |
+
out=None,
|
726 |
+
residual_out=None
|
727 |
+
):
|
728 |
+
x_shape_og = x.shape
|
729 |
+
# reshape input data into 2D tensor
|
730 |
+
x = x.reshape(-1, x.shape[-1])
|
731 |
+
if x.stride(-1) != 1:
|
732 |
+
x = x.contiguous()
|
733 |
+
if residual is not None:
|
734 |
+
assert residual.shape == x_shape_og
|
735 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
736 |
+
if residual.stride(-1) != 1:
|
737 |
+
residual = residual.contiguous()
|
738 |
+
if x1 is not None:
|
739 |
+
assert x1.shape == x_shape_og
|
740 |
+
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
|
741 |
+
x1 = x1.reshape(-1, x1.shape[-1])
|
742 |
+
if x1.stride(-1) != 1:
|
743 |
+
x1 = x1.contiguous()
|
744 |
+
weight = weight.contiguous()
|
745 |
+
if bias is not None:
|
746 |
+
bias = bias.contiguous()
|
747 |
+
if weight1 is not None:
|
748 |
+
weight1 = weight1.contiguous()
|
749 |
+
if bias1 is not None:
|
750 |
+
bias1 = bias1.contiguous()
|
751 |
+
if rowscale is not None:
|
752 |
+
rowscale = rowscale.reshape(-1).contiguous()
|
753 |
+
residual_dtype = (
|
754 |
+
residual.dtype
|
755 |
+
if residual is not None
|
756 |
+
else (torch.float32 if residual_in_fp32 else None)
|
757 |
+
)
|
758 |
+
if out is not None:
|
759 |
+
out = out.reshape(-1, out.shape[-1])
|
760 |
+
if residual_out is not None:
|
761 |
+
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
|
762 |
+
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
|
763 |
+
x,
|
764 |
+
weight,
|
765 |
+
bias,
|
766 |
+
eps,
|
767 |
+
residual,
|
768 |
+
x1,
|
769 |
+
weight1,
|
770 |
+
bias1,
|
771 |
+
dropout_p=dropout_p,
|
772 |
+
rowscale=rowscale,
|
773 |
+
residual_dtype=residual_dtype,
|
774 |
+
is_rms_norm=is_rms_norm,
|
775 |
+
return_dropout_mask=return_dropout_mask,
|
776 |
+
out=out,
|
777 |
+
residual_out=residual_out
|
778 |
+
)
|
779 |
+
ctx.save_for_backward(
|
780 |
+
residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
|
781 |
+
)
|
782 |
+
ctx.x_shape_og = x_shape_og
|
783 |
+
ctx.eps = eps
|
784 |
+
ctx.dropout_p = dropout_p
|
785 |
+
ctx.is_rms_norm = is_rms_norm
|
786 |
+
ctx.has_residual = residual is not None
|
787 |
+
ctx.has_x1 = x1 is not None
|
788 |
+
ctx.prenorm = prenorm
|
789 |
+
ctx.x_dtype = x.dtype
|
790 |
+
y = y.reshape(x_shape_og)
|
791 |
+
y1 = y1.reshape(x_shape_og) if y1 is not None else None
|
792 |
+
residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
|
793 |
+
dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
|
794 |
+
dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
|
795 |
+
if not return_dropout_mask:
|
796 |
+
if weight1 is None:
|
797 |
+
return y if not prenorm else (y, residual_out)
|
798 |
+
else:
|
799 |
+
return (y, y1) if not prenorm else (y, y1, residual_out)
|
800 |
+
else:
|
801 |
+
if weight1 is None:
|
802 |
+
return (
|
803 |
+
(y, dropout_mask, dropout_mask1)
|
804 |
+
if not prenorm
|
805 |
+
else (y, residual_out, dropout_mask, dropout_mask1)
|
806 |
+
)
|
807 |
+
else:
|
808 |
+
return (
|
809 |
+
(y, y1, dropout_mask, dropout_mask1)
|
810 |
+
if not prenorm
|
811 |
+
else (y, y1, residual_out, dropout_mask, dropout_mask1)
|
812 |
+
)
|
813 |
+
|
814 |
+
@staticmethod
|
815 |
+
def backward(ctx, dy, *args):
|
816 |
+
x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
|
817 |
+
dy = dy.reshape(-1, dy.shape[-1])
|
818 |
+
if dy.stride(-1) != 1:
|
819 |
+
dy = dy.contiguous()
|
820 |
+
assert dy.shape == x.shape
|
821 |
+
if weight1 is not None:
|
822 |
+
dy1, args = args[0], args[1:]
|
823 |
+
dy1 = dy1.reshape(-1, dy1.shape[-1])
|
824 |
+
if dy1.stride(-1) != 1:
|
825 |
+
dy1 = dy1.contiguous()
|
826 |
+
assert dy1.shape == x.shape
|
827 |
+
else:
|
828 |
+
dy1 = None
|
829 |
+
if ctx.prenorm:
|
830 |
+
dresidual = args[0]
|
831 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
832 |
+
if dresidual.stride(-1) != 1:
|
833 |
+
dresidual = dresidual.contiguous()
|
834 |
+
assert dresidual.shape == x.shape
|
835 |
+
else:
|
836 |
+
dresidual = None
|
837 |
+
dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
|
838 |
+
dy,
|
839 |
+
x,
|
840 |
+
weight,
|
841 |
+
bias,
|
842 |
+
ctx.eps,
|
843 |
+
mean,
|
844 |
+
rstd,
|
845 |
+
dresidual,
|
846 |
+
dy1,
|
847 |
+
weight1,
|
848 |
+
bias1,
|
849 |
+
seeds,
|
850 |
+
ctx.dropout_p,
|
851 |
+
rowscale,
|
852 |
+
ctx.has_residual,
|
853 |
+
ctx.has_x1,
|
854 |
+
ctx.is_rms_norm,
|
855 |
+
x_dtype=ctx.x_dtype,
|
856 |
+
)
|
857 |
+
return (
|
858 |
+
dx.reshape(ctx.x_shape_og),
|
859 |
+
dw,
|
860 |
+
db,
|
861 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
862 |
+
dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
|
863 |
+
dw1,
|
864 |
+
db1,
|
865 |
+
None,
|
866 |
+
None,
|
867 |
+
None,
|
868 |
+
None,
|
869 |
+
None,
|
870 |
+
None,
|
871 |
+
None,
|
872 |
+
None,
|
873 |
+
None,
|
874 |
+
)
|
875 |
+
|
876 |
+
|
877 |
+
def layer_norm_fn(
|
878 |
+
x,
|
879 |
+
weight,
|
880 |
+
bias,
|
881 |
+
residual=None,
|
882 |
+
x1=None,
|
883 |
+
weight1=None,
|
884 |
+
bias1=None,
|
885 |
+
eps=1e-6,
|
886 |
+
dropout_p=0.0,
|
887 |
+
rowscale=None,
|
888 |
+
prenorm=False,
|
889 |
+
residual_in_fp32=False,
|
890 |
+
is_rms_norm=False,
|
891 |
+
return_dropout_mask=False,
|
892 |
+
out=None,
|
893 |
+
residual_out=None
|
894 |
+
):
|
895 |
+
return LayerNormFn.apply(
|
896 |
+
x,
|
897 |
+
weight,
|
898 |
+
bias,
|
899 |
+
residual,
|
900 |
+
x1,
|
901 |
+
weight1,
|
902 |
+
bias1,
|
903 |
+
eps,
|
904 |
+
dropout_p,
|
905 |
+
rowscale,
|
906 |
+
prenorm,
|
907 |
+
residual_in_fp32,
|
908 |
+
is_rms_norm,
|
909 |
+
return_dropout_mask,
|
910 |
+
out,
|
911 |
+
residual_out
|
912 |
+
)
|
913 |
+
|
914 |
+
|
915 |
+
def rms_norm_fn(
|
916 |
+
x,
|
917 |
+
weight,
|
918 |
+
bias,
|
919 |
+
residual=None,
|
920 |
+
x1=None,
|
921 |
+
weight1=None,
|
922 |
+
bias1=None,
|
923 |
+
eps=1e-6,
|
924 |
+
dropout_p=0.0,
|
925 |
+
rowscale=None,
|
926 |
+
prenorm=False,
|
927 |
+
residual_in_fp32=False,
|
928 |
+
return_dropout_mask=False,
|
929 |
+
out=None,
|
930 |
+
residual_out=None
|
931 |
+
):
|
932 |
+
return LayerNormFn.apply(
|
933 |
+
x,
|
934 |
+
weight,
|
935 |
+
bias,
|
936 |
+
residual,
|
937 |
+
x1,
|
938 |
+
weight1,
|
939 |
+
bias1,
|
940 |
+
eps,
|
941 |
+
dropout_p,
|
942 |
+
rowscale,
|
943 |
+
prenorm,
|
944 |
+
residual_in_fp32,
|
945 |
+
True,
|
946 |
+
return_dropout_mask,
|
947 |
+
out,
|
948 |
+
residual_out
|
949 |
+
)
|
950 |
+
|
951 |
+
|
952 |
+
class RMSNorm(torch.nn.Module):
|
953 |
+
|
954 |
+
def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
|
955 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
956 |
+
super().__init__()
|
957 |
+
self.eps = eps
|
958 |
+
if dropout_p > 0.0:
|
959 |
+
self.drop = torch.nn.Dropout(dropout_p)
|
960 |
+
else:
|
961 |
+
self.drop = None
|
962 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
963 |
+
self.register_parameter("bias", None)
|
964 |
+
self.reset_parameters()
|
965 |
+
|
966 |
+
def reset_parameters(self):
|
967 |
+
torch.nn.init.ones_(self.weight)
|
968 |
+
|
969 |
+
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
970 |
+
return rms_norm_fn(
|
971 |
+
x,
|
972 |
+
self.weight,
|
973 |
+
self.bias,
|
974 |
+
residual=residual,
|
975 |
+
eps=self.eps,
|
976 |
+
dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
|
977 |
+
prenorm=prenorm,
|
978 |
+
residual_in_fp32=residual_in_fp32,
|
979 |
+
)
|
980 |
+
|
981 |
+
|
982 |
+
class LayerNormLinearFn(torch.autograd.Function):
|
983 |
+
@staticmethod
|
984 |
+
@custom_fwd
|
985 |
+
def forward(
|
986 |
+
ctx,
|
987 |
+
x,
|
988 |
+
norm_weight,
|
989 |
+
norm_bias,
|
990 |
+
linear_weight,
|
991 |
+
linear_bias,
|
992 |
+
residual=None,
|
993 |
+
eps=1e-6,
|
994 |
+
prenorm=False,
|
995 |
+
residual_in_fp32=False,
|
996 |
+
is_rms_norm=False,
|
997 |
+
):
|
998 |
+
x_shape_og = x.shape
|
999 |
+
# reshape input data into 2D tensor
|
1000 |
+
x = x.reshape(-1, x.shape[-1])
|
1001 |
+
if x.stride(-1) != 1:
|
1002 |
+
x = x.contiguous()
|
1003 |
+
if residual is not None:
|
1004 |
+
assert residual.shape == x_shape_og
|
1005 |
+
residual = residual.reshape(-1, residual.shape[-1])
|
1006 |
+
if residual.stride(-1) != 1:
|
1007 |
+
residual = residual.contiguous()
|
1008 |
+
norm_weight = norm_weight.contiguous()
|
1009 |
+
if norm_bias is not None:
|
1010 |
+
norm_bias = norm_bias.contiguous()
|
1011 |
+
residual_dtype = (
|
1012 |
+
residual.dtype
|
1013 |
+
if residual is not None
|
1014 |
+
else (torch.float32 if residual_in_fp32 else None)
|
1015 |
+
)
|
1016 |
+
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
|
1017 |
+
x,
|
1018 |
+
norm_weight,
|
1019 |
+
norm_bias,
|
1020 |
+
eps,
|
1021 |
+
residual,
|
1022 |
+
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
|
1023 |
+
residual_dtype=residual_dtype,
|
1024 |
+
is_rms_norm=is_rms_norm,
|
1025 |
+
)
|
1026 |
+
y = y.reshape(x_shape_og)
|
1027 |
+
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
|
1028 |
+
linear_weight = linear_weight.to(dtype)
|
1029 |
+
linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
|
1030 |
+
out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
|
1031 |
+
# We don't store y, will be recomputed in the backward pass to save memory
|
1032 |
+
ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
|
1033 |
+
ctx.x_shape_og = x_shape_og
|
1034 |
+
ctx.eps = eps
|
1035 |
+
ctx.is_rms_norm = is_rms_norm
|
1036 |
+
ctx.has_residual = residual is not None
|
1037 |
+
ctx.prenorm = prenorm
|
1038 |
+
ctx.x_dtype = x.dtype
|
1039 |
+
ctx.linear_bias_is_none = linear_bias is None
|
1040 |
+
return out if not prenorm else (out, residual_out.reshape(x_shape_og))
|
1041 |
+
|
1042 |
+
@staticmethod
|
1043 |
+
@custom_bwd
|
1044 |
+
def backward(ctx, dout, *args):
|
1045 |
+
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
|
1046 |
+
dout = dout.reshape(-1, dout.shape[-1])
|
1047 |
+
dy = F.linear(dout, linear_weight.t())
|
1048 |
+
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
|
1049 |
+
if dy.stride(-1) != 1:
|
1050 |
+
dy = dy.contiguous()
|
1051 |
+
assert dy.shape == x.shape
|
1052 |
+
if ctx.prenorm:
|
1053 |
+
dresidual = args[0]
|
1054 |
+
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
1055 |
+
if dresidual.stride(-1) != 1:
|
1056 |
+
dresidual = dresidual.contiguous()
|
1057 |
+
assert dresidual.shape == x.shape
|
1058 |
+
else:
|
1059 |
+
dresidual = None
|
1060 |
+
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
|
1061 |
+
dy,
|
1062 |
+
x,
|
1063 |
+
norm_weight,
|
1064 |
+
norm_bias,
|
1065 |
+
ctx.eps,
|
1066 |
+
mean,
|
1067 |
+
rstd,
|
1068 |
+
dresidual=dresidual,
|
1069 |
+
has_residual=ctx.has_residual,
|
1070 |
+
is_rms_norm=ctx.is_rms_norm,
|
1071 |
+
x_dtype=ctx.x_dtype,
|
1072 |
+
recompute_output=True,
|
1073 |
+
)
|
1074 |
+
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
|
1075 |
+
return (
|
1076 |
+
dx.reshape(ctx.x_shape_og),
|
1077 |
+
dnorm_weight,
|
1078 |
+
dnorm_bias,
|
1079 |
+
dlinear_weight,
|
1080 |
+
dlinear_bias,
|
1081 |
+
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
|
1082 |
+
None,
|
1083 |
+
None,
|
1084 |
+
None,
|
1085 |
+
None,
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
|
1089 |
+
def layer_norm_linear_fn(
|
1090 |
+
x,
|
1091 |
+
norm_weight,
|
1092 |
+
norm_bias,
|
1093 |
+
linear_weight,
|
1094 |
+
linear_bias,
|
1095 |
+
residual=None,
|
1096 |
+
eps=1e-6,
|
1097 |
+
prenorm=False,
|
1098 |
+
residual_in_fp32=False,
|
1099 |
+
is_rms_norm=False,
|
1100 |
+
):
|
1101 |
+
return LayerNormLinearFn.apply(
|
1102 |
+
x,
|
1103 |
+
norm_weight,
|
1104 |
+
norm_bias,
|
1105 |
+
linear_weight,
|
1106 |
+
linear_bias,
|
1107 |
+
residual,
|
1108 |
+
eps,
|
1109 |
+
prenorm,
|
1110 |
+
residual_in_fp32,
|
1111 |
+
is_rms_norm,
|
1112 |
+
)
|