Erland commited on
Commit
bec1e88
·
verified ·
1 Parent(s): 7fdd671

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/__init__.py +110 -0
  2. fla/__pycache__/__init__.cpython-311.pyc +0 -0
  3. fla/__pycache__/utils.cpython-311.pyc +0 -0
  4. fla/ops/__init__.py +45 -0
  5. fla/ops/__pycache__/__init__.cpython-311.pyc +0 -0
  6. fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-311.pyc +0 -0
  7. fla/ops/gla/__init__.py +11 -0
  8. fla/ops/gla/fused_recurrent.py +113 -0
  9. fla/ops/gsa/__init__.py +9 -0
  10. fla/ops/gsa/__pycache__/__init__.cpython-311.pyc +0 -0
  11. fla/ops/gsa/__pycache__/chunk.cpython-311.pyc +0 -0
  12. fla/ops/gsa/__pycache__/fused_recurrent.cpython-311.pyc +0 -0
  13. fla/ops/gsa/chunk.py +1264 -0
  14. fla/ops/gsa/fused_recurrent.py +564 -0
  15. fla/ops/gsa/naive.py +68 -0
  16. fla/ops/hgrn/__init__.py +9 -0
  17. fla/ops/hgrn/__pycache__/__init__.cpython-311.pyc +0 -0
  18. fla/ops/hgrn/__pycache__/chunk.cpython-311.pyc +0 -0
  19. fla/ops/hgrn/__pycache__/fused_recurrent.cpython-311.pyc +0 -0
  20. fla/ops/hgrn/fused_recurrent.py +308 -0
  21. fla/ops/hgrn/naive.py +63 -0
  22. fla/ops/lightning_attn/__pycache__/__init__.cpython-311.pyc +0 -0
  23. fla/ops/lightning_attn/__pycache__/chunk.cpython-311.pyc +0 -0
  24. fla/ops/lightning_attn/chunk.py +74 -0
  25. fla/ops/lightning_attn/fused_recurrent.py +75 -0
  26. fla/ops/linear_attn/__init__.py +11 -0
  27. fla/ops/linear_attn/__pycache__/__init__.cpython-311.pyc +0 -0
  28. fla/ops/linear_attn/__pycache__/chunk.cpython-311.pyc +0 -0
  29. fla/ops/linear_attn/__pycache__/fused_chunk.cpython-311.pyc +0 -0
  30. fla/ops/linear_attn/__pycache__/utils.cpython-311.pyc +0 -0
  31. fla/ops/linear_attn/fused_chunk.py +318 -0
  32. fla/ops/nsa/__init__.py +9 -0
  33. fla/ops/nsa/__pycache__/naive.cpython-311.pyc +0 -0
  34. fla/ops/nsa/__pycache__/parallel.cpython-311.pyc +0 -0
  35. fla/ops/nsa/__pycache__/utils.cpython-311.pyc +0 -0
  36. fla/ops/nsa/naive.py +94 -0
  37. fla/ops/rebased/__pycache__/__init__.cpython-311.pyc +0 -0
  38. fla/ops/rebased/parallel.py +466 -0
  39. fla/ops/retention/__init__.py +13 -0
  40. fla/ops/retention/__pycache__/chunk.cpython-311.pyc +0 -0
  41. fla/ops/retention/__pycache__/parallel.cpython-311.pyc +0 -0
  42. fla/ops/retention/chunk.py +72 -0
  43. fla/ops/retention/fused_recurrent.py +42 -0
  44. fla/ops/retention/naive.py +15 -0
  45. fla/ops/rwkv4/__init__.py +7 -0
  46. fla/ops/rwkv4/fused_recurrent.py +476 -0
  47. fla/ops/rwkv6/__init__.py +9 -0
  48. fla/ops/rwkv6/__pycache__/__init__.cpython-311.pyc +0 -0
  49. fla/ops/rwkv6/__pycache__/chunk.cpython-311.pyc +0 -0
  50. fla/ops/rwkv6/__pycache__/fused_recurrent.cpython-311.pyc +0 -0
fla/__init__.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.layers import (
4
+ ABCAttention,
5
+ Attention,
6
+ BasedLinearAttention,
7
+ BitAttention,
8
+ DeltaNet,
9
+ GatedDeltaNet,
10
+ GatedDeltaProduct,
11
+ GatedLinearAttention,
12
+ GatedSlotAttention,
13
+ HGRN2Attention,
14
+ HGRNAttention,
15
+ LightNetAttention,
16
+ LinearAttention,
17
+ MultiScaleRetention,
18
+ NativeSparseAttention,
19
+ ReBasedLinearAttention,
20
+ RWKV6Attention,
21
+ RWKV7Attention
22
+ )
23
+ from fla.models import (
24
+ ABCForCausalLM,
25
+ ABCModel,
26
+ BitNetForCausalLM,
27
+ BitNetModel,
28
+ DeltaNetForCausalLM,
29
+ DeltaNetModel,
30
+ GatedDeltaNetForCausalLM,
31
+ GatedDeltaNetModel,
32
+ GatedDeltaProductForCausalLM,
33
+ GatedDeltaProductModel,
34
+ GLAForCausalLM,
35
+ GLAModel,
36
+ GSAForCausalLM,
37
+ GSAModel,
38
+ HGRN2ForCausalLM,
39
+ HGRN2Model,
40
+ HGRNForCausalLM,
41
+ LightNetForCausalLM,
42
+ LightNetModel,
43
+ LinearAttentionForCausalLM,
44
+ LinearAttentionModel,
45
+ NSAForCausalLM,
46
+ NSAModel,
47
+ RetNetForCausalLM,
48
+ RetNetModel,
49
+ RWKV6ForCausalLM,
50
+ RWKV6Model,
51
+ RWKV7ForCausalLM,
52
+ RWKV7Model,
53
+ TransformerForCausalLM,
54
+ TransformerModel
55
+ )
56
+
57
+ __all__ = [
58
+ 'ABCAttention',
59
+ 'Attention',
60
+ 'BasedLinearAttention',
61
+ 'BitAttention',
62
+ 'DeltaNet',
63
+ 'GatedDeltaNet',
64
+ 'GatedDeltaProduct',
65
+ 'GatedLinearAttention',
66
+ 'GatedSlotAttention',
67
+ 'HGRNAttention',
68
+ 'HGRN2Attention',
69
+ 'LightNetAttention',
70
+ 'LinearAttention',
71
+ 'MultiScaleRetention',
72
+ 'NativeSparseAttention',
73
+ 'ReBasedLinearAttention',
74
+ 'RWKV6Attention',
75
+ 'RWKV7Attention',
76
+ 'ABCForCausalLM',
77
+ 'ABCModel',
78
+ 'BitNetForCausalLM',
79
+ 'BitNetModel',
80
+ 'DeltaNetForCausalLM',
81
+ 'DeltaNetModel',
82
+ 'GatedDeltaNetForCausalLM',
83
+ 'GatedDeltaNetModel',
84
+ 'GatedDeltaProductForCausalLM',
85
+ 'GatedDeltaProductModel',
86
+ 'GLAForCausalLM',
87
+ 'GLAModel',
88
+ 'GSAForCausalLM',
89
+ 'GSAModel',
90
+ 'HGRNForCausalLM',
91
+ 'HGRNModel',
92
+ 'HGRN2ForCausalLM',
93
+ 'HGRN2Model',
94
+ 'LightNetForCausalLM',
95
+ 'LightNetModel',
96
+ 'LinearAttentionForCausalLM',
97
+ 'LinearAttentionModel',
98
+ 'NSAForCausalLM',
99
+ 'NSAModel',
100
+ 'RetNetForCausalLM',
101
+ 'RetNetModel',
102
+ 'RWKV6ForCausalLM',
103
+ 'RWKV6Model',
104
+ 'RWKV7ForCausalLM',
105
+ 'RWKV7Model',
106
+ 'TransformerForCausalLM',
107
+ 'TransformerModel',
108
+ ]
109
+
110
+ __version__ = '0.1.2'
fla/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.33 kB). View file
 
fla/__pycache__/utils.cpython-311.pyc ADDED
Binary file (13.8 kB). View file
 
fla/ops/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .abc import chunk_abc
4
+ from .attn import parallel_attn
5
+ from .based import fused_chunk_based, parallel_based
6
+ from .delta_rule import chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule
7
+ from .forgetting_attn import parallel_forgetting_attn
8
+ from .gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
9
+ from .generalized_delta_rule import (
10
+ chunk_dplr_delta_rule,
11
+ chunk_iplr_delta_rule,
12
+ fused_recurrent_dplr_delta_rule,
13
+ fused_recurrent_iplr_delta_rule
14
+ )
15
+ from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
16
+ from .gsa import chunk_gsa, fused_recurrent_gsa
17
+ from .hgrn import fused_recurrent_hgrn
18
+ from .lightning_attn import chunk_lightning_attn, fused_recurrent_lightning_attn
19
+ from .linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
20
+ from .nsa import parallel_nsa
21
+ from .retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention
22
+ from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
23
+ from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
24
+ from .simple_gla import chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla
25
+
26
+ __all__ = [
27
+ 'chunk_abc',
28
+ 'parallel_attn',
29
+ 'fused_chunk_based', 'parallel_based',
30
+ 'chunk_delta_rule', 'fused_chunk_delta_rule', 'fused_recurrent_delta_rule',
31
+ 'parallel_forgetting_attn',
32
+ 'chunk_gated_delta_rule', 'fused_recurrent_gated_delta_rule',
33
+ 'chunk_dplr_delta_rule', 'chunk_iplr_delta_rule',
34
+ 'fused_recurrent_dplr_delta_rule', 'fused_recurrent_iplr_delta_rule',
35
+ 'chunk_gla', 'fused_chunk_gla', 'fused_recurrent_gla',
36
+ 'chunk_gsa', 'fused_recurrent_gsa',
37
+ 'fused_recurrent_hgrn',
38
+ 'chunk_lightning_attn', 'fused_recurrent_lightning_attn',
39
+ 'chunk_linear_attn', 'fused_chunk_linear_attn', 'fused_recurrent_linear_attn',
40
+ 'parallel_nsa',
41
+ 'chunk_retention', 'fused_chunk_retention', 'fused_recurrent_retention', 'parallel_retention',
42
+ 'chunk_rwkv6', 'fused_recurrent_rwkv6',
43
+ 'chunk_rwkv7', 'fused_recurrent_rwkv7',
44
+ 'chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla',
45
+ ]
fla/ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (448 Bytes). View file
 
