Erland commited on
Commit
2cb275d
·
verified ·
1 Parent(s): bec1e88

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. checkpoint/step-10000/.metadata +3 -0
  3. checkpoint/step-10000/__0_0.distcp +3 -0
  4. checkpoint/step-10000/__1_0.distcp +3 -0
  5. checkpoint/step-10000/__2_0.distcp +3 -0
  6. checkpoint/step-10000/__3_0.distcp +3 -0
  7. checkpoint/step-15000/.metadata +3 -0
  8. checkpoint/step-15000/__0_0.distcp +3 -0
  9. checkpoint/step-15000/__1_0.distcp +3 -0
  10. checkpoint/step-15000/__2_0.distcp +3 -0
  11. checkpoint/step-15000/__3_0.distcp +3 -0
  12. fla/modules/convolution.py +434 -0
  13. fla/modules/rotary.py +512 -0
  14. fla/ops/abc/__pycache__/__init__.cpython-311.pyc +0 -0
  15. fla/ops/abc/__pycache__/chunk.cpython-311.pyc +0 -0
  16. fla/ops/attn/__init__.py +7 -0
  17. fla/ops/attn/__pycache__/__init__.cpython-311.pyc +0 -0
  18. fla/ops/based/__init__.py +9 -0
  19. fla/ops/based/__pycache__/__init__.cpython-311.pyc +0 -0
  20. fla/ops/based/__pycache__/fused_chunk.cpython-311.pyc +0 -0
  21. fla/ops/based/fused_chunk.py +374 -0
  22. fla/ops/based/naive.py +72 -0
  23. fla/ops/common/__init__.py +1 -0
  24. fla/ops/common/__pycache__/chunk_delta_h.cpython-311.pyc +0 -0
  25. fla/ops/common/__pycache__/chunk_h.cpython-311.pyc +0 -0
  26. fla/ops/common/__pycache__/chunk_o.cpython-311.pyc +0 -0
  27. fla/ops/common/__pycache__/utils.cpython-311.pyc +0 -0
  28. fla/ops/common/chunk_delta_h.py +399 -0
  29. fla/ops/common/chunk_h_parallel.py +650 -0
  30. fla/ops/common/chunk_h_split.py +677 -0
  31. fla/ops/common/fused_recurrent.py +575 -0
  32. fla/ops/common/utils.py +69 -0
  33. fla/ops/delta_rule/__init__.py +11 -0
  34. fla/ops/delta_rule/__pycache__/chunk.cpython-311.pyc +0 -0
  35. fla/ops/delta_rule/__pycache__/fused_chunk.cpython-311.pyc +0 -0
  36. fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-311.pyc +0 -0
  37. fla/ops/delta_rule/__pycache__/wy_fast.cpython-311.pyc +0 -0
  38. fla/ops/delta_rule/chunk.py +373 -0
  39. fla/ops/delta_rule/fused_chunk.py +6 -0
  40. fla/ops/delta_rule/fused_recurrent.py +607 -0
  41. fla/ops/delta_rule/naive.py +120 -0
  42. fla/ops/delta_rule/parallel.py +394 -0
  43. fla/ops/delta_rule/wy_fast.py +340 -0
  44. fla/ops/forgetting_attn/__init__.py +7 -0
  45. fla/ops/forgetting_attn/__pycache__/__init__.cpython-311.pyc +0 -0
  46. fla/ops/forgetting_attn/__pycache__/parallel.cpython-311.pyc +0 -0
  47. fla/ops/forgetting_attn/parallel.py +708 -0
  48. fla/ops/gated_delta_rule/__init__.py +7 -0
  49. fla/ops/gated_delta_rule/__pycache__/__init__.cpython-311.pyc +0 -0
  50. fla/ops/gated_delta_rule/__pycache__/chunk.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoint/step-10000/.metadata filter=lfs diff=lfs merge=lfs -text
