Erland commited on
Commit
e4ebaab
·
verified ·
1 Parent(s): f0698db

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. configs/delta_net_340M.json +27 -0
  2. fla/layers/__init__.py +44 -0
  3. fla/layers/__pycache__/bitattn.cpython-311.pyc +0 -0
  4. fla/layers/__pycache__/gated_deltanet.cpython-311.pyc +0 -0
  5. fla/layers/__pycache__/hgrn2.cpython-311.pyc +0 -0
  6. fla/layers/__pycache__/lightnet.cpython-311.pyc +0 -0
  7. fla/layers/__pycache__/rebased.cpython-311.pyc +0 -0
  8. fla/layers/__pycache__/rwkv7.cpython-311.pyc +0 -0
  9. fla/layers/abc.py +218 -0
  10. fla/layers/delta_net.py +291 -0
  11. fla/layers/gated_deltanet.py +293 -0
  12. fla/layers/gated_deltaproduct.py +351 -0
  13. fla/layers/gla.py +294 -0
  14. fla/layers/gsa.py +227 -0
  15. fla/layers/hgrn.py +168 -0
  16. fla/layers/rebased.py +133 -0
  17. fla/layers/rwkv6.py +307 -0
  18. fla/layers/simple_gla.py +261 -0
  19. fla/models/abc/__init__.py +13 -0
  20. fla/models/abc/__pycache__/__init__.cpython-311.pyc +0 -0
  21. fla/models/abc/modeling_abc.py +418 -0
  22. fla/models/bitnet/__init__.py +13 -0
  23. fla/models/bitnet/__pycache__/__init__.cpython-311.pyc +0 -0
  24. fla/models/bitnet/__pycache__/configuration_bitnet.cpython-311.pyc +0 -0
  25. fla/models/bitnet/__pycache__/modeling_bitnet.cpython-311.pyc +0 -0
  26. fla/models/bitnet/configuration_bitnet.py +67 -0
  27. fla/models/delta_net/__init__.py +12 -0
  28. fla/models/forgetting_transformer/__pycache__/__init__.cpython-311.pyc +0 -0
  29. fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-311.pyc +0 -0
  30. fla/models/forgetting_transformer/configuration_forgetting_transformer.py +68 -0
  31. fla/models/gated_deltanet/__pycache__/__init__.cpython-311.pyc +0 -0
  32. fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-311.pyc +0 -0
  33. fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py +90 -0
  34. fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py +520 -0
  35. fla/models/gla/__init__.py +13 -0
  36. fla/models/gla/__pycache__/__init__.cpython-311.pyc +0 -0
  37. fla/models/gsa/__pycache__/__init__.cpython-311.pyc +0 -0
  38. fla/models/gsa/__pycache__/configuration_gsa.cpython-311.pyc +0 -0
  39. fla/models/gsa/configuration_gsa.py +97 -0
  40. fla/models/gsa/modeling_gsa.py +420 -0
  41. fla/models/hgrn/modeling_hgrn.py +420 -0
  42. fla/models/linear_attn/configuration_linear_attn.py +91 -0
  43. fla/models/nsa/modeling_nsa.py +398 -0
  44. flame/__init__.py +1 -0
  45. flame/__pycache__/__init__.cpython-311.pyc +0 -0
  46. flame/__pycache__/train.cpython-311.pyc +0 -0
  47. flame/models/parallelize_fla.py +550 -0
  48. flame/train.py +897 -0
  49. flame/utils/convert_dcp_to_hf.py +66 -0
  50. generation_config.json +6 -0
configs/delta_net_340M.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attn_mode": "chunk",
3
+ "bos_token_id": 1,
4
+ "conv_size": 4,
5
+ "eos_token_id": 2,
6
+ "expand_k": 1,
7
+ "expand_v": 1,
8
+ "fuse_cross_entropy": true,
9
+ "hidden_act": "swish",
10
+ "hidden_ratio": 4,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.006,
13
+ "intermediate_size": null,
14
+ "model_type": "delta_net",
15
+ "norm_eps": 1e-06,
16
+ "norm_first": false,
17
+ "num_heads": 8,
18
+ "num_hidden_layers": 24,
19
+ "qk_activation": "silu",
20
+ "qk_norm": "l2",
21
+ "tie_word_embeddings": false,
22
+ "use_beta": true,
23
+ "use_cache": true,
24
+ "use_gate": false,
25
+ "use_output_norm": true,
26
+ "use_short_conv": true
27
+ }
fla/layers/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from .abc import ABCAttention
5
+ from .attn import Attention
6
+ from .based import BasedLinearAttention
7
+ from .bitattn import BitAttention
8
+ from .delta_net import DeltaNet
9
+ from .forgetting_attn import ForgettingAttention
10
+ from .gated_deltanet import GatedDeltaNet
11
+ from .gated_deltaproduct import GatedDeltaProduct
12
+ from .gla import GatedLinearAttention
13
+ from .gsa import GatedSlotAttention
14
+ from .hgrn import HGRNAttention
15
+ from .hgrn2 import HGRN2Attention
16
+ from .lightnet import LightNetAttention
17
+ from .linear_attn import LinearAttention
18
+ from .multiscale_retention import MultiScaleRetention
19
+ from .nsa import NativeSparseAttention
20
+ from .rebased import ReBasedLinearAttention
21
+ from .rwkv6 import RWKV6Attention
22
+ from .rwkv7 import RWKV7Attention
23
+
24
+ __all__ = [
25
+ 'ABCAttention',
26
+ 'Attention',
27
+ 'BasedLinearAttention',
28
+ 'BitAttention',
29
+ 'DeltaNet',
30
+ 'ForgettingAttention',
31
+ 'GatedDeltaNet',
32
+ 'GatedDeltaProduct',
33
+ 'GatedLinearAttention',
34
+ 'GatedSlotAttention',
35
+ 'HGRNAttention',
36
+ 'HGRN2Attention',
37
+ 'LightNetAttention',
38
+ 'LinearAttention',
39
+ 'MultiScaleRetention',
40
+ 'NativeSparseAttention',
41
+ 'ReBasedLinearAttention',
42
+ 'RWKV6Attention',
43
+ 'RWKV7Attention',
44
+ ]
fla/layers/__pycache__/bitattn.cpython-311.pyc ADDED
Binary file (9.62 kB). View file
 
fla/layers/__pycache__/gated_deltanet.cpython-311.pyc ADDED
Binary file (13.9 kB). View file
 
fla/layers/__pycache__/hgrn2.cpython-311.pyc ADDED
Binary file (9.09 kB). View file
 
fla/layers/__pycache__/lightnet.cpython-311.pyc ADDED
Binary file (9.33 kB). View file
 
fla/layers/__pycache__/rebased.cpython-311.pyc ADDED
Binary file (7.17 kB). View file
 
