danieldk HF staff commited on
Commit
0f75957
·
0 Parent(s):

Add Triton-based layer norm from flash-attention

Browse files
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ ## triton-layer-norm
8
+
9
+ Triton layer norm [from flash-attention](https://github.com/Dao-AILab/flash-attention).
build.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [general]
2
+ name = "triton_layer_norm"
3
+
4
+ [torch]
5
+ universal = true
flake.nix ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Triton layer norm kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs ./.;
14
+ }
torch-ext/triton_layer_norm/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .layer_norm import RMSNorm, layer_norm_fn, layer_norm_linear_fn, rms_norm_fn
2
+
3
+ __all__ = ["RMSNorm", "layer_norm_fn", "layer_norm_linear_fn", "rms_norm_fn"]
torch-ext/triton_layer_norm/layer_norm.py ADDED
@@ -0,0 +1,1112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Implement dropout + residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.cuda.amp import custom_fwd, custom_bwd
14
+
15
+ import triton
16
+ import triton.language as tl
17
+
18
+
19
+ def layer_norm_ref(
20
+ x,
21
+ weight,
22
+ bias,
23
+ residual=None,
24
+ x1=None,
25
+ weight1=None,
26
+ bias1=None,
27
+ eps=1e-6,
28
+ dropout_p=0.0,
29
+ rowscale=None,
30
+ prenorm=False,
31
+ dropout_mask=None,
32
+ dropout_mask1=None,
33
+ upcast=False,
34
+ ):
35
+ dtype = x.dtype
36
+ if upcast:
37
+ x = x.float()
38
+ weight = weight.float()
39
+ bias = bias.float() if bias is not None else None
40
+ residual = residual.float() if residual is not None else residual
41
+ x1 = x1.float() if x1 is not None else None
42
+ weight1 = weight1.float() if weight1 is not None else None
43
+ bias1 = bias1.float() if bias1 is not None else None
44
+ if x1 is not None:
45
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
46
+ if rowscale is not None:
47
+ x = x * rowscale[..., None]
48
+ if dropout_p > 0.0:
49
+ if dropout_mask is not None:
50
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
51
+ else:
52
+ x = F.dropout(x, p=dropout_p)
53
+ if x1 is not None:
54
+ if dropout_mask1 is not None:
55
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
56
+ else:
57
+ x1 = F.dropout(x1, p=dropout_p)
58
+ if x1 is not None:
59
+ x = x + x1
60
+ if residual is not None:
61
+ x = (x + residual).to(x.dtype)
62
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
63
+ dtype
64
+ )
65
+ if weight1 is None:
66
+ return out if not prenorm else (out, x)
67
+ else:
68
+ out1 = F.layer_norm(
69
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
70
+ ).to(dtype)
71
+ return (out, out1) if not prenorm else (out, out1, x)
72
+
73
+
74
+ def rms_norm_ref(
75
+ x,
76
+ weight,
77
+ bias,
78
+ residual=None,
79
+ x1=None,
80
+ weight1=None,
81
+ bias1=None,
82
+ eps=1e-6,
83
+ dropout_p=0.0,
84
+ rowscale=None,
85
+ prenorm=False,
86
+ dropout_mask=None,
87
+ dropout_mask1=None,
88
+ upcast=False,
89
+ ):
90
+ dtype = x.dtype
91
+ if upcast:
92
+ x = x.float()
93
+ weight = weight.float()
94
+ bias = bias.float() if bias is not None else None
95
+ residual = residual.float() if residual is not None else residual
96
+ x1 = x1.float() if x1 is not None else None
97
+ weight1 = weight1.float() if weight1 is not None else None
98
+ bias1 = bias1.float() if bias1 is not None else None
99
+ if x1 is not None:
100
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
101
+ if rowscale is not None:
102
+ x = x * rowscale[..., None]
103
+ if dropout_p > 0.0:
104
+ if dropout_mask is not None:
105
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
106
+ else:
107
+ x = F.dropout(x, p=dropout_p)
108
+ if x1 is not None:
109
+ if dropout_mask1 is not None:
110
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
111
+ else:
112
+ x1 = F.dropout(x1, p=dropout_p)
113
+ if x1 is not None:
114
+ x = x + x1
115
+ if residual is not None:
116
+ x = (x + residual).to(x.dtype)
117
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
118
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
119
+ if weight1 is None:
120
+ return out if not prenorm else (out, x)
121
+ else:
122
+ out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
123
+ dtype
124
+ )
125
+ return (out, out1) if not prenorm else (out, out1, x)
126
+
127
+
128
+ @triton.autotune(
129
+ configs=[
130
+ triton.Config({}, num_warps=1),
131
+ triton.Config({}, num_warps=2),
132
+ triton.Config({}, num_warps=4),
133
+ triton.Config({}, num_warps=8),
134
+ triton.Config({}, num_warps=16),
135
+ triton.Config({}, num_warps=32),
136
+ ],
137
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
138
+ )
139
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
140
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
141
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
142
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
143
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
144
+ @triton.jit
145
+ def _layer_norm_fwd_1pass_kernel(
146
+ X, # pointer to the input
147
+ Y, # pointer to the output
148
+ W, # pointer to the weights
149
+ B, # pointer to the biases
150
+ RESIDUAL, # pointer to the residual
151
+ X1,
152
+ W1,
153
+ B1,
154
+ Y1,
155
+ RESIDUAL_OUT, # pointer to the residual
156
+ ROWSCALE,
157
+ SEEDS, # Dropout seeds for each row
158
+ DROPOUT_MASK,
159
+ Mean, # pointer to the mean
160
+ Rstd, # pointer to the 1/std
161
+ stride_x_row, # how much to increase the pointer when moving by 1 row
162
+ stride_y_row,
163
+ stride_res_row,
164
+ stride_res_out_row,
165
+ stride_x1_row,
166
+ stride_y1_row,
167
+ M, # number of rows in X
168
+ N, # number of columns in X
169
+ eps, # epsilon to avoid division by zero
170
+ dropout_p, # Dropout probability
171
+ IS_RMS_NORM: tl.constexpr,
172
+ BLOCK_N: tl.constexpr,
173
+ HAS_RESIDUAL: tl.constexpr,
174
+ STORE_RESIDUAL_OUT: tl.constexpr,
175
+ HAS_BIAS: tl.constexpr,
176
+ HAS_DROPOUT: tl.constexpr,
177
+ STORE_DROPOUT_MASK: tl.constexpr,
178
+ HAS_ROWSCALE: tl.constexpr,
179
+ HAS_X1: tl.constexpr,
180
+ HAS_W1: tl.constexpr,
181
+ HAS_B1: tl.constexpr,
182
+ ):
183
+ # Map the program id to the row of X and Y it should compute.
184
+ row = tl.program_id(0)
185
+ X += row * stride_x_row
186
+ Y += row * stride_y_row
187
+ if HAS_RESIDUAL:
188
+ RESIDUAL += row * stride_res_row
189
+ if STORE_RESIDUAL_OUT:
190
+ RESIDUAL_OUT += row * stride_res_out_row
191
+ if HAS_X1:
192
+ X1 += row * stride_x1_row
193
+ if HAS_W1:
194
+ Y1 += row * stride_y1_row
195
+ # Compute mean and variance
196
+ cols = tl.arange(0, BLOCK_N)
197
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
198
+ if HAS_ROWSCALE:
199
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
200
+ x *= rowscale
201
+ if HAS_DROPOUT:
202
+ # Compute dropout mask
203
+ # 7 rounds is good enough, and reduces register pressure
204
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
205
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
206
+ if STORE_DROPOUT_MASK:
207
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
208
+ if HAS_X1:
209
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
210
+ if HAS_ROWSCALE:
211
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
212
+ x1 *= rowscale
213
+ if HAS_DROPOUT:
214
+ # Compute dropout mask
215
+ # 7 rounds is good enough, and reduces register pressure
216
+ keep_mask = (
217
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
218
+ )
219
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
220
+ if STORE_DROPOUT_MASK:
221
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
222
+ x += x1
223
+ if HAS_RESIDUAL:
224
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
225
+ x += residual
226
+ if STORE_RESIDUAL_OUT:
227
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
228
+ if not IS_RMS_NORM:
229
+ mean = tl.sum(x, axis=0) / N
230
+ tl.store(Mean + row, mean)
231
+ xbar = tl.where(cols < N, x - mean, 0.0)
232
+ var = tl.sum(xbar * xbar, axis=0) / N
233
+ else:
234
+ xbar = tl.where(cols < N, x, 0.0)
235
+ var = tl.sum(xbar * xbar, axis=0) / N
236
+ rstd = 1 / tl.sqrt(var + eps)
237
+ tl.store(Rstd + row, rstd)
238
+ # Normalize and apply linear transformation
239
+ mask = cols < N
240
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
241
+ if HAS_BIAS:
242
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
243
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
244
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
245
+ # Write output
246
+ tl.store(Y + cols, y, mask=mask)
247
+ if HAS_W1:
248
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
249
+ if HAS_B1:
250
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
251
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
252
+ tl.store(Y1 + cols, y1, mask=mask)
253
+
254
+
255
+ def _layer_norm_fwd(
256
+ x,
257
+ weight,
258
+ bias,
259
+ eps,
260
+ residual=None,
261
+ x1=None,
262
+ weight1=None,
263
+ bias1=None,
264
+ dropout_p=0.0,
265
+ rowscale=None,
266
+ out_dtype=None,
267
+ residual_dtype=None,
268
+ is_rms_norm=False,
269
+ return_dropout_mask=False,
270
+ out=None,
271
+ residual_out=None
272
+ ):
273
+ if residual is not None:
274
+ residual_dtype = residual.dtype
275
+ M, N = x.shape
276
+ assert x.stride(-1) == 1
277
+ if residual is not None:
278
+ assert residual.stride(-1) == 1
279
+ assert residual.shape == (M, N)
280
+ assert weight.shape == (N,)
281
+ assert weight.stride(-1) == 1
282
+ if bias is not None:
283
+ assert bias.stride(-1) == 1
284
+ assert bias.shape == (N,)
285
+ if x1 is not None:
286
+ assert x1.shape == x.shape
287
+ assert rowscale is None
288
+ assert x1.stride(-1) == 1
289
+ if weight1 is not None:
290
+ assert weight1.shape == (N,)
291
+ assert weight1.stride(-1) == 1
292
+ if bias1 is not None:
293
+ assert bias1.shape == (N,)
294
+ assert bias1.stride(-1) == 1
295
+ if rowscale is not None:
296
+ assert rowscale.is_contiguous()
297
+ assert rowscale.shape == (M,)
298
+ # allocate output
299
+ if out is None:
300
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
301
+ else:
302
+ assert out.shape == x.shape
303
+ assert out.stride(-1) == 1
304
+ if weight1 is not None:
305
+ y1 = torch.empty_like(out)
306
+ assert y1.stride(-1) == 1
307
+ else:
308
+ y1 = None
309
+ if (
310
+ residual is not None
311
+ or (residual_dtype is not None and residual_dtype != x.dtype)
312
+ or dropout_p > 0.0
313
+ or rowscale is not None
314
+ or x1 is not None
315
+ ):
316
+ if residual_out is None:
317
+ residual_out = torch.empty(
318
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
319
+ )
320
+ else:
321
+ assert residual_out.shape == x.shape
322
+ assert residual_out.stride(-1) == 1
323
+ else:
324
+ residual_out = None
325
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
326
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
327
+ if dropout_p > 0.0:
328
+ seeds = torch.randint(
329
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
330
+ )
331
+ else:
332
+ seeds = None
333
+ if return_dropout_mask and dropout_p > 0.0:
334
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
335
+ else:
336
+ dropout_mask = None
337
+ # Less than 64KB per feature: enqueue fused kernel
338
+ MAX_FUSED_SIZE = 65536 // x.element_size()
339
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
340
+ if N > BLOCK_N:
341
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
342
+ with torch.cuda.device(x.device.index):
343
+ _layer_norm_fwd_1pass_kernel[(M,)](
344
+ x,
345
+ out,
346
+ weight,
347
+ bias,
348
+ residual,
349
+ x1,
350
+ weight1,
351
+ bias1,
352
+ y1,
353
+ residual_out,
354
+ rowscale,
355
+ seeds,
356
+ dropout_mask,
357
+ mean,
358
+ rstd,
359
+ x.stride(0),
360
+ out.stride(0),
361
+ residual.stride(0) if residual is not None else 0,
362
+ residual_out.stride(0) if residual_out is not None else 0,
363
+ x1.stride(0) if x1 is not None else 0,
364
+ y1.stride(0) if y1 is not None else 0,
365
+ M,
366
+ N,
367
+ eps,
368
+ dropout_p,
369
+ is_rms_norm,
370
+ BLOCK_N,
371
+ residual is not None,
372
+ residual_out is not None,
373
+ bias is not None,
374
+ dropout_p > 0.0,
375
+ dropout_mask is not None,
376
+ rowscale is not None,
377
+ )
378
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
379
+ if dropout_mask is not None and x1 is not None:
380
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
381
+ else:
382
+ dropout_mask1 = None
383
+ return (
384
+ out,
385
+ y1,
386
+ mean,
387
+ rstd,
388
+ residual_out if residual_out is not None else x,
389
+ seeds,
390
+ dropout_mask,
391
+ dropout_mask1,
392
+ )
393
+
394
+
395
+ @triton.autotune(
396
+ configs=[
397
+ triton.Config({}, num_warps=1),
398
+ triton.Config({}, num_warps=2),
399
+ triton.Config({}, num_warps=4),
400
+ triton.Config({}, num_warps=8),
401
+ triton.Config({}, num_warps=16),
402
+ triton.Config({}, num_warps=32),
403
+ ],
404
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
405
+ )
406
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
407
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
408
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
409
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
410
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
411
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
412
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
413
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
414
+ @triton.jit
415
+ def _layer_norm_bwd_kernel(
416
+ X, # pointer to the input
417
+ W, # pointer to the weights
418
+ B, # pointer to the biases
419
+ Y, # pointer to the output to be recomputed
420
+ DY, # pointer to the output gradient
421
+ DX, # pointer to the input gradient
422
+ DW, # pointer to the partial sum of weights gradient
423
+ DB, # pointer to the partial sum of biases gradient
424
+ DRESIDUAL,
425
+ W1,
426
+ DY1,
427
+ DX1,
428
+ DW1,
429
+ DB1,
430
+ DRESIDUAL_IN,
431
+ ROWSCALE,
432
+ SEEDS,
433
+ Mean, # pointer to the mean
434
+ Rstd, # pointer to the 1/std
435
+ stride_x_row, # how much to increase the pointer when moving by 1 row
436
+ stride_y_row,
437
+ stride_dy_row,
438
+ stride_dx_row,
439
+ stride_dres_row,
440
+ stride_dy1_row,
441
+ stride_dx1_row,
442
+ stride_dres_in_row,
443
+ M, # number of rows in X
444
+ N, # number of columns in X
445
+ eps, # epsilon to avoid division by zero
446
+ dropout_p,
447
+ rows_per_program,
448
+ IS_RMS_NORM: tl.constexpr,
449
+ BLOCK_N: tl.constexpr,
450
+ HAS_DRESIDUAL: tl.constexpr,
451
+ STORE_DRESIDUAL: tl.constexpr,
452
+ HAS_BIAS: tl.constexpr,
453
+ HAS_DROPOUT: tl.constexpr,
454
+ HAS_ROWSCALE: tl.constexpr,
455
+ HAS_DY1: tl.constexpr,
456
+ HAS_DX1: tl.constexpr,
457
+ HAS_B1: tl.constexpr,
458
+ RECOMPUTE_OUTPUT: tl.constexpr,
459
+ ):
460
+ # Map the program id to the elements of X, DX, and DY it should compute.
461
+ row_block_id = tl.program_id(0)
462
+ row_start = row_block_id * rows_per_program
463
+ # Do not early exit if row_start >= M, because we need to write DW and DB
464
+ cols = tl.arange(0, BLOCK_N)
465
+ mask = cols < N
466
+ X += row_start * stride_x_row
467
+ if HAS_DRESIDUAL:
468
+ DRESIDUAL += row_start * stride_dres_row
469
+ if STORE_DRESIDUAL:
470
+ DRESIDUAL_IN += row_start * stride_dres_in_row
471
+ DY += row_start * stride_dy_row
472
+ DX += row_start * stride_dx_row
473
+ if HAS_DY1:
474
+ DY1 += row_start * stride_dy1_row
475
+ if HAS_DX1:
476
+ DX1 += row_start * stride_dx1_row
477
+ if RECOMPUTE_OUTPUT:
478
+ Y += row_start * stride_y_row
479
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
480
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
481
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
482
+ if HAS_DY1:
483
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
484
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
485
+ if HAS_BIAS:
486
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
487
+ if HAS_DY1:
488
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
489
+ if HAS_B1:
490
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
491
+ row_end = min((row_block_id + 1) * rows_per_program, M)
492
+ for row in range(row_start, row_end):
493
+ # Load data to SRAM
494
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
495
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
496
+ if HAS_DY1:
497
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
498
+ if not IS_RMS_NORM:
499
+ mean = tl.load(Mean + row)
500
+ rstd = tl.load(Rstd + row)
501
+ # Compute dx
502
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
503
+ xhat = tl.where(mask, xhat, 0.0)
504
+ if RECOMPUTE_OUTPUT:
505
+ y = xhat * w + b if HAS_BIAS else xhat * w
506
+ tl.store(Y + cols, y, mask=mask)
507
+ wdy = w * dy
508
+ dw += dy * xhat
509
+ if HAS_BIAS:
510
+ db += dy
511
+ if HAS_DY1:
512
+ wdy += w1 * dy1
513
+ dw1 += dy1 * xhat
514
+ if HAS_B1:
515
+ db1 += dy1
516
+ if not IS_RMS_NORM:
517
+ c1 = tl.sum(xhat * wdy, axis=0) / N
518
+ c2 = tl.sum(wdy, axis=0) / N
519
+ dx = (wdy - (xhat * c1 + c2)) * rstd
520
+ else:
521
+ c1 = tl.sum(xhat * wdy, axis=0) / N
522
+ dx = (wdy - xhat * c1) * rstd
523
+ if HAS_DRESIDUAL:
524
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
525
+ dx += dres
526
+ # Write dx
527
+ if STORE_DRESIDUAL:
528
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
529
+ if HAS_DX1:
530
+ if HAS_DROPOUT:
531
+ keep_mask = (
532
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
533
+ )
534
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
535
+ else:
536
+ dx1 = dx
537
+ tl.store(DX1 + cols, dx1, mask=mask)
538
+ if HAS_DROPOUT:
539
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
540
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
541
+ if HAS_ROWSCALE:
542
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
543
+ dx *= rowscale
544
+ tl.store(DX + cols, dx, mask=mask)
545
+
546
+ X += stride_x_row
547
+ if HAS_DRESIDUAL:
548
+ DRESIDUAL += stride_dres_row
549
+ if STORE_DRESIDUAL:
550
+ DRESIDUAL_IN += stride_dres_in_row
551
+ if RECOMPUTE_OUTPUT:
552
+ Y += stride_y_row
553
+ DY += stride_dy_row
554
+ DX += stride_dx_row
555
+ if HAS_DY1:
556
+ DY1 += stride_dy1_row
557
+ if HAS_DX1:
558
+ DX1 += stride_dx1_row
559
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
560
+ if HAS_BIAS:
561
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
562
+ if HAS_DY1:
563
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
564
+ if HAS_B1:
565
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
566
+
567
+
568
+ def _layer_norm_bwd(
569
+ dy,
570
+ x,
571
+ weight,
572
+ bias,
573
+ eps,
574
+ mean,
575
+ rstd,
576
+ dresidual=None,
577
+ dy1=None,
578
+ weight1=None,
579
+ bias1=None,
580
+ seeds=None,
581
+ dropout_p=0.0,
582
+ rowscale=None,
583
+ has_residual=False,
584
+ has_x1=False,
585
+ is_rms_norm=False,
586
+ x_dtype=None,
587
+ recompute_output=False,
588
+ ):
589
+ M, N = x.shape
590
+ assert x.stride(-1) == 1
591
+ assert dy.stride(-1) == 1
592
+ assert dy.shape == (M, N)
593
+ if dresidual is not None:
594
+ assert dresidual.stride(-1) == 1
595
+ assert dresidual.shape == (M, N)
596
+ assert weight.shape == (N,)
597
+ assert weight.stride(-1) == 1
598
+ if bias is not None:
599
+ assert bias.stride(-1) == 1
600
+ assert bias.shape == (N,)
601
+ if dy1 is not None:
602
+ assert weight1 is not None
603
+ assert dy1.shape == dy.shape
604
+ assert dy1.stride(-1) == 1
605
+ if weight1 is not None:
606
+ assert weight1.shape == (N,)
607
+ assert weight1.stride(-1) == 1
608
+ if bias1 is not None:
609
+ assert bias1.shape == (N,)
610
+ assert bias1.stride(-1) == 1
611
+ if seeds is not None:
612
+ assert seeds.is_contiguous()
613
+ assert seeds.shape == (M if not has_x1 else M * 2,)
614
+ if rowscale is not None:
615
+ assert rowscale.is_contiguous()
616
+ assert rowscale.shape == (M,)
617
+ # allocate output
618
+ dx = (
619
+ torch.empty_like(x)
620
+ if x_dtype is None
621
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
622
+ )
623
+ dresidual_in = (
624
+ torch.empty_like(x)
625
+ if has_residual
626
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
627
+ else None
628
+ )
629
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
630
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
631
+ if recompute_output:
632
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
633
+
634
+ # Less than 64KB per feature: enqueue fused kernel
635
+ MAX_FUSED_SIZE = 65536 // x.element_size()
636
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
637
+ if N > BLOCK_N:
638
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
639
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
640
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
641
+ _db = (
642
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
643
+ if bias is not None
644
+ else None
645
+ )
646
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
647
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
648
+ rows_per_program = math.ceil(M / sm_count)
649
+ grid = (sm_count,)
650
+ with torch.cuda.device(x.device.index):
651
+ _layer_norm_bwd_kernel[grid](
652
+ x,
653
+ weight,
654
+ bias,
655
+ y,
656
+ dy,
657
+ dx,
658
+ _dw,
659
+ _db,
660
+ dresidual,
661
+ weight1,
662
+ dy1,
663
+ dx1,
664
+ _dw1,
665
+ _db1,
666
+ dresidual_in,
667
+ rowscale,
668
+ seeds,
669
+ mean,
670
+ rstd,
671
+ x.stride(0),
672
+ 0 if not recompute_output else y.stride(0),
673
+ dy.stride(0),
674
+ dx.stride(0),
675
+ dresidual.stride(0) if dresidual is not None else 0,
676
+ dy1.stride(0) if dy1 is not None else 0,
677
+ dx1.stride(0) if dx1 is not None else 0,
678
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
679
+ M,
680
+ N,
681
+ eps,
682
+ dropout_p,
683
+ rows_per_program,
684
+ is_rms_norm,
685
+ BLOCK_N,
686
+ dresidual is not None,
687
+ dresidual_in is not None,
688
+ bias is not None,
689
+ dropout_p > 0.0,
690
+ )
691
+ dw = _dw.sum(0).to(weight.dtype)
692
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
693
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
694
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
695
+ # Don't need to compute dresidual_in separately in this case
696
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
697
+ dresidual_in = dx
698
+ if has_x1 and dropout_p == 0.0:
699
+ dx1 = dx
700
+ return (
701
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
702
+ if not recompute_output
703
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
704
+ )
705
+
706
+
707
+ class LayerNormFn(torch.autograd.Function):
708
+ @staticmethod
709
+ def forward(
710
+ ctx,
711
+ x,
712
+ weight,
713
+ bias,
714
+ residual=None,
715
+ x1=None,
716
+ weight1=None,
717
+ bias1=None,
718
+ eps=1e-6,
719
+ dropout_p=0.0,
720
+ rowscale=None,
721
+ prenorm=False,
722
+ residual_in_fp32=False,
723
+ is_rms_norm=False,
724
+ return_dropout_mask=False,
725
+ out=None,
726
+ residual_out=None
727
+ ):
728
+ x_shape_og = x.shape
729
+ # reshape input data into 2D tensor
730
+ x = x.reshape(-1, x.shape[-1])
731
+ if x.stride(-1) != 1:
732
+ x = x.contiguous()
733
+ if residual is not None:
734
+ assert residual.shape == x_shape_og
735
+ residual = residual.reshape(-1, residual.shape[-1])
736
+ if residual.stride(-1) != 1:
737
+ residual = residual.contiguous()
738
+ if x1 is not None:
739
+ assert x1.shape == x_shape_og
740
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
741
+ x1 = x1.reshape(-1, x1.shape[-1])
742
+ if x1.stride(-1) != 1:
743
+ x1 = x1.contiguous()
744
+ weight = weight.contiguous()
745
+ if bias is not None:
746
+ bias = bias.contiguous()
747
+ if weight1 is not None:
748
+ weight1 = weight1.contiguous()
749
+ if bias1 is not None:
750
+ bias1 = bias1.contiguous()
751
+ if rowscale is not None:
752
+ rowscale = rowscale.reshape(-1).contiguous()
753
+ residual_dtype = (
754
+ residual.dtype
755
+ if residual is not None
756
+ else (torch.float32 if residual_in_fp32 else None)
757
+ )
758
+ if out is not None:
759
+ out = out.reshape(-1, out.shape[-1])
760
+ if residual_out is not None:
761
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
762
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
763
+ x,
764
+ weight,
765
+ bias,
766
+ eps,
767
+ residual,
768
+ x1,
769
+ weight1,
770
+ bias1,
771
+ dropout_p=dropout_p,
772
+ rowscale=rowscale,
773
+ residual_dtype=residual_dtype,
774
+ is_rms_norm=is_rms_norm,
775
+ return_dropout_mask=return_dropout_mask,
776
+ out=out,
777
+ residual_out=residual_out
778
+ )
779
+ ctx.save_for_backward(
780
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
781
+ )
782
+ ctx.x_shape_og = x_shape_og
783
+ ctx.eps = eps
784
+ ctx.dropout_p = dropout_p
785
+ ctx.is_rms_norm = is_rms_norm
786
+ ctx.has_residual = residual is not None
787
+ ctx.has_x1 = x1 is not None
788
+ ctx.prenorm = prenorm
789
+ ctx.x_dtype = x.dtype
790
+ y = y.reshape(x_shape_og)
791
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
792
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
793
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
794
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
795
+ if not return_dropout_mask:
796
+ if weight1 is None:
797
+ return y if not prenorm else (y, residual_out)
798
+ else:
799
+ return (y, y1) if not prenorm else (y, y1, residual_out)
800
+ else:
801
+ if weight1 is None:
802
+ return (
803
+ (y, dropout_mask, dropout_mask1)
804
+ if not prenorm
805
+ else (y, residual_out, dropout_mask, dropout_mask1)
806
+ )
807
+ else:
808
+ return (
809
+ (y, y1, dropout_mask, dropout_mask1)
810
+ if not prenorm
811
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
812
+ )
813
+
814
+ @staticmethod
815
+ def backward(ctx, dy, *args):
816
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
817
+ dy = dy.reshape(-1, dy.shape[-1])
818
+ if dy.stride(-1) != 1:
819
+ dy = dy.contiguous()
820
+ assert dy.shape == x.shape
821
+ if weight1 is not None:
822
+ dy1, args = args[0], args[1:]
823
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
824
+ if dy1.stride(-1) != 1:
825
+ dy1 = dy1.contiguous()
826
+ assert dy1.shape == x.shape
827
+ else:
828
+ dy1 = None
829
+ if ctx.prenorm:
830
+ dresidual = args[0]
831
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
832
+ if dresidual.stride(-1) != 1:
833
+ dresidual = dresidual.contiguous()
834
+ assert dresidual.shape == x.shape
835
+ else:
836
+ dresidual = None
837
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
838
+ dy,
839
+ x,
840
+ weight,
841
+ bias,
842
+ ctx.eps,
843
+ mean,
844
+ rstd,
845
+ dresidual,
846
+ dy1,
847
+ weight1,
848
+ bias1,
849
+ seeds,
850
+ ctx.dropout_p,
851
+ rowscale,
852
+ ctx.has_residual,
853
+ ctx.has_x1,
854
+ ctx.is_rms_norm,
855
+ x_dtype=ctx.x_dtype,
856
+ )
857
+ return (
858
+ dx.reshape(ctx.x_shape_og),
859
+ dw,
860
+ db,
861
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
862
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
863
+ dw1,
864
+ db1,
865
+ None,
866
+ None,
867
+ None,
868
+ None,
869
+ None,
870
+ None,
871
+ None,
872
+ None,
873
+ None,
874
+ )
875
+
876
+
877
+ def layer_norm_fn(
878
+ x,
879
+ weight,
880
+ bias,
881
+ residual=None,
882
+ x1=None,
883
+ weight1=None,
884
+ bias1=None,
885
+ eps=1e-6,
886
+ dropout_p=0.0,
887
+ rowscale=None,
888
+ prenorm=False,
889
+ residual_in_fp32=False,
890
+ is_rms_norm=False,
891
+ return_dropout_mask=False,
892
+ out=None,
893
+ residual_out=None
894
+ ):
895
+ return LayerNormFn.apply(
896
+ x,
897
+ weight,
898
+ bias,
899
+ residual,
900
+ x1,
901
+ weight1,
902
+ bias1,
903
+ eps,
904
+ dropout_p,
905
+ rowscale,
906
+ prenorm,
907
+ residual_in_fp32,
908
+ is_rms_norm,
909
+ return_dropout_mask,
910
+ out,
911
+ residual_out
912
+ )
913
+
914
+
915
+ def rms_norm_fn(
916
+ x,
917
+ weight,
918
+ bias,
919
+ residual=None,
920
+ x1=None,
921
+ weight1=None,
922
+ bias1=None,
923
+ eps=1e-6,
924
+ dropout_p=0.0,
925
+ rowscale=None,
926
+ prenorm=False,
927
+ residual_in_fp32=False,
928
+ return_dropout_mask=False,
929
+ out=None,
930
+ residual_out=None
931
+ ):
932
+ return LayerNormFn.apply(
933
+ x,
934
+ weight,
935
+ bias,
936
+ residual,
937
+ x1,
938
+ weight1,
939
+ bias1,
940
+ eps,
941
+ dropout_p,
942
+ rowscale,
943
+ prenorm,
944
+ residual_in_fp32,
945
+ True,
946
+ return_dropout_mask,
947
+ out,
948
+ residual_out
949
+ )
950
+
951
+
952
+ class RMSNorm(torch.nn.Module):
953
+
954
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
955
+ factory_kwargs = {"device": device, "dtype": dtype}
956
+ super().__init__()
957
+ self.eps = eps
958
+ if dropout_p > 0.0:
959
+ self.drop = torch.nn.Dropout(dropout_p)
960
+ else:
961
+ self.drop = None
962
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
963
+ self.register_parameter("bias", None)
964
+ self.reset_parameters()
965
+
966
+ def reset_parameters(self):
967
+ torch.nn.init.ones_(self.weight)
968
+
969
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
970
+ return rms_norm_fn(
971
+ x,
972
+ self.weight,
973
+ self.bias,
974
+ residual=residual,
975
+ eps=self.eps,
976
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
977
+ prenorm=prenorm,
978
+ residual_in_fp32=residual_in_fp32,
979
+ )
980
+
981
+
982
+ class LayerNormLinearFn(torch.autograd.Function):
983
+ @staticmethod
984
+ @custom_fwd
985
+ def forward(
986
+ ctx,
987
+ x,
988
+ norm_weight,
989
+ norm_bias,
990
+ linear_weight,
991
+ linear_bias,
992
+ residual=None,
993
+ eps=1e-6,
994
+ prenorm=False,
995
+ residual_in_fp32=False,
996
+ is_rms_norm=False,
997
+ ):
998
+ x_shape_og = x.shape
999
+ # reshape input data into 2D tensor
1000
+ x = x.reshape(-1, x.shape[-1])
1001
+ if x.stride(-1) != 1:
1002
+ x = x.contiguous()
1003
+ if residual is not None:
1004
+ assert residual.shape == x_shape_og
1005
+ residual = residual.reshape(-1, residual.shape[-1])
1006
+ if residual.stride(-1) != 1:
1007
+ residual = residual.contiguous()
1008
+ norm_weight = norm_weight.contiguous()
1009
+ if norm_bias is not None:
1010
+ norm_bias = norm_bias.contiguous()
1011
+ residual_dtype = (
1012
+ residual.dtype
1013
+ if residual is not None
1014
+ else (torch.float32 if residual_in_fp32 else None)
1015
+ )
1016
+ y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1017
+ x,
1018
+ norm_weight,
1019
+ norm_bias,
1020
+ eps,
1021
+ residual,
1022
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
1023
+ residual_dtype=residual_dtype,
1024
+ is_rms_norm=is_rms_norm,
1025
+ )
1026
+ y = y.reshape(x_shape_og)
1027
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1028
+ linear_weight = linear_weight.to(dtype)
1029
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1030
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1031
+ # We don't store y, will be recomputed in the backward pass to save memory
1032
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
1033
+ ctx.x_shape_og = x_shape_og
1034
+ ctx.eps = eps
1035
+ ctx.is_rms_norm = is_rms_norm
1036
+ ctx.has_residual = residual is not None
1037
+ ctx.prenorm = prenorm
1038
+ ctx.x_dtype = x.dtype
1039
+ ctx.linear_bias_is_none = linear_bias is None
1040
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1041
+
1042
+ @staticmethod
1043
+ @custom_bwd
1044
+ def backward(ctx, dout, *args):
1045
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1046
+ dout = dout.reshape(-1, dout.shape[-1])
1047
+ dy = F.linear(dout, linear_weight.t())
1048
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1049
+ if dy.stride(-1) != 1:
1050
+ dy = dy.contiguous()
1051
+ assert dy.shape == x.shape
1052
+ if ctx.prenorm:
1053
+ dresidual = args[0]
1054
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1055
+ if dresidual.stride(-1) != 1:
1056
+ dresidual = dresidual.contiguous()
1057
+ assert dresidual.shape == x.shape
1058
+ else:
1059
+ dresidual = None
1060
+ dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1061
+ dy,
1062
+ x,
1063
+ norm_weight,
1064
+ norm_bias,
1065
+ ctx.eps,
1066
+ mean,
1067
+ rstd,
1068
+ dresidual=dresidual,
1069
+ has_residual=ctx.has_residual,
1070
+ is_rms_norm=ctx.is_rms_norm,
1071
+ x_dtype=ctx.x_dtype,
1072
+ recompute_output=True,
1073
+ )
1074
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1075
+ return (
1076
+ dx.reshape(ctx.x_shape_og),
1077
+ dnorm_weight,
1078
+ dnorm_bias,
1079
+ dlinear_weight,
1080
+ dlinear_bias,
1081
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1082
+ None,
1083
+ None,
1084
+ None,
1085
+ None,
1086
+ )
1087
+
1088
+
1089
+ def layer_norm_linear_fn(
1090
+ x,
1091
+ norm_weight,
1092
+ norm_bias,
1093
+ linear_weight,
1094
+ linear_bias,
1095
+ residual=None,
1096
+ eps=1e-6,
1097
+ prenorm=False,
1098
+ residual_in_fp32=False,
1099
+ is_rms_norm=False,
1100
+ ):
1101
+ return LayerNormLinearFn.apply(
1102
+ x,
1103
+ norm_weight,
1104
+ norm_bias,
1105
+ linear_weight,
1106
+ linear_bias,
1107
+ residual,
1108
+ eps,
1109
+ prenorm,
1110
+ residual_in_fp32,
1111
+ is_rms_norm,
1112
+ )