37
+ checkpoint/step-15000/.metadata filter=lfs diff=lfs merge=lfs -text
38
+ tb/20250613-1241/wandb/run-20250613_124127--mtp.120M.batch8.seqlen2048.context2048.warmup1000.update1.steps15000.nft4.lr5e-4.cosine-202506131240/run--mtp.120M.batch8.seqlen2048.context2048.warmup1000.update1.steps15000.nft4.lr5e-4.cosine-202506131240.wandb filter=lfs diff=lfs merge=lfs -text
39
+ checkpoint/step-10000/__3_0.distcp filter=lfs diff=lfs merge=lfs -text
40
+ checkpoint/step-10000/__0_0.distcp filter=lfs diff=lfs merge=lfs -text
41
+ checkpoint/step-15000/__1_0.distcp filter=lfs diff=lfs merge=lfs -text
42
+ checkpoint/step-15000/__0_0.distcp filter=lfs diff=lfs merge=lfs -text
43
+ checkpoint/step-10000/__1_0.distcp filter=lfs diff=lfs merge=lfs -text
44
+ checkpoint/step-15000/__2_0.distcp filter=lfs diff=lfs merge=lfs -text
45
+ checkpoint/step-10000/__2_0.distcp filter=lfs diff=lfs merge=lfs -text
46
+ checkpoint/step-15000/__3_0.distcp filter=lfs diff=lfs merge=lfs -text
checkpoint/step-10000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9bfcd53d956e7ce1567219335ee1ed1c484eca733f0aa24870d0c1420f8613f
3
+ size 864528
checkpoint/step-10000/__0_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d41d8a6c5d13e200c4e0d493a72439ec522d3c33f913573a23da5993c09ae78d
3
+ size 401559768
checkpoint/step-10000/__1_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29d95b4dcd9c7a539ec0e96fa83295ad5775eedd16fc132405020a6523cba551
3
+ size 400714252
checkpoint/step-10000/__2_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6d6c7cfb6c2decaf8f865dcb34e714de6ecdc741cd06714095478d0a6b27875
3
+ size 401289872
checkpoint/step-10000/__3_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e319bd1abf88b357c21eca912417e32ed537e4c55c4afccfc28a814c80c2f62c
3
+ size 401088656
checkpoint/step-15000/.metadata ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98d42bcc9c9eee60c545adb82d32950bc5db492d6f788295eabb003ec90c2203
3
+ size 864554
checkpoint/step-15000/__0_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62d6d87b953ce1ac7e33777cac0bcdd1fe3b5fde5e41d6ec55a20b95ec8e325c
3
+ size 400464984
checkpoint/step-15000/__1_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76b34eeecec5e708e4f31621a0efefcc1309150c1d336bb6dbaa4a4b3724183f
3
+ size 401236172
checkpoint/step-15000/__2_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33393b08543267b450fd6335866908323d80eae4e0029438e5986a785edfbf04
3
+ size 400709840
checkpoint/step-15000/__3_0.distcp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46eb8f21ff5d665a5d869d46206d38064ed9f0db548139242a49e223f424cf46
3
+ size 400992592
fla/modules/convolution.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # from https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/convolution.py
4
+
5
+ import math
6
+ import warnings
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from einops import rearrange
15
+
16
+ from fla.modules.activations import ACT2FN
17
+ from fla.ops.common.utils import prepare_position_ids, prepare_sequence_ids
18
+ from fla.utils import checkpoint, input_guard
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn = None
24
+ causal_conv1d_update = None
25
+
26
+
27
+ def fft_conv(u, k, dropout_mask, gelu=True, k_rev=None):
28
+ seqlen = u.shape[-1]
29
+ fft_size = 2 * seqlen
30
+ k_f = torch.fft.rfft(k, n=fft_size) / fft_size
31
+ if k_rev is not None:
32
+ k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size
33
+ k_f = k_f + k_rev_f.conj()
34
+ u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
35
+
36
+ if len(u.shape) > 3:
37
+ k_f = k_f.unsqueeze(1)
38
+ y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
39
+
40
+ out = y + u
41
+ if gelu:
42
+ out = F.gelu(out)
43
+ if dropout_mask is not None:
44
+ return (out * rearrange(dropout_mask, "b H -> b H 1")).to(dtype=u.dtype)
45
+ else:
46
+ return out.to(dtype=u.dtype)
47
+
48
+
49
+ @checkpoint
50
+ def proj_then_conv1d(
51
+ x: torch.Tensor,
52
+ proj_weight: torch.Tensor,
53
+ conv1d_weight: torch.Tensor,
54
+ conv1d_bias: Optional[torch.Tensor] = None,
55
+ cache: Optional[torch.Tensor] = None
56
+ ) -> torch.Tensor:
57
+ # We do matmul and transpose BLH -> HBL at the same time
58
+ x = rearrange(proj_weight @ rearrange(x, "b t d -> d (b t)"), "d (b t) -> b d t", t=x.shape[-2])
59
+
60
+ if causal_conv1d_fn is None:
61
+ raise ImportError("`causal_conv1d_fn` is not available. Please install `causal-conv1d` first.")
62
+ if cache is None:
63
+ x = causal_conv1d_fn(
64
+ x=x,
65
+ weight=rearrange(conv1d_weight, "d 1 w -> d w"),
66
+ bias=conv1d_bias,
67
+ activation="silu",
68
+ ).transpose(1, 2)
69
+ else:
70
+ assert x.shape[-1] == 1, "Only support decoding with 1 token at a time for now"
71
+ x = x.squeeze(-1)
72
+ x = causal_conv1d_update(
73
+ x=x,
74
+ weight=rearrange(conv1d_weight, "d 1 w -> d w"),
75
+ bias=conv1d_bias,
76
+ cache=cache,
77
+ activation="silu",
78
+ )
79
+ return x
80
+
81
+
82
+ @triton.jit
83
+ def causal_conv1d_varlen_states_fwd_kernel(
84
+ x,
85
+ cache,
86
+ offsets,
87
+ D,
88
+ W,
89
+ BD: tl.constexpr,
90
+ BW: tl.constexpr
91
+ ):
92
+ i_d, i_w, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2)
93
+ eos = tl.load(offsets + i_n + 1)
94
+ bos = tl.maximum(tl.load(offsets + i_n), eos - W)
95
+ o_t = eos - (i_w + 1) * BW + tl.arange(0, BW)
96
+ o_d = i_d * BD + tl.arange(0, BD)
97
+ o_w = W - (i_w + 1) * BW + tl.arange(0, BW)
98
+
99
+ b_x = tl.load(x + o_t * D + o_d[:, None], mask=(o_t >= bos) & (o_d[:, None] < D), other=0)
100
+ tl.store(cache + i_n * D*W + o_d[:, None] * W + o_w, b_x, mask=(o_d[:, None] < D) & (o_w >= 0))
101
+
102
+
103
+ @input_guard
104
+ def causal_conv1d_varlen_states_fwd(
105
+ x: torch.Tensor,
106
+ cache: torch.Tensor,
107
+ cu_seqlens: torch.Tensor,
108
+ state_len: int
109
+ ) -> torch.Tensor:
110
+ N, D, W = len(cu_seqlens) - 1, x.shape[-1], state_len
111
+ cache = torch.empty(N, D, W, dtype=x.dtype, device=x.device) if cache is None else cache
112
+ BD = min(triton.next_power_of_2(D), 256)
113
+ BW = min(triton.next_power_of_2(state_len), 16)
114
+ grid = (triton.cdiv(D, BD), triton.cdiv(W, BW), N)
115
+ with torch.cuda.device(x.device.index):
116
+ causal_conv1d_varlen_states_fwd_kernel[grid](
117
+ x=x,
118
+ cache=cache,
119
+ offsets=cu_seqlens,
120
+ D=D,
121
+ W=W,
122
+ BW=BW,
123
+ BD=BD
124
+ )
125
+ return cache
126
+
127
+
128
+ class ShortConvolution(nn.Conv1d):
129
+ """
130
+ Simple wrapper around `nn.Conv1d` that accepts dimension last.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ hidden_size: int,
136
+ kernel_size: int,
137
+ bias: bool = False,
138
+ activation: Optional[str] = 'silu',
139
+ use_fast_conv1d: Optional[bool] = True,
140
+ device: Optional[torch.device] = None,
141
+ dtype: Optional[torch.dtype] = None,
142
+ ):
143
+ super().__init__(
144
+ in_channels=hidden_size,
145
+ out_channels=hidden_size,
146
+ kernel_size=kernel_size,
147
+ groups=hidden_size,
148
+ bias=bias,
149
+ padding=kernel_size - 1,
150
+ device=device,
151
+ dtype=dtype,
152
+ )
153
+
154
+ self.hidden_size = hidden_size
155
+ self.activation = None
156
+ if activation is not None:
157
+ assert activation in ['silu', 'swish'], f"Activation `{activation}` not supported yet."
158
+ self.activation = activation
159
+
160
+ if causal_conv1d_fn is None:
161
+ if use_fast_conv1d:
162
+ raise RuntimeError(
163
+ "Please either install `causal-conv1d>=1.4.0` to enable fast causal short convolution CUDA kernel "
164
+ "or set `use_fast_conv1d` to False"
165
+ )
166
+ else:
167
+ warnings.warn(
168
+ "The naive Pytorch verison is very slow in practice, "
169
+ "please run `pip install causal-conv1d>=1.4.0` to install fast causal short convolution CUDA kernel",
170
+ category=ImportWarning
171
+ )
172
+ self.use_fast_conv1d = use_fast_conv1d
173
+
174
+ def extra_repr(self):
175
+ s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
176
+ ', stride={stride}')
177
+ if self.padding != (0,) * len(self.padding):
178
+ s += ', padding={padding}'
179
+ if self.dilation != (1,) * len(self.dilation):
180
+ s += ', dilation={dilation}'
181
+ if self.output_padding != (0,) * len(self.output_padding):
182
+ s += ', output_padding={output_padding}'
183
+ if self.groups != 1:
184
+ s += ', groups={groups}'
185
+ if self.bias is None:
186
+ s += ', bias=False'
187
+ if self.padding_mode != 'zeros':
188
+ s += ', padding_mode={padding_mode}'
189
+ if self.activation is not None:
190
+ s += ', activation={activation}'
191
+ if not self.use_fast_conv1d:
192
+ s += ', use_fast_conv1d={use_fast_conv1d}'
193
+ return s.format(**self.__dict__)
194
+
195
+ def forward(
196
+ self,
197
+ x: torch.Tensor,
198
+ mask: Optional[torch.Tensor] = None,
199
+ cache: Optional[torch.Tensor] = None,
200
+ output_final_state: bool = False,
201
+ cu_seqlens: Optional[torch.LongTensor] = None,
202
+ **kwargs,
203
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
204
+ """
205
+ Args:
206
+ x (`torch.Tensor`):
207
+ Tensor of shape `[B, T, D]`.
208
+ If `seq_idx` is provided, `B` must be 1.
209
+ mask (`Optional[torch.Tensor]`):
210
+ Attention mask dealing with padded positions.
211
+ cache (`Optional[torch.Tensor]`):
212
+ Previous cache tensor of shape `[N, D, W]`, where `W` is the kernel size.
213
+ If provided, the cache is updated **inplace**.
214
+ output_final_state (Optional[bool]):
215
+ Whether to output the final state of shape `[N, D, W]`. Default: `False`.
216
+ cu_seqlens (Optional[torch.LongTensor]):
217
+ Cumulative sequence lengths for each batch. Used for varlen. Default: `None`.
218
+ Shape: [B+1]
219
+
220
+ Returns:
221
+ Tensor of shape `[B, T, D]`.
222
+ """
223
+
224
+ B, T, D, W = *x.shape, self.kernel_size[0]
225
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
226
+ if mask is not None:
227
+ if cu_seqlens is not None:
228
+ raise ValueError("`mask` and `cu_seqlens` cannot be provided at the same time")
229
+ x = x.mul_(mask.unsqueeze(-1))
230
+ if output_final_state and cache is None:
231
+ cache = x.new_zeros(N, D, W)
232
+ # during the decoding phase, we assume the batch is composed of sequences of length 1
233
+ if cache is not None and B * T == N:
234
+ return self.step(x, cache, cu_seqlens)
235
+
236
+ if cache is not None:
237
+ if cu_seqlens is not None:
238
+ cache = causal_conv1d_varlen_states_fwd(x, cache, cu_seqlens, W)
239
+ else:
240
+ cache[:, :, -min(W, T):].copy_(rearrange(x[..., -min(W, T):, :], 'n w d -> n d w'))
241
+
242
+ x = rearrange(x, 'b t d -> b d t')
243
+ if self.use_fast_conv1d:
244
+ # Sequence index for each token. Used for varlen.
245
+ # Suppose a batch consists of two sequences with lengths 3 and 4,
246
+ # seq_idx=[0, 0, 0, 1, 1, 1, 1] for this batch.
247
+ # NOTE: No need to provide this arg if `cu_seqlens` is passed.
248
+ # This arg is just for BC, and will be removed in the future.
249
+ # [B, T]
250
+ seq_idx = kwargs.get('seq_idx', None)
251
+ if cu_seqlens is not None and seq_idx is None:
252
+ seq_idx = prepare_sequence_ids(prepare_position_ids(cu_seqlens)).to(torch.int32).unsqueeze(0)
253
+ x = causal_conv1d_fn(
254
+ x=x,
255
+ weight=rearrange(self.weight, "d 1 w -> d w"),
256
+ bias=self.bias,
257
+ activation=self.activation,
258
+ seq_idx=seq_idx,
259
+ )
260
+ else:
261
+ if cu_seqlens is not None:
262
+ raise ValueError("`cu_seqlens` is not supported for the naive Pytorch version")
263
+ x = self._conv_forward(x, self.weight, self.bias)[..., :x.shape[-1]]
264
+ if self.activation is not None:
265
+ x = ACT2FN[self.activation](x)
266
+ return rearrange(x, "b d t -> b t d"), cache
267
+
268
+ def step(
269
+ self,
270
+ x: torch.Tensor,
271
+ cache: torch.Tensor,
272
+ cu_seqlens: Optional[torch.LongTensor] = None
273
+ ):
274
+ shape = x.shape
275
+ x = x.squeeze(0) if cu_seqlens is not None else x.squeeze(1)
276
+ if self.use_fast_conv1d:
277
+ x = causal_conv1d_update(
278
+ x=x,
279
+ conv_state=cache,
280
+ weight=rearrange(self.weight, "d 1 w -> d w"),
281
+ bias=self.bias,
282
+ activation=self.activation,
283
+ )
284
+ else:
285
+ dtype = x.dtype
286
+ # we follow the fast mode that updates the cache in-place
287
+ cache.copy_(cache.roll(shifts=-1, dims=-1))
288
+ cache[:, :, -1] = x
289
+ x = torch.sum(cache * rearrange(self.weight, "d 1 w -> d w"), dim=-1)
290
+ if self.bias is not None:
291
+ x = x + self.bias
292
+ if self.activation is not None:
293
+ x = ACT2FN[self.activation](x).to(dtype=dtype)
294
+ return x.view(shape), cache
295
+
296
+ @property
297
+ def state_size(self) -> int:
298
+ return self.hidden_size * self.kernel_size
299
+
300
+
301
+ class LongConvolution(nn.Module):
302
+ """
303
+ LongConvolution applies a convolution operation on the input tensor using a fixed
304
+ filter of length max_len.
305
+ The filter is learned during training and is applied using FFT convolution.
306
+ Args:
307
+ hidden_size (int): The number of expected features in the input and output.
308
+ max_len (int): The maximum sequence length.
309
+ Returns:
310
+ y: [batch_size, seq_len, hidden_size] tensor
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ hidden_size: int,
316
+ max_len: int,
317
+ **kwargs,
318
+ ):
319
+ """
320
+ Initializes the LongConvolution module.
321
+ Args:
322
+ hidden_size (int): The number of expected features in the input and output.
323
+ max_len (int): The maximum sequence length.
324
+ """
325
+ super().__init__()
326
+ self.hidden_size = hidden_size
327
+ self.filter = nn.Parameter(torch.randn(self.hidden_size, max_len), requires_grad=True)
328
+
329
+ def forward(self, x: torch.Tensor, *args, **kwargs):
330
+ """
331
+ Applies the LongConvolution operation on the input tensor.
332
+ Args:
333
+ x: [batch_size, seq_len, hidden_size] tensor
334
+ Returns:
335
+ y: [batch_size, seq_len, hidden_size] tensor
336
+ """
337
+ x = x.transpose(1, 2)
338
+ y = fft_conv(x, self.filter, dropout_mask=None, gelu=False)
339
+ y = y.transpose(1, 2)
340
+ return y.to(dtype=x.dtype)
341
+
342
+
343
+ class PositionalEmbedding(nn.Module):
344
+ def __init__(self, emb_dim: int, seq_len: int, **kwargs):
345
+ """Complex exponential positional embeddings for implicit long convolution filters."""
346
+ super().__init__()
347
+
348
+ self.seq_len = seq_len
349
+ # The time embedding fed to the filteres is normalized so that t_f = 1
350
+ t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1
351
+
352
+ if emb_dim > 1:
353
+ bands = (emb_dim - 1) // 2
354
+ # To compute the right embeddings we use the "proper" linspace
355
+ t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None]
356
+ w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1
357
+
358
+ f = torch.linspace(1e-4, bands - 1, bands)[None, None]
359
+ z = torch.exp(-1j * f * w)
360
+ z = torch.cat([t, z.real, z.imag], dim=-1)
361
+ self.z = nn.Parameter(z, requires_grad=False)
362
+
363
+ def forward(self, L):
364
+ return self.z[:, :L]
365
+
366
+
367
+ class ImplicitLongConvolution(nn.Module):
368
+ """
369
+ Long convolution with implicit filter parameterized by an MLP.
370
+
371
+ Args:
372
+ hidden_size (int):
373
+ The number of expected features in the input and output.
374
+ max_len (int):
375
+ The maximum sequence length.
376
+ d_emb (Optional[int]):
377
+ The dimension of the positional embeddings. Must be odd and greater or equal to 3 (time, sine and cosine).
378
+ Defaults to 3.
379
+ d_hidden (Optional[int]):
380
+ The number of features in the hidden layer of the MLP. Defaults to 16.
381
+
382
+ Attributes:
383
+ pos_emb (`PositionalEmbedding`): The positional embedding layer.
384
+ mlp (`nn.Sequential`): The MLP that parameterizes the implicit filter.
385
+
386
+ """
387
+
388
+ def __init__(
389
+ self,
390
+ hidden_size: int,
391
+ max_len: int,
392
+ d_emb: int = 3,
393
+ d_hidden: int = 16,
394
+ **kwargs,
395
+ ):
396
+ """
397
+ Long convolution with implicit filter parameterized by an MLP.
398
+
399
+
400
+ """
401
+ super().__init__()
402
+ self.hidden_size = hidden_size
403
+ self.d_emb = d_emb
404
+
405
+ assert (
406
+ d_emb % 2 != 0 and d_emb >= 3
407
+ ), "d_emb must be odd and greater or equal to 3 (time, sine and cosine)"
408
+ self.pos_emb = PositionalEmbedding(d_emb, max_len)
409
+
410
+ # final linear layer
411
+ self.mlp = nn.Sequential(
412
+ nn.Linear(d_emb, d_hidden),
413
+ torch.nn.ReLU(),
414
+ nn.Linear(d_hidden, hidden_size),
415
+ )
416
+
417
+ def filter(self, seq_len: int, *args, **kwargs):
418
+ k = self.mlp(self.pos_emb(seq_len))
419
+
420
+ return k.transpose(1, 2)
421
+
422
+ def forward(self, x: torch.Tensor, *args, **kwargs):
423
+ """
424
+ Args:
425
+ x: [batch_size, seq_len, hidden_size] tensor
426
+ Returns:
427
+ y: [batch_size, seq_len, hidden_size] tensor
428
+ """
429
+ x = x.transpose(1, 2)
430
+ k = self.filter(x.shape[-1])
431
+ y = fft_conv(x, k, dropout_mask=None, gelu=False)
432
+
433
+ y = y.transpose(1, 2)
434
+ return y.to(dtype=x.dtype)
fla/modules/rotary.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright (c) 2023, Tri Dao.
4
+ # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py
5
+
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import triton
11
+ import triton.language as tl
12
+ from einops import rearrange, repeat
13
+
14
+ from fla.utils import get_multiprocessor_count, input_guard
15
+
16
+
17
+ def rotate_half(x, interleaved=False):
18
+ if not interleaved:
19
+ x1, x2 = x.chunk(2, dim=-1)
20
+ return torch.cat((-x2, x1), dim=-1)
21
+ else:
22
+ x1, x2 = x[..., ::2], x[..., 1::2]
23
+ return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
24
+
25
+
26
+ def rotary_embedding_ref(x, cos, sin, interleaved=False):
27
+ ro_dim = cos.shape[-1] * 2
28
+ assert ro_dim <= x.shape[-1]
29
+ cos = repeat(cos, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)')
30
+ sin = repeat(sin, '... d -> ... 1 (2 d)' if not interleaved else '... d -> ... 1 (d 2)')
31
+ return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], -1)
32
+
33
+
34
+ @triton.autotune(
35
+ configs=[
36
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
37
+ for num_warps in [2, 4, 8, 16, 32]
38
+ for num_stages in [2, 3, 4]
39
+ ],
40
+ key=['B', 'H', 'D', 'INTERLEAVED'],
41
+ )
42
+ @triton.jit
43
+ def rotary_embedding_kernel(
44
+ x,
45
+ cos,
46
+ sin,
47
+ y,
48
+ cu_seqlens,
49
+ seq_offsets, # this could be int or a pointer
50
+ # Matrix dimensions
51
+ B: tl.constexpr,
52
+ T: tl.constexpr,
53
+ H: tl.constexpr,
54
+ D: tl.constexpr,
55
+ R: tl.constexpr,
56
+ TR: tl.constexpr,
57
+ BT: tl.constexpr,
58
+ BD: tl.constexpr,
59
+ IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
60
+ IS_VARLEN: tl.constexpr,
61
+ INTERLEAVED: tl.constexpr,
62
+ CONJUGATE: tl.constexpr
63
+ ):
64
+ i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2)
65
+
66
+ if not IS_VARLEN:
67
+ x = x + i_b * T*H*D + i_h * D
68
+ y = y + i_b * T*H*D + i_h * D
69
+ else:
70
+ bos, eos = tl.load(cu_seqlens + i_b), tl.load(cu_seqlens + i_b + 1)
71
+ T = eos - bos
72
+ x = x + bos * H*D + i_h * D
73
+ y = y + bos * H*D + i_h * D
74
+
75
+ if i_t * BT >= T:
76
+ return
77
+
78
+ o_t = i_t * BT + tl.arange(0, BT)
79
+ if not IS_SEQLEN_OFFSETS_TENSOR:
80
+ o_cs = o_t + seq_offsets
81
+ else:
82
+ o_cs = o_t + tl.load(seq_offsets + i_b)
83
+
84
+ if not INTERLEAVED:
85
+ # Load the 1st and 2nd halves of x, do calculation, then store to 1st and 2nd halves of out
86
+ o_r = tl.arange(0, BD // 2)
87
+ p_x = x + o_t[:, None] * H*D + o_r[None, :]
88
+ p_cos = cos + (o_cs[:, None] * R + o_r[None, :])
89
+ p_sin = sin + (o_cs[:, None] * R + o_r[None, :])
90
+ mask = (o_t[:, None] >= 0) & (o_t[:, None] < T) & (o_r[None, :] < R)
91
+
92
+ b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32)
93
+ b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32)
94
+ b_x0 = tl.load(p_x, mask=mask, other=0.0).to(tl.float32)
95
+ b_x1 = tl.load(p_x + R, mask=mask, other=0.0).to(tl.float32)
96
+ if CONJUGATE:
97
+ b_sin = -b_sin
98
+ b_o0 = b_x0 * b_cos - b_x1 * b_sin
99
+ b_o1 = b_x0 * b_sin + b_x1 * b_cos
100
+ # write back result
101
+ p_y = y + (o_t[:, None] * H*D + o_r[None, :])
102
+ tl.store(p_y, b_o0, mask=mask)
103
+ tl.store(p_y + R, b_o1, mask=mask)
104
+ else:
105
+ # We don't want to load x[0, 2, 4, ...] and x[1, 3, 5, ...] separately since both are slow.
106
+ # Instead, we load x0 = x[0, 1, 2, 3, ...] and x1 = x[1, 0, 3, 2, ...].
107
+ # Loading x0 will be fast but x1 will be slow.
108
+ # Then we load cos = cos[0, 0, 1, 1, ...] and sin = sin[0, 0, 1, 1, ...].
109
+ # Then we do the calculation and use tl.where to pick put the right outputs for the even
110
+ # and for the odd indices.
111
+ o_d = tl.arange(0, BD)
112
+ o_d_swap = o_d + ((o_d + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
113
+ o_d_repeat = tl.arange(0, BD) // 2
114
+ p_x0 = x + o_t[:, None] * H*D + o_d[None, :]
115
+ p_x1 = x + o_t[:, None] * H*D + o_d_swap[None, :]
116
+ p_cos = cos + (o_cs[:, None] * R + o_d_repeat[None, :])
117
+ p_sin = sin + (o_cs[:, None] * R + o_d_repeat[None, :])
118
+ mask = (o_cs[:, None] >= 0) & (o_cs[:, None] < TR) & (o_d_repeat[None, :] < R)
119
+
120
+ b_cos = tl.load(p_cos, mask=mask, other=1.0).to(tl.float32)
121
+ b_sin = tl.load(p_sin, mask=mask, other=0.0).to(tl.float32)
122
+ b_x0 = tl.load(p_x0, mask=mask, other=0.0).to(tl.float32)
123
+ b_x1 = tl.load(p_x1, mask=mask, other=0.0).to(tl.float32)
124
+ if CONJUGATE:
125
+ b_sin = -b_sin
126
+ b_o0 = b_x0 * b_cos
127
+ b_o1 = b_x1 * b_sin
128
+ b_y = tl.where(o_d[None, :] % 2 == 0, b_o0 - b_o1, b_o0 + b_o1)
129
+ p_y = y + (o_t[:, None] * H*D + o_d[None, :])
130
+ tl.store(p_y, b_y, mask=mask)
131
+
132
+
133
+ def rotary_embedding_fwdbwd(
134
+ x: torch.Tensor,
135
+ cos: torch.Tensor,
136
+ sin: torch.Tensor,
137
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
138
+ cu_seqlens: Optional[torch.Tensor] = None,
139
+ max_seqlen: Optional[int] = None,
140
+ interleaved: bool = False,
141
+ inplace: bool = False,
142
+ conjugate: bool = False
143
+ ) -> torch.Tensor:
144
+ """
145
+ Args:
146
+ x: [B, T, H, D].
147
+ cos: [TR, R / 2]
148
+ sin: [TR, R / 2]
149
+ seqlen_offsets: integer or integer tensor of size (N,)
150
+ cu_seqlens: (N + 1,) or None
151
+ max_seqlen: int
152
+
153
+ Returns:
154
+ y: [B, T, H, D]
155
+ """
156
+ is_varlen = cu_seqlens is not None
157
+
158
+ B, T, H, D = x.shape
159
+ if not is_varlen:
160
+ N = B
161
+ else:
162
+ assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
163
+ N, T = cu_seqlens.shape[0] - 1, max_seqlen
164
+ TR, R = cos.shape
165
+ assert sin.shape == cos.shape
166
+ R2 = R * 2
167
+
168
+ assert D <= 256, "Only support D <= 256"
169
+ assert TR >= T, "TR must be >= T"
170
+
171
+ assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
172
+ assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
173
+
174
+ if isinstance(seqlen_offsets, torch.Tensor):
175
+ assert seqlen_offsets.shape == (N,)
176
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
177
+ else:
178
+ assert seqlen_offsets + T <= TR
179
+
180
+ y = torch.empty_like(x) if not inplace else x
181
+ if R2 < D and not inplace:
182
+ y[..., R2:].copy_(x[..., R2:])
183
+
184
+ BD = triton.next_power_of_2(R2)
185
+ BT = min(128, triton.next_power_of_2(triton.cdiv(T, get_multiprocessor_count(x.device.index))))
186
+
187
+ def grid(meta): return (triton.cdiv(T, meta['BT']), N, H) # noqa
188
+ rotary_embedding_kernel[grid](
189
+ x,
190
+ cos,
191
+ sin,
192
+ y,
193
+ cu_seqlens,
194
+ seqlen_offsets,
195
+ B=B,
196
+ T=T,
197
+ H=H,
198
+ D=D,
199
+ R=R,
200
+ TR=TR,
201
+ BT=BT,
202
+ BD=BD,
203
+ IS_SEQLEN_OFFSETS_TENSOR=isinstance(seqlen_offsets, torch.Tensor),
204
+ IS_VARLEN=is_varlen,
205
+ INTERLEAVED=interleaved,
206
+ CONJUGATE=conjugate
207
+ )
208
+ return y
209
+
210
+
211
+ class RotaryEmbeddingFunction(torch.autograd.Function):
212
+
213
+ @staticmethod
214
+ @input_guard
215
+ def forward(
216
+ ctx,
217
+ x,
218
+ cos,
219
+ sin,
220
+ interleaved=False,
221
+ inplace=False,
222
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
223
+ cu_seqlens: Optional[torch.Tensor] = None,
224
+ max_seqlen: Optional[int] = None,
225
+ ):
226
+ y = rotary_embedding_fwdbwd(
227
+ x,
228
+ cos,
229
+ sin,
230
+ seqlen_offsets=seqlen_offsets,
231
+ cu_seqlens=cu_seqlens,
232
+ max_seqlen=max_seqlen,
233
+ interleaved=interleaved,
234
+ inplace=inplace,
235
+ )
236
+ if isinstance(seqlen_offsets, int):
237
+ # Can't save int with save_for_backward
238
+ ctx.save_for_backward(cos, sin, cu_seqlens)
239
+ ctx.seqlen_offsets = seqlen_offsets
240
+ else:
241
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
242
+ ctx.seqlen_offsets = None
243
+ ctx.interleaved = interleaved
244
+ ctx.inplace = inplace
245
+ ctx.max_seqlen = max_seqlen
246
+ return y if not inplace else x
247
+
248
+ @staticmethod
249
+ @input_guard
250
+ def backward(ctx, do):
251
+ seqlen_offsets = ctx.seqlen_offsets
252
+ if seqlen_offsets is None:
253
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
254
+ else:
255
+ cos, sin, cu_seqlens = ctx.saved_tensors
256
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
257
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
258
+ if not ctx.interleaved and not ctx.inplace:
259
+ do = do.clone()
260
+ dx = rotary_embedding_fwdbwd(
261
+ do,
262
+ cos,
263
+ sin,
264
+ seqlen_offsets=seqlen_offsets,
265
+ cu_seqlens=cu_seqlens,
266
+ max_seqlen=ctx.max_seqlen,
267
+ interleaved=ctx.interleaved,
268
+ inplace=ctx.inplace,
269
+ conjugate=True,
270
+ )
271
+ return dx, None, None, None, None, None, None, None
272
+
273
+
274
+ def rotary_embedding(
275
+ x,
276
+ cos,
277
+ sin,
278
+ interleaved=False,
279
+ inplace=False,
280
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
281
+ cu_seqlens: Optional[torch.Tensor] = None,
282
+ max_seqlen: Optional[int] = None,
283
+ ):
284
+ """
285
+ Args:
286
+ x: [B, T, H, D]
287
+ cos, sin: [TR, R//2]
288
+ interleaved:
289
+ If True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style).
290
+ inplace:
291
+ If True, apply rotary embedding in-place.
292
+ seqlen_offsets: [N,] or int.
293
+ Each sequence in x is shifted by this amount.
294
+ Most commonly used in inference when we have KV cache.
295
+ cu_seqlens: [N + 1,] or None
296
+ max_seqlen: int
297
+
298
+ Returns:
299
+ out: [B, T, H, D]
300
+ """
301
+ return RotaryEmbeddingFunction.apply(
302
+ x,
303
+ cos,
304
+ sin,
305
+ interleaved,
306
+ inplace,
307
+ seqlen_offsets,
308
+ cu_seqlens,
309
+ max_seqlen
310
+ )
311
+
312
+
313
+ class RotaryEmbedding(nn.Module):
314
+ """
315
+ The rotary position embeddings from RoFormer_ (Su et. al).
316
+ A crucial insight from the method is that the query and keys are
317
+ transformed by rotation matrices which depend on the relative positions.
318
+
319
+ Other implementations are available in the Rotary Transformer repo_ and in
320
+ GPT-NeoX_, GPT-NeoX was an inspiration
321
+
322
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
323
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
324
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
325
+
326
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
327
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
328
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
329
+ """
330
+
331
+ def __init__(
332
+ self,
333
+ dim: int,
334
+ base: float = 10000.0,
335
+ scale_base: Optional[float] = None,
336
+ interleaved: bool = False,
337
+ pos_idx_in_fp32: bool = True,
338
+ device: Optional[torch.device] = None,
339
+ ):
340
+ """
341
+ interleaved:
342
+ If True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style).
343
+ pos_idx_in_fp32:
344
+ If True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision.
345
+ This option was added because previously (before 2023-07-02), when we construct
346
+ the position indices, we use the dtype of self.inv_freq.
347
+ In most cases this would be fp32, but if the model is trained in pure bf16 (not mixed precision), then
348
+ self.inv_freq would be bf16, and the position indices are also in bf16.
349
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
350
+ embeddings for some positions will coincide.
351
+ To maintain compatibility with models previously trained in pure bf16, we add this option.
352
+ """
353
+ super().__init__()
354
+
355
+ self.dim = dim
356
+ self.base = float(base)
357
+ self.scale_base = scale_base
358
+ self.interleaved = interleaved
359
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
360
+ self.device = device
361
+
362
+ # Generate and save the inverse frequency buffer (non trainable)
363
+ self.register_buffer("inv_freq", torch.empty(-(dim // -2), dtype=torch.float32, device=device), persistent=False)
364
+
365
+ scale = None
366
+ if scale_base is not None:
367
+ scale = torch.empty(-(dim // -2), dtype=torch.float32, device=device)
368
+ self.register_buffer("scale", scale, persistent=False)
369
+
370
+ self._seq_len_cached = 0
371
+ self._cos_cached = None
372
+ self._sin_cached = None
373
+ self._cos_k_cached = None
374
+ self._sin_k_cached = None
375
+
376
+ self.reset_parameters()
377
+
378
+ def reset_parameters(self):
379
+ with torch.no_grad():
380
+ self.inv_freq.copy_(self._compute_inv_freq(device=self.inv_freq.device))
381
+ if self.scale_base is not None:
382
+ self.scale.copy_(self._compute_scale(device=self.scale.device))
383
+
384
+ def __repr__(self):
385
+ s = f"{self.__class__.__name__}("
386
+ s += f"dim={self.dim}, "
387
+ s += f"base={self.base}, "
388
+ s += f"interleaved={self.interleaved}, "
389
+ if self.scale_base is not None:
390
+ s += f"scale_base={self.scale_base}, "
391
+ s += f"pos_idx_in_fp32={self.pos_idx_in_fp32})"
392
+ return s
393
+
394
+ def _compute_inv_freq(self, device=None):
395
+ return 1.0 / (
396
+ self.base
397
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
398
+ )
399
+
400
+ def _compute_scale(self, device=None):
401
+ return (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) + 0.4 * self.dim) / (1.4 * self.dim)
402
+
403
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
404
+ # Reset the tables if the sequence length has changed,
405
+ # if we're on a new device (possibly due to tracing for instance),
406
+ # or if we're switching from inference mode to training
407
+ if (
408
+ seqlen > self._seq_len_cached
409
+ or self._cos_cached is None
410
+ or self._cos_cached.device != device
411
+ or self._cos_cached.dtype != dtype
412
+ or (self.training and self._cos_cached.is_inference())
413
+ ):
414
+ self._seq_len_cached = seqlen
415
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
416
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
417
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
418
+ if self.pos_idx_in_fp32:
419
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
420
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
421
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
422
+ # cos & sin output to change significantly.
423
+ # We want to recompute self.inv_freq if it was not loaded in fp32
424
+ if self.inv_freq.dtype != torch.float32:
425
+ inv_freq = self._compute_inv_freq(device=device)
426
+ else:
427
+ inv_freq = self.inv_freq
428
+ else:
429
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
430
+ inv_freq = self.inv_freq
431
+ # Don't do einsum, it converts fp32 to fp16 under AMP
432
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
433
+ freqs = torch.outer(t, inv_freq)
434
+ if self.scale is None:
435
+ self._cos_cached = torch.cos(freqs).to(dtype)
436
+ self._sin_cached = torch.sin(freqs).to(dtype)
437
+ else:
438
+ power = (
439
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
440
+ - seqlen // 2
441
+ ) / self.scale_base
442
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
443
+ # We want the multiplication by scale to happen in fp32
444
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
445
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
446
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
447
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
448
+
449
+ def forward(
450
+ self,
451
+ q: torch.Tensor,
452
+ k: torch.Tensor,
453
+ seqlen_offset: Union[int, torch.Tensor] = 0,
454
+ cu_seqlens: Optional[torch.Tensor] = None,
455
+ max_seqlen: Optional[int] = None,
456
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
457
+ """
458
+ q: [B, T, H, D]
459
+ k: [B, T, H, D]
460
+ seqlen_offset:
461
+ (N,) or int. Each sequence in x is shifted by this amount.
462
+ Most commonly used in inference when we have KV cache.
463
+ If it's a tensor of shape (N,), then to update the cos / sin cache, one
464
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
465
+ cu_seqlens: (N + 1,) or None
466
+ max_seqlen: int
467
+ """
468
+ if max_seqlen is not None:
469
+ self._update_cos_sin_cache(max_seqlen, device=q.device, dtype=q.dtype)
470
+ elif isinstance(seqlen_offset, int):
471
+ self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
472
+ if self.scale is None:
473
+ q = rotary_embedding(
474
+ q,
475
+ self._cos_cached,
476
+ self._sin_cached,
477
+ interleaved=self.interleaved,
478
+ seqlen_offsets=seqlen_offset,
479
+ cu_seqlens=cu_seqlens,
480
+ max_seqlen=max_seqlen
481
+ )
482
+ k = rotary_embedding(
483
+ k,
484
+ self._cos_cached,
485
+ self._sin_cached,
486
+ interleaved=self.interleaved,
487
+ seqlen_offsets=seqlen_offset,
488
+ cu_seqlens=cu_seqlens,
489
+ max_seqlen=max_seqlen
490
+ )
491
+
492
+ else:
493
+ q = rotary_embedding(
494
+ q,
495
+ self._cos_cached,
496
+ self._sin_cached,
497
+ interleaved=self.interleaved,
498
+ seqlen_offsets=seqlen_offset,
499
+ cu_seqlens=cu_seqlens,
500
+ max_seqlen=max_seqlen
501
+ )
502
+ k = rotary_embedding(
503
+ k,
504
+ self._cos_k_cached,
505
+ self._sin_k_cached,
506
+ interleaved=self.interleaved,
507
+ seqlen_offsets=seqlen_offset,
508
+ cu_seqlens=cu_seqlens,
509
+ max_seqlen=max_seqlen
510
+ )
511
+
512
+ return q, k
fla/ops/abc/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (234 Bytes). View file
 
fla/ops/abc/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (73.4 kB). View file
 
fla/ops/attn/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .parallel import parallel_attn
4
+
5
+ __all__ = [
6
+ 'parallel_attn'
7
+ ]
fla/ops/attn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (242 Bytes). View file
 
fla/ops/based/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .fused_chunk import fused_chunk_based
4
+ from .parallel import parallel_based
5
+
6
+ __all__ = [
7
+ 'fused_chunk_based',
8
+ 'parallel_based'
9
+ ]
fla/ops/based/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (323 Bytes). View file
 
fla/ops/based/__pycache__/fused_chunk.cpython-311.pyc ADDED
Binary file (22.9 kB). View file
 
fla/ops/based/fused_chunk.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
11
+
12
+
13
+ @triton.jit(do_not_specialize=['T'])
14
+ def fused_chunk_based_fwd_kernel(
15
+ q,
16
+ k,
17
+ v,
18
+ o,
19
+ z,
20
+ scale, # K ** -0.5
21
+ T,
22
+ B: tl.constexpr,
23
+ H: tl.constexpr,
24
+ K: tl.constexpr,
25
+ V: tl.constexpr,
26
+ BT: tl.constexpr,
27
+ BK: tl.constexpr,
28
+ BV: tl.constexpr,
29
+ ):
30
+ # indices
31
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
32
+
33
+ o_i = tl.arange(0, BT)
34
+
35
+ # [BT, BT]
36
+ m_s = o_i[:, None] >= o_i[None, :]
37
+
38
+ # [BV], zero-order taylor expansion
39
+ b_h_0o = tl.zeros([BV], dtype=tl.float32)
40
+ # [BK, BV], first-order taylor expansion
41
+ b_h_1o = tl.zeros([BK, BV], dtype=tl.float32)
42
+ # [BK, BK, BV] second-order taylor expansion
43
+ b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
44
+
45
+ # make block pointers
46
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
47
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
48
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
49
+ p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
50
+
51
+ p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT)
52
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
53
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
54
+ k_0o = 0
55
+
56
+ for i in range(0, tl.cdiv(T, BT)):
57
+ # [BK, BT]
58
+ b_k = tl.load(p_k, boundary_check=(0, 1))
59
+ # [BK*BK, BT]
60
+ b_k_2o = b_k[:, None, :] * b_k[None, :, :]
61
+ b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype)
62
+ # [BT, BV]
63
+ b_v = tl.load(p_v, boundary_check=(0, 1))
64
+ # [BT, BK]
65
+ b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype)
66
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
67
+ b_z = tl.zeros([BT], dtype=tl.float32)
68
+
69
+ # interchunk
70
+ # zero-order
71
+ b_o += b_h_0o
72
+ b_z += k_0o
73
+ # first-order
74
+ b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False)
75
+ b_z += tl.sum(b_q * k_1o, axis=1)
76
+ # second-order
77
+ b_q_2o = b_q[:, :, None] * b_q[:, None, :]
78
+ b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype)
79
+ b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5
80
+ b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5
81
+
82
+ # update running statistics
83
+ k_1o += tl.sum(b_k, axis=1)[None, :]
84
+ k_2o += tl.sum(b_k_2o, axis=1)[None, :]
85
+ k_0o += BT
86
+
87
+ # intrachunk
88
+ # [BT, BT]
89
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
90
+ b_s = 1 + b_s + 0.5 * b_s * b_s
91
+ b_s = tl.where(m_s, b_s, 0)
92
+ b_z += tl.sum(b_s, axis=1)
93
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
94
+ # [TB, BV]
95
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
96
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T)
97
+
98
+ # update hidden state
99
+ # [BK, BV]
100
+ b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False)
101
+ b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False)
102
+ b_h_0o = b_h_0o + tl.sum(b_v, axis=0)
103
+
104
+ p_q = tl.advance(p_q, (BT, 0))
105
+ p_k = tl.advance(p_k, (0, BT))
106
+ p_v = tl.advance(p_v, (BT, 0))
107
+ p_o = tl.advance(p_o, (BT, 0))
108
+ p_z += BT
109
+
110
+
111
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
112
+ @triton.jit
113
+ def fused_chunk_based_bwd_kernel(
114
+ # NV: number of split in the V dimension. NK: number of split in the K dimension
115
+ q,
116
+ k,
117
+ v,
118
+ do,
119
+ dz,
120
+ dq,
121
+ dk,
122
+ dv,
123
+ scale, # K ** -0.5
124
+ T,
125
+ B: tl.constexpr,
126
+ H: tl.constexpr,
127
+ K: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BK: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ ):
133
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
134
+
135
+ o_i = tl.arange(0, BT)
136
+ m_s = o_i[:, None] >= o_i[None, :]
137
+
138
+ # [BV], zero-order taylor expansion
139
+ # b_h_0o = tl.zeros([BV], dtype=tl.float32)
140
+ # [BK, BV], first-order taylor expansion
141
+ b_h_1o = tl.zeros([BV, BK], dtype=tl.float32)
142
+ # [BK, BK, BV] second-order taylor expansion
143
+ b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32)
144
+
145
+ k_1o = tl.zeros([1, BK], dtype=tl.float32)
146
+ k_2o = tl.zeros([1, BK * BK], dtype=tl.float32)
147
+
148
+ for i in range(0, tl.cdiv(T, BT)):
149
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
150
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
151
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
152
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
154
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT
155
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
156
+
157
+ # load tensors
158
+ # [BT, BK]
159
+ b_q = tl.load(p_q, boundary_check=(0, 1))
160
+ b_q = (b_q * scale).to(b_q.dtype)
161
+ b_k = tl.load(p_k, boundary_check=(0, 1))
162
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
163
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T)
164
+ # [BV, BT]
165
+ b_v = tl.load(p_v, boundary_check=(0, 1))
166
+
167
+ # inter-chunk
168
+ b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False)
169
+ if i_v == 0:
170
+ b_dq += b_dz[:, None] * k_1o
171
+ b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5
172
+ if i_v == 0:
173
+ b_dq_2o += (b_dz[:, None] * k_2o) * 0.5
174
+ b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK])
175
+ b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1)
176
+ b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2)
177
+ b_dq *= scale
178
+
179
+ # intra-chunk
180
+ # [BT, BT]
181
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
182
+ if i_v == 0:
183
+ b_ds += b_dz[:, None]
184
+ b_ds = tl.where(m_s, b_ds, 0) * scale
185
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
186
+ b_s = tl.where(m_s, b_s, 0)
187
+ b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False)
188
+
189
+ # store
190
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
191
+
192
+ # update hidden state
193
+ # [BT, BK*BK]
194
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
195
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
196
+ # [BV, BK*BK]
197
+ b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False)
198
+ # [BV, BK]
199
+ b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False)
200
+
201
+ if i_v == 0:
202
+ # update running statistics
203
+ k_1o += tl.sum(b_k, axis=0)[None, :]
204
+ k_2o += tl.sum(b_k_2o, axis=0)[None, :]
205
+
206
+ tl.debug_barrier()
207
+ b_h_1o = None
208
+ b_h_2o = None
209
+
210
+ # [BK, BV], first-order taylor expansion
211
+ b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32)
212
+ # [BK, BK, BV] second-order taylor expansion
213
+ b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32)
214
+ b_dh_0o = tl.zeros([BV], dtype=tl.float32)
215
+ m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]
216
+
217
+ dq_1o = tl.zeros([1, BK], dtype=tl.float32)
218
+ dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32)
219
+
220
+ for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT):
221
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BT), (0, 1))
222
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i, i_k * BK), (BT, BK), (1, 0))
223
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
224
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0))
225
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (i, i_k*BK), (BT, BK), (1, 0))
226
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (i, i_v*BV), (BT, BV), (1, 0))
227
+ p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i
228
+
229
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
230
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
231
+
232
+ b_q = tl.load(p_q, boundary_check=(0, 1))
233
+ b_k = tl.load(p_k, boundary_check=(0, 1))
234
+ b_v = tl.load(p_v, boundary_check=(0, 1))
235
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
236
+ b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T)
237
+ b_q = (b_q * scale).to(b_k.dtype)
238
+
239
+ # intra chunk
240
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
241
+ if i_v == 0:
242
+ b_ds += b_dz[None, :]
243
+ b_ds = tl.where(m_s, b_ds, 0)
244
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
245
+ b_s2 = 1 + b_s + 0.5 * b_s * b_s
246
+ b_s = tl.where(m_s, b_s, 0)
247
+ b_s2 = tl.where(m_s, b_s2, 0)
248
+ b_ds *= (1+b_s)
249
+
250
+ b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False)
251
+ b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False)
252
+
253
+ # inter chunk
254
+ b_k_2o = b_k[:, :, None] * b_k[:, None, :]
255
+ b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype)
256
+
257
+ b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False)
258
+ b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False)
259
+ b_dv += b_dh_0o
260
+
261
+ b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False)
262
+
263
+ if i_v == 0:
264
+ b_dk += dq_1o
265
+
266
+ b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False)
267
+ if i_v == 0:
268
+ b_dk_2o += dq_2o
269
+ b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT])
270
+ b_k_fp32 = tl.trans(b_k.to(tl.float32))
271
+ b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0)
272
+ b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1)
273
+ b_dk += tl.trans(b_dk2)
274
+
275
+ # hidden state update
276
+ b_dh_0o += tl.sum(b_do, axis=0)
277
+ b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False)
278
+ b_q_2o = b_q[None, :, :] * b_q[:, None, :]
279
+ b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype)
280
+ b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5
281
+
282
+ if i_v == 0:
283
+ dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :]
284
+ dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None]
285
+
286
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
287
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
288
+
289
+
290
+ class FusedChunkBasedFunction(torch.autograd.Function):
291
+
292
+ @staticmethod
293
+ @input_guard
294
+ @autocast_custom_fwd
295
+ def forward(ctx, q, k, v, scale=1):
296
+ B, H, T, K, V = *k.shape, v.shape[-1]
297
+
298
+ scale = scale
299
+ BT = 16
300
+ BK, BV = min(K, 16), min(V, 32)
301
+ BK, BV = max(BK, 16), max(BV, 16)
302
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
303
+
304
+ num_warps = 4
305
+
306
+ # the norm of o might explode, so we need to use float32 here
307
+ o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)
308
+ z = q.new_empty(NK, B, H, T, dtype=torch.float32)
309
+
310
+ grid = (NV, NK, B * H)
311
+ fused_chunk_based_fwd_kernel[grid](
312
+ q, k, v, o, z,
313
+ scale,
314
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
315
+ num_warps=num_warps,
316
+ )
317
+ o = o.sum(0)
318
+ z = z.sum(0)
319
+ ctx.save_for_backward(q, k, v)
320
+ ctx.scale = scale
321
+ return o.to(q.dtype), z.to(z.dtype)
322
+
323
+ @staticmethod
324
+ @input_guard
325
+ @autocast_custom_bwd
326
+ def backward(ctx, do, dz):
327
+ q, k, v = ctx.saved_tensors
328
+ B, H, T, K, V = *k.shape, v.shape[-1]
329
+ scale = ctx.scale
330
+
331
+ BT = 16
332
+ BK, BV = min(K, 16), min(V, 32)
333
+ BK, BV = max(BK, 16), max(BV, 16)
334
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
335
+ num_stages = 1
336
+ num_warps = 4
337
+
338
+ dq = q.new_empty(NV, B, H, T, K)
339
+ dk = q.new_empty(NV, B, H, T, K)
340
+ dv = q.new_empty(NK, B, H, T, V)
341
+ grid = (NV, NK, B * H)
342
+
343
+ fused_chunk_based_bwd_kernel[grid](
344
+ q, k, v, do, dz, dq, dk, dv,
345
+ scale,
346
+ T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV,
347
+ num_warps=num_warps,
348
+ num_stages=num_stages
349
+ )
350
+ dq = dq.sum(0)
351
+ dk = dk.sum(0)
352
+ dv = dv.sum(0)
353
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None
354
+
355
+
356
+ def fused_chunk_based(
357
+ q: torch.Tensor,
358
+ k: torch.Tensor,
359
+ v: torch.Tensor,
360
+ scale: Optional[float] = None,
361
+ use_norm: bool = True,
362
+ head_first: bool = True
363
+ ):
364
+ assert q.shape[-1] <= 16, 'only support feature dimension up to 16.'
365
+ if scale is None:
366
+ scale = q.shape[-1] ** -0.5
367
+ if not head_first:
368
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
369
+ o, z = FusedChunkBasedFunction.apply(q, k, v, scale)
370
+ if use_norm:
371
+ o = o / (z[..., None] + 1e-6)
372
+ if not head_first:
373
+ o = o.transpose(1, 2)
374
+ return o.to(q.dtype)
fla/ops/based/naive.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import rearrange
7
+
8
+
9
+ def naive_parallel_based(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ scale: Optional[float] = None,
14
+ use_norm: bool = True
15
+ ):
16
+ if scale is None:
17
+ scale = q.shape[-1] ** -0.5
18
+ q = q * scale
19
+ attn = q @ k.transpose(-2, -1)
20
+ attn = 1 + attn + 1/2 * (attn ** 2)
21
+ attn.masked_fill_(~torch.tril(torch.ones(
22
+ q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
23
+ o = attn @ v
24
+ if use_norm:
25
+ z = attn.sum(-1)
26
+ return o / (z[..., None] + 1e-6)
27
+ else:
28
+ return o
29
+
30
+
31
+ def naive_chunk_based(q, k, v, chunk_size=256):
32
+ q = q * (q.shape[-1] ** -0.5)
33
+ # compute normalizer.
34
+ k_cumsum = torch.cumsum(k, dim=-2)
35
+ kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
36
+ # first
37
+ z = (q * k_cumsum).sum(-1)
38
+ # second order
39
+ z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
40
+ # zero-th order
41
+ z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
42
+
43
+ # compute o
44
+ # constant term
45
+ _o = v.cumsum(-2)
46
+
47
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
48
+
49
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
50
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
51
+
52
+ intra_chunk_attn = q @ k.transpose(-2, -1)
53
+ intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
54
+ intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0)
55
+ o = intra_chunk_attn @ v
56
+
57
+ # quadractic term
58
+ kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
59
+ kv = kv.cumsum(2)
60
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
61
+
62
+ o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
63
+
64
+ # linear term
65
+ kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
66
+ kv = kv.cumsum(2)
67
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
68
+ o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
69
+
70
+ o = rearrange(o, 'b h n c d -> b h (n c) d')
71
+ o = o + _o
72
+ return o / (z[..., None] + 1e-6)
fla/ops/common/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
fla/ops/common/__pycache__/chunk_delta_h.cpython-311.pyc ADDED
Binary file (24.5 kB). View file
 
fla/ops/common/__pycache__/chunk_h.cpython-311.pyc ADDED
Binary file (25.4 kB). View file
 
fla/ops/common/__pycache__/chunk_o.cpython-311.pyc ADDED
Binary file (37.8 kB). View file
 
fla/ops/common/__pycache__/utils.cpython-311.pyc ADDED
Binary file (5.02 kB). View file
 
fla/ops/common/chunk_delta_h.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper, use_cuda_graph
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
20
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gated_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ v_new,
38
+ g,
39
+ h,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ chunk_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BC: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ NT: tl.constexpr,
53
+ USE_G: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr,
58
+ ):
59
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+ i_n, i_h = i_nh // H, i_nh % H
61
+ if USE_OFFSETS:
62
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
66
+ else:
67
+ bos, eos = i_n * T, i_n * T + T
68
+ NT = tl.cdiv(T, BT)
69
+ boh = i_n * NT
70
+
71
+ # [BK, BV]
72
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
75
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
76
+
77
+ for i_t in range(NT):
78
+ if HEAD_FIRST:
79
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ else:
81
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
83
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
84
+ if USE_G:
85
+ last_idx = min((i_t + 1) * BT, T) - 1
86
+ if HEAD_FIRST:
87
+ b_g_last = tl.load(g + i_nh * T + last_idx)
88
+ else:
89
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
90
+ else:
91
+ b_g_last = None
92
+ last_idx = None
93
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
94
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
95
+ if HEAD_FIRST:
96
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
97
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
98
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
100
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
103
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
104
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
105
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
106
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT+i_c*BC, ), (BC,), (0,)) if USE_G else None
107
+ b_g = tl.load(p_g, boundary_check=(0, )) if USE_G else None
108
+ # [BK, BC]
109
+ b_k = tl.load(p_k, boundary_check=(0, 1))
110
+ b_k = (b_k * exp(b_g_last - b_g)[None, :]).to(b_k.dtype) if USE_G else b_k
111
+ # [BC, BK]
112
+ b_d = tl.load(p_d, boundary_check=(0, 1))
113
+ b_d = (b_d * exp(b_g)[:, None]).to(b_d.dtype) if USE_G else b_d
114
+ # [BC, BV]
115
+ b_v = tl.load(p_v, boundary_check=(0, 1))
116
+ b_v2 = b_v - tl.dot(b_d, b_h.to(b_d.dtype))
117
+ # [BK, BV]
118
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
119
+ b_hc += tl.dot(b_k, b_v2.to(b_k.dtype), allow_tf32=False)
120
+ b_h *= exp(b_g_last) if USE_G else 1
121
+ b_h += b_hc
122
+
123
+ if STORE_FINAL_STATE:
124
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
125
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
126
+
127
+
128
+ @triton.heuristics({
129
+ 'USE_G': lambda args: args['g'] is not None,
130
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
131
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in NUM_WARPS
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BT', 'BK', 'BV', 'USE_G'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_gated_delta_rule_bwd_kernel_dhu(
145
+ q,
146
+ k,
147
+ d,
148
+ g,
149
+ dht,
150
+ dh0,
151
+ do,
152
+ dh,
153
+ dv,
154
+ dv2,
155
+ offsets,
156
+ chunk_offsets,
157
+ scale,
158
+ T,
159
+ H: tl.constexpr,
160
+ K: tl.constexpr,
161
+ V: tl.constexpr,
162
+ BT: tl.constexpr,
163
+ BC: tl.constexpr,
164
+ BK: tl.constexpr,
165
+ BV: tl.constexpr,
166
+ USE_G: tl.constexpr,
167
+ USE_INITIAL_STATE: tl.constexpr,
168
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr
171
+ ):
172
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
173
+ i_n, i_h = i_nh // H, i_nh % H
174
+ if USE_OFFSETS:
175
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
176
+ T = eos - bos
177
+ NT = tl.cdiv(T, BT)
178
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
179
+ else:
180
+ bos, eos = i_n * T, i_n * T + T
181
+ NT = tl.cdiv(T, BT)
182
+ boh = i_n * NT
183
+
184
+ # [BK, BV]
185
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
186
+ if USE_FINAL_STATE_GRADIENT:
187
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
188
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
189
+
190
+ for i_t in range(NT - 1, -1, -1):
191
+ if HEAD_FIRST:
192
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
193
+ else:
194
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
195
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
196
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
197
+ if USE_G:
198
+ last_idx = min((i_t + 1) * BT, T) - 1
199
+ if HEAD_FIRST:
200
+ bg_last = tl.load(g + i_nh * T + last_idx)
201
+ else:
202
+ bg_last = tl.load(g + (bos + last_idx) * H + i_h)
203
+ else:
204
+ bg_last = None
205
+ last_idx = None
206
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
207
+ if HEAD_FIRST:
208
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
209
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
210
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
211
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
212
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
213
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
214
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
215
+ else:
216
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
217
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
218
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
219
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
220
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
221
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT + i_c * BC,), (BC,), (0,)) if USE_G else None
222
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
223
+ b_g = tl.load(p_g, boundary_check=(0,)) if USE_G else None
224
+ # [BK, BT]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale * exp(b_g)[None, :]).to(b_q.dtype) if USE_G else (b_q * scale).to(b_q.dtype)
227
+ # [BT, BK]
228
+ b_k = tl.load(p_k, boundary_check=(0, 1))
229
+ b_d = tl.load(p_d, boundary_check=(0, 1))
230
+ b_k = (b_k * exp(bg_last - b_g)[:, None]).to(b_k.dtype) if USE_G else b_k
231
+ b_d = (b_d * exp(b_g)[None, :]).to(b_d.dtype) if USE_G else b_d
232
+ # [BT, V]
233
+ b_do = tl.load(p_do, boundary_check=(0, 1))
234
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
235
+ b_dv2 = b_dv + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
236
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
237
+ # [BK, BV]
238
+ b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
239
+ b_dh_tmp -= tl.dot(b_d, b_dv2.to(b_q.dtype), allow_tf32=False)
240
+ b_dh *= exp(bg_last) if USE_G else 1
241
+ b_dh += b_dh_tmp
242
+
243
+ if USE_INITIAL_STATE:
244
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
246
+
247
+
248
+ def chunk_gated_delta_rule_fwd_h(
249
+ k: torch.Tensor,
250
+ w: torch.Tensor,
251
+ u: torch.Tensor,
252
+ g: Optional[torch.Tensor] = None,
253
+ initial_state: Optional[torch.Tensor] = None,
254
+ output_final_state: bool = False,
255
+ offsets: Optional[torch.LongTensor] = None,
256
+ indices: Optional[torch.LongTensor] = None,
257
+ head_first: bool = True,
258
+ chunk_size: int = 64
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ if head_first:
261
+ B, H, T, K, V = *k.shape, u.shape[-1]
262
+ else:
263
+ B, T, H, K, V = *k.shape, u.shape[-1]
264
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
265
+ # N: the actual number of sequences in the batch with either equal or variable lengths
266
+ if offsets is None:
267
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
268
+ else:
269
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
270
+ BK = triton.next_power_of_2(K)
271
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
272
+ # H100 can have larger block size
273
+ if check_shared_mem('hopper', k.device.index):
274
+ BV = 64
275
+ BC = 64 if K <= 128 else 32
276
+ # A100
277
+ elif check_shared_mem('ampere', k.device.index):
278
+ BV = 32
279
+ BC = 64
280
+ else:
281
+ BV = 32
282
+ BC = 32 if K <= 128 else 16
283
+ BC = min(BT, BC)
284
+ NK = triton.cdiv(K, BK)
285
+ NV = triton.cdiv(V, BV)
286
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
287
+
288
+ if head_first:
289
+ h = k.new_empty(B, H, NT, K, V)
290
+ else:
291
+ h = k.new_empty(B, NT, H, K, V)
292
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
293
+
294
+ v_new = torch.empty_like(u)
295
+ grid = (NK, NV, N * H)
296
+
297
+ chunk_gated_delta_rule_fwd_kernel_h[grid](
298
+ k=k,
299
+ v=u,
300
+ d=w,
301
+ v_new=v_new,
302
+ g=g,
303
+ h=h,
304
+ h0=initial_state,
305
+ ht=final_state,
306
+ offsets=offsets,
307
+ chunk_offsets=chunk_offsets,
308
+ T=T,
309
+ H=H,
310
+ K=K,
311
+ V=V,
312
+ BT=BT,
313
+ BC=BC,
314
+ BK=BK,
315
+ BV=BV,
316
+ NT=NT,
317
+ HEAD_FIRST=head_first
318
+ )
319
+ return h, v_new, final_state
320
+
321
+
322
+ def chunk_gated_delta_rule_bwd_dhu(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ w: torch.Tensor,
326
+ g: torch.Tensor,
327
+ h0: torch.Tensor,
328
+ dht: Optional[torch.Tensor],
329
+ do: torch.Tensor,
330
+ dv: torch.Tensor,
331
+ scale: float,
332
+ offsets: Optional[torch.LongTensor] = None,
333
+ indices: Optional[torch.LongTensor] = None,
334
+ head_first: bool = True,
335
+ chunk_size: int = 64
336
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
337
+ if head_first:
338
+ B, H, T, K, V = *q.shape, do.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *q.shape, do.shape[-1]
341
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
342
+ # N: the actual number of sequences in the batch with either equal or variable lengths
343
+ if offsets is None:
344
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
345
+ else:
346
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
347
+
348
+ BK = triton.next_power_of_2(K)
349
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
350
+
351
+ # H100
352
+ if check_shared_mem('hopper', q.device.index):
353
+ BV = 64
354
+ BC = 64 if K <= 128 else 32
355
+ # A100
356
+ elif check_shared_mem('ampere', q.device.index):
357
+ BV = 32
358
+ BC = 64 if K <= 128 else 32
359
+ else:
360
+ BV = 32 if K <= 128 else 16
361
+ BC = 16
362
+
363
+ BC = min(BT, BC)
364
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
365
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
366
+
367
+ if head_first:
368
+ dh = q.new_empty(B, H, NT, K, V)
369
+ else:
370
+ dh = q.new_empty(B, NT, H, K, V)
371
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
372
+ dv2 = torch.empty_like(dv)
373
+
374
+ grid = (NK, NV, N * H)
375
+ chunk_gated_delta_rule_bwd_kernel_dhu[grid](
376
+ q=q,
377
+ k=k,
378
+ d=w,
379
+ g=g,
380
+ dht=dht,
381
+ dh0=dh0,
382
+ do=do,
383
+ dh=dh,
384
+ dv=dv,
385
+ dv2=dv2,
386
+ offsets=offsets,
387
+ chunk_offsets=chunk_offsets,
388
+ scale=scale,
389
+ T=T,
390
+ H=H,
391
+ K=K,
392
+ V=V,
393
+ BT=BT,
394
+ BC=BC,
395
+ BK=BK,
396
+ BV=BV,
397
+ HEAD_FIRST=head_first
398
+ )
399
+ return dh, dh0, dv2
fla/ops/common/chunk_h_parallel.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Fully parallelized state passing.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from fla.ops.utils.op import exp
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in [32, 64, 128]
26
+ for BV in [32, 64, 128]
27
+ for num_warps in [2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h_parallel(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_G: tl.constexpr,
52
+ USE_GK: tl.constexpr,
53
+ USE_GV: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr
58
+ ):
59
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+
61
+ NV = tl.cdiv(V, BV)
62
+ # i_b: batch index
63
+ # i_h: head index
64
+ # i_n: sequence index
65
+ # i_t: chunk index within current sequence
66
+ # i_tg: (global) chunk index across all sequences
67
+ i_k, i_v = i_kv // NV, i_kv % NV
68
+ i_b, i_h = i_bh // H, i_bh % H
69
+ if USE_OFFSETS:
70
+ i_tg = i_t
71
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
72
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
73
+ T = eos - bos
74
+ NT = tl.cdiv(T, BT)
75
+ else:
76
+ bos, eos = i_b * T, i_b * T + T
77
+ NT = tl.cdiv(T, BT)
78
+ i_n, i_tg = i_b, i_b * NT + i_t
79
+ i_nh = i_n * H + i_h
80
+
81
+ if HEAD_FIRST:
82
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
85
+ else:
86
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
89
+
90
+ if i_t == 0:
91
+ if USE_INITIAL_STATE:
92
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
93
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
94
+ else:
95
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ # [BK, BT]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ # [BT, BV]
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+
103
+ last_idx = min(i_t * BT + BT, T) - 1
104
+ # scalar decay
105
+ if USE_G:
106
+ if HEAD_FIRST:
107
+ b_g_last = tl.load(g + i_bh * T + last_idx)
108
+ p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
109
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
110
+ else:
111
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
112
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+
128
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
129
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
130
+
131
+ # vector decay, h = h @ Diag(gv)
132
+ if USE_GV:
133
+ if HEAD_FIRST:
134
+ p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
135
+ p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
136
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
137
+ else:
138
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
139
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
140
+
141
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
142
+
143
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
144
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
145
+
146
+ b_h = tl.dot(b_k, b_v)
147
+ if i_t < NT - 1:
148
+ if HEAD_FIRST:
149
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ else:
151
+ p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
153
+ elif STORE_FINAL_STATE:
154
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
155
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
156
+
157
+
158
+ @triton.heuristics({
159
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
160
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
161
+ })
162
+ @triton.autotune(
163
+ configs=[
164
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
165
+ for BK in [32, 64, 128]
166
+ for BV in [32, 64, 128]
167
+ for num_warps in [2, 4, 8, 16]
168
+ for num_stages in [2, 3]
169
+ ],
170
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
171
+ )
172
+ @triton.jit(do_not_specialize=['T'])
173
+ def chunk_fwd_kernel_h_reduction(
174
+ h,
175
+ g,
176
+ gk,
177
+ gv,
178
+ kvt,
179
+ ht,
180
+ offsets,
181
+ chunk_offsets,
182
+ T,
183
+ H: tl.constexpr,
184
+ K: tl.constexpr,
185
+ V: tl.constexpr,
186
+ BT: tl.constexpr,
187
+ BK: tl.constexpr,
188
+ BV: tl.constexpr,
189
+ USE_G: tl.constexpr,
190
+ USE_GK: tl.constexpr,
191
+ USE_GV: tl.constexpr,
192
+ STORE_FINAL_STATE: tl.constexpr,
193
+ USE_OFFSETS: tl.constexpr,
194
+ HEAD_FIRST: tl.constexpr
195
+ ):
196
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
197
+ i_n, i_h = i_nh // H, i_nh % H
198
+ if USE_OFFSETS:
199
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
200
+ T = eos - bos
201
+ NT = tl.cdiv(T, BT)
202
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
203
+ else:
204
+ bos, eos = i_n * T, i_n * T + T
205
+ NT = tl.cdiv(T, BT)
206
+ boh = i_n * NT
207
+
208
+ # [BK, BV]
209
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
210
+ for i_t in range(NT):
211
+ if HEAD_FIRST:
212
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
213
+ else:
214
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
215
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
216
+ if i_t > 0:
217
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
218
+
219
+ last_idx = min(i_t * BT + BT, T) - 1
220
+ # scalar decay
221
+ if USE_G:
222
+ if HEAD_FIRST:
223
+ b_g_last = tl.load(g + i_nh * T + last_idx)
224
+ else:
225
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
226
+ b_h *= exp(b_g_last)
227
+
228
+ # vector decay, h = Diag(gk) @ h
229
+ if USE_GK:
230
+ if HEAD_FIRST:
231
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
232
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
233
+ else:
234
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
235
+
236
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
237
+ b_h *= exp(b_gk_last)[:, None]
238
+
239
+ # vector decay, h = h @ Diag(gv)
240
+ if USE_GV:
241
+ if HEAD_FIRST:
242
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
243
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
244
+ else:
245
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
246
+
247
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
248
+ b_h *= exp(b_gv_last)[None, :]
249
+
250
+ if STORE_FINAL_STATE:
251
+ p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
252
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
253
+ b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32)
254
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
255
+
256
+
257
+ @triton.heuristics({
258
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
259
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
260
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
261
+ })
262
+ @triton.autotune(
263
+ configs=[
264
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
265
+ for BK in [32, 64, 128]
266
+ for BV in [32, 64, 128]
267
+ for num_warps in [2, 4, 8]
268
+ for num_stages in [2, 3, 4]
269
+ ],
270
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
271
+ )
272
+ @triton.jit(do_not_specialize=['T'])
273
+ def chunk_bwd_kernel_dh_parallel(
274
+ q,
275
+ g,
276
+ gk,
277
+ gv,
278
+ do,
279
+ dh,
280
+ dht,
281
+ dh0,
282
+ offsets,
283
+ indices,
284
+ scale,
285
+ T,
286
+ HQ: tl.constexpr,
287
+ H: tl.constexpr,
288
+ K: tl.constexpr,
289
+ V: tl.constexpr,
290
+ BT: tl.constexpr,
291
+ BK: tl.constexpr,
292
+ BV: tl.constexpr,
293
+ NG: tl.constexpr,
294
+ USE_G: tl.constexpr,
295
+ USE_GK: tl.constexpr,
296
+ USE_GV: tl.constexpr,
297
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
298
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
299
+ USE_OFFSETS: tl.constexpr,
300
+ HEAD_FIRST: tl.constexpr
301
+ ):
302
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
303
+
304
+ NV = tl.cdiv(V, BV)
305
+ i_k, i_v = i_kv // NV, i_kv % NV
306
+ i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG
307
+ i_h = i_hq // NG
308
+ if USE_OFFSETS:
309
+ i_tg = i_t
310
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
311
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
312
+ T = eos - bos
313
+ NT = tl.cdiv(T, BT)
314
+ else:
315
+ bos, eos = i_b * T, i_b * T + T
316
+ NT = tl.cdiv(T, BT)
317
+ i_n, i_tg = i_b, i_b * NT + i_t
318
+ i_nh = i_n * HQ + i_hq
319
+
320
+ if HEAD_FIRST:
321
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
322
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
323
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ else:
325
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
326
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
327
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
328
+
329
+ if i_t == NT - 1:
330
+ if USE_FINAL_STATE_GRADIENT:
331
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
332
+ b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
333
+ else:
334
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
335
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
336
+
337
+ # [BK, BT]
338
+ b_q = tl.load(p_q, boundary_check=(0, 1))
339
+ b_q = (b_q * scale).to(b_q.dtype)
340
+ # [BT, BV]
341
+ b_do = tl.load(p_do, boundary_check=(0, 1))
342
+
343
+ if USE_G:
344
+ if HEAD_FIRST:
345
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
346
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
347
+ else:
348
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
349
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
350
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ else:
356
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
357
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
358
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
359
+
360
+ if USE_GV:
361
+ if HEAD_FIRST:
362
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
363
+ else:
364
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
365
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
366
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
367
+
368
+ b_dh = tl.dot(b_q, b_do)
369
+ if i_t > 0:
370
+ if HEAD_FIRST:
371
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
372
+ else:
373
+ p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
374
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
375
+ elif STORE_INITIAL_STATE_GRADIENT:
376
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
377
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
378
+
379
+
380
+ @triton.heuristics({
381
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
382
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
383
+ })
384
+ @triton.autotune(
385
+ configs=[
386
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
387
+ for BK in [32, 64, 128]
388
+ for BV in [32, 64, 128]
389
+ for num_warps in [2, 4, 8, 16]
390
+ for num_stages in [2, 3]
391
+ ],
392
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
393
+ )
394
+ @triton.jit(do_not_specialize=['T'])
395
+ def chunk_bwd_kernel_dh_reduction(
396
+ g,
397
+ gk,
398
+ gv,
399
+ dh,
400
+ doq0,
401
+ dh0,
402
+ offsets,
403
+ chunk_offsets,
404
+ T,
405
+ HQ: tl.constexpr,
406
+ H: tl.constexpr,
407
+ K: tl.constexpr,
408
+ V: tl.constexpr,
409
+ BT: tl.constexpr,
410
+ BK: tl.constexpr,
411
+ BV: tl.constexpr,
412
+ NG: tl.constexpr,
413
+ USE_G: tl.constexpr,
414
+ USE_GK: tl.constexpr,
415
+ USE_GV: tl.constexpr,
416
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
417
+ USE_OFFSETS: tl.constexpr,
418
+ HEAD_FIRST: tl.constexpr
419
+ ):
420
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
421
+ i_bg = i_nh // NG
422
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
423
+ i_h = i_hq // NG
424
+ if USE_OFFSETS:
425
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
426
+ T = eos - bos
427
+ NT = tl.cdiv(T, BT)
428
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
429
+ else:
430
+ bos, eos = i_n * T, i_n * T + T
431
+ NT = tl.cdiv(T, BT)
432
+ boh = i_n * NT
433
+
434
+ # [BK, BV]
435
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
436
+ for i_t in range(NT - 1, -1, -1):
437
+ if HEAD_FIRST:
438
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
439
+ else:
440
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
441
+ b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32)
442
+ if i_t < NT - 1:
443
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
444
+
445
+ last_idx = min(i_t * BT + BT, T) - 1
446
+ if USE_G:
447
+ if HEAD_FIRST:
448
+ b_g_last = tl.load(g + i_bg * T + last_idx)
449
+ else:
450
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
451
+ b_dh *= exp(b_g_last)
452
+
453
+ if USE_GK:
454
+ if HEAD_FIRST:
455
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
456
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
457
+ else:
458
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
459
+
460
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
461
+ b_dh *= exp(b_gk_last)[:, None]
462
+
463
+ if USE_GV:
464
+ if HEAD_FIRST:
465
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
466
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
467
+ else:
468
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
469
+
470
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
471
+ b_dh *= exp(b_gv_last)[None, :]
472
+
473
+ if STORE_INITIAL_STATE_GRADIENT:
474
+ p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
475
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
476
+ b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32)
477
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
478
+
479
+
480
+ def chunk_fwd_h(
481
+ k: torch.Tensor,
482
+ v: torch.Tensor,
483
+ g: torch.Tensor,
484
+ gk: torch.Tensor,
485
+ gv: torch.Tensor,
486
+ h0: torch.Tensor,
487
+ output_final_state: bool,
488
+ states_in_fp32: bool = False,
489
+ offsets: Optional[torch.Tensor] = None,
490
+ indices: Optional[torch.Tensor] = None,
491
+ head_first: bool = True,
492
+ chunk_size: int = 64
493
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
494
+ if head_first:
495
+ B, H, T, K, V = *k.shape, v.shape[-1]
496
+ else:
497
+ B, T, H, K, V = *k.shape, v.shape[-1]
498
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
499
+ # N: the actual number of sequences in the batch with either equal or variable lengths
500
+ if offsets is None:
501
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
502
+ else:
503
+ if indices is None:
504
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
505
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
506
+ N, NT = len(offsets) - 1, len(indices)
507
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
508
+
509
+ h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float)
510
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
511
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H)
512
+ chunk_fwd_kernel_h_parallel[grid](
513
+ k=k,
514
+ v=v,
515
+ h=h,
516
+ g=g,
517
+ gk=gk,
518
+ gv=gv,
519
+ h0=h0,
520
+ ht=ht,
521
+ offsets=offsets,
522
+ indices=indices,
523
+ T=T,
524
+ H=H,
525
+ K=K,
526
+ V=V,
527
+ BT=BT,
528
+ USE_G=g is not None,
529
+ USE_GK=gk is not None,
530
+ USE_GV=gv is not None,
531
+ HEAD_FIRST=head_first
532
+ )
533
+ kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None)
534
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
535
+ chunk_fwd_kernel_h_reduction[grid](
536
+ h=h,
537
+ g=g,
538
+ gk=gk,
539
+ gv=gv,
540
+ kvt=kvt,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ chunk_offsets=chunk_offsets,
544
+ T=T,
545
+ H=H,
546
+ K=K,
547
+ V=V,
548
+ BT=BT,
549
+ USE_G=g is not None,
550
+ USE_GK=gk is not None,
551
+ USE_GV=gv is not None,
552
+ HEAD_FIRST=head_first
553
+ )
554
+ h = h.to(k.dtype) if not states_in_fp32 else h
555
+ return h, ht
556
+
557
+
558
+ def chunk_bwd_dh(
559
+ q: torch.Tensor,
560
+ k: torch.Tensor,
561
+ v: torch.Tensor,
562
+ g: torch.Tensor,
563
+ gk: torch.Tensor,
564
+ gv: torch.Tensor,
565
+ do: torch.Tensor,
566
+ h0: torch.Tensor,
567
+ dht: torch.Tensor,
568
+ scale: float,
569
+ states_in_fp32: bool = False,
570
+ offsets: Optional[torch.Tensor] = None,
571
+ indices: Optional[torch.Tensor] = None,
572
+ head_first: bool = True,
573
+ chunk_size: int = 64
574
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
575
+ if head_first:
576
+ B, H, T, K, V = *k.shape, v.shape[-1]
577
+ HQ = q.shape[1]
578
+ else:
579
+ B, T, H, K, V = *k.shape, v.shape[-1]
580
+ HQ = q.shape[2]
581
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
582
+ # N: the actual number of sequences in the batch with either equal or variable lengths
583
+ # NG: number of groups in GQA
584
+ if offsets is None:
585
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
586
+ else:
587
+ if indices is None:
588
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
589
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
590
+ N, NT = len(offsets) - 1, len(indices)
591
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
592
+ NG = HQ // H
593
+
594
+ if head_first:
595
+ dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
596
+ else:
597
+ dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
598
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
599
+
600
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ)
601
+ chunk_bwd_kernel_dh_parallel[grid](
602
+ q=q,
603
+ g=g,
604
+ gk=gk,
605
+ gv=gv,
606
+ do=do,
607
+ dh=dh,
608
+ dht=dht,
609
+ dh0=dh0,
610
+ offsets=offsets,
611
+ indices=indices,
612
+ scale=scale,
613
+ T=T,
614
+ HQ=HQ,
615
+ H=H,
616
+ K=K,
617
+ V=V,
618
+ BT=BT,
619
+ NG=NG,
620
+ USE_G=g is not None,
621
+ USE_GK=gk is not None,
622
+ USE_GV=gv is not None,
623
+ HEAD_FIRST=head_first
624
+ )
625
+
626
+ doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None)
627
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
628
+ chunk_bwd_kernel_dh_reduction[grid](
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ dh=dh,
633
+ doq0=doq0,
634
+ dh0=dh0,
635
+ offsets=offsets,
636
+ chunk_offsets=chunk_offsets,
637
+ T=T,
638
+ HQ=HQ,
639
+ H=H,
640
+ K=K,
641
+ V=V,
642
+ BT=BT,
643
+ NG=NG,
644
+ USE_G=g is not None,
645
+ USE_GK=gk is not None,
646
+ USE_GV=gv is not None,
647
+ HEAD_FIRST=head_first
648
+ )
649
+ dh = dh.to(q.dtype) if not states_in_fp32 else dh
650
+ return dh, dh0
fla/ops/common/chunk_h_split.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+
12
+
13
+ @triton.heuristics({
14
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
15
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
16
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
17
+ })
18
+ @triton.autotune(
19
+ configs=[
20
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
21
+ for BK in [32, 64]
22
+ for BV in [32, 64]
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3]
25
+ ],
26
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def chunk_fwd_kernel_h_split(
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ hs,
36
+ hr,
37
+ h0,
38
+ ht,
39
+ offsets,
40
+ split_indices,
41
+ T,
42
+ S: tl.constexpr,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ USE_G: tl.constexpr,
50
+ USE_GK: tl.constexpr,
51
+ USE_GV: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr
56
+ ):
57
+ # handle one split at a time
58
+ # i_h: head index
59
+ # i_n: sequence index
60
+ # i_s: local split index inside a sequence
61
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ i_ss, i_h = i_sh // H, i_sh % H
63
+ if USE_OFFSETS:
64
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
65
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
66
+ T = eos - bos
67
+ NS = tl.cdiv(T, S)
68
+ else:
69
+ NS = tl.cdiv(T, S)
70
+ i_n, i_s = i_ss // NS, i_ss % NS
71
+ bos, eos = i_n * T, i_n * T + T
72
+ i_nh = i_n * H + i_h
73
+
74
+ # [BK, BV]
75
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
76
+ # for the first split, we directly store the state as the final result
77
+ if i_s == 0:
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ b_h += tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
81
+ p_hr = tl.make_block_ptr(hr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
83
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
84
+ if HEAD_FIRST:
85
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
86
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
87
+ else:
88
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
89
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
90
+ # [BK, BT]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BT, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ last_idx = min(i_t * BT + BT, T) - 1
95
+
96
+ # scalar decay
97
+ if USE_G:
98
+ if HEAD_FIRST:
99
+ b_g_last = tl.load(g + i_nh * T + last_idx)
100
+ p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
101
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
102
+ else:
103
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
104
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
105
+ b_h *= exp(b_g_last)
106
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
107
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
108
+
109
+ # vector decay, h = Diag(gk) @ h
110
+ if USE_GK:
111
+ if HEAD_FIRST:
112
+ p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
113
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
114
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
115
+ else:
116
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
117
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
118
+
119
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
120
+ b_h *= exp(b_gk_last)[:, None]
121
+
122
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
123
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
124
+
125
+ # vector decay, h = h @ Diag(gv)
126
+ if USE_GV:
127
+ if HEAD_FIRST:
128
+ p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
129
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
130
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
131
+ else:
132
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
133
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
134
+
135
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
136
+ b_h *= exp(b_gv_last)[None, :]
137
+
138
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
139
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
140
+
141
+ b_h += tl.dot(b_k, b_v)
142
+
143
+ # if there are more than one splits, we store the result to (unreduced) hs
144
+ # otherwise, we store the result to ht as the final state
145
+ if NS > 1:
146
+ p_hs = tl.make_block_ptr(hs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
147
+ tl.store(p_hs, b_h.to(p_hs.dtype.element_ty), boundary_check=(0, 1))
148
+ elif STORE_FINAL_STATE:
149
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
151
+
152
+
153
+ @triton.heuristics({
154
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
155
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
156
+ })
157
+ @triton.autotune(
158
+ configs=[
159
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
160
+ for BK in [32, 64]
161
+ for BV in [32, 64]
162
+ for num_warps in [2, 4, 8]
163
+ for num_stages in [2, 3, 4]
164
+ ],
165
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
166
+ )
167
+ @triton.jit(do_not_specialize=['T'])
168
+ def chunk_fwd_kernel_h_reduction(
169
+ g,
170
+ gk,
171
+ gv,
172
+ hs,
173
+ hr,
174
+ ht,
175
+ offsets,
176
+ split_offsets,
177
+ T,
178
+ S: tl.constexpr,
179
+ H: tl.constexpr,
180
+ K: tl.constexpr,
181
+ V: tl.constexpr,
182
+ BT: tl.constexpr,
183
+ BK: tl.constexpr,
184
+ BV: tl.constexpr,
185
+ USE_G: tl.constexpr,
186
+ USE_GK: tl.constexpr,
187
+ USE_GV: tl.constexpr,
188
+ STORE_FINAL_STATE: tl.constexpr,
189
+ USE_OFFSETS: tl.constexpr,
190
+ HEAD_FIRST: tl.constexpr
191
+ ):
192
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
193
+ i_n, i_h = i_nh // H, i_nh % H
194
+ if USE_OFFSETS:
195
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
196
+ T = eos - bos
197
+ NS = tl.cdiv(T, S)
198
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
199
+ else:
200
+ bos, eos = i_n * T, i_n * T + T
201
+ NS = tl.cdiv(T, S)
202
+ boh = i_n * NS
203
+
204
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
205
+ # skip the first split
206
+ for i_s in range(1, NS):
207
+ p_hs = tl.make_block_ptr(hs + ((boh + i_s-1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
208
+ p_hr = tl.make_block_ptr(hr + ((boh + i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
209
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
210
+ tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1))
211
+
212
+ for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)):
213
+ last_idx = min(i_t * BT + BT, T) - 1
214
+ # scalar decay
215
+ if USE_G:
216
+ if HEAD_FIRST:
217
+ b_g_last = tl.load(g + i_nh * T + last_idx)
218
+ else:
219
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
220
+ b_h *= exp(b_g_last)
221
+
222
+ # vector decay, h = Diag(gk) @ h
223
+ if USE_GK:
224
+ if HEAD_FIRST:
225
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
226
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
227
+ else:
228
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
229
+
230
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
231
+ b_h *= exp(b_gk_last)[:, None]
232
+
233
+ # vector decay, h = h @ Diag(gv)
234
+ if USE_GV:
235
+ if HEAD_FIRST:
236
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
237
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
238
+ else:
239
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
240
+
241
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
242
+ b_h *= exp(b_gv_last)[None, :]
243
+
244
+ if NS > 1:
245
+ if STORE_FINAL_STATE:
246
+ p_hs = tl.make_block_ptr(hs + ((boh + NS-1) * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
247
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
248
+ b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32)
249
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
250
+
251
+
252
+ @triton.heuristics({
253
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
254
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
255
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
256
+ })
257
+ @triton.autotune(
258
+ configs=[
259
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
260
+ for BK in [32, 64]
261
+ for BV in [32, 64]
262
+ for num_warps in [2, 4, 8]
263
+ for num_stages in [2, 3]
264
+ ],
265
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
266
+ )
267
+ @triton.jit(do_not_specialize=['T'])
268
+ def chunk_bwd_kernel_dh_split(
269
+ q,
270
+ g,
271
+ gk,
272
+ gv,
273
+ do,
274
+ dht,
275
+ dhs,
276
+ dhr,
277
+ dh0,
278
+ offsets,
279
+ split_indices,
280
+ scale,
281
+ T,
282
+ S: tl.constexpr,
283
+ HQ: tl.constexpr,
284
+ H: tl.constexpr,
285
+ K: tl.constexpr,
286
+ V: tl.constexpr,
287
+ BT: tl.constexpr,
288
+ BK: tl.constexpr,
289
+ BV: tl.constexpr,
290
+ NG: tl.constexpr,
291
+ USE_G: tl.constexpr,
292
+ USE_GK: tl.constexpr,
293
+ USE_GV: tl.constexpr,
294
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
295
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
296
+ USE_OFFSETS: tl.constexpr,
297
+ HEAD_FIRST: tl.constexpr
298
+ ):
299
+ # handle one split at a time
300
+ # i_h: head index
301
+ # i_n: sequence index
302
+ # i_s: local split index inside a sequence
303
+ i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
304
+ i_ss, i_hq = i_sh // HQ, i_sh % HQ
305
+ if USE_OFFSETS:
306
+ i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32)
307
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
308
+ T = eos - bos
309
+ NS = tl.cdiv(T, S)
310
+ else:
311
+ NS = tl.cdiv(T, S)
312
+ i_n, i_s = i_ss // NS, i_ss % NS
313
+ bos, eos = i_n * T, i_n * T + T
314
+ i_nh = i_n * HQ + i_hq
315
+ i_ng, i_h = i_nh // NG, i_hq // NG
316
+
317
+ # [BK, BV]
318
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
319
+ if i_s == NS - 1:
320
+ if USE_FINAL_STATE_GRADIENT:
321
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
322
+ b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
323
+ p_dhr = tl.make_block_ptr(dhr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
325
+
326
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
327
+ if HEAD_FIRST:
328
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
329
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
330
+ else:
331
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
332
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
333
+
334
+ b_q = tl.load(p_q, boundary_check=(0, 1))
335
+ b_q = (b_q * scale).to(b_q.dtype)
336
+ # [BT, BV]
337
+ b_do = tl.load(p_do, boundary_check=(0, 1))
338
+
339
+ last_idx = min(i_t * BT + BT, T) - 1
340
+ if USE_G:
341
+ if HEAD_FIRST:
342
+ p_g = g + i_ng * T + i_t * BT + tl.arange(0, BT)
343
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
344
+ b_g_last = tl.load(g + i_ng * T + last_idx)
345
+ else:
346
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
347
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
348
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
349
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
350
+ b_dh *= exp(b_g_last)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_ng * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
356
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
357
+ else:
358
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
359
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
360
+
361
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
362
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
363
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
364
+ b_dh *= exp(b_gk_last)[:, None]
365
+
366
+ if USE_GV:
367
+ if HEAD_FIRST:
368
+ p_gv = tl.make_block_ptr(gv + i_ng * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
369
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
370
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
371
+ else:
372
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
373
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
374
+
375
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
376
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
377
+
378
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
379
+ b_dh *= exp(b_gv_last)[None, :]
380
+
381
+ b_dh += tl.dot(b_q, b_do)
382
+
383
+ if NS > 1:
384
+ p_dhs = tl.make_block_ptr(dhs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
385
+ tl.store(p_dhs, b_dh.to(p_dhs.dtype.element_ty), boundary_check=(0, 1))
386
+ elif STORE_INITIAL_STATE_GRADIENT:
387
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
388
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
389
+
390
+
391
+ @triton.heuristics({
392
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
393
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
394
+ })
395
+ @triton.autotune(
396
+ configs=[
397
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
398
+ for BK in [32, 64]
399
+ for BV in [32, 64]
400
+ for num_warps in [2, 4, 8]
401
+ for num_stages in [2, 3, 4]
402
+ ],
403
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV'],
404
+ )
405
+ @triton.jit(do_not_specialize=['T'])
406
+ def chunk_bwd_kernel_dh_reduction(
407
+ g,
408
+ gk,
409
+ gv,
410
+ dhs,
411
+ dhr,
412
+ dh0,
413
+ offsets,
414
+ split_offsets,
415
+ T,
416
+ S: tl.constexpr,
417
+ H: tl.constexpr,
418
+ HQ: tl.constexpr,
419
+ K: tl.constexpr,
420
+ V: tl.constexpr,
421
+ BT: tl.constexpr,
422
+ BK: tl.constexpr,
423
+ BV: tl.constexpr,
424
+ NG: tl.constexpr,
425
+ USE_G: tl.constexpr,
426
+ USE_GK: tl.constexpr,
427
+ USE_GV: tl.constexpr,
428
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
429
+ USE_OFFSETS: tl.constexpr,
430
+ HEAD_FIRST: tl.constexpr
431
+ ):
432
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
433
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
434
+ i_ng, i_h = i_nh // NG, i_hq // NG
435
+ if USE_OFFSETS:
436
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
437
+ T = eos - bos
438
+ NS = tl.cdiv(T, S)
439
+ boh = tl.load(split_offsets + i_n).to(tl.int32)
440
+ else:
441
+ bos, eos = i_n * T, i_n * T + T
442
+ NS = tl.cdiv(T, S)
443
+ boh = i_n * NS
444
+
445
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
446
+ for i_s in range(NS - 2, -1, -1):
447
+ p_dhs = tl.make_block_ptr(dhs + ((boh+i_s+1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
448
+ p_dhr = tl.make_block_ptr(dhr + ((boh+i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
449
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
450
+ tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1))
451
+
452
+ for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1):
453
+ last_idx = min(i_t * BT + BT, T) - 1
454
+ # scalar decay
455
+ if USE_G:
456
+ if HEAD_FIRST:
457
+ b_g_last = tl.load(g + i_ng * T + last_idx)
458
+ else:
459
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
460
+ b_dh *= exp(b_g_last)
461
+
462
+ if USE_GK:
463
+ if HEAD_FIRST:
464
+ p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
465
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
466
+ else:
467
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
468
+
469
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
470
+ b_dh *= exp(b_gk_last)[:, None]
471
+
472
+ if USE_GV:
473
+ if HEAD_FIRST:
474
+ p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
475
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
476
+ else:
477
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
478
+
479
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
480
+ b_dh *= exp(b_gv_last)[None, :]
481
+
482
+ if NS > 1:
483
+ if STORE_INITIAL_STATE_GRADIENT:
484
+ p_dhs = tl.make_block_ptr(dhs + (boh * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
485
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
486
+ b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32)
487
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
488
+
489
+
490
+ def chunk_fwd_h(
491
+ k: torch.Tensor,
492
+ v: torch.Tensor,
493
+ g: torch.Tensor,
494
+ gk: torch.Tensor,
495
+ gv: torch.Tensor,
496
+ h0: torch.Tensor,
497
+ output_final_state: bool,
498
+ offsets: Optional[torch.LongTensor] = None,
499
+ split_offsets: Optional[torch.LongTensor] = None,
500
+ split_indices: Optional[torch.LongTensor] = None,
501
+ head_first: bool = True,
502
+ chunk_size: int = 64,
503
+ split_size: int = 256,
504
+ states_in_fp32: bool = True
505
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
506
+ if head_first:
507
+ B, H, T, K, V = *k.shape, v.shape[-1]
508
+ else:
509
+ B, T, H, K, V = *k.shape, v.shape[-1]
510
+ # B: batch size
511
+ # N: the actual number of sequences in the batch
512
+ # H: number of heads
513
+ # T: sequence length, can be variable across sequences
514
+ # S: split size, a multiple of chunk size
515
+ # BT: chunk size
516
+ S, BT = split_size, chunk_size
517
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
518
+ if offsets is None:
519
+ N = B
520
+ NS = N * triton.cdiv(T, S)
521
+ else:
522
+ N = len(offsets) - 1
523
+ NS = split_offsets[-1]
524
+
525
+ # unreduced kv states per split
526
+ hs = k.new_empty(NS, H, K, V, dtype=torch.float)
527
+ # reduced states per split
528
+ hr = k.new_empty(NS, H, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
529
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
530
+ # parallelized over splits
531
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * H)
532
+ chunk_fwd_kernel_h_split[grid](
533
+ k=k,
534
+ v=v,
535
+ g=g,
536
+ gk=gk,
537
+ gv=gv,
538
+ hs=hs,
539
+ hr=hr,
540
+ h0=h0,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ split_indices=split_indices,
544
+ T=T,
545
+ S=S,
546
+ H=H,
547
+ K=K,
548
+ V=V,
549
+ BT=BT,
550
+ USE_G=g is not None,
551
+ USE_GK=gk is not None,
552
+ USE_GV=gv is not None,
553
+ HEAD_FIRST=head_first
554
+ )
555
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
556
+ chunk_fwd_kernel_h_reduction[grid](
557
+ g=g,
558
+ gk=gk,
559
+ gv=gv,
560
+ hs=hs,
561
+ hr=hr,
562
+ ht=ht,
563
+ offsets=offsets,
564
+ split_offsets=split_offsets,
565
+ T=T,
566
+ S=S,
567
+ H=H,
568
+ K=K,
569
+ V=V,
570
+ BT=BT,
571
+ USE_G=g is not None,
572
+ USE_GK=gk is not None,
573
+ USE_GV=gv is not None,
574
+ HEAD_FIRST=head_first
575
+ )
576
+ return hr, ht
577
+
578
+
579
+ def chunk_bwd_dh(
580
+ q: torch.Tensor,
581
+ k: torch.Tensor,
582
+ v: torch.Tensor,
583
+ g: torch.Tensor,
584
+ gk: torch.Tensor,
585
+ gv: torch.Tensor,
586
+ do: torch.Tensor,
587
+ h0: torch.Tensor,
588
+ dht: torch.Tensor,
589
+ scale: float,
590
+ offsets: Optional[torch.Tensor] = None,
591
+ split_offsets: Optional[torch.Tensor] = None,
592
+ split_indices: Optional[torch.Tensor] = None,
593
+ head_first: bool = True,
594
+ chunk_size: int = 64,
595
+ split_size: int = 256,
596
+ states_in_fp32: bool = True
597
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
598
+ if head_first:
599
+ B, H, T, K, V = *k.shape, v.shape[-1]
600
+ HQ = q.shape[1]
601
+ else:
602
+ B, T, H, K, V = *k.shape, v.shape[-1]
603
+ HQ = q.shape[2]
604
+ # B: batch size
605
+ # N: the actual number of sequences in the batch
606
+ # H: number of heads
607
+ # T: sequence length, can be variable across sequences
608
+ # S: split size, a multiple of chunk size
609
+ # BT: chunk size
610
+ S, BT = max(chunk_size, min(split_size, triton.next_power_of_2(T))), chunk_size
611
+ assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}"
612
+ if offsets is None:
613
+ N = B
614
+ NS = N * triton.cdiv(T, S)
615
+ else:
616
+ N = len(offsets) - 1
617
+ NS = split_offsets[-1]
618
+ # number of groups in GQA
619
+ NG = HQ // H
620
+
621
+ dhs = q.new_empty(NS, HQ, K, V, dtype=torch.float)
622
+ dhr = q.new_empty(NS, HQ, K, V, dtype=torch.float if states_in_fp32 else k.dtype)
623
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
624
+
625
+ # parallelized over splits
626
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * HQ)
627
+ chunk_bwd_kernel_dh_split[grid](
628
+ q=q,
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ do=do,
633
+ dht=dht,
634
+ dhs=dhs,
635
+ dhr=dhr,
636
+ dh0=dh0,
637
+ offsets=offsets,
638
+ split_indices=split_indices,
639
+ scale=scale,
640
+ T=T,
641
+ S=S,
642
+ HQ=HQ,
643
+ H=H,
644
+ K=K,
645
+ V=V,
646
+ BT=BT,
647
+ NG=NG,
648
+ USE_G=g is not None,
649
+ USE_GK=gk is not None,
650
+ USE_GV=gv is not None,
651
+ HEAD_FIRST=head_first,
652
+ )
653
+
654
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
655
+ chunk_bwd_kernel_dh_reduction[grid](
656
+ g=g,
657
+ gk=gk,
658
+ gv=gv,
659
+ dhs=dhs,
660
+ dhr=dhr,
661
+ dh0=dh0,
662
+ offsets=offsets,
663
+ split_offsets=split_offsets,
664
+ T=T,
665
+ S=S,
666
+ HQ=HQ,
667
+ H=H,
668
+ K=K,
669
+ V=V,
670
+ BT=BT,
671
+ NG=NG,
672
+ USE_G=g is not None,
673
+ USE_GK=gk is not None,
674
+ USE_GV=gv is not None,
675
+ HEAD_FIRST=head_first
676
+ )
677
+ return dhr, dh0
fla/ops/common/fused_recurrent.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import chunk_global_cumsum
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4]
24
+ ],
25
+ key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"],
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ o,
36
+ h0,
37
+ ht,
38
+ offsets,
39
+ scale,
40
+ T,
41
+ B: tl.constexpr,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ V: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ REVERSE: tl.constexpr,
48
+ USE_G: tl.constexpr,
49
+ USE_GK: tl.constexpr,
50
+ USE_GV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ STORE_FINAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ # indices
57
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
61
+ all = T
62
+ T = eos - bos
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ all = B * T
66
+
67
+ if HEAD_FIRST:
68
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
69
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
70
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
71
+ p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
72
+ if USE_G:
73
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
74
+ if USE_GK:
75
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
76
+ if USE_GV:
77
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
78
+ else:
79
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
80
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
82
+ p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
83
+ if USE_G:
84
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
85
+ if USE_GK:
86
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
87
+ if USE_GV:
88
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
89
+
90
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
91
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
92
+ mask_h = mask_k[None, :] & mask_v[:, None]
93
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
94
+
95
+ if USE_INITIAL_STATE:
96
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
97
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
98
+
99
+ for _ in range(0, T):
100
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
101
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
102
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
103
+ if USE_GK:
104
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
105
+ b_h = b_h * exp(b_gk[None, :])
106
+ if USE_GV:
107
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
108
+ b_h = b_h * exp(b_gv[:, None])
109
+ if USE_G:
110
+ b_g = tl.load(p_g).to(tl.float32)
111
+ b_h = b_h * exp(b_g)
112
+ b_h += b_k[None, :] * b_v[:, None]
113
+ b_o = b_h * b_q[None, :]
114
+ b_o = tl.sum(b_o, axis=1)
115
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
116
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
117
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
118
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
119
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
120
+ if USE_GK:
121
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
122
+ if USE_GV:
123
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
124
+ if USE_G:
125
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
126
+
127
+ if STORE_FINAL_STATE:
128
+ p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
129
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
130
+
131
+
132
+ @triton.heuristics({
133
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
134
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
135
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
136
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
137
+ })
138
+ @triton.autotune(
139
+ configs=[
140
+ triton.Config({}, num_warps=num_warps)
141
+ for num_warps in [1, 2, 4]
142
+ ],
143
+ key=['BK', 'BV', 'USE_GK', 'USE_GV', 'USE_G'],
144
+ )
145
+ @triton.jit(do_not_specialize=['T'])
146
+ def fused_recurrent_bwd_kernel(
147
+ q,
148
+ k,
149
+ v,
150
+ g,
151
+ gk,
152
+ gv,
153
+ h0,
154
+ do,
155
+ dq,
156
+ dk,
157
+ dv,
158
+ dht,
159
+ dh0,
160
+ offsets,
161
+ scale,
162
+ T,
163
+ B: tl.constexpr,
164
+ H: tl.constexpr,
165
+ K: tl.constexpr,
166
+ V: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ BV: tl.constexpr,
169
+ REVERSE: tl.constexpr,
170
+ USE_G: tl.constexpr,
171
+ USE_GK: tl.constexpr,
172
+ USE_GV: tl.constexpr,
173
+ USE_INITIAL_STATE: tl.constexpr,
174
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
175
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
176
+ USE_OFFSETS: tl.constexpr,
177
+ HEAD_FIRST: tl.constexpr
178
+ ):
179
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
180
+ i_n, i_h = i_nh // H, i_nh % H
181
+ if USE_OFFSETS:
182
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
183
+ all = T
184
+ T = eos - bos
185
+ else:
186
+ bos, eos = i_n * T, i_n * T + T
187
+ all = B * T
188
+
189
+ if HEAD_FIRST:
190
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
191
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
192
+ p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
193
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
194
+ if USE_G:
195
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
196
+ if USE_GK:
197
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
198
+ if USE_GV:
199
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
200
+ else:
201
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
202
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
203
+ p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
204
+ p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
205
+ if USE_G:
206
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
207
+ if USE_GK:
208
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
209
+ if USE_GV:
210
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
211
+
212
+ mask_k = i_k * BK + tl.arange(0, BK) < K
213
+ mask_v = i_v * BV + tl.arange(0, BV) < V
214
+ mask_h = mask_k[:, None] & mask_v[None, :]
215
+
216
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
217
+ if USE_INITIAL_STATE:
218
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
219
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
220
+
221
+ for _ in range(0, T):
222
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
223
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
224
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
225
+ if USE_G:
226
+ b_g = tl.load(p_g).to(tl.float32)
227
+ b_h = b_h * exp(b_g)
228
+ if USE_GK:
229
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
230
+ b_h = b_h * exp(b_gk[:, None])
231
+ if USE_GV:
232
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
233
+ b_h = b_h * exp(b_gv[None, :])
234
+ b_h += b_k[:, None] * b_v[None, :]
235
+ b_dq = b_h * b_do[None, :]
236
+ b_dq = tl.sum(b_dq, axis=1) * scale
237
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)
238
+
239
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
240
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
241
+ p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
242
+ p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
243
+ if USE_G:
244
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
245
+ if USE_GK:
246
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
247
+ if USE_GV:
248
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
249
+
250
+ # sync threads
251
+ tl.debug_barrier()
252
+
253
+ if HEAD_FIRST:
254
+ p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
255
+ p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
256
+ p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
257
+ p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
258
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
260
+ if USE_G:
261
+ p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0)
262
+ if USE_GK:
263
+ p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
264
+ if USE_GV:
265
+ p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
266
+ else:
267
+ p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+ p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
269
+ p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
270
+ p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
271
+ p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
272
+ p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
273
+ if USE_G:
274
+ p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h
275
+ if USE_GK:
276
+ p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
277
+ if USE_GV:
278
+ p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
279
+
280
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
281
+ if USE_FINAL_STATE_GRADIENT:
282
+ p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
283
+ b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32)
284
+
285
+ for _ in range(T):
286
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
287
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
288
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
289
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
290
+ b_dh += b_q[:, None] * b_do[None, :]
291
+ b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
292
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
293
+ if USE_G:
294
+ b_g = tl.load(p_g).to(tl.float32)
295
+ b_dh *= exp(b_g)
296
+ if USE_GK:
297
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
298
+ b_dh *= exp(b_gk)[:, None]
299
+ if USE_GV:
300
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
301
+ b_dh *= exp(b_gv)[None, :]
302
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
303
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
304
+
305
+ p_q += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
306
+ p_k += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
307
+ p_v += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
308
+ p_do += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
309
+ p_dk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
310
+ p_dv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
311
+ if USE_G:
312
+ p_g += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H)
313
+ if USE_GK:
314
+ p_gk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
315
+ if USE_GV:
316
+ p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
317
+
318
+ if STORE_INITIAL_STATE_GRADIENT:
319
+ p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
320
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h)
321
+
322
+
323
+ def fused_recurrent_fwd(
324
+ q: torch.Tensor,
325
+ k: torch.Tensor,
326
+ v: torch.Tensor,
327
+ g: Optional[torch.Tensor] = None,
328
+ gk: Optional[torch.Tensor] = None,
329
+ gv: Optional[torch.Tensor] = None,
330
+ scale: Optional[float] = None,
331
+ initial_state: Optional[torch.Tensor] = None,
332
+ output_final_state: bool = False,
333
+ reverse: bool = False,
334
+ offsets: Optional[torch.LongTensor] = None,
335
+ head_first: bool = True
336
+ ):
337
+ if head_first:
338
+ B, H, T, K, V = *k.shape, v.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *k.shape, v.shape[-1]
341
+ N = B if offsets is None else len(offsets) - 1
342
+ BK, BV = min(K, 64), min(V, 64)
343
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
344
+
345
+ h0 = initial_state
346
+ if output_final_state:
347
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
348
+ else:
349
+ ht = None
350
+ o = q.new_empty(NK, *v.shape, dtype=torch.float32)
351
+
352
+ grid = (NV, NK, N * H)
353
+ fused_recurrent_fwd_kernel[grid](
354
+ q,
355
+ k,
356
+ v,
357
+ g,
358
+ gk,
359
+ gv,
360
+ o,
361
+ h0,
362
+ ht,
363
+ offsets,
364
+ scale,
365
+ T=T,
366
+ B=B,
367
+ H=H,
368
+ K=K,
369
+ V=V,
370
+ BK=BK,
371
+ BV=BV,
372
+ USE_G=g is not None,
373
+ USE_GK=gk is not None,
374
+ USE_GV=gv is not None,
375
+ REVERSE=reverse,
376
+ HEAD_FIRST=head_first
377
+ )
378
+ o = o.sum(0)
379
+ return o, ht
380
+
381
+
382
+ def fused_recurrent_bwd(
383
+ q: torch.Tensor,
384
+ k: torch.Tensor,
385
+ v: torch.Tensor,
386
+ g: Optional[torch.Tensor] = None,
387
+ gk: Optional[torch.Tensor] = None,
388
+ gv: Optional[torch.Tensor] = None,
389
+ o: Optional[torch.Tensor] = None,
390
+ do: Optional[torch.Tensor] = None,
391
+ dht: Optional[torch.Tensor] = None,
392
+ scale: Optional[float] = None,
393
+ initial_state: Optional[torch.Tensor] = None,
394
+ reverse: bool = False,
395
+ offsets: Optional[torch.LongTensor] = None,
396
+ head_first: bool = True
397
+ ):
398
+ if head_first:
399
+ B, H, T, K, V = *k.shape, v.shape[-1]
400
+ else:
401
+ B, T, H, K, V = *k.shape, v.shape[-1]
402
+ N = B if offsets is None else len(offsets) - 1
403
+
404
+ BK, BV = min(K, 64), min(V, 64)
405
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
406
+
407
+ dq = q.new_empty(NV, *q.shape, dtype=torch.float32)
408
+ dk = q.new_empty(NV, *k.shape, dtype=torch.float32)
409
+ dv = q.new_empty(NK, *v.shape, dtype=torch.float32)
410
+ h0 = initial_state
411
+ dh0 = torch.empty_like(initial_state) if initial_state is not None else None
412
+
413
+ grid = (NV, NK, N * H)
414
+ fused_recurrent_bwd_kernel[grid](
415
+ q,
416
+ k,
417
+ v,
418
+ g,
419
+ gk,
420
+ gv,
421
+ h0,
422
+ do,
423
+ dq,
424
+ dk,
425
+ dv,
426
+ dht,
427
+ dh0,
428
+ offsets,
429
+ scale,
430
+ B=B,
431
+ T=T,
432
+ H=H,
433
+ K=K,
434
+ V=V,
435
+ BK=BK,
436
+ BV=BV,
437
+ USE_G=g is not None,
438
+ USE_GK=gk is not None,
439
+ USE_GV=gv is not None,
440
+ REVERSE=reverse,
441
+ HEAD_FIRST=head_first
442
+ )
443
+ dq = dq.sum(0)
444
+ dk = dk.sum(0)
445
+ dv = dv.sum(0)
446
+ dg, dgk, dgv = None, None, None
447
+ if g is not None:
448
+ dg = chunk_global_cumsum(
449
+ (dq * q.float() - dk * k.float()).sum(-1),
450
+ reverse=not reverse,
451
+ offsets=offsets,
452
+ head_first=head_first
453
+ )
454
+ if gk is not None:
455
+ dgk = chunk_global_cumsum(
456
+ dq * q.float() - dk * k.float(),
457
+ reverse=not reverse,
458
+ offsets=offsets,
459
+ head_first=head_first
460
+ )
461
+ if gv is not None:
462
+ dgv = chunk_global_cumsum(
463
+ do.float() * o.float() - dv * v.float(),
464
+ reverse=not reverse,
465
+ offsets=offsets,
466
+ head_first=head_first
467
+ )
468
+
469
+ return dq, dk, dv, dg, dgk, dgv, dh0
470
+
471
+
472
+ class FusedRecurrentFunction(torch.autograd.Function):
473
+
474
+ @staticmethod
475
+ @input_guard
476
+ @autocast_custom_fwd
477
+ def forward(
478
+ ctx,
479
+ q: torch.Tensor,
480
+ k: torch.Tensor,
481
+ v: torch.Tensor,
482
+ g: Optional[torch.Tensor] = None,
483
+ gk: Optional[torch.Tensor] = None,
484
+ gv: Optional[torch.Tensor] = None,
485
+ scale: Optional[float] = None,
486
+ initial_state: Optional[torch.Tensor] = None,
487
+ output_final_state: bool = False,
488
+ reverse: bool = False,
489
+ offsets: Optional[torch.LongTensor] = None,
490
+ head_first: bool = True
491
+ ):
492
+ o, ht = fused_recurrent_fwd(
493
+ q=q,
494
+ k=k,
495
+ v=v,
496
+ g=g,
497
+ gk=gk,
498
+ gv=gv,
499
+ scale=scale,
500
+ initial_state=initial_state,
501
+ output_final_state=output_final_state,
502
+ reverse=reverse,
503
+ offsets=offsets,
504
+ head_first=head_first
505
+ )
506
+ ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o)
507
+ ctx.scale = scale
508
+ ctx.reverse = reverse
509
+ ctx.offsets = offsets
510
+ ctx.head_first = head_first
511
+ return o.to(q.dtype), ht
512
+
513
+ @staticmethod
514
+ @input_guard
515
+ @autocast_custom_bwd
516
+ def backward(ctx, do, dht):
517
+ q, k, v, g, gk, gv, initial_state, o = ctx.saved_tensors
518
+ # not supported yet.
519
+ if dht is not None:
520
+ if not dht.eq(0).all():
521
+ if g is not None:
522
+ assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
523
+ if gk is not None:
524
+ assert gk.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
525
+ if gv is not None:
526
+ assert gv.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
527
+ dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd(
528
+ q=q,
529
+ k=k,
530
+ v=v,
531
+ g=g,
532
+ gk=gk,
533
+ gv=gv,
534
+ o=o,
535
+ do=do,
536
+ dht=dht,
537
+ scale=ctx.scale,
538
+ initial_state=initial_state,
539
+ reverse=ctx.reverse,
540
+ offsets=ctx.offsets,
541
+ head_first=ctx.head_first
542
+ )
543
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None
544
+
545
+
546
+ def fused_recurrent(
547
+ q: torch.Tensor,
548
+ k: torch.Tensor,
549
+ v: torch.Tensor,
550
+ g: Optional[torch.Tensor] = None,
551
+ gk: Optional[torch.Tensor] = None,
552
+ gv: Optional[torch.Tensor] = None,
553
+ scale: Optional[float] = None,
554
+ initial_state: Optional[torch.Tensor] = None,
555
+ output_final_state: bool = False,
556
+ reverse: bool = False,
557
+ cu_seqlens: Optional[torch.LongTensor] = None,
558
+ head_first: bool = True
559
+ ):
560
+ if scale is None:
561
+ scale = k.shape[-1] ** -0.5
562
+ return FusedRecurrentFunction.apply(
563
+ q,
564
+ k,
565
+ v,
566
+ g,
567
+ gk,
568
+ gv,
569
+ scale,
570
+ initial_state,
571
+ output_final_state,
572
+ reverse,
573
+ cu_seqlens,
574
+ head_first
575
+ )
fla/ops/common/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from fla.utils import tensor_cache
9
+
10
+
11
+ @triton.autotune(
12
+ configs=[
13
+ triton.Config({}, num_warps=num_warps)
14
+ for num_warps in [4, 8, 16, 32]
15
+ ],
16
+ key=['B'],
17
+ )
18
+ @triton.jit
19
+ def prepare_position_ids_kernel(
20
+ y,
21
+ offsets,
22
+ B: tl.constexpr
23
+ ):
24
+ i_n = tl.program_id(0)
25
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
26
+ T = eos - bos
27
+
28
+ o = tl.arange(0, B)
29
+ for i in range(0, tl.cdiv(T, B) * B, B):
30
+ o_i = o + i
31
+ tl.store(y + bos + o_i, o_i, o_i < T)
32
+
33
+
34
+ @tensor_cache
35
+ def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor:
36
+ return offsets[1:] - offsets[:-1]
37
+
38
+
39
+ @tensor_cache
40
+ def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor:
41
+ return torch.cat([torch.arange(n, dtype=offsets.dtype, device=offsets.device) for n in prepare_lens(offsets).unbind()])
42
+
43
+
44
+ @tensor_cache
45
+ def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
46
+ return position_ids.eq(0).cumsum(0) - 1
47
+
48
+
49
+ @tensor_cache
50
+ def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor:
51
+ position_ids = prepare_position_ids(offsets)
52
+ return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets)
53
+
54
+
55
+ @tensor_cache
56
+ def prepare_chunk_indices(
57
+ offsets: torch.LongTensor,
58
+ chunk_size: int
59
+ ) -> torch.LongTensor:
60
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()])
61
+ return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets)
62
+
63
+
64
+ @tensor_cache
65
+ def prepare_chunk_offsets(
66
+ offsets: torch.LongTensor,
67
+ chunk_size: int
68
+ ) -> torch.LongTensor:
69
+ return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1)
fla/ops/delta_rule/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_delta_rule
4
+ from .fused_chunk import fused_chunk_delta_rule
5
+ from .fused_recurrent import fused_recurrent_delta_rule
6
+
7
+ __all__ = [
8
+ 'fused_chunk_delta_rule',
9
+ 'fused_recurrent_delta_rule',
10
+ 'chunk_delta_rule'
11
+ ]
fla/ops/delta_rule/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
fla/ops/delta_rule/__pycache__/fused_chunk.cpython-311.pyc ADDED
Binary file (430 Bytes). View file
 