fla/layers/__pycache__/rwkv7.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
fla/layers/abc.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution
14
+ from fla.modules.activations import swiglu, swish
15
+ from fla.ops.abc.chunk import chunk_abc
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class ABCAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int = 1024,
26
+ expand_k: float = 0.5,
27
+ expand_v: float = 1.0,
28
+ num_heads: int = 4,
29
+ use_short_conv: bool = False,
30
+ conv_size: int = 4,
31
+ conv_bias: bool = False,
32
+ num_slots: Optional[int] = None,
33
+ elementwise_affine: Optional[bool] = True,
34
+ norm_eps: float = 1e-5,
35
+ gate_low_rank_dim: int = 16,
36
+ gate_logit_normalizer: int = 16,
37
+ use_rope: bool = True,
38
+ use_input_gate: bool = False,
39
+ use_output_gate: bool = True,
40
+ use_norm: bool = True,
41
+ clamp_min: Optional[float] = -32,
42
+ clamp_max: Optional[float] = 32,
43
+ layer_idx: Optional[int] = None,
44
+ **kwargs
45
+ ) -> ABCAttention:
46
+ super().__init__()
47
+
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.num_heads = num_heads
52
+ self.key_dim = int(self.hidden_size * self.expand_k)
53
+ self.value_dim = int(self.hidden_size * self.expand_v)
54
+ self.head_k_dim = self.key_dim // self.num_heads
55
+ self.head_v_dim = self.value_dim // self.num_heads
56
+
57
+ self.use_short_conv = use_short_conv
58
+ self.conv_size = conv_size
59
+ self.conv_bias = conv_bias
60
+
61
+ self.gate_low_rank_dim = gate_low_rank_dim
62
+ self.gate_logit_normalizer = gate_logit_normalizer
63
+
64
+ self.use_rope = use_rope
65
+ self.use_input_gate = use_input_gate
66
+ self.use_output_gate = use_output_gate
67
+ self.use_norm = use_norm
68
+
69
+ if num_slots is None:
70
+ num_slots = self.head_k_dim
71
+ self.num_slots = num_slots
72
+
73
+ self.norm_eps = norm_eps
74
+
75
+ self.clamp_min = clamp_min
76
+ self.clamp_max = clamp_max
77
+ self.layer_idx = layer_idx
78
+
79
+ if layer_idx is None:
80
+ warnings.warn(
81
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
82
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
83
+ "when creating this class."
84
+ )
85
+
86
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
87
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
88
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
89
+
90
+ if use_output_gate:
91
+ self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
92
+ self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
93
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
94
+
95
+ if use_short_conv:
96
+ self.conv_size = conv_size
97
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
98
+ self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
99
+ self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
100
+
101
+ if self.use_norm:
102
+ if self.use_output_gate:
103
+ self.g_norm = FusedRMSNormGated(
104
+ hidden_size=self.head_v_dim,
105
+ elementwise_affine=elementwise_affine,
106
+ eps=norm_eps
107
+ )
108
+ else:
109
+ self.g_norm = RMSNorm(
110
+ hidden_size=self.head_v_dim,
111
+ elementwise_affine=elementwise_affine,
112
+ eps=norm_eps
113
+ )
114
+
115
+ if self.use_rope:
116
+ self.rotary = RotaryEmbedding(self.head_k_dim)
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: Optional[torch.Tensor] = None,
122
+ past_key_values: Optional[Cache] = None,
123
+ use_cache: Optional[bool] = False,
124
+ output_attentions: Optional[bool] = False,
125
+ **kwargs
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
127
+ if attention_mask is not None:
128
+ assert len(attention_mask.shape) == 2, (
129
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
130
+ "for padding purposes (0 indicating padding). "
131
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
132
+ )
133
+
134
+ last_state = None
135
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
136
+ last_state = past_key_values[self.layer_idx]
137
+
138
+ cu_seqlens = kwargs.get('cu_seqlens', None)
139
+ if cu_seqlens is not None:
140
+ raise NotImplementedError("Training with cu_seqlens is not supported yet for ABCAttention")
141
+ if self.use_short_conv:
142
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
143
+ if last_state is not None:
144
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
145
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
146
+ q, conv_state_q = self.q_conv1d(
147
+ x=self.q_proj(hidden_states),
148
+ mask=conv_mask,
149
+ cache=conv_state_q,
150
+ output_final_state=use_cache,
151
+ cu_seqlens=cu_seqlens
152
+ )
153
+ k, conv_state_k = self.k_conv1d(
154
+ x=self.k_proj(hidden_states),
155
+ mask=conv_mask,
156
+ cache=conv_state_k,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens
159
+ )
160
+ v, conv_state_v = self.v_conv1d(
161
+ x=self.v_proj(hidden_states),
162
+ mask=conv_mask,
163
+ cache=conv_state_v,
164
+ output_final_state=use_cache,
165
+ cu_seqlens=cu_seqlens
166
+ )
167
+ else:
168
+ q = self.q_proj(hidden_states)
169
+ k = self.k_proj(hidden_states)
170
+ v = self.v_proj(hidden_states)
171
+
172
+ if self.use_input_gate:
173
+ q, k, v = map(lambda x: swish(x), (q, k, v))
174
+ # dealing with left-padding
175
+ if attention_mask is not None:
176
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
177
+
178
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
179
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
180
+ if self.use_rope:
181
+ seqlen_offset = 0
182
+ if past_key_values is not None:
183
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
184
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset)
185
+
186
+ s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', m=self.num_slots)
187
+ s = s.clamp_(self.clamp_min, self.clamp_max)
188
+
189
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
190
+ o, recurrent_state = chunk_abc(
191
+ q=q,
192
+ k=k,
193
+ v=v,
194
+ s=s,
195
+ initial_state=recurrent_state,
196
+ output_final_state=use_cache,
197
+ head_first=False
198
+ )
199
+ if past_key_values is not None:
200
+ past_key_values.update(
201
+ recurrent_state=recurrent_state,
202
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
203
+ layer_idx=self.layer_idx,
204
+ offset=q.shape[1]
205
+ )
206
+
207
+ if self.use_norm and not self.use_output_gate:
208
+ o = self.g_norm(o)
209
+ elif self.use_output_gate:
210
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
211
+ o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
212
+ o = rearrange(o, '... h d -> ... (h d)')
213
+ o = self.o_proj(o)
214
+
215
+ return o, None, past_key_values
216
+
217
+ def state_size(self, seq_len: int = 2048):
218
+ return 2 * self.num_slots * self.hidden_size
fla/layers/delta_net.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
+
16
+ if TYPE_CHECKING:
17
+ from transformers.processing_utils import Unpack
18
+
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ def elu_p1(x):
23
+ return (F.elu(x, 1., False) + 1.).to(x)
24
+
25
+
26
+ def sum_norm(x):
27
+ return (x / x.sum(-1, keepdim=True)).to(x)
28
+
29
+
30
+ class DeltaNet(nn.Module):
31
+ r"""
32
+ The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa:
33
+ DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa
34
+
35
+ Args:
36
+ mode (str, Optional):
37
+ Which DeltaNet kernel to use.
38
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
39
+ Default: `chunk`.
40
+ hidden_size (int, Optional):
41
+ The hidden size of the input. Default: 1024.
42
+ expand_k (float, Optional):
43
+ The expansion ratio for the key dim. Default: 1.0.
44
+ expand_v (float, Optional):
45
+ The expansion ratio for the value dim. Default: 1.0.
46
+ num_heads (int, Optional):
47
+ The number of heads. Default: 4.
48
+ use_beta (bool, Optional):
49
+ Whether to use beta. Default: `True`.
50
+ use_gate (bool, Optional):
51
+ Whether to use output gate. Default: `False`.
52
+ use_short_conv (bool, Optional):
53
+ Whether to use short convolutions. Default: `True`.
54
+ conv_size (int, Optional):
55
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
56
+ conv_bias (bool, Optional):
57
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
58
+ allow_neg_eigval (bool, Optional):
59
+ Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
60
+ See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
61
+ layer_idx (int, Optional):
62
+ The index of the layer. Default: None.
63
+ norm_eps (float, Optional):
64
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
65
+ qk_activation (str, Optional):
66
+ The activation function for the query and key. Default: `silu`.
67
+ qk_norm (str, Optional):
68
+ The normalization method for the query and key. Default: `l2`.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ d_model: int = None,
75
+ hidden_size: int = 1024,
76
+ expand_k: float = 1.0,
77
+ expand_v: float = 1.0,
78
+ num_heads: int = 4,
79
+ use_beta: bool = True,
80
+ use_gate: bool = False,
81
+ use_short_conv: bool = True,
82
+ conv_size: int = 4,
83
+ conv_bias: bool = False,
84
+ allow_neg_eigval: bool = False,
85
+ layer_idx: int = None,
86
+ qk_activation: str = 'silu',
87
+ qk_norm: str = 'l2',
88
+ norm_eps: float = 1e-5,
89
+ **kwargs
90
+ ) -> DeltaNet:
91
+ super().__init__()
92
+
93
+ self.mode = mode
94
+ self.qk_activation = qk_activation
95
+ self.qk_norm = qk_norm
96
+
97
+ assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
98
+ assert self.qk_norm in ['l2', 'sum']
99
+
100
+ if d_model is not None:
101
+ hidden_size = d_model
102
+ self.hidden_size = hidden_size
103
+ self.expand_k = expand_k
104
+ self.expand_v = expand_v
105
+ self.num_heads = num_heads
106
+ self.use_gate = use_gate
107
+ self.use_short_conv = use_short_conv
108
+ self.conv_size = conv_size
109
+ self.conv_bias = conv_bias
110
+ self.allow_neg_eigval = allow_neg_eigval
111
+
112
+ self.key_dim = int(hidden_size * expand_k)
113
+ self.value_dim = int(hidden_size * expand_v)
114
+ self.head_k_dim = self.key_dim // num_heads
115
+ self.head_v_dim = self.value_dim // num_heads
116
+ self.layer_idx = layer_idx
117
+
118
+ self.silu = nn.SiLU()
119
+ if mode == 'fused_chunk':
120
+ raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.")
121
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
122
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
123
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
124
+
125
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
126
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
127
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ self.use_beta = use_beta
130
+ if self.use_beta:
131
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
132
+ if use_short_conv:
133
+ self.conv_size = conv_size
134
+ self.q_conv1d = ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation='silu' if qk_activation == 'silu' else None
138
+ )
139
+ self.k_conv1d = ShortConvolution(
140
+ hidden_size=self.key_dim,
141
+ kernel_size=conv_size,
142
+ activation='silu' if qk_activation == 'silu' else None
143
+ )
144
+ self.v_conv1d = ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation='silu'
148
+ )
149
+ else:
150
+ raise UserWarning(
151
+ "ShortConvolution is crucial to the performance. "
152
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
153
+ )
154
+ if use_gate:
155
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
156
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
157
+ else:
158
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
159
+
160
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ attention_mask: Optional[torch.Tensor] = None,
166
+ past_key_values: Optional[Cache] = None,
167
+ use_cache: Optional[bool] = False,
168
+ output_attentions: Optional[bool] = False,
169
+ **kwargs: Unpack[Dict]
170
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
171
+ if attention_mask is not None:
172
+ assert len(attention_mask.shape) == 2, (
173
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
174
+ "for padding purposes (0 indicating padding). "
175
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
176
+ )
177
+
178
+ # change to inference mode.
179
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
180
+
181
+ last_state = None
182
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
183
+ last_state = past_key_values[self.layer_idx]
184
+
185
+ cu_seqlens = kwargs.get('cu_seqlens', None)
186
+ if self.use_short_conv:
187
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
188
+ if last_state is not None:
189
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
190
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
191
+ q, conv_state_q = self.q_conv1d(
192
+ x=self.q_proj(hidden_states),
193
+ mask=conv_mask,
194
+ cache=conv_state_q,
195
+ output_final_state=use_cache,
196
+ cu_seqlens=cu_seqlens
197
+ )
198
+ k, conv_state_k = self.k_conv1d(
199
+ x=self.k_proj(hidden_states),
200
+ mask=conv_mask,
201
+ cache=conv_state_k,
202
+ output_final_state=use_cache,
203
+ cu_seqlens=cu_seqlens
204
+ )
205
+ v, conv_state_v = self.v_conv1d(
206
+ x=self.v_proj(hidden_states),
207
+ mask=conv_mask,
208
+ cache=conv_state_v,
209
+ output_final_state=use_cache,
210
+ cu_seqlens=cu_seqlens
211
+ )
212
+ else:
213
+ q = self.q_proj(hidden_states)
214
+ k = self.k_proj(hidden_states)
215
+ if self.qk_activation == 'silu':
216
+ q, k = self.silu(q), self.silu(k)
217
+ v = self.silu(self.v_proj(hidden_states))
218
+
219
+ q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
220
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
221
+ if self.qk_activation != 'silu':
222
+ if self.qk_activation == 'relu':
223
+ q, k = q.relu(), k.relu()
224
+ elif self.qk_activation == 'elu':
225
+ q, k = elu_p1(q), elu_p1(k)
226
+ elif self.qk_activation == 'identity':
227
+ pass
228
+ else:
229
+ raise NotImplementedError
230
+
231
+ if self.qk_norm == 'sum':
232
+ q = sum_norm(q).to(q)
233
+ k = sum_norm(k).to(k)
234
+
235
+ if self.use_beta:
236
+ beta = self.b_proj(hidden_states).sigmoid()
237
+ else:
238
+ beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
239
+
240
+ if self.allow_neg_eigval:
241
+ beta = beta * 2.
242
+
243
+ # dealing with padding
244
+ if attention_mask is not None:
245
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
246
+
247
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
248
+ if mode == 'fused_recurrent':
249
+ o, recurrent_state = fused_recurrent_delta_rule(
250
+ q=q,
251
+ k=k,
252
+ v=v,
253
+ beta=beta,
254
+ initial_state=recurrent_state,
255
+ output_final_state=use_cache,
256
+ cu_seqlens=cu_seqlens,
257
+ head_first=False,
258
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
259
+ )
260
+ elif mode == 'chunk':
261
+ o, recurrent_state = chunk_delta_rule(
262
+ q=q,
263
+ k=k,
264
+ v=v,
265
+ beta=beta,
266
+ initial_state=recurrent_state,
267
+ output_final_state=use_cache,
268
+ cu_seqlens=cu_seqlens,
269
+ head_first=False,
270
+ use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
271
+ )
272
+ else:
273
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
274
+
275
+ if past_key_values is not None:
276
+ past_key_values.update(
277
+ recurrent_state=recurrent_state,
278
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
279
+ layer_idx=self.layer_idx,
280
+ offset=q.shape[1]
281
+ )
282
+
283
+ if self.use_gate:
284
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
285
+ o = self.o_norm(o, g)
286
+ else:
287
+ o = self.o_norm(o)
288
+ o = rearrange(o, 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+
291
+ return o, None, past_key_values
fla/layers/gated_deltanet.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from torch.nn import functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
16
+
17
+ if TYPE_CHECKING:
18
+ from transformers.processing_utils import Unpack
19
+
20
+ from fla.models.utils import Cache
21
+
22
+
23
+ @torch.compile
24
+ def elu_p1(x):
25
+ return (F.elu(x, 1., False) + 1.).to(x)
26
+
27
+
28
+ @torch.compile
29
+ def sum_norm(x):
30
+ return (x / x.sum(-1, keepdim=True)).to(x)
31
+
32
+
33
+ class GatedDeltaNet(nn.Module):
34
+ """
35
+ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
36
+
37
+ Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
38
+
39
+ Parameter alloation when use_gate=True:
40
+ - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
41
+ - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
42
+ - Others are ignorably small.
43
+ - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
44
+ NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
45
+
46
+ Parameter allocation when use_gate=False:
47
+ - 1 * hidden_size * hidden_size for the q_proj and k_proj each
48
+ - 2 * hidden_size * hidden_size for the v_proj and o_proj each
49
+ - Others are ignorably small.
50
+ - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
51
+
52
+ Args:
53
+ hidden_size (int, Optional):
54
+ The hidden size of the input. Default: 2048.
55
+ expand_v (float, Optional):
56
+ The expansion ratio for the value dim. Default: 2.0.
57
+ head_dim (int, Optional):
58
+ The dimension of each head. Default: 256.
59
+ num_heads (int, Optional):
60
+ The number of heads. Default: 4.
61
+ mode (str, Optional):
62
+ Which Gated DeltaNet kernel to use.
63
+ Currently available: `chunk` and `fused_recurrent`.
64
+ Default: `chunk`.
65
+ use_beta (bool, Optional):
66
+ Whether to use beta. Default: `True`.
67
+ use_gate (bool, Optional):
68
+ Whether to use output gate. Default: `True`.
69
+ use_short_conv (bool, Optional):
70
+ Whether to use short convolutions. Default: `True`.
71
+ conv_size (int, Optional):
72
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
73
+ conv_bias (bool, Optional):
74
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
75
+ layer_idx (int, Optional):
76
+ The index of the layer. Default: None.
77
+ norm_eps (float, Optional):
78
+ The epsilon value for the normalization layer. Default: 1e-5.
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size: int = 2048,
84
+ expand_v: float = 2,
85
+ head_dim: int = 256,
86
+ num_heads: int = 6,
87
+ mode: str = 'chunk',
88
+ use_gate: bool = True,
89
+ use_short_conv: bool = True,
90
+ conv_size: int = 4,
91
+ conv_bias: bool = False,
92
+ layer_idx: int = None,
93
+ norm_eps: float = 1e-5,
94
+ **kwargs
95
+ ) -> GatedDeltaNet:
96
+ super().__init__()
97
+
98
+ self.mode = mode
99
+
100
+ self.hidden_size = hidden_size
101
+ self.expand_v = expand_v
102
+
103
+ self.use_gate = use_gate
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+
108
+ self.head_dim = head_dim
109
+ self.num_heads = num_heads
110
+
111
+ self.key_dim = int(self.num_heads * self.head_dim)
112
+ self.value_dim = int(self.key_dim * self.expand_v)
113
+ self.head_k_dim = head_dim
114
+ self.head_v_dim = int(head_dim * self.expand_v)
115
+ self.layer_idx = layer_idx
116
+
117
+ # Consistency check: Ensure expand_v produces integer values
118
+ if not math.isclose(self.key_dim * expand_v, self.value_dim, rel_tol=1e-5):
119
+ raise ValueError(
120
+ f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
121
+ f"Resulting value_dim would be {self.key_dim * expand_v}, which is invalid for nn.Linear."
122
+ )
123
+ if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
124
+ raise ValueError(
125
+ f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
126
+ f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated."
127
+ )
128
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
129
+
130
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
131
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
132
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
133
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
134
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
135
+
136
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
137
+ self.A_log = nn.Parameter(torch.log(A))
138
+ self.A_log._no_weight_decay = True
139
+ # hard coded for now
140
+ dt_min = 0.001
141
+ dt_max = 0.1
142
+ dt_init_floor = 1e-4
143
+ dt = torch.exp(
144
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
145
+ + math.log(dt_min)
146
+ )
147
+ dt = torch.clamp(dt, min=dt_init_floor)
148
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
149
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
150
+ self.dt_bias = nn.Parameter(inv_dt)
151
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
152
+ # name.endswith("bias") in param_grouping.py
153
+ self.dt_bias._no_weight_decay = True
154
+
155
+ if use_short_conv:
156
+ self.conv_size = conv_size
157
+ self.q_conv1d = ShortConvolution(
158
+ hidden_size=self.key_dim,
159
+ kernel_size=conv_size,
160
+ activation='silu'
161
+ )
162
+ self.k_conv1d = ShortConvolution(
163
+ hidden_size=self.key_dim,
164
+ kernel_size=conv_size,
165
+ activation='silu'
166
+ )
167
+ self.v_conv1d = ShortConvolution(
168
+ hidden_size=self.value_dim,
169
+ kernel_size=conv_size,
170
+ activation='silu'
171
+ )
172
+ else:
173
+ raise UserWarning(
174
+ "ShortConvolution is crucial to the performance. "
175
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
176
+ )
177
+ if use_gate:
178
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
179
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
180
+ else:
181
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
182
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
183
+
184
+ def forward(
185
+ self,
186
+ hidden_states: torch.Tensor,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ past_key_values: Optional[Cache] = None,
189
+ use_cache: Optional[bool] = False,
190
+ output_attentions: Optional[bool] = False,
191
+ **kwargs: Unpack[Dict]
192
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
193
+ if attention_mask is not None:
194
+ assert len(attention_mask.shape) == 2, (
195
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
196
+ "for padding purposes (0 indicating padding). "
197
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
198
+ )
199
+
200
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
201
+ if self.training:
202
+ assert mode == 'chunk', "Only chunk mode is supported in training."
203
+
204
+ last_state = None
205
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
206
+ last_state = past_key_values[self.layer_idx]
207
+
208
+ cu_seqlens = kwargs.get('cu_seqlens', None)
209
+ if self.use_short_conv:
210
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
211
+ if last_state is not None:
212
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
213
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
214
+ q, conv_state_q = self.q_conv1d(
215
+ x=self.q_proj(hidden_states),
216
+ mask=conv_mask,
217
+ cache=conv_state_q,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens
220
+ )
221
+ k, conv_state_k = self.k_conv1d(
222
+ x=self.k_proj(hidden_states),
223
+ mask=conv_mask,
224
+ cache=conv_state_k,
225
+ output_final_state=use_cache,
226
+ cu_seqlens=cu_seqlens
227
+ )
228
+ v, conv_state_v = self.v_conv1d(
229
+ x=self.v_proj(hidden_states),
230
+ mask=conv_mask,
231
+ cache=conv_state_v,
232
+ output_final_state=use_cache,
233
+ cu_seqlens=cu_seqlens
234
+ )
235
+ else:
236
+ q = F.silu(self.q_proj(hidden_states))
237
+ k = F.silu(self.k_proj(hidden_states))
238
+ v = F.silu(self.v_proj(hidden_states))
239
+
240
+ q, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (q, k))
241
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
242
+ beta = self.b_proj(hidden_states).sigmoid()
243
+ g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
244
+
245
+ # dealing with padding
246
+ if attention_mask is not None:
247
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
248
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
249
+
250
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
251
+ if mode == 'chunk':
252
+ o, recurrent_state = chunk_gated_delta_rule(
253
+ q=q,
254
+ k=k,
255
+ v=v,
256
+ g=g,
257
+ beta=beta,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False,
262
+ use_qk_l2norm_in_kernel=True
263
+ )
264
+ elif mode == 'fused_recurrent':
265
+ o, recurrent_state = fused_recurrent_gated_delta_rule(
266
+ q=q,
267
+ k=k,
268
+ v=v,
269
+ g=g,
270
+ beta=beta,
271
+ initial_state=recurrent_state,
272
+ output_final_state=use_cache,
273
+ cu_seqlens=cu_seqlens,
274
+ head_first=False,
275
+ use_qk_l2norm_in_kernel=True
276
+ )
277
+ if past_key_values is not None:
278
+ past_key_values.update(
279
+ recurrent_state=recurrent_state,
280
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
281
+ layer_idx=self.layer_idx,
282
+ offset=q.shape[1]
283
+ )
284
+
285
+ if self.use_gate:
286
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
287
+ o = self.o_norm(o, g)
288
+ else:
289
+ o = self.o_norm(o)
290
+ o = rearrange(o, 'b t h d -> b t (h d)')
291
+ o = self.o_proj(o)
292
+
293
+ return o, None, past_key_values
fla/layers/gated_deltaproduct.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
12
+ from fla.ops.delta_rule import chunk_delta_rule
13
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule
14
+
15
+ if TYPE_CHECKING:
16
+ from transformers.processing_utils import Unpack
17
+
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ def elu_p1(x):
22
+ return (F.elu(x, 1.0, False) + 1.0).to(x)
23
+
24
+
25
+ def sum_norm(x):
26
+ return (x / x.sum(-1, keepdim=True)).to(x)
27
+
28
+
29
+ def interleave_multiple_sequences(*sequences):
30
+ """
31
+ Interleave multiple sequences together.
32
+ For example, with sequences [A1, A2], [B1, B2], [C1, C2],
33
+ returns [A1, B1, C1, A2, B2, C2]
34
+ """
35
+ if isinstance(sequences[0], (list, tuple)):
36
+ sequences = sequences[0]
37
+
38
+ if len(sequences) == 1:
39
+ return sequences[0]
40
+
41
+ # All sequences should have the same shape
42
+ assert all(s.shape == sequences[0].shape for s in sequences)
43
+
44
+ # Get the original shape
45
+ batch_size, seq_len, *rest = sequences[0].shape
46
+
47
+ # Stack sequences along a new dimension
48
+ stacked = torch.stack(sequences, dim=2)
49
+
50
+ # Reshape to interleave
51
+ reshaped = stacked.view(batch_size, seq_len * len(sequences), *rest)
52
+
53
+ return reshaped
54
+
55
+
56
+ class GatedDeltaProduct(nn.Module):
57
+ """
58
+ Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ hidden_size: int = 2048,
64
+ expand_v: float = 2,
65
+ head_dim: int = 256,
66
+ num_heads: int = 6,
67
+ num_householder: int = 2, # New parameter for number of householder transformations
68
+ mode: str = "chunk",
69
+ use_gate: bool = True,
70
+ use_forget_gate: bool = True, # when true Gated DeltaProduct, when false DeltaProduct
71
+ use_short_conv: bool = True,
72
+ conv_size: int = 4,
73
+ conv_bias: bool = False,
74
+ layer_idx: int | None = None,
75
+ norm_eps: float = 1e-5,
76
+ allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
77
+ **kwargs,
78
+ ) -> None:
79
+ super().__init__()
80
+
81
+ self.mode = mode
82
+ self.hidden_size = hidden_size
83
+ self.expand_v = expand_v
84
+ self.use_gate = use_gate
85
+ self.use_short_conv = use_short_conv
86
+ self.conv_size = conv_size
87
+ self.conv_bias = conv_bias
88
+ self.head_dim = head_dim
89
+ self.num_heads = num_heads
90
+ self.num_householder = num_householder
91
+ self.allow_neg_eigval = allow_neg_eigval
92
+ self.use_forget_gate = use_forget_gate
93
+ self.key_dim = self.num_heads * self.head_dim
94
+ self.value_dim = int(self.key_dim * self.expand_v)
95
+ self.head_qk_dim = head_dim
96
+ self.head_v_dim = int(head_dim * self.expand_v)
97
+ self.layer_idx = layer_idx
98
+ self.silu = nn.SiLU()
99
+ assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
100
+ # Create multiple projection layers for each householder transformation
101
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
102
+
103
+ self.k_projs = nn.ModuleList(
104
+ [
105
+ nn.Linear(hidden_size, self.key_dim, bias=False)
106
+ for _ in range(num_householder)
107
+ ]
108
+ )
109
+ self.v_projs = nn.ModuleList(
110
+ [
111
+ nn.Linear(hidden_size, self.value_dim, bias=False)
112
+ for _ in range(num_householder)
113
+ ]
114
+ )
115
+ self.b_projs = nn.ModuleList(
116
+ [
117
+ nn.Linear(hidden_size, self.num_heads, bias=False)
118
+ for _ in range(num_householder)
119
+ ]
120
+ )
121
+ if use_short_conv:
122
+ self.q_conv1ds = nn.ModuleList(
123
+ [
124
+ ShortConvolution(
125
+ hidden_size=self.key_dim,
126
+ kernel_size=conv_size,
127
+ activation="silu",
128
+ )
129
+ for _ in range(num_householder)
130
+ ]
131
+ )
132
+ self.k_conv1ds = nn.ModuleList(
133
+ [
134
+ ShortConvolution(
135
+ hidden_size=self.key_dim,
136
+ kernel_size=conv_size,
137
+ activation="silu",
138
+ )
139
+ for _ in range(num_householder)
140
+ ]
141
+ )
142
+ self.v_conv1ds = nn.ModuleList(
143
+ [
144
+ ShortConvolution(
145
+ hidden_size=self.value_dim,
146
+ kernel_size=conv_size,
147
+ activation="silu",
148
+ )
149
+ for _ in range(num_householder)
150
+ ]
151
+ )
152
+
153
+ if self.use_forget_gate:
154
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
155
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
156
+ A_log = torch.log(A)
157
+ self.A_log = nn.Parameter(A_log)
158
+ self.A_log._no_weight_decay = True
159
+
160
+ # Initialize dt parameters
161
+ dt_min = 0.001
162
+ dt_max = 0.1
163
+ dt_init_floor = 1e-4
164
+ dt = torch.exp(
165
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
166
+ + math.log(dt_min)
167
+ )
168
+ dt = torch.clamp(dt, min=dt_init_floor)
169
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
170
+ self.dt_bias = nn.Parameter(inv_dt)
171
+ self.dt_bias._no_weight_decay = True
172
+
173
+ if use_gate:
174
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
175
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
176
+ else:
177
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
178
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
179
+ self.k_id = torch.nn.Identity()
180
+ self.apply(self._initialize_weights)
181
+
182
+ def _initialize_weights(self, module: nn.Module):
183
+ if getattr(module, "_is_hf_initialized", False):
184
+ return
185
+ if isinstance(module, nn.Linear):
186
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
187
+ if module.bias is not None:
188
+ nn.init.zeros_(module.bias)
189
+ module._is_hf_initialized = True
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states: torch.Tensor,
194
+ attention_mask: Optional[torch.Tensor] = None,
195
+ past_key_values: Optional[Cache] = None,
196
+ use_cache: Optional[bool] = False,
197
+ output_attentions: Optional[bool] = False,
198
+ **kwargs: Unpack[Dict],
199
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
200
+ if attention_mask is not None:
201
+ assert len(attention_mask.shape) == 2, (
202
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
203
+ "for padding purposes (0 indicating padding)."
204
+ )
205
+
206
+ mode = (
207
+ "chunk" # 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
208
+ )
209
+ if self.training:
210
+ assert mode == "chunk", "Only chunk mode is supported in training."
211
+
212
+ last_state = None
213
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
214
+ last_state = past_key_values[self.layer_idx]
215
+
216
+ # Process each householder transformation
217
+ ks, vs, betas = [], [], []
218
+ conv_states = []
219
+
220
+ for i in range(self.num_householder):
221
+ if self.use_short_conv:
222
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
223
+ if last_state is not None:
224
+ conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"][
225
+ i
226
+ ]
227
+ conv_mask = (
228
+ attention_mask[:, -hidden_states.shape[1]:]
229
+ if attention_mask is not None
230
+ else None
231
+ )
232
+
233
+ k, conv_state_k = self.k_conv1ds[i](
234
+ x=self.k_projs[i](hidden_states),
235
+ mask=conv_mask,
236
+ cache=conv_state_k,
237
+ output_final_state=use_cache,
238
+ )
239
+ v, conv_state_v = self.v_conv1ds[i](
240
+ x=self.v_projs[i](hidden_states),
241
+ mask=conv_mask,
242
+ cache=conv_state_v,
243
+ output_final_state=use_cache,
244
+ )
245
+ conv_states.append((conv_state_q, conv_state_k, conv_state_v))
246
+ else:
247
+ k = self.silu(self.k_projs[i](hidden_states))
248
+ v = self.silu(self.v_projs[i](hidden_states))
249
+
250
+ ks.append(k)
251
+ vs.append(v)
252
+
253
+ beta = self.b_projs[i](
254
+ hidden_states
255
+ ).sigmoid() # bs, sequence_length, num_heads
256
+ if attention_mask is not None:
257
+ beta = beta.mul(attention_mask[:, -hidden_states.shape[1]:, None])
258
+ if self.allow_neg_eigval:
259
+ beta = beta * 2
260
+ betas.append(beta)
261
+
262
+ if self.use_short_conv:
263
+ q, conv_state_q = self.q_conv1ds[0](
264
+ x=self.q_proj(hidden_states),
265
+ mask=conv_mask,
266
+ cache=conv_state_q,
267
+ output_final_state=use_cache,
268
+ )
269
+ else:
270
+ q = self.silu(self.q_proj(hidden_states))
271
+ q = interleave_multiple_sequences(
272
+ [torch.zeros_like(q)] * (self.num_householder - 1) + [q]
273
+ )
274
+ # Interleave all sequences
275
+ k = interleave_multiple_sequences(ks)
276
+ v = interleave_multiple_sequences(vs)
277
+ beta = interleave_multiple_sequences(betas)
278
+
279
+ q, k, v = (
280
+ rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in (q, k, v)
281
+ )
282
+
283
+ recurrent_state = (
284
+ last_state["recurrent_state"] if last_state is not None else None
285
+ )
286
+ offsets = kwargs.get("offsets")
287
+
288
+ if mode == "chunk":
289
+ if self.use_forget_gate:
290
+ g = -self.A_log.float().exp() * F.softplus(
291
+ self.a_proj(hidden_states).float() + self.dt_bias
292
+ )
293
+ if attention_mask is not None:
294
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
295
+
296
+ # Interleave g with zeros for non-first transformations
297
+ g = interleave_multiple_sequences(
298
+ [g] + [torch.zeros_like(g)] * (self.num_householder - 1)
299
+ )
300
+
301
+ o, recurrent_state = chunk_gated_delta_rule(
302
+ q=q,
303
+ k=k,
304
+ v=v,
305
+ g=g,
306
+ beta=beta,
307
+ initial_state=recurrent_state,
308
+ output_final_state=use_cache,
309
+ cu_seqlens=offsets,
310
+ head_first=False,
311
+ use_qk_l2norm_in_kernel=True
312
+ )
313
+ else:
314
+ o, recurrent_state = chunk_delta_rule(
315
+ q=q,
316
+ k=k,
317
+ v=v,
318
+ beta=beta,
319
+ initial_state=recurrent_state,
320
+ output_final_state=use_cache,
321
+ cu_seqlens=offsets,
322
+ head_first=False,
323
+ use_qk_l2norm_in_kernel=True
324
+ )
325
+ else:
326
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
327
+
328
+ # Take every nth element for n householder transformations
329
+ o = o[:, self.num_householder - 1:: self.num_householder, :]
330
+
331
+ if past_key_values is not None:
332
+ past_key_values.update(
333
+ recurrent_state=recurrent_state,
334
+ conv_state=conv_states if self.use_short_conv else None,
335
+ layer_idx=self.layer_idx,
336
+ offset=q.shape[2],
337
+ )
338
+
339
+ if self.use_gate:
340
+ g = rearrange(
341
+ self.g_proj(hidden_states),
342
+ "... (h d) -> ... h d",
343
+ h=self.num_heads,
344
+ )
345
+ o = self.o_norm(o, g)
346
+ else:
347
+ o = self.o_norm(o)
348
+ o = rearrange(o, "b t h d -> b t (h d)")
349
+ o = self.o_proj(o)
350
+
351
+ return o, None, past_key_values
fla/layers/gla.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class GatedLinearAttention(nn.Module):
25
+ r"""
26
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
27
+
28
+ Args:
29
+ mode (str, Optional):
30
+ Which GLA kernel to use.
31
+ Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
32
+ Default: `chunk`.
33
+ hidden_size (int, Optional):
34
+ The hidden size of the input. Default: 1024.
35
+ expand_k (float, Optional):
36
+ The expansion ratio for the key dim. Default: 0.5.
37
+ expand_v (float, Optional):
38
+ The expansion ratio for the value dim. Default: 1.0.
39
+ num_heads (int, Optional):
40
+ The number of heads. Default: 4.
41
+ num_kv_heads (int, Optional):
42
+ The number of key/value heads, used for MQA. Default: None.
43
+ feature_map (str, Optional):
44
+ Feature map function applied to queries/keys. Default: None.
45
+ use_short_conv (bool, Optional):
46
+ Whether to use short convolutions. Default: `False`.
47
+ conv_size (int, Optional):
48
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
49
+ conv_bias (bool, Optional):
50
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
51
+ use_output_gate (bool, Optional):
52
+ Whether to use output gate. Default: `True`.
53
+ gate_fn (str, Optional):
54
+ The activation function for the output gate. Default: `swish`.
55
+ elementwise_affine (bool, Optional):
56
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
57
+ norm_eps (float, Optional):
58
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
59
+ gate_logit_normalizer (int, Optional):
60
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
61
+ gate_low_rank_dim (int, Optional):
62
+ The low rank dim for the gate projection. Default: 16.
63
+ clamp_min (float, Optional):
64
+ The minimum value for the gate logits. Default: None.
65
+ fuse_norm (bool, Optional):
66
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
67
+ layer_idx (int, Optional):
68
+ The index of the layer. Default: None.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ mode: str = 'chunk',
74
+ hidden_size: int = 1024,
75
+ expand_k: float = 0.5,
76
+ expand_v: float = 1.0,
77
+ num_heads: int = 4,
78
+ num_kv_heads: Optional[int] = None,
79
+ feature_map: Optional[str] = None,
80
+ use_short_conv: bool = False,
81
+ conv_size: int = 4,
82
+ conv_bias: bool = False,
83
+ use_output_gate: bool = True,
84
+ gate_fn: str = 'swish',
85
+ elementwise_affine: Optional[bool] = True,
86
+ norm_eps: float = 1e-5,
87
+ gate_logit_normalizer: int = 16,
88
+ gate_low_rank_dim: int = 16,
89
+ clamp_min: Optional[float] = None,
90
+ fuse_norm: bool = True,
91
+ layer_idx: int = None,
92
+ ) -> GatedLinearAttention:
93
+ super().__init__()
94
+
95
+ self.mode = mode
96
+ self.hidden_size = hidden_size
97
+ self.expand_k = expand_k
98
+ self.expand_v = expand_v
99
+ self.num_heads = num_heads
100
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
101
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
102
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
103
+
104
+ self.use_short_conv = use_short_conv
105
+ self.conv_size = conv_size
106
+ self.conv_bias = conv_bias
107
+ self.use_output_gate = use_output_gate
108
+
109
+ self.key_dim = int(hidden_size * expand_k)
110
+ self.value_dim = int(hidden_size * expand_v)
111
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
112
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
113
+ self.clamp_min = clamp_min
114
+ self.layer_idx = layer_idx
115
+
116
+ assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
117
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
118
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
119
+
120
+ self.head_k_dim = self.key_dim // num_heads
121
+ self.head_v_dim = self.value_dim // num_heads
122
+
123
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
124
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
125
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
126
+ if self.use_output_gate:
127
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
128
+
129
+ if use_short_conv:
130
+ self.conv_size = conv_size
131
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
132
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
133
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
134
+
135
+ self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
136
+ nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
137
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
138
+
139
+ if gate_fn == 'swish' and fuse_norm and use_output_gate:
140
+ self.g_norm_swish_gate = FusedRMSNormGated(
141
+ hidden_size=self.head_v_dim,
142
+ elementwise_affine=elementwise_affine,
143
+ eps=norm_eps
144
+ )
145
+ self.fuse_norm_and_gate = True
146
+ else:
147
+ self.fuse_norm_and_gate = False
148
+ self.g_norm = RMSNorm(
149
+ hidden_size=self.head_v_dim,
150
+ elementwise_affine=elementwise_affine,
151
+ eps=norm_eps
152
+ )
153
+ self.gate_fn = ACT2FN[gate_fn]
154
+
155
+ self.gate_logit_normalizer = gate_logit_normalizer
156
+
157
+ def forward(
158
+ self,
159
+ hidden_states: torch.Tensor,
160
+ attention_mask: Optional[torch.Tensor] = None,
161
+ past_key_values: Optional[Cache] = None,
162
+ use_cache: Optional[bool] = False,
163
+ output_attentions: Optional[bool] = False,
164
+ **kwargs: Unpack[Dict]
165
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
166
+ if attention_mask is not None:
167
+ assert len(attention_mask.shape) == 2, (
168
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
169
+ "for padding purposes (0 indicating padding). "
170
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
171
+ )
172
+
173
+ # launching the triton kernel for just one token will actually be slower
174
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
175
+
176
+ last_state = None
177
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
178
+ last_state = past_key_values[self.layer_idx]
179
+
180
+ cu_seqlens = kwargs.get('cu_seqlens', None)
181
+ if self.use_short_conv:
182
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
183
+ if last_state is not None:
184
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
185
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
186
+ q, conv_state_q = self.q_conv1d(
187
+ x=self.q_proj(hidden_states),
188
+ mask=conv_mask,
189
+ cache=conv_state_q,
190
+ output_final_state=use_cache,
191
+ cu_seqlens=cu_seqlens
192
+ )
193
+ k, conv_state_k = self.k_conv1d(
194
+ x=self.k_proj(hidden_states),
195
+ mask=conv_mask,
196
+ cache=conv_state_k,
197
+ output_final_state=use_cache,
198
+ cu_seqlens=cu_seqlens
199
+ )
200
+ v, conv_state_v = self.v_conv1d(
201
+ x=self.v_proj(hidden_states),
202
+ mask=conv_mask,
203
+ cache=conv_state_v,
204
+ output_final_state=use_cache,
205
+ cu_seqlens=cu_seqlens
206
+ )
207
+ else:
208
+ q = self.q_proj(hidden_states)
209
+ k = self.k_proj(hidden_states)
210
+ v = self.v_proj(hidden_states)
211
+ gk = self.gk_proj(hidden_states)
212
+
213
+ if self.feature_map_fn is not None:
214
+ q, k = map(self.feature_map_fn, (q, k))
215
+ # dealing with left-padding
216
+ if attention_mask is not None:
217
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
218
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
219
+ if self.num_kv_groups > 1:
220
+ k, gk = (repeat(x, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_k_dim) for x in (k, gk))
221
+ v = repeat(v, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_v_dim)
222
+ else:
223
+ k, gk = (rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim) for x in (k, gk))
224
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
225
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
226
+
227
+ if self.clamp_min is not None:
228
+ gk = torch.clamp_min(gk, self.clamp_min)
229
+
230
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
231
+ if mode == 'fused_recurrent':
232
+ o, recurrent_state = fused_recurrent_gla(
233
+ q=q,
234
+ k=k,
235
+ v=v,
236
+ gk=gk,
237
+ initial_state=recurrent_state,
238
+ output_final_state=use_cache,
239
+ cu_seqlens=cu_seqlens,
240
+ head_first=False
241
+ )
242
+ elif mode == 'fused_chunk':
243
+ o, recurrent_state = fused_chunk_gla(
244
+ q=q,
245
+ k=k,
246
+ v=v,
247
+ g=gk,
248
+ initial_state=recurrent_state,
249
+ output_final_state=use_cache,
250
+ head_first=False
251
+ )
252
+ elif mode == 'chunk':
253
+ o, recurrent_state = chunk_gla(
254
+ q=q,
255
+ k=k,
256
+ v=v,
257
+ g=gk,
258
+ initial_state=recurrent_state,
259
+ output_final_state=use_cache,
260
+ cu_seqlens=cu_seqlens,
261
+ head_first=False
262
+ )
263
+ else:
264
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
265
+
266
+ if past_key_values is not None:
267
+ past_key_values.update(
268
+ recurrent_state=recurrent_state,
269
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
270
+ layer_idx=self.layer_idx,
271
+ offset=q.shape[1]
272
+ )
273
+
274
+ if self.use_output_gate:
275
+ g = self.g_proj(hidden_states)
276
+ if self.fuse_norm_and_gate:
277
+ g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
278
+ o = self.g_norm_swish_gate(o, g)
279
+ o = rearrange(o, 'b t h d -> b t (h d)')
280
+ else:
281
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
282
+ o = o * self.gate_fn(g)
283
+ else:
284
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
285
+ o = self.o_proj(o)
286
+
287
+ return o, None, past_key_values
288
+
289
+ def state_size(self, **kwargs) -> int:
290
+ state_size = self.key_dim * self.head_v_dim
291
+ for module in self.children():
292
+ if isinstance(module, ShortConvolution):
293
+ state_size += module.state_size
294
+ return state_size
fla/layers/gsa.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+
14
+ from fla.modules import RMSNorm, ShortConvolution
15
+ from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap
16
+ from fla.modules.layernorm import rms_norm_linear
17
+ from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+ from fla.models.utils import Cache
23
+
24
+
25
+ class GatedSlotAttention(nn.Module):
26
+
27
+ def __init__(
28
+ self,
29
+ mode: str = 'chunk',
30
+ hidden_size: int = 1024,
31
+ expand_k: float = 1.,
32
+ expand_v: float = 1.,
33
+ num_heads: int = 4,
34
+ num_kv_heads: Optional[int] = None,
35
+ use_short_conv: bool = False,
36
+ conv_size: int = 4,
37
+ conv_bias: bool = False,
38
+ num_slots: Optional[int] = None,
39
+ elementwise_affine: Optional[bool] = True,
40
+ norm_eps: float = 1e-5,
41
+ gate_logit_normalizer: int = 8,
42
+ feature_map: str = 'swish',
43
+ use_output_gate: bool = False,
44
+ use_norm: bool = True,
45
+ layer_idx: Optional[int] = None,
46
+ scale: Optional[float] = 1.,
47
+ **kwargs
48
+ ) -> GatedSlotAttention:
49
+ super().__init__()
50
+
51
+ self.mode = mode
52
+ self.hidden_size = hidden_size
53
+ self.expand_k = expand_k
54
+ self.expand_v = expand_v
55
+ self.num_heads = num_heads
56
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
57
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
58
+ self.key_dim = int(hidden_size * expand_k)
59
+ self.value_dim = int(hidden_size * expand_v)
60
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
61
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
62
+ self.head_k_dim = self.key_dim // self.num_heads
63
+ self.head_v_dim = self.value_dim // self.num_heads
64
+
65
+ self.use_short_conv = use_short_conv
66
+ self.conv_size = conv_size
67
+ self.conv_bias = conv_bias
68
+
69
+ self.gate_logit_normalizer = gate_logit_normalizer
70
+
71
+ self.use_output_gate = use_output_gate
72
+ self.use_norm = use_norm
73
+ self.scale = scale
74
+
75
+ if num_slots is None:
76
+ num_slots = self.head_k_dim
77
+ self.num_slots = num_slots
78
+
79
+ self.layer_idx = layer_idx
80
+
81
+ if layer_idx is None:
82
+ warnings.warn(
83
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
84
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
85
+ "when creating this class."
86
+ )
87
+
88
+ self.register_module('feature_map', None)
89
+ if feature_map == 'swish':
90
+ self.feature_map = SwishFeatureMap()
91
+ elif feature_map == 'relu':
92
+ self.feature_map = ReLUFeatureMap()
93
+ elif feature_map == 't2r':
94
+ self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim)
95
+ else:
96
+ raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.")
97
+
98
+ self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
99
+ self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
100
+ self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
101
+ self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
102
+
103
+ if use_short_conv:
104
+ self.conv_size = conv_size
105
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
106
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
107
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
108
+
109
+ self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps)
110
+ self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states: torch.Tensor,
115
+ attention_mask: Optional[torch.Tensor] = None,
116
+ past_key_values: Optional[Cache] = None,
117
+ use_cache: Optional[bool] = False,
118
+ output_attentions: Optional[bool] = False,
119
+ **kwargs: Unpack[Dict]
120
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
121
+ if attention_mask is not None:
122
+ assert len(attention_mask.shape) == 2, (
123
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
124
+ "for padding purposes (0 indicating padding). "
125
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
126
+ )
127
+
128
+ # launching the triton kernel for just one token will actually be slower
129
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
130
+
131
+ last_state = None
132
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
133
+ last_state = past_key_values[self.layer_idx]
134
+
135
+ cu_seqlens = kwargs.get('cu_seqlens', None)
136
+ if self.use_short_conv:
137
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
138
+ if last_state is not None:
139
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
140
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
141
+ q, conv_state_q = self.q_conv1d(
142
+ x=self.q_proj(hidden_states),
143
+ mask=conv_mask,
144
+ cache=conv_state_q,
145
+ output_final_state=use_cache,
146
+ cu_seqlens=cu_seqlens
147
+ )
148
+ k, conv_state_k = self.k_conv1d(
149
+ x=self.k_proj(hidden_states),
150
+ mask=conv_mask,
151
+ cache=conv_state_k,
152
+ output_final_state=use_cache,
153
+ cu_seqlens=cu_seqlens
154
+ )
155
+ v, conv_state_v = self.v_conv1d(
156
+ x=self.v_proj(hidden_states),
157
+ mask=conv_mask,
158
+ cache=conv_state_v,
159
+ output_final_state=use_cache,
160
+ cu_seqlens=cu_seqlens
161
+ )
162
+ else:
163
+ q = self.q_proj(hidden_states)
164
+ k = self.k_proj(hidden_states)
165
+ v = self.v_proj(hidden_states)
166
+ f = self.f_proj(hidden_states)
167
+
168
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
169
+ k = rearrange(k, 'b t (h d) -> b t h d', d=self.head_k_dim)
170
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
171
+ f = rearrange(f, 'b t (h m) -> b t h m', m=self.num_slots)
172
+
173
+ if self.feature_map is not None:
174
+ q, k = map(lambda x: self.feature_map(x), (q, k))
175
+ v = F.silu(v)
176
+
177
+ f = F.logsigmoid(f) / self.gate_logit_normalizer
178
+ s = (1 - f.exp()).to(f.dtype)
179
+ # dealing with left-padding
180
+ if attention_mask is not None:
181
+ s = s.mul_(attention_mask[:, -s.shape[1]:, None, None])
182
+ v = v.mul_(attention_mask[:, -v.shape[1]:, None, None])
183
+
184
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
185
+ if mode == 'fused_recurrent':
186
+ o, recurrent_state = fused_recurrent_gsa(
187
+ q=q,
188
+ k=k,
189
+ v=v,
190
+ s=s,
191
+ g=f,
192
+ initial_state=recurrent_state,
193
+ output_final_state=use_cache,
194
+ scale=self.scale,
195
+ cu_seqlens=cu_seqlens,
196
+ head_first=False
197
+ )
198
+ elif mode == 'chunk':
199
+ o, recurrent_state = chunk_gsa(
200
+ q=q,
201
+ k=k,
202
+ v=v,
203
+ s=s,
204
+ g=f,
205
+ initial_state=recurrent_state,
206
+ output_final_state=use_cache,
207
+ scale=self.scale,
208
+ cu_seqlens=cu_seqlens,
209
+ head_first=False
210
+ )
211
+ else:
212
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
213
+
214
+ if past_key_values is not None:
215
+ past_key_values.update(
216
+ recurrent_state=recurrent_state,
217
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
218
+ layer_idx=self.layer_idx,
219
+ offset=q.shape[1]
220
+ )
221
+
222
+ o = rearrange(o, 'b t h d -> b t (h d)')
223
+ o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
224
+ return o, None, past_key_values
225
+
226
+ def state_size(self, *args, **kwargs) -> int:
227
+ return 2 * self.num_slots * self.hidden_size
fla/layers/hgrn.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, ShortConvolution
15
+ from fla.modules.activations import swiglu
16
+ from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class HGRNAttention(nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ mode: str = 'chunk',
29
+ hidden_size: int = 1024,
30
+ expand_ratio: Optional[int] = 1,
31
+ use_short_conv: bool = False,
32
+ conv_size: int = 4,
33
+ conv_bias: bool = False,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None
37
+ ) -> HGRNAttention:
38
+ super().__init__()
39
+
40
+ self.mode = mode
41
+ self.hidden_size = hidden_size
42
+ self.expand_ratio = expand_ratio
43
+ self.input_dim = int(hidden_size * expand_ratio)
44
+
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.conv_bias = conv_bias
48
+
49
+ self.layer_idx = layer_idx
50
+
51
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
52
+
53
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
54
+ self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
55
+ self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
56
+
57
+ if use_short_conv:
58
+ self.conv_size = conv_size
59
+ self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
60
+ self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
61
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
62
+
63
+ self.g_norm = FusedRMSNormGated(
64
+ hidden_size=self.input_dim,
65
+ elementwise_affine=elementwise_affine,
66
+ eps=norm_eps
67
+ )
68
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
69
+
70
+ def forward(
71
+ self,
72
+ hidden_states: torch.Tensor,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ past_key_values: Optional[Cache] = None,
75
+ use_cache: Optional[bool] = False,
76
+ output_attentions: Optional[bool] = False,
77
+ lower_bound: Optional[torch.Tensor] = None,
78
+ **kwargs: Unpack[Dict]
79
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
80
+ if attention_mask is not None:
81
+ assert len(attention_mask.shape) == 2, (
82
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
83
+ "for padding purposes (0 indicating padding). "
84
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
85
+ )
86
+
87
+ # launching the triton kernel for just one token will actually be slower
88
+ mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode
89
+
90
+ last_state = None
91
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
92
+ last_state = past_key_values[self.layer_idx]
93
+
94
+ cu_seqlens = kwargs.get('cu_seqlens', None)
95
+ if self.use_short_conv:
96
+ conv_state_i, conv_state_f = None, None
97
+ if last_state is not None:
98
+ conv_state_i, conv_state_f = last_state['conv_state']
99
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
100
+ i, conv_state_i = self.i_conv1d(
101
+ x=self.i_proj(hidden_states),
102
+ mask=conv_mask,
103
+ cache=conv_state_i,
104
+ output_final_state=use_cache,
105
+ cu_seqlens=cu_seqlens
106
+ )
107
+ f, conv_state_f = self.f_conv1d(
108
+ x=self.f_proj(hidden_states),
109
+ mask=conv_mask,
110
+ cache=conv_state_f,
111
+ output_final_state=use_cache,
112
+ cu_seqlens=cu_seqlens
113
+ )
114
+ else:
115
+ i = self.i_proj(hidden_states)
116
+ f = self.f_proj(hidden_states)
117
+
118
+ # the lower bound for the first layer is zero
119
+ if lower_bound is None or self.layer_idx == 0:
120
+ i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
121
+ else:
122
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
123
+ i, f = swiglu(i, 1 - g), g.log()
124
+
125
+ # dealing with left-padding
126
+ if attention_mask is not None:
127
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
128
+
129
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
130
+ if mode == 'chunk':
131
+ if cu_seqlens is not None:
132
+ raise NotImplementedError("Chunk mode does not support variable-length sequences.")
133
+ o, recurrent_state = chunk_hgrn(
134
+ x=i,
135
+ g=f,
136
+ initial_state=recurrent_state,
137
+ output_final_state=use_cache,
138
+ )
139
+ elif mode == 'fused_recurrent':
140
+ o, recurrent_state = fused_recurrent_hgrn(
141
+ x=i,
142
+ g=f,
143
+ initial_state=recurrent_state,
144
+ output_final_state=use_cache,
145
+ cu_seqlens=cu_seqlens
146
+ )
147
+ else:
148
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
149
+
150
+ if past_key_values is not None:
151
+ past_key_values.update(
152
+ recurrent_state=recurrent_state,
153
+ conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
154
+ layer_idx=self.layer_idx,
155
+ offset=i.shape[2]
156
+ )
157
+
158
+ o = self.g_norm(o, self.g_proj(hidden_states))
159
+ o = self.o_proj(o)
160
+
161
+ return o, None, past_key_values
162
+
163
+ def state_size(self, **kwargs) -> int:
164
+ state_size = self.hidden_size
165
+ for module in self.children():
166
+ if isinstance(module, ShortConvolution):
167
+ state_size += module.state_size
168
+ return state_size
fla/layers/rebased.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from einops import rearrange
15
+
16
+ from fla.modules.feature_map import RebasedFeatureMap
17
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
18
+ from fla.ops.rebased import parallel_rebased
19
+
20
+
21
+ class ReBasedLinearAttention(nn.Module):
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size: int,
26
+ l_max: int = 2048,
27
+ feature_dim: int = 16,
28
+ num_key_value_heads: int = 16,
29
+ num_heads: int = 16,
30
+ use_gamma: Optional[bool] = True,
31
+ use_beta: Optional[bool] = True,
32
+ normalize: Optional[bool] = True,
33
+ causal: bool = True,
34
+ eps: float = 1e-5,
35
+ mode: str = "parallel",
36
+ layer_idx: Optional[int] = None,
37
+ **kwargs
38
+ ) -> ReBasedLinearAttention:
39
+ super().__init__()
40
+ self.hidden_size = hidden_size
41
+ self.l_max = l_max
42
+ self.mode = mode
43
+ assert self.mode in ["fused_chunk", "parallel", 'chunk']
44
+
45
+ self.feature_dim = feature_dim
46
+ self.num_key_value_heads = num_key_value_heads
47
+ self.num_heads = num_heads
48
+ self.head_dim = self.hidden_size // self.num_key_value_heads
49
+ self.use_gamma = use_gamma
50
+ self.use_beta = use_beta
51
+ self.normalize = normalize
52
+ self.causal = causal
53
+ self.eps = eps
54
+ self.mode = mode
55
+ self.layer_idx = layer_idx
56
+
57
+ self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
58
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
59
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
60
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
61
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
62
+ self.dropout = nn.Identity()
63
+
64
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
65
+ mode = self.mode
66
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
67
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
68
+ q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
69
+ if mode == "fused_chunk":
70
+ o = fused_chunk_linear_attn(
71
+ q=q,
72
+ k=k,
73
+ v=v,
74
+ normalize=True,
75
+ scale=1,
76
+ head_first=False
77
+ )
78
+ elif mode == 'chunk':
79
+ o = chunk_linear_attn(
80
+ q=q,
81
+ k=k,
82
+ v=v,
83
+ normalize=True,
84
+ scale=1,
85
+ head_first=False
86
+ )
87
+ elif mode == 'parallel':
88
+ assert q.shape[-1] <= 128
89
+ o = parallel_rebased(
90
+ q=q,
91
+ k=k,
92
+ v=v,
93
+ eps=self.eps,
94
+ use_scale=True,
95
+ use_normalize=True,
96
+ head_first=False
97
+ )
98
+ o = self.o_proj(o)
99
+ o = self.dropout(o)
100
+ return o
101
+
102
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
103
+ def forward_reference(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ filters: torch.Tensor = None,
107
+ *args,
108
+ **kwargs
109
+ ):
110
+ """
111
+ x (torch.Tensor): tensor of shape (b, d, t)
112
+ y (torch.Tensor): tensor of shape (b, d, t)
113
+ """
114
+ b, t, _ = hidden_states.size()
115
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
116
+
117
+ q = q.view(b, t, -1, self.feature_dim).transpose(1, 2)
118
+ k = k.view(b, t, -1, self.feature_dim).transpose(1, 2)
119
+ v = v.view(b, t, -1, self.head_dim).transpose(1, 2)
120
+
121
+ # Linear attention
122
+ q, k = self.feature_map(q), self.feature_map(k)
123
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
124
+
125
+ # Compute attention
126
+ if self.causal:
127
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
128
+ else:
129
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
130
+ y = rearrange(y, 'b h t d -> b t (h d)')
131
+ y = self.o_proj(y.to(hidden_states.dtype))
132
+ y = self.dropout(y)
133
+ return y.to(hidden_states.dtype)
fla/layers/rwkv6.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV6Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ expand_k: float = 0.5,
29
+ expand_v: float = 1.0,
30
+ num_heads: int = 4,
31
+ gate_fn: str = 'swish',
32
+ proj_low_rank_dim: int = 32,
33
+ gate_low_rank_dim: int = 64,
34
+ fuse_norm: bool = True,
35
+ elementwise_affine: Optional[bool] = True,
36
+ norm_eps: float = 1e-5,
37
+ layer_idx: int = None,
38
+ **kwargs
39
+ ) -> RWKV6Attention:
40
+ super().__init__()
41
+
42
+ self.mode = mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.num_heads = num_heads
47
+ self.proj_low_rank_dim = proj_low_rank_dim
48
+ self.gate_low_rank_dim = gate_low_rank_dim
49
+
50
+ self.key_dim = int(hidden_size * expand_k)
51
+ self.value_dim = int(hidden_size * expand_v)
52
+ self.layer_idx = layer_idx
53
+
54
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
55
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
56
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
57
+
58
+ self.head_k_dim = self.key_dim // num_heads
59
+ self.head_v_dim = self.value_dim // num_heads
60
+
61
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
62
+ self.x_proj = nn.Sequential(
63
+ LerpLinear(hidden_size, proj_low_rank_dim * 5),
64
+ nn.Tanh(),
65
+ nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False)
66
+ )
67
+ self.x_bias = nn.Parameter(torch.zeros(5, hidden_size))
68
+
69
+ self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
70
+ self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
71
+ self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
72
+ self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
73
+ self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
74
+ self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim))
75
+
76
+ # TODO: fuse GroupNorm and output gate
77
+ self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps)
78
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
79
+ self.gate_fn = ACT2FN[gate_fn]
80
+
81
+ self.apply(self._initialize_weights)
82
+
83
+ def _initialize_weights(self, module: nn.Module):
84
+ if getattr(module, "_is_hf_initialized", False):
85
+ return
86
+ if isinstance(module, nn.Linear):
87
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
88
+ if module.bias is not None:
89
+ nn.init.zeros_(module.bias)
90
+ if isinstance(module, nn.Parameter):
91
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
92
+ module._is_hf_initialized = True
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ past_key_values: Optional[Cache] = None,
99
+ use_cache: Optional[bool] = False,
100
+ output_attentions: Optional[bool] = False,
101
+ **kwargs
102
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
103
+ if attention_mask is not None:
104
+ assert len(attention_mask.shape) == 2, (
105
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
106
+ "for padding purposes (0 indicating padding). "
107
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
108
+ )
109
+
110
+ batch_size, seq_len, hidden_size = hidden_states.shape
111
+ # launching the triton kernel for just one token will actually be slower
112
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
113
+
114
+ last_state = None
115
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
116
+ last_state = past_key_values[self.layer_idx]
117
+
118
+ if attention_mask is not None:
119
+ hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None])
120
+ if hidden_states.shape[1] == 1 and last_state is not None:
121
+ shifted = last_state['conv_state'].unsqueeze(1)
122
+ else:
123
+ shifted = self.time_shift(hidden_states)
124
+ if last_state is not None:
125
+ shifted[:, 0] = last_state['conv_state']
126
+
127
+ delta = shifted - hidden_states
128
+ x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
129
+ x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1))
130
+
131
+ r, w, k, v, g = x.add_(self.x_bias).unbind(-2)
132
+ r = self.r_proj(hidden_states, r, delta)
133
+ w = self.w_proj(hidden_states, w, delta)
134
+ k = self.k_proj(hidden_states, k, delta)
135
+ v = self.v_proj(hidden_states, v, delta)
136
+ g = self.g_proj(hidden_states, g, delta)
137
+
138
+ # dealing with left-padding
139
+ if attention_mask is not None:
140
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
141
+ r, w, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (r, w, k))
142
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
143
+ w = -torch.exp(w)
144
+ u = self.bonus
145
+
146
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
147
+ cu_seqlens = kwargs.get('cu_seqlens', None)
148
+ if mode == 'fused_recurrent':
149
+ o, recurrent_state = fused_recurrent_rwkv6(
150
+ r=r,
151
+ k=k,
152
+ v=v,
153
+ w=w,
154
+ u=u,
155
+ scale=1.,
156
+ initial_state=recurrent_state,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens,
159
+ head_first=False
160
+ )
161
+ elif mode == 'chunk':
162
+ o, recurrent_state = chunk_rwkv6(
163
+ q=r,
164
+ k=k,
165
+ v=v,
166
+ g=w,
167
+ u=u,
168
+ scale=1.,
169
+ initial_state=recurrent_state,
170
+ output_final_state=use_cache,
171
+ cu_seqlens=cu_seqlens,
172
+ head_first=False
173
+ )
174
+ else:
175
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
176
+
177
+ if past_key_values is not None:
178
+ past_key_values.update(
179
+ recurrent_state=recurrent_state,
180
+ conv_state=hidden_states[:, -1],
181
+ layer_idx=self.layer_idx,
182
+ offset=r.shape[2]
183
+ )
184
+
185
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g)
186
+ o = self.o_proj(o)
187
+
188
+ return o, None, past_key_values
189
+
190
+
191
+ class LoRA(nn.Module):
192
+
193
+ def __init__(
194
+ self,
195
+ input_dim: int,
196
+ output_dim: int,
197
+ low_rank_dim: int,
198
+ bias: Optional[bool] = True,
199
+ activation: Optional[str] = 'tanh'
200
+ ):
201
+ super().__init__()
202
+
203
+ self.input_dim = input_dim
204
+ self.output_dim = output_dim
205
+ self.low_rank_dim = low_rank_dim
206
+ self.bias = bias
207
+
208
+ if activation is None:
209
+ self.activation = nn.Identity()
210
+ elif activation == 'sigmoid':
211
+ self.activation = nn.Sigmoid()
212
+ elif activation == 'tanh':
213
+ self.activation = nn.Tanh()
214
+ elif activation == 'relu':
215
+ self.activation = nn.ReLU()
216
+ else:
217
+ raise ValueError(f"Not supported activation `{activation}`.")
218
+
219
+ self.lora = nn.Sequential(
220
+ nn.Linear(input_dim, low_rank_dim, bias=False),
221
+ self.activation,
222
+ nn.Linear(low_rank_dim, output_dim, bias=bias)
223
+ )
224
+
225
+ def __repr__(self) -> str:
226
+ s = f"{self.__class__.__name__}("
227
+ s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
228
+ if not self.bias:
229
+ s += f", bias={self.bias}"
230
+ s += ")"
231
+ return s
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ return self.lora(x)
235
+
236
+
237
+ class LerpLinear(nn.Module):
238
+
239
+ def __init__(
240
+ self,
241
+ input_dim: int,
242
+ output_dim: int,
243
+ low_rank_dim: Optional[int] = None
244
+ ):
245
+ super().__init__()
246
+
247
+ self.input_dim = input_dim
248
+ self.output_dim = output_dim
249
+ self.low_rank_dim = low_rank_dim
250
+
251
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
252
+ if low_rank_dim is None:
253
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
254
+ else:
255
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
256
+ self.mu = nn.Parameter(torch.zeros(input_dim))
257
+
258
+ def __repr__(self) -> str:
259
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
260
+ if self.low_rank_dim is not None:
261
+ s += f", low_rank_dim={self.low_rank_dim}"
262
+ s += ")"
263
+ return s
264
+
265
+ def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
266
+ if delta is None:
267
+ shifted = self.time_shift(x)
268
+ if len(shifted.shape) == 2:
269
+ shifted = shifted.unsqueeze(1)
270
+ delta = shifted - x
271
+ return self.linear(x + delta * self.mu)
272
+
273
+
274
+ class DDLerpLinear(nn.Module):
275
+
276
+ def __init__(
277
+ self,
278
+ input_dim: int,
279
+ output_dim: int,
280
+ low_rank_dim: Optional[int] = None
281
+ ):
282
+ super().__init__()
283
+
284
+ self.input_dim = input_dim
285
+ self.output_dim = output_dim
286
+ self.low_rank_dim = low_rank_dim
287
+
288
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
289
+ if low_rank_dim is None:
290
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
291
+ else:
292
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
293
+
294
+ def __repr__(self) -> str:
295
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
296
+ if self.low_rank_dim is not None:
297
+ s += f", low_rank_dim={self.low_rank_dim}"
298
+ s += ")"
299
+ return s
300
+
301
+ def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
302
+ if delta is None:
303
+ shifted = self.time_shift(x)
304
+ if len(shifted.shape) == 2:
305
+ shifted = shifted.unsqueeze(1)
306
+ delta = shifted - x
307
+ return self.linear(x + delta * mu)
fla/layers/simple_gla.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.activations import ACT2FN
15
+ from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class SimpleGatedLinearAttention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
24
+ This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
25
+
26
+ Args:
27
+ mode (str, Optional):
28
+ Which GLA kernel to use.
29
+ Currently available: `chunk`.
30
+ Default: `chunk`.
31
+ hidden_size (int, Optional):
32
+ The hidden size of the input. Default: 1024.
33
+ expand_k (float, Optional):
34
+ The expansion ratio for the key dim. Default: 1.0.
35
+ expand_v (float, Optional):
36
+ The expansion ratio for the value dim. Default: 1.0.
37
+ num_heads (int, Optional):
38
+ The number of heads. Default: 4.
39
+ num_kv_heads (int, Optional):
40
+ The number of key/value heads, used for MQA. Default: None.
41
+ feature_map (str, Optional):
42
+ Feature map function applied to queries/keys. Default: None.
43
+ use_short_conv (bool, Optional):
44
+ Whether to use short convolutions. Default: `False`.
45
+ conv_size (int, Optional):
46
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
47
+ conv_bias (bool, Optional):
48
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
49
+ gate_fn (str, Optional):
50
+ The activation function for the output gate. Default: `swish`.
51
+ elementwise_affine (bool, Optional):
52
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
53
+ norm_eps (float, Optional):
54
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
55
+ gate_logit_normalizer (int, Optional):
56
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
57
+ fuse_norm (bool, Optional):
58
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
59
+ layer_idx (int, Optional):
60
+ The index of the layer. Default: None.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ mode: str = 'chunk',
66
+ hidden_size: int = 1024,
67
+ expand_k: float = 1.,
68
+ expand_v: float = 1.,
69
+ num_heads: int = 4,
70
+ num_kv_heads: Optional[int] = None,
71
+ feature_map: Optional[str] = None,
72
+ use_short_conv: bool = True,
73
+ conv_size: int = 4,
74
+ conv_bias: bool = False,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ gate_logit_normalizer: int = 16,
79
+ fuse_norm: bool = True,
80
+ layer_idx: int = None,
81
+ ) -> SimpleGatedLinearAttention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+
97
+ self.key_dim = int(hidden_size * expand_k)
98
+ self.value_dim = int(hidden_size * expand_v)
99
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
100
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
101
+ self.layer_idx = layer_idx
102
+
103
+ assert mode in ['chunk', "fused_recurrent"], f"Not suppoerted mode `{mode}`."
104
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
105
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
106
+
107
+ self.head_k_dim = self.key_dim // num_heads
108
+ self.head_v_dim = self.value_dim // num_heads
109
+
110
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
111
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
112
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
113
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
114
+
115
+ if use_short_conv:
116
+ self.conv_size = conv_size
117
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
118
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
119
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
120
+
121
+ self.gk_proj = nn.Linear(hidden_size, self.num_heads)
122
+
123
+ if gate_fn == 'swish' and fuse_norm:
124
+ self.g_norm_swish_gate = FusedRMSNormGated(
125
+ hidden_size=self.head_v_dim,
126
+ elementwise_affine=elementwise_affine,
127
+ eps=norm_eps
128
+ )
129
+ self.fuse_norm_and_gate = True
130
+ else:
131
+ self.fuse_norm_and_gate = False
132
+ self.g_norm = RMSNorm(
133
+ hidden_size=self.head_v_dim,
134
+ elementwise_affine=elementwise_affine,
135
+ eps=norm_eps
136
+ )
137
+ self.gate_fn = ACT2FN[gate_fn]
138
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
139
+
140
+ self.gate_logit_normalizer = gate_logit_normalizer
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ past_key_values: Optional[Cache] = None,
147
+ use_cache: Optional[bool] = False,
148
+ output_attentions: Optional[bool] = False,
149
+ **kwargs
150
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
151
+ if attention_mask is not None:
152
+ assert len(attention_mask.shape) == 2, (
153
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
154
+ "for padding purposes (0 indicating padding). "
155
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
156
+ )
157
+
158
+ # launching the triton kernel for just one token will actually be slower
159
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
160
+
161
+ last_state = None
162
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
163
+ last_state = past_key_values[self.layer_idx]
164
+
165
+ cu_seqlens = kwargs.get('cu_seqlens', None)
166
+ if self.use_short_conv:
167
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
168
+ if last_state is not None:
169
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
170
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
171
+ q, conv_state_q = self.q_conv1d(
172
+ x=self.q_proj(hidden_states),
173
+ mask=conv_mask,
174
+ cache=conv_state_q,
175
+ output_final_state=use_cache,
176
+ cu_seqlens=cu_seqlens
177
+ )
178
+ k, conv_state_k = self.k_conv1d(
179
+ x=self.k_proj(hidden_states),
180
+ mask=conv_mask,
181
+ cache=conv_state_k,
182
+ output_final_state=use_cache,
183
+ cu_seqlens=cu_seqlens
184
+ )
185
+ v, conv_state_v = self.v_conv1d(
186
+ x=self.v_proj(hidden_states),
187
+ mask=conv_mask,
188
+ cache=conv_state_v,
189
+ output_final_state=use_cache,
190
+ cu_seqlens=cu_seqlens
191
+ )
192
+ else:
193
+ q = self.q_proj(hidden_states)
194
+ k = self.k_proj(hidden_states)
195
+ v = self.v_proj(hidden_states)
196
+ gk = self.gk_proj(hidden_states)
197
+
198
+ if self.feature_map_fn is not None:
199
+ q, k = map(self.feature_map_fn, (q, k))
200
+ # dealing with left-padding
201
+ if attention_mask is not None:
202
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
203
+ q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
204
+ if self.num_kv_groups > 1:
205
+ k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
206
+ else:
207
+ k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v))
208
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
209
+
210
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
211
+ if mode == 'chunk':
212
+ o, recurrent_state = chunk_simple_gla(
213
+ q=q,
214
+ k=k,
215
+ v=v,
216
+ gk=gk,
217
+ initial_state=recurrent_state,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens,
220
+ head_first=False
221
+ )
222
+ elif mode == 'fused_recurrent':
223
+ o, recurrent_state = fused_recurrent_simple_gla(
224
+ q=q,
225
+ k=k,
226
+ v=v,
227
+ gk=gk,
228
+ initial_state=recurrent_state,
229
+ output_final_state=use_cache,
230
+ cu_seqlens=cu_seqlens,
231
+ head_first=False
232
+ )
233
+ else:
234
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
235
+
236
+ if past_key_values is not None:
237
+ past_key_values.update(
238
+ recurrent_state=recurrent_state,
239
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
240
+ layer_idx=self.layer_idx,
241
+ offset=q.shape[1]
242
+ )
243
+
244
+ g = self.g_proj(hidden_states)
245
+ if self.fuse_norm_and_gate:
246
+ g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads)
247
+ o = self.g_norm_swish_gate(o, g)
248
+ o = rearrange(o, 'b t h d -> b t (h d)')
249
+ else:
250
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
251
+ o = o * self.gate_fn(g)
252
+ o = self.o_proj(o)
253
+
254
+ return o, None, past_key_values
255
+
256
+ def state_size(self, **kwargs) -> int:
257
+ state_size = self.key_dim * self.head_v_dim
258
+ for module in self.children():
259
+ if isinstance(module, ShortConvolution):
260
+ state_size += module.state_size
261
+ return state_size
fla/models/abc/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.abc.configuration_abc import ABCConfig
6
+ from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
7
+
8
+ AutoConfig.register(ABCConfig.model_type, ABCConfig)
9
+ AutoModel.register(ABCConfig, ABCModel)
10
+ AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
11
+
12
+
13
+ __all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']
fla/models/abc/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (714 Bytes). View file
 