fla/ops/gla/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gla
4
+ from .fused_chunk import fused_chunk_gla
5
+ from .fused_recurrent import fused_recurrent_gla
6
+
7
+ __all__ = [
8
+ 'chunk_gla',
9
+ 'fused_chunk_gla',
10
+ 'fused_recurrent_gla'
11
+ ]
fla/ops/gla/fused_recurrent.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.common.fused_recurrent import fused_recurrent
9
+
10
+
11
+ def fused_recurrent_gla(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ gk: Optional[torch.Tensor] = None,
16
+ gv: Optional[torch.Tensor] = None,
17
+ scale: Optional[int] = None,
18
+ initial_state: Optional[torch.Tensor] = None,
19
+ output_final_state: bool = False,
20
+ reverse: bool = False,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ head_first: bool = True
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""
25
+ Args:
26
+ q (torch.Tensor):
27
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
28
+ k (torch.Tensor):
29
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
30
+ v (torch.Tensor):
31
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
32
+ gk (torch.Tensor):
33
+ Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys.
34
+ gv (torch.Tensor):
35
+ Forget gates of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` applied to values.
36
+ scale (Optional[int]):
37
+ Scale factor for the attention scores.
38
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
39
+ initial_state (Optional[torch.Tensor]):
40
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
41
+ For equal-length input sequences, `N` equals the batch size `B`.
42
+ Default: `None`.
43
+ output_final_state (Optional[bool]):
44
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
45
+ reverse (Optional[bool]):
46
+ If `True`, process the state passing in reverse order. Default: `False`.
47
+ cu_seqlens (torch.LongTensor):
48
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
49
+ consistent with the FlashAttention API.
50
+ head_first (Optional[bool]):
51
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
52
+ Default: `True`.
53
+
54
+ Returns:
55
+ o (torch.Tensor):
56
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
57
+ final_state (torch.Tensor):
58
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
59
+
60
+ Examples::
61
+ >>> import torch
62
+ >>> import torch.nn.functional as F
63
+ >>> from einops import rearrange
64
+ >>> from fla.ops.gla import fused_recurrent_gla
65
+ # inputs with equal lengths
66
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
67
+ >>> q = torch.randn(B, T, H, K, device='cuda')
68
+ >>> k = torch.randn(B, T, H, K, device='cuda')
69
+ >>> v = torch.randn(B, T, H, V, device='cuda')
70
+ >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda'))
71
+ >>> h0 = torch.randn(B, H, K, V, device='cuda')
72
+ >>> o, ht = fused_recurrent_gla(q, k, v, g,
73
+ initial_state=h0,
74
+ output_final_state=True,
75
+ head_first=False)
76
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
77
+ >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g))
78
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
79
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
80
+ >>> o_var, ht_var = fused_recurrent_gla(q, k, v, g,
81
+ initial_state=h0,
82
+ output_final_state=True,
83
+ cu_seqlens=cu_seqlens,
84
+ head_first=False)
85
+ >>> assert o.allclose(o_var.view(o.shape))
86
+ >>> assert ht.allclose(ht_var)
87
+ """
88
+ if cu_seqlens is not None:
89
+ if q.shape[0] != 1:
90
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
91
+ f"Please flatten variable-length inputs before processing.")
92
+ if head_first:
93
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
94
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
95
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
96
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
97
+ if scale is None:
98
+ scale = k.shape[-1] ** -0.5
99
+ o, final_state = fused_recurrent(
100
+ q=q,
101
+ k=k,
102
+ v=v,
103
+ g=None,
104
+ gk=gk,
105
+ gv=gv,
106
+ scale=scale,
107
+ initial_state=initial_state,
108
+ output_final_state=output_final_state,
109
+ reverse=reverse,
110
+ cu_seqlens=cu_seqlens,
111
+ head_first=head_first
112
+ )
113
+ return o, final_state
fla/ops/gsa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gsa
4
+ from .fused_recurrent import fused_recurrent_gsa
5
+
6
+ __all__ = [
7
+ 'chunk_gsa',
8
+ 'fused_recurrent_gsa'
9
+ ]
fla/ops/gsa/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (319 Bytes). View file
 
fla/ops/gsa/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (70.3 kB). View file
 
fla/ops/gsa/__pycache__/fused_recurrent.cpython-311.pyc ADDED
Binary file (26.6 kB). View file
 
fla/ops/gsa/chunk.py ADDED
@@ -0,0 +1,1264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import reduce
10
+
11
+ from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h
12
+ from fla.ops.gla.chunk import chunk_gla_bwd, chunk_gla_fwd
13
+ from fla.ops.utils import chunk_local_cumsum, softmax_bwd, softmax_fwd
14
+ from fla.ops.utils.op import exp, safe_exp
15
+ from fla.utils import input_guard
16
+
17
+
18
+ @triton.heuristics({
19
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
20
+ })
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
24
+ for BK in [32, 64]
25
+ for BV in [32, 64]
26
+ for num_warps in [2, 4, 8]
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['BT']
30
+ )
31
+ @triton.jit(do_not_specialize=['T'])
32
+ def chunk_gsa_fwd_k_kernel_inter(
33
+ q,
34
+ k,
35
+ h,
36
+ g,
37
+ o,
38
+ A,
39
+ offsets,
40
+ indices,
41
+ scale,
42
+ T,
43
+ HQ: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ BV: tl.constexpr,
50
+ NG: tl.constexpr,
51
+ USE_OFFSETS: tl.constexpr,
52
+ HEAD_FIRST: tl.constexpr
53
+ ):
54
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
55
+ i_bg = i_bh // NG
56
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
57
+ i_h = i_hq // NG
58
+ if USE_OFFSETS:
59
+ i_tg = i_t
60
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
61
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
62
+ T = eos - bos
63
+ NT = tl.cdiv(T, BT)
64
+ else:
65
+ NT = tl.cdiv(T, BT)
66
+ i_tg = i_b * NT + i_t
67
+ bos, eos = i_b * T, i_b * T + T
68
+
69
+ o_i = tl.arange(0, BT)
70
+ m_s = o_i[:, None] >= o_i[None, :]
71
+
72
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
73
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
74
+ for i_k in range(tl.cdiv(K, BK)):
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
77
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
78
+ p_h = tl.make_block_ptr(h + (i_bg * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
79
+ else:
80
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
81
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
82
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
83
+
84
+ # [BT, BK]
85
+ b_q = tl.load(p_q, boundary_check=(0, 1))
86
+ b_q = (b_q * scale).to(b_q.dtype)
87
+ # [BK, BT]
88
+ b_k = tl.load(p_k, boundary_check=(0, 1))
89
+ # [BK, BV]
90
+ b_h = tl.load(p_h, boundary_check=(0, 1))
91
+ # [BT, BV]
92
+ b_o += tl.dot(b_q, b_h)
93
+ # [BT, BT]
94
+ b_A += tl.dot(b_q, b_k)
95
+ if HEAD_FIRST:
96
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
97
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
98
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
99
+ else:
100
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
101
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
102
+ p_A = tl.make_block_ptr(A + (bos * HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
103
+ # [BT, BV]
104
+ b_g = tl.load(p_g, boundary_check=(0, 1))
105
+ b_o = b_o * exp(b_g)
106
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
107
+
108
+ # [BT, BT]
109
+ b_A = tl.where(m_s, b_A, 0.)
110
+ if i_v == 0:
111
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
112
+
113
+
114
+ @triton.heuristics({
115
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
116
+ })
117
+ @triton.jit(do_not_specialize=['T'])
118
+ def chunk_gsa_fwd_k_kernel_intra(
119
+ v,
120
+ g,
121
+ o,
122
+ A,
123
+ offsets,
124
+ indices,
125
+ T,
126
+ HQ: tl.constexpr,
127
+ H: tl.constexpr,
128
+ V: tl.constexpr,
129
+ BT: tl.constexpr,
130
+ BC: tl.constexpr,
131
+ BV: tl.constexpr,
132
+ NC: tl.constexpr,
133
+ NG: tl.constexpr,
134
+ USE_OFFSETS: tl.constexpr,
135
+ HEAD_FIRST: tl.constexpr
136
+ ):
137
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
138
+ i_bg = i_bh // NG
139
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
140
+ i_h = i_hq // NG
141
+ i_t, i_i = i_c // NC, i_c % NC
142
+ if USE_OFFSETS:
143
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
144
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
145
+ T = eos - bos
146
+ else:
147
+ bos, eos = i_b * T, i_b * T + T
148
+
149
+ o_v = i_v * BV + tl.arange(0, BV)
150
+ m_v = o_v < V
151
+
152
+ if i_t * BT + i_i * BC > T:
153
+ return
154
+
155
+ if HEAD_FIRST:
156
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
157
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + min(i_t * BT + i_i * BC, T) * V + o_v, BV), BV)
158
+ else:
159
+ p_g = tl.make_block_ptr(g + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
160
+ p_gn = g + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v
161
+ # [BV,]
162
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
163
+ # [BC, BV]
164
+ b_o = tl.zeros([BC, BV], dtype=tl.float32)
165
+ for i_j in range(0, i_i):
166
+ if HEAD_FIRST:
167
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
168
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
169
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
170
+ else:
171
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0))
172
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
173
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
174
+ # [BC, BV]
175
+ b_v = tl.load(p_v, boundary_check=(0, 1))
176
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
177
+ b_vg = (b_v * exp(b_gn[None, :] - b_gv)).to(b_v.dtype)
178
+ # [BC, BC]
179
+ b_A = tl.load(p_A, boundary_check=(0, 1))
180
+ b_o += tl.dot(b_A, b_vg)
181
+ # [BC, BV]
182
+ b_g = tl.load(p_g, boundary_check=(0, 1))
183
+ b_o *= exp(b_g - b_gn[None, :])
184
+
185
+ o_i = tl.arange(0, BC)
186
+ if HEAD_FIRST:
187
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC
188
+ else:
189
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * HQ*BT + i_hq * BT + i_i * BC
190
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
191
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
192
+ if HEAD_FIRST:
193
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
194
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC + j) * V + o_v, BV), BV)
195
+ else:
196
+ p_v = v + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
197
+ p_gv = g + (bos + i_t * BT + i_i * BC + j) * H*V + i_h * V + o_v
198
+ # [BC,]
199
+ b_A = tl.load(A + o_A + j, mask=m_A, other=0)
200
+ # [BV,]
201
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
202
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
203
+ # [BC, BV]
204
+ b_vg = b_v[None, :] * exp(b_g - b_gv[None, :])
205
+ # avoid 0 * inf = inf
206
+ b_o += tl.where(o_i[:, None] >= j, b_A[:, None] * b_vg, 0.)
207
+ if HEAD_FIRST:
208
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
209
+ else:
210
+ p_o = tl.make_block_ptr(o + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
211
+ b_o += tl.load(p_o, boundary_check=(0, 1))
212
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
213
+
214
+
215
+ @triton.heuristics({
216
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
217
+ })
218
+ @triton.autotune(
219
+ configs=[
220
+ triton.Config({}, num_warps=num_warps)
221
+ for num_warps in [2, 4, 8]
222
+ ],
223
+ key=["BT"]
224
+ )
225
+ @triton.jit(do_not_specialize=['T'])
226
+ def chunk_gsa_bwd_k_kernel_dA(
227
+ v,
228
+ g,
229
+ do,
230
+ dA,
231
+ indices,
232
+ offsets,
233
+ scale,
234
+ T,
235
+ B: tl.constexpr,
236
+ HQ: tl.constexpr,
237
+ H: tl.constexpr,
238
+ V: tl.constexpr,
239
+ BT: tl.constexpr,
240
+ BC: tl.constexpr,
241
+ BV: tl.constexpr,
242
+ NC: tl.constexpr,
243
+ NG: tl.constexpr,
244
+ USE_OFFSETS: tl.constexpr,
245
+ HEAD_FIRST: tl.constexpr
246
+ ):
247
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
248
+ i_bg = i_bh // NG
249
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
250
+ i_h = i_hq // NG
251
+ i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC
252
+ if USE_OFFSETS:
253
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
254
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
255
+ all = T
256
+ T = eos - bos
257
+ else:
258
+ bos, eos = i_b * T, i_b * T + T
259
+ all = B * T
260
+
261
+ o_v = i_v * BV + tl.arange(0, BV)
262
+ m_v = o_v < V
263
+
264
+ if i_t * BT + i_i * BC > T:
265
+ return
266
+
267
+ if HEAD_FIRST:
268
+ p_dA = tl.make_block_ptr(dA+(i_v*B*H+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
269
+ else:
270
+ p_dA = tl.make_block_ptr(dA+((i_v*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t*BT+i_i*BC, i_j*BC), (BC, BC), (1, 0))
271
+
272
+ # [BC, BC]
273
+ b_dA = tl.zeros([BC, BC], dtype=tl.float32)
274
+ if i_i > i_j:
275
+ if HEAD_FIRST:
276
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
277
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1))
278
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
279
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
280
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
281
+ else:
282
+ p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
283
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
284
+ p_gn = g + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v
285
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
286
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
287
+ # [BV,]
288
+ b_gn = tl.load(p_gn, mask=m_v, other=0.)
289
+ # [BC, BV]
290
+ b_g = tl.load(p_g, boundary_check=(0, 1))
291
+ b_do = tl.load(p_do, boundary_check=(0, 1))
292
+ b_do = (b_do * exp(b_g - b_gn[None, :]) * scale).to(b_do.dtype)
293
+ # [BV, BC]
294
+ b_v = tl.load(p_v, boundary_check=(0, 1))
295
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
296
+ b_vg = (b_v * exp(b_gn[:, None] - b_gv)).to(b_v.dtype)
297
+ # [BC, BC]
298
+ b_dA = tl.dot(b_do, b_vg)
299
+ elif i_i == i_j:
300
+ if HEAD_FIRST:
301
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
302
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
303
+ p_v = tl.max_contiguous(tl.multiple_of(v + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
304
+ p_gv = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_j * BC) * V + o_v, BV), BV)
305
+ else:
306
+ p_g = tl.make_block_ptr(g + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
307
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
308
+ p_v = v + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
309
+ p_gv = g + (bos + i_t*BT + i_j*BC) * H*V + i_h * V + o_v
310
+ # [BC, BV]
311
+ b_g = tl.load(p_g, boundary_check=(0, 1))
312
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * scale
313
+ m_v = o_v < V
314
+
315
+ o_i = tl.arange(0, BC)
316
+ # [BC, BC]
317
+ m_dA = o_i[:, None] >= o_i[None, :]
318
+ for j in range(0, min(BC, T - i_t * BT - i_j * BC)):
319
+ # [BV,]
320
+ b_v = tl.load(p_v, mask=m_v, other=0).to(tl.float32)
321
+ b_gv = tl.load(p_gv, mask=m_v, other=0).to(tl.float32)
322
+ # [BC,]
323
+ b_dAj = tl.sum(b_do * b_v[None, :] * exp(b_g - b_gv[None, :]), 1)
324
+ b_dA = tl.where((o_i == j)[None, :], b_dAj[:, None], b_dA)
325
+
326
+ p_v += (1 if HEAD_FIRST else H) * V
327
+ p_gv += (1 if HEAD_FIRST else H) * V
328
+ b_dA = tl.where(m_dA, b_dA, 0.)
329
+ tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1))
330
+
331
+
332
+ @triton.heuristics({
333
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
334
+ })
335
+ @triton.autotune(
336
+ configs=[
337
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
338
+ for num_warps in [2, 4]
339
+ for num_stages in [2, 3, 4]
340
+ ],
341
+ key=['BT']
342
+ )
343
+ @triton.jit(do_not_specialize=['T'])
344
+ def chunk_gsa_bwd_k_kernel_dqkvg(
345
+ q,
346
+ k,
347
+ v,
348
+ h,
349
+ g,
350
+ A,
351
+ do,
352
+ dh,
353
+ dq,
354
+ dk,
355
+ dv,
356
+ dg,
357
+ dgv,
358
+ dA,
359
+ offsets,
360
+ indices,
361
+ scale,
362
+ T,
363
+ B: tl.constexpr,
364
+ HQ: tl.constexpr,
365
+ H: tl.constexpr,
366
+ K: tl.constexpr,
367
+ V: tl.constexpr,
368
+ BT: tl.constexpr,
369
+ BK: tl.constexpr,
370
+ BV: tl.constexpr,
371
+ NG: tl.constexpr,
372
+ USE_OFFSETS: tl.constexpr,
373
+ HEAD_FIRST: tl.constexpr
374
+ ):
375
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
376
+ i_bg = i_bh // NG
377
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
378
+ i_h = i_hq // NG
379
+ if USE_OFFSETS:
380
+ i_tg = i_t
381
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
382
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
383
+ all = T
384
+ T = eos - bos
385
+ NT = tl.cdiv(T, BT)
386
+ else:
387
+ NT = tl.cdiv(T, BT)
388
+ i_tg = i_b * NT + i_t
389
+ bos, eos = i_b * T, i_b * T + T
390
+ all = B * T
391
+
392
+ o_i = tl.arange(0, BT)
393
+ o_t = min(i_t * BT + BT, T)
394
+ m_s = o_i[:, None] >= o_i[None, :]
395
+
396
+ if HEAD_FIRST:
397
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
398
+ p_k = tl.make_block_ptr(k + i_bg * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
399
+ p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh) * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
400
+ else:
401
+ p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
402
+ p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
403
+ p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
404
+
405
+ # [BT, BK]
406
+ b_q = tl.load(p_q, boundary_check=(0, 1))
407
+ b_k = tl.load(p_k, boundary_check=(0, 1))
408
+ # [BT, BT]
409
+ b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k))
410
+ b_A = tl.where(m_s, b_A, 0.)
411
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
412
+
413
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
414
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
415
+ for i_v in range(tl.cdiv(V, BV)):
416
+ o_v = i_v * BV + tl.arange(0, BV)
417
+ if HEAD_FIRST:
418
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
419
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
420
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (o_t - 1) * V + o_v, BV), BV)
421
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
422
+ p_dv = tl.make_block_ptr(dv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
423
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
424
+ p_dgv = tl.make_block_ptr(dgv + (i_k*B*H+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
425
+ p_h = tl.make_block_ptr(h + i_bg * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
426
+ p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
427
+ else:
428
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
429
+ p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
430
+ p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v
431
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
432
+ p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
433
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
434
+ p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
435
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
436
+ p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
437
+ m_v = o_v < V
438
+
439
+ # [BV,]
440
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
441
+ # [BT, BV]
442
+ b_v = tl.load(p_v, boundary_check=(0, 1))
443
+ b_g = tl.load(p_g, boundary_check=(0, 1))
444
+ b_gv = exp(b_gn[None, :] - b_g)
445
+ # [BV, BK]
446
+ b_h = tl.load(p_h, boundary_check=(0, 1))
447
+ # [BT, BV]
448
+ b_do = tl.load(p_do, boundary_check=(0, 1))
449
+ b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
450
+ # [BK, BV]
451
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
452
+ # [BV]
453
+ b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)
454
+
455
+ b_dh = b_dh.to(b_k.dtype)
456
+ # [BT, BK]
457
+ b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
458
+ b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
459
+ # [BT, BV]
460
+ b_dv = tl.dot(b_k, b_dh) * b_gv
461
+ # [BV]
462
+ b_dg += tl.sum(b_dv * b_v, 0)
463
+
464
+ if i_k == 0:
465
+ b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :]
466
+ else:
467
+ b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :]
468
+
469
+ tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1))
470
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
471
+ if HEAD_FIRST:
472
+ p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
473
+ p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
474
+ p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
475
+ else:
476
+ p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
477
+ p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
478
+ p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
479
+ # [BT, BT]
480
+ b_dA = tl.load(p_dA, boundary_check=(0, 1))
481
+ # [BT, BK]
482
+ b_dq += tl.dot(b_dA, b_k)
483
+ b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q)
484
+
485
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
486
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
487
+
488
+
489
+ @triton.heuristics({
490
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
491
+ })
492
+ @triton.jit(do_not_specialize=['T'])
493
+ def chunk_gsa_bwd_k_kernel_intra_dvg(
494
+ v,
495
+ g,
496
+ o,
497
+ A,
498
+ do,
499
+ dv,
500
+ dg,
501
+ offsets,
502
+ indices,
503
+ T,
504
+ HQ: tl.constexpr,
505
+ H: tl.constexpr,
506
+ V: tl.constexpr,
507
+ BT: tl.constexpr,
508
+ BC: tl.constexpr,
509
+ BV: tl.constexpr,
510
+ NC: tl.constexpr,
511
+ NG: tl.constexpr,
512
+ USE_OFFSETS: tl.constexpr,
513
+ HEAD_FIRST: tl.constexpr
514
+ ):
515
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
516
+ i_bg = i_bh // NG
517
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
518
+ i_h = i_hq // NG
519
+ i_t, i_i = i_c // NC, i_c % NC
520
+ if USE_OFFSETS:
521
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
522
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
523
+ T = eos - bos
524
+ else:
525
+ bos, eos = i_b * T, i_b * T + T
526
+
527
+ o_v = i_v * BV + tl.arange(0, BV)
528
+ m_v = o_v < V
529
+
530
+ if i_t * BT + i_i * BC > T:
531
+ return
532
+
533
+ if HEAD_FIRST:
534
+ p_gv = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
535
+ p_gn = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (min(i_t * BT + i_i * BC + BC, T) - 1) * V + o_v, BV), BV)
536
+ else:
537
+ p_gv = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
538
+ p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T)-1)*H*V + i_h*V + o_v
539
+ # [BV,]
540
+ b_gn = tl.load(p_gn, mask=m_v, other=0)
541
+ # [BC, BV]
542
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
543
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
544
+ for i_j in range(i_i + 1, NC):
545
+ if HEAD_FIRST:
546
+ p_g = tl.make_block_ptr(g + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
547
+ p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1))
548
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
549
+ else:
550
+ p_g = tl.make_block_ptr(g + (bos*H+i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0))
551
+ p_A = tl.make_block_ptr(A + (bos*HQ+i_hq) * BT, (BT, T), (1, HQ*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
552
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq) * V, (T, V), (HQ*V, 1), (i_t*BT + i_j*BC, i_v*BV), (BC, BV), (1, 0))
553
+ # [BC, BV]
554
+ b_g = tl.load(p_g, boundary_check=(0, 1))
555
+ b_do = tl.load(p_do, boundary_check=(0, 1)) * safe_exp(b_g - b_gn[None, :])
556
+ # [BC, BC]
557
+ b_A = tl.load(p_A, boundary_check=(0, 1))
558
+ # [BC, BV]
559
+ b_dv += tl.dot(b_A, b_do.to(b_A.dtype))
560
+ b_dv *= exp(b_gn[None, :] - b_gv)
561
+
562
+ o_i = tl.arange(0, BC)
563
+ o_c = i_i * BC + tl.arange(0, BC)
564
+
565
+ if HEAD_FIRST:
566
+ p_g = tl.max_contiguous(tl.multiple_of(g + i_bg * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
567
+ p_A = tl.max_contiguous(tl.multiple_of(A + i_bh * T*BT + (i_t * BT + i_i * BC) * BT + o_c, BC), BC)
568
+ p_do = tl.max_contiguous(tl.multiple_of(do + i_bh * T*V + (i_t * BT + i_i * BC) * V + o_v, BV), BV)
569
+ else:
570
+ p_g = g + (bos + i_t * BT + i_i * BC) * H*V + i_h * V + o_v
571
+ p_A = A + (bos + i_t*BT + i_i*BC) * HQ*BT + i_hq * BT + o_c
572
+ p_do = do + (bos + i_t*BT + i_i*BC) * HQ*V + i_hq * V + o_v
573
+
574
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
575
+ # [BC,]
576
+ b_A = tl.load(p_A)
577
+ # [BV,]
578
+ b_g = tl.load(p_g, mask=m_v, other=0)
579
+ b_do = tl.load(p_do, mask=m_v, other=0)
580
+ # [BC, BV]
581
+ m_i = o_i[:, None] <= j
582
+ b_dv += tl.where(m_i, exp(b_g[None, :] - b_gv) * b_A[:, None] * b_do[None, :], 0.)
583
+
584
+ p_g += (1 if HEAD_FIRST else H) * V
585
+ p_A += (1 if HEAD_FIRST else HQ) * BT
586
+ p_do += (1 if HEAD_FIRST else HQ) * V
587
+ if HEAD_FIRST:
588
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
589
+ p_v = tl.make_block_ptr(v + i_bg * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
590
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
591
+ p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
592
+ p_dg = tl.make_block_ptr(dg + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
593
+ else:
594
+ p_o = tl.make_block_ptr(o + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
595
+ p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
596
+ p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
597
+ p_dv = tl.make_block_ptr(dv + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
598
+ p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t*BT + i_i*BC, i_v*BV), (BC, BV), (1, 0))
599
+
600
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
601
+ b_v = tl.load(p_v, boundary_check=(0, 1)).to(tl.float32)
602
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(tl.float32)
603
+ b_dv = b_dv + tl.load(p_dv, boundary_check=(0, 1)).to(tl.float32)
604
+ b_dg = b_o * b_do - b_v * b_dv
605
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
606
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
607
+
608
+
609
+ def chunk_gsa_fwd_v(
610
+ q: torch.Tensor,
611
+ k: torch.Tensor,
612
+ v: torch.Tensor,
613
+ g: torch.Tensor,
614
+ scale: float = 1.,
615
+ initial_state: Optional[torch.Tensor] = None,
616
+ output_final_state: bool = False,
617
+ offsets: Optional[torch.LongTensor] = None,
618
+ indices: Optional[torch.LongTensor] = None,
619
+ head_first: bool = True,
620
+ chunk_size: int = 64
621
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
622
+ _, A, h, ht, o = chunk_gla_fwd(
623
+ q=q,
624
+ k=k,
625
+ v=v,
626
+ g=None,
627
+ g_cumsum=g,
628
+ scale=scale,
629
+ initial_state=initial_state,
630
+ output_final_state=output_final_state,
631
+ offsets=offsets,
632
+ indices=indices,
633
+ head_first=head_first,
634
+ chunk_size=chunk_size
635
+ )
636
+ return A, h, ht, o
637
+
638
+
639
+ def chunk_gsa_fwd_k(
640
+ q: torch.Tensor,
641
+ k: torch.Tensor,
642
+ v: torch.Tensor,
643
+ g: torch.Tensor,
644
+ h0: Optional[torch.Tensor] = None,
645
+ output_final_state: bool = False,
646
+ scale: float = 1.,
647
+ offsets: Optional[torch.LongTensor] = None,
648
+ indices: Optional[torch.LongTensor] = None,
649
+ head_first: bool = True,
650
+ chunk_size: int = 64
651
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
652
+ if head_first:
653
+ B, H, T, K, V = *k.shape, v.shape[-1]
654
+ else:
655
+ B, T, H, K, V = *k.shape, v.shape[-1]
656
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
657
+ BC = min(16, BT)
658
+ BV = min(64, triton.next_power_of_2(V))
659
+ HQ = q.shape[1] if head_first else q.shape[2]
660
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
661
+ NC = triton.cdiv(BT, BC)
662
+ NG = HQ // H
663
+
664
+ h, ht = chunk_fwd_h(
665
+ k=k,
666
+ v=v,
667
+ g=None,
668
+ gk=None,
669
+ gv=g,
670
+ h0=h0,
671
+ output_final_state=output_final_state,
672
+ offsets=offsets,
673
+ head_first=head_first,
674
+ chunk_size=BT,
675
+ states_in_fp32=False
676
+ )
677
+ o = v.new_empty(B, *((HQ, T) if head_first else (T, HQ)), V)
678
+ A = q.new_empty(B, *((HQ, T) if head_first else (T, HQ)), BT)
679
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ)
680
+ chunk_gsa_fwd_k_kernel_inter[grid](
681
+ q,
682
+ k,
683
+ h,
684
+ g,
685
+ o,
686
+ A,
687
+ offsets=offsets,
688
+ indices=indices,
689
+ scale=scale,
690
+ T=T,
691
+ HQ=HQ,
692
+ H=H,
693
+ K=K,
694
+ V=V,
695
+ BT=BT,
696
+ NG=NG,
697
+ HEAD_FIRST=head_first
698
+ )
699
+
700
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
701
+ chunk_gsa_fwd_k_kernel_intra[grid](
702
+ v,
703
+ g,
704
+ o,
705
+ A,
706
+ offsets=offsets,
707
+ indices=indices,
708
+ T=T,
709
+ HQ=HQ,
710
+ H=H,
711
+ V=V,
712
+ BT=BT,
713
+ BC=BC,
714
+ BV=BV,
715
+ NC=NC,
716
+ NG=NG,
717
+ HEAD_FIRST=head_first,
718
+ num_warps=4,
719
+ num_stages=2
720
+ )
721
+ return A, h, ht, o
722
+
723
+
724
+ def chunk_gsa_bwd_v(
725
+ q: torch.Tensor,
726
+ k: torch.Tensor,
727
+ v: torch.Tensor,
728
+ g: torch.Tensor,
729
+ h0: torch.Tensor,
730
+ h: torch.Tensor,
731
+ A: torch.Tensor,
732
+ do: torch.Tensor,
733
+ dht: torch.Tensor,
734
+ dg: torch.Tensor,
735
+ scale: float = 1.,
736
+ offsets: Optional[torch.LongTensor] = None,
737
+ indices: Optional[torch.LongTensor] = None,
738
+ head_first: bool = True,
739
+ chunk_size: int = 64
740
+ ):
741
+ dq, dk, dv, dg, dh0 = chunk_gla_bwd(
742
+ q=q,
743
+ k=k,
744
+ v=v,
745
+ g=None,
746
+ g_cumsum=g,
747
+ scale=scale,
748
+ initial_state=h0,
749
+ h=h,
750
+ A=A,
751
+ do=do,
752
+ dht=dht,
753
+ offsets=offsets,
754
+ indices=indices,
755
+ head_first=head_first,
756
+ chunk_size=chunk_size
757
+ )
758
+ return dq, dk, dv, dg, dh0
759
+
760
+
761
+ def chunk_gsa_bwd_k(
762
+ q: torch.Tensor,
763
+ k: torch.Tensor,
764
+ v: torch.Tensor,
765
+ g: torch.Tensor,
766
+ h: torch.Tensor,
767
+ h0: torch.Tensor,
768
+ o: torch.Tensor,
769
+ do: torch.Tensor,
770
+ dht: torch.Tensor,
771
+ dg: torch.Tensor,
772
+ scale: float = 1.,
773
+ offsets: Optional[torch.LongTensor] = None,
774
+ indices: Optional[torch.LongTensor] = None,
775
+ head_first: bool = True,
776
+ chunk_size: int = 64
777
+ ):
778
+ if head_first:
779
+ B, H, T, K, V = *k.shape, v.shape[-1]
780
+ else:
781
+ B, T, H, K, V = *k.shape, v.shape[-1]
782
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
783
+ BC = min(16, BT)
784
+ BK = min(64, triton.next_power_of_2(K))
785
+ BV = min(64, triton.next_power_of_2(V))
786
+ HQ = q.shape[1] if head_first else q.shape[2]
787
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
788
+ NC = triton.cdiv(BT, BC)
789
+ NK = triton.cdiv(K, BK)
790
+ NV = triton.cdiv(V, BV)
791
+ NG = HQ // H
792
+
793
+ if h is None:
794
+ h, _ = chunk_fwd_h(
795
+ k=k,
796
+ v=v,
797
+ g=None,
798
+ gk=None,
799
+ gv=g,
800
+ h0=h0,
801
+ output_final_state=False,
802
+ offsets=offsets,
803
+ head_first=head_first,
804
+ chunk_size=BT,
805
+ states_in_fp32=False
806
+ )
807
+ dh, dh0 = chunk_bwd_dh(
808
+ q=q,
809
+ k=k,
810
+ v=v,
811
+ g=None,
812
+ gk=None,
813
+ gv=g,
814
+ do=do,
815
+ h0=h0,
816
+ dht=dht,
817
+ scale=scale,
818
+ offsets=offsets,
819
+ head_first=head_first,
820
+ chunk_size=BT,
821
+ states_in_fp32=True
822
+ )
823
+ dA = q.new_empty(NV, B, *((HQ, T) if head_first else (T, HQ)), BT)
824
+ grid = (NV, NT * NC * NC, B * HQ)
825
+ chunk_gsa_bwd_k_kernel_dA[grid](
826
+ v,
827
+ g,
828
+ do,
829
+ dA,
830
+ offsets=offsets,
831
+ indices=indices,
832
+ scale=scale,
833
+ T=T,
834
+ B=B,
835
+ HQ=HQ,
836
+ H=H,
837
+ V=V,
838
+ BT=BT,
839
+ BC=BC,
840
+ BV=BV,
841
+ NC=NC,
842
+ NG=NG,
843
+ HEAD_FIRST=head_first
844
+ )
845
+ dA = dA.sum(0, dtype=dA.dtype)
846
+
847
+ A = do.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), BT)
848
+ dq = torch.empty_like(q)
849
+ dk = k.new_empty(B, *((HQ, T) if head_first else (T, HQ)), K)
850
+ dv = v.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V)
851
+ dgv = g.new_empty(NK, B, *((HQ, T) if head_first else (T, HQ)), V, dtype=torch.float)
852
+ grid = (NK, NT, B * HQ)
853
+ chunk_gsa_bwd_k_kernel_dqkvg[grid](
854
+ q,
855
+ k,
856
+ v,
857
+ h,
858
+ g,
859
+ A,
860
+ do,
861
+ dh,
862
+ dq,
863
+ dk,
864
+ dv,
865
+ dg,
866
+ dgv,
867
+ dA,
868
+ offsets=offsets,
869
+ indices=indices,
870
+ scale=scale,
871
+ T=T,
872
+ B=B,
873
+ HQ=HQ,
874
+ H=H,
875
+ K=K,
876
+ V=V,
877
+ BT=BT,
878
+ BK=BK,
879
+ BV=BV,
880
+ NG=NG,
881
+ HEAD_FIRST=head_first
882
+ )
883
+ A = A.sum(0, dtype=A.dtype)
884
+ dv = dv.sum(0, dtype=dv.dtype)
885
+ dgv = dgv.sum(0, dtype=dgv.dtype)
886
+
887
+ def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
888
+ chunk_gsa_bwd_k_kernel_intra_dvg[grid](
889
+ v,
890
+ g,
891
+ o,
892
+ A,
893
+ do,
894
+ dv,
895
+ dg,
896
+ offsets=offsets,
897
+ indices=indices,
898
+ T=T,
899
+ HQ=HQ,
900
+ H=H,
901
+ V=V,
902
+ BT=BT,
903
+ BC=BC,
904
+ BV=BV,
905
+ NC=NC,
906
+ NG=NG,
907
+ HEAD_FIRST=head_first,
908
+ num_warps=4,
909
+ num_stages=2
910
+ )
911
+ dg = dgv.add_(chunk_local_cumsum(dg, chunk_size=BT, reverse=True, offsets=offsets, indices=indices, head_first=head_first))
912
+
913
+ return dq, dk, dv, dg, dh0
914
+
915
+
916
+ def chunk_gsa_fwd(
917
+ q: torch.Tensor,
918
+ k: torch.Tensor,
919
+ v: torch.Tensor,
920
+ s: torch.Tensor,
921
+ g: torch.Tensor,
922
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
923
+ output_final_state: bool = False,
924
+ scale: float = 1.,
925
+ offsets: Optional[torch.LongTensor] = None,
926
+ indices: Optional[torch.LongTensor] = None,
927
+ head_first: bool = True,
928
+ chunk_size: int = 64
929
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
930
+ hk0, hv0 = None, None
931
+ if initial_state is not None:
932
+ hk0, hv0 = initial_state
933
+ Ak, hk, hkt, ok = chunk_gsa_fwd_k(
934
+ q=q,
935
+ k=k,
936
+ v=s,
937
+ g=g,
938
+ h0=hk0,
939
+ output_final_state=output_final_state,
940
+ scale=scale,
941
+ offsets=offsets,
942
+ indices=indices,
943
+ head_first=head_first,
944
+ chunk_size=chunk_size
945
+ )
946
+
947
+ # p is kept in fp32 for safe softmax backward
948
+ p = softmax_fwd(ok, dtype=torch.float)
949
+
950
+ qv = p.to(q.dtype)
951
+ Av, hv, hvt, ov = chunk_gsa_fwd_v(
952
+ q=qv,
953
+ k=s,
954
+ v=v,
955
+ g=g,
956
+ scale=1.,
957
+ initial_state=hv0,
958
+ output_final_state=output_final_state,
959
+ offsets=offsets,
960
+ indices=indices,
961
+ head_first=head_first,
962
+ chunk_size=chunk_size
963
+ )
964
+ return Ak, hk, hkt, ok, p, Av, hv, hvt, ov
965
+
966
+
967
+ def chunk_gsa_bwd(
968
+ q: torch.Tensor,
969
+ k: torch.Tensor,
970
+ v: torch.Tensor,
971
+ s: torch.Tensor,
972
+ g: torch.Tensor,
973
+ ok: torch.Tensor,
974
+ p: torch.Tensor,
975
+ A: Tuple[torch.Tensor, torch.Tensor],
976
+ h: Tuple[torch.Tensor, torch.Tensor],
977
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]],
978
+ scale: float,
979
+ do: torch.Tensor,
980
+ dht: Tuple[torch.Tensor, torch.Tensor],
981
+ offsets: Optional[torch.LongTensor] = None,
982
+ indices: Optional[torch.LongTensor] = None,
983
+ head_first: bool = True,
984
+ chunk_size: int = 64
985
+ ):
986
+ hk0, hv0 = None, None
987
+ if initial_state is not None:
988
+ hk0, hv0 = initial_state
989
+
990
+ _, Av = A
991
+ hk, hv = h
992
+ dhkt, dhvt = dht
993
+
994
+ qv = p.to(q.dtype)
995
+ dqv, dsv, dv, dg, dhv0 = chunk_gsa_bwd_v(
996
+ q=qv,
997
+ k=s,
998
+ v=v,
999
+ g=g,
1000
+ h0=hv0,
1001
+ h=hv,
1002
+ A=Av,
1003
+ do=do,
1004
+ dht=dhvt,
1005
+ dg=None,
1006
+ scale=1.,
1007
+ offsets=offsets,
1008
+ indices=indices,
1009
+ head_first=head_first,
1010
+ chunk_size=chunk_size
1011
+ )
1012
+
1013
+ # softmax gradient, equivalent to:
1014
+ # dok = qv * (dqv - (qv * dqv).sum(-1, True))
1015
+ dok = softmax_bwd(p, dqv, dtype=ok.dtype)
1016
+
1017
+ dq, dk, dsk, dg, dhk0 = chunk_gsa_bwd_k(
1018
+ q=q,
1019
+ k=k,
1020
+ v=s,
1021
+ g=g,
1022
+ h0=hk0,
1023
+ h=hk,
1024
+ o=ok,
1025
+ do=dok,
1026
+ dht=dhkt,
1027
+ dg=dg,
1028
+ scale=scale,
1029
+ offsets=offsets,
1030
+ indices=indices,
1031
+ head_first=head_first,
1032
+ chunk_size=chunk_size
1033
+ )
1034
+
1035
+ ds = dsv.add_(dsk)
1036
+ if q.shape[1] != k.shape[1]:
1037
+ dk, dv, ds, dg = map(lambda x: reduce(x, 'b (h g) ... -> b h ...', 'sum', h=k.shape[1]), (dk, dv, ds, dg))
1038
+ dg = dg.to(s.dtype)
1039
+ return dq, dk, dv, ds, dg, dhk0, dhv0
1040
+
1041
+
1042
+ class ChunkGSAFunction(torch.autograd.Function):
1043
+
1044
+ @staticmethod
1045
+ @input_guard
1046
+ def forward(
1047
+ ctx,
1048
+ q: torch.Tensor,
1049
+ k: torch.Tensor,
1050
+ v: torch.Tensor,
1051
+ s: torch.Tensor,
1052
+ g: torch.Tensor,
1053
+ scale: float,
1054
+ hk0: Optional[torch.Tensor],
1055
+ hv0: Optional[torch.Tensor],
1056
+ output_final_state: bool,
1057
+ checkpoint_level: int,
1058
+ offsets: Optional[torch.LongTensor],
1059
+ head_first: bool = True
1060
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1061
+ T = q.shape[2] if head_first else q.shape[1]
1062
+ chunk_size = min(64, max(16, triton.next_power_of_2(T)))
1063
+
1064
+ # 2-d indices denoting the offsets of chunks in each sequence
1065
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
1066
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
1067
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1068
+ indices = None
1069
+ if offsets is not None:
1070
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()])
1071
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
1072
+ g_org, g = g, chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1073
+ Ak, hk, hkt, ok, p, Av, hv, hvt, ov = chunk_gsa_fwd(
1074
+ q=q,
1075
+ k=k,
1076
+ v=v,
1077
+ s=s,
1078
+ g=g,
1079
+ initial_state=(hk0, hv0),
1080
+ output_final_state=output_final_state,
1081
+ scale=scale,
1082
+ offsets=offsets,
1083
+ indices=indices,
1084
+ head_first=head_first,
1085
+ chunk_size=chunk_size
1086
+ )
1087
+
1088
+ if checkpoint_level >= 1:
1089
+ del g
1090
+ g = g_org
1091
+ if checkpoint_level > 1:
1092
+ del hk
1093
+ del hv
1094
+ hk, hv = None, None
1095
+ else:
1096
+ hk0, hv0 = None, None
1097
+
1098
+ ctx.save_for_backward(q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv)
1099
+ ctx.checkpoint_level = checkpoint_level
1100
+ ctx.scale = scale
1101
+ ctx.offsets = offsets
1102
+ ctx.indices = indices
1103
+ ctx.head_first = head_first
1104
+ ctx.chunk_size = chunk_size
1105
+ return ov, hkt, hvt
1106
+
1107
+ @staticmethod
1108
+ @input_guard
1109
+ def backward(ctx, dov, dhkt=None, dhvt=None):
1110
+ q, k, v, s, g, ok, p, Av, hk0, hv0, hk, hv = ctx.saved_tensors
1111
+ scale = ctx.scale
1112
+ offsets = ctx.offsets
1113
+ indices = ctx.indices
1114
+ head_first = ctx.head_first
1115
+ chunk_size = ctx.chunk_size
1116
+
1117
+ if ctx.checkpoint_level >= 1:
1118
+ g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first)
1119
+ dq, dk, dv, ds, dg, dhk0, dhv0 = chunk_gsa_bwd(
1120
+ q=q,
1121
+ k=k,
1122
+ v=v,
1123
+ s=s,
1124
+ g=g,
1125
+ ok=ok,
1126
+ p=p,
1127
+ A=(None, Av),
1128
+ h=(hk, hv),
1129
+ initial_state=(hk0, hv0),
1130
+ scale=scale,
1131
+ do=dov,
1132
+ dht=(dhkt, dhvt),
1133
+ offsets=offsets,
1134
+ indices=indices,
1135
+ head_first=head_first,
1136
+ chunk_size=chunk_size
1137
+ )
1138
+ return dq, dk, dv, ds, dg, None, dhk0, dhv0, None, None, None, None
1139
+
1140
+
1141
+ @torch.compiler.disable
1142
+ def chunk_gsa(
1143
+ q: torch.Tensor,
1144
+ k: torch.Tensor,
1145
+ v: torch.Tensor,
1146
+ s: torch.Tensor,
1147
+ g: Optional[torch.Tensor] = None,
1148
+ scale: Optional[int] = None,
1149
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
1150
+ output_final_state: Optional[bool] = False,
1151
+ checkpoint_level: Optional[int] = 2,
1152
+ cu_seqlens: Optional[torch.LongTensor] = None,
1153
+ head_first: Optional[bool] = True
1154
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1155
+ r"""
1156
+ Args:
1157
+ q (torch.Tensor):
1158
+ queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`.
1159
+ k (torch.Tensor):
1160
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
1161
+ GQA is performed if `H` is not equal to `HQ`.
1162
+ v (torch.Tensor):
1163
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1164
+ s (torch.Tensor):
1165
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`.
1166
+ g (torch.Tensor):
1167
+ Forget gates of shape `[B, H, T, M]` applied to keys.
1168
+ If not provided, this function is equivalent to vanilla ABC.
1169
+ scale (Optional[int]):
1170
+ Scale factor for attention scores.
1171
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1172
+ initial_state (Optional[Tuple[torch.Tensor]]):
1173
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
1174
+ For equal-length input sequences, `N` equals the batch size `B`.
1175
+ Default: `None`.
1176
+ output_final_state (Optional[bool]):
1177
+ Whether to output the final state tuple, having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
1178
+ Default: `False`.
1179
+ checkpoint_level (Optional[int]):
1180
+ Checkpointing level; higher values will save more memories and do more recomputations during backward.
1181
+ Default: `2`:
1182
+ - Level `0`: no memory saved, no recomputation.
1183
+ - Level `1`: recompute the fp32 cumulative values during backward.
1184
+ - Level `2`: recompute the fp32 cumulative values and forward hidden states during backward.
1185
+ cu_seqlens (torch.LongTensor):
1186
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1187
+ consistent with the FlashAttention API.
1188
+ head_first (Optional[bool]):
1189
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
1190
+ Default: `True`.
1191
+
1192
+ Returns:
1193
+ o (torch.Tensor):
1194
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
1195
+ final_state (Tuple[torch.Tensor]):
1196
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` if `output_final_state=True`.
1197
+ `None` otherwise.
1198
+
1199
+ Examples::
1200
+ >>> import torch
1201
+ >>> import torch.nn.functional as F
1202
+ >>> from einops import rearrange
1203
+ >>> from fla.ops.gsa import fused_recurrent_gsa
1204
+ # inputs with equal lengths
1205
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
1206
+ >>> q = torch.randn(B, T, H, K, device='cuda')
1207
+ >>> k = torch.randn(B, T, H, K, device='cuda')
1208
+ >>> v = torch.randn(B, T, H, V, device='cuda')
1209
+ >>> s = torch.randn(B, T, H, M, device='cuda')
1210
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
1211
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
1212
+ >>> o, (hk, hv) = chunk_gsa(q, k, v, s, g,
1213
+ initial_state=h0,
1214
+ output_final_state=True,
1215
+ head_first=False)
1216
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
1217
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
1218
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
1219
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
1220
+ >>> o_var, (hk_var, hv_var) = chunk_gsa(q, k, v, s, g,
1221
+ initial_state=h0,
1222
+ output_final_state=True,
1223
+ cu_seqlens=cu_seqlens,
1224
+ head_first=False)
1225
+ >>> assert o.allclose(o_var.view(o.shape))
1226
+ >>> assert hk.allclose(hk_var)
1227
+ >>> assert hv.allclose(hv_var)
1228
+ """
1229
+ if cu_seqlens is not None:
1230
+ if q.shape[0] != 1:
1231
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
1232
+ f"Please flatten variable-length inputs before processing.")
1233
+ if head_first:
1234
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
1235
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
1236
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
1237
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}.")
1238
+ assert checkpoint_level in [0, 1, 2]
1239
+ if g is None:
1240
+ # TODO: this 3 steps took huge amount of time, ought to be optimized
1241
+ z = s.float().logcumsumexp(2)
1242
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
1243
+ s = torch.exp(s - z).to(k.dtype)
1244
+ if scale is None:
1245
+ scale = q.shape[-1] ** -0.5
1246
+
1247
+ hk0, hv0 = None, None
1248
+ if initial_state is not None:
1249
+ hk0, hv0 = initial_state
1250
+ o, *final_state = ChunkGSAFunction.apply(
1251
+ q,
1252
+ k,
1253
+ v,
1254
+ s,
1255
+ g,
1256
+ scale,
1257
+ hk0,
1258
+ hv0,
1259
+ output_final_state,
1260
+ checkpoint_level,
1261
+ cu_seqlens,
1262
+ head_first
1263
+ )
1264
+ return o, final_state
fla/ops/gsa/fused_recurrent.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.fused_recurrent import fused_recurrent_bwd_kernel, fused_recurrent_fwd_kernel
11
+ from fla.ops.utils import chunk_global_cumsum
12
+ from fla.ops.utils.op import exp
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
14
+
15
+
16
+ @triton.jit
17
+ def fused_recurrent_gsa_inference_kernel(
18
+ q,
19
+ k,
20
+ v,
21
+ s,
22
+ g,
23
+ o,
24
+ hk0,
25
+ hv0,
26
+ hkt,
27
+ hvt,
28
+ scale,
29
+ K: tl.constexpr,
30
+ V: tl.constexpr,
31
+ M: tl.constexpr,
32
+ BK: tl.constexpr,
33
+ BV: tl.constexpr,
34
+ NG: tl.constexpr
35
+ ):
36
+ i_bh = tl.program_id(0)
37
+ i_bg = i_bh // NG
38
+
39
+ b_s = tl.load(s + i_bg * M + tl.arange(0, M)).to(tl.float32)
40
+ b_g = tl.load(g + i_bg * M + tl.arange(0, M)).to(tl.float32)
41
+ b_g = exp(b_g)
42
+
43
+ b_ok = tl.zeros([M], dtype=tl.float32)
44
+ for i_k in range(tl.cdiv(K, BK)):
45
+ o_k = i_k * BK + tl.arange(0, BK)
46
+
47
+ p_hk0 = hk0 + i_bg * K * M + (o_k[None, :]) * M + tl.arange(0, M)[:, None]
48
+ # [BK,]
49
+ mask_k = o_k < K
50
+ # [M, BK]
51
+ mask_hk = (tl.arange(0, M) < M)[:, None] & mask_k[None, :]
52
+ # [M, BK]
53
+ b_hk = tl.load(p_hk0, mask=mask_hk, other=0.).to(tl.float32)
54
+ # [BK,]
55
+ b_q = tl.load(q + i_bh * K + o_k, mask=mask_k, other=0.).to(tl.float32) * scale
56
+ b_k = tl.load(k + i_bg * K + o_k, mask=mask_k, other=0.).to(tl.float32)
57
+ b_hk = b_hk * b_g[:, None] + b_k[None, :] * b_s[:, None]
58
+ b_ok += tl.sum(b_hk * b_q[None, :], axis=1)
59
+
60
+ if i_bh % NG == 0:
61
+ p_hkt = hkt + i_bg * K * M + o_k[None, :] * M + tl.arange(0, M)[:, None]
62
+ tl.store(p_hkt, b_hk.to(p_hkt.dtype.element_ty), mask=mask_hk)
63
+
64
+ b_qv = tl.softmax(b_ok)
65
+ for i_v in range(tl.cdiv(V, BV)):
66
+ o_v = i_v * BV + tl.arange(0, BV)
67
+
68
+ p_hv0 = hv0 + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None]
69
+ # [BV,]
70
+ mask_v = o_v < V
71
+ # [BV, M]
72
+ mask_hv = mask_v[:, None] & (tl.arange(0, M) < M)[None, :]
73
+ # [BV, M]
74
+ b_hv = tl.load(p_hv0, mask=mask_hv, other=0).to(tl.float32)
75
+ # [BV,]
76
+ b_v = tl.load(v + i_bg * V + o_v, mask=mask_v, other=0).to(tl.float32)
77
+ b_hv = b_hv * b_g[None, :] + b_s[None, :] * b_v[:, None]
78
+ b_ov = tl.sum(b_hv * b_qv[None, :], axis=1)
79
+
80
+ tl.store(o + i_bh * V + o_v, b_ov.to(o.dtype.element_ty), mask=mask_v)
81
+
82
+ if i_bh % NG == 0:
83
+ p_hvt = hvt + i_bg * M * V + tl.arange(0, M)[None, :] * V + o_v[:, None]
84
+ tl.store(p_hvt, b_hv.to(p_hvt.dtype.element_ty), mask=mask_hv)
85
+
86
+
87
+ def fused_recurrent_gsa_inference(
88
+ q: torch.Tensor,
89
+ k: torch.Tensor,
90
+ v: torch.Tensor,
91
+ s: torch.Tensor,
92
+ g: torch.Tensor,
93
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
94
+ output_final_state: bool = False,
95
+ scale: float = 1.,
96
+ head_first: bool = True
97
+ ) -> torch.Tensor:
98
+ if head_first:
99
+ B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
100
+ else:
101
+ B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
102
+ HQ = q.shape[1] if head_first else q.shape[2]
103
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
104
+ NG = HQ // H
105
+
106
+ if initial_state != (None, None) and initial_state is not None:
107
+ hk0, hv0 = initial_state
108
+ else:
109
+ hk0, hv0 = q.new_zeros(B, H, K, M, dtype=torch.float), q.new_zeros(B, H, M, V, dtype=torch.float)
110
+
111
+ hkt, hvt = None, None
112
+ if output_final_state:
113
+ if NG == 1:
114
+ hkt, hvt = hk0, hv0
115
+ else:
116
+ hkt, hvt = q.new_empty(B, H, K, M, dtype=torch.float), q.new_empty(B, H, M, V, dtype=torch.float)
117
+
118
+ o = v.new_empty(B, HQ, T, V) if head_first else v.new_empty(B, T, HQ, V)
119
+ grid = (B * HQ,)
120
+ fused_recurrent_gsa_inference_kernel[grid](
121
+ q,
122
+ k,
123
+ v,
124
+ s,
125
+ g,
126
+ o,
127
+ hk0,
128
+ hv0,
129
+ hkt,
130
+ hvt,
131
+ scale=scale,
132
+ K=K,
133
+ V=V,
134
+ M=M,
135
+ BK=BK,
136
+ BV=BV,
137
+ NG=NG
138
+ )
139
+ return o, (hkt, hvt)
140
+
141
+
142
+ def fused_recurrent_gsa_fwd(
143
+ q: torch.Tensor,
144
+ k: torch.Tensor,
145
+ v: torch.Tensor,
146
+ s: torch.Tensor,
147
+ g: torch.Tensor,
148
+ initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
149
+ output_final_state: bool = False,
150
+ scale: float = 1.,
151
+ reverse: bool = False,
152
+ offsets: Optional[torch.LongTensor] = None,
153
+ head_first: bool = True
154
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
155
+ if head_first:
156
+ B, H, T, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
157
+ else:
158
+ B, T, H, K, V, M = *k.shape, v.shape[-1], s.shape[-1]
159
+ N = B if offsets is None else len(offsets) - 1
160
+ HQ = q.shape[1] if head_first else q.shape[2]
161
+ if HQ != H:
162
+ raise ValueError("GQA not supported yet.")
163
+
164
+ BK, BV, BM = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64), min(M, 64)
165
+ NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
166
+
167
+ hk0, hv0 = None, None
168
+ if initial_state != (None, None) and initial_state is not None:
169
+ hk0, hv0 = initial_state
170
+ hkt, hvt = None, None
171
+ if output_final_state:
172
+ hkt, hvt = q.new_empty(N, H, K, M, dtype=torch.float), q.new_empty(N, H, M, V, dtype=torch.float)
173
+
174
+ ok = q.new_empty(NK, *s.shape, dtype=torch.float)
175
+ gk, gv = None, g
176
+ grid = (NM, NK, N * H)
177
+ fused_recurrent_fwd_kernel[grid](
178
+ q=q,
179
+ k=k,
180
+ v=s,
181
+ g=None,
182
+ gk=gk,
183
+ gv=gv,
184
+ o=ok,
185
+ h0=hk0,
186
+ ht=hkt,
187
+ offsets=offsets,
188
+ scale=scale,
189
+ B=B,
190
+ T=T,
191
+ H=H,
192
+ K=K,
193
+ V=M,
194
+ BK=BK,
195
+ BV=BM,
196
+ USE_G=False,
197
+ USE_GK=False,
198
+ USE_GV=True,
199
+ REVERSE=reverse,
200
+ HEAD_FIRST=head_first
201
+ )
202
+ ok = ok.sum(0)
203
+
204
+ qv = ok.softmax(-1, dtype=torch.float)
205
+ ov = q.new_empty(NM, *v.shape, dtype=torch.float)
206
+ gk, gv = g, None
207
+ grid = (NV, NM, N * H)
208
+ fused_recurrent_fwd_kernel[grid](
209
+ q=qv,
210
+ k=s,
211
+ v=v,
212
+ g=None,
213
+ gk=gk,
214
+ gv=gv,
215
+ o=ov,
216
+ h0=hv0,
217
+ ht=hvt,
218
+ offsets=offsets,
219
+ scale=1.,
220
+ B=B,
221
+ T=T,
222
+ H=H,
223
+ K=M,
224
+ V=V,
225
+ BK=BM,
226
+ BV=BV,
227
+ USE_G=False,
228
+ USE_GK=True,
229
+ USE_GV=False,
230
+ REVERSE=reverse,
231
+ HEAD_FIRST=head_first
232
+ )
233
+ ov = ov.sum(0)
234
+ return ok, hkt, qv, ov, hvt
235
+
236
+
237
+ def fused_recurrent_gsa_bwd(
238
+ q: torch.Tensor,
239
+ k: torch.Tensor,
240
+ v: torch.Tensor,
241
+ s: torch.Tensor,
242
+ g: torch.Tensor,
243
+ qv: torch.Tensor,
244
+ hk0: Optional[torch.Tensor] = None,
245
+ hv0: Optional[torch.Tensor] = None,
246
+ ok: Optional[torch.Tensor] = None,
247
+ do: Optional[torch.Tensor] = None,
248
+ dhkt: Optional[torch.Tensor] = None,
249
+ dhvt: Optional[torch.Tensor] = None,
250
+ scale: float = 1.,
251
+ reverse: bool = False,
252
+ offsets: Optional[torch.LongTensor] = None,
253
+ head_first: bool = True
254
+ ) -> Tuple[torch.Tensor]:
255
+ if head_first:
256
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
257
+ else:
258
+ B, T, H, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
259
+ N = B if offsets is None else len(offsets) - 1
260
+
261
+ BK, BV, BM = min(K, 64), min(V, 64), min(M, 64)
262
+ NK, NV, NM = triton.cdiv(K, BK), triton.cdiv(V, BV), triton.cdiv(M, BM)
263
+
264
+ if head_first:
265
+ dqv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
266
+ dsv = q.new_empty(NV, B, H, T, M, dtype=torch.float)
267
+ dv = q.new_empty(NM, B, H, T, V, dtype=torch.float)
268
+ else:
269
+ dqv = q.new_empty(NV, B, T, H, M, dtype=torch.float)
270
+ dsv = q.new_empty(NV, B, T, H, M, dtype=torch.float)
271
+ dv = q.new_empty(NM, B, T, H, V, dtype=torch.float)
272
+ dhk0 = torch.empty_like(hk0)if hk0 is not None else None
273
+ dhv0 = torch.empty_like(hv0)if hv0 is not None else None
274
+
275
+ gk, gv = g, None
276
+ grid = (NV, NM, N * H)
277
+ fused_recurrent_bwd_kernel[grid](
278
+ q=qv,
279
+ k=s,
280
+ v=v,
281
+ g=None,
282
+ gk=gk,
283
+ gv=gv,
284
+ h0=hv0,
285
+ do=do,
286
+ dq=dqv,
287
+ dk=dsv,
288
+ dv=dv,
289
+ dht=dhvt,
290
+ dh0=dhv0,
291
+ offsets=offsets,
292
+ scale=1.,
293
+ B=B,
294
+ T=T,
295
+ H=H,
296
+ K=M,
297
+ V=V,
298
+ BK=BM,
299
+ BV=BV,
300
+ USE_G=False,
301
+ USE_GK=True,
302
+ USE_GV=False,
303
+ REVERSE=reverse,
304
+ HEAD_FIRST=head_first
305
+ )
306
+ dqv = dqv.sum(0)
307
+ dsv = dsv.sum(0)
308
+ dv = dv.sum(0)
309
+ dgk = chunk_global_cumsum(dqv * qv.float() - dsv * s.float(),
310
+ reverse=not reverse,
311
+ offsets=offsets,
312
+ head_first=head_first)
313
+
314
+ dok = qv * (dqv - (qv * dqv).sum(-1, True))
315
+ if head_first:
316
+ dq = q.new_empty(NM, B, H, T, K, dtype=torch.float)
317
+ dk = q.new_empty(NM, B, H, T, K, dtype=torch.float)
318
+ dsk = q.new_empty(NK, B, H, T, M, dtype=torch.float)
319
+ else:
320
+ dq = q.new_empty(NM, B, T, H, K, dtype=torch.float)
321
+ dk = q.new_empty(NM, B, T, H, K, dtype=torch.float)
322
+ dsk = q.new_empty(NK, B, T, H, M, dtype=torch.float)
323
+ gk, gv = None, g
324
+ grid = (NM, NK, N * H)
325
+ fused_recurrent_bwd_kernel[grid](
326
+ q=q,
327
+ k=k,
328
+ v=s,
329
+ g=None,
330
+ gk=gk,
331
+ gv=gv,
332
+ h0=hk0,
333
+ do=dok,
334
+ dq=dq,
335
+ dk=dk,
336
+ dv=dsk,
337
+ dht=dhkt,
338
+ dh0=dhk0,
339
+ offsets=offsets,
340
+ scale=scale,
341
+ B=B,
342
+ T=T,
343
+ H=H,
344
+ K=K,
345
+ V=M,
346
+ BK=BK,
347
+ BV=BM,
348
+ USE_G=False,
349
+ USE_GK=False,
350
+ USE_GV=True,
351
+ REVERSE=reverse,
352
+ HEAD_FIRST=head_first
353
+ )
354
+ dq = dq.sum(0)
355
+ dk = dk.sum(0)
356
+ dsk = dsk.sum(0)
357
+
358
+ dgv = chunk_global_cumsum(dok.float() * ok.float() - dsk * s.float(),
359
+ reverse=not reverse,
360
+ offsets=offsets,
361
+ head_first=head_first)
362
+
363
+ ds = dsk.add_(dsv)
364
+ dg = dgk.add_(dgv)
365
+
366
+ return dq, dk, dv, ds, dg, dhk0, dhv0
367
+
368
+
369
+ class FusedRecurrentGSAFunction(torch.autograd.Function):
370
+
371
+ @staticmethod
372
+ @input_guard
373
+ @autocast_custom_fwd
374
+ def forward(
375
+ ctx,
376
+ q: torch.Tensor,
377
+ k: torch.Tensor,
378
+ v: torch.Tensor,
379
+ s: torch.Tensor,
380
+ g: torch.Tensor,
381
+ scale: Optional[float] = None,
382
+ hk0: Optional[torch.Tensor] = None,
383
+ hv0: Optional[torch.Tensor] = None,
384
+ output_final_state: bool = False,
385
+ reverse: bool = False,
386
+ offsets: Optional[torch.LongTensor] = None,
387
+ head_first: bool = True
388
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
389
+ T = q.shape[2] if head_first else q.shape[1]
390
+ if T == 1 and not q.requires_grad:
391
+ o, (hkt, hvt) = fused_recurrent_gsa_inference(
392
+ q=q,
393
+ k=k,
394
+ v=v,
395
+ s=s,
396
+ g=g,
397
+ initial_state=(hk0, hv0),
398
+ output_final_state=output_final_state,
399
+ scale=scale,
400
+ head_first=head_first
401
+ )
402
+ return o, hkt, hvt
403
+ ok, hkt, qv, ov, hvt = fused_recurrent_gsa_fwd(
404
+ q=q,
405
+ k=k,
406
+ v=v,
407
+ s=s,
408
+ g=g,
409
+ initial_state=(hk0, hv0),
410
+ output_final_state=output_final_state,
411
+ scale=scale,
412
+ reverse=reverse,
413
+ offsets=offsets,
414
+ head_first=head_first
415
+ )
416
+ ctx.save_for_backward(q, k, v, s, g, qv, hk0, hv0, ok)
417
+ ctx.scale = scale
418
+ ctx.reverse = reverse
419
+ ctx.offsets = offsets
420
+ ctx.head_first = head_first
421
+ return ov.to(q.dtype), hkt, hvt
422
+
423
+ @staticmethod
424
+ @input_guard
425
+ @autocast_custom_bwd
426
+ def backward(ctx, do, dhkt=None, dhvt=None):
427
+ q, k, v, s, g, qv, hk0, hv0, ok = ctx.saved_tensors
428
+ scale = ctx.scale
429
+ reverse = ctx.reverse
430
+ offsets = ctx.offsets
431
+ head_first = ctx.head_first
432
+
433
+ # not supported yet.
434
+ if dhkt is not None or dhvt is not None:
435
+ if g is not None:
436
+ assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
437
+ dq, dk, dv, ds, dg, dhk0, dhv0 = fused_recurrent_gsa_bwd(
438
+ q=q,
439
+ k=k,
440
+ v=v,
441
+ s=s,
442
+ g=g,
443
+ qv=qv,
444
+ hk0=hk0,
445
+ hv0=hv0,
446
+ ok=ok,
447
+ do=do,
448
+ dhkt=dhkt,
449
+ dhvt=dhvt,
450
+ scale=scale,
451
+ reverse=reverse,
452
+ offsets=offsets,
453
+ head_first=head_first
454
+ )
455
+ return dq.to(q), dk.to(k), dv.to(v), ds.to(s), dg.to(g), None, dhk0, dhv0, None, None, None, None
456
+
457
+
458
+ def fused_recurrent_gsa(
459
+ q: torch.Tensor,
460
+ k: torch.Tensor,
461
+ v: torch.Tensor,
462
+ s: torch.Tensor,
463
+ g: Optional[torch.Tensor] = None,
464
+ scale: Optional[int] = None,
465
+ initial_state: Optional[Tuple[torch.Tensor]] = None,
466
+ output_final_state: Optional[bool] = False,
467
+ reverse: Optional[bool] = False,
468
+ cu_seqlens: Optional[torch.LongTensor] = None,
469
+ head_first: bool = True
470
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
471
+ r"""
472
+ Args:
473
+ q (torch.Tensor):
474
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
475
+ k (torch.Tensor):
476
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
477
+ v (torch.Tensor):
478
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
479
+ s (torch.Tensor):
480
+ slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`.
481
+ g (torch.Tensor):
482
+ Forget gates of shape `[B, H, T, M]` applied to keys.
483
+ scale (Optional[int]):
484
+ Scale factor for the attention scores.
485
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
486
+ initial_state (Optional[Tuple[torch.Tensor]]):
487
+ Initial state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]` for `N` input sequences.
488
+ For equal-length input sequences, `N` equals the batch size `B`.
489
+ Default: `None`.
490
+ output_final_state (Optional[bool]):
491
+ Whether to output the final state of shape `[N, H, K, V]` and `[N, H, M, V]`.
492
+ Default: `False`.
493
+ reverse (Optional[bool]):
494
+ If `True`, process the state passing in reverse order. Default: `False`.
495
+ cu_seqlens (torch.LongTensor):
496
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
497
+ consistent with the FlashAttention API.
498
+ head_first (Optional[bool]):
499
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
500
+ Default: `True`.
501
+
502
+ Returns:
503
+ o (torch.Tensor):
504
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
505
+ final_state (Tuple[torch.Tensor]):
506
+ Final state tuple having tensors of shape `[N, H, K, M]` and `[N, H, M, V]`.
507
+
508
+ Examples::
509
+ >>> import torch
510
+ >>> import torch.nn.functional as F
511
+ >>> from einops import rearrange
512
+ >>> from fla.ops.gsa import fused_recurrent_gsa
513
+ # inputs with equal lengths
514
+ >>> B, T, H, K, V, M = 4, 2048, 4, 512, 512, 64
515
+ >>> q = torch.randn(B, T, H, K, device='cuda')
516
+ >>> k = torch.randn(B, T, H, K, device='cuda')
517
+ >>> v = torch.randn(B, T, H, V, device='cuda')
518
+ >>> s = torch.randn(B, T, H, M, device='cuda')
519
+ >>> g = F.logsigmoid(torch.randn(B, T, H, M, device='cuda'))
520
+ >>> h0 = (torch.randn(B, H, K, M, device='cuda'), torch.randn(B, H, M, V, device='cuda'))
521
+ >>> o, (hk, hv) = fused_recurrent_gsa(q, k, v, s, g,
522
+ initial_state=h0,
523
+ output_final_state=True,
524
+ head_first=False)
525
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
526
+ >>> q, k, v, s, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, s, g))
527
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
528
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
529
+ >>> o_var, (hk_var, hv_var) = fused_recurrent_gsa(q, k, v, s, g,
530
+ initial_state=h0,
531
+ output_final_state=True,
532
+ cu_seqlens=cu_seqlens,
533
+ head_first=False)
534
+ >>> assert o.allclose(o_var.view(o.shape))
535
+ >>> assert hk.allclose(hk_var)
536
+ >>> assert hv.allclose(hv_var)
537
+ """
538
+ if cu_seqlens is not None:
539
+ if q.shape[0] != 1:
540
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
541
+ f"Please flatten variable-length inputs before processing.")
542
+ if head_first:
543
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
544
+ if initial_state is not None and initial_state[0].shape[0] != len(cu_seqlens) - 1:
545
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
546
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state[0].shape[0]}.")
547
+ if scale is None:
548
+ scale = k.shape[-1] ** -0.5
549
+ if initial_state is None:
550
+ initial_state = (None, None)
551
+ o, *final_state = FusedRecurrentGSAFunction.apply(
552
+ q,
553
+ k,
554
+ v,
555
+ s,
556
+ g,
557
+ scale,
558
+ *initial_state,
559
+ output_final_state,
560
+ reverse,
561
+ cu_seqlens,
562
+ head_first
563
+ )
564
+ return o, final_state
fla/ops/gsa/naive.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_gsa(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+
21
+ NG = q.shape[1]//k.shape[1]
22
+ # [batch_size, n_heads, seq_len, n_slots]
23
+ if g is None:
24
+ z = s.float().logcumsumexp(2)
25
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
26
+ s = torch.exp(s - z)
27
+ q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ return ov.to(dtype), final_state
fla/ops/hgrn/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_hgrn
4
+ from .fused_recurrent import fused_recurrent_hgrn
5
+
6
+ __all__ = [
7
+ 'chunk_hgrn',
8
+ 'fused_recurrent_hgrn'
9
+ ]
fla/ops/hgrn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (322 Bytes). View file
 
fla/ops/hgrn/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (16.9 kB). View file
 
fla/ops/hgrn/__pycache__/fused_recurrent.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
fla/ops/hgrn/fused_recurrent.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BD': BD}, num_warps=num_warps)
22
+ for BD in [32, 64, 128]
23
+ for num_warps in [1, 2, 4, 8]
24
+ ],
25
+ key=['D']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_hgrn_fwd_kernel(
29
+ x,
30
+ g,
31
+ o,
32
+ h0,
33
+ ht,
34
+ offsets,
35
+ T,
36
+ D: tl.constexpr,
37
+ BD: tl.constexpr,
38
+ USE_INITIAL_STATE: tl.constexpr,
39
+ STORE_FINAL_STATE: tl.constexpr,
40
+ USE_OFFSETS: tl.constexpr
41
+ ):
42
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
43
+ if USE_OFFSETS:
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_n * T, i_n * T + T
48
+
49
+ o_d = i_d * BD + tl.arange(0, BD)
50
+ mask = o_d < D
51
+
52
+ p_x = x + bos * D + o_d
53
+ p_g = g + bos * D + o_d
54
+ p_o = o + bos * D + o_d
55
+
56
+ b_h = tl.zeros([BD], dtype=tl.float32)
57
+ if USE_INITIAL_STATE:
58
+ p_h0 = h0 + i_n * D + o_d
59
+ b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
60
+ for _ in range(0, T):
61
+ b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
62
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
63
+ b_h = exp(b_g) * b_h + b_x
64
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
65
+
66
+ p_x += D
67
+ p_g += D
68
+ p_o += D
69
+
70
+ if STORE_FINAL_STATE:
71
+ p_ht = ht + i_n * D + o_d
72
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
73
+
74
+
75
+ @triton.heuristics({
76
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
77
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
78
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
79
+ })
80
+ @triton.autotune(
81
+ configs=[
82
+ triton.Config({'BD': BD}, num_warps=num_warps)
83
+ for BD in [32, 64, 128]
84
+ for num_warps in [1, 2, 4, 8]
85
+ ],
86
+ key=['D']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fused_recurrent_hgrn_bwd_kernel(
90
+ g,
91
+ o,
92
+ h0,
93
+ dx,
94
+ dg,
95
+ do,
96
+ dht,
97
+ dh0,
98
+ offsets,
99
+ T,
100
+ D: tl.constexpr,
101
+ BD: tl.constexpr,
102
+ USE_INITIAL_STATE: tl.constexpr,
103
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
104
+ USE_OFFSETS: tl.constexpr
105
+ ):
106
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
107
+ if USE_OFFSETS:
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_n * T, i_n * T + T
112
+
113
+ o_d = i_d * BD + tl.arange(0, BD)
114
+ mask = o_d < D
115
+
116
+ p_g = g + (bos + T - 1) * D + o_d
117
+ p_o = o + (bos + T - 2) * D + o_d
118
+ p_dx = dx + (bos + T - 1) * D + o_d
119
+ p_dg = dg + (bos + T - 1) * D + o_d
120
+ p_do = do + (bos + T - 1) * D + o_d
121
+
122
+ b_dh = tl.zeros([BD], dtype=tl.float32)
123
+ if USE_FINAL_STATE_GRADIENT:
124
+ p_dht = dht + i_n * D + o_d
125
+ b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32)
126
+
127
+ for i in range(T - 1, -1, -1):
128
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
129
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
130
+ if i > 0:
131
+ b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
132
+ elif USE_INITIAL_STATE:
133
+ b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32)
134
+ else:
135
+ b_o = tl.zeros([BD], dtype=tl.float32)
136
+
137
+ b_dh = b_dh + b_do
138
+ b_dx = b_dh
139
+ b_dh = b_dh * exp(b_g)
140
+ b_dg = b_dh * b_o
141
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
142
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
143
+
144
+ p_g -= D
145
+ p_o -= D
146
+ p_dx -= D
147
+ p_dg -= D
148
+ p_do -= D
149
+
150
+ if USE_INITIAL_STATE:
151
+ p_dh0 = dh0 + i_n * D + o_d
152
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask)
153
+
154
+
155
+ def fused_recurrent_hgrn_fwd(
156
+ x: torch.Tensor,
157
+ g: torch.Tensor,
158
+ initial_state: torch.Tensor = None,
159
+ output_final_state: bool = False,
160
+ offsets: Optional[torch.LongTensor] = None,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ B, T, D = x.shape
163
+ N = B if offsets is None else len(offsets) - 1
164
+
165
+ o = torch.empty_like(x)
166
+ final_state = x.new_empty(N, D) if output_final_state else None
167
+
168
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
169
+ fused_recurrent_hgrn_fwd_kernel[grid](
170
+ x=x,
171
+ g=g,
172
+ o=o,
173
+ h0=initial_state,
174
+ ht=final_state,
175
+ offsets=offsets,
176
+ T=T,
177
+ D=D
178
+ )
179
+ return o, final_state
180
+
181
+
182
+ def fused_recurrent_hgrn_bwd(
183
+ g: torch.Tensor,
184
+ o: torch.Tensor,
185
+ do: torch.Tensor,
186
+ dht: torch.Tensor = None,
187
+ initial_state: torch.Tensor = None,
188
+ offsets: Optional[torch.LongTensor] = None
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ B, T, D = do.shape
191
+ N = B if offsets is None else len(offsets) - 1
192
+
193
+ dx = torch.empty_like(o, dtype=torch.float)
194
+ dg = torch.empty_like(g, dtype=torch.float)
195
+ dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None
196
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
197
+ fused_recurrent_hgrn_bwd_kernel[grid](
198
+ g=g,
199
+ o=o,
200
+ h0=initial_state,
201
+ dx=dx,
202
+ dg=dg,
203
+ do=do,
204
+ dht=dht,
205
+ dh0=dh0,
206
+ offsets=offsets,
207
+ T=T,
208
+ D=D
209
+ )
210
+ return dx, dg, dh0
211
+
212
+
213
+ class FusedRecurrentHGRNFunction(torch.autograd.Function):
214
+
215
+ @staticmethod
216
+ @input_guard
217
+ def forward(
218
+ ctx,
219
+ x: torch.Tensor,
220
+ g: torch.Tensor,
221
+ initial_state: torch.Tensor = None,
222
+ output_final_state: bool = False,
223
+ offsets: Optional[torch.LongTensor] = None
224
+ ):
225
+ o, ht = fused_recurrent_hgrn_fwd(
226
+ x=x,
227
+ g=g,
228
+ initial_state=initial_state,
229
+ output_final_state=output_final_state,
230
+ offsets=offsets
231
+ )
232
+ ctx.save_for_backward(g, o, initial_state)
233
+ ctx.offsets = offsets
234
+ return o, ht
235
+
236
+ @staticmethod
237
+ @input_guard
238
+ def backward(ctx, do, dht=None):
239
+ g, o, initial_state = ctx.saved_tensors
240
+ offsets = ctx.offsets
241
+
242
+ dx, dg, dh0 = fused_recurrent_hgrn_bwd(
243
+ g=g,
244
+ o=o,
245
+ do=do,
246
+ dht=dht,
247
+ initial_state=initial_state,
248
+ offsets=offsets
249
+ )
250
+ return dx, dg, dh0, None, None
251
+
252
+
253
+ @torch.compiler.disable
254
+ def fused_recurrent_hgrn(
255
+ x: torch.Tensor,
256
+ g: torch.Tensor,
257
+ initial_state: torch.Tensor = None,
258
+ output_final_state: bool = False,
259
+ cu_seqlens: Optional[torch.LongTensor] = None,
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ r"""
262
+ Args:
263
+ x (torch.Tensor):
264
+ inputs of shape `[B, T, D].
265
+ g (torch.Tensor):
266
+ Forget gates of shape `[B, T, D]`.
267
+ initial_state (Optional[torch.Tensor]):
268
+ Initial state of shape `[N, D]` for `N` input sequences.
269
+ For equal-length input sequences, `N` equals the batch size `B`.
270
+ Default: `None`.
271
+ output_final_state (Optional[bool]):
272
+ Whether to output the final state of shape `[N, D]`. Default: `False`.
273
+ cu_seqlens (torch.LongTensor):
274
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
275
+ consistent with the FlashAttention API.
276
+
277
+ Returns:
278
+ o (torch.Tensor):
279
+ Outputs of shape `[B, T, D]`.
280
+ final_state (torch.Tensor):
281
+ Final state of shape `[N, D]` if `output_final_state=True` else `None`.
282
+
283
+ Examples::
284
+ >>> import torch
285
+ >>> import torch.nn.functional as F
286
+ >>> from einops import rearrange
287
+ >>> from fla.ops.hgrn import fused_recurrent_hgrn
288
+ # inputs with equal lengths
289
+ >>> B, T, D = 4, 2048, 512
290
+ >>> x = torch.randn(B, T, D, device='cuda')
291
+ >>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda'))
292
+ >>> h0 = torch.randn(B, D, device='cuda')
293
+ >>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True)
294
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
295
+ >>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g))
296
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
297
+ >>> cu_seqlens = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
298
+ >>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens)
299
+ >>> assert o.allclose(o_var.view(o.shape))
300
+ >>> assert ht.allclose(ht_var)
301
+ """
302
+ return FusedRecurrentHGRNFunction.apply(
303
+ x,
304
+ g,
305
+ initial_state,
306
+ output_final_state,
307
+ cu_seqlens
308
+ )
fla/ops/hgrn/naive.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ def naive_recurrent_hgrn(
9
+ x: torch.Tensor,
10
+ g: torch.Tensor,
11
+ initial_state: Optional[torch.Tensor] = None,
12
+ output_final_state: Optional[bool] = False
13
+ ) -> torch.Tensor:
14
+ dtype = x.dtype
15
+ x, g = map(lambda i: i.float(), (x, g))
16
+ B, T, D = x.shape
17
+
18
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
19
+ o = torch.zeros_like(x)
20
+
21
+ final_state = None
22
+ if initial_state is not None:
23
+ h += initial_state
24
+
25
+ for i in range(T):
26
+ h = g[:, i].exp() * h + x[:, i]
27
+ o[:, i] = h
28
+
29
+ if output_final_state:
30
+ final_state = h
31
+ return o.to(dtype), final_state
32
+
33
+
34
+ def naive_chunk_hgrn(
35
+ x: torch.Tensor,
36
+ g: torch.Tensor,
37
+ initial_state: Optional[torch.Tensor] = None,
38
+ output_final_state: Optional[bool] = False,
39
+ chunk_size: int = 64
40
+ ) -> torch.Tensor:
41
+ dtype = x.dtype
42
+ x, g = map(lambda i: i.float(), (x, g))
43
+ B, T, D = x.shape
44
+
45
+ gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g)
46
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
47
+ o = torch.zeros_like(x)
48
+
49
+ final_state = None
50
+ if initial_state is not None:
51
+ h += initial_state
52
+
53
+ for i in range(0, T, chunk_size):
54
+ hp = h
55
+ h = torch.zeros(B, D, dtype=torch.float, device=x.device)
56
+ for j in range(i, i + chunk_size):
57
+ h = g[:, j].exp() * h + x[:, j]
58
+ o[:, j] = hp * gc[:, j].exp() + h
59
+ h = o[:, j].clone()
60
+
61
+ if output_final_state:
62
+ final_state = h
63
+ return o.to(dtype), final_state
fla/ops/lightning_attn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (352 Bytes). View file
 
fla/ops/lightning_attn/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (3.8 kB). View file
 
fla/ops/lightning_attn/chunk.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.simple_gla.chunk import chunk_simple_gla
9
+
10
+
11
+ @torch.compiler.disable
12
+ def chunk_lightning_attn(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ v: torch.Tensor,
16
+ layer_idx: int,
17
+ num_layers: int,
18
+ scale: Optional[float] = None,
19
+ initial_state: Optional[torch.Tensor] = None,
20
+ output_final_state: bool = False,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ head_first: bool = True
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""
25
+ Args:
26
+ q (torch.Tensor):
27
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
28
+ k (torch.Tensor):
29
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
30
+ v (torch.Tensor):
31
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
32
+ layer_idx (int):
33
+ The index of the current layer.
34
+ num_layers (int):
35
+ The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor.
36
+ scale (Optional[int]):
37
+ Scale factor for the attention scores.
38
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
39
+ initial_state (Optional[torch.Tensor]):
40
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
41
+ For equal-length input sequences, `N` equals the batch size `B`.
42
+ Default: `None`.
43
+ output_final_state (Optional[bool]):
44
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
45
+ cu_seqlens (torch.LongTensor):
46
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
47
+ consistent with the FlashAttention API.
48
+ head_first (Optional[bool]):
49
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
50
+ Default: `True`.
51
+
52
+ Returns:
53
+ o (torch.Tensor):
54
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
55
+ final_state (torch.Tensor):
56
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
57
+ """
58
+ H = q.shape[1] if head_first else q.shape[2]
59
+ s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float)
60
+ if head_first:
61
+ g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
62
+ else:
63
+ g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
64
+ return chunk_simple_gla(
65
+ q=q,
66
+ k=k,
67
+ v=v,
68
+ scale=scale,
69
+ g=g,
70
+ initial_state=initial_state,
71
+ output_final_state=output_final_state,
72
+ head_first=head_first,
73
+ cu_seqlens=cu_seqlens
74
+ )
fla/ops/lightning_attn/fused_recurrent.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla
9
+
10
+
11
+ def fused_recurrent_lightning_attn(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ layer_idx: int,
16
+ num_layers: int,
17
+ scale: Optional[float] = None,
18
+ initial_state: Optional[torch.Tensor] = None,
19
+ output_final_state: bool = False,
20
+ reverse: bool = False,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ head_first: bool = True
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""
25
+ Args:
26
+ q (torch.Tensor):
27
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
28
+ k (torch.Tensor):
29
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
30
+ v (torch.Tensor):
31
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
32
+ layer_idx (int):
33
+ The index of the current layer.
34
+ num_layers (int):
35
+ The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor.
36
+ scale (Optional[int]):
37
+ Scale factor for the attention scores.
38
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
39
+ initial_state (Optional[torch.Tensor]):
40
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
41
+ For equal-length input sequences, `N` equals the batch size `B`.
42
+ Default: `None`.
43
+ output_final_state (Optional[bool]):
44
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
45
+ cu_seqlens (torch.LongTensor):
46
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
47
+ consistent with the FlashAttention API.
48
+ head_first (Optional[bool]):
49
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
50
+ Default: `True`.
51
+
52
+ Returns:
53
+ o (torch.Tensor):
54
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
55
+ final_state (torch.Tensor):
56
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
57
+ """
58
+ H = q.shape[1] if head_first else q.shape[2]
59
+ s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float)
60
+ if head_first:
61
+ g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
62
+ else:
63
+ g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
64
+ return fused_recurrent_simple_gla(
65
+ q=q,
66
+ k=k,
67
+ v=v,
68
+ g=g,
69
+ scale=scale,
70
+ initial_state=initial_state,
71
+ output_final_state=output_final_state,
72
+ reverse=reverse,
73
+ cu_seqlens=cu_seqlens,
74
+ head_first=head_first
75
+ )
fla/ops/linear_attn/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_linear_attn
4
+ from .fused_chunk import fused_chunk_linear_attn
5
+ from .fused_recurrent import fused_recurrent_linear_attn
6
+
7
+ __all__ = [
8
+ 'chunk_linear_attn',
9
+ 'fused_chunk_linear_attn',
10
+ 'fused_recurrent_linear_attn'
11
+ ]
fla/ops/linear_attn/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (427 Bytes). View file
 
fla/ops/linear_attn/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (2.62 kB). View file
 
fla/ops/linear_attn/__pycache__/fused_chunk.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
fla/ops/linear_attn/__pycache__/utils.cpython-311.pyc ADDED
Binary file (583 Bytes). View file
 
fla/ops/linear_attn/fused_chunk.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from packaging import version
10
+
11
+ from fla.ops.linear_attn.utils import normalize_output
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.jit
16
+ def fused_chunk_linear_attn_fwd_kernel(
17
+ q, # query [B, H, T, K]
18
+ k, # key [B, H, T, V]
19
+ v, # value [B, H, T, V]
20
+ o, # output [B, H, T, V]
21
+ h0,
22
+ ht,
23
+ scale,
24
+ B, # batch size
25
+ H, # H
26
+ T, # T
27
+ K: tl.constexpr, # K
28
+ V: tl.constexpr, # V
29
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
30
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
31
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
32
+ USE_INITIAL_STATE: tl.constexpr,
33
+ STORE_FINAL_STATE: tl.constexpr,
34
+ CHECK: tl.constexpr
35
+ ):
36
+ # indices
37
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
38
+
39
+ o_i = tl.arange(0, BT)
40
+
41
+ # [BT, BT]
42
+ m_s = o_i[:, None] >= o_i[None, :]
43
+ # [BK, BV]
44
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
45
+
46
+ # make block pointers
47
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0))
48
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1))
49
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
50
+ p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0))
51
+
52
+ if USE_INITIAL_STATE:
53
+ p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
54
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
55
+
56
+ for i in range(0, tl.cdiv(T, BT)):
57
+ # [BT, BK]
58
+ b_q = tl.load(p_q, boundary_check=(0, 1))
59
+ b_q = (b_q * scale).to(b_q.dtype)
60
+ # [BK, BT]
61
+ b_k = tl.load(p_k, boundary_check=(0, 1))
62
+ # [BT, BV]
63
+ b_v = tl.load(p_v, boundary_check=(0, 1))
64
+
65
+ # [BT, BT]
66
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
67
+ b_s = tl.where(m_s, b_s, 0)
68
+ # [BT, BV]
69
+ b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
70
+ if CHECK and i == 0:
71
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
72
+ b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
73
+ else:
74
+ b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
75
+ b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False)
76
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
77
+ p_q = tl.advance(p_q, (BT, 0))
78
+ p_k = tl.advance(p_k, (0, BT))
79
+ p_v = tl.advance(p_v, (BT, 0))
80
+ p_o = tl.advance(p_o, (BT, 0))
81
+
82
+ if STORE_FINAL_STATE:
83
+ p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
84
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
85
+
86
+
87
+ @triton.jit
88
+ def fused_chunk_linear_attn_bwd_kernel(
89
+ q, # query [B, H, T, K]
90
+ k, # key [B, H, T, V]
91
+ v, # value [B, H, T, V]
92
+ do, # gradient of output [B, H, T, V]
93
+ dq, # gradient of query [NV, B, H, T, K]
94
+ dk, # gradient of key [NV, B, H, T, K]
95
+ dv, # gradient of value [NK, B, H, T, V]
96
+ h0, # initial state of the chunk [B, H, K, V]
97
+ scale, # K ** -0.5
98
+ B, # B
99
+ H, # H
100
+ T, # T
101
+ K: tl.constexpr, # K
102
+ V: tl.constexpr, # V
103
+ BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size
104
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
105
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
106
+ USE_INITIAL_STATE: tl.constexpr,
107
+ CHECK: tl.constexpr
108
+ ):
109
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
110
+ o_i = tl.arange(0, BT)
111
+
112
+ m_s = o_i[:, None] >= o_i[None, :]
113
+ # [BV, BK]
114
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
115
+ if USE_INITIAL_STATE:
116
+ p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
117
+ b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
118
+
119
+ for i in range(0, tl.cdiv(T, BT)):
120
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0))
121
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1))
122
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0))
123
+ p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0))
124
+
125
+ # [BT, BK]
126
+ b_k = tl.load(p_k, boundary_check=(0, 1))
127
+ # [V, BT]
128
+ b_v = tl.load(p_v, boundary_check=(0, 1))
129
+ # [BT, V]
130
+ b_do = tl.load(p_do, boundary_check=(0, 1))
131
+
132
+ # [BT, BT]
133
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
134
+ b_ds = tl.where(m_s, b_ds, 0)
135
+ # [BT, BK]
136
+ b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)
137
+ # [BV, BK]
138
+ if CHECK and i == 0:
139
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
140
+ b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
141
+ else:
142
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False)
143
+ b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False)
144
+ b_dq *= scale
145
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
146
+
147
+ # sync threads
148
+ b_h = None
149
+ tl.debug_barrier()
150
+ # [BK, BV]
151
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
152
+ m_s = o_i[:, None] <= o_i[None, :]
153
+ for i in range(1, tl.cdiv(T, BT) + 1):
154
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1))
155
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0))
156
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
157
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0))
158
+ p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0))
159
+ p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0))
160
+ # [BK, BT]
161
+ b_q = tl.load(p_q, boundary_check=(0, 1))
162
+ b_q = (b_q * scale).to(b_q.dtype)
163
+ # [BT, BK]
164
+ b_k = tl.load(p_k, boundary_check=(0, 1))
165
+ # [BT, BV]
166
+ b_v = tl.load(p_v, boundary_check=(0, 1))
167
+ b_do = tl.load(p_do, boundary_check=(0, 1))
168
+
169
+ # [BT, BT]
170
+ b_s = tl.dot(b_k, b_q, allow_tf32=False)
171
+ b_s = tl.where(m_s, b_s, 0).to(b_q.dtype)
172
+ # [BT, BT]
173
+ b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False)
174
+ b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype)
175
+ # [BT, BK]
176
+ b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False)
177
+ # [BT, BV]
178
+ b_dv = tl.dot(b_s, b_do, allow_tf32=False)
179
+ if CHECK and i == 1:
180
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
181
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
182
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
183
+ else:
184
+ b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False)
185
+ b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
186
+ b_dh += tl.dot(b_q, b_do, allow_tf32=False)
187
+
188
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
189
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
190
+
191
+
192
+ class FusedChunkLinearAttentionFunction(torch.autograd.Function):
193
+
194
+ @staticmethod
195
+ @input_guard
196
+ @autocast_custom_fwd
197
+ def forward(ctx, q, k, v, scale, initial_state, output_final_state):
198
+ B, H, T, K, V = *k.shape, v.shape[-1]
199
+ BT = 64
200
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
201
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
202
+ num_warps = 4
203
+ num_stages = 1
204
+
205
+ o = q.new_empty(NK, B, H, T, V)
206
+ final_state = q.new_empty(B, H, K, V, dtype=torch.float) if output_final_state else None
207
+ # the bug still exists even for Triton 2.2 on H100 GPUs
208
+ # so we always enable initial checks
209
+ CHECK = True
210
+ if version.parse(triton.__version__) < version.parse('2.2.0'):
211
+ import warnings
212
+ warnings.warn(
213
+ "Triton<2.2.0 detected for running this kernel, "
214
+ "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) "
215
+ "that lead to significant precision loss. "
216
+ "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. "
217
+ "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)."
218
+ )
219
+ CHECK = True
220
+
221
+ grid = (NV, NK, B * H)
222
+ fused_chunk_linear_attn_fwd_kernel[grid](
223
+ q, k, v, o, initial_state, final_state,
224
+ scale,
225
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
226
+ USE_INITIAL_STATE=initial_state is not None,
227
+ STORE_FINAL_STATE=output_final_state,
228
+ CHECK=CHECK,
229
+ num_warps=num_warps,
230
+ num_stages=num_stages
231
+ )
232
+ o = o.sum(0) if NK > 1 else o[0]
233
+
234
+ ctx.save_for_backward(q, k, v, initial_state)
235
+ ctx.scale = scale
236
+ ctx.CHECK = CHECK
237
+ return o.to(q.dtype), final_state
238
+
239
+ @staticmethod
240
+ @input_guard
241
+ @autocast_custom_bwd
242
+ def backward(ctx, do, dht=None):
243
+ q, k, v, initial_state = ctx.saved_tensors
244
+ B, H, T, K, V = *k.shape, v.shape[-1]
245
+ scale = ctx.scale
246
+
247
+ BT = 64
248
+ BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64)
249
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
250
+ num_warps = 4
251
+ num_stages = 1
252
+
253
+ dq = q.new_empty(NV, B, H, T, K)
254
+ dk = q.new_empty(NV, B, H, T, K)
255
+ dv = q.new_empty(NK, B, H, T, V)
256
+ grid = (NV, NK, B * H)
257
+
258
+ fused_chunk_linear_attn_bwd_kernel[grid](
259
+ q, k, v, do, dq, dk, dv, initial_state,
260
+ scale,
261
+ B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,
262
+ USE_INITIAL_STATE=initial_state is not None,
263
+ CHECK=ctx.CHECK,
264
+ num_warps=num_warps,
265
+ num_stages=num_stages
266
+ )
267
+ dq = dq.sum(0)
268
+ dk = dk.sum(0)
269
+ dv = dv.sum(0)
270
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None
271
+
272
+
273
+ def fused_chunk_linear_attn(
274
+ q: torch.Tensor,
275
+ k: torch.Tensor,
276
+ v: torch.Tensor,
277
+ scale: Optional[float] = None,
278
+ initial_state: torch.Tensor = None,
279
+ output_final_state: bool = False,
280
+ normalize: bool = True,
281
+ head_first: bool = True
282
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
283
+ r"""
284
+ Args:
285
+ q (torch.Tensor):
286
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
287
+ k (torch.Tensor):
288
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
289
+ v (torch.Tensor):
290
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
291
+ scale (Optional[int]):
292
+ Scale factor for linear attention scores.
293
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
294
+ initial_state (Optional[torch.Tensor]):
295
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
296
+ output_final_state (Optional[bool]):
297
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
298
+ normalize (bool):
299
+ Whether to normalize the output. Default: `True`.
300
+ head_first (Optional[bool]):
301
+ Whether the inputs are in the head-first format. Default: `True`.
302
+
303
+ Returns:
304
+ o (torch.Tensor):
305
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
306
+ final_state (torch.Tensor):
307
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`
308
+ """
309
+ if scale is None:
310
+ scale = q.shape[-1] ** -0.5
311
+ if not head_first:
312
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
313
+ o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
314
+ if normalize:
315
+ o = normalize_output(q * scale, k, o)
316
+ if not head_first:
317
+ o = o.transpose(1, 2)
318
+ return o, final_state
fla/ops/nsa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .naive import naive_nsa
4
+ from .parallel import parallel_nsa
5
+
6
+ __all__ = [
7
+ 'naive_nsa',
8
+ 'parallel_nsa'
9
+ ]
fla/ops/nsa/__pycache__/naive.cpython-311.pyc ADDED
Binary file (6.28 kB). View file
 
fla/ops/nsa/__pycache__/parallel.cpython-311.pyc ADDED
Binary file (70.1 kB). View file
 
fla/ops/nsa/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.99 kB). View file
 
fla/ops/nsa/naive.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+
9
+
10
+ def naive_nsa(
11
+ q: torch.Tensor,
12
+ k: torch.Tensor,
13
+ v: torch.Tensor,
14
+ indices: torch.LongTensor,
15
+ block_size: int = 64,
16
+ scale: Optional[float] = None,
17
+ head_first: bool = False,
18
+ cu_seqlens: Optional[torch.LongTensor] = None
19
+ ) -> torch.Tensor:
20
+ r"""
21
+ Args:
22
+ q (torch.Tensor):
23
+ queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`.
24
+ k (torch.Tensor):
25
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
26
+ GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
27
+ v (torch.Tensor):
28
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
29
+ indices (torch.LongTensor):
30
+ Block indices of shape `[B, T, H, S]` if `head_first=True` else `[B, T, H, S]`.
31
+ `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
32
+ block_size (int):
33
+ Selected block size. Default: 64.
34
+ scale (Optional[int]):
35
+ Scale factor for attention scores.
36
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
37
+ head_first (Optional[bool]):
38
+ Whether the inputs are in the head-first format. Default: `False`.
39
+ cu_seqlens (torch.LongTensor):
40
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
41
+ consistent with the FlashAttention API.
42
+
43
+ Returns:
44
+ o (torch.Tensor):
45
+ Outputs of shape `[B, HQ, T, V]` if `head_first=True` else `[B, T, HQ, V]`.
46
+ """
47
+ if scale is None:
48
+ scale = k.shape[-1] ** -0.5
49
+ if cu_seqlens is not None:
50
+ if head_first:
51
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
52
+ if head_first:
53
+ q, k, v, indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, indices))
54
+
55
+ dtype = q.dtype
56
+ G = q.shape[2] // k.shape[2]
57
+ BS = block_size
58
+ k, v, indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, indices))
59
+ q, k, v = map(lambda x: x.float(), (q, k, v))
60
+
61
+ o = torch.zeros_like(v)
62
+ varlen = True
63
+ if cu_seqlens is None:
64
+ varlen = False
65
+ B, T = q.shape[:2]
66
+ cu_seqlens = torch.cat([indices.new_tensor(range(0, B*T, T)), indices.new_tensor([B*T])])
67
+
68
+ for i in range(len(cu_seqlens) - 1):
69
+ if not varlen:
70
+ q_b, k_b, v_b, i_b = q[i], k[i], v[i], indices[i]
71
+ else:
72
+ T = cu_seqlens[i+1] - cu_seqlens[i]
73
+ q_b, k_b, v_b, i_b = map(lambda x: x[0][cu_seqlens[i]:cu_seqlens[i+1]], (q, k, v, indices))
74
+
75
+ i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
76
+ # [T, S*BS, HQ]
77
+ i_b = i_b.view(T, indices.shape[2], -1).transpose(1, 2)
78
+ for i_q in range(T):
79
+ # [HQ, D]
80
+ q_i = q_b[i_q] * scale
81
+ # [S*BS, HQ]
82
+ i_i = i_b[i_q]
83
+ # [S*BS, HQ, -1]
84
+ k_i, v_i = map(lambda x: x.gather(0, i_i.clamp(0, T-1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
85
+ # [S*BS, HQ]
86
+ attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill(i_i > i_q, float('-inf')).softmax(0)
87
+ if not varlen:
88
+ o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
89
+ else:
90
+ o[0][cu_seqlens[i]+i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
91
+
92
+ if head_first:
93
+ o = rearrange(o, 'b t h d -> b h t d')
94
+ return o.to(dtype)
fla/ops/rebased/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (248 Bytes). View file
 
fla/ops/rebased/parallel.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
10
+
11
+ # Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models
12
+ # https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
13
+
14
+
15
+ @triton.jit(do_not_specialize=['T'])
16
+ def parallel_rebased_fwd_kernel(
17
+ q,
18
+ k,
19
+ v,
20
+ o,
21
+ z,
22
+ scale,
23
+ T,
24
+ B: tl.constexpr,
25
+ H: tl.constexpr,
26
+ K: tl.constexpr,
27
+ V: tl.constexpr,
28
+ BTL: tl.constexpr,
29
+ BTS: tl.constexpr,
30
+ BK: tl.constexpr,
31
+ BV: tl.constexpr,
32
+ ):
33
+ # i_c: chunk index. used for sequence parallelism
34
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+ NV = tl.cdiv(V, BV)
36
+ i_k = i_kv // (NV)
37
+ i_v = i_kv % (NV)
38
+
39
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
40
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, 0), (BK, BTS), (0, 1))
41
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v*BV), (BTS, BV), (1, 0))
42
+
43
+ # [BQ, BD] block Q, in the shared memory throughout the whole kernel
44
+ b_q = tl.load(p_q, boundary_check=(0, 1))
45
+ b_q = (b_q * scale).to(b_q.dtype)
46
+ b_o = tl.zeros([BTL, BV], dtype=tl.float32)
47
+ b_z = tl.zeros([BTL], dtype=tl.float32)
48
+
49
+ # Q block and K block have no overlap
50
+ # no need for mask, thereby saving flops
51
+ for _ in range(0, i_c*BTL, BTS):
52
+ # [BK, BTS]
53
+ b_k = tl.load(p_k, boundary_check=(0, 1))
54
+
55
+ # [BTS, BV]
56
+ b_v = tl.load(p_v, boundary_check=(0, 1))
57
+ # [BTL, BTS]
58
+ b_s = tl.dot(b_q, (b_k), allow_tf32=False)
59
+ b_s = b_s * b_s
60
+ b_z += tl.sum(b_s, axis=1)
61
+
62
+ # [BQ, BD]
63
+ b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
64
+ p_k = tl.advance(p_k, (0, BTS))
65
+ p_v = tl.advance(p_v, (BTS, 0))
66
+
67
+ # # rescale interchunk output
68
+ tl.debug_barrier()
69
+ o_q = tl.arange(0, BTL)
70
+ # # sync threads, easy for compiler to optimize
71
+ # tl.debug_barrier()
72
+
73
+ o_k = tl.arange(0, BTS)
74
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, i_c*BTL), (BK, BTS), (0, 1))
75
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTS, BV), (1, 0))
76
+ # Q block and K block have overlap. masks required
77
+ for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS):
78
+ # [BK, BTS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BTS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BTL, BTS]
83
+ m_s = o_q[:, None] >= o_k[None, :]
84
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
85
+ b_s = b_s * b_s
86
+ b_s = tl.where(m_s, b_s, 0)
87
+ b_z += tl.sum(b_s, axis=1)
88
+ # [BTL, BV]
89
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
90
+ p_k = tl.advance(p_k, (0, BTS))
91
+ p_v = tl.advance(p_v, (BTS, 0))
92
+ o_k += BTS
93
+
94
+ p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
95
+ p_z = z + (i_bh + B * H * i_k) * T + i_c*BTL + tl.arange(0, BTL)
96
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
97
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c*BTL + tl.arange(0, BTL)) < T))
98
+
99
+
100
+ @triton.jit(do_not_specialize=['T'])
101
+ def _parallel_rebased_bwd_dq(
102
+ i_bh,
103
+ i_c,
104
+ i_k,
105
+ i_v,
106
+ i_h,
107
+ q,
108
+ k,
109
+ v,
110
+ do,
111
+ dz,
112
+ dq,
113
+ scale,
114
+ T,
115
+ B: tl.constexpr,
116
+ H: tl.constexpr,
117
+ K: tl.constexpr,
118
+ V: tl.constexpr,
119
+ BTL: tl.constexpr,
120
+ BTS: tl.constexpr,
121
+ BK: tl.constexpr,
122
+ BV: tl.constexpr
123
+ ):
124
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
125
+ p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
126
+ b_q = tl.load(p_q, boundary_check=(0, 1))
127
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
128
+ b_q = (b_q * scale).to(b_q.dtype)
129
+ b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
130
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k*BK), (BTS, BK), (1, 0))
131
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, 0), (BV, BTS), (0, 1))
132
+ p_dz = dz + i_bh * T + i_c*BTL + tl.arange(0, BTL)
133
+ b_dz = tl.load(p_dz, mask=(i_c*BTL + tl.arange(0, BTL)) < T)
134
+
135
+ for _ in range(0, i_c*BTL, BTS):
136
+ # [BTS, BK]
137
+ b_k = tl.load(p_k, boundary_check=(0, 1))
138
+ # [BV, BTS]
139
+ b_v = tl.load(p_v, boundary_check=(0, 1))
140
+ # [BTL, BTS]
141
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
142
+ if i_v == 0:
143
+ b_ds += b_dz[:, None]
144
+ else:
145
+ b_ds = b_ds
146
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
147
+ # [BQ, BD]
148
+ b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)
149
+ p_k = tl.advance(p_k, (BTS, 0))
150
+ p_v = tl.advance(p_v, (0, BTS))
151
+
152
+ b_dq *= scale
153
+ o_q = tl.arange(0, BTL)
154
+ o_k = tl.arange(0, BTS)
155
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTS, BK), (1, 0))
156
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, i_c*BTL), (BV, BTS), (0, 1))
157
+ # Q block and K block have overlap. masks required
158
+ for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS):
159
+ # [BTS, BK]
160
+ b_k = tl.load(p_k, boundary_check=(0, 1))
161
+ # [BV, BTS]
162
+ b_v = tl.load(p_v, boundary_check=(0, 1))
163
+ # [BTL, BTS]
164
+ m_s = o_q[:, None] >= o_k[None, :]
165
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
166
+ if i_v == 0:
167
+ b_ds += b_dz[:, None]
168
+ else:
169
+ b_ds = b_ds
170
+ b_ds = tl.where(m_s, b_ds, 0) * scale
171
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
172
+ b_s = tl.where(m_s, b_s, 0)
173
+ # [BTL, BK]
174
+ b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype),
175
+ b_k, allow_tf32=False)
176
+ p_k = tl.advance(p_k, (BTS, 0))
177
+ p_v = tl.advance(p_v, (0, BTS))
178
+ o_k += BTS
179
+ p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
180
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
181
+ return
182
+
183
+
184
+ @triton.jit(do_not_specialize=['T'])
185
+ def _parallel_rebased_bwd_dkv(
186
+ i_bh,
187
+ i_c,
188
+ i_k,
189
+ i_v,
190
+ i_h,
191
+ q,
192
+ k,
193
+ v,
194
+ do,
195
+ dz,
196
+ dk,
197
+ dv,
198
+ scale,
199
+ T,
200
+ B: tl.constexpr,
201
+ H: tl.constexpr,
202
+ K: tl.constexpr,
203
+ V: tl.constexpr,
204
+ BTL: tl.constexpr,
205
+ BTS: tl.constexpr,
206
+ BK: tl.constexpr,
207
+ BV: tl.constexpr,
208
+ ):
209
+ # compute dk dv
210
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
211
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
212
+ b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1))
213
+ b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
214
+ [BTL, BV], dtype=tl.float32)
215
+
216
+ for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
217
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1))
218
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1))
219
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
220
+ # [BK, BTS]
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ # [BV, BTS]
223
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
224
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
225
+ # [BTL, BTS]
226
+ b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale
227
+ b_s2 = b_s * b_s
228
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
229
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
230
+ if i_v == 0:
231
+ b_ds += b_dz[None, :] * scale
232
+ else:
233
+ b_ds = b_ds
234
+ b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
235
+
236
+ tl.debug_barrier()
237
+ o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
238
+ for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
239
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1))
240
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1))
241
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
242
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
243
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
244
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
245
+ # [BK, BQ]
246
+ m_s = o_k[:, None] <= o_q[None, :]
247
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
248
+ b_s2 = b_s * b_s
249
+ b_s = tl.where(m_s, b_s, 0)
250
+ b_s2 = tl.where(m_s, b_s2, 0)
251
+
252
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False)
253
+ if i_v == 0:
254
+ b_ds += b_dz[None, :]
255
+ else:
256
+ b_ds = b_ds
257
+ b_ds = tl.where(m_s, b_ds, 0) * scale
258
+ # [BK, BD]
259
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
260
+ b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
261
+ o_q += BTS
262
+
263
+ p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
264
+ p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
265
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
266
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
267
+ return
268
+
269
+
270
+ @triton.jit(do_not_specialize=['T'])
271
+ def parallel_rebased_bwd_kernel(
272
+ q,
273
+ k,
274
+ v,
275
+ do,
276
+ dz,
277
+ dq,
278
+ dk,
279
+ dv,
280
+ scale,
281
+ T,
282
+ B: tl.constexpr,
283
+ H: tl.constexpr,
284
+ K: tl.constexpr,
285
+ V: tl.constexpr,
286
+ BTL: tl.constexpr,
287
+ BTS: tl.constexpr,
288
+ BK: tl.constexpr,
289
+ BV: tl.constexpr
290
+ ):
291
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
292
+ NV = tl.cdiv(V, BV)
293
+ i_k = i_kv // (NV)
294
+ i_v = i_kv % (NV)
295
+ i_h = i_bh % H
296
+ _parallel_rebased_bwd_dq(
297
+ i_bh,
298
+ i_c,
299
+ i_k,
300
+ i_v,
301
+ i_h,
302
+ q,
303
+ k,
304
+ v,
305
+ do,
306
+ dz,
307
+ dq,
308
+ scale,
309
+ B=B,
310
+ H=H,
311
+ T=T,
312
+ K=K,
313
+ V=V,
314
+ BTL=BTL,
315
+ BTS=BTS,
316
+ BK=BK,
317
+ BV=BV
318
+ )
319
+ tl.debug_barrier()
320
+ _parallel_rebased_bwd_dkv(
321
+ i_bh,
322
+ i_c,
323
+ i_k,
324
+ i_v,
325
+ i_h,
326
+ q,
327
+ k,
328
+ v,
329
+ do,
330
+ dz,
331
+ dk,
332
+ dv,
333
+ scale,
334
+ B=B,
335
+ H=H,
336
+ T=T,
337
+ K=K,
338
+ V=V,
339
+ BTL=BTL,
340
+ BTS=BTS,
341
+ BK=BK,
342
+ BV=BV
343
+ )
344
+
345
+
346
+ class ParallelBasedFunction(torch.autograd.Function):
347
+
348
+ @staticmethod
349
+ @input_guard
350
+ @autocast_custom_fwd
351
+ def forward(ctx, q, k, v, scale):
352
+ BTL, BTS = 128, 32
353
+ assert BTL % BTS == 0
354
+ # assert q.shape[-1] % 16 == 0
355
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
356
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
357
+ BK, BV = max(BK, 16), max(BV, 16)
358
+ B, H, T, K, V = *k.shape, v.shape[-1]
359
+ num_stages = 2
360
+ num_warps = 4
361
+ NK = triton.cdiv(K, BK)
362
+ NV = triton.cdiv(V, BV)
363
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
364
+
365
+ assert NK == 1, "will encounter some synchronization issue if not."
366
+
367
+ o = torch.empty(NK, B, H, T, V, device=q.device)
368
+ z = torch.empty(NK, B, H, T, device=q.device)
369
+ parallel_rebased_fwd_kernel[grid](
370
+ q,
371
+ k,
372
+ v,
373
+ o,
374
+ z,
375
+ scale,
376
+ T=T,
377
+ B=B,
378
+ H=H,
379
+ K=K,
380
+ V=V,
381
+ BTL=BTL,
382
+ BTS=BTS,
383
+ BK=BK,
384
+ BV=BV,
385
+ num_warps=num_warps,
386
+ num_stages=num_stages
387
+ )
388
+ ctx.save_for_backward(q, k, v)
389
+ ctx.scale = scale
390
+ return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
391
+
392
+ @staticmethod
393
+ @input_guard
394
+ @autocast_custom_bwd
395
+ def backward(ctx, do, dz):
396
+ q, k, v = ctx.saved_tensors
397
+ scale = ctx.scale
398
+ BTL, BTS = 64, 32
399
+ assert BTL % BTS == 0
400
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
401
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
402
+ BK, BV = max(BK, 16), max(BV, 16)
403
+ B, H, T, K, V = *k.shape, v.shape[-1]
404
+ num_stages = 2
405
+ num_warps = 4
406
+ NK = triton.cdiv(K, BK)
407
+ NV = triton.cdiv(V, BV)
408
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
409
+
410
+ assert NK == 1, "will encounter some synchronization issue if not"
411
+
412
+ dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
413
+ dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
414
+ dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)
415
+
416
+ parallel_rebased_bwd_kernel[grid](
417
+ q,
418
+ k,
419
+ v,
420
+ do,
421
+ dz,
422
+ dq,
423
+ dk,
424
+ dv,
425
+ scale,
426
+ T=T,
427
+ B=B,
428
+ H=H,
429
+ K=K,
430
+ V=V,
431
+ BTL=BTL,
432
+ BTS=BTS,
433
+ BK=BK,
434
+ BV=BV,
435
+ num_warps=num_warps,
436
+ num_stages=num_stages
437
+ )
438
+
439
+ return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
440
+
441
+
442
+ def parallel_rebased(
443
+ q: torch.Tensor,
444
+ k: torch.Tensor,
445
+ v: torch.Tensor,
446
+ eps: float = 1e-5,
447
+ use_scale: bool = True,
448
+ use_normalize: bool = True,
449
+ return_both: bool = False,
450
+ head_first: bool = True
451
+ ):
452
+ assert q.shape[-1] <= 128, "only support feature dim up to 128"
453
+ if use_scale:
454
+ scale = q.shape[-1] ** -0.5
455
+ else:
456
+ scale = 1
457
+ if not head_first:
458
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
459
+ o, z = ParallelBasedFunction.apply(q, k, v, scale)
460
+ if return_both:
461
+ return o, z
462
+ if use_normalize:
463
+ o = o / (z[..., None] + eps)
464
+ if not head_first:
465
+ o = o.transpose(1, 2)
466
+ return o.to(q.dtype)
fla/ops/retention/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_retention
4
+ from .fused_chunk import fused_chunk_retention
5
+ from .fused_recurrent import fused_recurrent_retention
6
+ from .parallel import parallel_retention
7
+
8
+ __all__ = [
9
+ 'chunk_retention',
10
+ 'fused_chunk_retention',
11
+ 'parallel_retention',
12
+ 'fused_recurrent_retention'
13
+ ]
fla/ops/retention/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (3.7 kB). View file
 
fla/ops/retention/__pycache__/parallel.cpython-311.pyc ADDED
Binary file (3.25 kB). View file
 
fla/ops/retention/chunk.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.simple_gla.chunk import chunk_simple_gla
9
+
10
+
11
+ @torch.compiler.disable
12
+ def chunk_retention(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ v: torch.Tensor,
16
+ scale: Optional[float] = None,
17
+ initial_state: Optional[torch.Tensor] = None,
18
+ output_final_state: bool = False,
19
+ cu_seqlens: Optional[torch.LongTensor] = None,
20
+ head_first: bool = True
21
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
22
+ r"""
23
+ Args:
24
+ q (torch.Tensor):
25
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
26
+ k (torch.Tensor):
27
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
28
+ v (torch.Tensor):
29
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
30
+ scale (Optional[int]):
31
+ Scale factor for the attention scores.
32
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
33
+ initial_state (Optional[torch.Tensor]):
34
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
35
+ For equal-length input sequences, `N` equals the batch size `B`.
36
+ Default: `None`.
37
+ output_final_state (Optional[bool]):
38
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
39
+ cu_seqlens (torch.LongTensor):
40
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
41
+ consistent with the FlashAttention API.
42
+ head_first (Optional[bool]):
43
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
44
+ Default: `True`.
45
+
46
+ Returns:
47
+ o (torch.Tensor):
48
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
49
+ final_state (torch.Tensor):
50
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
51
+
52
+ """
53
+ if head_first:
54
+ n_heads = q.shape[1]
55
+ else:
56
+ n_heads = q.shape[2]
57
+ s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log()
58
+ if head_first:
59
+ g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
60
+ else:
61
+ g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
62
+ return chunk_simple_gla(
63
+ q=q,
64
+ k=k,
65
+ v=v,
66
+ scale=scale,
67
+ g=g,
68
+ initial_state=initial_state,
69
+ output_final_state=output_final_state,
70
+ head_first=head_first,
71
+ cu_seqlens=cu_seqlens
72
+ )
fla/ops/retention/fused_recurrent.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla
9
+
10
+
11
+ def fused_recurrent_retention(
12
+ q: torch.Tensor,
13
+ k: torch.Tensor,
14
+ v: torch.Tensor,
15
+ scale: Optional[float] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: bool = False,
18
+ reverse: bool = False,
19
+ cu_seqlens: Optional[torch.LongTensor] = None,
20
+ head_first: bool = True
21
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
22
+ if head_first:
23
+ n_heads = q.shape[1]
24
+ else:
25
+ n_heads = q.shape[2]
26
+ s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log()
27
+ if head_first:
28
+ g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
29
+ else:
30
+ g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
31
+ return fused_recurrent_simple_gla(
32
+ q=q,
33
+ k=k,
34
+ v=v,
35
+ g=g,
36
+ scale=scale,
37
+ initial_state=initial_state,
38
+ output_final_state=output_final_state,
39
+ reverse=reverse,
40
+ cu_seqlens=cu_seqlens,
41
+ head_first=head_first
42
+ )
fla/ops/retention/naive.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+
5
+
6
+ def naive_retention(q, k, v):
7
+ orig_type = q.dtype
8
+ q, k, v = q.float(), k.float(), v.float()
9
+ _, n_heads, seq_len, d_head = q.shape
10
+ s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2()
11
+ n = q.new_tensor(range(seq_len), dtype=torch.float)
12
+ n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n)
13
+ s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype))
14
+ o = torch.einsum('bhqk,bhkd->bhqd', s, v)
15
+ return o.to(orig_type)
fla/ops/rwkv4/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .fused_recurrent import fused_recurrent_rwkv4
4
+
5
+ __all__ = [
6
+ 'fused_recurrent_rwkv4'
7
+ ]
fla/ops/rwkv4/fused_recurrent.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Any, cast
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from torch import Tensor
10
+ from torch.autograd.function import Function, FunctionCtx, once_differentiable
11
+
12
+ from fla.ops.utils.op import exp
13
+
14
+
15
+ def get_block_size_c(chans: int) -> int:
16
+ if chans < 32:
17
+ return 32
18
+ if chans < 64:
19
+ return 64
20
+ return 128
21
+
22
+
23
+ @triton.jit
24
+ def fused_recurrent_rwkv4_forward_kernel(
25
+ # W
26
+ w_ptr,
27
+ w_s_c,
28
+ # U
29
+ u_ptr,
30
+ u_s_c,
31
+ # K
32
+ k_ptr,
33
+ k_s_b,
34
+ k_s_t,
35
+ k_s_c,
36
+ # V
37
+ v_ptr,
38
+ v_s_b,
39
+ v_s_t,
40
+ v_s_c,
41
+ # State
42
+ state_ptr,
43
+ state_s_b,
44
+ state_s_abe,
45
+ state_s_c,
46
+ # WKV
47
+ wkv_ptr,
48
+ wkv_s_b,
49
+ wkv_s_t,
50
+ wkv_s_c,
51
+ # Output state
52
+ state_out_ptr,
53
+ state_out_s_b,
54
+ state_out_s_abe,
55
+ state_out_s_t,
56
+ state_out_s_c,
57
+ # Params
58
+ chans,
59
+ tsz,
60
+ BLOCK_SIZE_C: tl.constexpr,
61
+ ):
62
+ # Parallelize over the batch dimension.
63
+ b_idx = tl.program_id(0)
64
+ c_idx = tl.program_id(1)
65
+
66
+ cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
67
+ cmask = cs < chans
68
+
69
+ # Pointers to the batch (and possibly channel) for the input tensors.
70
+ k_ptr = k_ptr + b_idx * k_s_b
71
+ v_ptr = v_ptr + b_idx * v_s_b
72
+ alpha_ptr = state_ptr + b_idx * state_s_b
73
+ beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
74
+ eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
75
+
76
+ # Pointers to the batch (and possibly channel) for the output tensors.
77
+ wkv_ptr = wkv_ptr + b_idx * wkv_s_b
78
+ alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b
79
+ beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe
80
+ eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe
81
+
82
+ # Loads parameters.
83
+ alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
84
+ beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
85
+ eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32)
86
+ w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32)
87
+ u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32)
88
+
89
+ for t in range(tsz):
90
+ kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32)
91
+ vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32)
92
+
93
+ ukt = u + kt
94
+ tau = tl.maximum(ukt, eps)
95
+ e1a = exp(eps - tau)
96
+ e2a = exp(ukt - tau)
97
+ wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a)
98
+ tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask)
99
+
100
+ w_eps = w + eps
101
+ eps = tl.maximum(w_eps, kt)
102
+ e1b = exp(w_eps - eps)
103
+ e2b = exp(kt - eps)
104
+ alpha = e1b * alpha + e2b * vt
105
+ beta = e1b * beta + e2b
106
+ tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask)
107
+ tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask)
108
+ tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask)
109
+
110
+
111
+ def fused_recurrent_rwkv4_forward(
112
+ w: Tensor,
113
+ u: Tensor,
114
+ k: Tensor,
115
+ v: Tensor,
116
+ state: Tensor,
117
+ ) -> tuple[Tensor, Tensor]:
118
+ (bsz, tsz, chans) = k.shape
119
+
120
+ # New tensors to output.
121
+ wkvs = k.new_empty(bsz, tsz, chans)
122
+ state_out = k.new_empty(bsz, 3, tsz, chans)
123
+
124
+ # Constants.
125
+ block_size_c = get_block_size_c(chans)
126
+
127
+ def grid(meta: dict[str, Any]) -> tuple[int, ...]:
128
+ return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
129
+
130
+ fused_recurrent_rwkv4_forward_kernel[grid](
131
+ # W
132
+ w,
133
+ w.stride(0),
134
+ # U
135
+ u,
136
+ u.stride(0),
137
+ # K
138
+ k,
139
+ k.stride(0),
140
+ k.stride(1),
141
+ k.stride(2),
142
+ # V
143
+ v,
144
+ v.stride(0),
145
+ v.stride(1),
146
+ v.stride(2),
147
+ # State
148
+ state,
149
+ state.stride(0),
150
+ state.stride(1),
151
+ state.stride(3),
152
+ # WKV
153
+ wkvs,
154
+ wkvs.stride(0),
155
+ wkvs.stride(1),
156
+ wkvs.stride(2),
157
+ # Output state
158
+ state_out,
159
+ state_out.stride(0),
160
+ state_out.stride(1),
161
+ state_out.stride(2),
162
+ state_out.stride(3),
163
+ # Params
164
+ chans,
165
+ tsz,
166
+ BLOCK_SIZE_C=block_size_c,
167
+ )
168
+
169
+ state_out = torch.cat((state, state_out), dim=2)
170
+
171
+ return wkvs, state_out
172
+
173
+
174
+ @triton.jit
175
+ def fused_recurrent_rwkv4_backward_kernel(
176
+ # W
177
+ w_ptr,
178
+ w_s_c,
179
+ # U
180
+ u_ptr,
181
+ u_s_c,
182
+ # K
183
+ k_ptr,
184
+ k_s_b,
185
+ k_s_t,
186
+ k_s_c,
187
+ # V
188
+ v_ptr,
189
+ v_s_b,
190
+ v_s_t,
191
+ v_s_c,
192
+ # State
193
+ state_ptr,
194
+ state_s_b,
195
+ state_s_abe,
196
+ state_s_t,
197
+ state_s_c,
198
+ # WKV grad
199
+ gwkv_ptr,
200
+ gwkv_s_b,
201
+ gwkv_s_t,
202
+ gwkv_s_c,
203
+ # Output state grad
204
+ gstate_out_ptr,
205
+ gstate_out_s_b,
206
+ gstate_out_s_abe,
207
+ gstate_out_s_c,
208
+ # W grad
209
+ gw_ptr,
210
+ gw_s_c,
211
+ # U grad
212
+ gu_ptr,
213
+ gu_s_c,
214
+ # K grad
215
+ gk_ptr,
216
+ gk_s_b,
217
+ gk_s_t,
218
+ gk_s_c,
219
+ # V grad
220
+ gv_ptr,
221
+ gv_s_b,
222
+ gv_s_t,
223
+ gv_s_c,
224
+ # State grad
225
+ gstate_ptr,
226
+ gstate_s_b,
227
+ gstate_s_abe,
228
+ gstate_s_c,
229
+ # Params
230
+ tsz,
231
+ chans,
232
+ BLOCK_SIZE_C: tl.constexpr,
233
+ ):
234
+ # Parallelize over the batch dimension.
235
+ b_idx = tl.program_id(0)
236
+ c_idx = tl.program_id(1)
237
+
238
+ cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C)
239
+ cmask = cs < chans
240
+
241
+ # Pointers to the batch (and possibly channel) for the input tensors.
242
+ k_ptr = k_ptr + b_idx * k_s_b
243
+ v_ptr = v_ptr + b_idx * v_s_b
244
+ alpha_ptr = state_ptr + b_idx * state_s_b
245
+ beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe
246
+ eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe
247
+
248
+ # Pointers to the batch (and possibly channel) for the output tensors.
249
+ gk_ptr = gk_ptr + b_idx * gk_s_b
250
+ gv_ptr = gv_ptr + b_idx * gv_s_b
251
+
252
+ # Pointers to gradients which were recieved by the function.
253
+ gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b
254
+ galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b
255
+ gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe
256
+ geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe
257
+
258
+ # Loads parameters.
259
+ galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
260
+ gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
261
+ geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32)
262
+ w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32)
263
+ u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32)
264
+
265
+ # Gradient accumulators.
266
+ gw = tl.zeros_like(w)
267
+ gu = tl.zeros_like(u)
268
+
269
+ alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
270
+ beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
271
+ eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
272
+
273
+ for t in range(tsz):
274
+ tc = tsz - t - 1
275
+
276
+ kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32)
277
+ vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32)
278
+
279
+ alpha_curr = alpha_prev
280
+ beta_curr = beta_prev
281
+ eps_curr = eps_prev
282
+
283
+ alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
284
+ beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
285
+ eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32)
286
+
287
+ ukt = u + kt
288
+ tau = tl.maximum(ukt, eps_prev)
289
+ e1 = exp(eps_prev - tau)
290
+ e2 = exp(ukt - tau)
291
+
292
+ euke = exp(ukt + eps_prev - 2 * tau)
293
+
294
+ denom = e1 * beta_prev + e2
295
+ denom_sq = denom * denom
296
+
297
+ gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32)
298
+
299
+ # Backpropagates wkv gradients.
300
+ guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq
301
+ gu += guk
302
+ gk = guk
303
+ gv = gwkvt * e2 / denom
304
+
305
+ galpha_wkv = gwkvt * e1 / denom
306
+ gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq
307
+ geps_wkv_denom = e1 * beta_prev + e2
308
+ geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom)
309
+
310
+ e1 = exp(w + eps_prev - eps_curr)
311
+ e2 = exp(kt - eps_curr)
312
+
313
+ # Backpropagates alpha gradients.
314
+ galpha_we = galpha * e1 * alpha_prev
315
+ gw += galpha_we
316
+ gk += galpha * e2 * vt
317
+ gv += galpha * e2
318
+ geps += galpha * -alpha_curr
319
+
320
+ # Backpropagates beta gradients.
321
+ gbeta_we = gbeta * e1 * beta_prev
322
+ gw += gbeta_we
323
+ gk += gbeta * e2
324
+ geps += gbeta * -beta_curr
325
+
326
+ # Backpropagates epsilon gradients.
327
+ geps_mask = w + eps_prev > kt
328
+ geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps))
329
+ gw += geps_we
330
+ gk += tl.where(geps_mask, tl.zeros_like(geps), geps)
331
+
332
+ # Stores the gradients for k and v.
333
+ tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask)
334
+ tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask)
335
+
336
+ # Computes new gradients for alpha and beta.
337
+ galpha = galpha * e1 + galpha_wkv
338
+ gbeta = gbeta * e1 + gbeta_wkv
339
+ geps = galpha_we + gbeta_we + geps_we + geps_wkv
340
+
341
+ # Stores final gradients for alpha and beta.
342
+ galpha_ptr = gstate_ptr + b_idx * gstate_s_b
343
+ gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe
344
+ geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe
345
+ tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask)
346
+ tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask)
347
+ tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask)
348
+
349
+ # Stores final gradients for w and u.
350
+ gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32)
351
+ gw_temp += gw
352
+ tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask)
353
+ gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32)
354
+ gu_temp += gu
355
+ tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask)
356
+
357
+
358
+ def fused_recurrent_rwkv4_backward(
359
+ w: Tensor,
360
+ u: Tensor,
361
+ k: Tensor,
362
+ v: Tensor,
363
+ state: Tensor,
364
+ grad_wkv: Tensor,
365
+ grad_state: Tensor,
366
+ ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
367
+ bsz, tsz, chans = k.shape
368
+
369
+ gw = torch.zeros_like(w) # New tensors to output.
370
+ gu = torch.zeros_like(u)
371
+ gk = torch.empty_like(k)
372
+ gv = torch.empty_like(v)
373
+ gstate = k.new_empty(bsz, 3, 1, chans)
374
+
375
+ block_size_c = get_block_size_c(chans) # Constants.
376
+
377
+ def grid(meta: dict[str, Any]) -> tuple[int, ...]:
378
+ return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"]))
379
+
380
+ fused_recurrent_rwkv4_backward_kernel[grid](
381
+ # W
382
+ w,
383
+ w.stride(0),
384
+ # U
385
+ u,
386
+ u.stride(0),
387
+ # K
388
+ k,
389
+ k.stride(0),
390
+ k.stride(1),
391
+ k.stride(2),
392
+ # V
393
+ v,
394
+ v.stride(0),
395
+ v.stride(1),
396
+ v.stride(2),
397
+ # State
398
+ state,
399
+ state.stride(0),
400
+ state.stride(1),
401
+ state.stride(2),
402
+ state.stride(3),
403
+ # WKV grad
404
+ grad_wkv,
405
+ grad_wkv.stride(0),
406
+ grad_wkv.stride(1),
407
+ grad_wkv.stride(2),
408
+ # Output state grad
409
+ grad_state,
410
+ grad_state.stride(0),
411
+ grad_state.stride(1),
412
+ grad_state.stride(3),
413
+ # W grad
414
+ gw,
415
+ gw.stride(0),
416
+ # U grad
417
+ gu,
418
+ gu.stride(0),
419
+ # K grad
420
+ gk,
421
+ gk.stride(0),
422
+ gk.stride(1),
423
+ gk.stride(2),
424
+ # V grad
425
+ gv,
426
+ gv.stride(0),
427
+ gv.stride(1),
428
+ gv.stride(2),
429
+ # State grad
430
+ gstate,
431
+ gstate.stride(0),
432
+ gstate.stride(1),
433
+ gstate.stride(3),
434
+ # Params
435
+ tsz,
436
+ chans,
437
+ BLOCK_SIZE_C=block_size_c,
438
+ )
439
+
440
+ return gw, gu, gk, gv, gstate
441
+
442
+
443
+ class FusedRecurrentRWKV4Function(Function):
444
+ @staticmethod
445
+ def forward(
446
+ ctx: FunctionCtx,
447
+ w: Tensor,
448
+ u: Tensor,
449
+ k: Tensor,
450
+ v: Tensor,
451
+ state: Tensor,
452
+ ) -> tuple[Tensor, Tensor]:
453
+ ctx.input_dtype = k.dtype
454
+
455
+ w = -torch.exp(w.float().contiguous())
456
+ if k.dtype == torch.float16:
457
+ u = u.float()
458
+ k = k.float()
459
+ v = v.float()
460
+ u = u.contiguous()
461
+ k = k.contiguous()
462
+ v = v.contiguous()
463
+ wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state)
464
+ ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1])
465
+ return wkv, state_out[:, :, -1:]
466
+
467
+ @staticmethod
468
+ @once_differentiable
469
+ def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
470
+ w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors)
471
+ gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate)
472
+ return gw, gu, gk, gv, gstate
473
+
474
+
475
+ def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]:
476
+ return FusedRecurrentRWKV4Function.apply(w, u, k, v, state)
fla/ops/rwkv6/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_rwkv6
4
+ from .fused_recurrent import fused_recurrent_rwkv6
5
+
6
+ __all__ = [
7
+ 'chunk_rwkv6',
8
+ 'fused_recurrent_rwkv6'
9
+ ]
fla/ops/rwkv6/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (325 Bytes). View file
 
fla/ops/rwkv6/__pycache__/chunk.cpython-311.pyc ADDED
Binary file (83 kB). View file
 
fla/ops/rwkv6/__pycache__/fused_recurrent.cpython-311.pyc ADDED
Binary file (40.3 kB). View file