fla/ops/delta_rule/__pycache__/fused_recurrent.cpython-311.pyc ADDED
Binary file (34.6 kB). View file
 
fla/ops/delta_rule/__pycache__/wy_fast.cpython-311.pyc ADDED
Binary file (20.8 kB). View file
 
fla/ops/delta_rule/chunk.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ from einops import rearrange
9
+
10
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
11
+ from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
12
+ from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
13
+ from fla.ops.common.utils import prepare_chunk_indices
14
+ from fla.ops.delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u
15
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
16
+
17
+
18
+ def chunk_delta_rule_fwd(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ beta: torch.Tensor,
23
+ scale: float,
24
+ initial_state: torch.Tensor,
25
+ output_final_state: bool,
26
+ offsets: Optional[torch.LongTensor] = None,
27
+ indices: Optional[torch.LongTensor] = None,
28
+ head_first: bool = True,
29
+ chunk_size: int = 64
30
+ ):
31
+ T = q.shape[2] if head_first else q.shape[1]
32
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
33
+ # obtain WY representation. u is actually the new v.
34
+ w, u, A = fwd_prepare_wy_repr(
35
+ k=k,
36
+ v=v,
37
+ beta=beta,
38
+ offsets=offsets,
39
+ indices=indices,
40
+ head_first=head_first,
41
+ chunk_size=BT
42
+ )
43
+
44
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
45
+ k=k,
46
+ w=w,
47
+ u=u,
48
+ g=None,
49
+ initial_state=initial_state,
50
+ output_final_state=output_final_state,
51
+ offsets=offsets,
52
+ indices=indices,
53
+ head_first=head_first,
54
+ chunk_size=BT
55
+ )
56
+ o = chunk_fwd_o(
57
+ q=q,
58
+ k=k,
59
+ v=v_new,
60
+ h=h,
61
+ g=None,
62
+ scale=scale,
63
+ offsets=offsets,
64
+ indices=indices,
65
+ head_first=head_first,
66
+ chunk_size=BT
67
+ )
68
+ return o, A, final_state
69
+
70
+
71
+ def chunk_delta_rule_bwd(
72
+ q: torch.Tensor,
73
+ k: torch.Tensor,
74
+ v: torch.Tensor,
75
+ beta: torch.Tensor,
76
+ A: torch.Tensor,
77
+ scale: float,
78
+ initial_state: torch.Tensor,
79
+ do: torch.Tensor,
80
+ dht: torch.Tensor,
81
+ offsets: Optional[torch.LongTensor] = None,
82
+ indices: Optional[torch.LongTensor] = None,
83
+ head_first: bool = True,
84
+ chunk_size: int = 64
85
+ ):
86
+ T = q.shape[2] if head_first else q.shape[1]
87
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
88
+ w, u = fwd_recompute_w_u(
89
+ k=k,
90
+ v=v,
91
+ beta=beta,
92
+ A=A,
93
+ offsets=offsets,
94
+ indices=indices,
95
+ head_first=head_first,
96
+ chunk_size=BT
97
+ )
98
+ h, v_new, _ = chunk_gated_delta_rule_fwd_h(
99
+ k=k,
100
+ w=w,
101
+ u=u,
102
+ g=None,
103
+ initial_state=initial_state,
104
+ output_final_state=False,
105
+ offsets=offsets,
106
+ indices=indices,
107
+ head_first=head_first,
108
+ chunk_size=BT
109
+ )
110
+ dv = chunk_bwd_dv_local(
111
+ q=q,
112
+ k=k,
113
+ do=do,
114
+ g=None,
115
+ dh=None,
116
+ scale=scale,
117
+ offsets=offsets,
118
+ indices=indices,
119
+ head_first=head_first,
120
+ chunk_size=BT
121
+ )
122
+ dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu(
123
+ q=q,
124
+ k=k,
125
+ w=w,
126
+ g=None,
127
+ h0=initial_state,
128
+ dht=dht,
129
+ do=do,
130
+ dv=dv,
131
+ scale=scale,
132
+ offsets=offsets,
133
+ indices=indices,
134
+ head_first=head_first,
135
+ chunk_size=BT
136
+ )
137
+ dq, dk, dw, _ = chunk_bwd_dqkwg(
138
+ q=q,
139
+ k=k,
140
+ v=v_new,
141
+ h=h,
142
+ w=w,
143
+ dv=dv,
144
+ do=do,
145
+ dh=dh,
146
+ g=None,
147
+ scale=scale,
148
+ offsets=offsets,
149
+ indices=indices,
150
+ head_first=head_first,
151
+ chunk_size=BT
152
+ )
153
+ dk2, dv, db = bwd_prepare_wy_repr(
154
+ k=k,
155
+ v=v,
156
+ beta=beta,
157
+ A=A,
158
+ dw=dw,
159
+ du=dv,
160
+ offsets=offsets,
161
+ indices=indices,
162
+ head_first=head_first,
163
+ chunk_size=BT
164
+ )
165
+ dk.add_(dk2)
166
+ return dq, dk, dv, db, dh0
167
+
168
+
169
+ class ChunkDeltaRuleFunction(torch.autograd.Function):
170
+
171
+ @staticmethod
172
+ @input_guard
173
+ @autocast_custom_fwd
174
+ def forward(
175
+ ctx,
176
+ q: torch.Tensor,
177
+ k: torch.Tensor,
178
+ v: torch.Tensor,
179
+ beta: torch.Tensor,
180
+ scale: float,
181
+ initial_state: torch.Tensor,
182
+ output_final_state: bool,
183
+ offsets: Optional[torch.LongTensor] = None,
184
+ head_first: bool = True,
185
+ use_qk_l2norm_in_kernel: bool = True
186
+ ):
187
+ T = q.shape[2] if head_first else q.shape[1]
188
+ chunk_size = min(64, max(triton.next_power_of_2(T), 16))
189
+
190
+ q_orig = q
191
+ k_orig = k
192
+
193
+ if use_qk_l2norm_in_kernel:
194
+ q = l2norm_fwd(q)
195
+ k = l2norm_fwd(k)
196
+
197
+ # 2-d indices denoting the offsets of chunks in each sequence
198
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
199
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
200
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
201
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
202
+
203
+ o, A, final_state = chunk_delta_rule_fwd(
204
+ q=q,
205
+ k=k,
206
+ v=v,
207
+ beta=beta,
208
+ scale=scale,
209
+ initial_state=initial_state,
210
+ output_final_state=output_final_state,
211
+ offsets=offsets,
212
+ indices=indices,
213
+ head_first=head_first,
214
+ chunk_size=chunk_size
215
+ )
216
+ ctx.save_for_backward(q_orig, k_orig, v, beta, A, initial_state)
217
+ ctx.chunk_size = chunk_size
218
+ ctx.scale = scale
219
+ ctx.offsets = offsets
220
+ ctx.indices = indices
221
+ ctx.head_first = head_first
222
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
223
+ return o.to(q.dtype), final_state
224
+
225
+ @staticmethod
226
+ @input_guard
227
+ @autocast_custom_bwd
228
+ def backward(
229
+ ctx,
230
+ do: torch.Tensor,
231
+ dht: torch.Tensor
232
+ ):
233
+ q, k, v, beta, A, initial_state = ctx.saved_tensors
234
+ use_qk_l2norm_in_kernel = ctx.use_qk_l2norm_in_kernel
235
+ if use_qk_l2norm_in_kernel:
236
+ q, q_orig = l2norm_fwd(q), q
237
+ k, k_orig = l2norm_fwd(k), k
238
+
239
+ dq, dk, dv, db, dh0 = chunk_delta_rule_bwd(
240
+ q=q,
241
+ k=k,
242
+ v=v,
243
+ beta=beta,
244
+ A=A,
245
+ scale=ctx.scale,
246
+ initial_state=initial_state,
247
+ do=do,
248
+ dht=dht,
249
+ offsets=ctx.offsets,
250
+ indices=ctx.indices,
251
+ head_first=ctx.head_first,
252
+ chunk_size=ctx.chunk_size
253
+ )
254
+ if use_qk_l2norm_in_kernel:
255
+ dq = l2norm_bwd(q_orig, dq)
256
+ dk = l2norm_bwd(k_orig, dk)
257
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), db.to(beta.dtype), None, dh0, None, None, None, None, None, None
258
+
259
+
260
+ @torch.compiler.disable
261
+ def chunk_delta_rule(
262
+ q: torch.Tensor,
263
+ k: torch.Tensor,
264
+ v: torch.Tensor,
265
+ beta: torch.Tensor,
266
+ scale: float = None,
267
+ initial_state: torch.Tensor = None,
268
+ output_final_state: bool = False,
269
+ cu_seqlens: Optional[torch.LongTensor] = None,
270
+ head_first: bool = False,
271
+ use_qk_l2norm_in_kernel: bool = False
272
+ ):
273
+ r"""
274
+ Args:
275
+ q (torch.Tensor):
276
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
277
+ k (torch.Tensor):
278
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
279
+ v (torch.Tensor):
280
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
281
+ beta (torch.Tensor):
282
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
283
+ scale (Optional[int]):
284
+ Scale factor for the RetNet attention scores.
285
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
286
+ initial_state (Optional[torch.Tensor]):
287
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
288
+ For equal-length input sequences, `N` equals the batch size `B`.
289
+ Default: `None`.
290
+ output_final_state (Optional[bool]):
291
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
292
+ cu_seqlens (torch.LongTensor):
293
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
294
+ consistent with the FlashAttention API.
295
+ head_first (Optional[bool]):
296
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
297
+ Default: `False`.
298
+ use_qk_l2norm_in_kernel (Optional[bool]):
299
+ Whether to use qk l2norm within the kernel for saving GPU memory.
300
+ Default: `False`.
301
+
302
+ Returns:
303
+ o (torch.Tensor):
304
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
305
+ final_state (torch.Tensor):
306
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
307
+
308
+ Examples::
309
+ >>> import torch
310
+ >>> import torch.nn.functional as F
311
+ >>> from einops import rearrange
312
+ >>> from fla.ops.delta_rule import chunk_delta_rule
313
+ # inputs with equal lengths
314
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
315
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
316
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
317
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
318
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
319
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
320
+ >>> o, ht = chunk_delta_rule(
321
+ q, k, v, beta,
322
+ initial_state=h0,
323
+ output_final_state=True
324
+ )
325
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
326
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
327
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
328
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
329
+ >>> o_var, ht_var = chunk_delta_rule(
330
+ q, k, v, beta,
331
+ initial_state=h0,
332
+ output_final_state=True,
333
+ cu_seqlens=cu_seqlens
334
+ )
335
+ """
336
+ assert q.dtype == k.dtype == v.dtype
337
+ assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
338
+ assert len(beta.shape) == 3, "beta must be of shape (batch size, num of head, seq len)."
339
+
340
+ if cu_seqlens is not None:
341
+ if q.shape[0] != 1:
342
+ raise ValueError(
343
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
344
+ f"Please flatten variable-length inputs before processing."
345
+ )
346
+ if head_first:
347
+ raise RuntimeError(
348
+ "Sequences with variable lengths are not supported for head-first mode"
349
+ )
350
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
351
+ raise ValueError(
352
+ f"The number of initial states is expected to be equal to the number of input sequences, "
353
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
354
+ )
355
+ if head_first:
356
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
357
+ beta = rearrange(beta, 'b h t -> b t h')
358
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
359
+ o, final_state = ChunkDeltaRuleFunction.apply(
360
+ q,
361
+ k,
362
+ v,
363
+ beta,
364
+ scale,
365
+ initial_state,
366
+ output_final_state,
367
+ cu_seqlens,
368
+ False,
369
+ use_qk_l2norm_in_kernel
370
+ )
371
+ if head_first:
372
+ o = rearrange(o, 'b t h v -> b h t v')
373
+ return o, final_state
fla/ops/delta_rule/fused_chunk.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ def fused_chunk_delta_rule(
4
+ **kwargs
5
+ ):
6
+ raise NotImplementedError("fused_chunk_delta_rule is deprecated. Please use chunk_delta_rule instead.")
fla/ops/delta_rule/fused_recurrent.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
12
+ from fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.jit(do_not_specialize=['T'])
21
+ def fused_recurrent_delta_rule_fwd_kernel(
22
+ q,
23
+ k,
24
+ v,
25
+ u,
26
+ beta,
27
+ o,
28
+ h0,
29
+ ht,
30
+ offsets,
31
+ scale,
32
+ T,
33
+ B: tl.constexpr,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ USE_INITIAL_STATE: tl.constexpr,
40
+ STORE_FINAL_STATE: tl.constexpr,
41
+ IS_BETA_HEADWISE: tl.constexpr,
42
+ USE_OFFSETS: tl.constexpr,
43
+ HEAD_FIRST: tl.constexpr
44
+ ):
45
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+ i_n, i_h = i_nh // H, i_nh % H
47
+ if USE_OFFSETS:
48
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
49
+ all = T
50
+ T = eos - bos
51
+ else:
52
+ bos, eos = i_n * T, i_n * T + T
53
+ all = B * T
54
+
55
+ if HEAD_FIRST:
56
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
57
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
58
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
59
+ p_u = u + i_nh * T*V + i_v * BV + tl.arange(0, BV)
60
+ if IS_BETA_HEADWISE:
61
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
62
+ else:
63
+ p_beta = beta + i_nh * T
64
+ p_o = o + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
65
+ else:
66
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
67
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
68
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
69
+ p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
70
+ if IS_BETA_HEADWISE:
71
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
72
+ else:
73
+ p_beta = beta + bos * H + i_h
74
+ p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
75
+
76
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
77
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
78
+ mask_h = mask_k[None, :] & mask_v[:, None]
79
+
80
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
81
+ if USE_INITIAL_STATE:
82
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
83
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
84
+
85
+ for _ in range(0, T):
86
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
87
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
88
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
89
+ b_v_minus = tl.sum(b_h * b_k[None, :], axis=1)
90
+ b_v -= b_v_minus
91
+ if IS_BETA_HEADWISE:
92
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
93
+ else:
94
+ b_beta = tl.load(p_beta).to(tl.float32)
95
+ tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v)
96
+ b_v *= b_beta
97
+ b_h += b_k[None, :] * b_v[:, None]
98
+ b_o = b_h * b_q[None, :]
99
+ b_o = tl.sum(b_o, axis=1)
100
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
101
+
102
+ p_q += K if HEAD_FIRST else H*K
103
+ p_k += K if HEAD_FIRST else H*K
104
+ p_o += V if HEAD_FIRST else H*V
105
+ p_v += V if HEAD_FIRST else H*V
106
+ p_u += V if HEAD_FIRST else H*V
107
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
108
+
109
+ if STORE_FINAL_STATE:
110
+ p_ht = ht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
111
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
116
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
117
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
118
+ })
119
+ @triton.jit(do_not_specialize=['T'])
120
+ def fused_recurrent_delta_rule_bwd_kernel(
121
+ q,
122
+ k,
123
+ v,
124
+ beta,
125
+ h0,
126
+ dh0,
127
+ dht,
128
+ do,
129
+ dq,
130
+ dk,
131
+ dv,
132
+ db,
133
+ offsets,
134
+ scale,
135
+ B: tl.constexpr,
136
+ T,
137
+ H: tl.constexpr,
138
+ K: tl.constexpr,
139
+ V: tl.constexpr,
140
+ BK: tl.constexpr,
141
+ BV: tl.constexpr,
142
+ NK: tl.constexpr,
143
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar
144
+ USE_INITIAL_STATE: tl.constexpr, # whether to use dh0
145
+ USE_FINAL_STATE_GRADIENT: tl.constexpr, # whether to use dht
146
+ USE_OFFSETS: tl.constexpr,
147
+ HEAD_FIRST: tl.constexpr
148
+ ):
149
+ i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
150
+ i_n, i_h = i_nh // H, i_nh % H
151
+ if USE_OFFSETS:
152
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
153
+ all = T
154
+ T = eos - bos
155
+ else:
156
+ bos, eos = i_n * T, i_n * T + T
157
+ all = B * T
158
+
159
+ mask_k = i_k * BK + tl.arange(0, BK) < K
160
+ mask_v = i_v * BV + tl.arange(0, BV) < V
161
+
162
+ if HEAD_FIRST:
163
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
164
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
165
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
166
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
167
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK) + (T - 1) * K
168
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
169
+ if IS_BETA_HEADWISE:
170
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV) + (T - 1) * V
171
+ p_dbeta = db + (i_v * NK*B*H + i_k * B*H + i_nh) * T*V + tl.arange(0, BV) + (T - 1) * V
172
+ else:
173
+ p_beta = beta + i_nh * T + T - 1
174
+ p_dbeta = db + (i_v * B*H + i_nh) * T + T - 1
175
+ else:
176
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
177
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
178
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
179
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
180
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK) + (T - 1) * H*K
181
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV) + (T - 1) * H*V
182
+ if IS_BETA_HEADWISE:
183
+ p_beta = beta + (bos + T - 1) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
184
+ p_dbeta = db + ((i_v * NK + i_k) * all + bos + T - 1) * H*V + i_h * V + tl.arange(0, BV)
185
+ else:
186
+ p_beta = beta + (bos + T - 1) * H + i_h
187
+ p_dbeta = db + (i_v * all + bos + T - 1) * H + i_h
188
+
189
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
190
+ if USE_FINAL_STATE_GRADIENT:
191
+ p_ht = dht + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
192
+ b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32)
193
+
194
+ for _ in range(T):
195
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
196
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
197
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
198
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
199
+ if IS_BETA_HEADWISE:
200
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
201
+ else:
202
+ b_beta = tl.load(p_beta).to(tl.float32)
203
+ b_dh += b_q[:, None] * b_do[None, :]
204
+ b_dk = tl.sum(b_dh * (b_v * b_beta)[None, :], axis=1)
205
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
206
+
207
+ b_db = b_dv * b_v if IS_BETA_HEADWISE else tl.sum(b_dv * b_v)
208
+ b_dv = b_dv * b_beta
209
+
210
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
211
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
212
+ if IS_BETA_HEADWISE:
213
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty), mask=mask_v)
214
+ else:
215
+ tl.store(p_dbeta, b_db.to(p_dbeta.dtype.element_ty))
216
+
217
+ b_dh -= b_k[:, None] * b_dv[None, :]
218
+
219
+ p_q -= K if HEAD_FIRST else H*K
220
+ p_k -= K if HEAD_FIRST else H*K
221
+ p_v -= V if HEAD_FIRST else H*V
222
+ p_do -= V if HEAD_FIRST else H*V
223
+ p_dk -= K if HEAD_FIRST else H*K
224
+ p_dv -= V if HEAD_FIRST else H*V
225
+ p_dbeta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
226
+ p_beta -= (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
227
+
228
+ if USE_INITIAL_STATE:
229
+ p_dh0 = dh0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
230
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :])
231
+
232
+ tl.debug_barrier()
233
+
234
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
235
+
236
+ if HEAD_FIRST:
237
+ p_q = q + i_nh * T*K + i_k * BK + tl.arange(0, BK)
238
+ p_k = k + i_nh * T*K + i_k * BK + tl.arange(0, BK)
239
+ p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV)
240
+ if IS_BETA_HEADWISE:
241
+ p_beta = beta + i_nh * T*V + i_v * BV + tl.arange(0, BV)
242
+ else:
243
+ p_beta = beta + i_nh * T
244
+ p_do = do + i_nh * T*V + i_v * BV + tl.arange(0, BV)
245
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
246
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + i_k * BK + tl.arange(0, BK)
247
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + i_v * BV + tl.arange(0, BV)
248
+ else:
249
+ p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
250
+ p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
251
+ p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
252
+ if IS_BETA_HEADWISE:
253
+ p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
254
+ else:
255
+ p_beta = beta + bos * H + i_h
256
+ p_do = do + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
257
+ p_dq = dq + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
258
+ p_dk = dk + ((i_v * all + bos) * H + i_h) * K + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
260
+
261
+ if USE_INITIAL_STATE:
262
+ mask_h = mask_k[:, None] & mask_v[None, :]
263
+ p_h0 = h0 + i_nh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
264
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
265
+
266
+ for _ in range(0, T):
267
+ b_dk = tl.load(p_dk, mask=mask_k, other=0).to(tl.float32)
268
+ b_dv = tl.load(p_dv, mask=mask_v, other=0).to(tl.float32)
269
+ b_dk -= tl.sum(b_dv[None, :] * b_h, axis=1)
270
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
271
+
272
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
273
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
274
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
275
+ if IS_BETA_HEADWISE:
276
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
277
+ else:
278
+ b_beta = tl.load(p_beta).to(tl.float32)
279
+ b_v *= b_beta
280
+
281
+ b_h += b_k[:, None] * b_v[None, :]
282
+ b_dq = b_h * b_do[None, :]
283
+ d_q = tl.sum(b_dq, axis=1) * scale
284
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k)
285
+
286
+ p_k += K if HEAD_FIRST else H*K
287
+ p_v += V if HEAD_FIRST else H*V
288
+ p_do += V if HEAD_FIRST else H*V
289
+ p_dq += K if HEAD_FIRST else H*K
290
+ p_dk += K if HEAD_FIRST else H*K
291
+ p_dv += V if HEAD_FIRST else H*V
292
+ p_beta += (1 if HEAD_FIRST else H) * (V if IS_BETA_HEADWISE else 1)
293
+
294
+
295
+ def fused_recurrent_delta_rule_fwd(
296
+ q: torch.Tensor,
297
+ k: torch.Tensor,
298
+ v: torch.Tensor,
299
+ beta: torch.Tensor,
300
+ scale: float,
301
+ initial_state: torch.Tensor,
302
+ output_final_state: bool,
303
+ offsets: Optional[torch.LongTensor] = None,
304
+ head_first: bool = True
305
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ N = B if offsets is None else len(offsets) - 1
311
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
312
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
313
+ assert NK == 1, "NK > 1 is not supported yet"
314
+ num_stages = 1
315
+ num_warps = 1
316
+
317
+ o = q.new_empty(NK, *v.shape)
318
+ if output_final_state:
319
+ final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
320
+ else:
321
+ final_state = None
322
+
323
+ grid = (NV, NK, N * H)
324
+ u = torch.empty_like(v)
325
+ fused_recurrent_delta_rule_fwd_kernel[grid](
326
+ q,
327
+ k,
328
+ v,
329
+ u,
330
+ beta,
331
+ o,
332
+ initial_state,
333
+ final_state,
334
+ offsets,
335
+ scale,
336
+ T=T,
337
+ B=B,
338
+ H=H,
339
+ K=K,
340
+ V=V,
341
+ BK=BK,
342
+ BV=BV,
343
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
344
+ HEAD_FIRST=head_first,
345
+ num_warps=num_warps,
346
+ num_stages=num_stages,
347
+ )
348
+ o = o.squeeze(0)
349
+ return o, u, final_state
350
+
351
+
352
+ def fused_recurrent_delta_rule_bwd(
353
+ q: torch.Tensor,
354
+ k: torch.Tensor,
355
+ v: torch.Tensor,
356
+ beta: torch.Tensor,
357
+ dht: torch.Tensor,
358
+ do: torch.Tensor,
359
+ scale: float,
360
+ initial_state: torch.Tensor,
361
+ offsets: Optional[torch.LongTensor] = None,
362
+ head_first: bool = True
363
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
364
+ if head_first:
365
+ B, H, T, K, V = *k.shape, v.shape[-1]
366
+ else:
367
+ B, T, H, K, V = *k.shape, v.shape[-1]
368
+ N = B if offsets is None else len(offsets) - 1
369
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
370
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
371
+ assert NK == 1, "NK > 1 is not supported yet"
372
+ num_stages = 1
373
+ num_warps = 2
374
+
375
+ beta_vector = beta.ndim == v.ndim
376
+
377
+ dq = q.new_empty(NV, *q.shape)
378
+ dk = q.new_empty(NV, *k.shape)
379
+ dv = q.new_empty(NK, *v.shape)
380
+ if beta_vector:
381
+ db = q.new_empty(NV, NK, B, H, T, V) if head_first else q.new_empty(NV, NK, B, T, H, V)
382
+ else:
383
+ db = q.new_empty(NV, B, H, T) if head_first else q.new_empty(NV, B, T, H)
384
+ grid = (NV, NK, N * H)
385
+
386
+ if initial_state is not None and initial_state.requires_grad:
387
+ dh0 = torch.empty_like(initial_state, dtype=torch.float32)
388
+ else:
389
+ dh0 = None
390
+
391
+ fused_recurrent_delta_rule_bwd_kernel[grid](
392
+ q,
393
+ k,
394
+ v,
395
+ beta,
396
+ initial_state,
397
+ dh0,
398
+ dht,
399
+ do,
400
+ dq,
401
+ dk,
402
+ dv,
403
+ db,
404
+ offsets,
405
+ scale,
406
+ T=T,
407
+ B=B,
408
+ H=H,
409
+ K=K,
410
+ V=V,
411
+ BK=BK,
412
+ BV=BV,
413
+ NK=NK,
414
+ IS_BETA_HEADWISE=beta_vector,
415
+ HEAD_FIRST=head_first,
416
+ num_warps=num_warps,
417
+ num_stages=num_stages
418
+ )
419
+ dq = dq.sum(0)
420
+ dk = dk.sum(0)
421
+ dv = dv.sum(0)
422
+ db = db.sum((0, 1)) if beta_vector else db.sum(0)
423
+
424
+ return dq, dk, dv, db, dh0
425
+
426
+
427
+ class FusedRecurrentFunction(torch.autograd.Function):
428
+
429
+ @staticmethod
430
+ @input_guard
431
+ def forward(
432
+ ctx,
433
+ q: torch.Tensor,
434
+ k: torch.Tensor,
435
+ v: torch.Tensor,
436
+ beta: torch.Tensor,
437
+ scale: float,
438
+ initial_state: torch.Tensor,
439
+ output_final_state: bool,
440
+ offsets: Optional[torch.LongTensor] = None,
441
+ head_first: bool = True,
442
+ use_qk_l2norm_in_kernel: bool = False
443
+ ):
444
+ q_orig = q
445
+ k_orig = k
446
+
447
+ if use_qk_l2norm_in_kernel:
448
+ q = l2norm_fwd(q)
449
+ k = l2norm_fwd(k)
450
+
451
+ o, u, final_state = fused_recurrent_delta_rule_fwd(
452
+ q=q,
453
+ k=k,
454
+ v=v,
455
+ beta=beta,
456
+ scale=scale,
457
+ initial_state=initial_state,
458
+ output_final_state=output_final_state,
459
+ offsets=offsets,
460
+ head_first=head_first
461
+ )
462
+
463
+ ctx.save_for_backward(q_orig, k_orig, u, beta, initial_state)
464
+ ctx.scale = scale
465
+ ctx.offsets = offsets
466
+ ctx.head_first = head_first
467
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
468
+ return o, final_state
469
+
470
+ @staticmethod
471
+ @input_guard
472
+ def backward(ctx, do, dht):
473
+ q, k, v, beta, initial_state = ctx.saved_tensors
474
+ if ctx.use_qk_l2norm_in_kernel:
475
+ q, q_orig = l2norm_fwd(q), q
476
+ k, k_orig = l2norm_fwd(k), k
477
+ dq, dk, dv, db, dh0 = fused_recurrent_delta_rule_bwd(
478
+ q=q,
479
+ k=k,
480
+ v=v,
481
+ beta=beta,
482
+ dht=dht,
483
+ do=do,
484
+ scale=ctx.scale,
485
+ initial_state=initial_state,
486
+ offsets=ctx.offsets,
487
+ head_first=ctx.head_first
488
+ )
489
+ if ctx.use_qk_l2norm_in_kernel:
490
+ dq, dk = l2norm_bwd(q_orig, dq), l2norm_bwd(k_orig, dk)
491
+ return dq.to(q), dk.to(k), dv.to(v), db.to(beta), None, dh0, None, None, None, None
492
+
493
+
494
+ @torch.compiler.disable
495
+ def fused_recurrent_delta_rule(
496
+ q: torch.Tensor,
497
+ k: torch.Tensor,
498
+ v: torch.Tensor,
499
+ beta: torch.Tensor = None,
500
+ scale: float = None,
501
+ initial_state: torch.Tensor = None,
502
+ output_final_state: bool = False,
503
+ cu_seqlens: Optional[torch.LongTensor] = None,
504
+ head_first: bool = True,
505
+ use_qk_l2norm_in_kernel: bool = False
506
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
507
+ r"""
508
+ Args:
509
+ q (torch.Tensor):
510
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
511
+ k (torch.Tensor):
512
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
513
+ v (torch.Tensor):
514
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
515
+ beta (torch.Tensor):
516
+ betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`.
517
+ scale (Optional[int]):
518
+ Scale factor for the RetNet attention scores.
519
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
520
+ initial_state (Optional[torch.Tensor]):
521
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
522
+ For equal-length input sequences, `N` equals the batch size `B`.
523
+ Default: `None`.
524
+ output_final_state (Optional[bool]):
525
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
526
+ cu_seqlens (torch.LongTensor):
527
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
528
+ consistent with the FlashAttention API.
529
+ head_first (Optional[bool]):
530
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
531
+ Default: `False`.
532
+
533
+ Returns:
534
+ o (torch.Tensor):
535
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
536
+ final_state (torch.Tensor):
537
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
538
+
539
+ Examples::
540
+ >>> import torch
541
+ >>> import torch.nn.functional as F
542
+ >>> from einops import rearrange
543
+ >>> from fla.ops.delta_rule import fused_recurrent_delta_rule
544
+ # inputs with equal lengths
545
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
546
+ >>> q = torch.randn(B, T, H, K, device='cuda')
547
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
548
+ >>> v = torch.randn(B, T, H, V, device='cuda')
549
+ >>> beta = torch.rand(B, T, H, device='cuda').sigmoid()
550
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
551
+ >>> o, ht = fused_recurrent_delta_rule(
552
+ q, k, v, beta,
553
+ initial_state=h0,
554
+ output_final_state=True
555
+ )
556
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
557
+ >>> q, k, v, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta))
558
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
559
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
560
+ >>> o_var, ht_var = fused_recurrent_delta_rule(
561
+ q, k, v, beta,
562
+ initial_state=h0,
563
+ output_final_state=True,
564
+ cu_seqlens=cu_seqlens
565
+ )
566
+ >>> assert o.allclose(o_var.view(o.shape))
567
+ >>> assert ht.allclose(ht_var)
568
+ """
569
+ if cu_seqlens is not None:
570
+ if q.shape[0] != 1:
571
+ raise ValueError(
572
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
573
+ f"Please flatten variable-length inputs before processing."
574
+ )
575
+ if head_first:
576
+ raise RuntimeError(
577
+ "Sequences with variable lengths are not supported for head-first mode"
578
+ )
579
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
580
+ raise ValueError(
581
+ f"The number of initial states is expected to be equal to the number of input sequences, "
582
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
583
+ )
584
+ if scale is None:
585
+ scale = k.shape[-1] ** -0.5
586
+ else:
587
+ assert scale > 0, "scale must be positive"
588
+ if beta is None:
589
+ beta = torch.ones_like(q[..., 0])
590
+ if head_first:
591
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
592
+ beta = rearrange(beta, 'b h t -> b t h')
593
+ o, final_state = FusedRecurrentFunction.apply(
594
+ q,
595
+ k,
596
+ v,
597
+ beta,
598
+ scale,
599
+ initial_state,
600
+ output_final_state,
601
+ cu_seqlens,
602
+ False,
603
+ use_qk_l2norm_in_kernel
604
+ )
605
+ if head_first:
606
+ o = rearrange(o, 'b t h v -> b h t v')
607
+ return o, final_state
fla/ops/delta_rule/naive.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+
7
+ def delta_rule_recurrence(q, k, v, beta, initial_state=None, output_final_state=True):
8
+ orig_dtype = q.dtype
9
+ b, h, l, d_k = q.shape
10
+ q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta])
11
+ d_v = v.shape[-1]
12
+ o = torch.zeros_like(v)
13
+ S = torch.zeros(b, h, d_k, d_v).to(v)
14
+ q = q * (d_k ** -0.5)
15
+
16
+ if beta.ndim < v.ndim:
17
+ beta = beta[..., None]
18
+
19
+ if initial_state is not None:
20
+ S += initial_state
21
+
22
+ for i in range(l):
23
+ _k = k[:, :, i]
24
+ _q = q[:, :, i]
25
+ _v = v[:, :, i].clone()
26
+ beta_i = beta[:, :, i]
27
+ _v = _v - (S.clone() * _k[..., None]).sum(-2)
28
+ _v = _v * beta_i
29
+ S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2)
30
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
31
+ S = None if output_final_state is False else S
32
+ return o.to(orig_dtype), S
33
+
34
+
35
+ def delta_rule_chunkwise(q, k, v, beta, chunk_size=32):
36
+ b, h, l, d_k = q.shape
37
+ d_v = v.shape[-1]
38
+ q = q * (d_k ** -0.5)
39
+ v = v * beta[..., None]
40
+ k_beta = k * beta[..., None]
41
+
42
+ assert l % chunk_size == 0
43
+
44
+ # compute (I - tri(diag(beta) KK^T))^{-1}
45
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
46
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta])
47
+ attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
48
+ for i in range(1, chunk_size):
49
+ attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2)
50
+ attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device)
51
+
52
+ u = attn @ v
53
+ w = attn @ k_beta
54
+ S = k.new_zeros(b, h, d_k, d_v)
55
+ o = torch.zeros_like(v)
56
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
57
+ for i in range(0, l // chunk_size):
58
+ q_i, k_i = q[:, :, i], k[:, :, i]
59
+ attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0)
60
+ u_i = u[:, :, i] - w[:, :, i] @ S
61
+ o_inter = q_i @ S
62
+ o[:, :, i] = o_inter + attn @ u_i
63
+ S = S + k_i.transpose(-1, -2) @ u_i
64
+
65
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
66
+
67
+
68
+ def delta_rule_parallel(q, k, v, beta, BM=128, BN=32):
69
+ b, h, l, d_k = q.shape
70
+ # d_v = v.shape[-1]
71
+ q = q * (d_k ** -0.5)
72
+ v = v * beta[..., None]
73
+ k_beta = k * beta[..., None]
74
+ # compute (I - tri(diag(beta) KK^T))^{-1}
75
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta])
76
+ mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0)
77
+ T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
78
+ for i in range(1, BN):
79
+ T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2)
80
+ T = T + torch.eye(BN, dtype=torch.float, device=q.device)
81
+
82
+ mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1)
83
+ A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T
84
+ o_intra = A_local @ v
85
+
86
+ # apply cumprod transition matrices on k to the last position within the chunk
87
+ k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta
88
+ # apply cumprod transition matrices on q to the first position within the chunk
89
+ q = q - A_local @ k_beta
90
+ o_intra = A_local @ v
91
+
92
+ A = torch.zeros(b, h, l, l, device=q.device)
93
+
94
+ q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra])
95
+ o = torch.empty_like(v)
96
+ for i in range(0, l, BM):
97
+ q_i = q[:, :, i:i+BM]
98
+ o_i = o_intra[:, :, i:i+BM]
99
+ # intra block
100
+ for j in range(i + BM - 2 * BN, i-BN, -BN):
101
+ k_j = k[:, :, j:j+BN]
102
+ A_ij = q_i @ k_j.transpose(-1, -2)
103
+ mask = torch.arange(i, i+BM) >= (j + BN)
104
+ A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0)
105
+ A[:, :, i:i+BM, j:j+BN] = A_ij
106
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
107
+ o_i += A_ij @ v[:, :, j:j+BN]
108
+ # inter block
109
+ for j in range(i - BN, -BN, -BN):
110
+ k_j = k[:, :, j:j+BN]
111
+ A_ij = q_i @ k_j.transpose(-1, -2)
112
+ A[:, :, i:i+BM, j:j+BN] = A_ij
113
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
114
+ o_i += A_ij @ v[:, :, j:j+BN]
115
+ o[:, :, i:i+BM] = o_i
116
+
117
+ for i in range(0, l//BN):
118
+ A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i]
119
+
120
+ return o, A
fla/ops/delta_rule/parallel.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.ops.delta_rule.wy_fast import fwd_prepare_T
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.autotune(
16
+ configs=[
17
+ triton.Config({}, num_warps=num_warps)
18
+ for num_warps in [1, 2, 4]
19
+ ],
20
+ key=['BT', 'K', 'V'],
21
+ )
22
+ @triton.jit(do_not_specialize=['T'])
23
+ def chunk_transform_qk_fwd_kernel(
24
+ q,
25
+ k,
26
+ v,
27
+ beta,
28
+ o,
29
+ A,
30
+ q_new,
31
+ k_new,
32
+ A_local,
33
+ scale,
34
+ T,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ BT: tl.constexpr,
40
+ OUTPUT_ATTENTIONS: tl.constexpr
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+
44
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
45
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
46
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
47
+ b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(p_q.dtype.element_ty)
48
+ b_k = tl.load(p_k, boundary_check=(0, 1))
49
+ b_v = tl.load(p_v, boundary_check=(0, 1))
50
+
51
+ p_T = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
52
+ b_T = tl.load(p_T, boundary_check=(0, 1))
53
+
54
+ o_i = tl.arange(0, BT)
55
+ m_t = o_i[:, None] >= o_i[None, :]
56
+ b_qk = tl.where(m_t, tl.dot(b_q, tl.trans(b_k), allow_tf32=False), 0).to(b_q.dtype)
57
+ m_t = o_i[:, None] > o_i[None, :]
58
+ b_kk = tl.where(m_t, tl.dot(b_k, tl.trans(b_k), allow_tf32=False), 0).to(b_k.dtype)
59
+
60
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (i_t * BT, ), (BT, ), (0, ))
61
+ b_beta = tl.load(p_beta, boundary_check=(0, ))
62
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
63
+
64
+ b_qkT = tl.dot(b_qk, b_T, allow_tf32=False).to(b_k.dtype)
65
+
66
+ if OUTPUT_ATTENTIONS:
67
+ p_a = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
68
+ tl.store(p_a, b_qkT.to(p_a.dtype.element_ty), boundary_check=(0, 1))
69
+
70
+ b_kkT = tl.dot(b_kk, b_T, allow_tf32=False).to(b_k.dtype)
71
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
72
+ tl.store(p_o, tl.dot(b_qkT, b_v).to(p_o.dtype.element_ty), boundary_check=(0, 1))
73
+
74
+ p_q_new = tl.make_block_ptr(q_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
75
+ tl.store(p_q_new, (b_q - tl.dot(b_qkT, b_k_beta, allow_tf32=False)).to(p_q_new.dtype.element_ty), boundary_check=(0, 1))
76
+
77
+ p_k_new = tl.make_block_ptr(k_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
78
+ b_k_new = b_k - tl.dot(tl.trans(b_kkT), b_k_beta, allow_tf32=False)
79
+ tl.store(p_k_new, b_k_new.to(p_k_new.dtype.element_ty), boundary_check=(0, 1))
80
+
81
+
82
+ def chunk_transform_qk_fwd(
83
+ q: torch.Tensor,
84
+ k: torch.Tensor,
85
+ v: torch.Tensor,
86
+ beta: torch.Tensor,
87
+ A: torch.Tensor,
88
+ scale: float,
89
+ chunk_size: int,
90
+ output_attentions: bool
91
+ ):
92
+ B, H, T, K = k.shape
93
+ BT = chunk_size
94
+ q_new = torch.empty_like(q)
95
+ k_new = torch.empty_like(k)
96
+ o = torch.empty_like(v)
97
+ grid = (triton.cdiv(T, BT), B*H)
98
+ V = v.shape[-1]
99
+ A_local = torch.empty_like(A) if output_attentions else None
100
+ chunk_transform_qk_fwd_kernel[grid](
101
+ q,
102
+ k,
103
+ v,
104
+ beta,
105
+ o,
106
+ A,
107
+ q_new,
108
+ k_new,
109
+ A_local,
110
+ scale=scale,
111
+ T=T,
112
+ K=K,
113
+ V=V,
114
+ BT=BT,
115
+ BK=triton.next_power_of_2(K),
116
+ BV=triton.next_power_of_2(V),
117
+ OUTPUT_ATTENTIONS=output_attentions
118
+ )
119
+ return q_new, k_new, o, A_local
120
+
121
+
122
+ @triton.autotune(
123
+ configs=[
124
+ triton.Config({}, num_warps=1),
125
+ triton.Config({}, num_warps=2),
126
+ ],
127
+ key=['BT'],
128
+ )
129
+ @triton.jit(do_not_specialize=['T'])
130
+ def save_intra_chunk_attn(
131
+ A,
132
+ A_local,
133
+ T,
134
+ BT: tl.constexpr,
135
+ ):
136
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
137
+ p_A = tl.make_block_ptr(A + i_bh * T * T, (T, T), (T, 1), (i_t * BT, i_t * BT), (BT, BT), (1, 0))
138
+ p_A_local = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
139
+ b_A_local = tl.load(p_A_local, boundary_check=(0, 1))
140
+ tl.store(p_A, b_A_local.to(p_A.dtype.element_ty), boundary_check=(0, 1))
141
+
142
+
143
+ @triton.heuristics({
144
+ 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None
145
+ })
146
+ @triton.jit(do_not_specialize=['T'])
147
+ def parallel_delta_rule_fwd_kernel(
148
+ q,
149
+ k,
150
+ k2, # original k
151
+ v,
152
+ beta,
153
+ o,
154
+ o_new,
155
+ attn,
156
+ T,
157
+ K: tl.constexpr,
158
+ V: tl.constexpr,
159
+ BT: tl.constexpr,
160
+ BS: tl.constexpr,
161
+ BK: tl.constexpr,
162
+ BV: tl.constexpr,
163
+ OUTPUT_ATTENTIONS: tl.constexpr
164
+ ):
165
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
166
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
167
+
168
+ # the Q block is kept in the shared memory throughout the whole kernel
169
+ # [BT, BK]
170
+ b_q = tl.zeros([BT, BK], dtype=tl.float32)
171
+ b_q += tl.load(p_q, boundary_check=(0, 1))
172
+
173
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
174
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
175
+ b_o += tl.load(p_o, boundary_check=(0, 1))
176
+
177
+ # As opposed to Flashattention, this kernel requires scanning the KV blocks from right to left
178
+ # Q block and K block have overlap.
179
+ # masks required
180
+ for offset in range((i_t + 1) * BT - 2 * BS, i_t * BT - BS, -BS):
181
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1))
182
+ p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0))
183
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0))
184
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,))
185
+ # [BK, BS]
186
+ b_k = tl.load(p_k, boundary_check=(0, 1))
187
+ # [BS, BV]
188
+ b_v = tl.load(p_v, boundary_check=(0, 1))
189
+ # [BS]
190
+ b_beta = tl.load(p_beta, boundary_check=(0,))
191
+ # [BT, BS]
192
+ m_s = tl.arange(0, BT) >= (offset - i_t*BT + BS)
193
+ b_s = tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False)
194
+ b_s = tl.where(m_s[:, None], b_s, 0)
195
+
196
+ b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
197
+ b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype)
198
+ b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False)
199
+
200
+ if OUTPUT_ATTENTIONS:
201
+ p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0))
202
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
203
+
204
+ # Q block and K block have no overlap
205
+ # no need for mask, thereby saving flops
206
+ for offset in range(i_t * BT - BS, -BS, -BS):
207
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1))
208
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0))
209
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,))
210
+ p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0))
211
+
212
+ # [BK, BS]
213
+ b_k = tl.load(p_k, boundary_check=(0, 1))
214
+ # [BS, BV]
215
+ b_v = tl.load(p_v, boundary_check=(0, 1))
216
+ # [BS]
217
+ b_beta = tl.load(p_beta, boundary_check=(0,))
218
+ # [BT, BS]
219
+ b_s = (tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False))
220
+ # [BT, BV]
221
+ b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
222
+ b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype)
223
+ b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False).to(b_q.dtype)
224
+
225
+ if OUTPUT_ATTENTIONS:
226
+ p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0))
227
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
228
+
229
+ p_o_new = tl.make_block_ptr(o_new + i_bh * T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
230
+ tl.store(p_o_new, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
231
+
232
+
233
+ class ParallelDeltaRuleFunction(torch.autograd.Function):
234
+
235
+ @staticmethod
236
+ @input_guard
237
+ @autocast_custom_fwd
238
+ def forward(ctx, q, k, v, beta, scale, output_attentions):
239
+ B, H, T, K, V = *k.shape, v.shape[-1]
240
+ assert q.shape[-1] <= 128, 'The maximum supported sequence length is 128.'
241
+ BT, BS = 128, 32
242
+ BK = triton.next_power_of_2(k.shape[-1])
243
+ BV = triton.next_power_of_2(v.shape[-1])
244
+ assert BT % BS == 0
245
+
246
+ A = fwd_prepare_T(k, beta, BS)
247
+ attn = q.new_zeros(B, H, T, T) if output_attentions else None
248
+ q_new, k_new, o, A_local = chunk_transform_qk_fwd(
249
+ q,
250
+ k,
251
+ v,
252
+ beta,
253
+ A,
254
+ scale,
255
+ BS,
256
+ output_attentions
257
+ )
258
+
259
+ num_stages = 3 if K <= 64 else 2
260
+ num_warps = 4
261
+ grid = (triton.cdiv(T, BT), B * H)
262
+ o_new = torch.empty_like(o)
263
+
264
+ parallel_delta_rule_fwd_kernel[grid](
265
+ q=q_new,
266
+ k=k_new,
267
+ k2=k,
268
+ v=v,
269
+ beta=beta,
270
+ o=o,
271
+ o_new=o_new,
272
+ attn=attn,
273
+ T=T,
274
+ K=K,
275
+ V=V,
276
+ BT=BT,
277
+ BS=BS,
278
+ BK=BK,
279
+ BV=BV,
280
+ num_stages=num_stages,
281
+ num_warps=num_warps
282
+ )
283
+
284
+ if output_attentions:
285
+ grid = (triton.cdiv(T, BS), B * H)
286
+ save_intra_chunk_attn[grid](
287
+ A=attn,
288
+ A_local=A_local,
289
+ T=T,
290
+ BT=BS
291
+ )
292
+ return o_new.to(q.dtype), attn
293
+
294
+ @staticmethod
295
+ @input_guard
296
+ @autocast_custom_bwd
297
+ def backward(ctx, do, d_attn=None):
298
+ raise NotImplementedError('Backward pass is not implemented. Stay tuned!')
299
+
300
+
301
+ def parallel_delta_rule(
302
+ q: torch.Tensor,
303
+ k: torch.Tensor,
304
+ v: torch.Tensor,
305
+ beta: torch.Tensor,
306
+ scale: float = None,
307
+ output_attentions: bool = False,
308
+ head_first: bool = True
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ r"""
311
+ Args:
312
+ q (torch.Tensor):
313
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
314
+ k (torch.Tensor):
315
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
316
+ v (torch.Tensor):
317
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
318
+ beta (torch.Tensor):
319
+ betas of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`.
320
+ scale (Optional[int]):
321
+ Scale factor for attention scores.
322
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
323
+ output_attentions (bool):
324
+ Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`.
325
+ head_first (Optional[bool]):
326
+ Whether the inputs are in the head-first format.
327
+ Default: `True`.
328
+
329
+ Returns:
330
+ o (torch.Tensor):
331
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
332
+ attn (torch.Tensor):
333
+ Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`.
334
+ """
335
+ if not head_first:
336
+ q, k, v, beta = map(lambda x: x.transpose(1, 2), (q, k, v, beta))
337
+ o, attn = ParallelDeltaRuleFunction.apply(q, k, v, beta, scale, output_attentions)
338
+ if not head_first:
339
+ o = o.transpose(1, 2)
340
+ return o, attn
341
+
342
+
343
+ def naive_delta_rule_parallel(q, k, v, beta, BM=128, BN=32):
344
+ b, h, l, d_k = q.shape
345
+ q = q * (d_k ** -0.5)
346
+ v = v * beta[..., None]
347
+ k_beta = k * beta[..., None]
348
+ # compute (I - tri(diag(beta) KK^T))^{-1}
349
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta])
350
+ mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0)
351
+ T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
352
+ for i in range(1, BN):
353
+ T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2)
354
+ T = T + torch.eye(BN, dtype=q.dtype, device=q.device)
355
+
356
+ mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1)
357
+ A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T
358
+ o_intra = A_local @ v
359
+
360
+ # apply cumprod transition matrices on k to the last position within the chunk
361
+ k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta
362
+ # apply cumprod transition matrices on q to the first position within the chunk
363
+ q = q - A_local @ k_beta
364
+ o_intra = A_local @ v
365
+
366
+ A = torch.zeros(b, h, l, l, device=q.device)
367
+
368
+ q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra])
369
+ o = torch.empty_like(v)
370
+ for i in range(0, l, BM):
371
+ q_i = q[:, :, i:i+BM]
372
+ o_i = o_intra[:, :, i:i+BM]
373
+ # intra block
374
+ for j in range(i + BM - 2 * BN, i-BN, -BN):
375
+ k_j = k[:, :, j:j+BN]
376
+ A_ij = q_i @ k_j.transpose(-1, -2)
377
+ mask = torch.arange(i, i+BM) >= (j + BN)
378
+ A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0)
379
+ A[:, :, i:i+BM, j:j+BN] = A_ij
380
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
381
+ o_i += A_ij @ v[:, :, j:j+BN]
382
+ # inter block
383
+ for j in range(i - BN, -BN, -BN):
384
+ k_j = k[:, :, j:j+BN]
385
+ A_ij = q_i @ k_j.transpose(-1, -2)
386
+ A[:, :, i:i+BM, j:j+BN] = A_ij
387
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
388
+ o_i += A_ij @ v[:, :, j:j+BN]
389
+ o[:, :, i:i+BM] = o_i
390
+
391
+ for i in range(0, l//BN):
392
+ A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i]
393
+
394
+ return o, A
fla/ops/delta_rule/wy_fast.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
11
+ from fla.ops.utils.solve_tril import solve_tril
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
27
+ )
28
+ @triton.jit(do_not_specialize=['T'])
29
+ def fwd_recompute_w_u_kernel(
30
+ k,
31
+ v,
32
+ beta,
33
+ w,
34
+ u,
35
+ A,
36
+ offsets,
37
+ indices,
38
+ T,
39
+ H: tl.constexpr,
40
+ K: tl.constexpr,
41
+ V: tl.constexpr,
42
+ BT: tl.constexpr,
43
+ BK: tl.constexpr,
44
+ BV: tl.constexpr,
45
+ HEAD_FIRST: tl.constexpr,
46
+ USE_OFFSETS: tl.constexpr
47
+ ):
48
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
49
+ i_b, i_h = i_bh // H, i_bh % H
50
+ if USE_OFFSETS:
51
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
52
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
53
+ T = eos - bos
54
+ else:
55
+ bos, eos = i_b * T, i_b * T + T
56
+
57
+ if HEAD_FIRST:
58
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
59
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
60
+ else:
61
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
62
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
63
+ b_beta = tl.load(p_beta, boundary_check=(0,))
64
+ b_A = tl.load(p_A, boundary_check=(0, 1))
65
+
66
+ for i_v in range(tl.cdiv(V, BV)):
67
+ if HEAD_FIRST:
68
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
69
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
70
+ else:
71
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
72
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
73
+ b_v = tl.load(p_v, boundary_check=(0, 1))
74
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
75
+ b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False)
76
+ tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1))
77
+
78
+ for i_k in range(tl.cdiv(K, BK)):
79
+ if HEAD_FIRST:
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
82
+ else:
83
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
84
+ p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
85
+ b_k = tl.load(p_k, boundary_check=(0, 1))
86
+ b_kb = (b_k * b_beta[:, None]).to(b_k.dtype)
87
+ b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False)
88
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
89
+
90
+
91
+ @triton.heuristics({
92
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
93
+ })
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
97
+ for num_warps in NUM_WARPS
98
+ for num_stages in [2, 3, 4]
99
+ ],
100
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'],
101
+ )
102
+ @triton.jit(do_not_specialize=['T'])
103
+ def bwd_prepare_wy_repr_kernel(
104
+ k,
105
+ v,
106
+ beta,
107
+ A,
108
+ dw,
109
+ du,
110
+ dk,
111
+ dv,
112
+ dbeta,
113
+ offsets,
114
+ indices,
115
+ T,
116
+ H: tl.constexpr,
117
+ K: tl.constexpr,
118
+ V: tl.constexpr,
119
+ BT: tl.constexpr,
120
+ BK: tl.constexpr,
121
+ BV: tl.constexpr,
122
+ HEAD_FIRST: tl.constexpr,
123
+ USE_OFFSETS: tl.constexpr
124
+ ):
125
+ i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
126
+ i_b, i_h = i_bh // H, i_bh % H
127
+ if USE_OFFSETS:
128
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
129
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
130
+ T = eos - bos
131
+ else:
132
+ bos, eos = i_b * T, i_b * T + T
133
+
134
+ if HEAD_FIRST:
135
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
136
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1))
137
+ else:
138
+ p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
139
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1))
140
+
141
+ b_beta = tl.load(p_beta, boundary_check=(0,))
142
+ b_A = tl.load(p_A, boundary_check=(0, 1))
143
+
144
+ b_dbeta = tl.zeros([BT], dtype=tl.float32)
145
+ b_dA = tl.zeros([BT, BT], dtype=tl.float32)
146
+ for i_v in range(tl.cdiv(V, BV)):
147
+ if HEAD_FIRST:
148
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
149
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
150
+ p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
151
+ else:
152
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
153
+ p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
154
+ p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
155
+
156
+ b_v = tl.load(p_v, boundary_check=(0, 1))
157
+ b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype)
158
+ b_du = tl.load(p_du, boundary_check=(0, 1))
159
+ b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False)
160
+ b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False)
161
+ b_dv = b_dv_beta * b_beta[:, None]
162
+ b_dbeta += tl.sum(b_dv_beta * b_v, 1)
163
+
164
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
165
+
166
+ for i_k in range(tl.cdiv(K, BK)):
167
+ if HEAD_FIRST:
168
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
169
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
170
+ p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
171
+ else:
172
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
173
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
174
+ p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
175
+ b_k = tl.load(p_k, boundary_check=(0, 1))
176
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
177
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
178
+ b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False)
179
+ b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False)
180
+ b_dk = b_dk_beta * b_beta[:, None]
181
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
182
+
183
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
184
+
185
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0)
186
+ b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
187
+ b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
188
+ b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty)
189
+
190
+ for i_k in range(tl.cdiv(K, BK)):
191
+ if HEAD_FIRST:
192
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
193
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
194
+ else:
195
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
196
+ p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
197
+ b_k = tl.load(p_k, boundary_check=(0, 1))
198
+ b_dk = tl.load(p_dk, boundary_check=(0, 1))
199
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
200
+
201
+ b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False)
202
+ b_dbeta += tl.sum(b_dk_beta * b_k, 1)
203
+ b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False)
204
+ b_dk += b_dk_beta * b_beta[:, None]
205
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
206
+
207
+ if HEAD_FIRST:
208
+ p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
209
+ else:
210
+ p_dbeta = tl.make_block_ptr(dbeta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
211
+ tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,))
212
+
213
+
214
+ def fwd_prepare_wy_repr(
215
+ k: torch.Tensor,
216
+ v: torch.Tensor,
217
+ beta: torch.Tensor,
218
+ offsets: Optional[torch.LongTensor],
219
+ indices: Optional[torch.LongTensor],
220
+ head_first: bool = False,
221
+ chunk_size: int = 64
222
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
223
+ A = chunk_scaled_dot_kkt_fwd(
224
+ k=k,
225
+ beta=beta,
226
+ cu_seqlens=offsets,
227
+ head_first=head_first,
228
+ chunk_size=chunk_size,
229
+ output_dtype=torch.float32
230
+ )
231
+ A = solve_tril(
232
+ A=A,
233
+ cu_seqlens=offsets,
234
+ head_first=head_first,
235
+ output_dtype=k.dtype
236
+ )
237
+
238
+ w, u = fwd_recompute_w_u(
239
+ k=k,
240
+ v=v,
241
+ beta=beta,
242
+ A=A,
243
+ offsets=offsets,
244
+ indices=indices,
245
+ head_first=head_first,
246
+ chunk_size=chunk_size
247
+ )
248
+ return w, u, A
249
+
250
+
251
+ def fwd_recompute_w_u(
252
+ k: torch.Tensor,
253
+ v: torch.Tensor,
254
+ beta: torch.Tensor,
255
+ A: torch.Tensor,
256
+ offsets: Optional[torch.LongTensor],
257
+ indices: Optional[torch.LongTensor],
258
+ head_first: bool,
259
+ chunk_size: int
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ if head_first:
262
+ B, H, T, K, V = *k.shape, v.shape[-1]
263
+ else:
264
+ B, T, H, K, V = *k.shape, v.shape[-1]
265
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
266
+ CONST_TILING = 64 if check_shared_mem() else 32
267
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
268
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
269
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
270
+
271
+ u = torch.empty_like(v)
272
+ w = torch.empty_like(k)
273
+ fwd_recompute_w_u_kernel[(NT, B*H)](
274
+ k,
275
+ v,
276
+ beta,
277
+ w,
278
+ u,
279
+ A,
280
+ offsets=offsets,
281
+ indices=indices,
282
+ T=T,
283
+ H=H,
284
+ K=K,
285
+ V=V,
286
+ BT=BT,
287
+ BK=BK,
288
+ BV=BV,
289
+ HEAD_FIRST=head_first
290
+ )
291
+ return w, u
292
+
293
+
294
+ def bwd_prepare_wy_repr(
295
+ k: torch.Tensor,
296
+ v: torch.Tensor,
297
+ beta: torch.Tensor,
298
+ A: torch.Tensor,
299
+ dw: torch.Tensor,
300
+ du: torch.Tensor,
301
+ offsets: Optional[torch.LongTensor],
302
+ indices: Optional[torch.LongTensor],
303
+ head_first: bool,
304
+ chunk_size: int
305
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
306
+ if head_first:
307
+ B, H, T, K, V = *k.shape, v.shape[-1]
308
+ else:
309
+ B, T, H, K, V = *k.shape, v.shape[-1]
310
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
311
+ CONST_TILING = 64 if check_shared_mem() else 32
312
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
313
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
314
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
315
+
316
+ dk = torch.empty_like(k)
317
+ dv = torch.empty_like(v)
318
+ dbeta = torch.empty_like(beta)
319
+ bwd_prepare_wy_repr_kernel[(NT, B * H)](
320
+ k,
321
+ v,
322
+ beta,
323
+ A,
324
+ dw,
325
+ du,
326
+ dk,
327
+ dv,
328
+ dbeta,
329
+ offsets=offsets,
330
+ indices=indices,
331
+ T=T,
332
+ H=H,
333
+ K=K,
334
+ V=V,
335
+ BT=BT,
336
+ BK=BK,
337
+ BV=BV,
338
+ HEAD_FIRST=head_first
339
+ )
340
+ return dk, dv, dbeta
fla/ops/forgetting_attn/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .parallel import parallel_forgetting_attn
4
+
5
+ __all__ = [
6
+ 'parallel_forgetting_attn'
7
+ ]
fla/ops/forgetting_attn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (264 Bytes). View file
 
fla/ops/forgetting_attn/__pycache__/parallel.cpython-311.pyc ADDED
Binary file (40.1 kB). View file
 
fla/ops/forgetting_attn/parallel.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum
13
+ from fla.ops.utils.op import div, exp, log
14
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
24
+ for num_stages in [2, 3, 4, 5]
25
+ ],
26
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
27
+ )
28
+ @triton.jit
29
+ def parallel_forgetting_attn_fwd_kernel(
30
+ q,
31
+ k,
32
+ v,
33
+ g,
34
+ o,
35
+ lse,
36
+ scale,
37
+ offsets,
38
+ indices,
39
+ T,
40
+ B: tl.constexpr,
41
+ H: tl.constexpr,
42
+ HQ: tl.constexpr,
43
+ G: tl.constexpr,
44
+ K: tl.constexpr,
45
+ V: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BS: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ USE_OFFSETS: tl.constexpr
51
+ ):
52
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
54
+ i_h = i_hq // G
55
+
56
+ if USE_OFFSETS:
57
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
58
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
59
+ T = eos - bos
60
+ else:
61
+ i_n = i_b
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
65
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
66
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
67
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
68
+
69
+ # the Q block is kept in the shared memory throughout the whole kernel
70
+ # [BT, BK]
71
+ b_q = tl.load(p_q, boundary_check=(0, 1))
72
+ b_q = (b_q * scale).to(b_q.dtype)
73
+ # [BT,]
74
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
75
+ # [BT, BV]
76
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
77
+
78
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
79
+ b_acc = tl.zeros([BT], dtype=tl.float32)
80
+
81
+ # [BT]
82
+ o_q = i_t * BT + tl.arange(0, BT)
83
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
84
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
85
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
86
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
87
+
88
+ # [BS]
89
+ o_k = i_s + tl.arange(0, BS)
90
+ # [BK, BS]
91
+ b_k = tl.load(p_k, boundary_check=(0, 1))
92
+ # [BS, BV]
93
+ b_v = tl.load(p_v, boundary_check=(0, 1))
94
+ # [BS,]
95
+ b_gk = tl.load(p_gk, boundary_check=(0,))
96
+ # [BT, BS]
97
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] - b_gk[None, :]
98
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
99
+
100
+ # [BT]
101
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
102
+ b_r = exp(b_mp - b_m)
103
+ # [BT, BS]
104
+ b_p = exp(b_s - b_m[:, None])
105
+ # [BT]
106
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
107
+ # [BT, BV]
108
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
109
+
110
+ b_mp = b_m
111
+
112
+ for i_s in range(i_t * BT - BS, -BS, -BS):
113
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
114
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
115
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
116
+
117
+ # [BK, BS]
118
+ b_k = tl.load(p_k, boundary_check=(0, 1))
119
+ # [BS, BV]
120
+ b_v = tl.load(p_v, boundary_check=(0, 1))
121
+ # [BS,]
122
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
123
+
124
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
125
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
126
+ # [BT, BS]
127
+ b_s = tl.dot(b_q, b_k) + b_gq[:, None] + (b_gn - b_gk)[None, :]
128
+
129
+ b_gq += b_gn - b_gp
130
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
131
+ b_r = exp(b_mp - b_m)
132
+ # [BT, BS]
133
+ b_p = exp(b_s - b_m[:, None])
134
+ # [BT]
135
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
136
+ # [BT, BV]
137
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
138
+
139
+ b_mp = b_m
140
+
141
+ b_o = div(b_o, b_acc[:, None])
142
+ b_m += log(b_acc)
143
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
144
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
145
+
146
+
147
+ @triton.jit
148
+ def parallel_forgetting_attn_bwd_kernel_preprocess(
149
+ o,
150
+ do,
151
+ delta,
152
+ B: tl.constexpr,
153
+ V: tl.constexpr
154
+ ):
155
+ i_n = tl.program_id(0)
156
+ o_d = tl.arange(0, B)
157
+ m_d = o_d < V
158
+
159
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
160
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
161
+ b_delta = tl.sum(b_o * b_do)
162
+
163
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
164
+
165
+
166
+ @triton.heuristics({
167
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
168
+ })
169
+ @triton.autotune(
170
+ configs=[
171
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
172
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
173
+ for num_stages in [2, 3, 4]
174
+ ],
175
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
176
+ )
177
+ @triton.jit(do_not_specialize=['T'])
178
+ def parallel_forgetting_attn_bwd_kernel_dq(
179
+ q,
180
+ k,
181
+ v,
182
+ g,
183
+ lse,
184
+ delta,
185
+ do,
186
+ dq,
187
+ dg,
188
+ scale,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ B: tl.constexpr,
193
+ H: tl.constexpr,
194
+ HQ: tl.constexpr,
195
+ G: tl.constexpr,
196
+ K: tl.constexpr,
197
+ V: tl.constexpr,
198
+ BT: tl.constexpr,
199
+ BS: tl.constexpr,
200
+ BK: tl.constexpr,
201
+ BV: tl.constexpr,
202
+ USE_OFFSETS: tl.constexpr
203
+ ):
204
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
205
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
206
+ i_h = i_hq // G
207
+
208
+ if USE_OFFSETS:
209
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
210
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
211
+ T = eos - bos
212
+ else:
213
+ i_n = i_b
214
+ bos, eos = i_n * T, i_n * T + T
215
+
216
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
217
+ p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
218
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
219
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
220
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
221
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
222
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
223
+
224
+ # [BT, BK]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale).to(b_q.dtype)
227
+ # [BT, BV]
228
+ b_do = tl.load(p_do, boundary_check=(0, 1))
229
+ # [BT]
230
+ b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32)
231
+ b_lse = tl.load(p_lse, boundary_check=(0,))
232
+ b_delta = tl.load(p_delta, boundary_check=(0,))
233
+
234
+ # [BT]
235
+ o_q = i_t * BT + tl.arange(0, BT)
236
+ # [BT, BK]
237
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
238
+ # [BT]
239
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
240
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
241
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
242
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
243
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
244
+
245
+ # [BS]
246
+ o_k = i_s + tl.arange(0, BS)
247
+ # [BK, BS]
248
+ b_k = tl.load(p_k, boundary_check=(0, 1))
249
+ # [BV, BS]
250
+ b_v = tl.load(p_v, boundary_check=(0, 1))
251
+ # [BS,]
252
+ b_gk = tl.load(p_gk, boundary_check=(0,))
253
+ # [BT, BS]
254
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] - b_gk[None, :]
255
+ b_p = exp(tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')))
256
+
257
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
258
+ b_dp = tl.dot(b_do, b_v)
259
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
260
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
261
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
262
+ # [BT]
263
+ b_dg += tl.sum(b_ds, 1)
264
+
265
+ for i_s in range(i_t * BT - BS, -BS, -BS):
266
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
267
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
268
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
269
+
270
+ # [BK, BS]
271
+ b_k = tl.load(p_k, boundary_check=(0, 1))
272
+ # [BV, BS]
273
+ b_v = tl.load(p_v, boundary_check=(0, 1))
274
+ # [BS,]
275
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
276
+
277
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
278
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
279
+ # [BT, BS]
280
+ b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] + (b_gn - b_gk)[None, :]
281
+ b_p = exp(b_s)
282
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
283
+ b_dp = tl.dot(b_do, b_v)
284
+ b_ds = b_p * (b_dp - b_delta[:, None])
285
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
286
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
287
+ # [BT]
288
+ b_dg += tl.sum(b_ds, 1)
289
+
290
+ b_gq += b_gn - b_gp
291
+
292
+ b_dq *= scale
293
+
294
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
295
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
296
+
297
+
298
+ @triton.heuristics({
299
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
300
+ })
301
+ @triton.autotune(
302
+ configs=[
303
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
304
+ for num_warps in [1, 2, 4, 8]
305
+ for num_stages in [2, 3, 4]
306
+ ],
307
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
308
+ )
309
+ @triton.jit(do_not_specialize=['T'])
310
+ def parallel_forgetting_attn_bwd_kernel_dkv(
311
+ q,
312
+ k,
313
+ v,
314
+ g,
315
+ lse,
316
+ delta,
317
+ do,
318
+ dk,
319
+ dv,
320
+ dg,
321
+ offsets,
322
+ indices,
323
+ scale,
324
+ T,
325
+ B: tl.constexpr,
326
+ H: tl.constexpr,
327
+ HQ: tl.constexpr,
328
+ G: tl.constexpr,
329
+ K: tl.constexpr,
330
+ V: tl.constexpr,
331
+ BT: tl.constexpr,
332
+ BS: tl.constexpr,
333
+ BK: tl.constexpr,
334
+ BV: tl.constexpr,
335
+ USE_OFFSETS: tl.constexpr
336
+ ):
337
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
338
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
339
+ i_h = i_hq // G
340
+
341
+ if USE_OFFSETS:
342
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
343
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
344
+ T = eos - bos
345
+ else:
346
+ i_n = i_b
347
+ bos, eos = i_n * T, i_n * T + T
348
+
349
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
350
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
351
+ p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
352
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
353
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
354
+ p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,))
355
+
356
+ # [BT, BK]
357
+ b_k = tl.load(p_k, boundary_check=(0, 1))
358
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
359
+ # [BT, BV]
360
+ b_v = tl.load(p_v, boundary_check=(0, 1))
361
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
362
+ # [BT]
363
+ b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32)
364
+ b_dg = tl.zeros([BT,], dtype=tl.float32)
365
+
366
+ o_k = i_t * BT + tl.arange(0, BT)
367
+ m_k = o_k < T
368
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
369
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
370
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
371
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
372
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
373
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
374
+
375
+ # [BS]
376
+ o_q = i_s + tl.arange(0, BS)
377
+ # [BS, BK]
378
+ b_q = tl.load(p_q, boundary_check=(0, 1))
379
+ b_q = (b_q * scale).to(b_q.dtype)
380
+ # [BS, BV]
381
+ b_do = tl.load(p_do, boundary_check=(0, 1))
382
+ # [BS]
383
+ b_lse = tl.load(p_lse, boundary_check=(0,))
384
+ b_delta = tl.load(p_delta, boundary_check=(0,))
385
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
386
+
387
+ m_q = o_q < T
388
+ m_s = (o_k[:, None] <= o_q[None, :]) & m_k[:, None] & m_q[None, :]
389
+ # [BT, BS]
390
+ b_s = tl.dot(b_k, tl.trans(b_q)) - b_gk[:, None] + (b_gq - b_lse)[None, :]
391
+ b_p = tl.where(m_s, exp(b_s), 0)
392
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
393
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
394
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
395
+ b_dp = tl.dot(b_v, tl.trans(b_do))
396
+ # [BT, BS]
397
+ b_ds = b_p * (b_dp - b_delta[None, :])
398
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
399
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
400
+ # [BT]
401
+ b_dg -= tl.sum(b_ds, 1)
402
+
403
+ b_gk -= tl.load(g + (bos + min((i_t + 1) * BT, T) - 1) * HQ + i_hq).to(tl.float32)
404
+ for i_s in range((i_t + 1) * BT, T, BS):
405
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
406
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
407
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
408
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
409
+ p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
410
+
411
+ # [BS]
412
+ o_q = i_s + tl.arange(0, BS)
413
+ # [BS, BK]
414
+ b_q = tl.load(p_q, boundary_check=(0, 1))
415
+ b_q = (b_q * scale).to(b_q.dtype)
416
+ # [BS, BV]
417
+ b_do = tl.load(p_do, boundary_check=(0, 1))
418
+ # [BS]
419
+ b_lse = tl.load(p_lse, boundary_check=(0,))
420
+ b_delta = tl.load(p_delta, boundary_check=(0,))
421
+ b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32)
422
+
423
+ b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32)
424
+ b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0.
425
+ # [BT, BS]
426
+ b_s = tl.dot(b_k, tl.trans(b_q)) - (b_gk + b_gp)[:, None] + (b_gq - b_lse)[None, :]
427
+ b_p = exp(b_s)
428
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
429
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
430
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
431
+ b_dp = tl.dot(b_v, tl.trans(b_do))
432
+ # [BT, BS]
433
+ b_ds = b_p * (b_dp - b_delta[None, :])
434
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
435
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
436
+ # [BT]
437
+ b_dg -= tl.sum(b_ds, 1)
438
+
439
+ b_gk -= b_gn - b_gp
440
+
441
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
442
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
443
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))
444
+
445
+
446
+ def parallel_forgetting_attn_fwd(
447
+ q: torch.Tensor,
448
+ k: torch.Tensor,
449
+ v: torch.Tensor,
450
+ g: torch.Tensor,
451
+ scale: float,
452
+ chunk_size: int = 128,
453
+ offsets: Optional[torch.LongTensor] = None,
454
+ indices: Optional[torch.LongTensor] = None,
455
+ ):
456
+ B, T, H, K, V = *k.shape, v.shape[-1]
457
+ HQ = q.shape[2]
458
+ G = HQ // H
459
+ BT = chunk_size
460
+ BK = max(16, triton.next_power_of_2(K))
461
+ assert V <= 256, "V must be less than or equal to 256"
462
+ if check_shared_mem('hopper'):
463
+ BS = min(64, max(16, triton.next_power_of_2(T)))
464
+ else:
465
+ BS = min(32, max(16, triton.next_power_of_2(T)))
466
+ BV = min(256, max(16, triton.next_power_of_2(V)))
467
+ NV = triton.cdiv(V, BV)
468
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
469
+
470
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
471
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
472
+
473
+ grid = (NV, NT, B * HQ)
474
+ parallel_forgetting_attn_fwd_kernel[grid](
475
+ q=q,
476
+ k=k,
477
+ v=v,
478
+ g=g,
479
+ o=o,
480
+ lse=lse,
481
+ scale=scale,
482
+ offsets=offsets,
483
+ indices=indices,
484
+ B=B,
485
+ T=T,
486
+ H=H,
487
+ HQ=HQ,
488
+ G=G,
489
+ K=K,
490
+ V=V,
491
+ BT=BT,
492
+ BS=BS,
493
+ BK=BK,
494
+ BV=BV,
495
+ )
496
+ return o, lse
497
+
498
+
499
+ def parallel_forgetting_attn_bwd_preprocess(
500
+ o: torch.Tensor,
501
+ do: torch.Tensor
502
+ ):
503
+ V = o.shape[-1]
504
+ delta = torch.empty_like(o[..., 0], dtype=torch.float)
505
+ parallel_forgetting_attn_bwd_kernel_preprocess[(delta.numel(),)](
506
+ o=o,
507
+ do=do,
508
+ delta=delta,
509
+ B=triton.next_power_of_2(V),
510
+ V=V,
511
+ )
512
+ return delta
513
+
514
+
515
+ def parallel_forgetting_attn_bwd(
516
+ q: torch.Tensor,
517
+ k: torch.Tensor,
518
+ v: torch.Tensor,
519
+ g: torch.Tensor,
520
+ o: torch.Tensor,
521
+ lse: torch.Tensor,
522
+ do: torch.Tensor,
523
+ scale: float = None,
524
+ chunk_size: int = 128,
525
+ offsets: Optional[torch.LongTensor] = None,
526
+ indices: Optional[torch.LongTensor] = None,
527
+ ):
528
+ B, T, H, K, V = *k.shape, v.shape[-1]
529
+ HQ = q.shape[2]
530
+ G = HQ // H
531
+ BT = chunk_size
532
+ BS = min(32, max(16, triton.next_power_of_2(T)))
533
+ BK = max(16, triton.next_power_of_2(K))
534
+ BV = max(16, triton.next_power_of_2(V))
535
+ NV = triton.cdiv(V, BV)
536
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
537
+
538
+ delta = parallel_forgetting_attn_bwd_preprocess(o, do)
539
+ dq = q.new_empty(B, T, HQ, K, dtype=q.dtype)
540
+ dk = q.new_empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float)
541
+ dv = q.new_empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float)
542
+ dg = q.new_empty(g.shape, dtype=torch.float)
543
+ # NOTE: the original `dg` can be destroyed during autotuning
544
+ # this is [a known triton issue](https://github.com/triton-lang/triton/issues/5082), which will be fixed in 3.3 (?)
545
+ # so we need to make a copy of `dg`
546
+ dg2 = q.new_empty(g.shape, dtype=torch.float)
547
+ grid = (NV, NT, B * HQ)
548
+ parallel_forgetting_attn_bwd_kernel_dq[grid](
549
+ q=q,
550
+ k=k,
551
+ v=v,
552
+ g=g,
553
+ lse=lse,
554
+ delta=delta,
555
+ do=do,
556
+ dq=dq,
557
+ dg=dg,
558
+ offsets=offsets,
559
+ indices=indices,
560
+ scale=scale,
561
+ T=T,
562
+ B=B,
563
+ H=H,
564
+ HQ=HQ,
565
+ G=G,
566
+ K=K,
567
+ V=V,
568
+ BT=BT,
569
+ BS=BS,
570
+ BK=BK,
571
+ BV=BV
572
+ )
573
+ parallel_forgetting_attn_bwd_kernel_dkv[grid](
574
+ q=q,
575
+ k=k,
576
+ v=v,
577
+ g=g,
578
+ lse=lse,
579
+ delta=delta,
580
+ do=do,
581
+ dk=dk,
582
+ dv=dv,
583
+ dg=dg2,
584
+ offsets=offsets,
585
+ indices=indices,
586
+ scale=scale,
587
+ T=T,
588
+ B=B,
589
+ H=H,
590
+ HQ=HQ,
591
+ G=G,
592
+ K=K,
593
+ V=V,
594
+ BT=BT,
595
+ BS=BS,
596
+ BK=BK,
597
+ BV=BV
598
+ )
599
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
600
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
601
+ dg = dg.add_(dg2)
602
+ return dq, dk, dv, dg
603
+
604
+
605
+ @torch.compile
606
+ class ParallelForgettingAttentionFunction(torch.autograd.Function):
607
+
608
+ @staticmethod
609
+ @input_guard
610
+ @autocast_custom_fwd
611
+ def forward(ctx, q, k, v, g, scale, offsets):
612
+ ctx.dtype = q.dtype
613
+ if check_shared_mem('hopper'):
614
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
615
+ else:
616
+ chunk_size = min(64, max(16, triton.next_power_of_2(q.shape[1])))
617
+ # 2-d indices denoting the offsets of chunks in each sequence
618
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
619
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
620
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
621
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
622
+
623
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=False)
624
+ o, lse = parallel_forgetting_attn_fwd(
625
+ q=q,
626
+ k=k,
627
+ v=v,
628
+ g=g,
629
+ scale=scale,
630
+ chunk_size=chunk_size,
631
+ offsets=offsets,
632
+ indices=indices
633
+ )
634
+ ctx.save_for_backward(q, k, v, g, o, lse)
635
+ ctx.chunk_size = chunk_size
636
+ ctx.offsets = offsets
637
+ ctx.indices = indices
638
+ ctx.scale = scale
639
+ return o.to(q.dtype)
640
+
641
+ @staticmethod
642
+ @input_guard
643
+ @autocast_custom_bwd
644
+ def backward(ctx, do):
645
+ q, k, v, g, o, lse = ctx.saved_tensors
646
+ dq, dk, dv, dg = parallel_forgetting_attn_bwd(
647
+ q=q,
648
+ k=k,
649
+ v=v,
650
+ g=g,
651
+ o=o,
652
+ lse=lse,
653
+ do=do,
654
+ scale=ctx.scale,
655
+ chunk_size=ctx.chunk_size,
656
+ offsets=ctx.offsets,
657
+ indices=ctx.indices
658
+ )
659
+ dg = chunk_global_cumsum(dg, reverse=True, head_first=False, offsets=ctx.offsets)
660
+ return dq.to(q), dk.to(k), dv.to(v), dg.to(g), None, None, None, None, None, None, None, None
661
+
662
+
663
+ def parallel_forgetting_attn(
664
+ q: torch.Tensor,
665
+ k: torch.Tensor,
666
+ v: torch.Tensor,
667
+ g: torch.Tensor,
668
+ scale: Optional[float] = None,
669
+ cu_seqlens: Optional[torch.LongTensor] = None,
670
+ head_first: bool = False
671
+ ) -> torch.Tensor:
672
+ r"""
673
+ Args:
674
+ q (torch.Tensor):
675
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
676
+ k (torch.Tensor):
677
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
678
+ GQA will be applied if HQ is divisible by H.
679
+ v (torch.Tensor):
680
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
681
+ g (torch.Tensor):
682
+ Forget gates (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
683
+ scale (Optional[int]):
684
+ Scale factor for attention scores.
685
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
686
+ cu_seqlens (torch.LongTensor):
687
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
688
+ consistent with the FlashAttention API.
689
+ head_first (Optional[bool]):
690
+ Whether the inputs are in the head-first format. Default: `False`.
691
+
692
+ Returns:
693
+ o (torch.Tensor):
694
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
695
+ """
696
+ if scale is None:
697
+ scale = k.shape[-1] ** -0.5
698
+ if cu_seqlens is not None:
699
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
700
+ if g is not None:
701
+ g = g.float()
702
+ if head_first:
703
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
704
+ g = rearrange(g, 'b h t -> b t h')
705
+ o = ParallelForgettingAttentionFunction.apply(q, k, v, g, scale, cu_seqlens)
706
+ if head_first:
707
+ o = rearrange(o, 'b t h d -> b h t d')
708
+ return o
fla/ops/gated_delta_rule/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_gated_delta_rule
2
+ from .fused_recurrent import fused_recurrent_gated_delta_rule
3
+
4
+ __all__ = [
5
+ "chunk_gated_delta_rule",
6
+ "fused_recurrent_gated_delta_rule"
7
+ ]
fla/ops/gated_delta_rule/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (356 Bytes). View file
 
fla/ops/gated_delta_rule/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (15 kB). View file