fla/models/abc/modeling_abc.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.abc import ABCAttention
19
+ from fla.layers.attn import Attention
20
+ from fla.models.abc.configuration_abc import ABCConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as ABCMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers.processing_utils import Unpack
30
+
31
+
32
+ class ABCBlock(nn.Module):
33
+ def __init__(self, config: ABCConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = ABCAttention(
53
+ hidden_size=config.hidden_size,
54
+ expand_k=config.expand_k,
55
+ expand_v=config.expand_v,
56
+ num_heads=config.num_heads,
57
+ num_slots=config.num_slots,
58
+ use_short_conv=config.use_short_conv,
59
+ conv_size=config.conv_size,
60
+ gate_fn=config.hidden_act,
61
+ elementwise_affine=config.elementwise_affine,
62
+ norm_eps=config.norm_eps,
63
+ use_rope=config.use_rope,
64
+ clamp_min=config.clamp_min,
65
+ clamp_max=config.clamp_max,
66
+ fuse_norm=config.fuse_norm,
67
+ layer_idx=layer_idx
68
+ )
69
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
70
+ self.mlp = ABCMLP(
71
+ hidden_size=config.hidden_size,
72
+ hidden_ratio=config.hidden_ratio,
73
+ intermediate_size=config.intermediate_size,
74
+ hidden_act=config.hidden_act,
75
+ fuse_swiglu=config.fuse_swiglu
76
+ )
77
+
78
+ def forward(
79
+ self,
80
+ hidden_states: torch.Tensor,
81
+ attention_mask: Optional[torch.Tensor] = None,
82
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
83
+ use_cache: Optional[bool] = False,
84
+ output_attentions: Optional[bool] = False,
85
+ **kwargs: Unpack[Dict]
86
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
87
+
88
+ residual = hidden_states
89
+
90
+ hidden_states = self.attn_norm(hidden_states)
91
+ hidden_states, attentions, past_key_values = self.attn(
92
+ hidden_states=hidden_states,
93
+ attention_mask=attention_mask,
94
+ past_key_values=past_key_values,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ **kwargs
98
+ )
99
+ if self.config.fuse_norm:
100
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
101
+ else:
102
+ hidden_states = residual + hidden_states
103
+ residual = hidden_states
104
+ hidden_states = self.mlp_norm(hidden_states)
105
+ hidden_states = self.mlp(hidden_states)
106
+ hidden_states = residual + hidden_states
107
+
108
+ outputs = (hidden_states, attentions, past_key_values)
109
+
110
+ return outputs
111
+
112
+
113
+ class ABCPreTrainedModel(PreTrainedModel):
114
+
115
+ config_class = ABCConfig
116
+ base_model_prefix = 'model'
117
+ supports_gradient_checkpointing = True
118
+ _no_split_modules = ['ABCBlock']
119
+ _supports_cache_class = True
120
+
121
+ def __init__(self, *inputs, **kwargs):
122
+ super().__init__(*inputs, **kwargs)
123
+
124
+ def _init_weights(
125
+ self,
126
+ module: nn.Module,
127
+ prenorm_residual_strategy: Optional[str] = 'rescale',
128
+ num_residuals_per_layer: int = 2,
129
+ ):
130
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
131
+ # Slightly different from the TF version which uses truncated_normal for initialization
132
+ # cf https://github.com/pytorch/pytorch/pull/5617
133
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
134
+ if module.bias is not None:
135
+ nn.init.zeros_(module.bias)
136
+ elif isinstance(module, nn.Embedding):
137
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
138
+ elif hasattr(module, 'reset_parameters'):
139
+ module.reset_parameters()
140
+
141
+ if prenorm_residual_strategy is not None:
142
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
143
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
144
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
145
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
146
+ #
147
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
148
+ p = None
149
+ if hasattr(module, 'o_proj'):
150
+ p = module.o_proj.weight
151
+ elif hasattr(module, 'down_proj'):
152
+ p = module.down_proj.weight
153
+ if p is not None:
154
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
155
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
156
+ # We need to reinit p since this code could be called multiple times
157
+ # Having just p *= scale would repeatedly scale it down
158
+ if prenorm_residual_strategy == 'rescale':
159
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
160
+ with torch.no_grad():
161
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
162
+ elif prenorm_residual_strategy == 'zero':
163
+ nn.init.zeros_(p)
164
+ else:
165
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
166
+
167
+
168
+ class ABCModel(ABCPreTrainedModel):
169
+
170
+ def __init__(self, config: ABCConfig):
171
+ super().__init__(config)
172
+ self.padding_idx = config.pad_token_id
173
+ self.vocab_size = config.vocab_size
174
+
175
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
176
+ self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.")
203
+ output_attentions = False
204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
206
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
207
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
208
+
209
+ # retrieve input_ids and inputs_embeds
210
+ if input_ids is not None and inputs_embeds is not None:
211
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
212
+ if input_ids is None and inputs_embeds is None:
213
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.embeddings(input_ids)
217
+ hidden_states = inputs_embeds
218
+
219
+ if use_cache and not isinstance(past_key_values, Cache):
220
+ past_key_values = Cache.from_legacy_cache(past_key_values)
221
+
222
+ if self.gradient_checkpointing and self.training and use_cache:
223
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
224
+ use_cache = False
225
+
226
+ all_hidden_states = () if output_hidden_states else None
227
+ all_attns = () if output_attentions else None
228
+ for layer in self.layers:
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ **kwargs
241
+ )
242
+ else:
243
+ hidden_states, attentions, past_key_values = layer(
244
+ hidden_states,
245
+ attention_mask,
246
+ past_key_values=past_key_values,
247
+ use_cache=use_cache,
248
+ output_attentions=output_attentions,
249
+ **kwargs
250
+ )
251
+
252
+ if output_attentions:
253
+ all_attns += (attentions,)
254
+
255
+ hidden_states = self.norm(hidden_states)
256
+
257
+ # add hidden states from the last decoder layer
258
+ if output_hidden_states:
259
+ all_hidden_states += (hidden_states,)
260
+
261
+ if not return_dict:
262
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
263
+ return BaseModelOutputWithPast(
264
+ last_hidden_state=hidden_states,
265
+ past_key_values=past_key_values,
266
+ hidden_states=all_hidden_states,
267
+ attentions=all_attns
268
+ )
269
+
270
+
271
+ class ABCForCausalLM(ABCPreTrainedModel, GenerationMixin):
272
+
273
+ _tied_weights_keys = ["lm_head.weight"]
274
+
275
+ def __init__(self, config):
276
+ super().__init__(config)
277
+ self.model = ABCModel(config)
278
+ self.vocab_size = config.vocab_size
279
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
280
+ self.criterion = None
281
+
282
+ # Initialize weights and apply final processing
283
+ self.post_init()
284
+
285
+ def get_input_embeddings(self):
286
+ return self.model.embeddings
287
+
288
+ def set_input_embeddings(self, value):
289
+ self.model.embeddings = value
290
+
291
+ def get_output_embeddings(self):
292
+ return self.lm_head
293
+
294
+ def set_output_embeddings(self, new_embeddings):
295
+ self.lm_head = new_embeddings
296
+
297
+ def set_decoder(self, decoder):
298
+ self.model = decoder
299
+
300
+ def get_decoder(self):
301
+ return self.model
302
+
303
+ def generate(self, *args, **kwargs):
304
+ try:
305
+ return super().generate(*args, **kwargs)
306
+ except AttributeError as exception:
307
+ if 'past_key_values' in str(exception):
308
+ raise AttributeError(
309
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
310
+ f"which is not supported for {self.__class__.__name__}. "
311
+ f"Try another generation strategy instead. "
312
+ f"For the available generation strategies, check this doc: "
313
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
314
+ )
315
+ else:
316
+ raise exception
317
+
318
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
319
+ def prepare_inputs_for_generation(
320
+ self,
321
+ input_ids: torch.LongTensor = None,
322
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
323
+ attention_mask: Optional[torch.Tensor] = None,
324
+ inputs_embeds: Optional[torch.Tensor] = None,
325
+ use_cache: bool = True,
326
+ logits_to_keep: Optional[int] = None,
327
+ **kwargs
328
+ ):
329
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
330
+ if past_key_values is not None and len(past_key_values) > 0:
331
+ input_ids = input_ids[:, -1:]
332
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
333
+ if inputs_embeds is not None and len(past_key_values) == 0:
334
+ model_inputs = {'inputs_embeds': inputs_embeds}
335
+ else:
336
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
337
+ # recompiles graphs as the stride of the inputs is a guard.
338
+ # Ref: https://github.com/huggingface/transformers/pull/29114
339
+ # TODO: use `next_tokens` directly instead.
340
+ model_inputs = {'input_ids': input_ids.contiguous()}
341
+
342
+ if logits_to_keep is not None:
343
+ model_inputs['logits_to_keep'] = logits_to_keep
344
+
345
+ model_inputs.update({
346
+ 'past_key_values': past_key_values,
347
+ 'use_cache': use_cache,
348
+ 'attention_mask': attention_mask,
349
+ })
350
+ return model_inputs
351
+
352
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
353
+ def forward(
354
+ self,
355
+ input_ids: torch.LongTensor = None,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ inputs_embeds: Optional[torch.Tensor] = None,
358
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
359
+ labels: Optional[torch.LongTensor] = None,
360
+ use_cache: Optional[bool] = None,
361
+ output_attentions: Optional[bool] = None,
362
+ output_hidden_states: Optional[bool] = None,
363
+ return_dict: Optional[bool] = None,
364
+ logits_to_keep: Optional[int] = 0,
365
+ **kwargs: Unpack[Dict]
366
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
367
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
368
+ output_hidden_states = (
369
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
370
+ )
371
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
372
+
373
+ outputs = self.model(
374
+ input_ids=input_ids,
375
+ attention_mask=attention_mask,
376
+ inputs_embeds=inputs_embeds,
377
+ past_key_values=past_key_values,
378
+ use_cache=use_cache,
379
+ output_attentions=output_attentions,
380
+ output_hidden_states=output_hidden_states,
381
+ return_dict=return_dict,
382
+ **kwargs
383
+ )
384
+
385
+ hidden_states = outputs[0]
386
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
387
+
388
+ loss, logits = None, None
389
+ if not fuse_linear_and_cross_entropy or labels is None:
390
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
391
+ if labels is not None:
392
+ if getattr(self, 'criterion', None) is None:
393
+ if fuse_linear_and_cross_entropy:
394
+ criterion = FusedLinearCrossEntropyLoss()
395
+ elif self.config.fuse_cross_entropy:
396
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
397
+ else:
398
+ criterion = nn.CrossEntropyLoss()
399
+ else:
400
+ criterion = self.criterion
401
+ labels = labels.to(hidden_states.device)
402
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
403
+ if fuse_linear_and_cross_entropy:
404
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
405
+ else:
406
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
407
+
408
+ if not return_dict:
409
+ output = (logits,) + outputs[1:]
410
+ return (loss,) + output if loss is not None else output
411
+
412
+ return CausalLMOutputWithPast(
413
+ loss=loss,
414
+ logits=logits,
415
+ past_key_values=outputs.past_key_values,
416
+ hidden_states=outputs.hidden_states,
417
+ attentions=outputs.attentions,
418
+ )
fla/models/bitnet/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
6
+ from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel
7
+
8
+ AutoConfig.register(BitNetConfig.model_type, BitNetConfig)
9
+ AutoModel.register(BitNetConfig, BitNetModel)
10
+ AutoModelForCausalLM.register(BitNetConfig, BitNetForCausalLM)
11
+
12
+
13
+ __all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel']
fla/models/bitnet/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (739 Bytes). View file
 
fla/models/bitnet/__pycache__/configuration_bitnet.cpython-311.pyc ADDED
Binary file (2.64 kB). View file
 
fla/models/bitnet/__pycache__/modeling_bitnet.cpython-311.pyc ADDED
Binary file (19.6 kB). View file
 
fla/models/bitnet/configuration_bitnet.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class BitNetConfig(PretrainedConfig):
9
+
10
+ model_type = 'bitnet'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: int = None,
19
+ window_size: Optional[int] = None,
20
+ rope_theta: Optional[float] = 10000.,
21
+ max_position_embeddings: int = 2048,
22
+ hidden_ratio: Optional[int] = 4,
23
+ intermediate_size: Optional[int] = None,
24
+ hidden_act: str = "swish",
25
+ initializer_range: float = 0.006,
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ use_cache: bool = True,
29
+ pad_token_id: int = None,
30
+ bos_token_id: int = 1,
31
+ eos_token_id: int = 2,
32
+ tie_word_embeddings: bool = False,
33
+ fuse_norm: bool = True,
34
+ fuse_swiglu: bool = True,
35
+ fuse_cross_entropy: bool = True,
36
+ vocab_size: int = 32000,
37
+ **kwargs,
38
+ ):
39
+ self.hidden_size = hidden_size
40
+ self.num_hidden_layers = num_hidden_layers
41
+ self.num_heads = num_heads
42
+ self.num_kv_heads = num_kv_heads
43
+ self.window_size = window_size
44
+ self.rope_theta = rope_theta
45
+ self.max_position_embeddings = max_position_embeddings
46
+
47
+ self.hidden_ratio = hidden_ratio
48
+ self.intermediate_size = intermediate_size
49
+ self.hidden_act = hidden_act
50
+
51
+ self.initializer_range = initializer_range
52
+ self.elementwise_affine = elementwise_affine
53
+ self.norm_eps = norm_eps
54
+ self.use_cache = use_cache
55
+
56
+ self.fuse_norm = fuse_norm
57
+ self.fuse_swiglu = fuse_swiglu
58
+ self.fuse_cross_entropy = fuse_cross_entropy
59
+ self.vocab_size = vocab_size
60
+
61
+ super().__init__(
62
+ pad_token_id=pad_token_id,
63
+ bos_token_id=bos_token_id,
64
+ eos_token_id=eos_token_id,
65
+ tie_word_embeddings=tie_word_embeddings,
66
+ **kwargs,
67
+ )
fla/models/delta_net/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
6
+ from fla.models.delta_net.modeling_delta_net import DeltaNetForCausalLM, DeltaNetModel
7
+
8
+ AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
9
+ AutoModel.register(DeltaNetConfig, DeltaNetModel)
10
+ AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
11
+
12
+ __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']
fla/models/forgetting_transformer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (888 Bytes). View file
 
fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-311.pyc ADDED
Binary file (18.1 kB). View file
 
fla/models/forgetting_transformer/configuration_forgetting_transformer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class ForgettingTransformerConfig(PretrainedConfig):
9
+
10
+ model_type = 'forgetting_transformer'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ num_hidden_layers: int = 24,
17
+ num_heads: int = 32,
18
+ num_kv_heads: Optional[int] = None,
19
+ qkv_bias: bool = False,
20
+ qk_norm: bool = False,
21
+ window_size: Optional[int] = None,
22
+ use_output_gate: bool = False,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ initializer_range: float = 0.006,
27
+ elementwise_affine: Optional[bool] = True,
28
+ norm_eps: float = 1e-6,
29
+ use_cache: bool = True,
30
+ pad_token_id: Optional[int] = None,
31
+ bos_token_id: int = 1,
32
+ eos_token_id: int = 2,
33
+ tie_word_embeddings: bool = False,
34
+ fuse_norm: bool = True,
35
+ fuse_swiglu: bool = True,
36
+ fuse_cross_entropy: bool = True,
37
+ vocab_size: int = 32000,
38
+ **kwargs,
39
+ ):
40
+ self.hidden_size = hidden_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_heads = num_heads
43
+ self.num_kv_heads = num_kv_heads
44
+ self.qkv_bias = qkv_bias
45
+ self.qk_norm = qk_norm
46
+ self.window_size = window_size
47
+ self.use_output_gate = use_output_gate
48
+ self.hidden_ratio = hidden_ratio
49
+ self.intermediate_size = intermediate_size
50
+ self.hidden_act = hidden_act
51
+
52
+ self.initializer_range = initializer_range
53
+ self.elementwise_affine = elementwise_affine
54
+ self.norm_eps = norm_eps
55
+ self.use_cache = use_cache
56
+
57
+ self.fuse_norm = fuse_norm
58
+ self.fuse_swiglu = fuse_swiglu
59
+ self.fuse_cross_entropy = fuse_cross_entropy
60
+ self.vocab_size = vocab_size
61
+
62
+ super().__init__(
63
+ pad_token_id=pad_token_id,
64
+ bos_token_id=bos_token_id,
65
+ eos_token_id=eos_token_id,
66
+ tie_word_embeddings=tie_word_embeddings,
67
+ **kwargs,
68
+ )
fla/models/gated_deltanet/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (803 Bytes). View file
 
fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-311.pyc ADDED
Binary file (3.75 kB). View file
 
fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GatedDeltaProductConfig(PretrainedConfig):
9
+ model_type = "gated_deltaproduct"
10
+ keys_to_ignore_at_inference = ["past_key_values"]
11
+
12
+ def __init__(
13
+ self,
14
+ attn_mode: str = "chunk",
15
+ hidden_size: int = 2048,
16
+ expand_v: int = 2,
17
+ use_gate: bool = True,
18
+ use_short_conv: bool = True,
19
+ conv_size: int = 4,
20
+ head_dim: int = 256,
21
+ num_heads: int = 6,
22
+ max_position_embeddings: int = 2048,
23
+ hidden_ratio: Optional[int] = 4,
24
+ intermediate_size: Optional[int] = None,
25
+ hidden_act: str = "swish",
26
+ num_hidden_layers: int = 21,
27
+ norm_first: bool = False,
28
+ norm_eps: float = 1e-6,
29
+ attn: Optional[Dict] = None,
30
+ use_cache: bool = True,
31
+ pad_token_id: int | None = None,
32
+ bos_token_id: int = 1,
33
+ eos_token_id: int = 2,
34
+ tie_word_embeddings: bool = False,
35
+ initializer_range: float = 0.006,
36
+ fuse_cross_entropy: bool = True,
37
+ vocab_size: int = 32000,
38
+ use_forget_gate: bool = False, # when true Gated DeltaProduct, when false DeltaProduct
39
+ allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
40
+ num_householder: int = 1,
41
+ **kwargs,
42
+ ):
43
+ self.attn_mode = attn_mode
44
+ self.hidden_size = hidden_size
45
+ self.expand_v = expand_v
46
+ self.use_gate = use_gate
47
+ self.use_short_conv = use_short_conv
48
+ self.conv_size = conv_size
49
+ self.head_dim = head_dim
50
+ self.num_heads = num_heads
51
+ self.max_position_embeddings = max_position_embeddings
52
+
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.hidden_act = hidden_act
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.norm_first = norm_first
58
+ self.norm_eps = norm_eps
59
+ self.attn = attn
60
+ self.use_cache = use_cache
61
+ self.initializer_range = initializer_range
62
+ self.fuse_cross_entropy = fuse_cross_entropy
63
+ self.vocab_size = vocab_size
64
+
65
+ # DeltaProduct specific
66
+ self.allow_neg_eigval = allow_neg_eigval
67
+ self.num_householder = num_householder
68
+ self.use_forget_gate = use_forget_gate
69
+
70
+ if attn is not None:
71
+ if not isinstance(attn, Dict):
72
+ raise ValueError("attn must be a dictionary")
73
+ if "layers" not in attn:
74
+ raise ValueError(
75
+ "Layer indices must be provided to initialize hybrid attention layers"
76
+ )
77
+ if "num_heads" not in attn:
78
+ raise ValueError(
79
+ "Number of heads must be provided to initialize hybrid attention layers"
80
+ )
81
+ attn["num_kv_heads"] = attn.get("num_kv_heads", attn["num_heads"])
82
+ attn["window_size"] = attn.get("window_size", None)
83
+
84
+ super().__init__(
85
+ pad_token_id=pad_token_id,
86
+ bos_token_id=bos_token_id,
87
+ eos_token_id=eos_token_id,
88
+ tie_word_embeddings=tie_word_embeddings,
89
+ **kwargs,
90
+ )
fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
17
+ from transformers.utils.deprecation import deprecate_kwarg
18
+
19
+ from fla.layers.attn import Attention
20
+ from fla.layers.gated_deltaproduct import GatedDeltaProduct
21
+ from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
22
+ from fla.models.utils import Cache
23
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
24
+ from fla.modules.activations import swiglu_linear
25
+ from fla.modules.layernorm import rms_norm_linear
26
+
27
+ if TYPE_CHECKING:
28
+ from transformers.processing_utils import Unpack
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class GatedDeltaNetMLP(nn.Module):
34
+ def __init__(
35
+ self,
36
+ hidden_size: int,
37
+ hidden_ratio: Optional[int] = None,
38
+ intermediate_size: Optional[int] = None,
39
+ hidden_act: str = "swish",
40
+ norm_first: bool = True,
41
+ norm_eps: float = 1e-5,
42
+ ) -> GatedDeltaNetMLP:
43
+ super().__init__()
44
+
45
+ self.hidden_size = hidden_size
46
+ # the final number of params is `hidden_ratio * hidden_size^2`
47
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
48
+ if hidden_ratio is None:
49
+ hidden_ratio = 4
50
+ if intermediate_size is None:
51
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
52
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.norm_first = norm_first
56
+
57
+ if norm_first:
58
+ self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
59
+
60
+ self.gate_proj = nn.Linear(
61
+ self.hidden_size, self.intermediate_size * 2, bias=False
62
+ )
63
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
64
+ self.act_fn = ACT2FN[hidden_act]
65
+
66
+ def forward(
67
+ self,
68
+ x: torch.Tensor,
69
+ **kwargs: Unpack[Dict],
70
+ ) -> torch.Tensor:
71
+ if self.norm_first:
72
+ x = rms_norm_linear(
73
+ x,
74
+ self.norm.weight,
75
+ self.norm.bias,
76
+ self.gate_proj.weight,
77
+ self.gate_proj.bias,
78
+ )
79
+ else:
80
+ x = self.gate_proj(x)
81
+ gate, y = x.chunk(2, -1)
82
+ return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
83
+
84
+
85
+ class GatedDeltaProductBlock(nn.Module):
86
+ def __init__(self, config: GatedDeltaProductConfig, layer_idx: int):
87
+ super().__init__()
88
+ self.hidden_size = config.hidden_size
89
+
90
+ if not config.norm_first:
91
+ self.attn_norm = RMSNorm(
92
+ hidden_size=config.hidden_size, eps=config.norm_eps
93
+ )
94
+ if config.attn is not None and layer_idx in config.attn["layers"]:
95
+ self.attn = Attention(
96
+ hidden_size=config.hidden_size,
97
+ num_heads=config.attn["num_heads"],
98
+ num_kv_heads=config.attn["num_kv_heads"],
99
+ window_size=config.attn["window_size"],
100
+ max_position_embeddings=config.max_position_embeddings,
101
+ layer_idx=layer_idx,
102
+ )
103
+ else:
104
+ self.attn = GatedDeltaProduct(
105
+ mode=config.attn_mode,
106
+ hidden_size=config.hidden_size,
107
+ expand_v=config.expand_v,
108
+ head_dim=config.head_dim,
109
+ num_heads=config.num_heads,
110
+ use_gate=config.use_gate,
111
+ use_forget_gate=config.use_forget_gate,
112
+ use_short_conv=config.use_short_conv,
113
+ conv_size=config.conv_size,
114
+ norm_first=config.norm_first,
115
+ norm_eps=config.norm_eps,
116
+ allow_neg_eigval=config.allow_neg_eigval,
117
+ num_householder=config.num_householder,
118
+ layer_idx=layer_idx,
119
+ use_beta_conv=config.use_beta_conv
120
+ )
121
+ if not config.norm_first:
122
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
123
+ self.mlp = GatedDeltaNetMLP(
124
+ hidden_size=config.hidden_size,
125
+ hidden_ratio=config.hidden_ratio,
126
+ intermediate_size=config.intermediate_size,
127
+ hidden_act=config.hidden_act,
128
+ norm_first=config.norm_first,
129
+ norm_eps=config.norm_eps,
130
+ )
131
+
132
+ def forward(
133
+ self,
134
+ hidden_states: torch.Tensor,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
137
+ use_cache: Optional[bool] = False,
138
+ output_attentions: Optional[bool] = False,
139
+ **kwargs: Unpack[Dict],
140
+ ) -> Tuple[
141
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
142
+ ]:
143
+ residual = hidden_states
144
+ if hasattr(self, "attn_norm"):
145
+ hidden_states = self.attn_norm(hidden_states)
146
+ hidden_states, attentions, past_key_values = self.attn(
147
+ hidden_states=hidden_states,
148
+ attention_mask=attention_mask,
149
+ past_key_values=past_key_values,
150
+ use_cache=use_cache,
151
+ output_attentions=output_attentions,
152
+ **kwargs,
153
+ )
154
+ if hasattr(self, "mlp_norm"):
155
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
156
+ else:
157
+ hidden_states = residual + hidden_states
158
+ residual = hidden_states
159
+ hidden_states = self.mlp(hidden_states, **kwargs)
160
+ hidden_states = residual + hidden_states
161
+
162
+ outputs = (hidden_states, attentions, past_key_values)
163
+
164
+ return outputs
165
+
166
+
167
+ class GatedDeltaProductPreTrainedModel(PreTrainedModel):
168
+ config_class = GatedDeltaProductConfig
169
+ supports_gradient_checkpointing = True
170
+ _no_split_modules = ["GatedDeltaNetBlock"]
171
+
172
+ def __init__(self, *inputs, **kwargs):
173
+ super().__init__(*inputs, **kwargs)
174
+
175
+ def _init_weights(
176
+ self,
177
+ module: nn.Module,
178
+ rescale_prenorm_residual: bool = True,
179
+ num_residuals_per_layer: int = 2,
180
+ ):
181
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
182
+ # Slightly different from the TF version which uses truncated_normal for initialization
183
+ # cf https://github.com/pytorch/pytorch/pull/5617
184
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
185
+ if module.bias is not None:
186
+ nn.init.zeros_(module.bias)
187
+ elif isinstance(module, nn.Embedding):
188
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
189
+ if module.padding_idx is not None:
190
+ module.weight.data[module.padding_idx].zero_()
191
+
192
+ if rescale_prenorm_residual:
193
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
194
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
195
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
196
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
197
+ #
198
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
199
+ for name, p in module.named_parameters():
200
+ if name in ["o_proj.weight", "down_proj.weight"]:
201
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
202
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
203
+ # We need to reinit p since this code could be called multiple times
204
+ # Having just p *= scale would repeatedly scale it down
205
+ with torch.no_grad():
206
+ p /= math.sqrt(
207
+ num_residuals_per_layer * self.config.num_hidden_layers
208
+ )
209
+
210
+
211
+ class GatedDeltaProductModel(GatedDeltaProductPreTrainedModel):
212
+ def __init__(self, config: GatedDeltaProductConfig):
213
+ super().__init__(config)
214
+ self.padding_idx = config.pad_token_id
215
+ self.vocab_size = config.vocab_size
216
+
217
+ self.embeddings = nn.Embedding(
218
+ config.vocab_size, config.hidden_size, self.padding_idx
219
+ )
220
+ self.layers = nn.ModuleList(
221
+ [
222
+ GatedDeltaProductBlock(config, layer_idx)
223
+ for layer_idx in range(config.num_hidden_layers)
224
+ ]
225
+ )
226
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
227
+
228
+ self.gradient_checkpointing = False
229
+
230
+ self.post_init()
231
+
232
+ def get_input_embeddings(self):
233
+ return self.embeddings
234
+
235
+ def set_input_embeddings(self, value):
236
+ self.embeddings = value
237
+
238
+ def forward(
239
+ self,
240
+ input_ids: Optional[torch.LongTensor] = None,
241
+ attention_mask: Optional[torch.Tensor] = None,
242
+ inputs_embeds: Optional[torch.FloatTensor] = None,
243
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
244
+ use_cache: Optional[bool] = None,
245
+ output_attentions: Optional[bool] = None,
246
+ output_hidden_states: Optional[bool] = None,
247
+ return_dict: Optional[bool] = None,
248
+ **kwargs: Unpack[Dict],
249
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
250
+ if output_attentions:
251
+ warnings.warn(
252
+ "`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.",
253
+ stacklevel=2,
254
+ )
255
+ output_attentions = False
256
+ output_attentions = (
257
+ output_attentions
258
+ if output_attentions is not None
259
+ else self.config.output_attentions
260
+ )
261
+ output_hidden_states = (
262
+ output_hidden_states
263
+ if output_hidden_states is not None
264
+ else self.config.output_hidden_states
265
+ )
266
+ use_cache = (
267
+ use_cache
268
+ if use_cache is not None
269
+ else (self.config.use_cache if not self.training else False)
270
+ )
271
+ return_dict = (
272
+ return_dict if return_dict is not None else self.config.use_return_dict
273
+ )
274
+
275
+ # retrieve input_ids and inputs_embeds
276
+ if input_ids is not None and inputs_embeds is not None:
277
+ raise ValueError(
278
+ "You cannot specify both input_ids and inputs_embeds at the same time"
279
+ )
280
+ if input_ids is None and inputs_embeds is None:
281
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
282
+
283
+ if inputs_embeds is None:
284
+ inputs_embeds = self.embeddings(input_ids)
285
+ hidden_states = inputs_embeds
286
+
287
+ if use_cache and not isinstance(past_key_values, Cache):
288
+ past_key_values = Cache.from_legacy_cache(past_key_values)
289
+
290
+ if self.gradient_checkpointing and self.training and use_cache:
291
+ logger.warning_once(
292
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
293
+ )
294
+ use_cache = False
295
+
296
+ all_hidden_states = () if output_hidden_states else None
297
+ all_attns = () if output_attentions else None
298
+ for layer in self.layers:
299
+ if output_hidden_states:
300
+ all_hidden_states += (hidden_states,)
301
+
302
+ if self.gradient_checkpointing and self.training:
303
+ hidden_states, attentions, past_key_values = (
304
+ self._gradient_checkpointing_func(
305
+ layer.__call__,
306
+ hidden_states,
307
+ attention_mask,
308
+ past_key_values,
309
+ use_cache,
310
+ output_attentions,
311
+ **kwargs,
312
+ )
313
+ )
314
+ else:
315
+ hidden_states, attentions, past_key_values = layer(
316
+ hidden_states,
317
+ attention_mask=attention_mask,
318
+ past_key_values=past_key_values,
319
+ use_cache=use_cache,
320
+ output_attentions=output_attentions,
321
+ **kwargs,
322
+ )
323
+
324
+ if output_attentions:
325
+ all_attns += (attentions,)
326
+
327
+ hidden_states = self.norm(hidden_states)
328
+ # add hidden states from the last decoder layer
329
+ if output_hidden_states:
330
+ all_hidden_states += (hidden_states,)
331
+
332
+ if not return_dict:
333
+ return tuple(
334
+ i
335
+ for i in [
336
+ hidden_states,
337
+ past_key_values,
338
+ all_hidden_states,
339
+ all_attns,
340
+ ]
341
+ if i is not None
342
+ )
343
+ return BaseModelOutputWithPast(
344
+ last_hidden_state=hidden_states,
345
+ past_key_values=past_key_values,
346
+ hidden_states=all_hidden_states,
347
+ attentions=all_attns,
348
+ )
349
+
350
+
351
+ class GatedDeltaProductForCausalLM(GatedDeltaProductPreTrainedModel, GenerationMixin):
352
+ _tied_weights_keys = ["lm_head.weight"]
353
+
354
+ def __init__(self, config):
355
+ super().__init__(config)
356
+ self.model = GatedDeltaProductModel(config)
357
+ self.vocab_size = config.vocab_size
358
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
359
+
360
+ # Initialize weights and apply final processing
361
+ self.post_init()
362
+
363
+ def get_input_embeddings(self):
364
+ return self.model.embeddings
365
+
366
+ def set_input_embeddings(self, value):
367
+ self.model.embeddings = value
368
+
369
+ def get_output_embeddings(self):
370
+ return self.lm_head
371
+
372
+ def set_output_embeddings(self, new_embeddings):
373
+ self.lm_head = new_embeddings
374
+
375
+ def set_decoder(self, decoder):
376
+ self.model = decoder
377
+
378
+ def get_decoder(self):
379
+ return self.model
380
+
381
+ def generate(self, *args, **kwargs):
382
+ try:
383
+ return super().generate(*args, **kwargs)
384
+ except AttributeError as exception:
385
+ if "past_key_values" in str(exception):
386
+ raise AttributeError(
387
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
388
+ f"which is not supported for {self.__class__.__name__}. "
389
+ f"Try another generation strategy instead. "
390
+ f"For the available generation strategies, check this doc: "
391
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
392
+ )
393
+ else:
394
+ raise exception
395
+
396
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
397
+ def prepare_inputs_for_generation(
398
+ self,
399
+ input_ids: torch.LongTensor = None,
400
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
401
+ attention_mask: Optional[torch.Tensor] = None,
402
+ inputs_embeds: Optional[torch.Tensor] = None,
403
+ use_cache: bool = True,
404
+ num_logits_to_keep: Optional[int] = None,
405
+ logits_to_keep: Optional[int] = None,
406
+ **kwargs,
407
+ ):
408
+ # only last token for `inputs_ids` if the `past_key_values` is passed along is not empty.
409
+ if past_key_values is not None and len(past_key_values) > 0:
410
+ input_ids = input_ids[:, -1:]
411
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
412
+ if inputs_embeds is not None and past_key_values is None:
413
+ model_inputs = {"inputs_embeds": inputs_embeds}
414
+ else:
415
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
416
+ # recompiles graphs as the stride of the inputs is a guard.
417
+ # Ref: https://github.com/huggingface/transformers/pull/29114
418
+ # TODO: use `next_tokens` directly instead.
419
+ model_inputs = {"input_ids": input_ids.contiguous()}
420
+
421
+ if logits_to_keep is not None:
422
+ model_inputs['logits_to_keep'] = logits_to_keep
423
+
424
+ model_inputs.update(
425
+ {
426
+ "past_key_values": past_key_values,
427
+ "use_cache": use_cache,
428
+ "attention_mask": attention_mask,
429
+ "num_logits_to_keep": num_logits_to_keep,
430
+ }
431
+ )
432
+ return model_inputs
433
+
434
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
435
+ def forward(
436
+ self,
437
+ input_ids: torch.LongTensor = None,
438
+ attention_mask: Optional[torch.Tensor] = None,
439
+ inputs_embeds: Optional[torch.Tensor] = None,
440
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
441
+ labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ num_logits_to_keep: Optional[int] = 0,
447
+ logits_to_keep: Optional[int] = 0,
448
+ **kwargs: Unpack[Dict],
449
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
450
+ num_logits_to_keep = 0 if num_logits_to_keep is None else num_logits_to_keep
451
+ output_attentions = (
452
+ output_attentions
453
+ if output_attentions is not None
454
+ else self.config.output_attentions
455
+ )
456
+ output_hidden_states = (
457
+ output_hidden_states
458
+ if output_hidden_states is not None
459
+ else self.config.output_hidden_states
460
+ )
461
+ return_dict = (
462
+ return_dict if return_dict is not None else self.config.use_return_dict
463
+ )
464
+ kwargs.pop("num_items_in_batch", None)
465
+ outputs = self.model(
466
+ input_ids=input_ids,
467
+ attention_mask=attention_mask,
468
+ inputs_embeds=inputs_embeds,
469
+ past_key_values=past_key_values,
470
+ use_cache=use_cache,
471
+ output_attentions=output_attentions,
472
+ output_hidden_states=output_hidden_states,
473
+ return_dict=return_dict,
474
+ **kwargs,
475
+ )
476
+ hidden_states = outputs[0]
477
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
478
+
479
+ loss, logits = None, None
480
+ if not fuse_linear_and_cross_entropy or labels is None:
481
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
482
+ if labels is not None:
483
+ if self.config.fuse_cross_entropy:
484
+ if fuse_linear_and_cross_entropy:
485
+ loss_fct = FusedLinearCrossEntropyLoss()
486
+ else:
487
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
488
+ else:
489
+ loss_fct = nn.CrossEntropyLoss()
490
+ # Enable model parallelism
491
+ labels = labels.to(hidden_states.device)
492
+ labels = torch.cat(
493
+ (
494
+ labels[..., 1:],
495
+ torch.full_like(labels[:, :1], loss_fct.ignore_index),
496
+ ),
497
+ 1,
498
+ )
499
+ if fuse_linear_and_cross_entropy:
500
+ loss = loss_fct(
501
+ hidden_states.view(-1, self.config.hidden_size),
502
+ labels.view(-1),
503
+ self.lm_head.weight,
504
+ self.lm_head.bias,
505
+ )
506
+ else:
507
+ loss = loss_fct(
508
+ logits.view(-1, self.config.vocab_size), labels.view(-1)
509
+ )
510
+
511
+ if not return_dict:
512
+ output = (logits,) + outputs[1:]
513
+ return (loss, *output) if loss is not None else output
514
+ return CausalLMOutputWithPast(
515
+ loss=loss,
516
+ logits=logits,
517
+ past_key_values=outputs.past_key_values,
518
+ hidden_states=outputs.hidden_states,
519
+ attentions=outputs.attentions,
520
+ )
fla/models/gla/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gla.configuration_gla import GLAConfig
6
+ from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
7
+
8
+ AutoConfig.register(GLAConfig.model_type, GLAConfig)
9
+ AutoModel.register(GLAConfig, GLAModel)
10
+ AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
11
+
12
+
13
+ __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']
fla/models/gla/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (714 Bytes). View file
 
fla/models/gsa/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (714 Bytes). View file
 
fla/models/gsa/__pycache__/configuration_gsa.cpython-311.pyc ADDED
Binary file (4.27 kB). View file
 
fla/models/gsa/configuration_gsa.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class GSAConfig(PretrainedConfig):
9
+
10
+ model_type = 'gsa'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ hidden_size: int = 2048,
16
+ gate_logit_normalizer: Optional[int] = 8,
17
+ clamp_min: Optional[float] = None,
18
+ clamp_max: Optional[float] = None,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ num_slots: Optional[int] = 64,
25
+ use_short_conv: bool = False,
26
+ conv_size: int = 4,
27
+ exapnd_k: float = 1,
28
+ exapnd_v: float = 1,
29
+ feature_map: str = 'swish',
30
+ use_output_gate: bool = False,
31
+ use_norm: bool = True,
32
+ max_position_embeddings: int = 2048,
33
+ hidden_act: str = "swish",
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-6,
36
+ attn: Optional[Dict] = None,
37
+ use_cache: bool = True,
38
+ pad_token_id: int = None,
39
+ bos_token_id: int = 1,
40
+ eos_token_id: int = 2,
41
+ initializer_range: float = 0.006,
42
+ tie_word_embeddings: bool = False,
43
+ fuse_norm: bool = True,
44
+ fuse_swiglu: bool = True,
45
+ fuse_cross_entropy: bool = True,
46
+ vocab_size: int = 32000,
47
+ **kwargs
48
+ ):
49
+ self.hidden_size = hidden_size
50
+ self.gate_logit_normalizer = gate_logit_normalizer
51
+ self.clamp_min = clamp_min
52
+ self.clamp_max = clamp_max
53
+ self.hidden_ratio = hidden_ratio
54
+ self.intermediate_size = intermediate_size
55
+ self.num_hidden_layers = num_hidden_layers
56
+ self.num_heads = num_heads
57
+ self.num_kv_heads = num_kv_heads
58
+ self.num_slots = num_slots
59
+ self.use_short_conv = use_short_conv
60
+ self.conv_size = conv_size
61
+ self.expand_k = exapnd_k
62
+ self.expand_v = exapnd_v
63
+ self.feature_map = feature_map
64
+ self.use_output_gate = use_output_gate
65
+ self.use_norm = use_norm
66
+ self.max_position_embeddings = max_position_embeddings
67
+ self.hidden_act = hidden_act
68
+ self.elementwise_affine = elementwise_affine
69
+ self.norm_eps = norm_eps
70
+ self.attn = attn
71
+ self.use_cache = use_cache
72
+ self.initializer_range = initializer_range
73
+
74
+ self.fuse_norm = fuse_norm
75
+ self.fuse_swiglu = fuse_swiglu
76
+ self.fuse_cross_entropy = fuse_cross_entropy
77
+ self.vocab_size = vocab_size
78
+
79
+ if attn is not None:
80
+ if not isinstance(attn, Dict):
81
+ raise ValueError("attn must be a dictionary")
82
+ if 'layers' not in attn:
83
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
84
+ if 'num_heads' not in attn:
85
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
86
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
87
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
88
+ attn['window_size'] = attn.get('window_size', None)
89
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
90
+
91
+ super().__init__(
92
+ pad_token_id=pad_token_id,
93
+ bos_token_id=bos_token_id,
94
+ eos_token_id=eos_token_id,
95
+ tie_word_embeddings=tie_word_embeddings,
96
+ **kwargs,
97
+ )
fla/models/gsa/modeling_gsa.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gsa import GatedSlotAttention
20
+ from fla.models.gsa.configuration_gsa import GSAConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GSAMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class GSABlock(nn.Module):
33
+ def __init__(self, config: GSAConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = GatedSlotAttention(
53
+ hidden_size=config.hidden_size,
54
+ expand_k=config.expand_k,
55
+ expand_v=config.expand_v,
56
+ num_heads=config.num_heads,
57
+ num_kv_heads=config.num_kv_heads,
58
+ num_slots=config.num_slots,
59
+ use_short_conv=config.use_short_conv,
60
+ conv_size=config.conv_size,
61
+ feature_map=config.feature_map,
62
+ use_output_gate=config.use_output_gate,
63
+ use_norm=config.use_norm,
64
+ gate_fn=config.hidden_act,
65
+ gate_logit_normalizer=config.gate_logit_normalizer,
66
+ elementwise_affine=config.elementwise_affine,
67
+ norm_eps=config.norm_eps,
68
+ fuse_norm=config.fuse_norm,
69
+ layer_idx=layer_idx
70
+ )
71
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
72
+ self.mlp = GSAMLP(
73
+ hidden_size=config.hidden_size,
74
+ hidden_ratio=config.hidden_ratio,
75
+ intermediate_size=config.intermediate_size,
76
+ hidden_act=config.hidden_act,
77
+ fuse_swiglu=config.fuse_swiglu
78
+ )
79
+
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
85
+ use_cache: Optional[bool] = False,
86
+ output_attentions: Optional[bool] = False,
87
+ **kwargs: Unpack[Dict]
88
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
89
+ residual = hidden_states
90
+ hidden_states = self.attn_norm(hidden_states)
91
+ hidden_states, attentions, past_key_values = self.attn(
92
+ hidden_states=hidden_states,
93
+ attention_mask=attention_mask,
94
+ past_key_values=past_key_values,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ **kwargs
98
+ )
99
+ if self.config.fuse_norm:
100
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
101
+ else:
102
+ hidden_states = residual + hidden_states
103
+ residual = hidden_states
104
+ hidden_states = self.mlp_norm(hidden_states)
105
+ hidden_states = self.mlp(hidden_states, **kwargs)
106
+ hidden_states = residual + hidden_states
107
+
108
+ outputs = (hidden_states, attentions, past_key_values)
109
+
110
+ return outputs
111
+
112
+
113
+ class GSAPreTrainedModel(PreTrainedModel):
114
+
115
+ config_class = GSAConfig
116
+ base_model_prefix = 'model'
117
+ supports_gradient_checkpointing = True
118
+ _no_split_modules = ['GSABlock']
119
+ _supports_cache_class = True
120
+
121
+ def __init__(self, *inputs, **kwargs):
122
+ super().__init__(*inputs, **kwargs)
123
+
124
+ def _init_weights(
125
+ self,
126
+ module: nn.Module,
127
+ prenorm_residual_strategy: Optional[str] = 'rescale',
128
+ num_residuals_per_layer: int = 2,
129
+ ):
130
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
131
+ # Slightly different from the TF version which uses truncated_normal for initialization
132
+ # cf https://github.com/pytorch/pytorch/pull/5617
133
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
134
+ if module.bias is not None:
135
+ nn.init.zeros_(module.bias)
136
+ elif isinstance(module, nn.Embedding):
137
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
138
+ elif hasattr(module, 'reset_parameters'):
139
+ module.reset_parameters()
140
+
141
+ if prenorm_residual_strategy is not None:
142
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
143
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
144
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
145
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
146
+ #
147
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
148
+ p = None
149
+ if hasattr(module, 'o_proj'):
150
+ p = module.o_proj.weight
151
+ elif hasattr(module, 'down_proj'):
152
+ p = module.down_proj.weight
153
+ if p is not None:
154
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
155
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
156
+ # We need to reinit p since this code could be called multiple times
157
+ # Having just p *= scale would repeatedly scale it down
158
+ if prenorm_residual_strategy == 'rescale':
159
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
160
+ with torch.no_grad():
161
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
162
+ elif prenorm_residual_strategy == 'zero':
163
+ nn.init.zeros_(p)
164
+ else:
165
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
166
+
167
+
168
+ class GSAModel(GSAPreTrainedModel):
169
+
170
+ def __init__(self, config: GSAConfig):
171
+ super().__init__(config)
172
+ self.padding_idx = config.pad_token_id
173
+ self.vocab_size = config.vocab_size
174
+
175
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
176
+ self.layers = nn.ModuleList([GSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn("`GSAModel` does not `output_attentions` now, setting it to `False`.")
203
+ output_attentions = False
204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
206
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
207
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
208
+
209
+ # retrieve input_ids and inputs_embeds
210
+ if input_ids is not None and inputs_embeds is not None:
211
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
212
+ if input_ids is None and inputs_embeds is None:
213
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.embeddings(input_ids)
217
+ hidden_states = inputs_embeds
218
+
219
+ if use_cache and not isinstance(past_key_values, Cache):
220
+ past_key_values = Cache.from_legacy_cache(past_key_values)
221
+
222
+ if self.gradient_checkpointing and self.training and use_cache:
223
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
224
+ use_cache = False
225
+
226
+ all_hidden_states = () if output_hidden_states else None
227
+ all_attns = () if output_attentions else None
228
+ for layer in self.layers:
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ **kwargs
241
+ )
242
+ else:
243
+ hidden_states, attentions, past_key_values = layer(
244
+ hidden_states,
245
+ attention_mask=attention_mask,
246
+ past_key_values=past_key_values,
247
+ use_cache=use_cache,
248
+ output_attentions=output_attentions,
249
+ **kwargs
250
+ )
251
+
252
+ if output_attentions:
253
+ all_attns += (attentions,)
254
+
255
+ hidden_states = self.norm(hidden_states)
256
+
257
+ # add hidden states from the last decoder layer
258
+ if output_hidden_states:
259
+ all_hidden_states += (hidden_states,)
260
+
261
+ if not return_dict:
262
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
263
+ return BaseModelOutputWithPast(
264
+ last_hidden_state=hidden_states,
265
+ past_key_values=past_key_values,
266
+ hidden_states=all_hidden_states,
267
+ attentions=all_attns
268
+ )
269
+
270
+
271
+ class GSAForCausalLM(GSAPreTrainedModel, GenerationMixin):
272
+
273
+ _tied_weights_keys = ["lm_head.weight"]
274
+
275
+ def __init__(self, config):
276
+
277
+ super().__init__(config)
278
+ self.model = GSAModel(config)
279
+ self.vocab_size = config.vocab_size
280
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
281
+ self.criterion = None
282
+
283
+ # Initialize weights and apply final processing
284
+ self.post_init()
285
+
286
+ def get_input_embeddings(self):
287
+ return self.model.embeddings
288
+
289
+ def set_input_embeddings(self, value):
290
+ self.model.embeddings = value
291
+
292
+ def get_output_embeddings(self):
293
+ return self.lm_head
294
+
295
+ def set_output_embeddings(self, new_embeddings):
296
+ self.lm_head = new_embeddings
297
+
298
+ def set_decoder(self, decoder):
299
+ self.model = decoder
300
+
301
+ def get_decoder(self):
302
+ return self.model
303
+
304
+ def generate(self, *args, **kwargs):
305
+ try:
306
+ return super().generate(*args, **kwargs)
307
+ except AttributeError as exception:
308
+ if 'past_key_values' in str(exception):
309
+ raise AttributeError(
310
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
311
+ f"which is not supported for {self.__class__.__name__}. "
312
+ f"Try another generation strategy instead. "
313
+ f"For the available generation strategies, check this doc: "
314
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
315
+ )
316
+ else:
317
+ raise exception
318
+
319
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
320
+ def prepare_inputs_for_generation(
321
+ self,
322
+ input_ids: torch.LongTensor = None,
323
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ inputs_embeds: Optional[torch.Tensor] = None,
326
+ use_cache: bool = True,
327
+ logits_to_keep: Optional[int] = None,
328
+ **kwargs
329
+ ):
330
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
331
+ if past_key_values is not None and len(past_key_values) > 0:
332
+ input_ids = input_ids[:, -1:]
333
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
334
+ if inputs_embeds is not None and len(past_key_values) == 0:
335
+ model_inputs = {'inputs_embeds': inputs_embeds}
336
+ else:
337
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
338
+ # recompiles graphs as the stride of the inputs is a guard.
339
+ # Ref: https://github.com/huggingface/transformers/pull/29114
340
+ # TODO: use `next_tokens` directly instead.
341
+ model_inputs = {'input_ids': input_ids.contiguous()}
342
+
343
+ if logits_to_keep is not None:
344
+ model_inputs['logits_to_keep'] = logits_to_keep
345
+
346
+ model_inputs.update({
347
+ 'past_key_values': past_key_values,
348
+ 'use_cache': use_cache,
349
+ 'attention_mask': attention_mask,
350
+ })
351
+ return model_inputs
352
+
353
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
354
+ def forward(
355
+ self,
356
+ input_ids: torch.LongTensor = None,
357
+ attention_mask: Optional[torch.Tensor] = None,
358
+ inputs_embeds: Optional[torch.Tensor] = None,
359
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
360
+ labels: Optional[torch.LongTensor] = None,
361
+ use_cache: Optional[bool] = None,
362
+ output_attentions: Optional[bool] = None,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ logits_to_keep: Optional[int] = 0,
366
+ **kwargs: Unpack[Dict]
367
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
+ output_hidden_states = (
370
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
371
+ )
372
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
373
+
374
+ outputs = self.model(
375
+ input_ids=input_ids,
376
+ attention_mask=attention_mask,
377
+ inputs_embeds=inputs_embeds,
378
+ past_key_values=past_key_values,
379
+ use_cache=use_cache,
380
+ output_attentions=output_attentions,
381
+ output_hidden_states=output_hidden_states,
382
+ return_dict=return_dict,
383
+ **kwargs
384
+ )
385
+
386
+ hidden_states = outputs[0]
387
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
388
+
389
+ loss, logits = None, None
390
+ if not fuse_linear_and_cross_entropy or labels is None:
391
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
392
+ if labels is not None:
393
+ if getattr(self, 'criterion', None) is None:
394
+ if fuse_linear_and_cross_entropy:
395
+ criterion = FusedLinearCrossEntropyLoss()
396
+ elif self.config.fuse_cross_entropy:
397
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
398
+ else:
399
+ criterion = nn.CrossEntropyLoss()
400
+ else:
401
+ criterion = self.criterion
402
+ # Enable model parallelism
403
+ labels = labels.to(hidden_states.device)
404
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
405
+ if fuse_linear_and_cross_entropy:
406
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
407
+ else:
408
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
409
+
410
+ if not return_dict:
411
+ output = (logits,) + outputs[1:]
412
+ return (loss,) + output if loss is not None else output
413
+
414
+ return CausalLMOutputWithPast(
415
+ loss=loss,
416
+ logits=logits,
417
+ past_key_values=outputs.past_key_values,
418
+ hidden_states=outputs.hidden_states,
419
+ attentions=outputs.attentions,
420
+ )
fla/models/hgrn/modeling_hgrn.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.hgrn import HGRNAttention
20
+ from fla.models.hgrn.configuration_hgrn import HGRNConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as HGRNMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class HGRNBlock(nn.Module):
33
+ def __init__(self, config: HGRNConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = HGRNAttention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ elementwise_affine=config.elementwise_affine,
59
+ norm_eps=config.norm_eps,
60
+ layer_idx=layer_idx
61
+ )
62
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
63
+ self.mlp = HGRNMLP(
64
+ hidden_size=config.hidden_size,
65
+ hidden_ratio=config.hidden_ratio,
66
+ intermediate_size=config.intermediate_size,
67
+ hidden_act=config.hidden_act,
68
+ fuse_swiglu=config.fuse_swiglu
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ hidden_states: torch.Tensor,
74
+ attention_mask: Optional[torch.Tensor] = None,
75
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
76
+ use_cache: Optional[bool] = False,
77
+ output_attentions: Optional[bool] = False,
78
+ lower_bound: Optional[torch.Tensor] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ lower_bound=lower_bound,
90
+ **kwargs
91
+ )
92
+ if self.config.fuse_norm:
93
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
94
+ else:
95
+ hidden_states = residual + hidden_states
96
+ residual = hidden_states
97
+ hidden_states = self.mlp_norm(hidden_states)
98
+ hidden_states = self.mlp(hidden_states, **kwargs)
99
+ hidden_states = residual + hidden_states
100
+
101
+ outputs = (hidden_states, attentions, past_key_values)
102
+
103
+ return outputs
104
+
105
+
106
+ class HGRNPreTrainedModel(PreTrainedModel):
107
+
108
+ config_class = HGRNConfig
109
+ base_model_prefix = 'model'
110
+ supports_gradient_checkpointing = True
111
+ _no_split_modules = ['HGRNBlock']
112
+ _supports_cache_class = True
113
+
114
+ def __init__(self, *inputs, **kwargs):
115
+ super().__init__(*inputs, **kwargs)
116
+
117
+ def _init_weights(
118
+ self,
119
+ module: nn.Module,
120
+ prenorm_residual_strategy: Optional[str] = 'rescale',
121
+ num_residuals_per_layer: int = 2,
122
+ ):
123
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
124
+ # Slightly different from the TF version which uses truncated_normal for initialization
125
+ # cf https://github.com/pytorch/pytorch/pull/5617
126
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
127
+ if module.bias is not None:
128
+ nn.init.zeros_(module.bias)
129
+ elif isinstance(module, nn.Embedding):
130
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
131
+ elif hasattr(module, 'reset_parameters'):
132
+ module.reset_parameters()
133
+
134
+ if prenorm_residual_strategy is not None:
135
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
136
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
137
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
138
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
139
+ #
140
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
141
+ p = None
142
+ if hasattr(module, 'o_proj'):
143
+ p = module.o_proj.weight
144
+ elif hasattr(module, 'down_proj'):
145
+ p = module.down_proj.weight
146
+ if p is not None:
147
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
148
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
149
+ # We need to reinit p since this code could be called multiple times
150
+ # Having just p *= scale would repeatedly scale it down
151
+ if prenorm_residual_strategy == 'rescale':
152
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
153
+ with torch.no_grad():
154
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
155
+ elif prenorm_residual_strategy == 'zero':
156
+ nn.init.zeros_(p)
157
+ else:
158
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
159
+
160
+
161
+ class HGRNModel(HGRNPreTrainedModel):
162
+
163
+ def __init__(self, config: HGRNConfig):
164
+ super().__init__(config)
165
+ self.padding_idx = config.pad_token_id
166
+ self.vocab_size = config.vocab_size
167
+
168
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
169
+ if config.use_lower_bound:
170
+ self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
171
+ self.layers = nn.ModuleList([HGRNBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
172
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
173
+
174
+ self.gradient_checkpointing = False
175
+
176
+ self.post_init()
177
+
178
+ def get_input_embeddings(self):
179
+ return self.embeddings
180
+
181
+ def set_input_embeddings(self, value):
182
+ self.embeddings = value
183
+
184
+ def forward(
185
+ self,
186
+ input_ids: Optional[torch.LongTensor] = None,
187
+ attention_mask: Optional[torch.Tensor] = None, # noqa
188
+ inputs_embeds: Optional[torch.FloatTensor] = None,
189
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
190
+ use_cache: Optional[bool] = None,
191
+ output_attentions: Optional[bool] = None,
192
+ output_hidden_states: Optional[bool] = None,
193
+ return_dict: Optional[bool] = None,
194
+ **kwargs: Unpack[Dict]
195
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
196
+ if output_attentions:
197
+ warnings.warn("`HGRNModel` does not `output_attentions` now, setting it to `False`.")
198
+ output_attentions = False
199
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
200
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
201
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
202
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
203
+
204
+ # retrieve input_ids and inputs_embeds
205
+ if input_ids is not None and inputs_embeds is not None:
206
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
207
+ if input_ids is None and inputs_embeds is None:
208
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
209
+
210
+ if inputs_embeds is None:
211
+ inputs_embeds = self.embeddings(input_ids)
212
+ hidden_states = inputs_embeds
213
+
214
+ if use_cache and not isinstance(past_key_values, Cache):
215
+ past_key_values = Cache.from_legacy_cache(past_key_values)
216
+
217
+ if self.gradient_checkpointing and self.training and use_cache:
218
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
219
+ use_cache = False
220
+
221
+ all_hidden_states = () if output_hidden_states else None
222
+ all_attns = () if output_attentions else None
223
+
224
+ if self.config.use_lower_bound:
225
+ lower_bounds = self.lower_bounds.softmax(0)
226
+ lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
227
+ for i, layer in enumerate(self.layers):
228
+ if output_hidden_states:
229
+ all_hidden_states += (hidden_states,)
230
+
231
+ lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ lower_bound,
241
+ **kwargs
242
+ )
243
+ else:
244
+ hidden_states, attentions, past_key_values = layer(
245
+ hidden_states,
246
+ attention_mask=attention_mask,
247
+ past_key_values=past_key_values,
248
+ use_cache=use_cache,
249
+ output_attentions=output_attentions,
250
+ lower_bound=lower_bound,
251
+ **kwargs
252
+ )
253
+
254
+ if output_attentions:
255
+ all_attns += (attentions,)
256
+
257
+ hidden_states = self.norm(hidden_states)
258
+
259
+ # add hidden states from the last decoder layer
260
+ if output_hidden_states:
261
+ all_hidden_states += (hidden_states,)
262
+
263
+ if not return_dict:
264
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
265
+ return BaseModelOutputWithPast(
266
+ last_hidden_state=hidden_states,
267
+ past_key_values=past_key_values,
268
+ hidden_states=all_hidden_states,
269
+ attentions=all_attns
270
+ )
271
+
272
+
273
+ class HGRNForCausalLM(HGRNPreTrainedModel, GenerationMixin):
274
+
275
+ _tied_weights_keys = ["lm_head.weight"]
276
+
277
+ def __init__(self, config):
278
+ super().__init__(config)
279
+ self.model = HGRNModel(config)
280
+ self.vocab_size = config.vocab_size
281
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
282
+ self.criterion = None
283
+
284
+ # Initialize weights and apply final processing
285
+ self.post_init()
286
+
287
+ def get_input_embeddings(self):
288
+ return self.model.embeddings
289
+
290
+ def set_input_embeddings(self, value):
291
+ self.model.embeddings = value
292
+
293
+ def get_output_embeddings(self):
294
+ return self.lm_head
295
+
296
+ def set_output_embeddings(self, new_embeddings):
297
+ self.lm_head = new_embeddings
298
+
299
+ def set_decoder(self, decoder):
300
+ self.model = decoder
301
+
302
+ def get_decoder(self):
303
+ return self.model
304
+
305
+ def generate(self, *args, **kwargs):
306
+ try:
307
+ return super().generate(*args, **kwargs)
308
+ except AttributeError as exception:
309
+ if 'past_key_values' in str(exception):
310
+ raise AttributeError(
311
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
312
+ f"which is not supported for {self.__class__.__name__}. "
313
+ f"Try another generation strategy instead. "
314
+ f"For the available generation strategies, check this doc: "
315
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
316
+ )
317
+ else:
318
+ raise exception
319
+
320
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
321
+ def prepare_inputs_for_generation(
322
+ self,
323
+ input_ids: torch.LongTensor = None,
324
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ inputs_embeds: Optional[torch.Tensor] = None,
327
+ use_cache: bool = True,
328
+ logits_to_keep: Optional[int] = None,
329
+ **kwargs: Unpack[Dict]
330
+ ):
331
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
332
+ if past_key_values is not None and len(past_key_values) > 0:
333
+ input_ids = input_ids[:, -1:]
334
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
335
+ if inputs_embeds is not None and len(past_key_values) == 0:
336
+ model_inputs = {'inputs_embeds': inputs_embeds}
337
+ else:
338
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
339
+ # recompiles graphs as the stride of the inputs is a guard.
340
+ # Ref: https://github.com/huggingface/transformers/pull/29114
341
+ # TODO: use `next_tokens` directly instead.
342
+ model_inputs = {'input_ids': input_ids.contiguous()}
343
+
344
+ if logits_to_keep is not None:
345
+ model_inputs['logits_to_keep'] = logits_to_keep
346
+
347
+ model_inputs.update({
348
+ 'past_key_values': past_key_values,
349
+ 'use_cache': use_cache,
350
+ 'attention_mask': attention_mask,
351
+ })
352
+ return model_inputs
353
+
354
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
355
+ def forward(
356
+ self,
357
+ input_ids: torch.LongTensor = None,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ inputs_embeds: Optional[torch.Tensor] = None,
360
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
361
+ labels: Optional[torch.LongTensor] = None,
362
+ use_cache: Optional[bool] = None,
363
+ output_attentions: Optional[bool] = None,
364
+ output_hidden_states: Optional[bool] = None,
365
+ return_dict: Optional[bool] = None,
366
+ logits_to_keep: Optional[int] = 0,
367
+ **kwargs: Unpack[Dict]
368
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
369
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
370
+ output_hidden_states = (
371
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
372
+ )
373
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
374
+
375
+ outputs = self.model(
376
+ input_ids=input_ids,
377
+ attention_mask=attention_mask,
378
+ inputs_embeds=inputs_embeds,
379
+ past_key_values=past_key_values,
380
+ use_cache=use_cache,
381
+ output_attentions=output_attentions,
382
+ output_hidden_states=output_hidden_states,
383
+ return_dict=return_dict,
384
+ **kwargs
385
+ )
386
+
387
+ hidden_states = outputs[0]
388
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
389
+
390
+ loss, logits = None, None
391
+ if not fuse_linear_and_cross_entropy or labels is None:
392
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
393
+ if labels is not None:
394
+ if getattr(self, 'criterion', None) is None:
395
+ if fuse_linear_and_cross_entropy:
396
+ criterion = FusedLinearCrossEntropyLoss()
397
+ elif self.config.fuse_cross_entropy:
398
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
399
+ else:
400
+ criterion = nn.CrossEntropyLoss()
401
+ else:
402
+ criterion = self.criterion
403
+ labels = labels.to(hidden_states.device)
404
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
405
+ if fuse_linear_and_cross_entropy:
406
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
407
+ else:
408
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
409
+
410
+ if not return_dict:
411
+ output = (logits,) + outputs[1:]
412
+ return (loss,) + output if loss is not None else output
413
+
414
+ return CausalLMOutputWithPast(
415
+ loss=loss,
416
+ logits=logits,
417
+ past_key_values=outputs.past_key_values,
418
+ hidden_states=outputs.hidden_states,
419
+ attentions=outputs.attentions,
420
+ )
fla/models/linear_attn/configuration_linear_attn.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class LinearAttentionConfig(PretrainedConfig):
9
+
10
+ model_type = 'linear_attn'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "fused_chunk",
16
+ hidden_size: int = 2048,
17
+ expand_k: int = 1,
18
+ expand_v: int = 1,
19
+ hidden_ratio: Optional[int] = 4,
20
+ intermediate_size: Optional[int] = None,
21
+ num_hidden_layers: int = 24,
22
+ num_heads: int = 4,
23
+ num_kv_heads: Optional[int] = None,
24
+ feature_map: str = "elementwise_product",
25
+ tie_feature_map_qk: bool = False,
26
+ norm_q: bool = False,
27
+ norm_k: bool = False,
28
+ norm_feature_map: bool = False,
29
+ hidden_act: str = "swish",
30
+ max_position_embeddings: int = 2048,
31
+ elementwise_affine: Optional[bool] = True,
32
+ norm_eps: float = 1e-6,
33
+ attn: Optional[Dict] = None,
34
+ use_cache: bool = True,
35
+ pad_token_id: int = None,
36
+ bos_token_id: int = 1,
37
+ eos_token_id: int = 2,
38
+ tie_word_embeddings: bool = False,
39
+ initializer_range: float = 0.006,
40
+ fuse_norm: bool = True,
41
+ fuse_swiglu: bool = True,
42
+ fuse_cross_entropy: bool = True,
43
+ vocab_size: int = 32000,
44
+ **kwargs
45
+ ):
46
+ self.attn_mode = attn_mode
47
+ self.hidden_size = hidden_size
48
+ self.expand_k = expand_k
49
+ self.expand_v = expand_v
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+ self.num_hidden_layers = num_hidden_layers
53
+ self.num_heads = num_heads
54
+ self.num_kv_heads = num_kv_heads
55
+ self.feature_map = feature_map
56
+ self.tie_feature_map_qk = tie_feature_map_qk
57
+ self.norm_q = norm_q
58
+ self.norm_k = norm_k
59
+ self.norm_feature_map = norm_feature_map
60
+ self.hidden_act = hidden_act
61
+ self.max_position_embeddings = max_position_embeddings
62
+ self.elementwise_affine = elementwise_affine
63
+ self.norm_eps = norm_eps
64
+ self.attn = attn
65
+ self.use_cache = use_cache
66
+ self.initializer_range = initializer_range
67
+
68
+ self.fuse_norm = fuse_norm
69
+ self.fuse_swiglu = fuse_swiglu
70
+ self.fuse_cross_entropy = fuse_cross_entropy
71
+ self.vocab_size = vocab_size
72
+
73
+ if attn is not None:
74
+ if not isinstance(attn, Dict):
75
+ raise ValueError("attn must be a dictionary")
76
+ if 'layers' not in attn:
77
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
78
+ if 'num_heads' not in attn:
79
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
80
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
81
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
82
+ attn['window_size'] = attn.get('window_size', None)
83
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
84
+
85
+ super().__init__(
86
+ pad_token_id=pad_token_id,
87
+ bos_token_id=bos_token_id,
88
+ eos_token_id=eos_token_id,
89
+ tie_word_embeddings=tie_word_embeddings,
90
+ **kwargs,
91
+ )
fla/models/nsa/modeling_nsa.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.nsa import NativeSparseAttention
19
+ from fla.models.nsa.configuration_nsa import NSAConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
22
+ from fla.modules import GatedMLP as NSAMLP
23
+ from fla.modules import RMSNorm
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class NSABlock(nn.Module):
32
+ def __init__(self, config: NSAConfig, layer_idx: int):
33
+ super().__init__()
34
+
35
+ self.config = config
36
+ self.layer_idx = layer_idx
37
+
38
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
39
+ self.attn = NativeSparseAttention(
40
+ hidden_size=config.hidden_size,
41
+ num_heads=config.num_heads,
42
+ num_kv_heads=config.num_kv_heads,
43
+ qkv_bias=config.qkv_bias,
44
+ block_size=config.block_size,
45
+ block_counts=config.block_counts,
46
+ window_size=config.window_size,
47
+ rope_theta=config.rope_theta,
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
52
+ self.mlp = NSAMLP(
53
+ hidden_size=config.hidden_size,
54
+ hidden_ratio=config.hidden_ratio,
55
+ intermediate_size=config.intermediate_size,
56
+ hidden_act=config.hidden_act,
57
+ fuse_swiglu=config.fuse_swiglu
58
+ )
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
65
+ use_cache: Optional[bool] = False,
66
+ output_attentions: Optional[bool] = False,
67
+ **kwargs: Unpack[Dict]
68
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
69
+ residual = hidden_states
70
+ hidden_states = self.attn_norm(hidden_states)
71
+ hidden_states, attentions, past_key_values = self.attn(
72
+ hidden_states=hidden_states,
73
+ attention_mask=attention_mask,
74
+ past_key_values=past_key_values,
75
+ use_cache=use_cache,
76
+ output_attentions=output_attentions,
77
+ **kwargs
78
+ )
79
+ if self.config.fuse_norm:
80
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
81
+ else:
82
+ hidden_states = residual + hidden_states
83
+ residual = hidden_states
84
+ hidden_states = self.mlp_norm(hidden_states)
85
+ hidden_states = self.mlp(hidden_states, **kwargs)
86
+ hidden_states = residual + hidden_states
87
+
88
+ outputs = (hidden_states, attentions, past_key_values)
89
+
90
+ return outputs
91
+
92
+
93
+ class NSAPreTrainedModel(PreTrainedModel):
94
+
95
+ config_class = NSAConfig
96
+ base_model_prefix = 'model'
97
+ supports_gradient_checkpointing = True
98
+ _no_split_modules = ['NSABlock']
99
+ _supports_cache_class = True
100
+
101
+ def __init__(self, *inputs, **kwargs):
102
+ super().__init__(*inputs, **kwargs)
103
+
104
+ def _init_weights(
105
+ self,
106
+ module: nn.Module,
107
+ prenorm_residual_strategy: Optional[str] = 'rescale',
108
+ num_residuals_per_layer: int = 2,
109
+ ):
110
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
111
+ # Slightly different from the TF version which uses truncated_normal for initialization
112
+ # cf https://github.com/pytorch/pytorch/pull/5617
113
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
114
+ if module.bias is not None:
115
+ nn.init.zeros_(module.bias)
116
+ elif isinstance(module, nn.Embedding):
117
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
118
+ elif hasattr(module, 'reset_parameters'):
119
+ module.reset_parameters()
120
+
121
+ if prenorm_residual_strategy is not None:
122
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
123
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
124
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
125
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
126
+ #
127
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
128
+ p = None
129
+ if hasattr(module, 'o_proj'):
130
+ p = module.o_proj.weight
131
+ elif hasattr(module, 'down_proj'):
132
+ p = module.down_proj.weight
133
+ if p is not None:
134
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
135
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
136
+ # We need to reinit p since this code could be called multiple times
137
+ # Having just p *= scale would repeatedly scale it down
138
+ if prenorm_residual_strategy == 'rescale':
139
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
140
+ with torch.no_grad():
141
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
142
+ elif prenorm_residual_strategy == 'zero':
143
+ nn.init.zeros_(p)
144
+ else:
145
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
146
+
147
+
148
+ class NSAModel(NSAPreTrainedModel):
149
+
150
+ def __init__(self, config: NSAConfig):
151
+ super().__init__(config)
152
+ self.padding_idx = config.pad_token_id
153
+ self.vocab_size = config.vocab_size
154
+
155
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
156
+ self.layers = nn.ModuleList([NSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
157
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
158
+
159
+ self.gradient_checkpointing = False
160
+
161
+ self.post_init()
162
+
163
+ def get_input_embeddings(self):
164
+ return self.embeddings
165
+
166
+ def set_input_embeddings(self, value):
167
+ self.embeddings = value
168
+
169
+ def forward(
170
+ self,
171
+ input_ids: Optional[torch.LongTensor] = None,
172
+ attention_mask: Optional[torch.Tensor] = None, # noqa
173
+ inputs_embeds: Optional[torch.FloatTensor] = None,
174
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
175
+ use_cache: Optional[bool] = None,
176
+ output_attentions: Optional[bool] = None,
177
+ output_hidden_states: Optional[bool] = None,
178
+ return_dict: Optional[bool] = None,
179
+ **kwargs: Unpack[Dict]
180
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
181
+ if output_attentions:
182
+ warnings.warn("`NSAModel` does not `output_attentions` now, setting it to `False`.")
183
+ output_attentions = False
184
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
185
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
186
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
187
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
188
+
189
+ # retrieve input_ids and inputs_embeds
190
+ if input_ids is not None and inputs_embeds is not None:
191
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
192
+ if input_ids is None and inputs_embeds is None:
193
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
194
+
195
+ if inputs_embeds is None:
196
+ inputs_embeds = self.embeddings(input_ids)
197
+ hidden_states = inputs_embeds
198
+
199
+ if use_cache and not isinstance(past_key_values, Cache):
200
+ past_key_values = Cache.from_legacy_cache(past_key_values)
201
+
202
+ if self.gradient_checkpointing and self.training and use_cache:
203
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
204
+ use_cache = False
205
+
206
+ all_hidden_states = () if output_hidden_states else None
207
+ all_attns = () if output_attentions else None
208
+ for layer in self.layers:
209
+ if output_hidden_states:
210
+ all_hidden_states += (hidden_states,)
211
+
212
+ if self.gradient_checkpointing and self.training:
213
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
214
+ layer.__call__,
215
+ hidden_states,
216
+ attention_mask,
217
+ past_key_values,
218
+ use_cache,
219
+ output_attentions,
220
+ **kwargs
221
+ )
222
+ else:
223
+ hidden_states, attentions, past_key_values = layer(
224
+ hidden_states,
225
+ attention_mask=attention_mask,
226
+ past_key_values=past_key_values,
227
+ use_cache=use_cache,
228
+ output_attentions=output_attentions,
229
+ **kwargs
230
+ )
231
+
232
+ if output_attentions:
233
+ all_attns += (attentions,)
234
+
235
+ hidden_states = self.norm(hidden_states)
236
+
237
+ # add hidden states from the last decoder layer
238
+ if output_hidden_states:
239
+ all_hidden_states += (hidden_states,)
240
+
241
+ if not return_dict:
242
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
243
+ return BaseModelOutputWithPast(
244
+ last_hidden_state=hidden_states,
245
+ past_key_values=past_key_values,
246
+ hidden_states=all_hidden_states,
247
+ attentions=all_attns
248
+ )
249
+
250
+
251
+ class NSAForCausalLM(NSAPreTrainedModel, GenerationMixin):
252
+
253
+ _tied_weights_keys = ["lm_head.weight"]
254
+
255
+ def __init__(self, config):
256
+ super().__init__(config)
257
+ self.model = NSAModel(config)
258
+ self.vocab_size = config.vocab_size
259
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
260
+ self.criterion = None
261
+
262
+ # Initialize weights and apply final processing
263
+ self.post_init()
264
+
265
+ def get_input_embeddings(self):
266
+ return self.model.embeddings
267
+
268
+ def set_input_embeddings(self, value):
269
+ self.model.embeddings = value
270
+
271
+ def get_output_embeddings(self):
272
+ return self.lm_head
273
+
274
+ def set_output_embeddings(self, new_embeddings):
275
+ self.lm_head = new_embeddings
276
+
277
+ def set_decoder(self, decoder):
278
+ self.model = decoder
279
+
280
+ def get_decoder(self):
281
+ return self.model
282
+
283
+ def generate(self, *args, **kwargs):
284
+ try:
285
+ return super().generate(*args, **kwargs)
286
+ except AttributeError as exception:
287
+ if 'past_key_values' in str(exception):
288
+ raise AttributeError(
289
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
290
+ f"which is not supported for {self.__class__.__name__}. "
291
+ f"Try another generation strategy instead. "
292
+ f"For the available generation strategies, check this doc: "
293
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
294
+ )
295
+ else:
296
+ raise exception
297
+
298
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
299
+ def prepare_inputs_for_generation(
300
+ self,
301
+ input_ids: torch.LongTensor = None,
302
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
303
+ attention_mask: Optional[torch.Tensor] = None,
304
+ inputs_embeds: Optional[torch.Tensor] = None,
305
+ use_cache: bool = True,
306
+ logits_to_keep: Optional[int] = None,
307
+ **kwargs
308
+ ):
309
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
310
+ if past_key_values is not None and len(past_key_values) > 0:
311
+ input_ids = input_ids[:, -1:]
312
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
313
+ if inputs_embeds is not None and len(past_key_values) == 0:
314
+ model_inputs = {'inputs_embeds': inputs_embeds}
315
+ else:
316
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
317
+ # recompiles graphs as the stride of the inputs is a guard.
318
+ # Ref: https://github.com/huggingface/transformers/pull/29114
319
+ # TODO: use `next_tokens` directly instead.
320
+ model_inputs = {'input_ids': input_ids.contiguous()}
321
+
322
+ if logits_to_keep is not None:
323
+ model_inputs['logits_to_keep'] = logits_to_keep
324
+
325
+ model_inputs.update({
326
+ 'past_key_values': past_key_values,
327
+ 'use_cache': use_cache,
328
+ 'attention_mask': attention_mask,
329
+ })
330
+ return model_inputs
331
+
332
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
333
+ def forward(
334
+ self,
335
+ input_ids: torch.LongTensor = None,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ inputs_embeds: Optional[torch.Tensor] = None,
338
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
339
+ labels: Optional[torch.LongTensor] = None,
340
+ use_cache: Optional[bool] = None,
341
+ output_attentions: Optional[bool] = None,
342
+ output_hidden_states: Optional[bool] = None,
343
+ return_dict: Optional[bool] = None,
344
+ logits_to_keep: Optional[int] = 0,
345
+ **kwargs: Unpack[Dict]
346
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
347
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
348
+ output_hidden_states = (
349
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
350
+ )
351
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
352
+
353
+ outputs = self.model(
354
+ input_ids=input_ids,
355
+ attention_mask=attention_mask,
356
+ inputs_embeds=inputs_embeds,
357
+ past_key_values=past_key_values,
358
+ use_cache=use_cache,
359
+ output_attentions=output_attentions,
360
+ output_hidden_states=output_hidden_states,
361
+ return_dict=return_dict,
362
+ **kwargs
363
+ )
364
+
365
+ hidden_states = outputs[0]
366
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
367
+
368
+ loss, logits = None, None
369
+ if not fuse_linear_and_cross_entropy or labels is None:
370
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
371
+ if labels is not None:
372
+ if getattr(self, 'criterion', None) is None:
373
+ if fuse_linear_and_cross_entropy:
374
+ criterion = FusedLinearCrossEntropyLoss()
375
+ elif self.config.fuse_cross_entropy:
376
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
377
+ else:
378
+ criterion = nn.CrossEntropyLoss()
379
+ else:
380
+ criterion = self.criterion
381
+ labels = labels.to(hidden_states.device)
382
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
383
+ if fuse_linear_and_cross_entropy:
384
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
385
+ else:
386
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
387
+
388
+ if not return_dict:
389
+ output = (logits,) + outputs[1:]
390
+ return (loss,) + output if loss is not None else output
391
+
392
+ return CausalLMOutputWithPast(
393
+ loss=loss,
394
+ logits=logits,
395
+ past_key_values=outputs.past_key_values,
396
+ hidden_states=outputs.hidden_states,
397
+ attentions=outputs.attentions,
398
+ )
flame/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.1.0"
flame/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (163 Bytes). View file
 
flame/__pycache__/train.cpython-311.pyc ADDED
Binary file (38.8 kB). View file
 
flame/models/parallelize_fla.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.distributed import DeviceMesh
15
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
16
+ from torch.distributed._composable.replicate import replicate
17
+ from torch.distributed._tensor import Replicate, Shard
18
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
19
+ from torch.distributed.tensor.parallel import (
20
+ ColwiseParallel,
21
+ PrepareModuleInput,
22
+ PrepareModuleOutput,
23
+ RowwiseParallel,
24
+ SequenceParallel,
25
+ parallelize_module
26
+ )
27
+
28
+ from fla.modules.fused_linear_cross_entropy import LinearLossParallel
29
+ from fla.modules.mlp import SwiGLULinearParallel
30
+ from fla.modules.parallel import PrepareModuleWeight
31
+ from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
32
+ from torchtitan.distributed.parallel_dims import ParallelDims
33
+ from torchtitan.tools.logging import logger
34
+
35
+
36
+ def parallelize_fla(
37
+ model: nn.Module,
38
+ world_mesh: DeviceMesh,
39
+ parallel_dims: ParallelDims,
40
+ job_config: JobConfig,
41
+ ):
42
+ """
43
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
44
+ parallelism to the model.
45
+
46
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
47
+ the model must fit on GPU or CPU memory.
48
+ """
49
+
50
+ if parallel_dims.tp_enabled:
51
+ if (
52
+ job_config.experimental.enable_async_tensor_parallel
53
+ and not job_config.training.compile
54
+ ):
55
+ raise RuntimeError("Async TP requires --training.compile")
56
+ enable_float8_linear = "float8" in job_config.model.converters
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8=enable_float8_linear,
62
+ enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
63
+ )
64
+
65
+ if job_config.activation_checkpoint.mode != "none":
66
+ apply_ac(model, job_config.activation_checkpoint)
67
+
68
+ # turn on per-block compile after AC wrapping and before FSDP
69
+ if job_config.training.compile:
70
+ apply_compile(model)
71
+
72
+ if (
73
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
74
+ ): # apply FSDP or HSDP, potentially with Context Parallel
75
+ if parallel_dims.dp_replicate_enabled:
76
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
77
+ else:
78
+ dp_mesh_dim_names = ("dp_shard_cp",)
79
+
80
+ apply_fsdp(
81
+ model,
82
+ world_mesh[tuple(dp_mesh_dim_names)],
83
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
84
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
85
+ pp_enabled=parallel_dims.pp_enabled,
86
+ cpu_offload=job_config.training.enable_cpu_offload,
87
+ reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
88
+ )
89
+
90
+ if parallel_dims.dp_replicate_enabled:
91
+ logger.info("Applied HSDP to the model")
92
+ else:
93
+ logger.info("Applied FSDP to the model")
94
+
95
+ if parallel_dims.cp_enabled:
96
+ logger.info("Applied Context Parallel to the model")
97
+
98
+ if job_config.training.enable_cpu_offload:
99
+ logger.info("Applied CPU Offloading to the model")
100
+ elif parallel_dims.dp_replicate_enabled:
101
+ if world_mesh.ndim > 1:
102
+ raise RuntimeError("DDP has not supported > 1D parallelism")
103
+ apply_ddp(
104
+ model,
105
+ world_mesh,
106
+ enable_compile=job_config.training.compile,
107
+ enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
108
+ )
109
+
110
+
111
+ class TPPlan:
112
+ def __init__(
113
+ self,
114
+ model=None,
115
+ loss_parallel=False,
116
+ enable_float8=False,
117
+ ):
118
+ self.model = model
119
+ self.loss_parallel = loss_parallel
120
+ self.enable_float8 = enable_float8
121
+ self.base_model_prefix = getattr(model, "base_model_prefix", "model")
122
+
123
+ # TODO(vkuzo): once float8 configuration supports delayed scaling,
124
+ # add a check here to enforce supported float8 all-gather configurations
125
+ # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
126
+ try:
127
+ from torchao.float8.float8_tensor_parallel import (
128
+ Float8ColwiseParallel,
129
+ Float8RowwiseParallel,
130
+ PrepareFloat8ModuleInput
131
+ )
132
+ except ImportError:
133
+ Float8ColwiseParallel = None
134
+ Float8RowwiseParallel = None
135
+ PrepareFloat8ModuleInput = None
136
+ if self.enable_float8 and Float8ColwiseParallel is not None:
137
+ self.rowwise_parallel = Float8RowwiseParallel
138
+ self.colwise_parallel = Float8ColwiseParallel
139
+ self.prepare_module_input = PrepareFloat8ModuleInput
140
+ self.prepare_module_output = PrepareModuleOutput
141
+ else:
142
+ self.rowwise_parallel = RowwiseParallel
143
+ self.colwise_parallel = ColwiseParallel
144
+ self.prepare_module_input = PrepareModuleInput
145
+ self.prepare_module_output = PrepareModuleOutput
146
+
147
+ @property
148
+ def model_plan(self):
149
+ plans = {
150
+ f"{self.base_model_prefix}.embeddings": RowwiseParallel(
151
+ input_layouts=Replicate(),
152
+ output_layouts=Shard(1),
153
+ ),
154
+ f"{self.base_model_prefix}.norm": SequenceParallel(),
155
+ }
156
+ if self.loss_parallel:
157
+ plans.update(
158
+ {
159
+ "lm_head": ColwiseParallel(
160
+ input_layouts=Shard(1),
161
+ output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
162
+ use_local_output=not self.loss_parallel,
163
+ ),
164
+ }
165
+ )
166
+ else:
167
+ plans.update(
168
+ {
169
+ "lm_head": PrepareModuleWeight(layouts=Replicate()),
170
+ "criterion": LinearLossParallel(),
171
+ }
172
+ )
173
+ return plans
174
+
175
+ @property
176
+ def layer_plan(self):
177
+ return {
178
+ "attn_norm": SequenceParallel(),
179
+ **self.attn_plan,
180
+ "mlp_norm": SequenceParallel(),
181
+ **self.mlp_plan,
182
+ }
183
+
184
+ @property
185
+ def attn_plan(self):
186
+ raise NotImplementedError(
187
+ f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
188
+ )
189
+
190
+ @property
191
+ def mlp_plan(self):
192
+ return {
193
+ "mlp": self.prepare_module_input(
194
+ input_layouts=(Shard(1),),
195
+ desired_input_layouts=(Replicate(),),
196
+ ),
197
+ "mlp.gate_proj": self.colwise_parallel(),
198
+ "mlp.up_proj": self.colwise_parallel(),
199
+ "mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
200
+ "mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
201
+ }
202
+
203
+
204
+ class TransformerTPPlan(TPPlan):
205
+
206
+ @property
207
+ def attn_plan(self):
208
+ return {
209
+ "attn": self.prepare_module_input(
210
+ input_kwarg_layouts={"hidden_states": Shard(1)},
211
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
212
+ ),
213
+ "attn.q_proj": self.colwise_parallel(),
214
+ "attn.k_proj": self.colwise_parallel(),
215
+ "attn.v_proj": self.colwise_parallel(),
216
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
217
+ }
218
+
219
+
220
+ class GLATPPlan(TPPlan):
221
+
222
+ @property
223
+ def attn_plan(self):
224
+ return {
225
+ "attn": self.prepare_module_input(
226
+ input_kwarg_layouts={"hidden_states": Shard(1)},
227
+ desired_input_kwarg_layouts={"hidden_states": Replicate()},
228
+ ),
229
+ "attn.q_proj": self.colwise_parallel(),
230
+ "attn.k_proj": self.colwise_parallel(),
231
+ "attn.v_proj": self.colwise_parallel(),
232
+ "attn.g_proj": self.colwise_parallel(),
233
+ "attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
234
+ "attn.gk_proj.1": self.colwise_parallel(),
235
+ "attn.g_norm": SequenceParallel(sequence_dim=-1),
236
+ "attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
237
+ }
238
+
239
+
240
+ TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
241
+
242
+
243
+ def apply_tp(
244
+ model: nn.Module,
245
+ tp_mesh: DeviceMesh,
246
+ loss_parallel: bool,
247
+ enable_float8: bool,
248
+ enable_async_tp: bool,
249
+ ):
250
+ """Apply tensor parallelism."""
251
+ # 1. Parallelize the embedding and shard its outputs (which are the first
252
+ # transformer block's inputs)
253
+ # 2. Parallelize the root norm layer over the sequence dim
254
+ # 3. Parallelize the final linear output layer
255
+ tp_plan = TP_PLAN_MAP[model.config.model_type](
256
+ model, loss_parallel=loss_parallel, enable_float8=enable_float8
257
+ )
258
+ parallelize_module(model, tp_mesh, tp_plan.model_plan)
259
+
260
+ blocks = get_blocks(model)
261
+ if blocks is None:
262
+ logger.warning("No block found for tensor parallelism")
263
+ else:
264
+ for _, block in enumerate(blocks):
265
+ parallelize_module(
266
+ module=block,
267
+ device_mesh=tp_mesh,
268
+ parallelize_plan=tp_plan.layer_plan,
269
+ )
270
+
271
+ if enable_async_tp:
272
+ from torch.distributed._symmetric_memory import enable_symm_mem_for_group
273
+
274
+ torch._inductor.config._micro_pipeline_tp = True
275
+ enable_symm_mem_for_group(tp_mesh.get_group().group_name)
276
+
277
+ logger.info(
278
+ f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
279
+ "Tensor Parallelism to the model"
280
+ )
281
+
282
+
283
+ # for selective op activation checkpointing
284
+ _save_list = {
285
+ torch.ops.aten.mm.default,
286
+ torch.ops.aten._scaled_dot_product_efficient_attention.default,
287
+ torch.ops.aten._scaled_dot_product_flash_attention.default,
288
+ torch.ops._c10d_functional.reduce_scatter_tensor.default,
289
+ # for low precision training, it's useful to always save
290
+ # the result of max, since the absolute maximum is
291
+ # used to compute the scaling factor for quantization.
292
+ torch.ops.aten.max.default,
293
+ }
294
+
295
+
296
+ def _apply_ac_to_block(module: nn.Module, ac_config):
297
+ valid_ac_modes = ("full", "selective")
298
+ if ac_config.mode not in valid_ac_modes:
299
+ raise ValueError(
300
+ f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
301
+ )
302
+
303
+ if ac_config.mode == "full":
304
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
305
+
306
+ assert ac_config.mode == "selective", f"{ac_config.mode}"
307
+ use_op_sac = ac_config.selective_ac_option == "op"
308
+ use_layer_sac = ac_config.selective_ac_option.isdigit()
309
+ if not use_op_sac and not use_layer_sac:
310
+ raise ValueError(
311
+ f"Invalid selective AC option: {ac_config.selective_ac_option}. "
312
+ f"Valid options: 'op' or a positive int representing layer frequency"
313
+ )
314
+ if use_op_sac:
315
+ from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
316
+
317
+ def _get_custom_policy(meta):
318
+ def _custom_policy(ctx, func, *args, **kwargs):
319
+ mode = "recompute" if ctx.is_recompute else "forward"
320
+ mm_count_key = f"{mode}_mm_count"
321
+ if func == torch.ops.aten.mm.default:
322
+ meta[mm_count_key] += 1
323
+ # Saves output of all compute ops, except every second mm
324
+ to_save = func in _save_list and not (
325
+ func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
326
+ )
327
+ return (
328
+ CheckpointPolicy.MUST_SAVE
329
+ if to_save
330
+ else CheckpointPolicy.PREFER_RECOMPUTE
331
+ )
332
+
333
+ return _custom_policy
334
+
335
+ def selective_checkpointing_context_fn():
336
+ meta = defaultdict(int)
337
+ return create_selective_checkpoint_contexts(_get_custom_policy(meta))
338
+
339
+ return ptd_checkpoint_wrapper(
340
+ module,
341
+ context_fn=selective_checkpointing_context_fn,
342
+ preserve_rng_state=False,
343
+ )
344
+ elif use_layer_sac:
345
+ # Checkpoint every `ac_freq` of the modules passed to this function
346
+ ac_freq = int(ac_config.selective_ac_option)
347
+ ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
348
+ ptd_checkpoint_wrapper._count += 1
349
+ if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
350
+ return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
351
+ else:
352
+ return module
353
+
354
+
355
+ def apply_ac(model: nn.Module, ac_config):
356
+ """Apply activation checkpointing to the model."""
357
+ blocks = get_blocks(model)
358
+ if blocks is None:
359
+ logger.warning("No block found for activation checkpointing")
360
+ return
361
+
362
+ for layer_id, block in blocks.named_children():
363
+ block = _apply_ac_to_block(block, ac_config)
364
+ blocks.register_module(layer_id, block)
365
+
366
+ logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
367
+
368
+
369
+ def apply_compile(model: nn.Module):
370
+ """
371
+ Apply torch.compile to each block, which makes compilation efficient due to
372
+ repeated structure. Alternatively one can compile the whole model (after applying DP).
373
+ """
374
+
375
+ blocks = get_blocks(model)
376
+ if blocks is None:
377
+ logger.warning("No block found for torch.compile")
378
+ else:
379
+ for layer_id, block in blocks.named_children():
380
+ block = torch.compile(block)
381
+ blocks.register_module(layer_id, block)
382
+ logger.info("Compiling each block with torch.compile")
383
+
384
+ real_model = get_model(model)
385
+
386
+ logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
387
+ embeddings_key = get_components_name(real_model, "tok_embeddings")
388
+ if embeddings_key is not None:
389
+ embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
390
+ real_model.register_module(embeddings_key, embeddings)
391
+
392
+ norm_key = get_components_name(real_model, "norm")
393
+ if norm_key is not None:
394
+ norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
395
+ real_model.register_module(norm_key, norm)
396
+
397
+ lm_head_key = get_components_name(model, "lm_head")
398
+ if lm_head_key is not None:
399
+ lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
400
+ model.register_module(lm_head_key, lm_head)
401
+
402
+ logger.info("Compiling the entire model with torch.compile")
403
+ model = torch.compile(model)
404
+
405
+
406
+ def apply_fsdp(
407
+ model: nn.Module,
408
+ dp_mesh: DeviceMesh,
409
+ param_dtype: torch.dtype,
410
+ reduce_dtype: torch.dtype,
411
+ pp_enabled: bool,
412
+ cpu_offload: bool = False,
413
+ reshard_after_forward_policy: str = "default",
414
+ ):
415
+ """
416
+ Apply data parallelism (via FSDP2) to the model.
417
+
418
+ Args:
419
+ model (nn.Module): The model to apply data parallelism to.
420
+ dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
421
+ param_dtype (torch.dtype): The data type to use for model parameters.
422
+ reduce_dtype (torch.dtype): The data type to use for reduction operations.
423
+ pp_enabled (bool): Whether pipeline parallelism is enabled.
424
+ cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
425
+ reshard_after_forward_policy (str, optional):
426
+ The policy to use for resharding after forward pass. Defaults to "default".
427
+ Other options: "never", "always".
428
+ - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
429
+ - "always" will enable `reshard_after_forward` for all forward passes.
430
+ - "never" will disable `reshard_after_forward` for all forward passes.
431
+
432
+ """
433
+ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
434
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
435
+ if cpu_offload:
436
+ fsdp_config["offload_policy"] = CPUOffloadPolicy()
437
+
438
+ blocks = get_blocks(model)
439
+ if blocks is None:
440
+ logger.warning("No block found for FSDP")
441
+ else:
442
+ total_blocks = len(blocks)
443
+ for layer_id, block in enumerate(blocks):
444
+ if reshard_after_forward_policy == "always":
445
+ reshard_after_forward = True
446
+ elif reshard_after_forward_policy == "never":
447
+ reshard_after_forward = False
448
+ elif reshard_after_forward_policy == "default":
449
+ if pp_enabled:
450
+ # For PP, do not reshard after forward to avoid per-microbatch
451
+ # all-gathers, which can be expensive and non-overlapped
452
+ reshard_after_forward = False
453
+ else:
454
+ # As an optimization, do not reshard after forward for the last
455
+ # transformer block since FSDP would prefetch it immediately
456
+ reshard_after_forward = int(layer_id) < total_blocks - 1
457
+ else:
458
+ raise ValueError(
459
+ f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
460
+ )
461
+ fully_shard(
462
+ block,
463
+ **fsdp_config,
464
+ reshard_after_forward=reshard_after_forward,
465
+ )
466
+
467
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
468
+
469
+
470
+ def apply_ddp(
471
+ model: nn.Module,
472
+ dp_mesh: DeviceMesh,
473
+ enable_compile: bool,
474
+ enable_compiled_autograd: bool,
475
+ ):
476
+ if enable_compile:
477
+ if enable_compiled_autograd:
478
+ torch._dynamo.config.optimize_ddp = (
479
+ "python_reducer_without_compiled_forward"
480
+ )
481
+ else:
482
+ torch._dynamo.config.optimize_ddp = "ddp_optimizer"
483
+
484
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
485
+
486
+ logger.info("Applied DDP to the model")
487
+
488
+
489
+ def get_model(model):
490
+ base_model_prefix = getattr(model, "base_model_prefix", "model")
491
+ if not hasattr(model, base_model_prefix):
492
+ return None
493
+ model = getattr(model, base_model_prefix)
494
+ return model
495
+
496
+
497
+ def get_blocks(model):
498
+ # TODO[flame]: adapt for network not using 'layers' attribute
499
+ model = get_model(model)
500
+ if not hasattr(model, "layers"):
501
+ logger.warning('no "layers" in model can be found')
502
+ return None
503
+ return model.layers
504
+
505
+
506
+ def get_components_name(model, component_name):
507
+ """
508
+ We try to catch tok_embeddings, norm layers and lm_head layers
509
+ We do not catch the layer names in the blocks, for blocks see `get_blocks`
510
+ We assume the model has the following structure:
511
+ LlamaForCausalLM:
512
+ Model:
513
+ embed_tokens,
514
+ layers,
515
+ norm,
516
+ lm_head
517
+ ***
518
+ so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
519
+ and for 'lm_head' we need to pass `model`
520
+ ***
521
+ """
522
+
523
+ if component_name == "tok_embeddings":
524
+ if hasattr(model, "tok_embeddings"):
525
+ return "tok_embeddings"
526
+ elif hasattr(model, "embed_tokens"):
527
+ return "embed_tokens"
528
+ elif hasattr(model, "embeddings"):
529
+ return "embeddings"
530
+ else:
531
+ logger.warning("No tok_embeddings found in model")
532
+ return None
533
+
534
+ elif component_name == "norm":
535
+ if hasattr(model, "norm"):
536
+ return "norm"
537
+ elif hasattr(model, "norms"):
538
+ return "norms"
539
+ elif hasattr(model, "layernorm"):
540
+ return "layernorm"
541
+ else:
542
+ logger.warning("No norm found in model")
543
+ return None
544
+
545
+ elif component_name == "lm_head":
546
+ if hasattr(model, "lm_head"):
547
+ return "lm_head"
548
+ else:
549
+ logger.warning("No lm_head found in model")
550
+ return None
flame/train.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ from datetime import timedelta
11
+ from collections import defaultdict
12
+ import dataclasses
13
+
14
+ import torch
15
+ from datasets import interleave_datasets, load_dataset
16
+ from torch.distributed.elastic.multiprocessing.errors import record
17
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
18
+
19
+ import fla # noqa
20
+ from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
21
+ from fla.ops.common.utils import prepare_position_ids
22
+ from flame.components.checkpoint import TrainState
23
+ from flame.config_manager import JobConfig
24
+ from flame.data import build_dataloader, shuffle
25
+ from flame.models.parallelize_fla import parallelize_fla
26
+ from flame.models.pipeline_fla import pipeline_fla
27
+ from flame.tools.utils import get_nparams_and_flops
28
+ from flame.utils.checkpoint import cleanup_local_checkpoints
29
+ from flame.utils.convert_dcp_to_hf import save_pretrained
30
+ from flame.utils.hf_utils import upload_checkpoint_to_hf
31
+ from datetime import datetime
32
+ from torchtitan.components.checkpoint import CheckpointManager
33
+ from torchtitan.components.ft import FTParallelDims, init_ft_manager
34
+ from torchtitan.components.loss import build_cross_entropy_loss
35
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
36
+ from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible
37
+ from torchtitan.components.optimizer import build_optimizers
38
+ from torchtitan.distributed import ParallelDims
39
+ from torchtitan.distributed import utils as dist_utils
40
+ from torchtitan.protocols.model_converter import build_model_converters
41
+ from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
42
+ from torchtitan.tools import utils
43
+ from torchtitan.tools.logging import init_logger, logger
44
+ from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
45
+
46
+ from dotenv import load_dotenv
47
+ load_dotenv()
48
+
49
+ import wandb
50
+ wandb.login(key=os.environ["WANDB_API_KEY"])
51
+
52
+ import huggingface_hub
53
+ huggingface_hub.login(token=os.environ["HF_TOKEN"])
54
+
55
+
56
+ def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
57
+ return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path)
58
+
59
+
60
+ register_train_spec(
61
+ TrainSpec(
62
+ name="fla",
63
+ cls=AutoModelForCausalLM,
64
+ config=AutoConfig,
65
+ parallelize_fn=parallelize_fla,
66
+ pipelining_fn=pipeline_fla,
67
+ build_optimizers_fn=build_optimizers,
68
+ build_lr_schedulers_fn=build_lr_schedulers,
69
+ build_dataloader_fn=build_dataloader,
70
+ build_tokenizer_fn=build_tokenizer,
71
+ build_loss_fn=build_cross_entropy_loss,
72
+ )
73
+ )
74
+
75
+
76
+ # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
77
+ @record
78
+ def main(job_config: JobConfig):
79
+ logger.info(f"Starting job: {job_config.job.description}")
80
+
81
+ if job_config.experimental.custom_model_path:
82
+ utils.import_module_from_path(job_config.experimental.custom_model_path)
83
+
84
+ # used for colorful printing
85
+ color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color
86
+
87
+ if job_config.job.print_args:
88
+ logger.info(
89
+ f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}"
90
+ )
91
+
92
+ # take control of garbage collection to avoid stragglers
93
+ gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
94
+
95
+ device_module, device_type = utils.device_module, utils.device_type
96
+ device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
97
+ # Device has to be set before creating TorchFT manager.
98
+ device_module.set_device(device)
99
+ ft_manager = init_ft_manager(job_config)
100
+
101
+ run_specific_repo_id = None
102
+ if getattr(job_config.checkpoint, "hf_upload_enabled", False):
103
+ hf_repo_base = getattr(job_config.checkpoint, "hf_repo_base_name", None)
104
+ if hf_repo_base:
105
+ # Generate timestamp (adjust format if desired)
106
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
107
+ run_specific_repo_id = f"{hf_repo_base}-{timestamp}"
108
+ logger.info(f"Target Hugging Face repository for this run: {run_specific_repo_id}")
109
+ else:
110
+ logger.warning("HF Hub upload enabled, but 'checkpoint.hf_repo_base_name' is not set.")
111
+ # Disable upload if base name is missing
112
+ job_config.checkpoint.hf_upload_enabled = False
113
+
114
+ # init distributed
115
+ world_size = int(os.environ["WORLD_SIZE"])
116
+ if not ft_manager.enabled:
117
+ parallel_dims = ParallelDims(
118
+ dp_shard=job_config.training.data_parallel_shard_degree,
119
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
120
+ cp=job_config.experimental.context_parallel_degree,
121
+ tp=job_config.training.tensor_parallel_degree,
122
+ pp=job_config.experimental.pipeline_parallel_degree,
123
+ world_size=world_size,
124
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
125
+ )
126
+ else:
127
+ parallel_dims = FTParallelDims(
128
+ dp_shard=job_config.training.data_parallel_shard_degree,
129
+ dp_replicate=job_config.training.data_parallel_replicate_degree,
130
+ cp=job_config.experimental.context_parallel_degree,
131
+ tp=job_config.training.tensor_parallel_degree,
132
+ pp=job_config.experimental.pipeline_parallel_degree,
133
+ world_size=world_size,
134
+ enable_loss_parallel=not job_config.training.disable_loss_parallel,
135
+ ft_manager=ft_manager,
136
+ )
137
+ dist_utils.init_distributed(job_config)
138
+ # initialize device memory monitor and get peak flops for MFU calculation
139
+ device_memory_monitor = build_device_memory_monitor()
140
+ gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
141
+ logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
142
+
143
+ # build meshes
144
+ world_mesh = parallel_dims.build_mesh(device_type=device_type)
145
+ if parallel_dims.dp_enabled:
146
+ dp_mesh = world_mesh["dp"]
147
+ dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
148
+ else:
149
+ dp_degree, dp_rank = 1, 0
150
+
151
+ if parallel_dims.pp_enabled:
152
+ raise NotImplementedError(
153
+ "Pipeline parallelism is not supported in this version"
154
+ )
155
+ """
156
+ ! TODO[flame]: We need to fix the pipeline parallelism for flame
157
+ [x] Match the key of models' components with the actual naming
158
+ [ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically
159
+ forces to tie if head is None, we need to handle this case
160
+ [ ]
161
+ """
162
+ pp_mesh = world_mesh["pp"]
163
+
164
+ # Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
165
+ dist_utils.set_determinism(
166
+ world_mesh, device, job_config.training.seed, job_config.training.deterministic
167
+ )
168
+ train_spec = get_train_spec(job_config.model.name)
169
+
170
+ logger.info("Loading tokenizer...")
171
+ tokenizer = AutoTokenizer.from_pretrained(
172
+ job_config.model.tokenizer_path,
173
+ trust_remote_code=True,
174
+ model_max_length=int(1e10),
175
+ )
176
+ logger.info(f"{tokenizer}")
177
+ logger.info(
178
+ f"Loading dataset {job_config.training.dataset}"
179
+ f":{job_config.training.dataset_name}"
180
+ if job_config.training.dataset_name is not None
181
+ else ""
182
+ )
183
+
184
+ min_num_shards = dp_degree * job_config.training.num_workers
185
+ if len(job_config.training.dataset.split(",")) == 1:
186
+ dataset = load_dataset(
187
+ path=job_config.training.dataset,
188
+ name=getattr(job_config.training, "dataset_name", None),
189
+ data_dir=getattr(job_config.training, "data_dir", None),
190
+ data_files=getattr(job_config.training, "data_files", None),
191
+ split=job_config.training.dataset_split or "train",
192
+ trust_remote_code=True,
193
+ streaming=job_config.training.streaming,
194
+ num_proc=(
195
+ job_config.training.num_workers
196
+ if not job_config.training.streaming
197
+ else None
198
+ ),
199
+ )
200
+ logger.info(f"{dataset}")
201
+
202
+ logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
203
+ if not job_config.training.streaming:
204
+ # the states of map-style dataset is recoverable after shuffling
205
+ dataset = dataset.shuffle(
206
+ seed=job_config.training.seed
207
+ ).to_iterable_dataset(num_shards=min_num_shards)
208
+ else:
209
+ if dataset.num_shards < min_num_shards:
210
+ logger.warning(
211
+ f"{color.red}"
212
+ f"Dataset {job_config.training.dataset} has insufficient shards ({dataset.num_shards}). "
213
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
214
+ f"{job_config.training.num_workers} dataloader workers. "
215
+ f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards."
216
+ f"{color.reset}"
217
+ )
218
+ dataset = (
219
+ load_dataset(
220
+ path=job_config.training.dataset,
221
+ name=getattr(job_config.training, "dataset_name", None),
222
+ data_dir=getattr(job_config.training, "data_dir", None),
223
+ data_files=getattr(job_config.training, "data_files", None),
224
+ split=job_config.training.dataset_split or "train",
225
+ trust_remote_code=True,
226
+ streaming=False,
227
+ num_proc=job_config.training.num_workers,
228
+ )
229
+ .shuffle(seed=job_config.training.seed)
230
+ .to_iterable_dataset(num_shards=min_num_shards)
231
+ )
232
+ else:
233
+ dataset = shuffle(dataset, seed=job_config.training.seed)
234
+ else:
235
+ datasets = job_config.training.dataset.split(",")
236
+ if job_config.training.dataset_name is not None:
237
+ dataset_names = [
238
+ name or None for name in job_config.training.dataset_name.split(",")
239
+ ]
240
+ assert len(dataset_names) == len(datasets), (
241
+ "The number of dataset names must match the number of datasets"
242
+ )
243
+ else:
244
+ dataset_names = [None] * len(datasets)
245
+ if job_config.training.dataset_split is not None:
246
+ dataset_splits = [
247
+ split or "train"
248
+ for split in job_config.training.dataset_split.split(",")
249
+ ]
250
+ assert len(dataset_splits) == len(datasets), (
251
+ "The number of dataset splits must match the number of datasets"
252
+ )
253
+ else:
254
+ dataset_splits = ["train"] * len(datasets)
255
+ if job_config.training.data_dir is not None:
256
+ data_dirs = [
257
+ data_dir or None for data_dir in job_config.training.data_dir.split(",")
258
+ ]
259
+ assert len(data_dirs) == len(datasets), (
260
+ "The number of data dirs must match the number of datasets"
261
+ )
262
+ else:
263
+ data_dirs = [None] * len(datasets)
264
+ if job_config.training.data_files is not None:
265
+ data_files = job_config.training.data_files.split(",")
266
+ assert len(data_files) == len(datasets), (
267
+ "The number of data files must match the number of datasets"
268
+ )
269
+ else:
270
+ data_files = [None] * len(datasets)
271
+ if job_config.training.data_probs is not None:
272
+ data_probs = [float(p) for p in job_config.training.data_probs.split(",")]
273
+ assert len(data_probs) == len(datasets), (
274
+ "The number of data probabilities must match the number of datasets"
275
+ )
276
+ else:
277
+ raise ValueError(
278
+ "Data sampling probabilities are required if using multiple datasets"
279
+ )
280
+
281
+ subsets = []
282
+ for i, prob in enumerate(data_probs):
283
+ subset = load_dataset(
284
+ path=datasets[i],
285
+ name=dataset_names[i],
286
+ data_dir=data_dirs[i],
287
+ data_files=data_files[i],
288
+ split=dataset_splits[i],
289
+ trust_remote_code=True,
290
+ streaming=job_config.training.streaming,
291
+ num_proc=(
292
+ job_config.training.num_workers
293
+ if not job_config.training.streaming
294
+ else None
295
+ ),
296
+ )
297
+ logger.info(
298
+ f"Subset {color.cyan}{datasets[i]}"
299
+ + (f":{dataset_names[i]} " if dataset_names[i] else " ")
300
+ + f"(p = {prob:.3f}){color.reset}:\n"
301
+ + f"{subset}"
302
+ )
303
+
304
+ logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
305
+ if not job_config.training.streaming:
306
+ # the states of map-style dataset is recoverable after shuffling
307
+ subset = subset.shuffle(
308
+ seed=job_config.training.seed
309
+ ).to_iterable_dataset(num_shards=min_num_shards)
310
+ else:
311
+ if subset.num_shards < min_num_shards:
312
+ logger.warning(
313
+ f"{color.red}"
314
+ f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). "
315
+ f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
316
+ f"{job_config.training.num_workers} dataloader workers. "
317
+ f"Resharding dataset to {min_num_shards} shards and disabling streaming mode."
318
+ f"{color.reset}"
319
+ )
320
+ # again, it's ok to directly shuffle the map-style dataset
321
+ # we expect an error raised if the map-style dataset still has not enough data shards
322
+ subset = (
323
+ load_dataset(
324
+ path=datasets[i],
325
+ name=dataset_names[i],
326
+ data_dir=data_dirs[i],
327
+ data_files=data_files[i],
328
+ split=dataset_splits[i],
329
+ trust_remote_code=True,
330
+ streaming=False,
331
+ num_proc=job_config.training.num_workers,
332
+ )
333
+ .shuffle(seed=job_config.training.seed)
334
+ .to_iterable_dataset(min_num_shards)
335
+ )
336
+ else:
337
+ # we set relatively small buffer size here as interleaving could provide some randomness
338
+ subset = shuffle(
339
+ subset,
340
+ seed=job_config.training.seed,
341
+ buffer_size=max(128, 1024 // len(datasets)),
342
+ )
343
+
344
+ if "text" in subset.column_names:
345
+ subset = subset.select_columns("text")
346
+ elif "content" in subset.column_names:
347
+ subset = subset.select_columns("content")
348
+ else:
349
+ raise ValueError(
350
+ f"Subset {datasets[i]} has no 'text' or 'content' column"
351
+ )
352
+ subsets.append(subset)
353
+
354
+ logger.info(
355
+ f"Interleaving {len(subsets)} datasets with probabilities {data_probs}"
356
+ )
357
+ dataset = interleave_datasets(
358
+ datasets=subsets,
359
+ probabilities=data_probs,
360
+ stopping_strategy="all_exhausted",
361
+ seed=job_config.training.seed,
362
+ )
363
+ logger.info(f"{dataset}")
364
+
365
+
366
+ logger.info(f"Loading model config from {job_config.model.config}")
367
+ model_config = AutoConfig.from_pretrained(job_config.model.config)
368
+
369
+ logger.info("Building dataloader...")
370
+ dataloader = build_dataloader(
371
+ dataset=dataset,
372
+ tokenizer=tokenizer,
373
+ rank=dp_rank,
374
+ world_size=dp_degree,
375
+ batch_size=job_config.training.batch_size,
376
+ # TODO: Make this more modular
377
+ # seq_len=job_config.training.seq_len if not model_config.use_myopic_loss else job_config.training.seq_len*2,
378
+ seq_len=job_config.training.seq_len * 2,
379
+ context_len=job_config.training.context_len,
380
+ varlen=job_config.training.varlen,
381
+ num_workers=job_config.training.num_workers,
382
+ pin_memory=job_config.training.pin_memory,
383
+ persistent_workers=job_config.training.persistent_workers,
384
+ snapshot_every_n_steps=job_config.checkpoint.interval,
385
+ )
386
+
387
+ # set the model configs from training inputs:
388
+ # 1. norm type to decide which norm layer to use
389
+ # 2. disable fused norm if TP is enabled
390
+ # 3. vocab size from tokenizer
391
+ # 4. context_len base on inputs
392
+ if parallel_dims.tp_enabled:
393
+ if model_config.fuse_norm:
394
+ logger.warning(
395
+ f"{color.red}"
396
+ f"Fused norm is not compatible with tensor parallelism. "
397
+ f"Disabling it for now."
398
+ f"{color.reset}"
399
+ )
400
+ model_config.fuse_norm = False
401
+ if parallel_dims.loss_parallel_enabled:
402
+ if model_config.fuse_cross_entropy:
403
+ logger.warning(
404
+ f"{color.red}"
405
+ f"Loss parallel enabled. Disabling fused cross entropy for now."
406
+ f"{color.reset}"
407
+ )
408
+ model_config.fuse_cross_entropy = False
409
+ model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
410
+
411
+ logger.info(
412
+ f"Building model from the config\n{color.green}{model_config}{color.reset}"
413
+ )
414
+ with torch.device("meta"):
415
+ model = AutoModelForCausalLM.from_config(model_config)
416
+ if (
417
+ getattr(model_config, "fuse_cross_entropy", False)
418
+ and FusedLinearCrossEntropyLoss is not None
419
+ ):
420
+ model.criterion = FusedLinearCrossEntropyLoss(
421
+ num_chunks=8 // parallel_dims.tp
422
+ )
423
+ # defer weight initialization until after parallelisms are applied
424
+ model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
425
+ logger.info(f"{color.blue}\n{model}{color.reset}\n")
426
+
427
+ # Build the collection of model converters. No-op if `model.converters` empty
428
+ model_converters = build_model_converters(job_config, parallel_dims)
429
+ model_converters.convert(model)
430
+
431
+ # calculate model size and flops per token
432
+ model_param_count, num_flops_per_token = get_nparams_and_flops(
433
+ model, model_config, job_config.training.context_len
434
+ )
435
+
436
+ # move sharded model to CPU/GPU and initialize weights via DTensor
437
+ if job_config.checkpoint.create_seed_checkpoint:
438
+ init_device = "cpu"
439
+ elif job_config.training.enable_cpu_offload:
440
+ init_device = "cpu"
441
+ else:
442
+ init_device = device_type
443
+
444
+ # apply parallelisms and initialization
445
+ if parallel_dims.pp_enabled:
446
+ # apply PT-D Pipeline Parallel
447
+ (
448
+ pp_schedule,
449
+ model_parts,
450
+ has_first_stage,
451
+ has_last_stage,
452
+ ) = train_spec.pipelining_fn(
453
+ model,
454
+ pp_mesh,
455
+ parallel_dims,
456
+ job_config,
457
+ device,
458
+ model_config,
459
+ train_spec.loss_fn,
460
+ )
461
+ # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
462
+ del model
463
+
464
+ # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
465
+ # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
466
+ # optimizer, and checkpointing
467
+ for m in model_parts:
468
+ # apply SPMD-style PT-D techniques
469
+ train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
470
+ m.to_empty(device=init_device)
471
+ with torch.no_grad():
472
+ m.post_init()
473
+ m.train()
474
+
475
+ # confirm that user will be able to view loss metrics on the console
476
+ ensure_pp_loss_visible(parallel_dims, job_config, color)
477
+ else:
478
+ # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
479
+ train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
480
+ model.to_empty(device=init_device)
481
+ with torch.no_grad():
482
+ model.post_init()
483
+ model.train()
484
+
485
+ model_parts = [model]
486
+
487
+ device_mem_stats = device_memory_monitor.get_peak_stats()
488
+ logger.info(
489
+ f"{device_type.upper()} memory usage for model: "
490
+ f"{device_mem_stats.max_reserved_gib:.2f}GiB"
491
+ f"({device_mem_stats.max_reserved_pct:.2f}%)"
492
+ )
493
+
494
+ # build optimizer after applying parallelisms to the model
495
+ optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
496
+ lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
497
+ # Post optimizer step model converters hook.
498
+ # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
499
+ # where it issues a single all-reduce for all parameters at once for better performance
500
+ optimizers.register_step_post_hook(
501
+ lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts)
502
+ )
503
+
504
+ train_state = TrainState()
505
+
506
+ # load initial checkpoint
507
+ checkpoint = CheckpointManager(
508
+ dataloader=dataloader,
509
+ model_parts=model_parts,
510
+ optimizers=optimizers,
511
+ lr_schedulers=lr_schedulers,
512
+ states={"train_state": train_state},
513
+ job_config=job_config,
514
+ ft_manager=ft_manager,
515
+ )
516
+
517
+ if job_config.checkpoint.create_seed_checkpoint:
518
+ assert world_size == 1, (
519
+ "Must create seed checkpoint using a single device, to disable sharding"
520
+ )
521
+ assert job_config.checkpoint.enable_checkpoint, (
522
+ "Must enable checkpointing when creating a seed checkpoint"
523
+ )
524
+ checkpoint.save(curr_step=0, force=True)
525
+ logger.info("Created seed checkpoint")
526
+ return
527
+
528
+ checkpoint.load(step=job_config.checkpoint.load_step)
529
+ metric_logger = build_metrics_processor(job_config, parallel_dims)
530
+ # Set dependent attributes for metric_logger
531
+ metric_logger.num_flops_per_token = num_flops_per_token
532
+ metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic
533
+ metric_logger.lr_schedulers = (
534
+ lr_schedulers # Pass schedulers if needed by logger logic
535
+ )
536
+
537
+ # plot losses loaded from checkpoint (if any) to TensorBoard
538
+ # NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
539
+ # This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
540
+ if train_state.step > 0 and len(metric_logger.data_loading_times) > 0:
541
+ for idx, step in enumerate(train_state.log_steps):
542
+ metric_logger.log(
543
+ step,
544
+ global_avg_loss=train_state.global_avg_losses[idx],
545
+ global_max_loss=train_state.global_max_losses[idx],
546
+ )
547
+
548
+ data_iterator = iter(dataloader)
549
+
550
+ train_context = dist_utils.get_train_context(
551
+ parallel_dims.loss_parallel_enabled,
552
+ job_config.experimental.enable_compiled_autograd,
553
+ )
554
+
555
+ # variables used to keep info for metrics logging
556
+ device_memory_monitor.reset_peak_stats()
557
+
558
+ global_batch_size = (
559
+ job_config.training.batch_size
560
+ * dp_degree
561
+ * job_config.training.gradient_accumulation_steps
562
+ )
563
+ num_tokens_per_step = global_batch_size * job_config.training.seq_len
564
+ # train loop
565
+ logger.info(f"{color.red}***** Running training *****{color.reset}")
566
+ logger.info(f"{color.green} Training starts at step {train_state.step + 1}")
567
+ logger.info(
568
+ f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}"
569
+ )
570
+ logger.info(
571
+ f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}"
572
+ )
573
+ logger.info(
574
+ f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}"
575
+ )
576
+ logger.info(
577
+ f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}"
578
+ f" ({num_tokens_per_step:,} tokens)"
579
+ )
580
+ logger.info(
581
+ f"{color.green} Total optimization steps = {job_config.training.steps:,} "
582
+ f"({job_config.training.steps * num_tokens_per_step:,} tokens)"
583
+ )
584
+ logger.info(
585
+ f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}"
586
+ f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)"
587
+ )
588
+ logger.info(
589
+ f"{color.green} Number of parameters = {model_param_count:,} {color.reset}"
590
+ )
591
+
592
+ with (
593
+ maybe_enable_profiling(
594
+ job_config, global_step=train_state.step
595
+ ) as torch_profiler,
596
+ maybe_enable_memory_snapshot(
597
+ job_config, global_step=train_state.step
598
+ ) as memory_profiler,
599
+ ):
600
+ while train_state.step < job_config.training.steps:
601
+ train_state.step += 1
602
+ gc_handler.run(train_state.step)
603
+
604
+ optimizers.zero_grad()
605
+
606
+ losses = defaultdict(list)
607
+ actual_loss = []
608
+ # do gradient accumulation if enabled
609
+ for _ in range(job_config.training.gradient_accumulation_steps):
610
+ # get batch
611
+ data_load_start = time.perf_counter()
612
+ batch = next(data_iterator)
613
+ # Recall that this is, for myopic and MTP, it will be
614
+ # input_ids : (B, seq_len)
615
+ # labels : (B, seq_len * 2)
616
+ input_ids, labels = batch["input_ids"][:, :job_config.training.seq_len], batch["labels"]
617
+
618
+ # Update metrics processor state before forward/backward
619
+ metric_logger.ntokens_since_last_log += input_ids.numel()
620
+ metric_logger.data_loading_times.append(
621
+ time.perf_counter() - data_load_start
622
+ )
623
+
624
+ input_ids = input_ids.to(device_type)
625
+
626
+ """
627
+ TODO[flame]: We need to carefully handle the position_ids for TP/CP
628
+ Depending on the Models'PE, the position_ids might be different.
629
+
630
+ e.g. for TP
631
+ For RoPE, all ranks have the same position_ids. [FOR HF model]
632
+ For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model]
633
+
634
+ e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids]
635
+ Each rank has the coresponding chunked position_ids. [FOR All model]
636
+
637
+ """
638
+ labels = labels.to(device_type)
639
+ cu_seqlens = (
640
+ batch["cu_seqlens"].to(device_type)
641
+ if "cu_seqlens" in batch
642
+ else None
643
+ )
644
+ if cu_seqlens is not None:
645
+ position_ids = prepare_position_ids(cu_seqlens).to(torch.int32)
646
+ else:
647
+ position_ids = (
648
+ torch.arange(0, input_ids.shape[1], device=device_type)
649
+ .repeat(input_ids.shape[0], 1)
650
+ .to(torch.int32)
651
+ )
652
+ # apply context parallelism if cp is enabled
653
+ # ensure CP handles the separate freqs_cis buffer for each pp stage
654
+ optional_context_parallel_ctx = (
655
+ dist_utils.create_context_parallel_ctx(
656
+ cp_mesh=world_mesh["cp"],
657
+ cp_buffers=[input_ids, labels, position_ids],
658
+ cp_seq_dims=[1, 1, 1],
659
+ cp_no_restore_buffers={input_ids, labels, position_ids},
660
+ cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
661
+ )
662
+ if parallel_dims.cp_enabled
663
+ else None
664
+ )
665
+
666
+ # #! TODO[flame], we should distribute the position_ids as well with CP
667
+ if parallel_dims.pp_enabled:
668
+ raise NotImplementedError(
669
+ "Pipeline parallelism is not supported in this version"
670
+ )
671
+ # Pipeline Parallel forward / backward inside step() call
672
+ with train_context(optional_context_parallel_ctx):
673
+ targets, losses = (
674
+ (labels, []) if has_last_stage else (None, None)
675
+ )
676
+
677
+ if has_first_stage:
678
+ pp_schedule.step(input_ids, target=targets, losses=losses)
679
+ else:
680
+ pp_schedule.step(target=targets, losses=losses)
681
+
682
+ # accumulate losses across pipeline microbatches
683
+ # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
684
+ loss = (
685
+ torch.mean(torch.stack(losses)).to(device)
686
+ if has_last_stage
687
+ else torch.tensor([-1.0], device=device)
688
+ )
689
+ else:
690
+ # Non-PP forward / backward
691
+ with train_context(optional_context_parallel_ctx):
692
+ output = model(
693
+ input_ids=input_ids,
694
+ labels=labels,
695
+ position_ids=position_ids,
696
+ cu_seqlens=cu_seqlens,
697
+ )
698
+ output_attributes = [field.name for field in dataclasses.fields(output)]
699
+ losses_atributes = [x for x in output_attributes if "loss" in x and x != "loss"]
700
+ loss = (
701
+ output.loss
702
+ / job_config.training.gradient_accumulation_steps
703
+ )
704
+ loss.backward()
705
+
706
+ actual_loss.append(loss)
707
+ for loss_attr in losses_atributes:
708
+ custom_loss = getattr(output, loss_attr, None)
709
+ if custom_loss is not None:
710
+ custom_loss = custom_loss / job_config.training.gradient_accumulation_steps
711
+ custom_loss = custom_loss
712
+ losses[loss_attr].append(custom_loss)
713
+
714
+ loss = sum(actual_loss)
715
+ for loss_attr, loss_values in losses.items():
716
+ losses[loss_attr] = sum(loss_values)
717
+
718
+ # clip gradients
719
+ grad_norm = dist_utils.clip_grad_norm_(
720
+ [p for m in model_parts for p in m.parameters()],
721
+ job_config.training.max_norm,
722
+ foreach=True,
723
+ pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
724
+ )
725
+
726
+ # optimizer step
727
+ checkpoint.maybe_wait_for_staging()
728
+ if job_config.training.skip_nan_inf and (
729
+ grad_norm.isnan() or grad_norm.isinf()
730
+ ):
731
+ logger.warning(
732
+ f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
733
+ )
734
+ optimizers.zero_grad()
735
+ train_state.skipped_step += 1
736
+ else:
737
+ optimizers.step()
738
+ lr_schedulers.step()
739
+
740
+ # log metrics - Use MetricsProcessor
741
+ global_avg_custom_loss = {}
742
+ global_max_custom_loss = {}
743
+ if metric_logger.should_log(train_state.step):
744
+ if (
745
+ parallel_dims.dp_replicate_enabled
746
+ or parallel_dims.dp_shard_enabled
747
+ or parallel_dims.cp_enabled
748
+ ):
749
+ loss = loss.detach()
750
+ # Use dist_mean/max on the accumulated loss for the step
751
+ global_avg_loss, global_max_loss = (
752
+ dist_utils.dist_mean(
753
+ loss,
754
+ world_mesh["dp_cp"],
755
+ ),
756
+ dist_utils.dist_max(
757
+ loss,
758
+ world_mesh["dp_cp"],
759
+ ),
760
+ )
761
+ for loss_attr, loss_value in losses.items():
762
+ global_avg_custom_loss[loss_attr] = dist_utils.dist_mean(
763
+ loss_value, world_mesh["dp_cp"]
764
+ )
765
+ global_max_custom_loss[loss_attr] = dist_utils.dist_max(
766
+ loss_value, world_mesh["dp_cp"]
767
+ )
768
+ else:
769
+ # Scale back the loss before logging
770
+ global_avg_loss = global_max_loss = loss.item()
771
+ for loss_attr, loss_value in losses.items():
772
+ global_avg_custom_loss[loss_attr] = global_max_custom_loss[
773
+ loss_attr
774
+ ] = loss_value.item()
775
+
776
+ # Update train state tokens and elapsed time
777
+ time_now = time.perf_counter()
778
+ time_delta = (
779
+ time_now - metric_logger.time_last_log
780
+ ) # Use metric_logger's time
781
+ train_state.token += (
782
+ metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger
783
+ * parallel_dims.world_size
784
+ / parallel_dims.non_data_parallel_size
785
+ )
786
+ train_state.elapsed += timedelta(seconds=time_delta)
787
+ train_state.log_steps.append(train_state.step)
788
+ train_state.global_avg_losses.append(global_avg_loss)
789
+ train_state.global_max_losses.append(global_max_loss)
790
+
791
+ # Log using the metric processor
792
+ last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
793
+ eta = (
794
+ train_state.elapsed
795
+ * (job_config.training.steps - train_state.step)
796
+ / train_state.step
797
+ )
798
+ extra_metrics = {
799
+ "optimizer/lr": last_lr,
800
+ "optimizer/grad_norm": grad_norm.item(),
801
+ "optimizer/skipped_step": train_state.skipped_step,
802
+ }
803
+ for loss_attr, loss_value in global_avg_custom_loss.items():
804
+ extra_metrics[f"loss_metrics/global_avg_{loss_attr}"] = loss_value.item() if isinstance(loss_value, torch.Tensor) else loss_value
805
+ metric_logger.log(
806
+ train_state.step,
807
+ global_avg_loss,
808
+ global_max_loss,
809
+ extra_metrics=extra_metrics,
810
+ )
811
+
812
+ logger.info(
813
+ f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
814
+ f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
815
+ )
816
+
817
+ checkpoint.save(
818
+ train_state.step, force=(train_state.step == job_config.training.steps)
819
+ )
820
+
821
+ if torch.distributed.get_rank() == 0:
822
+ if job_config.checkpoint.enable_checkpoint:
823
+ hf_target_path = None
824
+ dcp_save_path = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder, f"step-{train_state.step}")
825
+
826
+ # TODO: Haven't tested this one yet
827
+ if getattr(job_config.checkpoint, "convert_to_hf_on_save", False):
828
+ try:
829
+ # Get the path where DCP was just saved
830
+ # Check CheckpointManager API for the best way, assuming get_save_path exists
831
+ hf_target_path = f"{dcp_save_path}" # e.g., .../checkpoint/step-1000-hf
832
+
833
+ logger.info(f"Converting step {train_state.step} DCP checkpoint to HF format at: {hf_target_path}")
834
+ save_pretrained( # Call the imported function
835
+ path=hf_target_path, # Pass target HF path as 'path'
836
+ step=train_state.step,
837
+ config=job_config.model.config, # Pass model config path/id
838
+ tokenizer=job_config.model.tokenizer_path # Pass tokenizer path/id
839
+ )
840
+ logger.info(f"Successfully converted step {train_state.step} to HF format.")
841
+
842
+ except Exception as e:
843
+ logger.error(f"Failed to convert checkpoint step {train_state.step} to HF format: {e}", exc_info=True)
844
+
845
+ base_checkpoint_dir = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder)
846
+ if getattr(job_config.checkpoint, "hf_upload_enabled", True):
847
+ upload_format = getattr(job_config.checkpoint, "hf_upload_format", "hf")
848
+ keep_k_hub = getattr(job_config.checkpoint, "hf_keep_latest_k", 5)
849
+
850
+ local_path_to_upload = None
851
+ if upload_format == "hf":
852
+ if hf_target_path and os.path.isdir(hf_target_path):
853
+ local_path_to_upload = hf_target_path
854
+ elif upload_format == "dcp":
855
+ if dcp_save_path and os.path.isdir(dcp_save_path):
856
+ local_path_to_upload = dcp_save_path
857
+
858
+ if local_path_to_upload:
859
+ try:
860
+ upload_checkpoint_to_hf(
861
+ local_path=local_path_to_upload,
862
+ step=train_state.step,
863
+ hf_repo_id_for_run=run_specific_repo_id,
864
+ upload_format=upload_format,
865
+ hf_keep_latest_k=job_config.checkpoint.keep_latest_k,
866
+ )
867
+ except Exception as e:
868
+ logger.error(f"Failed during HF Hub upload for step {train_state.step}: {e}", exc_info=True)
869
+
870
+ # signal the profiler that the next profiling step has started
871
+ if torch_profiler:
872
+ torch_profiler.step()
873
+ if memory_profiler:
874
+ memory_profiler.step()
875
+
876
+ # reduce timeout after first train step for faster signal
877
+ # (assuming lazy init and compilation are finished)
878
+ if train_state.step == 1:
879
+ dist_utils.set_pg_timeouts(
880
+ timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
881
+ world_mesh=world_mesh,
882
+ )
883
+
884
+ if torch.distributed.get_rank() == 0:
885
+ logger.info("Sleeping 2 seconds for other ranks to complete")
886
+ time.sleep(2)
887
+
888
+ metric_logger.close()
889
+ logger.info("Training completed")
890
+
891
+
892
+ if __name__ == "__main__":
893
+ init_logger()
894
+ config = JobConfig()
895
+ config.parse_args()
896
+ main(config)
897
+ torch.distributed.destroy_process_group()
flame/utils/convert_dcp_to_hf.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import argparse
5
+ import io
6
+ import os
7
+ import tempfile
8
+ from datetime import timedelta
9
+
10
+ import torch
11
+ import torch.serialization
12
+ from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
13
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
14
+
15
+ import fla # noqa
16
+ from torchtitan.tools.logging import init_logger, logger
17
+
18
+
19
+ @torch.inference_mode()
20
+ def save_pretrained(
21
+ path: str,
22
+ step: int,
23
+ config: str,
24
+ tokenizer: str
25
+ ):
26
+ logger.info(f"Loading the config from {config}")
27
+ config = AutoConfig.from_pretrained(config, trust_remote_code=True)
28
+
29
+ logger.info(f"Saving the config to {path}")
30
+ config.save_pretrained(path)
31
+ logger.info(f"Loading the tokenizer from {tokenizer}")
32
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
33
+ logger.info(f"Saving the tokenizer to {path}")
34
+ tokenizer.save_pretrained(path)
35
+
36
+ with tempfile.TemporaryDirectory() as tmpdir:
37
+ # base_checkpoint_dir = os.path.dirname(path)
38
+ base_checkpoint_dir = path
39
+ checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
40
+ checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
41
+ logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
42
+ dcp_to_torch_save(checkpoint, checkpoint_path)
43
+
44
+ logger.info(f"Initializing the model from config\n{config}")
45
+ model = AutoModelForCausalLM.from_config(config)
46
+ logger.info(model)
47
+ logger.info("Loading state dict from the checkpoint")
48
+
49
+ # Add datetime.timedelta and io.BytesIO to safe globals
50
+ torch.serialization.add_safe_globals([timedelta, io.BytesIO])
51
+ # torch.load now with default weights_only=True will work
52
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
53
+
54
+ logger.info(f"Saving the model to {path}")
55
+ model.save_pretrained(path)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ init_logger()
60
+ parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
61
+ parser.add_argument("--path", type=str, required=True)
62
+ parser.add_argument("--step", type=int, required=True)
63
+ parser.add_argument("--config", type=str, required=True)
64
+ parser.add_argument("--tokenizer", type=str, required=True)
65
+ args = parser.parse_args()
66
+ save_pretrained(args.path, args.step, args.config, args.tokenizer)
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.51.3"
6
+ }