Add files using upload-large-folder tool
Browse files- configs/delta_net_340M.json +27 -0
- fla/layers/__init__.py +44 -0
- fla/layers/__pycache__/bitattn.cpython-311.pyc +0 -0
- fla/layers/__pycache__/gated_deltanet.cpython-311.pyc +0 -0
- fla/layers/__pycache__/hgrn2.cpython-311.pyc +0 -0
- fla/layers/__pycache__/lightnet.cpython-311.pyc +0 -0
- fla/layers/__pycache__/rebased.cpython-311.pyc +0 -0
- fla/layers/__pycache__/rwkv7.cpython-311.pyc +0 -0
- fla/layers/abc.py +218 -0
- fla/layers/delta_net.py +291 -0
- fla/layers/gated_deltanet.py +293 -0
- fla/layers/gated_deltaproduct.py +351 -0
- fla/layers/gla.py +294 -0
- fla/layers/gsa.py +227 -0
- fla/layers/hgrn.py +168 -0
- fla/layers/rebased.py +133 -0
- fla/layers/rwkv6.py +307 -0
- fla/layers/simple_gla.py +261 -0
- fla/models/abc/__init__.py +13 -0
- fla/models/abc/__pycache__/__init__.cpython-311.pyc +0 -0
- fla/models/abc/modeling_abc.py +418 -0
- fla/models/bitnet/__init__.py +13 -0
- fla/models/bitnet/__pycache__/__init__.cpython-311.pyc +0 -0
- fla/models/bitnet/__pycache__/configuration_bitnet.cpython-311.pyc +0 -0
- fla/models/bitnet/__pycache__/modeling_bitnet.cpython-311.pyc +0 -0
- fla/models/bitnet/configuration_bitnet.py +67 -0
- fla/models/delta_net/__init__.py +12 -0
- fla/models/forgetting_transformer/__pycache__/__init__.cpython-311.pyc +0 -0
- fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-311.pyc +0 -0
- fla/models/forgetting_transformer/configuration_forgetting_transformer.py +68 -0
- fla/models/gated_deltanet/__pycache__/__init__.cpython-311.pyc +0 -0
- fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-311.pyc +0 -0
- fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py +90 -0
- fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py +520 -0
- fla/models/gla/__init__.py +13 -0
- fla/models/gla/__pycache__/__init__.cpython-311.pyc +0 -0
- fla/models/gsa/__pycache__/__init__.cpython-311.pyc +0 -0
- fla/models/gsa/__pycache__/configuration_gsa.cpython-311.pyc +0 -0
- fla/models/gsa/configuration_gsa.py +97 -0
- fla/models/gsa/modeling_gsa.py +420 -0
- fla/models/hgrn/modeling_hgrn.py +420 -0
- fla/models/linear_attn/configuration_linear_attn.py +91 -0
- fla/models/nsa/modeling_nsa.py +398 -0
- flame/__init__.py +1 -0
- flame/__pycache__/__init__.cpython-311.pyc +0 -0
- flame/__pycache__/train.cpython-311.pyc +0 -0
- flame/models/parallelize_fla.py +550 -0
- flame/train.py +897 -0
- flame/utils/convert_dcp_to_hf.py +66 -0
- generation_config.json +6 -0
configs/delta_net_340M.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"attn_mode": "chunk",
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"conv_size": 4,
|
5 |
+
"eos_token_id": 2,
|
6 |
+
"expand_k": 1,
|
7 |
+
"expand_v": 1,
|
8 |
+
"fuse_cross_entropy": true,
|
9 |
+
"hidden_act": "swish",
|
10 |
+
"hidden_ratio": 4,
|
11 |
+
"hidden_size": 1024,
|
12 |
+
"initializer_range": 0.006,
|
13 |
+
"intermediate_size": null,
|
14 |
+
"model_type": "delta_net",
|
15 |
+
"norm_eps": 1e-06,
|
16 |
+
"norm_first": false,
|
17 |
+
"num_heads": 8,
|
18 |
+
"num_hidden_layers": 24,
|
19 |
+
"qk_activation": "silu",
|
20 |
+
"qk_norm": "l2",
|
21 |
+
"tie_word_embeddings": false,
|
22 |
+
"use_beta": true,
|
23 |
+
"use_cache": true,
|
24 |
+
"use_gate": false,
|
25 |
+
"use_output_norm": true,
|
26 |
+
"use_short_conv": true
|
27 |
+
}
|
fla/layers/__init__.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
from .abc import ABCAttention
|
5 |
+
from .attn import Attention
|
6 |
+
from .based import BasedLinearAttention
|
7 |
+
from .bitattn import BitAttention
|
8 |
+
from .delta_net import DeltaNet
|
9 |
+
from .forgetting_attn import ForgettingAttention
|
10 |
+
from .gated_deltanet import GatedDeltaNet
|
11 |
+
from .gated_deltaproduct import GatedDeltaProduct
|
12 |
+
from .gla import GatedLinearAttention
|
13 |
+
from .gsa import GatedSlotAttention
|
14 |
+
from .hgrn import HGRNAttention
|
15 |
+
from .hgrn2 import HGRN2Attention
|
16 |
+
from .lightnet import LightNetAttention
|
17 |
+
from .linear_attn import LinearAttention
|
18 |
+
from .multiscale_retention import MultiScaleRetention
|
19 |
+
from .nsa import NativeSparseAttention
|
20 |
+
from .rebased import ReBasedLinearAttention
|
21 |
+
from .rwkv6 import RWKV6Attention
|
22 |
+
from .rwkv7 import RWKV7Attention
|
23 |
+
|
24 |
+
__all__ = [
|
25 |
+
'ABCAttention',
|
26 |
+
'Attention',
|
27 |
+
'BasedLinearAttention',
|
28 |
+
'BitAttention',
|
29 |
+
'DeltaNet',
|
30 |
+
'ForgettingAttention',
|
31 |
+
'GatedDeltaNet',
|
32 |
+
'GatedDeltaProduct',
|
33 |
+
'GatedLinearAttention',
|
34 |
+
'GatedSlotAttention',
|
35 |
+
'HGRNAttention',
|
36 |
+
'HGRN2Attention',
|
37 |
+
'LightNetAttention',
|
38 |
+
'LinearAttention',
|
39 |
+
'MultiScaleRetention',
|
40 |
+
'NativeSparseAttention',
|
41 |
+
'ReBasedLinearAttention',
|
42 |
+
'RWKV6Attention',
|
43 |
+
'RWKV7Attention',
|
44 |
+
]
|
fla/layers/__pycache__/bitattn.cpython-311.pyc
ADDED
Binary file (9.62 kB). View file
|
|
fla/layers/__pycache__/gated_deltanet.cpython-311.pyc
ADDED
Binary file (13.9 kB). View file
|
|
fla/layers/__pycache__/hgrn2.cpython-311.pyc
ADDED
Binary file (9.09 kB). View file
|
|
fla/layers/__pycache__/lightnet.cpython-311.pyc
ADDED
Binary file (9.33 kB). View file
|
|
fla/layers/__pycache__/rebased.cpython-311.pyc
ADDED
Binary file (7.17 kB). View file
|
|
fla/layers/__pycache__/rwkv7.cpython-311.pyc
ADDED
Binary file (11 kB). View file
|
|
fla/layers/abc.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution
|
14 |
+
from fla.modules.activations import swiglu, swish
|
15 |
+
from fla.ops.abc.chunk import chunk_abc
|
16 |
+
|
17 |
+
if TYPE_CHECKING:
|
18 |
+
from fla.models.utils import Cache
|
19 |
+
|
20 |
+
|
21 |
+
class ABCAttention(nn.Module):
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
hidden_size: int = 1024,
|
26 |
+
expand_k: float = 0.5,
|
27 |
+
expand_v: float = 1.0,
|
28 |
+
num_heads: int = 4,
|
29 |
+
use_short_conv: bool = False,
|
30 |
+
conv_size: int = 4,
|
31 |
+
conv_bias: bool = False,
|
32 |
+
num_slots: Optional[int] = None,
|
33 |
+
elementwise_affine: Optional[bool] = True,
|
34 |
+
norm_eps: float = 1e-5,
|
35 |
+
gate_low_rank_dim: int = 16,
|
36 |
+
gate_logit_normalizer: int = 16,
|
37 |
+
use_rope: bool = True,
|
38 |
+
use_input_gate: bool = False,
|
39 |
+
use_output_gate: bool = True,
|
40 |
+
use_norm: bool = True,
|
41 |
+
clamp_min: Optional[float] = -32,
|
42 |
+
clamp_max: Optional[float] = 32,
|
43 |
+
layer_idx: Optional[int] = None,
|
44 |
+
**kwargs
|
45 |
+
) -> ABCAttention:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.hidden_size = hidden_size
|
49 |
+
self.expand_k = expand_k
|
50 |
+
self.expand_v = expand_v
|
51 |
+
self.num_heads = num_heads
|
52 |
+
self.key_dim = int(self.hidden_size * self.expand_k)
|
53 |
+
self.value_dim = int(self.hidden_size * self.expand_v)
|
54 |
+
self.head_k_dim = self.key_dim // self.num_heads
|
55 |
+
self.head_v_dim = self.value_dim // self.num_heads
|
56 |
+
|
57 |
+
self.use_short_conv = use_short_conv
|
58 |
+
self.conv_size = conv_size
|
59 |
+
self.conv_bias = conv_bias
|
60 |
+
|
61 |
+
self.gate_low_rank_dim = gate_low_rank_dim
|
62 |
+
self.gate_logit_normalizer = gate_logit_normalizer
|
63 |
+
|
64 |
+
self.use_rope = use_rope
|
65 |
+
self.use_input_gate = use_input_gate
|
66 |
+
self.use_output_gate = use_output_gate
|
67 |
+
self.use_norm = use_norm
|
68 |
+
|
69 |
+
if num_slots is None:
|
70 |
+
num_slots = self.head_k_dim
|
71 |
+
self.num_slots = num_slots
|
72 |
+
|
73 |
+
self.norm_eps = norm_eps
|
74 |
+
|
75 |
+
self.clamp_min = clamp_min
|
76 |
+
self.clamp_max = clamp_max
|
77 |
+
self.layer_idx = layer_idx
|
78 |
+
|
79 |
+
if layer_idx is None:
|
80 |
+
warnings.warn(
|
81 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
82 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
83 |
+
"when creating this class."
|
84 |
+
)
|
85 |
+
|
86 |
+
self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
|
87 |
+
self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
|
88 |
+
self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
|
89 |
+
|
90 |
+
if use_output_gate:
|
91 |
+
self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False)
|
92 |
+
self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False)
|
93 |
+
self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
|
94 |
+
|
95 |
+
if use_short_conv:
|
96 |
+
self.conv_size = conv_size
|
97 |
+
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
98 |
+
self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
99 |
+
self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu')
|
100 |
+
|
101 |
+
if self.use_norm:
|
102 |
+
if self.use_output_gate:
|
103 |
+
self.g_norm = FusedRMSNormGated(
|
104 |
+
hidden_size=self.head_v_dim,
|
105 |
+
elementwise_affine=elementwise_affine,
|
106 |
+
eps=norm_eps
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
self.g_norm = RMSNorm(
|
110 |
+
hidden_size=self.head_v_dim,
|
111 |
+
elementwise_affine=elementwise_affine,
|
112 |
+
eps=norm_eps
|
113 |
+
)
|
114 |
+
|
115 |
+
if self.use_rope:
|
116 |
+
self.rotary = RotaryEmbedding(self.head_k_dim)
|
117 |
+
|
118 |
+
def forward(
|
119 |
+
self,
|
120 |
+
hidden_states: torch.Tensor,
|
121 |
+
attention_mask: Optional[torch.Tensor] = None,
|
122 |
+
past_key_values: Optional[Cache] = None,
|
123 |
+
use_cache: Optional[bool] = False,
|
124 |
+
output_attentions: Optional[bool] = False,
|
125 |
+
**kwargs
|
126 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
127 |
+
if attention_mask is not None:
|
128 |
+
assert len(attention_mask.shape) == 2, (
|
129 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
130 |
+
"for padding purposes (0 indicating padding). "
|
131 |
+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
132 |
+
)
|
133 |
+
|
134 |
+
last_state = None
|
135 |
+
if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
136 |
+
last_state = past_key_values[self.layer_idx]
|
137 |
+
|
138 |
+
cu_seqlens = kwargs.get('cu_seqlens', None)
|
139 |
+
if cu_seqlens is not None:
|
140 |
+
raise NotImplementedError("Training with cu_seqlens is not supported yet for ABCAttention")
|
141 |
+
if self.use_short_conv:
|
142 |
+
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
143 |
+
if last_state is not None:
|
144 |
+
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
|
145 |
+
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
|
146 |
+
q, conv_state_q = self.q_conv1d(
|
147 |
+
x=self.q_proj(hidden_states),
|
148 |
+
mask=conv_mask,
|
149 |
+
cache=conv_state_q,
|
150 |
+
output_final_state=use_cache,
|
151 |
+
cu_seqlens=cu_seqlens
|
152 |
+
)
|
153 |
+
k, conv_state_k = self.k_conv1d(
|
154 |
+
x=self.k_proj(hidden_states),
|
155 |
+
mask=conv_mask,
|
156 |
+
cache=conv_state_k,
|
157 |
+
output_final_state=use_cache,
|
158 |
+
cu_seqlens=cu_seqlens
|
159 |
+
)
|
160 |
+
v, conv_state_v = self.v_conv1d(
|
161 |
+
x=self.v_proj(hidden_states),
|
162 |
+
mask=conv_mask,
|
163 |
+
cache=conv_state_v,
|
164 |
+
output_final_state=use_cache,
|
165 |
+
cu_seqlens=cu_seqlens
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
q = self.q_proj(hidden_states)
|
169 |
+
k = self.k_proj(hidden_states)
|
170 |
+
v = self.v_proj(hidden_states)
|
171 |
+
|
172 |
+
if self.use_input_gate:
|
173 |
+
q, k, v = map(lambda x: swish(x), (q, k, v))
|
174 |
+
# dealing with left-padding
|
175 |
+
if attention_mask is not None:
|
176 |
+
v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
|
177 |
+
|
178 |
+
q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
|
179 |
+
v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
|
180 |
+
if self.use_rope:
|
181 |
+
seqlen_offset = 0
|
182 |
+
if past_key_values is not None:
|
183 |
+
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
|
184 |
+
q, k = self.rotary(q, k, seqlen_offset=seqlen_offset)
|
185 |
+
|
186 |
+
s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', m=self.num_slots)
|
187 |
+
s = s.clamp_(self.clamp_min, self.clamp_max)
|
188 |
+
|
189 |
+
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
|
190 |
+
o, recurrent_state = chunk_abc(
|
191 |
+
q=q,
|
192 |
+
k=k,
|
193 |
+
v=v,
|
194 |
+
s=s,
|
195 |
+
initial_state=recurrent_state,
|
196 |
+
output_final_state=use_cache,
|
197 |
+
head_first=False
|
198 |
+
)
|
199 |
+
if past_key_values is not None:
|
200 |
+
past_key_values.update(
|
201 |
+
recurrent_state=recurrent_state,
|
202 |
+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
|
203 |
+
layer_idx=self.layer_idx,
|
204 |
+
offset=q.shape[1]
|
205 |
+
)
|
206 |
+
|
207 |
+
if self.use_norm and not self.use_output_gate:
|
208 |
+
o = self.g_norm(o)
|
209 |
+
elif self.use_output_gate:
|
210 |
+
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
|
211 |
+
o = self.g_norm(o, g) if self.use_norm else swiglu(g, o)
|
212 |
+
o = rearrange(o, '... h d -> ... (h d)')
|
213 |
+
o = self.o_proj(o)
|
214 |
+
|
215 |
+
return o, None, past_key_values
|
216 |
+
|
217 |
+
def state_size(self, seq_len: int = 2048):
|
218 |
+
return 2 * self.num_slots * self.hidden_size
|
fla/layers/delta_net.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from einops import rearrange
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
|
14 |
+
from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
|
15 |
+
|
16 |
+
if TYPE_CHECKING:
|
17 |
+
from transformers.processing_utils import Unpack
|
18 |
+
|
19 |
+
from fla.models.utils import Cache
|
20 |
+
|
21 |
+
|
22 |
+
def elu_p1(x):
|
23 |
+
return (F.elu(x, 1., False) + 1.).to(x)
|
24 |
+
|
25 |
+
|
26 |
+
def sum_norm(x):
|
27 |
+
return (x / x.sum(-1, keepdim=True)).to(x)
|
28 |
+
|
29 |
+
|
30 |
+
class DeltaNet(nn.Module):
|
31 |
+
r"""
|
32 |
+
The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa:
|
33 |
+
DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa
|
34 |
+
|
35 |
+
Args:
|
36 |
+
mode (str, Optional):
|
37 |
+
Which DeltaNet kernel to use.
|
38 |
+
Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
|
39 |
+
Default: `chunk`.
|
40 |
+
hidden_size (int, Optional):
|
41 |
+
The hidden size of the input. Default: 1024.
|
42 |
+
expand_k (float, Optional):
|
43 |
+
The expansion ratio for the key dim. Default: 1.0.
|
44 |
+
expand_v (float, Optional):
|
45 |
+
The expansion ratio for the value dim. Default: 1.0.
|
46 |
+
num_heads (int, Optional):
|
47 |
+
The number of heads. Default: 4.
|
48 |
+
use_beta (bool, Optional):
|
49 |
+
Whether to use beta. Default: `True`.
|
50 |
+
use_gate (bool, Optional):
|
51 |
+
Whether to use output gate. Default: `False`.
|
52 |
+
use_short_conv (bool, Optional):
|
53 |
+
Whether to use short convolutions. Default: `True`.
|
54 |
+
conv_size (int, Optional):
|
55 |
+
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
|
56 |
+
conv_bias (bool, Optional):
|
57 |
+
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
|
58 |
+
allow_neg_eigval (bool, Optional):
|
59 |
+
Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2.
|
60 |
+
See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537)
|
61 |
+
layer_idx (int, Optional):
|
62 |
+
The index of the layer. Default: None.
|
63 |
+
norm_eps (float, Optional):
|
64 |
+
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
|
65 |
+
qk_activation (str, Optional):
|
66 |
+
The activation function for the query and key. Default: `silu`.
|
67 |
+
qk_norm (str, Optional):
|
68 |
+
The normalization method for the query and key. Default: `l2`.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
mode: str = 'chunk',
|
74 |
+
d_model: int = None,
|
75 |
+
hidden_size: int = 1024,
|
76 |
+
expand_k: float = 1.0,
|
77 |
+
expand_v: float = 1.0,
|
78 |
+
num_heads: int = 4,
|
79 |
+
use_beta: bool = True,
|
80 |
+
use_gate: bool = False,
|
81 |
+
use_short_conv: bool = True,
|
82 |
+
conv_size: int = 4,
|
83 |
+
conv_bias: bool = False,
|
84 |
+
allow_neg_eigval: bool = False,
|
85 |
+
layer_idx: int = None,
|
86 |
+
qk_activation: str = 'silu',
|
87 |
+
qk_norm: str = 'l2',
|
88 |
+
norm_eps: float = 1e-5,
|
89 |
+
**kwargs
|
90 |
+
) -> DeltaNet:
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
self.mode = mode
|
94 |
+
self.qk_activation = qk_activation
|
95 |
+
self.qk_norm = qk_norm
|
96 |
+
|
97 |
+
assert self.qk_activation in ['silu', 'relu', 'elu', 'identity']
|
98 |
+
assert self.qk_norm in ['l2', 'sum']
|
99 |
+
|
100 |
+
if d_model is not None:
|
101 |
+
hidden_size = d_model
|
102 |
+
self.hidden_size = hidden_size
|
103 |
+
self.expand_k = expand_k
|
104 |
+
self.expand_v = expand_v
|
105 |
+
self.num_heads = num_heads
|
106 |
+
self.use_gate = use_gate
|
107 |
+
self.use_short_conv = use_short_conv
|
108 |
+
self.conv_size = conv_size
|
109 |
+
self.conv_bias = conv_bias
|
110 |
+
self.allow_neg_eigval = allow_neg_eigval
|
111 |
+
|
112 |
+
self.key_dim = int(hidden_size * expand_k)
|
113 |
+
self.value_dim = int(hidden_size * expand_v)
|
114 |
+
self.head_k_dim = self.key_dim // num_heads
|
115 |
+
self.head_v_dim = self.value_dim // num_heads
|
116 |
+
self.layer_idx = layer_idx
|
117 |
+
|
118 |
+
self.silu = nn.SiLU()
|
119 |
+
if mode == 'fused_chunk':
|
120 |
+
raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.")
|
121 |
+
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
122 |
+
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
123 |
+
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
124 |
+
|
125 |
+
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
126 |
+
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
127 |
+
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
128 |
+
|
129 |
+
self.use_beta = use_beta
|
130 |
+
if self.use_beta:
|
131 |
+
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
|
132 |
+
if use_short_conv:
|
133 |
+
self.conv_size = conv_size
|
134 |
+
self.q_conv1d = ShortConvolution(
|
135 |
+
hidden_size=self.key_dim,
|
136 |
+
kernel_size=conv_size,
|
137 |
+
activation='silu' if qk_activation == 'silu' else None
|
138 |
+
)
|
139 |
+
self.k_conv1d = ShortConvolution(
|
140 |
+
hidden_size=self.key_dim,
|
141 |
+
kernel_size=conv_size,
|
142 |
+
activation='silu' if qk_activation == 'silu' else None
|
143 |
+
)
|
144 |
+
self.v_conv1d = ShortConvolution(
|
145 |
+
hidden_size=self.value_dim,
|
146 |
+
kernel_size=conv_size,
|
147 |
+
activation='silu'
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
raise UserWarning(
|
151 |
+
"ShortConvolution is crucial to the performance. "
|
152 |
+
"Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
|
153 |
+
)
|
154 |
+
if use_gate:
|
155 |
+
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
156 |
+
self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
|
157 |
+
else:
|
158 |
+
self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
|
159 |
+
|
160 |
+
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
161 |
+
|
162 |
+
def forward(
|
163 |
+
self,
|
164 |
+
hidden_states: torch.Tensor,
|
165 |
+
attention_mask: Optional[torch.Tensor] = None,
|
166 |
+
past_key_values: Optional[Cache] = None,
|
167 |
+
use_cache: Optional[bool] = False,
|
168 |
+
output_attentions: Optional[bool] = False,
|
169 |
+
**kwargs: Unpack[Dict]
|
170 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
171 |
+
if attention_mask is not None:
|
172 |
+
assert len(attention_mask.shape) == 2, (
|
173 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
174 |
+
"for padding purposes (0 indicating padding). "
|
175 |
+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
176 |
+
)
|
177 |
+
|
178 |
+
# change to inference mode.
|
179 |
+
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
|
180 |
+
|
181 |
+
last_state = None
|
182 |
+
if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
183 |
+
last_state = past_key_values[self.layer_idx]
|
184 |
+
|
185 |
+
cu_seqlens = kwargs.get('cu_seqlens', None)
|
186 |
+
if self.use_short_conv:
|
187 |
+
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
188 |
+
if last_state is not None:
|
189 |
+
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
|
190 |
+
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
|
191 |
+
q, conv_state_q = self.q_conv1d(
|
192 |
+
x=self.q_proj(hidden_states),
|
193 |
+
mask=conv_mask,
|
194 |
+
cache=conv_state_q,
|
195 |
+
output_final_state=use_cache,
|
196 |
+
cu_seqlens=cu_seqlens
|
197 |
+
)
|
198 |
+
k, conv_state_k = self.k_conv1d(
|
199 |
+
x=self.k_proj(hidden_states),
|
200 |
+
mask=conv_mask,
|
201 |
+
cache=conv_state_k,
|
202 |
+
output_final_state=use_cache,
|
203 |
+
cu_seqlens=cu_seqlens
|
204 |
+
)
|
205 |
+
v, conv_state_v = self.v_conv1d(
|
206 |
+
x=self.v_proj(hidden_states),
|
207 |
+
mask=conv_mask,
|
208 |
+
cache=conv_state_v,
|
209 |
+
output_final_state=use_cache,
|
210 |
+
cu_seqlens=cu_seqlens
|
211 |
+
)
|
212 |
+
else:
|
213 |
+
q = self.q_proj(hidden_states)
|
214 |
+
k = self.k_proj(hidden_states)
|
215 |
+
if self.qk_activation == 'silu':
|
216 |
+
q, k = self.silu(q), self.silu(k)
|
217 |
+
v = self.silu(self.v_proj(hidden_states))
|
218 |
+
|
219 |
+
q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
|
220 |
+
v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
|
221 |
+
if self.qk_activation != 'silu':
|
222 |
+
if self.qk_activation == 'relu':
|
223 |
+
q, k = q.relu(), k.relu()
|
224 |
+
elif self.qk_activation == 'elu':
|
225 |
+
q, k = elu_p1(q), elu_p1(k)
|
226 |
+
elif self.qk_activation == 'identity':
|
227 |
+
pass
|
228 |
+
else:
|
229 |
+
raise NotImplementedError
|
230 |
+
|
231 |
+
if self.qk_norm == 'sum':
|
232 |
+
q = sum_norm(q).to(q)
|
233 |
+
k = sum_norm(k).to(k)
|
234 |
+
|
235 |
+
if self.use_beta:
|
236 |
+
beta = self.b_proj(hidden_states).sigmoid()
|
237 |
+
else:
|
238 |
+
beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2])
|
239 |
+
|
240 |
+
if self.allow_neg_eigval:
|
241 |
+
beta = beta * 2.
|
242 |
+
|
243 |
+
# dealing with padding
|
244 |
+
if attention_mask is not None:
|
245 |
+
beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
|
246 |
+
|
247 |
+
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
|
248 |
+
if mode == 'fused_recurrent':
|
249 |
+
o, recurrent_state = fused_recurrent_delta_rule(
|
250 |
+
q=q,
|
251 |
+
k=k,
|
252 |
+
v=v,
|
253 |
+
beta=beta,
|
254 |
+
initial_state=recurrent_state,
|
255 |
+
output_final_state=use_cache,
|
256 |
+
cu_seqlens=cu_seqlens,
|
257 |
+
head_first=False,
|
258 |
+
use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
|
259 |
+
)
|
260 |
+
elif mode == 'chunk':
|
261 |
+
o, recurrent_state = chunk_delta_rule(
|
262 |
+
q=q,
|
263 |
+
k=k,
|
264 |
+
v=v,
|
265 |
+
beta=beta,
|
266 |
+
initial_state=recurrent_state,
|
267 |
+
output_final_state=use_cache,
|
268 |
+
cu_seqlens=cu_seqlens,
|
269 |
+
head_first=False,
|
270 |
+
use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
274 |
+
|
275 |
+
if past_key_values is not None:
|
276 |
+
past_key_values.update(
|
277 |
+
recurrent_state=recurrent_state,
|
278 |
+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
|
279 |
+
layer_idx=self.layer_idx,
|
280 |
+
offset=q.shape[1]
|
281 |
+
)
|
282 |
+
|
283 |
+
if self.use_gate:
|
284 |
+
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
|
285 |
+
o = self.o_norm(o, g)
|
286 |
+
else:
|
287 |
+
o = self.o_norm(o)
|
288 |
+
o = rearrange(o, 'b t h d -> b t (h d)')
|
289 |
+
o = self.o_proj(o)
|
290 |
+
|
291 |
+
return o, None, past_key_values
|
fla/layers/gated_deltanet.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
import math
|
7 |
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange
|
12 |
+
from torch.nn import functional as F
|
13 |
+
|
14 |
+
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
|
15 |
+
from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
|
16 |
+
|
17 |
+
if TYPE_CHECKING:
|
18 |
+
from transformers.processing_utils import Unpack
|
19 |
+
|
20 |
+
from fla.models.utils import Cache
|
21 |
+
|
22 |
+
|
23 |
+
@torch.compile
|
24 |
+
def elu_p1(x):
|
25 |
+
return (F.elu(x, 1., False) + 1.).to(x)
|
26 |
+
|
27 |
+
|
28 |
+
@torch.compile
|
29 |
+
def sum_norm(x):
|
30 |
+
return (x / x.sum(-1, keepdim=True)).to(x)
|
31 |
+
|
32 |
+
|
33 |
+
class GatedDeltaNet(nn.Module):
|
34 |
+
"""
|
35 |
+
The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
|
36 |
+
|
37 |
+
Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
|
38 |
+
|
39 |
+
Parameter alloation when use_gate=True:
|
40 |
+
- 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
|
41 |
+
- 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
|
42 |
+
- Others are ignorably small.
|
43 |
+
- In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
|
44 |
+
NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
|
45 |
+
|
46 |
+
Parameter allocation when use_gate=False:
|
47 |
+
- 1 * hidden_size * hidden_size for the q_proj and k_proj each
|
48 |
+
- 2 * hidden_size * hidden_size for the v_proj and o_proj each
|
49 |
+
- Others are ignorably small.
|
50 |
+
- In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
|
51 |
+
|
52 |
+
Args:
|
53 |
+
hidden_size (int, Optional):
|
54 |
+
The hidden size of the input. Default: 2048.
|
55 |
+
expand_v (float, Optional):
|
56 |
+
The expansion ratio for the value dim. Default: 2.0.
|
57 |
+
head_dim (int, Optional):
|
58 |
+
The dimension of each head. Default: 256.
|
59 |
+
num_heads (int, Optional):
|
60 |
+
The number of heads. Default: 4.
|
61 |
+
mode (str, Optional):
|
62 |
+
Which Gated DeltaNet kernel to use.
|
63 |
+
Currently available: `chunk` and `fused_recurrent`.
|
64 |
+
Default: `chunk`.
|
65 |
+
use_beta (bool, Optional):
|
66 |
+
Whether to use beta. Default: `True`.
|
67 |
+
use_gate (bool, Optional):
|
68 |
+
Whether to use output gate. Default: `True`.
|
69 |
+
use_short_conv (bool, Optional):
|
70 |
+
Whether to use short convolutions. Default: `True`.
|
71 |
+
conv_size (int, Optional):
|
72 |
+
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
|
73 |
+
conv_bias (bool, Optional):
|
74 |
+
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
|
75 |
+
layer_idx (int, Optional):
|
76 |
+
The index of the layer. Default: None.
|
77 |
+
norm_eps (float, Optional):
|
78 |
+
The epsilon value for the normalization layer. Default: 1e-5.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
hidden_size: int = 2048,
|
84 |
+
expand_v: float = 2,
|
85 |
+
head_dim: int = 256,
|
86 |
+
num_heads: int = 6,
|
87 |
+
mode: str = 'chunk',
|
88 |
+
use_gate: bool = True,
|
89 |
+
use_short_conv: bool = True,
|
90 |
+
conv_size: int = 4,
|
91 |
+
conv_bias: bool = False,
|
92 |
+
layer_idx: int = None,
|
93 |
+
norm_eps: float = 1e-5,
|
94 |
+
**kwargs
|
95 |
+
) -> GatedDeltaNet:
|
96 |
+
super().__init__()
|
97 |
+
|
98 |
+
self.mode = mode
|
99 |
+
|
100 |
+
self.hidden_size = hidden_size
|
101 |
+
self.expand_v = expand_v
|
102 |
+
|
103 |
+
self.use_gate = use_gate
|
104 |
+
self.use_short_conv = use_short_conv
|
105 |
+
self.conv_size = conv_size
|
106 |
+
self.conv_bias = conv_bias
|
107 |
+
|
108 |
+
self.head_dim = head_dim
|
109 |
+
self.num_heads = num_heads
|
110 |
+
|
111 |
+
self.key_dim = int(self.num_heads * self.head_dim)
|
112 |
+
self.value_dim = int(self.key_dim * self.expand_v)
|
113 |
+
self.head_k_dim = head_dim
|
114 |
+
self.head_v_dim = int(head_dim * self.expand_v)
|
115 |
+
self.layer_idx = layer_idx
|
116 |
+
|
117 |
+
# Consistency check: Ensure expand_v produces integer values
|
118 |
+
if not math.isclose(self.key_dim * expand_v, self.value_dim, rel_tol=1e-5):
|
119 |
+
raise ValueError(
|
120 |
+
f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
|
121 |
+
f"Resulting value_dim would be {self.key_dim * expand_v}, which is invalid for nn.Linear."
|
122 |
+
)
|
123 |
+
if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
|
124 |
+
raise ValueError(
|
125 |
+
f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
|
126 |
+
f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated."
|
127 |
+
)
|
128 |
+
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
129 |
+
|
130 |
+
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
131 |
+
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
132 |
+
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
133 |
+
self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
|
134 |
+
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
|
135 |
+
|
136 |
+
A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
|
137 |
+
self.A_log = nn.Parameter(torch.log(A))
|
138 |
+
self.A_log._no_weight_decay = True
|
139 |
+
# hard coded for now
|
140 |
+
dt_min = 0.001
|
141 |
+
dt_max = 0.1
|
142 |
+
dt_init_floor = 1e-4
|
143 |
+
dt = torch.exp(
|
144 |
+
torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
|
145 |
+
+ math.log(dt_min)
|
146 |
+
)
|
147 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
148 |
+
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
149 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
150 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
151 |
+
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
|
152 |
+
# name.endswith("bias") in param_grouping.py
|
153 |
+
self.dt_bias._no_weight_decay = True
|
154 |
+
|
155 |
+
if use_short_conv:
|
156 |
+
self.conv_size = conv_size
|
157 |
+
self.q_conv1d = ShortConvolution(
|
158 |
+
hidden_size=self.key_dim,
|
159 |
+
kernel_size=conv_size,
|
160 |
+
activation='silu'
|
161 |
+
)
|
162 |
+
self.k_conv1d = ShortConvolution(
|
163 |
+
hidden_size=self.key_dim,
|
164 |
+
kernel_size=conv_size,
|
165 |
+
activation='silu'
|
166 |
+
)
|
167 |
+
self.v_conv1d = ShortConvolution(
|
168 |
+
hidden_size=self.value_dim,
|
169 |
+
kernel_size=conv_size,
|
170 |
+
activation='silu'
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
raise UserWarning(
|
174 |
+
"ShortConvolution is crucial to the performance. "
|
175 |
+
"Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
|
176 |
+
)
|
177 |
+
if use_gate:
|
178 |
+
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
179 |
+
self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
|
180 |
+
else:
|
181 |
+
self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
|
182 |
+
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
183 |
+
|
184 |
+
def forward(
|
185 |
+
self,
|
186 |
+
hidden_states: torch.Tensor,
|
187 |
+
attention_mask: Optional[torch.Tensor] = None,
|
188 |
+
past_key_values: Optional[Cache] = None,
|
189 |
+
use_cache: Optional[bool] = False,
|
190 |
+
output_attentions: Optional[bool] = False,
|
191 |
+
**kwargs: Unpack[Dict]
|
192 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
193 |
+
if attention_mask is not None:
|
194 |
+
assert len(attention_mask.shape) == 2, (
|
195 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
196 |
+
"for padding purposes (0 indicating padding). "
|
197 |
+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
198 |
+
)
|
199 |
+
|
200 |
+
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
|
201 |
+
if self.training:
|
202 |
+
assert mode == 'chunk', "Only chunk mode is supported in training."
|
203 |
+
|
204 |
+
last_state = None
|
205 |
+
if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
206 |
+
last_state = past_key_values[self.layer_idx]
|
207 |
+
|
208 |
+
cu_seqlens = kwargs.get('cu_seqlens', None)
|
209 |
+
if self.use_short_conv:
|
210 |
+
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
211 |
+
if last_state is not None:
|
212 |
+
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
|
213 |
+
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
|
214 |
+
q, conv_state_q = self.q_conv1d(
|
215 |
+
x=self.q_proj(hidden_states),
|
216 |
+
mask=conv_mask,
|
217 |
+
cache=conv_state_q,
|
218 |
+
output_final_state=use_cache,
|
219 |
+
cu_seqlens=cu_seqlens
|
220 |
+
)
|
221 |
+
k, conv_state_k = self.k_conv1d(
|
222 |
+
x=self.k_proj(hidden_states),
|
223 |
+
mask=conv_mask,
|
224 |
+
cache=conv_state_k,
|
225 |
+
output_final_state=use_cache,
|
226 |
+
cu_seqlens=cu_seqlens
|
227 |
+
)
|
228 |
+
v, conv_state_v = self.v_conv1d(
|
229 |
+
x=self.v_proj(hidden_states),
|
230 |
+
mask=conv_mask,
|
231 |
+
cache=conv_state_v,
|
232 |
+
output_final_state=use_cache,
|
233 |
+
cu_seqlens=cu_seqlens
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
q = F.silu(self.q_proj(hidden_states))
|
237 |
+
k = F.silu(self.k_proj(hidden_states))
|
238 |
+
v = F.silu(self.v_proj(hidden_states))
|
239 |
+
|
240 |
+
q, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (q, k))
|
241 |
+
v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
|
242 |
+
beta = self.b_proj(hidden_states).sigmoid()
|
243 |
+
g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
|
244 |
+
|
245 |
+
# dealing with padding
|
246 |
+
if attention_mask is not None:
|
247 |
+
beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
|
248 |
+
g = g.mul(attention_mask[:, -g.shape[-2]:, None])
|
249 |
+
|
250 |
+
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
|
251 |
+
if mode == 'chunk':
|
252 |
+
o, recurrent_state = chunk_gated_delta_rule(
|
253 |
+
q=q,
|
254 |
+
k=k,
|
255 |
+
v=v,
|
256 |
+
g=g,
|
257 |
+
beta=beta,
|
258 |
+
initial_state=recurrent_state,
|
259 |
+
output_final_state=use_cache,
|
260 |
+
cu_seqlens=cu_seqlens,
|
261 |
+
head_first=False,
|
262 |
+
use_qk_l2norm_in_kernel=True
|
263 |
+
)
|
264 |
+
elif mode == 'fused_recurrent':
|
265 |
+
o, recurrent_state = fused_recurrent_gated_delta_rule(
|
266 |
+
q=q,
|
267 |
+
k=k,
|
268 |
+
v=v,
|
269 |
+
g=g,
|
270 |
+
beta=beta,
|
271 |
+
initial_state=recurrent_state,
|
272 |
+
output_final_state=use_cache,
|
273 |
+
cu_seqlens=cu_seqlens,
|
274 |
+
head_first=False,
|
275 |
+
use_qk_l2norm_in_kernel=True
|
276 |
+
)
|
277 |
+
if past_key_values is not None:
|
278 |
+
past_key_values.update(
|
279 |
+
recurrent_state=recurrent_state,
|
280 |
+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
|
281 |
+
layer_idx=self.layer_idx,
|
282 |
+
offset=q.shape[1]
|
283 |
+
)
|
284 |
+
|
285 |
+
if self.use_gate:
|
286 |
+
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim)
|
287 |
+
o = self.o_norm(o, g)
|
288 |
+
else:
|
289 |
+
o = self.o_norm(o)
|
290 |
+
o = rearrange(o, 'b t h d -> b t (h d)')
|
291 |
+
o = self.o_proj(o)
|
292 |
+
|
293 |
+
return o, None, past_key_values
|
fla/layers/gated_deltaproduct.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
|
12 |
+
from fla.ops.delta_rule import chunk_delta_rule
|
13 |
+
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
14 |
+
|
15 |
+
if TYPE_CHECKING:
|
16 |
+
from transformers.processing_utils import Unpack
|
17 |
+
|
18 |
+
from fla.models.utils import Cache
|
19 |
+
|
20 |
+
|
21 |
+
def elu_p1(x):
|
22 |
+
return (F.elu(x, 1.0, False) + 1.0).to(x)
|
23 |
+
|
24 |
+
|
25 |
+
def sum_norm(x):
|
26 |
+
return (x / x.sum(-1, keepdim=True)).to(x)
|
27 |
+
|
28 |
+
|
29 |
+
def interleave_multiple_sequences(*sequences):
|
30 |
+
"""
|
31 |
+
Interleave multiple sequences together.
|
32 |
+
For example, with sequences [A1, A2], [B1, B2], [C1, C2],
|
33 |
+
returns [A1, B1, C1, A2, B2, C2]
|
34 |
+
"""
|
35 |
+
if isinstance(sequences[0], (list, tuple)):
|
36 |
+
sequences = sequences[0]
|
37 |
+
|
38 |
+
if len(sequences) == 1:
|
39 |
+
return sequences[0]
|
40 |
+
|
41 |
+
# All sequences should have the same shape
|
42 |
+
assert all(s.shape == sequences[0].shape for s in sequences)
|
43 |
+
|
44 |
+
# Get the original shape
|
45 |
+
batch_size, seq_len, *rest = sequences[0].shape
|
46 |
+
|
47 |
+
# Stack sequences along a new dimension
|
48 |
+
stacked = torch.stack(sequences, dim=2)
|
49 |
+
|
50 |
+
# Reshape to interleave
|
51 |
+
reshaped = stacked.view(batch_size, seq_len * len(sequences), *rest)
|
52 |
+
|
53 |
+
return reshaped
|
54 |
+
|
55 |
+
|
56 |
+
class GatedDeltaProduct(nn.Module):
|
57 |
+
"""
|
58 |
+
Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
hidden_size: int = 2048,
|
64 |
+
expand_v: float = 2,
|
65 |
+
head_dim: int = 256,
|
66 |
+
num_heads: int = 6,
|
67 |
+
num_householder: int = 2, # New parameter for number of householder transformations
|
68 |
+
mode: str = "chunk",
|
69 |
+
use_gate: bool = True,
|
70 |
+
use_forget_gate: bool = True, # when true Gated DeltaProduct, when false DeltaProduct
|
71 |
+
use_short_conv: bool = True,
|
72 |
+
conv_size: int = 4,
|
73 |
+
conv_bias: bool = False,
|
74 |
+
layer_idx: int | None = None,
|
75 |
+
norm_eps: float = 1e-5,
|
76 |
+
allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
|
77 |
+
**kwargs,
|
78 |
+
) -> None:
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
self.mode = mode
|
82 |
+
self.hidden_size = hidden_size
|
83 |
+
self.expand_v = expand_v
|
84 |
+
self.use_gate = use_gate
|
85 |
+
self.use_short_conv = use_short_conv
|
86 |
+
self.conv_size = conv_size
|
87 |
+
self.conv_bias = conv_bias
|
88 |
+
self.head_dim = head_dim
|
89 |
+
self.num_heads = num_heads
|
90 |
+
self.num_householder = num_householder
|
91 |
+
self.allow_neg_eigval = allow_neg_eigval
|
92 |
+
self.use_forget_gate = use_forget_gate
|
93 |
+
self.key_dim = self.num_heads * self.head_dim
|
94 |
+
self.value_dim = int(self.key_dim * self.expand_v)
|
95 |
+
self.head_qk_dim = head_dim
|
96 |
+
self.head_v_dim = int(head_dim * self.expand_v)
|
97 |
+
self.layer_idx = layer_idx
|
98 |
+
self.silu = nn.SiLU()
|
99 |
+
assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`."
|
100 |
+
# Create multiple projection layers for each householder transformation
|
101 |
+
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
102 |
+
|
103 |
+
self.k_projs = nn.ModuleList(
|
104 |
+
[
|
105 |
+
nn.Linear(hidden_size, self.key_dim, bias=False)
|
106 |
+
for _ in range(num_householder)
|
107 |
+
]
|
108 |
+
)
|
109 |
+
self.v_projs = nn.ModuleList(
|
110 |
+
[
|
111 |
+
nn.Linear(hidden_size, self.value_dim, bias=False)
|
112 |
+
for _ in range(num_householder)
|
113 |
+
]
|
114 |
+
)
|
115 |
+
self.b_projs = nn.ModuleList(
|
116 |
+
[
|
117 |
+
nn.Linear(hidden_size, self.num_heads, bias=False)
|
118 |
+
for _ in range(num_householder)
|
119 |
+
]
|
120 |
+
)
|
121 |
+
if use_short_conv:
|
122 |
+
self.q_conv1ds = nn.ModuleList(
|
123 |
+
[
|
124 |
+
ShortConvolution(
|
125 |
+
hidden_size=self.key_dim,
|
126 |
+
kernel_size=conv_size,
|
127 |
+
activation="silu",
|
128 |
+
)
|
129 |
+
for _ in range(num_householder)
|
130 |
+
]
|
131 |
+
)
|
132 |
+
self.k_conv1ds = nn.ModuleList(
|
133 |
+
[
|
134 |
+
ShortConvolution(
|
135 |
+
hidden_size=self.key_dim,
|
136 |
+
kernel_size=conv_size,
|
137 |
+
activation="silu",
|
138 |
+
)
|
139 |
+
for _ in range(num_householder)
|
140 |
+
]
|
141 |
+
)
|
142 |
+
self.v_conv1ds = nn.ModuleList(
|
143 |
+
[
|
144 |
+
ShortConvolution(
|
145 |
+
hidden_size=self.value_dim,
|
146 |
+
kernel_size=conv_size,
|
147 |
+
activation="silu",
|
148 |
+
)
|
149 |
+
for _ in range(num_householder)
|
150 |
+
]
|
151 |
+
)
|
152 |
+
|
153 |
+
if self.use_forget_gate:
|
154 |
+
self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
|
155 |
+
A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
|
156 |
+
A_log = torch.log(A)
|
157 |
+
self.A_log = nn.Parameter(A_log)
|
158 |
+
self.A_log._no_weight_decay = True
|
159 |
+
|
160 |
+
# Initialize dt parameters
|
161 |
+
dt_min = 0.001
|
162 |
+
dt_max = 0.1
|
163 |
+
dt_init_floor = 1e-4
|
164 |
+
dt = torch.exp(
|
165 |
+
torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
|
166 |
+
+ math.log(dt_min)
|
167 |
+
)
|
168 |
+
dt = torch.clamp(dt, min=dt_init_floor)
|
169 |
+
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
170 |
+
self.dt_bias = nn.Parameter(inv_dt)
|
171 |
+
self.dt_bias._no_weight_decay = True
|
172 |
+
|
173 |
+
if use_gate:
|
174 |
+
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
175 |
+
self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
|
176 |
+
else:
|
177 |
+
self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
|
178 |
+
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
179 |
+
self.k_id = torch.nn.Identity()
|
180 |
+
self.apply(self._initialize_weights)
|
181 |
+
|
182 |
+
def _initialize_weights(self, module: nn.Module):
|
183 |
+
if getattr(module, "_is_hf_initialized", False):
|
184 |
+
return
|
185 |
+
if isinstance(module, nn.Linear):
|
186 |
+
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
187 |
+
if module.bias is not None:
|
188 |
+
nn.init.zeros_(module.bias)
|
189 |
+
module._is_hf_initialized = True
|
190 |
+
|
191 |
+
def forward(
|
192 |
+
self,
|
193 |
+
hidden_states: torch.Tensor,
|
194 |
+
attention_mask: Optional[torch.Tensor] = None,
|
195 |
+
past_key_values: Optional[Cache] = None,
|
196 |
+
use_cache: Optional[bool] = False,
|
197 |
+
output_attentions: Optional[bool] = False,
|
198 |
+
**kwargs: Unpack[Dict],
|
199 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
200 |
+
if attention_mask is not None:
|
201 |
+
assert len(attention_mask.shape) == 2, (
|
202 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
203 |
+
"for padding purposes (0 indicating padding)."
|
204 |
+
)
|
205 |
+
|
206 |
+
mode = (
|
207 |
+
"chunk" # 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
|
208 |
+
)
|
209 |
+
if self.training:
|
210 |
+
assert mode == "chunk", "Only chunk mode is supported in training."
|
211 |
+
|
212 |
+
last_state = None
|
213 |
+
if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
214 |
+
last_state = past_key_values[self.layer_idx]
|
215 |
+
|
216 |
+
# Process each householder transformation
|
217 |
+
ks, vs, betas = [], [], []
|
218 |
+
conv_states = []
|
219 |
+
|
220 |
+
for i in range(self.num_householder):
|
221 |
+
if self.use_short_conv:
|
222 |
+
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
223 |
+
if last_state is not None:
|
224 |
+
conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"][
|
225 |
+
i
|
226 |
+
]
|
227 |
+
conv_mask = (
|
228 |
+
attention_mask[:, -hidden_states.shape[1]:]
|
229 |
+
if attention_mask is not None
|
230 |
+
else None
|
231 |
+
)
|
232 |
+
|
233 |
+
k, conv_state_k = self.k_conv1ds[i](
|
234 |
+
x=self.k_projs[i](hidden_states),
|
235 |
+
mask=conv_mask,
|
236 |
+
cache=conv_state_k,
|
237 |
+
output_final_state=use_cache,
|
238 |
+
)
|
239 |
+
v, conv_state_v = self.v_conv1ds[i](
|
240 |
+
x=self.v_projs[i](hidden_states),
|
241 |
+
mask=conv_mask,
|
242 |
+
cache=conv_state_v,
|
243 |
+
output_final_state=use_cache,
|
244 |
+
)
|
245 |
+
conv_states.append((conv_state_q, conv_state_k, conv_state_v))
|
246 |
+
else:
|
247 |
+
k = self.silu(self.k_projs[i](hidden_states))
|
248 |
+
v = self.silu(self.v_projs[i](hidden_states))
|
249 |
+
|
250 |
+
ks.append(k)
|
251 |
+
vs.append(v)
|
252 |
+
|
253 |
+
beta = self.b_projs[i](
|
254 |
+
hidden_states
|
255 |
+
).sigmoid() # bs, sequence_length, num_heads
|
256 |
+
if attention_mask is not None:
|
257 |
+
beta = beta.mul(attention_mask[:, -hidden_states.shape[1]:, None])
|
258 |
+
if self.allow_neg_eigval:
|
259 |
+
beta = beta * 2
|
260 |
+
betas.append(beta)
|
261 |
+
|
262 |
+
if self.use_short_conv:
|
263 |
+
q, conv_state_q = self.q_conv1ds[0](
|
264 |
+
x=self.q_proj(hidden_states),
|
265 |
+
mask=conv_mask,
|
266 |
+
cache=conv_state_q,
|
267 |
+
output_final_state=use_cache,
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
q = self.silu(self.q_proj(hidden_states))
|
271 |
+
q = interleave_multiple_sequences(
|
272 |
+
[torch.zeros_like(q)] * (self.num_householder - 1) + [q]
|
273 |
+
)
|
274 |
+
# Interleave all sequences
|
275 |
+
k = interleave_multiple_sequences(ks)
|
276 |
+
v = interleave_multiple_sequences(vs)
|
277 |
+
beta = interleave_multiple_sequences(betas)
|
278 |
+
|
279 |
+
q, k, v = (
|
280 |
+
rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in (q, k, v)
|
281 |
+
)
|
282 |
+
|
283 |
+
recurrent_state = (
|
284 |
+
last_state["recurrent_state"] if last_state is not None else None
|
285 |
+
)
|
286 |
+
offsets = kwargs.get("offsets")
|
287 |
+
|
288 |
+
if mode == "chunk":
|
289 |
+
if self.use_forget_gate:
|
290 |
+
g = -self.A_log.float().exp() * F.softplus(
|
291 |
+
self.a_proj(hidden_states).float() + self.dt_bias
|
292 |
+
)
|
293 |
+
if attention_mask is not None:
|
294 |
+
g = g.mul(attention_mask[:, -g.shape[-2]:, None])
|
295 |
+
|
296 |
+
# Interleave g with zeros for non-first transformations
|
297 |
+
g = interleave_multiple_sequences(
|
298 |
+
[g] + [torch.zeros_like(g)] * (self.num_householder - 1)
|
299 |
+
)
|
300 |
+
|
301 |
+
o, recurrent_state = chunk_gated_delta_rule(
|
302 |
+
q=q,
|
303 |
+
k=k,
|
304 |
+
v=v,
|
305 |
+
g=g,
|
306 |
+
beta=beta,
|
307 |
+
initial_state=recurrent_state,
|
308 |
+
output_final_state=use_cache,
|
309 |
+
cu_seqlens=offsets,
|
310 |
+
head_first=False,
|
311 |
+
use_qk_l2norm_in_kernel=True
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
o, recurrent_state = chunk_delta_rule(
|
315 |
+
q=q,
|
316 |
+
k=k,
|
317 |
+
v=v,
|
318 |
+
beta=beta,
|
319 |
+
initial_state=recurrent_state,
|
320 |
+
output_final_state=use_cache,
|
321 |
+
cu_seqlens=offsets,
|
322 |
+
head_first=False,
|
323 |
+
use_qk_l2norm_in_kernel=True
|
324 |
+
)
|
325 |
+
else:
|
326 |
+
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
327 |
+
|
328 |
+
# Take every nth element for n householder transformations
|
329 |
+
o = o[:, self.num_householder - 1:: self.num_householder, :]
|
330 |
+
|
331 |
+
if past_key_values is not None:
|
332 |
+
past_key_values.update(
|
333 |
+
recurrent_state=recurrent_state,
|
334 |
+
conv_state=conv_states if self.use_short_conv else None,
|
335 |
+
layer_idx=self.layer_idx,
|
336 |
+
offset=q.shape[2],
|
337 |
+
)
|
338 |
+
|
339 |
+
if self.use_gate:
|
340 |
+
g = rearrange(
|
341 |
+
self.g_proj(hidden_states),
|
342 |
+
"... (h d) -> ... h d",
|
343 |
+
h=self.num_heads,
|
344 |
+
)
|
345 |
+
o = self.o_norm(o, g)
|
346 |
+
else:
|
347 |
+
o = self.o_norm(o)
|
348 |
+
o = rearrange(o, "b t h d -> b t (h d)")
|
349 |
+
o = self.o_proj(o)
|
350 |
+
|
351 |
+
return o, None, past_key_values
|
fla/layers/gla.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
|
14 |
+
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
|
15 |
+
from fla.modules.activations import ACT2FN
|
16 |
+
from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
|
17 |
+
|
18 |
+
if TYPE_CHECKING:
|
19 |
+
from transformers.processing_utils import Unpack
|
20 |
+
|
21 |
+
from fla.models.utils import Cache
|
22 |
+
|
23 |
+
|
24 |
+
class GatedLinearAttention(nn.Module):
|
25 |
+
r"""
|
26 |
+
The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
|
27 |
+
|
28 |
+
Args:
|
29 |
+
mode (str, Optional):
|
30 |
+
Which GLA kernel to use.
|
31 |
+
Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`.
|
32 |
+
Default: `chunk`.
|
33 |
+
hidden_size (int, Optional):
|
34 |
+
The hidden size of the input. Default: 1024.
|
35 |
+
expand_k (float, Optional):
|
36 |
+
The expansion ratio for the key dim. Default: 0.5.
|
37 |
+
expand_v (float, Optional):
|
38 |
+
The expansion ratio for the value dim. Default: 1.0.
|
39 |
+
num_heads (int, Optional):
|
40 |
+
The number of heads. Default: 4.
|
41 |
+
num_kv_heads (int, Optional):
|
42 |
+
The number of key/value heads, used for MQA. Default: None.
|
43 |
+
feature_map (str, Optional):
|
44 |
+
Feature map function applied to queries/keys. Default: None.
|
45 |
+
use_short_conv (bool, Optional):
|
46 |
+
Whether to use short convolutions. Default: `False`.
|
47 |
+
conv_size (int, Optional):
|
48 |
+
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
|
49 |
+
conv_bias (bool, Optional):
|
50 |
+
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
|
51 |
+
use_output_gate (bool, Optional):
|
52 |
+
Whether to use output gate. Default: `True`.
|
53 |
+
gate_fn (str, Optional):
|
54 |
+
The activation function for the output gate. Default: `swish`.
|
55 |
+
elementwise_affine (bool, Optional):
|
56 |
+
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
|
57 |
+
norm_eps (float, Optional):
|
58 |
+
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
|
59 |
+
gate_logit_normalizer (int, Optional):
|
60 |
+
The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
|
61 |
+
gate_low_rank_dim (int, Optional):
|
62 |
+
The low rank dim for the gate projection. Default: 16.
|
63 |
+
clamp_min (float, Optional):
|
64 |
+
The minimum value for the gate logits. Default: None.
|
65 |
+
fuse_norm (bool, Optional):
|
66 |
+
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
|
67 |
+
layer_idx (int, Optional):
|
68 |
+
The index of the layer. Default: None.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
mode: str = 'chunk',
|
74 |
+
hidden_size: int = 1024,
|
75 |
+
expand_k: float = 0.5,
|
76 |
+
expand_v: float = 1.0,
|
77 |
+
num_heads: int = 4,
|
78 |
+
num_kv_heads: Optional[int] = None,
|
79 |
+
feature_map: Optional[str] = None,
|
80 |
+
use_short_conv: bool = False,
|
81 |
+
conv_size: int = 4,
|
82 |
+
conv_bias: bool = False,
|
83 |
+
use_output_gate: bool = True,
|
84 |
+
gate_fn: str = 'swish',
|
85 |
+
elementwise_affine: Optional[bool] = True,
|
86 |
+
norm_eps: float = 1e-5,
|
87 |
+
gate_logit_normalizer: int = 16,
|
88 |
+
gate_low_rank_dim: int = 16,
|
89 |
+
clamp_min: Optional[float] = None,
|
90 |
+
fuse_norm: bool = True,
|
91 |
+
layer_idx: int = None,
|
92 |
+
) -> GatedLinearAttention:
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
self.mode = mode
|
96 |
+
self.hidden_size = hidden_size
|
97 |
+
self.expand_k = expand_k
|
98 |
+
self.expand_v = expand_v
|
99 |
+
self.num_heads = num_heads
|
100 |
+
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
|
101 |
+
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
102 |
+
self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
|
103 |
+
|
104 |
+
self.use_short_conv = use_short_conv
|
105 |
+
self.conv_size = conv_size
|
106 |
+
self.conv_bias = conv_bias
|
107 |
+
self.use_output_gate = use_output_gate
|
108 |
+
|
109 |
+
self.key_dim = int(hidden_size * expand_k)
|
110 |
+
self.value_dim = int(hidden_size * expand_v)
|
111 |
+
self.key_dim_per_group = self.key_dim // self.num_kv_groups
|
112 |
+
self.value_dim_per_group = self.value_dim // self.num_kv_groups
|
113 |
+
self.clamp_min = clamp_min
|
114 |
+
self.layer_idx = layer_idx
|
115 |
+
|
116 |
+
assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`."
|
117 |
+
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
118 |
+
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
119 |
+
|
120 |
+
self.head_k_dim = self.key_dim // num_heads
|
121 |
+
self.head_v_dim = self.value_dim // num_heads
|
122 |
+
|
123 |
+
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
124 |
+
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
|
125 |
+
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
|
126 |
+
if self.use_output_gate:
|
127 |
+
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
128 |
+
|
129 |
+
if use_short_conv:
|
130 |
+
self.conv_size = conv_size
|
131 |
+
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
132 |
+
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
|
133 |
+
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
|
134 |
+
|
135 |
+
self.gk_proj = nn.Sequential(nn.Linear(hidden_size, gate_low_rank_dim, bias=False),
|
136 |
+
nn.Linear(gate_low_rank_dim, self.key_dim_per_group, bias=True))
|
137 |
+
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
138 |
+
|
139 |
+
if gate_fn == 'swish' and fuse_norm and use_output_gate:
|
140 |
+
self.g_norm_swish_gate = FusedRMSNormGated(
|
141 |
+
hidden_size=self.head_v_dim,
|
142 |
+
elementwise_affine=elementwise_affine,
|
143 |
+
eps=norm_eps
|
144 |
+
)
|
145 |
+
self.fuse_norm_and_gate = True
|
146 |
+
else:
|
147 |
+
self.fuse_norm_and_gate = False
|
148 |
+
self.g_norm = RMSNorm(
|
149 |
+
hidden_size=self.head_v_dim,
|
150 |
+
elementwise_affine=elementwise_affine,
|
151 |
+
eps=norm_eps
|
152 |
+
)
|
153 |
+
self.gate_fn = ACT2FN[gate_fn]
|
154 |
+
|
155 |
+
self.gate_logit_normalizer = gate_logit_normalizer
|
156 |
+
|
157 |
+
def forward(
|
158 |
+
self,
|
159 |
+
hidden_states: torch.Tensor,
|
160 |
+
attention_mask: Optional[torch.Tensor] = None,
|
161 |
+
past_key_values: Optional[Cache] = None,
|
162 |
+
use_cache: Optional[bool] = False,
|
163 |
+
output_attentions: Optional[bool] = False,
|
164 |
+
**kwargs: Unpack[Dict]
|
165 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
166 |
+
if attention_mask is not None:
|
167 |
+
assert len(attention_mask.shape) == 2, (
|
168 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
169 |
+
"for padding purposes (0 indicating padding). "
|
170 |
+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
171 |
+
)
|
172 |
+
|
173 |
+
# launching the triton kernel for just one token will actually be slower
|
174 |
+
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
|
175 |
+
|
176 |
+
last_state = None
|
177 |
+
if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
178 |
+
last_state = past_key_values[self.layer_idx]
|
179 |
+
|
180 |
+
cu_seqlens = kwargs.get('cu_seqlens', None)
|
181 |
+
if self.use_short_conv:
|
182 |
+
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
183 |
+
if last_state is not None:
|
184 |
+
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
|
185 |
+
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
|
186 |
+
q, conv_state_q = self.q_conv1d(
|
187 |
+
x=self.q_proj(hidden_states),
|
188 |
+
mask=conv_mask,
|
189 |
+
cache=conv_state_q,
|
190 |
+
output_final_state=use_cache,
|
191 |
+
cu_seqlens=cu_seqlens
|
192 |
+
)
|
193 |
+
k, conv_state_k = self.k_conv1d(
|
194 |
+
x=self.k_proj(hidden_states),
|
195 |
+
mask=conv_mask,
|
196 |
+
cache=conv_state_k,
|
197 |
+
output_final_state=use_cache,
|
198 |
+
cu_seqlens=cu_seqlens
|
199 |
+
)
|
200 |
+
v, conv_state_v = self.v_conv1d(
|
201 |
+
x=self.v_proj(hidden_states),
|
202 |
+
mask=conv_mask,
|
203 |
+
cache=conv_state_v,
|
204 |
+
output_final_state=use_cache,
|
205 |
+
cu_seqlens=cu_seqlens
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
q = self.q_proj(hidden_states)
|
209 |
+
k = self.k_proj(hidden_states)
|
210 |
+
v = self.v_proj(hidden_states)
|
211 |
+
gk = self.gk_proj(hidden_states)
|
212 |
+
|
213 |
+
if self.feature_map_fn is not None:
|
214 |
+
q, k = map(self.feature_map_fn, (q, k))
|
215 |
+
# dealing with left-padding
|
216 |
+
if attention_mask is not None:
|
217 |
+
v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
|
218 |
+
q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
|
219 |
+
if self.num_kv_groups > 1:
|
220 |
+
k, gk = (repeat(x, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_k_dim) for x in (k, gk))
|
221 |
+
v = repeat(v, 'b t (h d) -> b t (h g) d', g=self.num_kv_groups, d=self.head_v_dim)
|
222 |
+
else:
|
223 |
+
k, gk = (rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim) for x in (k, gk))
|
224 |
+
v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
|
225 |
+
gk = F.logsigmoid(gk) / self.gate_logit_normalizer
|
226 |
+
|
227 |
+
if self.clamp_min is not None:
|
228 |
+
gk = torch.clamp_min(gk, self.clamp_min)
|
229 |
+
|
230 |
+
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
|
231 |
+
if mode == 'fused_recurrent':
|
232 |
+
o, recurrent_state = fused_recurrent_gla(
|
233 |
+
q=q,
|
234 |
+
k=k,
|
235 |
+
v=v,
|
236 |
+
gk=gk,
|
237 |
+
initial_state=recurrent_state,
|
238 |
+
output_final_state=use_cache,
|
239 |
+
cu_seqlens=cu_seqlens,
|
240 |
+
head_first=False
|
241 |
+
)
|
242 |
+
elif mode == 'fused_chunk':
|
243 |
+
o, recurrent_state = fused_chunk_gla(
|
244 |
+
q=q,
|
245 |
+
k=k,
|
246 |
+
v=v,
|
247 |
+
g=gk,
|
248 |
+
initial_state=recurrent_state,
|
249 |
+
output_final_state=use_cache,
|
250 |
+
head_first=False
|
251 |
+
)
|
252 |
+
elif mode == 'chunk':
|
253 |
+
o, recurrent_state = chunk_gla(
|
254 |
+
q=q,
|
255 |
+
k=k,
|
256 |
+
v=v,
|
257 |
+
g=gk,
|
258 |
+
initial_state=recurrent_state,
|
259 |
+
output_final_state=use_cache,
|
260 |
+
cu_seqlens=cu_seqlens,
|
261 |
+
head_first=False
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
265 |
+
|
266 |
+
if past_key_values is not None:
|
267 |
+
past_key_values.update(
|
268 |
+
recurrent_state=recurrent_state,
|
269 |
+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
|
270 |
+
layer_idx=self.layer_idx,
|
271 |
+
offset=q.shape[1]
|
272 |
+
)
|
273 |
+
|
274 |
+
if self.use_output_gate:
|
275 |
+
g = self.g_proj(hidden_states)
|
276 |
+
if self.fuse_norm_and_gate:
|
277 |
+
g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim)
|
278 |
+
o = self.g_norm_swish_gate(o, g)
|
279 |
+
o = rearrange(o, 'b t h d -> b t (h d)')
|
280 |
+
else:
|
281 |
+
o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
|
282 |
+
o = o * self.gate_fn(g)
|
283 |
+
else:
|
284 |
+
o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
|
285 |
+
o = self.o_proj(o)
|
286 |
+
|
287 |
+
return o, None, past_key_values
|
288 |
+
|
289 |
+
def state_size(self, **kwargs) -> int:
|
290 |
+
state_size = self.key_dim * self.head_v_dim
|
291 |
+
for module in self.children():
|
292 |
+
if isinstance(module, ShortConvolution):
|
293 |
+
state_size += module.state_size
|
294 |
+
return state_size
|
fla/layers/gsa.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from fla.modules import RMSNorm, ShortConvolution
|
15 |
+
from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap
|
16 |
+
from fla.modules.layernorm import rms_norm_linear
|
17 |
+
from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa
|
18 |
+
|
19 |
+
if TYPE_CHECKING:
|
20 |
+
from transformers.processing_utils import Unpack
|
21 |
+
|
22 |
+
from fla.models.utils import Cache
|
23 |
+
|
24 |
+
|
25 |
+
class GatedSlotAttention(nn.Module):
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
mode: str = 'chunk',
|
30 |
+
hidden_size: int = 1024,
|
31 |
+
expand_k: float = 1.,
|
32 |
+
expand_v: float = 1.,
|
33 |
+
num_heads: int = 4,
|
34 |
+
num_kv_heads: Optional[int] = None,
|
35 |
+
use_short_conv: bool = False,
|
36 |
+
conv_size: int = 4,
|
37 |
+
conv_bias: bool = False,
|
38 |
+
num_slots: Optional[int] = None,
|
39 |
+
elementwise_affine: Optional[bool] = True,
|
40 |
+
norm_eps: float = 1e-5,
|
41 |
+
gate_logit_normalizer: int = 8,
|
42 |
+
feature_map: str = 'swish',
|
43 |
+
use_output_gate: bool = False,
|
44 |
+
use_norm: bool = True,
|
45 |
+
layer_idx: Optional[int] = None,
|
46 |
+
scale: Optional[float] = 1.,
|
47 |
+
**kwargs
|
48 |
+
) -> GatedSlotAttention:
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
self.mode = mode
|
52 |
+
self.hidden_size = hidden_size
|
53 |
+
self.expand_k = expand_k
|
54 |
+
self.expand_v = expand_v
|
55 |
+
self.num_heads = num_heads
|
56 |
+
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
57 |
+
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
58 |
+
self.key_dim = int(hidden_size * expand_k)
|
59 |
+
self.value_dim = int(hidden_size * expand_v)
|
60 |
+
self.key_dim_per_group = self.key_dim // self.num_kv_groups
|
61 |
+
self.value_dim_per_group = self.value_dim // self.num_kv_groups
|
62 |
+
self.head_k_dim = self.key_dim // self.num_heads
|
63 |
+
self.head_v_dim = self.value_dim // self.num_heads
|
64 |
+
|
65 |
+
self.use_short_conv = use_short_conv
|
66 |
+
self.conv_size = conv_size
|
67 |
+
self.conv_bias = conv_bias
|
68 |
+
|
69 |
+
self.gate_logit_normalizer = gate_logit_normalizer
|
70 |
+
|
71 |
+
self.use_output_gate = use_output_gate
|
72 |
+
self.use_norm = use_norm
|
73 |
+
self.scale = scale
|
74 |
+
|
75 |
+
if num_slots is None:
|
76 |
+
num_slots = self.head_k_dim
|
77 |
+
self.num_slots = num_slots
|
78 |
+
|
79 |
+
self.layer_idx = layer_idx
|
80 |
+
|
81 |
+
if layer_idx is None:
|
82 |
+
warnings.warn(
|
83 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
84 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
85 |
+
"when creating this class."
|
86 |
+
)
|
87 |
+
|
88 |
+
self.register_module('feature_map', None)
|
89 |
+
if feature_map == 'swish':
|
90 |
+
self.feature_map = SwishFeatureMap()
|
91 |
+
elif feature_map == 'relu':
|
92 |
+
self.feature_map = ReLUFeatureMap()
|
93 |
+
elif feature_map == 't2r':
|
94 |
+
self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim)
|
95 |
+
else:
|
96 |
+
raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.")
|
97 |
+
|
98 |
+
self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False)
|
99 |
+
self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False)
|
100 |
+
self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False)
|
101 |
+
self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False)
|
102 |
+
|
103 |
+
if use_short_conv:
|
104 |
+
self.conv_size = conv_size
|
105 |
+
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
106 |
+
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
|
107 |
+
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
|
108 |
+
|
109 |
+
self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps)
|
110 |
+
self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
hidden_states: torch.Tensor,
|
115 |
+
attention_mask: Optional[torch.Tensor] = None,
|
116 |
+
past_key_values: Optional[Cache] = None,
|
117 |
+
use_cache: Optional[bool] = False,
|
118 |
+
output_attentions: Optional[bool] = False,
|
119 |
+
**kwargs: Unpack[Dict]
|
120 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
121 |
+
if attention_mask is not None:
|
122 |
+
assert len(attention_mask.shape) == 2, (
|
123 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
124 |
+
"for padding purposes (0 indicating padding). "
|
125 |
+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
126 |
+
)
|
127 |
+
|
128 |
+
# launching the triton kernel for just one token will actually be slower
|
129 |
+
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
|
130 |
+
|
131 |
+
last_state = None
|
132 |
+
if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
133 |
+
last_state = past_key_values[self.layer_idx]
|
134 |
+
|
135 |
+
cu_seqlens = kwargs.get('cu_seqlens', None)
|
136 |
+
if self.use_short_conv:
|
137 |
+
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
138 |
+
if last_state is not None:
|
139 |
+
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
|
140 |
+
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
|
141 |
+
q, conv_state_q = self.q_conv1d(
|
142 |
+
x=self.q_proj(hidden_states),
|
143 |
+
mask=conv_mask,
|
144 |
+
cache=conv_state_q,
|
145 |
+
output_final_state=use_cache,
|
146 |
+
cu_seqlens=cu_seqlens
|
147 |
+
)
|
148 |
+
k, conv_state_k = self.k_conv1d(
|
149 |
+
x=self.k_proj(hidden_states),
|
150 |
+
mask=conv_mask,
|
151 |
+
cache=conv_state_k,
|
152 |
+
output_final_state=use_cache,
|
153 |
+
cu_seqlens=cu_seqlens
|
154 |
+
)
|
155 |
+
v, conv_state_v = self.v_conv1d(
|
156 |
+
x=self.v_proj(hidden_states),
|
157 |
+
mask=conv_mask,
|
158 |
+
cache=conv_state_v,
|
159 |
+
output_final_state=use_cache,
|
160 |
+
cu_seqlens=cu_seqlens
|
161 |
+
)
|
162 |
+
else:
|
163 |
+
q = self.q_proj(hidden_states)
|
164 |
+
k = self.k_proj(hidden_states)
|
165 |
+
v = self.v_proj(hidden_states)
|
166 |
+
f = self.f_proj(hidden_states)
|
167 |
+
|
168 |
+
q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim)
|
169 |
+
k = rearrange(k, 'b t (h d) -> b t h d', d=self.head_k_dim)
|
170 |
+
v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
|
171 |
+
f = rearrange(f, 'b t (h m) -> b t h m', m=self.num_slots)
|
172 |
+
|
173 |
+
if self.feature_map is not None:
|
174 |
+
q, k = map(lambda x: self.feature_map(x), (q, k))
|
175 |
+
v = F.silu(v)
|
176 |
+
|
177 |
+
f = F.logsigmoid(f) / self.gate_logit_normalizer
|
178 |
+
s = (1 - f.exp()).to(f.dtype)
|
179 |
+
# dealing with left-padding
|
180 |
+
if attention_mask is not None:
|
181 |
+
s = s.mul_(attention_mask[:, -s.shape[1]:, None, None])
|
182 |
+
v = v.mul_(attention_mask[:, -v.shape[1]:, None, None])
|
183 |
+
|
184 |
+
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
|
185 |
+
if mode == 'fused_recurrent':
|
186 |
+
o, recurrent_state = fused_recurrent_gsa(
|
187 |
+
q=q,
|
188 |
+
k=k,
|
189 |
+
v=v,
|
190 |
+
s=s,
|
191 |
+
g=f,
|
192 |
+
initial_state=recurrent_state,
|
193 |
+
output_final_state=use_cache,
|
194 |
+
scale=self.scale,
|
195 |
+
cu_seqlens=cu_seqlens,
|
196 |
+
head_first=False
|
197 |
+
)
|
198 |
+
elif mode == 'chunk':
|
199 |
+
o, recurrent_state = chunk_gsa(
|
200 |
+
q=q,
|
201 |
+
k=k,
|
202 |
+
v=v,
|
203 |
+
s=s,
|
204 |
+
g=f,
|
205 |
+
initial_state=recurrent_state,
|
206 |
+
output_final_state=use_cache,
|
207 |
+
scale=self.scale,
|
208 |
+
cu_seqlens=cu_seqlens,
|
209 |
+
head_first=False
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
213 |
+
|
214 |
+
if past_key_values is not None:
|
215 |
+
past_key_values.update(
|
216 |
+
recurrent_state=recurrent_state,
|
217 |
+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
|
218 |
+
layer_idx=self.layer_idx,
|
219 |
+
offset=q.shape[1]
|
220 |
+
)
|
221 |
+
|
222 |
+
o = rearrange(o, 'b t h d -> b t (h d)')
|
223 |
+
o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias)
|
224 |
+
return o, None, past_key_values
|
225 |
+
|
226 |
+
def state_size(self, *args, **kwargs) -> int:
|
227 |
+
return 2 * self.num_slots * self.hidden_size
|
fla/layers/hgrn.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
# "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from fla.modules import FusedRMSNormGated, ShortConvolution
|
15 |
+
from fla.modules.activations import swiglu
|
16 |
+
from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
|
17 |
+
|
18 |
+
if TYPE_CHECKING:
|
19 |
+
from transformers.processing_utils import Unpack
|
20 |
+
|
21 |
+
from fla.models.utils import Cache
|
22 |
+
|
23 |
+
|
24 |
+
class HGRNAttention(nn.Module):
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
mode: str = 'chunk',
|
29 |
+
hidden_size: int = 1024,
|
30 |
+
expand_ratio: Optional[int] = 1,
|
31 |
+
use_short_conv: bool = False,
|
32 |
+
conv_size: int = 4,
|
33 |
+
conv_bias: bool = False,
|
34 |
+
elementwise_affine: Optional[bool] = True,
|
35 |
+
norm_eps: float = 1e-5,
|
36 |
+
layer_idx: int = None
|
37 |
+
) -> HGRNAttention:
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.mode = mode
|
41 |
+
self.hidden_size = hidden_size
|
42 |
+
self.expand_ratio = expand_ratio
|
43 |
+
self.input_dim = int(hidden_size * expand_ratio)
|
44 |
+
|
45 |
+
self.use_short_conv = use_short_conv
|
46 |
+
self.conv_size = conv_size
|
47 |
+
self.conv_bias = conv_bias
|
48 |
+
|
49 |
+
self.layer_idx = layer_idx
|
50 |
+
|
51 |
+
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
52 |
+
|
53 |
+
self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
|
54 |
+
self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
|
55 |
+
self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
|
56 |
+
|
57 |
+
if use_short_conv:
|
58 |
+
self.conv_size = conv_size
|
59 |
+
self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
|
60 |
+
self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
|
61 |
+
self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
|
62 |
+
|
63 |
+
self.g_norm = FusedRMSNormGated(
|
64 |
+
hidden_size=self.input_dim,
|
65 |
+
elementwise_affine=elementwise_affine,
|
66 |
+
eps=norm_eps
|
67 |
+
)
|
68 |
+
self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
|
69 |
+
|
70 |
+
def forward(
|
71 |
+
self,
|
72 |
+
hidden_states: torch.Tensor,
|
73 |
+
attention_mask: Optional[torch.Tensor] = None,
|
74 |
+
past_key_values: Optional[Cache] = None,
|
75 |
+
use_cache: Optional[bool] = False,
|
76 |
+
output_attentions: Optional[bool] = False,
|
77 |
+
lower_bound: Optional[torch.Tensor] = None,
|
78 |
+
**kwargs: Unpack[Dict]
|
79 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
80 |
+
if attention_mask is not None:
|
81 |
+
assert len(attention_mask.shape) == 2, (
|
82 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
83 |
+
"for padding purposes (0 indicating padding). "
|
84 |
+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
85 |
+
)
|
86 |
+
|
87 |
+
# launching the triton kernel for just one token will actually be slower
|
88 |
+
mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode
|
89 |
+
|
90 |
+
last_state = None
|
91 |
+
if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
92 |
+
last_state = past_key_values[self.layer_idx]
|
93 |
+
|
94 |
+
cu_seqlens = kwargs.get('cu_seqlens', None)
|
95 |
+
if self.use_short_conv:
|
96 |
+
conv_state_i, conv_state_f = None, None
|
97 |
+
if last_state is not None:
|
98 |
+
conv_state_i, conv_state_f = last_state['conv_state']
|
99 |
+
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
|
100 |
+
i, conv_state_i = self.i_conv1d(
|
101 |
+
x=self.i_proj(hidden_states),
|
102 |
+
mask=conv_mask,
|
103 |
+
cache=conv_state_i,
|
104 |
+
output_final_state=use_cache,
|
105 |
+
cu_seqlens=cu_seqlens
|
106 |
+
)
|
107 |
+
f, conv_state_f = self.f_conv1d(
|
108 |
+
x=self.f_proj(hidden_states),
|
109 |
+
mask=conv_mask,
|
110 |
+
cache=conv_state_f,
|
111 |
+
output_final_state=use_cache,
|
112 |
+
cu_seqlens=cu_seqlens
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
i = self.i_proj(hidden_states)
|
116 |
+
f = self.f_proj(hidden_states)
|
117 |
+
|
118 |
+
# the lower bound for the first layer is zero
|
119 |
+
if lower_bound is None or self.layer_idx == 0:
|
120 |
+
i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
|
121 |
+
else:
|
122 |
+
g = lower_bound + (1 - lower_bound) * f.sigmoid()
|
123 |
+
i, f = swiglu(i, 1 - g), g.log()
|
124 |
+
|
125 |
+
# dealing with left-padding
|
126 |
+
if attention_mask is not None:
|
127 |
+
i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
|
128 |
+
|
129 |
+
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
|
130 |
+
if mode == 'chunk':
|
131 |
+
if cu_seqlens is not None:
|
132 |
+
raise NotImplementedError("Chunk mode does not support variable-length sequences.")
|
133 |
+
o, recurrent_state = chunk_hgrn(
|
134 |
+
x=i,
|
135 |
+
g=f,
|
136 |
+
initial_state=recurrent_state,
|
137 |
+
output_final_state=use_cache,
|
138 |
+
)
|
139 |
+
elif mode == 'fused_recurrent':
|
140 |
+
o, recurrent_state = fused_recurrent_hgrn(
|
141 |
+
x=i,
|
142 |
+
g=f,
|
143 |
+
initial_state=recurrent_state,
|
144 |
+
output_final_state=use_cache,
|
145 |
+
cu_seqlens=cu_seqlens
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
149 |
+
|
150 |
+
if past_key_values is not None:
|
151 |
+
past_key_values.update(
|
152 |
+
recurrent_state=recurrent_state,
|
153 |
+
conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
|
154 |
+
layer_idx=self.layer_idx,
|
155 |
+
offset=i.shape[2]
|
156 |
+
)
|
157 |
+
|
158 |
+
o = self.g_norm(o, self.g_proj(hidden_states))
|
159 |
+
o = self.o_proj(o)
|
160 |
+
|
161 |
+
return o, None, past_key_values
|
162 |
+
|
163 |
+
def state_size(self, **kwargs) -> int:
|
164 |
+
state_size = self.hidden_size
|
165 |
+
for module in self.children():
|
166 |
+
if isinstance(module, ShortConvolution):
|
167 |
+
state_size += module.state_size
|
168 |
+
return state_size
|
fla/layers/rebased.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
"""
|
5 |
+
https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
from __future__ import annotations
|
9 |
+
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
from fla.modules.feature_map import RebasedFeatureMap
|
17 |
+
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
|
18 |
+
from fla.ops.rebased import parallel_rebased
|
19 |
+
|
20 |
+
|
21 |
+
class ReBasedLinearAttention(nn.Module):
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
hidden_size: int,
|
26 |
+
l_max: int = 2048,
|
27 |
+
feature_dim: int = 16,
|
28 |
+
num_key_value_heads: int = 16,
|
29 |
+
num_heads: int = 16,
|
30 |
+
use_gamma: Optional[bool] = True,
|
31 |
+
use_beta: Optional[bool] = True,
|
32 |
+
normalize: Optional[bool] = True,
|
33 |
+
causal: bool = True,
|
34 |
+
eps: float = 1e-5,
|
35 |
+
mode: str = "parallel",
|
36 |
+
layer_idx: Optional[int] = None,
|
37 |
+
**kwargs
|
38 |
+
) -> ReBasedLinearAttention:
|
39 |
+
super().__init__()
|
40 |
+
self.hidden_size = hidden_size
|
41 |
+
self.l_max = l_max
|
42 |
+
self.mode = mode
|
43 |
+
assert self.mode in ["fused_chunk", "parallel", 'chunk']
|
44 |
+
|
45 |
+
self.feature_dim = feature_dim
|
46 |
+
self.num_key_value_heads = num_key_value_heads
|
47 |
+
self.num_heads = num_heads
|
48 |
+
self.head_dim = self.hidden_size // self.num_key_value_heads
|
49 |
+
self.use_gamma = use_gamma
|
50 |
+
self.use_beta = use_beta
|
51 |
+
self.normalize = normalize
|
52 |
+
self.causal = causal
|
53 |
+
self.eps = eps
|
54 |
+
self.mode = mode
|
55 |
+
self.layer_idx = layer_idx
|
56 |
+
|
57 |
+
self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
|
58 |
+
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
|
59 |
+
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
|
60 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
61 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
62 |
+
self.dropout = nn.Identity()
|
63 |
+
|
64 |
+
def forward(self, hidden_states: torch.Tensor, **kwargs):
|
65 |
+
mode = self.mode
|
66 |
+
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
67 |
+
q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
|
68 |
+
q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
|
69 |
+
if mode == "fused_chunk":
|
70 |
+
o = fused_chunk_linear_attn(
|
71 |
+
q=q,
|
72 |
+
k=k,
|
73 |
+
v=v,
|
74 |
+
normalize=True,
|
75 |
+
scale=1,
|
76 |
+
head_first=False
|
77 |
+
)
|
78 |
+
elif mode == 'chunk':
|
79 |
+
o = chunk_linear_attn(
|
80 |
+
q=q,
|
81 |
+
k=k,
|
82 |
+
v=v,
|
83 |
+
normalize=True,
|
84 |
+
scale=1,
|
85 |
+
head_first=False
|
86 |
+
)
|
87 |
+
elif mode == 'parallel':
|
88 |
+
assert q.shape[-1] <= 128
|
89 |
+
o = parallel_rebased(
|
90 |
+
q=q,
|
91 |
+
k=k,
|
92 |
+
v=v,
|
93 |
+
eps=self.eps,
|
94 |
+
use_scale=True,
|
95 |
+
use_normalize=True,
|
96 |
+
head_first=False
|
97 |
+
)
|
98 |
+
o = self.o_proj(o)
|
99 |
+
o = self.dropout(o)
|
100 |
+
return o
|
101 |
+
|
102 |
+
# https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
|
103 |
+
def forward_reference(
|
104 |
+
self,
|
105 |
+
hidden_states: torch.Tensor,
|
106 |
+
filters: torch.Tensor = None,
|
107 |
+
*args,
|
108 |
+
**kwargs
|
109 |
+
):
|
110 |
+
"""
|
111 |
+
x (torch.Tensor): tensor of shape (b, d, t)
|
112 |
+
y (torch.Tensor): tensor of shape (b, d, t)
|
113 |
+
"""
|
114 |
+
b, t, _ = hidden_states.size()
|
115 |
+
q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
|
116 |
+
|
117 |
+
q = q.view(b, t, -1, self.feature_dim).transpose(1, 2)
|
118 |
+
k = k.view(b, t, -1, self.feature_dim).transpose(1, 2)
|
119 |
+
v = v.view(b, t, -1, self.head_dim).transpose(1, 2)
|
120 |
+
|
121 |
+
# Linear attention
|
122 |
+
q, k = self.feature_map(q), self.feature_map(k)
|
123 |
+
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
|
124 |
+
|
125 |
+
# Compute attention
|
126 |
+
if self.causal:
|
127 |
+
y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
|
128 |
+
else:
|
129 |
+
y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
|
130 |
+
y = rearrange(y, 'b h t d -> b t (h d)')
|
131 |
+
y = self.o_proj(y.to(hidden_states.dtype))
|
132 |
+
y = self.dropout(y)
|
133 |
+
return y.to(hidden_states.dtype)
|
fla/layers/rwkv6.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
# "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
|
5 |
+
|
6 |
+
from __future__ import annotations
|
7 |
+
|
8 |
+
from typing import TYPE_CHECKING, Optional, Tuple
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from fla.modules import GroupNorm
|
15 |
+
from fla.modules.activations import ACT2FN
|
16 |
+
from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
|
17 |
+
|
18 |
+
if TYPE_CHECKING:
|
19 |
+
from fla.models.utils import Cache
|
20 |
+
|
21 |
+
|
22 |
+
class RWKV6Attention(nn.Module):
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
mode: str = 'chunk',
|
27 |
+
hidden_size: int = 1024,
|
28 |
+
expand_k: float = 0.5,
|
29 |
+
expand_v: float = 1.0,
|
30 |
+
num_heads: int = 4,
|
31 |
+
gate_fn: str = 'swish',
|
32 |
+
proj_low_rank_dim: int = 32,
|
33 |
+
gate_low_rank_dim: int = 64,
|
34 |
+
fuse_norm: bool = True,
|
35 |
+
elementwise_affine: Optional[bool] = True,
|
36 |
+
norm_eps: float = 1e-5,
|
37 |
+
layer_idx: int = None,
|
38 |
+
**kwargs
|
39 |
+
) -> RWKV6Attention:
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.mode = mode
|
43 |
+
self.hidden_size = hidden_size
|
44 |
+
self.expand_k = expand_k
|
45 |
+
self.expand_v = expand_v
|
46 |
+
self.num_heads = num_heads
|
47 |
+
self.proj_low_rank_dim = proj_low_rank_dim
|
48 |
+
self.gate_low_rank_dim = gate_low_rank_dim
|
49 |
+
|
50 |
+
self.key_dim = int(hidden_size * expand_k)
|
51 |
+
self.value_dim = int(hidden_size * expand_v)
|
52 |
+
self.layer_idx = layer_idx
|
53 |
+
|
54 |
+
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
|
55 |
+
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
56 |
+
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
57 |
+
|
58 |
+
self.head_k_dim = self.key_dim // num_heads
|
59 |
+
self.head_v_dim = self.value_dim // num_heads
|
60 |
+
|
61 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
62 |
+
self.x_proj = nn.Sequential(
|
63 |
+
LerpLinear(hidden_size, proj_low_rank_dim * 5),
|
64 |
+
nn.Tanh(),
|
65 |
+
nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False)
|
66 |
+
)
|
67 |
+
self.x_bias = nn.Parameter(torch.zeros(5, hidden_size))
|
68 |
+
|
69 |
+
self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
|
70 |
+
self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
|
71 |
+
self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
|
72 |
+
self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
|
73 |
+
self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
|
74 |
+
self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim))
|
75 |
+
|
76 |
+
# TODO: fuse GroupNorm and output gate
|
77 |
+
self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps)
|
78 |
+
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
79 |
+
self.gate_fn = ACT2FN[gate_fn]
|
80 |
+
|
81 |
+
self.apply(self._initialize_weights)
|
82 |
+
|
83 |
+
def _initialize_weights(self, module: nn.Module):
|
84 |
+
if getattr(module, "_is_hf_initialized", False):
|
85 |
+
return
|
86 |
+
if isinstance(module, nn.Linear):
|
87 |
+
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
|
88 |
+
if module.bias is not None:
|
89 |
+
nn.init.zeros_(module.bias)
|
90 |
+
if isinstance(module, nn.Parameter):
|
91 |
+
nn.init.xavier_uniform_(module, gain=2 ** -2.5)
|
92 |
+
module._is_hf_initialized = True
|
93 |
+
|
94 |
+
def forward(
|
95 |
+
self,
|
96 |
+
hidden_states: torch.Tensor,
|
97 |
+
attention_mask: Optional[torch.Tensor] = None,
|
98 |
+
past_key_values: Optional[Cache] = None,
|
99 |
+
use_cache: Optional[bool] = False,
|
100 |
+
output_attentions: Optional[bool] = False,
|
101 |
+
**kwargs
|
102 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
103 |
+
if attention_mask is not None:
|
104 |
+
assert len(attention_mask.shape) == 2, (
|
105 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
106 |
+
"for padding purposes (0 indicating padding). "
|
107 |
+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
108 |
+
)
|
109 |
+
|
110 |
+
batch_size, seq_len, hidden_size = hidden_states.shape
|
111 |
+
# launching the triton kernel for just one token will actually be slower
|
112 |
+
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
|
113 |
+
|
114 |
+
last_state = None
|
115 |
+
if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
116 |
+
last_state = past_key_values[self.layer_idx]
|
117 |
+
|
118 |
+
if attention_mask is not None:
|
119 |
+
hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None])
|
120 |
+
if hidden_states.shape[1] == 1 and last_state is not None:
|
121 |
+
shifted = last_state['conv_state'].unsqueeze(1)
|
122 |
+
else:
|
123 |
+
shifted = self.time_shift(hidden_states)
|
124 |
+
if last_state is not None:
|
125 |
+
shifted[:, 0] = last_state['conv_state']
|
126 |
+
|
127 |
+
delta = shifted - hidden_states
|
128 |
+
x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
|
129 |
+
x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1))
|
130 |
+
|
131 |
+
r, w, k, v, g = x.add_(self.x_bias).unbind(-2)
|
132 |
+
r = self.r_proj(hidden_states, r, delta)
|
133 |
+
w = self.w_proj(hidden_states, w, delta)
|
134 |
+
k = self.k_proj(hidden_states, k, delta)
|
135 |
+
v = self.v_proj(hidden_states, v, delta)
|
136 |
+
g = self.g_proj(hidden_states, g, delta)
|
137 |
+
|
138 |
+
# dealing with left-padding
|
139 |
+
if attention_mask is not None:
|
140 |
+
v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
|
141 |
+
r, w, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (r, w, k))
|
142 |
+
v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
|
143 |
+
w = -torch.exp(w)
|
144 |
+
u = self.bonus
|
145 |
+
|
146 |
+
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
|
147 |
+
cu_seqlens = kwargs.get('cu_seqlens', None)
|
148 |
+
if mode == 'fused_recurrent':
|
149 |
+
o, recurrent_state = fused_recurrent_rwkv6(
|
150 |
+
r=r,
|
151 |
+
k=k,
|
152 |
+
v=v,
|
153 |
+
w=w,
|
154 |
+
u=u,
|
155 |
+
scale=1.,
|
156 |
+
initial_state=recurrent_state,
|
157 |
+
output_final_state=use_cache,
|
158 |
+
cu_seqlens=cu_seqlens,
|
159 |
+
head_first=False
|
160 |
+
)
|
161 |
+
elif mode == 'chunk':
|
162 |
+
o, recurrent_state = chunk_rwkv6(
|
163 |
+
q=r,
|
164 |
+
k=k,
|
165 |
+
v=v,
|
166 |
+
g=w,
|
167 |
+
u=u,
|
168 |
+
scale=1.,
|
169 |
+
initial_state=recurrent_state,
|
170 |
+
output_final_state=use_cache,
|
171 |
+
cu_seqlens=cu_seqlens,
|
172 |
+
head_first=False
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
176 |
+
|
177 |
+
if past_key_values is not None:
|
178 |
+
past_key_values.update(
|
179 |
+
recurrent_state=recurrent_state,
|
180 |
+
conv_state=hidden_states[:, -1],
|
181 |
+
layer_idx=self.layer_idx,
|
182 |
+
offset=r.shape[2]
|
183 |
+
)
|
184 |
+
|
185 |
+
o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g)
|
186 |
+
o = self.o_proj(o)
|
187 |
+
|
188 |
+
return o, None, past_key_values
|
189 |
+
|
190 |
+
|
191 |
+
class LoRA(nn.Module):
|
192 |
+
|
193 |
+
def __init__(
|
194 |
+
self,
|
195 |
+
input_dim: int,
|
196 |
+
output_dim: int,
|
197 |
+
low_rank_dim: int,
|
198 |
+
bias: Optional[bool] = True,
|
199 |
+
activation: Optional[str] = 'tanh'
|
200 |
+
):
|
201 |
+
super().__init__()
|
202 |
+
|
203 |
+
self.input_dim = input_dim
|
204 |
+
self.output_dim = output_dim
|
205 |
+
self.low_rank_dim = low_rank_dim
|
206 |
+
self.bias = bias
|
207 |
+
|
208 |
+
if activation is None:
|
209 |
+
self.activation = nn.Identity()
|
210 |
+
elif activation == 'sigmoid':
|
211 |
+
self.activation = nn.Sigmoid()
|
212 |
+
elif activation == 'tanh':
|
213 |
+
self.activation = nn.Tanh()
|
214 |
+
elif activation == 'relu':
|
215 |
+
self.activation = nn.ReLU()
|
216 |
+
else:
|
217 |
+
raise ValueError(f"Not supported activation `{activation}`.")
|
218 |
+
|
219 |
+
self.lora = nn.Sequential(
|
220 |
+
nn.Linear(input_dim, low_rank_dim, bias=False),
|
221 |
+
self.activation,
|
222 |
+
nn.Linear(low_rank_dim, output_dim, bias=bias)
|
223 |
+
)
|
224 |
+
|
225 |
+
def __repr__(self) -> str:
|
226 |
+
s = f"{self.__class__.__name__}("
|
227 |
+
s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
|
228 |
+
if not self.bias:
|
229 |
+
s += f", bias={self.bias}"
|
230 |
+
s += ")"
|
231 |
+
return s
|
232 |
+
|
233 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
234 |
+
return self.lora(x)
|
235 |
+
|
236 |
+
|
237 |
+
class LerpLinear(nn.Module):
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
input_dim: int,
|
242 |
+
output_dim: int,
|
243 |
+
low_rank_dim: Optional[int] = None
|
244 |
+
):
|
245 |
+
super().__init__()
|
246 |
+
|
247 |
+
self.input_dim = input_dim
|
248 |
+
self.output_dim = output_dim
|
249 |
+
self.low_rank_dim = low_rank_dim
|
250 |
+
|
251 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
252 |
+
if low_rank_dim is None:
|
253 |
+
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
254 |
+
else:
|
255 |
+
self.linear = LoRA(input_dim, output_dim, low_rank_dim)
|
256 |
+
self.mu = nn.Parameter(torch.zeros(input_dim))
|
257 |
+
|
258 |
+
def __repr__(self) -> str:
|
259 |
+
s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
|
260 |
+
if self.low_rank_dim is not None:
|
261 |
+
s += f", low_rank_dim={self.low_rank_dim}"
|
262 |
+
s += ")"
|
263 |
+
return s
|
264 |
+
|
265 |
+
def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
|
266 |
+
if delta is None:
|
267 |
+
shifted = self.time_shift(x)
|
268 |
+
if len(shifted.shape) == 2:
|
269 |
+
shifted = shifted.unsqueeze(1)
|
270 |
+
delta = shifted - x
|
271 |
+
return self.linear(x + delta * self.mu)
|
272 |
+
|
273 |
+
|
274 |
+
class DDLerpLinear(nn.Module):
|
275 |
+
|
276 |
+
def __init__(
|
277 |
+
self,
|
278 |
+
input_dim: int,
|
279 |
+
output_dim: int,
|
280 |
+
low_rank_dim: Optional[int] = None
|
281 |
+
):
|
282 |
+
super().__init__()
|
283 |
+
|
284 |
+
self.input_dim = input_dim
|
285 |
+
self.output_dim = output_dim
|
286 |
+
self.low_rank_dim = low_rank_dim
|
287 |
+
|
288 |
+
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
289 |
+
if low_rank_dim is None:
|
290 |
+
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
291 |
+
else:
|
292 |
+
self.linear = LoRA(input_dim, output_dim, low_rank_dim)
|
293 |
+
|
294 |
+
def __repr__(self) -> str:
|
295 |
+
s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
|
296 |
+
if self.low_rank_dim is not None:
|
297 |
+
s += f", low_rank_dim={self.low_rank_dim}"
|
298 |
+
s += ")"
|
299 |
+
return s
|
300 |
+
|
301 |
+
def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
|
302 |
+
if delta is None:
|
303 |
+
shifted = self.time_shift(x)
|
304 |
+
if len(shifted.shape) == 2:
|
305 |
+
shifted = shifted.unsqueeze(1)
|
306 |
+
delta = shifted - x
|
307 |
+
return self.linear(x + delta * mu)
|
fla/layers/simple_gla.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
from __future__ import annotations
|
5 |
+
|
6 |
+
from typing import TYPE_CHECKING, Optional, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
|
14 |
+
from fla.modules.activations import ACT2FN
|
15 |
+
from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla
|
16 |
+
|
17 |
+
if TYPE_CHECKING:
|
18 |
+
from fla.models.utils import Cache
|
19 |
+
|
20 |
+
|
21 |
+
class SimpleGatedLinearAttention(nn.Module):
|
22 |
+
r"""
|
23 |
+
The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
|
24 |
+
This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
mode (str, Optional):
|
28 |
+
Which GLA kernel to use.
|
29 |
+
Currently available: `chunk`.
|
30 |
+
Default: `chunk`.
|
31 |
+
hidden_size (int, Optional):
|
32 |
+
The hidden size of the input. Default: 1024.
|
33 |
+
expand_k (float, Optional):
|
34 |
+
The expansion ratio for the key dim. Default: 1.0.
|
35 |
+
expand_v (float, Optional):
|
36 |
+
The expansion ratio for the value dim. Default: 1.0.
|
37 |
+
num_heads (int, Optional):
|
38 |
+
The number of heads. Default: 4.
|
39 |
+
num_kv_heads (int, Optional):
|
40 |
+
The number of key/value heads, used for MQA. Default: None.
|
41 |
+
feature_map (str, Optional):
|
42 |
+
Feature map function applied to queries/keys. Default: None.
|
43 |
+
use_short_conv (bool, Optional):
|
44 |
+
Whether to use short convolutions. Default: `False`.
|
45 |
+
conv_size (int, Optional):
|
46 |
+
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
|
47 |
+
conv_bias (bool, Optional):
|
48 |
+
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
|
49 |
+
gate_fn (str, Optional):
|
50 |
+
The activation function for the output gate. Default: `swish`.
|
51 |
+
elementwise_affine (bool, Optional):
|
52 |
+
If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
|
53 |
+
norm_eps (float, Optional):
|
54 |
+
The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
|
55 |
+
gate_logit_normalizer (int, Optional):
|
56 |
+
The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
|
57 |
+
fuse_norm (bool, Optional):
|
58 |
+
Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
|
59 |
+
layer_idx (int, Optional):
|
60 |
+
The index of the layer. Default: None.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
mode: str = 'chunk',
|
66 |
+
hidden_size: int = 1024,
|
67 |
+
expand_k: float = 1.,
|
68 |
+
expand_v: float = 1.,
|
69 |
+
num_heads: int = 4,
|
70 |
+
num_kv_heads: Optional[int] = None,
|
71 |
+
feature_map: Optional[str] = None,
|
72 |
+
use_short_conv: bool = True,
|
73 |
+
conv_size: int = 4,
|
74 |
+
conv_bias: bool = False,
|
75 |
+
gate_fn: str = 'swish',
|
76 |
+
elementwise_affine: Optional[bool] = True,
|
77 |
+
norm_eps: float = 1e-5,
|
78 |
+
gate_logit_normalizer: int = 16,
|
79 |
+
fuse_norm: bool = True,
|
80 |
+
layer_idx: int = None,
|
81 |
+
) -> SimpleGatedLinearAttention:
|
82 |
+
super().__init__()
|
83 |
+
|
84 |
+
self.mode = mode
|
85 |
+
self.hidden_size = hidden_size
|
86 |
+
self.expand_k = expand_k
|
87 |
+
self.expand_v = expand_v
|
88 |
+
self.num_heads = num_heads
|
89 |
+
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
|
90 |
+
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
91 |
+
self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
|
92 |
+
|
93 |
+
self.use_short_conv = use_short_conv
|
94 |
+
self.conv_size = conv_size
|
95 |
+
self.conv_bias = conv_bias
|
96 |
+
|
97 |
+
self.key_dim = int(hidden_size * expand_k)
|
98 |
+
self.value_dim = int(hidden_size * expand_v)
|
99 |
+
self.key_dim_per_group = self.key_dim // self.num_kv_groups
|
100 |
+
self.value_dim_per_group = self.value_dim // self.num_kv_groups
|
101 |
+
self.layer_idx = layer_idx
|
102 |
+
|
103 |
+
assert mode in ['chunk', "fused_recurrent"], f"Not suppoerted mode `{mode}`."
|
104 |
+
assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
|
105 |
+
assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
|
106 |
+
|
107 |
+
self.head_k_dim = self.key_dim // num_heads
|
108 |
+
self.head_v_dim = self.value_dim // num_heads
|
109 |
+
|
110 |
+
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
|
111 |
+
self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
|
112 |
+
self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
|
113 |
+
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
114 |
+
|
115 |
+
if use_short_conv:
|
116 |
+
self.conv_size = conv_size
|
117 |
+
self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
|
118 |
+
self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
|
119 |
+
self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
|
120 |
+
|
121 |
+
self.gk_proj = nn.Linear(hidden_size, self.num_heads)
|
122 |
+
|
123 |
+
if gate_fn == 'swish' and fuse_norm:
|
124 |
+
self.g_norm_swish_gate = FusedRMSNormGated(
|
125 |
+
hidden_size=self.head_v_dim,
|
126 |
+
elementwise_affine=elementwise_affine,
|
127 |
+
eps=norm_eps
|
128 |
+
)
|
129 |
+
self.fuse_norm_and_gate = True
|
130 |
+
else:
|
131 |
+
self.fuse_norm_and_gate = False
|
132 |
+
self.g_norm = RMSNorm(
|
133 |
+
hidden_size=self.head_v_dim,
|
134 |
+
elementwise_affine=elementwise_affine,
|
135 |
+
eps=norm_eps
|
136 |
+
)
|
137 |
+
self.gate_fn = ACT2FN[gate_fn]
|
138 |
+
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
|
139 |
+
|
140 |
+
self.gate_logit_normalizer = gate_logit_normalizer
|
141 |
+
|
142 |
+
def forward(
|
143 |
+
self,
|
144 |
+
hidden_states: torch.Tensor,
|
145 |
+
attention_mask: Optional[torch.Tensor] = None,
|
146 |
+
past_key_values: Optional[Cache] = None,
|
147 |
+
use_cache: Optional[bool] = False,
|
148 |
+
output_attentions: Optional[bool] = False,
|
149 |
+
**kwargs
|
150 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
151 |
+
if attention_mask is not None:
|
152 |
+
assert len(attention_mask.shape) == 2, (
|
153 |
+
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
|
154 |
+
"for padding purposes (0 indicating padding). "
|
155 |
+
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
|
156 |
+
)
|
157 |
+
|
158 |
+
# launching the triton kernel for just one token will actually be slower
|
159 |
+
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
|
160 |
+
|
161 |
+
last_state = None
|
162 |
+
if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
163 |
+
last_state = past_key_values[self.layer_idx]
|
164 |
+
|
165 |
+
cu_seqlens = kwargs.get('cu_seqlens', None)
|
166 |
+
if self.use_short_conv:
|
167 |
+
conv_state_q, conv_state_k, conv_state_v = None, None, None
|
168 |
+
if last_state is not None:
|
169 |
+
conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
|
170 |
+
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
|
171 |
+
q, conv_state_q = self.q_conv1d(
|
172 |
+
x=self.q_proj(hidden_states),
|
173 |
+
mask=conv_mask,
|
174 |
+
cache=conv_state_q,
|
175 |
+
output_final_state=use_cache,
|
176 |
+
cu_seqlens=cu_seqlens
|
177 |
+
)
|
178 |
+
k, conv_state_k = self.k_conv1d(
|
179 |
+
x=self.k_proj(hidden_states),
|
180 |
+
mask=conv_mask,
|
181 |
+
cache=conv_state_k,
|
182 |
+
output_final_state=use_cache,
|
183 |
+
cu_seqlens=cu_seqlens
|
184 |
+
)
|
185 |
+
v, conv_state_v = self.v_conv1d(
|
186 |
+
x=self.v_proj(hidden_states),
|
187 |
+
mask=conv_mask,
|
188 |
+
cache=conv_state_v,
|
189 |
+
output_final_state=use_cache,
|
190 |
+
cu_seqlens=cu_seqlens
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
q = self.q_proj(hidden_states)
|
194 |
+
k = self.k_proj(hidden_states)
|
195 |
+
v = self.v_proj(hidden_states)
|
196 |
+
gk = self.gk_proj(hidden_states)
|
197 |
+
|
198 |
+
if self.feature_map_fn is not None:
|
199 |
+
q, k = map(self.feature_map_fn, (q, k))
|
200 |
+
# dealing with left-padding
|
201 |
+
if attention_mask is not None:
|
202 |
+
v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
|
203 |
+
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
|
204 |
+
if self.num_kv_groups > 1:
|
205 |
+
k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
|
206 |
+
else:
|
207 |
+
k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v))
|
208 |
+
gk = F.logsigmoid(gk) / self.gate_logit_normalizer
|
209 |
+
|
210 |
+
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
|
211 |
+
if mode == 'chunk':
|
212 |
+
o, recurrent_state = chunk_simple_gla(
|
213 |
+
q=q,
|
214 |
+
k=k,
|
215 |
+
v=v,
|
216 |
+
gk=gk,
|
217 |
+
initial_state=recurrent_state,
|
218 |
+
output_final_state=use_cache,
|
219 |
+
cu_seqlens=cu_seqlens,
|
220 |
+
head_first=False
|
221 |
+
)
|
222 |
+
elif mode == 'fused_recurrent':
|
223 |
+
o, recurrent_state = fused_recurrent_simple_gla(
|
224 |
+
q=q,
|
225 |
+
k=k,
|
226 |
+
v=v,
|
227 |
+
gk=gk,
|
228 |
+
initial_state=recurrent_state,
|
229 |
+
output_final_state=use_cache,
|
230 |
+
cu_seqlens=cu_seqlens,
|
231 |
+
head_first=False
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
raise NotImplementedError(f"Not supported mode `{mode}`.")
|
235 |
+
|
236 |
+
if past_key_values is not None:
|
237 |
+
past_key_values.update(
|
238 |
+
recurrent_state=recurrent_state,
|
239 |
+
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
|
240 |
+
layer_idx=self.layer_idx,
|
241 |
+
offset=q.shape[1]
|
242 |
+
)
|
243 |
+
|
244 |
+
g = self.g_proj(hidden_states)
|
245 |
+
if self.fuse_norm_and_gate:
|
246 |
+
g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads)
|
247 |
+
o = self.g_norm_swish_gate(o, g)
|
248 |
+
o = rearrange(o, 'b t h d -> b t (h d)')
|
249 |
+
else:
|
250 |
+
o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
|
251 |
+
o = o * self.gate_fn(g)
|
252 |
+
o = self.o_proj(o)
|
253 |
+
|
254 |
+
return o, None, past_key_values
|
255 |
+
|
256 |
+
def state_size(self, **kwargs) -> int:
|
257 |
+
state_size = self.key_dim * self.head_v_dim
|
258 |
+
for module in self.children():
|
259 |
+
if isinstance(module, ShortConvolution):
|
260 |
+
state_size += module.state_size
|
261 |
+
return state_size
|
fla/models/abc/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
4 |
+
|
5 |
+
from fla.models.abc.configuration_abc import ABCConfig
|
6 |
+
from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
|
7 |
+
|
8 |
+
AutoConfig.register(ABCConfig.model_type, ABCConfig)
|
9 |
+
AutoModel.register(ABCConfig, ABCModel)
|
10 |
+
AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
|
11 |
+
|
12 |
+
|
13 |
+
__all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']
|
fla/models/abc/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (714 Bytes). View file
|
|
fla/models/abc/modeling_abc.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers.generation import GenerationMixin
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
+
from transformers.modeling_utils import PreTrainedModel
|
15 |
+
from transformers.utils import logging
|
16 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
17 |
+
|
18 |
+
from fla.layers.abc import ABCAttention
|
19 |
+
from fla.layers.attn import Attention
|
20 |
+
from fla.models.abc.configuration_abc import ABCConfig
|
21 |
+
from fla.models.utils import Cache
|
22 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
|
23 |
+
from fla.modules import GatedMLP as ABCMLP
|
24 |
+
from fla.modules import RMSNorm
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
if TYPE_CHECKING:
|
29 |
+
from transformers.processing_utils import Unpack
|
30 |
+
|
31 |
+
|
32 |
+
class ABCBlock(nn.Module):
|
33 |
+
def __init__(self, config: ABCConfig, layer_idx: int):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.config = config
|
37 |
+
self.layer_idx = layer_idx
|
38 |
+
|
39 |
+
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
40 |
+
if config.attn is not None and layer_idx in config.attn['layers']:
|
41 |
+
self.attn = Attention(
|
42 |
+
hidden_size=config.hidden_size,
|
43 |
+
num_heads=config.attn['num_heads'],
|
44 |
+
num_kv_heads=config.attn['num_kv_heads'],
|
45 |
+
qkv_bias=config.attn['qkv_bias'],
|
46 |
+
window_size=config.attn['window_size'],
|
47 |
+
rope_theta=config.attn['rope_theta'],
|
48 |
+
max_position_embeddings=config.max_position_embeddings,
|
49 |
+
layer_idx=layer_idx
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
self.attn = ABCAttention(
|
53 |
+
hidden_size=config.hidden_size,
|
54 |
+
expand_k=config.expand_k,
|
55 |
+
expand_v=config.expand_v,
|
56 |
+
num_heads=config.num_heads,
|
57 |
+
num_slots=config.num_slots,
|
58 |
+
use_short_conv=config.use_short_conv,
|
59 |
+
conv_size=config.conv_size,
|
60 |
+
gate_fn=config.hidden_act,
|
61 |
+
elementwise_affine=config.elementwise_affine,
|
62 |
+
norm_eps=config.norm_eps,
|
63 |
+
use_rope=config.use_rope,
|
64 |
+
clamp_min=config.clamp_min,
|
65 |
+
clamp_max=config.clamp_max,
|
66 |
+
fuse_norm=config.fuse_norm,
|
67 |
+
layer_idx=layer_idx
|
68 |
+
)
|
69 |
+
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
70 |
+
self.mlp = ABCMLP(
|
71 |
+
hidden_size=config.hidden_size,
|
72 |
+
hidden_ratio=config.hidden_ratio,
|
73 |
+
intermediate_size=config.intermediate_size,
|
74 |
+
hidden_act=config.hidden_act,
|
75 |
+
fuse_swiglu=config.fuse_swiglu
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(
|
79 |
+
self,
|
80 |
+
hidden_states: torch.Tensor,
|
81 |
+
attention_mask: Optional[torch.Tensor] = None,
|
82 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
83 |
+
use_cache: Optional[bool] = False,
|
84 |
+
output_attentions: Optional[bool] = False,
|
85 |
+
**kwargs: Unpack[Dict]
|
86 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
87 |
+
|
88 |
+
residual = hidden_states
|
89 |
+
|
90 |
+
hidden_states = self.attn_norm(hidden_states)
|
91 |
+
hidden_states, attentions, past_key_values = self.attn(
|
92 |
+
hidden_states=hidden_states,
|
93 |
+
attention_mask=attention_mask,
|
94 |
+
past_key_values=past_key_values,
|
95 |
+
use_cache=use_cache,
|
96 |
+
output_attentions=output_attentions,
|
97 |
+
**kwargs
|
98 |
+
)
|
99 |
+
if self.config.fuse_norm:
|
100 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
101 |
+
else:
|
102 |
+
hidden_states = residual + hidden_states
|
103 |
+
residual = hidden_states
|
104 |
+
hidden_states = self.mlp_norm(hidden_states)
|
105 |
+
hidden_states = self.mlp(hidden_states)
|
106 |
+
hidden_states = residual + hidden_states
|
107 |
+
|
108 |
+
outputs = (hidden_states, attentions, past_key_values)
|
109 |
+
|
110 |
+
return outputs
|
111 |
+
|
112 |
+
|
113 |
+
class ABCPreTrainedModel(PreTrainedModel):
|
114 |
+
|
115 |
+
config_class = ABCConfig
|
116 |
+
base_model_prefix = 'model'
|
117 |
+
supports_gradient_checkpointing = True
|
118 |
+
_no_split_modules = ['ABCBlock']
|
119 |
+
_supports_cache_class = True
|
120 |
+
|
121 |
+
def __init__(self, *inputs, **kwargs):
|
122 |
+
super().__init__(*inputs, **kwargs)
|
123 |
+
|
124 |
+
def _init_weights(
|
125 |
+
self,
|
126 |
+
module: nn.Module,
|
127 |
+
prenorm_residual_strategy: Optional[str] = 'rescale',
|
128 |
+
num_residuals_per_layer: int = 2,
|
129 |
+
):
|
130 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
131 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
132 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
133 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
134 |
+
if module.bias is not None:
|
135 |
+
nn.init.zeros_(module.bias)
|
136 |
+
elif isinstance(module, nn.Embedding):
|
137 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
138 |
+
elif hasattr(module, 'reset_parameters'):
|
139 |
+
module.reset_parameters()
|
140 |
+
|
141 |
+
if prenorm_residual_strategy is not None:
|
142 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
143 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
144 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
145 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
146 |
+
#
|
147 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
148 |
+
p = None
|
149 |
+
if hasattr(module, 'o_proj'):
|
150 |
+
p = module.o_proj.weight
|
151 |
+
elif hasattr(module, 'down_proj'):
|
152 |
+
p = module.down_proj.weight
|
153 |
+
if p is not None:
|
154 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
155 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
156 |
+
# We need to reinit p since this code could be called multiple times
|
157 |
+
# Having just p *= scale would repeatedly scale it down
|
158 |
+
if prenorm_residual_strategy == 'rescale':
|
159 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
160 |
+
with torch.no_grad():
|
161 |
+
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
|
162 |
+
elif prenorm_residual_strategy == 'zero':
|
163 |
+
nn.init.zeros_(p)
|
164 |
+
else:
|
165 |
+
raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
|
166 |
+
|
167 |
+
|
168 |
+
class ABCModel(ABCPreTrainedModel):
|
169 |
+
|
170 |
+
def __init__(self, config: ABCConfig):
|
171 |
+
super().__init__(config)
|
172 |
+
self.padding_idx = config.pad_token_id
|
173 |
+
self.vocab_size = config.vocab_size
|
174 |
+
|
175 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
176 |
+
self.layers = nn.ModuleList([ABCBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
177 |
+
self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
178 |
+
|
179 |
+
self.gradient_checkpointing = False
|
180 |
+
|
181 |
+
self.post_init()
|
182 |
+
|
183 |
+
def get_input_embeddings(self):
|
184 |
+
return self.embeddings
|
185 |
+
|
186 |
+
def set_input_embeddings(self, value):
|
187 |
+
self.embeddings = value
|
188 |
+
|
189 |
+
def forward(
|
190 |
+
self,
|
191 |
+
input_ids: Optional[torch.LongTensor] = None,
|
192 |
+
attention_mask: Optional[torch.Tensor] = None, # noqa
|
193 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
194 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
195 |
+
use_cache: Optional[bool] = None,
|
196 |
+
output_attentions: Optional[bool] = None,
|
197 |
+
output_hidden_states: Optional[bool] = None,
|
198 |
+
return_dict: Optional[bool] = None,
|
199 |
+
**kwargs: Unpack[Dict]
|
200 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
201 |
+
if output_attentions:
|
202 |
+
warnings.warn("`ABCModel` does not `output_attentions` now, setting it to `False`.")
|
203 |
+
output_attentions = False
|
204 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
205 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
206 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
207 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
208 |
+
|
209 |
+
# retrieve input_ids and inputs_embeds
|
210 |
+
if input_ids is not None and inputs_embeds is not None:
|
211 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
212 |
+
if input_ids is None and inputs_embeds is None:
|
213 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
214 |
+
|
215 |
+
if inputs_embeds is None:
|
216 |
+
inputs_embeds = self.embeddings(input_ids)
|
217 |
+
hidden_states = inputs_embeds
|
218 |
+
|
219 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
220 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
221 |
+
|
222 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
223 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
224 |
+
use_cache = False
|
225 |
+
|
226 |
+
all_hidden_states = () if output_hidden_states else None
|
227 |
+
all_attns = () if output_attentions else None
|
228 |
+
for layer in self.layers:
|
229 |
+
if output_hidden_states:
|
230 |
+
all_hidden_states += (hidden_states,)
|
231 |
+
|
232 |
+
if self.gradient_checkpointing and self.training:
|
233 |
+
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
|
234 |
+
layer.__call__,
|
235 |
+
hidden_states,
|
236 |
+
attention_mask,
|
237 |
+
past_key_values,
|
238 |
+
use_cache,
|
239 |
+
output_attentions,
|
240 |
+
**kwargs
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
hidden_states, attentions, past_key_values = layer(
|
244 |
+
hidden_states,
|
245 |
+
attention_mask,
|
246 |
+
past_key_values=past_key_values,
|
247 |
+
use_cache=use_cache,
|
248 |
+
output_attentions=output_attentions,
|
249 |
+
**kwargs
|
250 |
+
)
|
251 |
+
|
252 |
+
if output_attentions:
|
253 |
+
all_attns += (attentions,)
|
254 |
+
|
255 |
+
hidden_states = self.norm(hidden_states)
|
256 |
+
|
257 |
+
# add hidden states from the last decoder layer
|
258 |
+
if output_hidden_states:
|
259 |
+
all_hidden_states += (hidden_states,)
|
260 |
+
|
261 |
+
if not return_dict:
|
262 |
+
return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
|
263 |
+
return BaseModelOutputWithPast(
|
264 |
+
last_hidden_state=hidden_states,
|
265 |
+
past_key_values=past_key_values,
|
266 |
+
hidden_states=all_hidden_states,
|
267 |
+
attentions=all_attns
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
+
class ABCForCausalLM(ABCPreTrainedModel, GenerationMixin):
|
272 |
+
|
273 |
+
_tied_weights_keys = ["lm_head.weight"]
|
274 |
+
|
275 |
+
def __init__(self, config):
|
276 |
+
super().__init__(config)
|
277 |
+
self.model = ABCModel(config)
|
278 |
+
self.vocab_size = config.vocab_size
|
279 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
280 |
+
self.criterion = None
|
281 |
+
|
282 |
+
# Initialize weights and apply final processing
|
283 |
+
self.post_init()
|
284 |
+
|
285 |
+
def get_input_embeddings(self):
|
286 |
+
return self.model.embeddings
|
287 |
+
|
288 |
+
def set_input_embeddings(self, value):
|
289 |
+
self.model.embeddings = value
|
290 |
+
|
291 |
+
def get_output_embeddings(self):
|
292 |
+
return self.lm_head
|
293 |
+
|
294 |
+
def set_output_embeddings(self, new_embeddings):
|
295 |
+
self.lm_head = new_embeddings
|
296 |
+
|
297 |
+
def set_decoder(self, decoder):
|
298 |
+
self.model = decoder
|
299 |
+
|
300 |
+
def get_decoder(self):
|
301 |
+
return self.model
|
302 |
+
|
303 |
+
def generate(self, *args, **kwargs):
|
304 |
+
try:
|
305 |
+
return super().generate(*args, **kwargs)
|
306 |
+
except AttributeError as exception:
|
307 |
+
if 'past_key_values' in str(exception):
|
308 |
+
raise AttributeError(
|
309 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
310 |
+
f"which is not supported for {self.__class__.__name__}. "
|
311 |
+
f"Try another generation strategy instead. "
|
312 |
+
f"For the available generation strategies, check this doc: "
|
313 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
raise exception
|
317 |
+
|
318 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
319 |
+
def prepare_inputs_for_generation(
|
320 |
+
self,
|
321 |
+
input_ids: torch.LongTensor = None,
|
322 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
323 |
+
attention_mask: Optional[torch.Tensor] = None,
|
324 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
325 |
+
use_cache: bool = True,
|
326 |
+
logits_to_keep: Optional[int] = None,
|
327 |
+
**kwargs
|
328 |
+
):
|
329 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
330 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
331 |
+
input_ids = input_ids[:, -1:]
|
332 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
333 |
+
if inputs_embeds is not None and len(past_key_values) == 0:
|
334 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
335 |
+
else:
|
336 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
337 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
338 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
339 |
+
# TODO: use `next_tokens` directly instead.
|
340 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
341 |
+
|
342 |
+
if logits_to_keep is not None:
|
343 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
344 |
+
|
345 |
+
model_inputs.update({
|
346 |
+
'past_key_values': past_key_values,
|
347 |
+
'use_cache': use_cache,
|
348 |
+
'attention_mask': attention_mask,
|
349 |
+
})
|
350 |
+
return model_inputs
|
351 |
+
|
352 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
353 |
+
def forward(
|
354 |
+
self,
|
355 |
+
input_ids: torch.LongTensor = None,
|
356 |
+
attention_mask: Optional[torch.Tensor] = None,
|
357 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
358 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
359 |
+
labels: Optional[torch.LongTensor] = None,
|
360 |
+
use_cache: Optional[bool] = None,
|
361 |
+
output_attentions: Optional[bool] = None,
|
362 |
+
output_hidden_states: Optional[bool] = None,
|
363 |
+
return_dict: Optional[bool] = None,
|
364 |
+
logits_to_keep: Optional[int] = 0,
|
365 |
+
**kwargs: Unpack[Dict]
|
366 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
367 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
368 |
+
output_hidden_states = (
|
369 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
370 |
+
)
|
371 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
372 |
+
|
373 |
+
outputs = self.model(
|
374 |
+
input_ids=input_ids,
|
375 |
+
attention_mask=attention_mask,
|
376 |
+
inputs_embeds=inputs_embeds,
|
377 |
+
past_key_values=past_key_values,
|
378 |
+
use_cache=use_cache,
|
379 |
+
output_attentions=output_attentions,
|
380 |
+
output_hidden_states=output_hidden_states,
|
381 |
+
return_dict=return_dict,
|
382 |
+
**kwargs
|
383 |
+
)
|
384 |
+
|
385 |
+
hidden_states = outputs[0]
|
386 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
387 |
+
|
388 |
+
loss, logits = None, None
|
389 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
390 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
391 |
+
if labels is not None:
|
392 |
+
if getattr(self, 'criterion', None) is None:
|
393 |
+
if fuse_linear_and_cross_entropy:
|
394 |
+
criterion = FusedLinearCrossEntropyLoss()
|
395 |
+
elif self.config.fuse_cross_entropy:
|
396 |
+
criterion = FusedCrossEntropyLoss(inplace_backward=True)
|
397 |
+
else:
|
398 |
+
criterion = nn.CrossEntropyLoss()
|
399 |
+
else:
|
400 |
+
criterion = self.criterion
|
401 |
+
labels = labels.to(hidden_states.device)
|
402 |
+
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
|
403 |
+
if fuse_linear_and_cross_entropy:
|
404 |
+
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
|
405 |
+
else:
|
406 |
+
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
|
407 |
+
|
408 |
+
if not return_dict:
|
409 |
+
output = (logits,) + outputs[1:]
|
410 |
+
return (loss,) + output if loss is not None else output
|
411 |
+
|
412 |
+
return CausalLMOutputWithPast(
|
413 |
+
loss=loss,
|
414 |
+
logits=logits,
|
415 |
+
past_key_values=outputs.past_key_values,
|
416 |
+
hidden_states=outputs.hidden_states,
|
417 |
+
attentions=outputs.attentions,
|
418 |
+
)
|
fla/models/bitnet/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
4 |
+
|
5 |
+
from fla.models.bitnet.configuration_bitnet import BitNetConfig
|
6 |
+
from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel
|
7 |
+
|
8 |
+
AutoConfig.register(BitNetConfig.model_type, BitNetConfig)
|
9 |
+
AutoModel.register(BitNetConfig, BitNetModel)
|
10 |
+
AutoModelForCausalLM.register(BitNetConfig, BitNetForCausalLM)
|
11 |
+
|
12 |
+
|
13 |
+
__all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel']
|
fla/models/bitnet/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (739 Bytes). View file
|
|
fla/models/bitnet/__pycache__/configuration_bitnet.cpython-311.pyc
ADDED
Binary file (2.64 kB). View file
|
|
fla/models/bitnet/__pycache__/modeling_bitnet.cpython-311.pyc
ADDED
Binary file (19.6 kB). View file
|
|
fla/models/bitnet/configuration_bitnet.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class BitNetConfig(PretrainedConfig):
|
9 |
+
|
10 |
+
model_type = 'bitnet'
|
11 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
hidden_size: int = 2048,
|
16 |
+
num_hidden_layers: int = 24,
|
17 |
+
num_heads: int = 32,
|
18 |
+
num_kv_heads: int = None,
|
19 |
+
window_size: Optional[int] = None,
|
20 |
+
rope_theta: Optional[float] = 10000.,
|
21 |
+
max_position_embeddings: int = 2048,
|
22 |
+
hidden_ratio: Optional[int] = 4,
|
23 |
+
intermediate_size: Optional[int] = None,
|
24 |
+
hidden_act: str = "swish",
|
25 |
+
initializer_range: float = 0.006,
|
26 |
+
elementwise_affine: Optional[bool] = True,
|
27 |
+
norm_eps: float = 1e-6,
|
28 |
+
use_cache: bool = True,
|
29 |
+
pad_token_id: int = None,
|
30 |
+
bos_token_id: int = 1,
|
31 |
+
eos_token_id: int = 2,
|
32 |
+
tie_word_embeddings: bool = False,
|
33 |
+
fuse_norm: bool = True,
|
34 |
+
fuse_swiglu: bool = True,
|
35 |
+
fuse_cross_entropy: bool = True,
|
36 |
+
vocab_size: int = 32000,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
self.hidden_size = hidden_size
|
40 |
+
self.num_hidden_layers = num_hidden_layers
|
41 |
+
self.num_heads = num_heads
|
42 |
+
self.num_kv_heads = num_kv_heads
|
43 |
+
self.window_size = window_size
|
44 |
+
self.rope_theta = rope_theta
|
45 |
+
self.max_position_embeddings = max_position_embeddings
|
46 |
+
|
47 |
+
self.hidden_ratio = hidden_ratio
|
48 |
+
self.intermediate_size = intermediate_size
|
49 |
+
self.hidden_act = hidden_act
|
50 |
+
|
51 |
+
self.initializer_range = initializer_range
|
52 |
+
self.elementwise_affine = elementwise_affine
|
53 |
+
self.norm_eps = norm_eps
|
54 |
+
self.use_cache = use_cache
|
55 |
+
|
56 |
+
self.fuse_norm = fuse_norm
|
57 |
+
self.fuse_swiglu = fuse_swiglu
|
58 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
59 |
+
self.vocab_size = vocab_size
|
60 |
+
|
61 |
+
super().__init__(
|
62 |
+
pad_token_id=pad_token_id,
|
63 |
+
bos_token_id=bos_token_id,
|
64 |
+
eos_token_id=eos_token_id,
|
65 |
+
tie_word_embeddings=tie_word_embeddings,
|
66 |
+
**kwargs,
|
67 |
+
)
|
fla/models/delta_net/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
4 |
+
|
5 |
+
from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
|
6 |
+
from fla.models.delta_net.modeling_delta_net import DeltaNetForCausalLM, DeltaNetModel
|
7 |
+
|
8 |
+
AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
|
9 |
+
AutoModel.register(DeltaNetConfig, DeltaNetModel)
|
10 |
+
AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
|
11 |
+
|
12 |
+
__all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']
|
fla/models/forgetting_transformer/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (888 Bytes). View file
|
|
fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-311.pyc
ADDED
Binary file (18.1 kB). View file
|
|
fla/models/forgetting_transformer/configuration_forgetting_transformer.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class ForgettingTransformerConfig(PretrainedConfig):
|
9 |
+
|
10 |
+
model_type = 'forgetting_transformer'
|
11 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
hidden_size: int = 2048,
|
16 |
+
num_hidden_layers: int = 24,
|
17 |
+
num_heads: int = 32,
|
18 |
+
num_kv_heads: Optional[int] = None,
|
19 |
+
qkv_bias: bool = False,
|
20 |
+
qk_norm: bool = False,
|
21 |
+
window_size: Optional[int] = None,
|
22 |
+
use_output_gate: bool = False,
|
23 |
+
hidden_ratio: Optional[int] = 4,
|
24 |
+
intermediate_size: Optional[int] = None,
|
25 |
+
hidden_act: str = "swish",
|
26 |
+
initializer_range: float = 0.006,
|
27 |
+
elementwise_affine: Optional[bool] = True,
|
28 |
+
norm_eps: float = 1e-6,
|
29 |
+
use_cache: bool = True,
|
30 |
+
pad_token_id: Optional[int] = None,
|
31 |
+
bos_token_id: int = 1,
|
32 |
+
eos_token_id: int = 2,
|
33 |
+
tie_word_embeddings: bool = False,
|
34 |
+
fuse_norm: bool = True,
|
35 |
+
fuse_swiglu: bool = True,
|
36 |
+
fuse_cross_entropy: bool = True,
|
37 |
+
vocab_size: int = 32000,
|
38 |
+
**kwargs,
|
39 |
+
):
|
40 |
+
self.hidden_size = hidden_size
|
41 |
+
self.num_hidden_layers = num_hidden_layers
|
42 |
+
self.num_heads = num_heads
|
43 |
+
self.num_kv_heads = num_kv_heads
|
44 |
+
self.qkv_bias = qkv_bias
|
45 |
+
self.qk_norm = qk_norm
|
46 |
+
self.window_size = window_size
|
47 |
+
self.use_output_gate = use_output_gate
|
48 |
+
self.hidden_ratio = hidden_ratio
|
49 |
+
self.intermediate_size = intermediate_size
|
50 |
+
self.hidden_act = hidden_act
|
51 |
+
|
52 |
+
self.initializer_range = initializer_range
|
53 |
+
self.elementwise_affine = elementwise_affine
|
54 |
+
self.norm_eps = norm_eps
|
55 |
+
self.use_cache = use_cache
|
56 |
+
|
57 |
+
self.fuse_norm = fuse_norm
|
58 |
+
self.fuse_swiglu = fuse_swiglu
|
59 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
60 |
+
self.vocab_size = vocab_size
|
61 |
+
|
62 |
+
super().__init__(
|
63 |
+
pad_token_id=pad_token_id,
|
64 |
+
bos_token_id=bos_token_id,
|
65 |
+
eos_token_id=eos_token_id,
|
66 |
+
tie_word_embeddings=tie_word_embeddings,
|
67 |
+
**kwargs,
|
68 |
+
)
|
fla/models/gated_deltanet/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (803 Bytes). View file
|
|
fla/models/gated_deltaproduct/__pycache__/configuration_gated_deltaproduct.cpython-311.pyc
ADDED
Binary file (3.75 kB). View file
|
|
fla/models/gated_deltaproduct/configuration_gated_deltaproduct.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from typing import Dict, Optional
|
4 |
+
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class GatedDeltaProductConfig(PretrainedConfig):
|
9 |
+
model_type = "gated_deltaproduct"
|
10 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
attn_mode: str = "chunk",
|
15 |
+
hidden_size: int = 2048,
|
16 |
+
expand_v: int = 2,
|
17 |
+
use_gate: bool = True,
|
18 |
+
use_short_conv: bool = True,
|
19 |
+
conv_size: int = 4,
|
20 |
+
head_dim: int = 256,
|
21 |
+
num_heads: int = 6,
|
22 |
+
max_position_embeddings: int = 2048,
|
23 |
+
hidden_ratio: Optional[int] = 4,
|
24 |
+
intermediate_size: Optional[int] = None,
|
25 |
+
hidden_act: str = "swish",
|
26 |
+
num_hidden_layers: int = 21,
|
27 |
+
norm_first: bool = False,
|
28 |
+
norm_eps: float = 1e-6,
|
29 |
+
attn: Optional[Dict] = None,
|
30 |
+
use_cache: bool = True,
|
31 |
+
pad_token_id: int | None = None,
|
32 |
+
bos_token_id: int = 1,
|
33 |
+
eos_token_id: int = 2,
|
34 |
+
tie_word_embeddings: bool = False,
|
35 |
+
initializer_range: float = 0.006,
|
36 |
+
fuse_cross_entropy: bool = True,
|
37 |
+
vocab_size: int = 32000,
|
38 |
+
use_forget_gate: bool = False, # when true Gated DeltaProduct, when false DeltaProduct
|
39 |
+
allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1]
|
40 |
+
num_householder: int = 1,
|
41 |
+
**kwargs,
|
42 |
+
):
|
43 |
+
self.attn_mode = attn_mode
|
44 |
+
self.hidden_size = hidden_size
|
45 |
+
self.expand_v = expand_v
|
46 |
+
self.use_gate = use_gate
|
47 |
+
self.use_short_conv = use_short_conv
|
48 |
+
self.conv_size = conv_size
|
49 |
+
self.head_dim = head_dim
|
50 |
+
self.num_heads = num_heads
|
51 |
+
self.max_position_embeddings = max_position_embeddings
|
52 |
+
|
53 |
+
self.hidden_ratio = hidden_ratio
|
54 |
+
self.intermediate_size = intermediate_size
|
55 |
+
self.hidden_act = hidden_act
|
56 |
+
self.num_hidden_layers = num_hidden_layers
|
57 |
+
self.norm_first = norm_first
|
58 |
+
self.norm_eps = norm_eps
|
59 |
+
self.attn = attn
|
60 |
+
self.use_cache = use_cache
|
61 |
+
self.initializer_range = initializer_range
|
62 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
63 |
+
self.vocab_size = vocab_size
|
64 |
+
|
65 |
+
# DeltaProduct specific
|
66 |
+
self.allow_neg_eigval = allow_neg_eigval
|
67 |
+
self.num_householder = num_householder
|
68 |
+
self.use_forget_gate = use_forget_gate
|
69 |
+
|
70 |
+
if attn is not None:
|
71 |
+
if not isinstance(attn, Dict):
|
72 |
+
raise ValueError("attn must be a dictionary")
|
73 |
+
if "layers" not in attn:
|
74 |
+
raise ValueError(
|
75 |
+
"Layer indices must be provided to initialize hybrid attention layers"
|
76 |
+
)
|
77 |
+
if "num_heads" not in attn:
|
78 |
+
raise ValueError(
|
79 |
+
"Number of heads must be provided to initialize hybrid attention layers"
|
80 |
+
)
|
81 |
+
attn["num_kv_heads"] = attn.get("num_kv_heads", attn["num_heads"])
|
82 |
+
attn["window_size"] = attn.get("window_size", None)
|
83 |
+
|
84 |
+
super().__init__(
|
85 |
+
pad_token_id=pad_token_id,
|
86 |
+
bos_token_id=bos_token_id,
|
87 |
+
eos_token_id=eos_token_id,
|
88 |
+
tie_word_embeddings=tie_word_embeddings,
|
89 |
+
**kwargs,
|
90 |
+
)
|
fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers.activations import ACT2FN
|
13 |
+
from transformers.generation import GenerationMixin
|
14 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
15 |
+
from transformers.modeling_utils import PreTrainedModel
|
16 |
+
from transformers.utils import logging
|
17 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
18 |
+
|
19 |
+
from fla.layers.attn import Attention
|
20 |
+
from fla.layers.gated_deltaproduct import GatedDeltaProduct
|
21 |
+
from fla.models.gated_deltaproduct.configuration_gated_deltaproduct import GatedDeltaProductConfig
|
22 |
+
from fla.models.utils import Cache
|
23 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, RMSNorm
|
24 |
+
from fla.modules.activations import swiglu_linear
|
25 |
+
from fla.modules.layernorm import rms_norm_linear
|
26 |
+
|
27 |
+
if TYPE_CHECKING:
|
28 |
+
from transformers.processing_utils import Unpack
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
class GatedDeltaNetMLP(nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
hidden_size: int,
|
37 |
+
hidden_ratio: Optional[int] = None,
|
38 |
+
intermediate_size: Optional[int] = None,
|
39 |
+
hidden_act: str = "swish",
|
40 |
+
norm_first: bool = True,
|
41 |
+
norm_eps: float = 1e-5,
|
42 |
+
) -> GatedDeltaNetMLP:
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.hidden_size = hidden_size
|
46 |
+
# the final number of params is `hidden_ratio * hidden_size^2`
|
47 |
+
# `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
|
48 |
+
if hidden_ratio is None:
|
49 |
+
hidden_ratio = 4
|
50 |
+
if intermediate_size is None:
|
51 |
+
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
|
52 |
+
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
|
53 |
+
self.hidden_ratio = hidden_ratio
|
54 |
+
self.intermediate_size = intermediate_size
|
55 |
+
self.norm_first = norm_first
|
56 |
+
|
57 |
+
if norm_first:
|
58 |
+
self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
|
59 |
+
|
60 |
+
self.gate_proj = nn.Linear(
|
61 |
+
self.hidden_size, self.intermediate_size * 2, bias=False
|
62 |
+
)
|
63 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
64 |
+
self.act_fn = ACT2FN[hidden_act]
|
65 |
+
|
66 |
+
def forward(
|
67 |
+
self,
|
68 |
+
x: torch.Tensor,
|
69 |
+
**kwargs: Unpack[Dict],
|
70 |
+
) -> torch.Tensor:
|
71 |
+
if self.norm_first:
|
72 |
+
x = rms_norm_linear(
|
73 |
+
x,
|
74 |
+
self.norm.weight,
|
75 |
+
self.norm.bias,
|
76 |
+
self.gate_proj.weight,
|
77 |
+
self.gate_proj.bias,
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
x = self.gate_proj(x)
|
81 |
+
gate, y = x.chunk(2, -1)
|
82 |
+
return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
|
83 |
+
|
84 |
+
|
85 |
+
class GatedDeltaProductBlock(nn.Module):
|
86 |
+
def __init__(self, config: GatedDeltaProductConfig, layer_idx: int):
|
87 |
+
super().__init__()
|
88 |
+
self.hidden_size = config.hidden_size
|
89 |
+
|
90 |
+
if not config.norm_first:
|
91 |
+
self.attn_norm = RMSNorm(
|
92 |
+
hidden_size=config.hidden_size, eps=config.norm_eps
|
93 |
+
)
|
94 |
+
if config.attn is not None and layer_idx in config.attn["layers"]:
|
95 |
+
self.attn = Attention(
|
96 |
+
hidden_size=config.hidden_size,
|
97 |
+
num_heads=config.attn["num_heads"],
|
98 |
+
num_kv_heads=config.attn["num_kv_heads"],
|
99 |
+
window_size=config.attn["window_size"],
|
100 |
+
max_position_embeddings=config.max_position_embeddings,
|
101 |
+
layer_idx=layer_idx,
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
self.attn = GatedDeltaProduct(
|
105 |
+
mode=config.attn_mode,
|
106 |
+
hidden_size=config.hidden_size,
|
107 |
+
expand_v=config.expand_v,
|
108 |
+
head_dim=config.head_dim,
|
109 |
+
num_heads=config.num_heads,
|
110 |
+
use_gate=config.use_gate,
|
111 |
+
use_forget_gate=config.use_forget_gate,
|
112 |
+
use_short_conv=config.use_short_conv,
|
113 |
+
conv_size=config.conv_size,
|
114 |
+
norm_first=config.norm_first,
|
115 |
+
norm_eps=config.norm_eps,
|
116 |
+
allow_neg_eigval=config.allow_neg_eigval,
|
117 |
+
num_householder=config.num_householder,
|
118 |
+
layer_idx=layer_idx,
|
119 |
+
use_beta_conv=config.use_beta_conv
|
120 |
+
)
|
121 |
+
if not config.norm_first:
|
122 |
+
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
|
123 |
+
self.mlp = GatedDeltaNetMLP(
|
124 |
+
hidden_size=config.hidden_size,
|
125 |
+
hidden_ratio=config.hidden_ratio,
|
126 |
+
intermediate_size=config.intermediate_size,
|
127 |
+
hidden_act=config.hidden_act,
|
128 |
+
norm_first=config.norm_first,
|
129 |
+
norm_eps=config.norm_eps,
|
130 |
+
)
|
131 |
+
|
132 |
+
def forward(
|
133 |
+
self,
|
134 |
+
hidden_states: torch.Tensor,
|
135 |
+
attention_mask: Optional[torch.Tensor] = None,
|
136 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
137 |
+
use_cache: Optional[bool] = False,
|
138 |
+
output_attentions: Optional[bool] = False,
|
139 |
+
**kwargs: Unpack[Dict],
|
140 |
+
) -> Tuple[
|
141 |
+
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
142 |
+
]:
|
143 |
+
residual = hidden_states
|
144 |
+
if hasattr(self, "attn_norm"):
|
145 |
+
hidden_states = self.attn_norm(hidden_states)
|
146 |
+
hidden_states, attentions, past_key_values = self.attn(
|
147 |
+
hidden_states=hidden_states,
|
148 |
+
attention_mask=attention_mask,
|
149 |
+
past_key_values=past_key_values,
|
150 |
+
use_cache=use_cache,
|
151 |
+
output_attentions=output_attentions,
|
152 |
+
**kwargs,
|
153 |
+
)
|
154 |
+
if hasattr(self, "mlp_norm"):
|
155 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
156 |
+
else:
|
157 |
+
hidden_states = residual + hidden_states
|
158 |
+
residual = hidden_states
|
159 |
+
hidden_states = self.mlp(hidden_states, **kwargs)
|
160 |
+
hidden_states = residual + hidden_states
|
161 |
+
|
162 |
+
outputs = (hidden_states, attentions, past_key_values)
|
163 |
+
|
164 |
+
return outputs
|
165 |
+
|
166 |
+
|
167 |
+
class GatedDeltaProductPreTrainedModel(PreTrainedModel):
|
168 |
+
config_class = GatedDeltaProductConfig
|
169 |
+
supports_gradient_checkpointing = True
|
170 |
+
_no_split_modules = ["GatedDeltaNetBlock"]
|
171 |
+
|
172 |
+
def __init__(self, *inputs, **kwargs):
|
173 |
+
super().__init__(*inputs, **kwargs)
|
174 |
+
|
175 |
+
def _init_weights(
|
176 |
+
self,
|
177 |
+
module: nn.Module,
|
178 |
+
rescale_prenorm_residual: bool = True,
|
179 |
+
num_residuals_per_layer: int = 2,
|
180 |
+
):
|
181 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
182 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
183 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
184 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
185 |
+
if module.bias is not None:
|
186 |
+
nn.init.zeros_(module.bias)
|
187 |
+
elif isinstance(module, nn.Embedding):
|
188 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
189 |
+
if module.padding_idx is not None:
|
190 |
+
module.weight.data[module.padding_idx].zero_()
|
191 |
+
|
192 |
+
if rescale_prenorm_residual:
|
193 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
194 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
195 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
196 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
197 |
+
#
|
198 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
199 |
+
for name, p in module.named_parameters():
|
200 |
+
if name in ["o_proj.weight", "down_proj.weight"]:
|
201 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
202 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
203 |
+
# We need to reinit p since this code could be called multiple times
|
204 |
+
# Having just p *= scale would repeatedly scale it down
|
205 |
+
with torch.no_grad():
|
206 |
+
p /= math.sqrt(
|
207 |
+
num_residuals_per_layer * self.config.num_hidden_layers
|
208 |
+
)
|
209 |
+
|
210 |
+
|
211 |
+
class GatedDeltaProductModel(GatedDeltaProductPreTrainedModel):
|
212 |
+
def __init__(self, config: GatedDeltaProductConfig):
|
213 |
+
super().__init__(config)
|
214 |
+
self.padding_idx = config.pad_token_id
|
215 |
+
self.vocab_size = config.vocab_size
|
216 |
+
|
217 |
+
self.embeddings = nn.Embedding(
|
218 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
219 |
+
)
|
220 |
+
self.layers = nn.ModuleList(
|
221 |
+
[
|
222 |
+
GatedDeltaProductBlock(config, layer_idx)
|
223 |
+
for layer_idx in range(config.num_hidden_layers)
|
224 |
+
]
|
225 |
+
)
|
226 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
|
227 |
+
|
228 |
+
self.gradient_checkpointing = False
|
229 |
+
|
230 |
+
self.post_init()
|
231 |
+
|
232 |
+
def get_input_embeddings(self):
|
233 |
+
return self.embeddings
|
234 |
+
|
235 |
+
def set_input_embeddings(self, value):
|
236 |
+
self.embeddings = value
|
237 |
+
|
238 |
+
def forward(
|
239 |
+
self,
|
240 |
+
input_ids: Optional[torch.LongTensor] = None,
|
241 |
+
attention_mask: Optional[torch.Tensor] = None,
|
242 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
243 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
244 |
+
use_cache: Optional[bool] = None,
|
245 |
+
output_attentions: Optional[bool] = None,
|
246 |
+
output_hidden_states: Optional[bool] = None,
|
247 |
+
return_dict: Optional[bool] = None,
|
248 |
+
**kwargs: Unpack[Dict],
|
249 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
250 |
+
if output_attentions:
|
251 |
+
warnings.warn(
|
252 |
+
"`GatedDeltaNetModel` does not `output_attentions` now, setting it to `False`.",
|
253 |
+
stacklevel=2,
|
254 |
+
)
|
255 |
+
output_attentions = False
|
256 |
+
output_attentions = (
|
257 |
+
output_attentions
|
258 |
+
if output_attentions is not None
|
259 |
+
else self.config.output_attentions
|
260 |
+
)
|
261 |
+
output_hidden_states = (
|
262 |
+
output_hidden_states
|
263 |
+
if output_hidden_states is not None
|
264 |
+
else self.config.output_hidden_states
|
265 |
+
)
|
266 |
+
use_cache = (
|
267 |
+
use_cache
|
268 |
+
if use_cache is not None
|
269 |
+
else (self.config.use_cache if not self.training else False)
|
270 |
+
)
|
271 |
+
return_dict = (
|
272 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
273 |
+
)
|
274 |
+
|
275 |
+
# retrieve input_ids and inputs_embeds
|
276 |
+
if input_ids is not None and inputs_embeds is not None:
|
277 |
+
raise ValueError(
|
278 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
279 |
+
)
|
280 |
+
if input_ids is None and inputs_embeds is None:
|
281 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
282 |
+
|
283 |
+
if inputs_embeds is None:
|
284 |
+
inputs_embeds = self.embeddings(input_ids)
|
285 |
+
hidden_states = inputs_embeds
|
286 |
+
|
287 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
288 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
289 |
+
|
290 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
291 |
+
logger.warning_once(
|
292 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
293 |
+
)
|
294 |
+
use_cache = False
|
295 |
+
|
296 |
+
all_hidden_states = () if output_hidden_states else None
|
297 |
+
all_attns = () if output_attentions else None
|
298 |
+
for layer in self.layers:
|
299 |
+
if output_hidden_states:
|
300 |
+
all_hidden_states += (hidden_states,)
|
301 |
+
|
302 |
+
if self.gradient_checkpointing and self.training:
|
303 |
+
hidden_states, attentions, past_key_values = (
|
304 |
+
self._gradient_checkpointing_func(
|
305 |
+
layer.__call__,
|
306 |
+
hidden_states,
|
307 |
+
attention_mask,
|
308 |
+
past_key_values,
|
309 |
+
use_cache,
|
310 |
+
output_attentions,
|
311 |
+
**kwargs,
|
312 |
+
)
|
313 |
+
)
|
314 |
+
else:
|
315 |
+
hidden_states, attentions, past_key_values = layer(
|
316 |
+
hidden_states,
|
317 |
+
attention_mask=attention_mask,
|
318 |
+
past_key_values=past_key_values,
|
319 |
+
use_cache=use_cache,
|
320 |
+
output_attentions=output_attentions,
|
321 |
+
**kwargs,
|
322 |
+
)
|
323 |
+
|
324 |
+
if output_attentions:
|
325 |
+
all_attns += (attentions,)
|
326 |
+
|
327 |
+
hidden_states = self.norm(hidden_states)
|
328 |
+
# add hidden states from the last decoder layer
|
329 |
+
if output_hidden_states:
|
330 |
+
all_hidden_states += (hidden_states,)
|
331 |
+
|
332 |
+
if not return_dict:
|
333 |
+
return tuple(
|
334 |
+
i
|
335 |
+
for i in [
|
336 |
+
hidden_states,
|
337 |
+
past_key_values,
|
338 |
+
all_hidden_states,
|
339 |
+
all_attns,
|
340 |
+
]
|
341 |
+
if i is not None
|
342 |
+
)
|
343 |
+
return BaseModelOutputWithPast(
|
344 |
+
last_hidden_state=hidden_states,
|
345 |
+
past_key_values=past_key_values,
|
346 |
+
hidden_states=all_hidden_states,
|
347 |
+
attentions=all_attns,
|
348 |
+
)
|
349 |
+
|
350 |
+
|
351 |
+
class GatedDeltaProductForCausalLM(GatedDeltaProductPreTrainedModel, GenerationMixin):
|
352 |
+
_tied_weights_keys = ["lm_head.weight"]
|
353 |
+
|
354 |
+
def __init__(self, config):
|
355 |
+
super().__init__(config)
|
356 |
+
self.model = GatedDeltaProductModel(config)
|
357 |
+
self.vocab_size = config.vocab_size
|
358 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
359 |
+
|
360 |
+
# Initialize weights and apply final processing
|
361 |
+
self.post_init()
|
362 |
+
|
363 |
+
def get_input_embeddings(self):
|
364 |
+
return self.model.embeddings
|
365 |
+
|
366 |
+
def set_input_embeddings(self, value):
|
367 |
+
self.model.embeddings = value
|
368 |
+
|
369 |
+
def get_output_embeddings(self):
|
370 |
+
return self.lm_head
|
371 |
+
|
372 |
+
def set_output_embeddings(self, new_embeddings):
|
373 |
+
self.lm_head = new_embeddings
|
374 |
+
|
375 |
+
def set_decoder(self, decoder):
|
376 |
+
self.model = decoder
|
377 |
+
|
378 |
+
def get_decoder(self):
|
379 |
+
return self.model
|
380 |
+
|
381 |
+
def generate(self, *args, **kwargs):
|
382 |
+
try:
|
383 |
+
return super().generate(*args, **kwargs)
|
384 |
+
except AttributeError as exception:
|
385 |
+
if "past_key_values" in str(exception):
|
386 |
+
raise AttributeError(
|
387 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
388 |
+
f"which is not supported for {self.__class__.__name__}. "
|
389 |
+
f"Try another generation strategy instead. "
|
390 |
+
f"For the available generation strategies, check this doc: "
|
391 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
392 |
+
)
|
393 |
+
else:
|
394 |
+
raise exception
|
395 |
+
|
396 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
397 |
+
def prepare_inputs_for_generation(
|
398 |
+
self,
|
399 |
+
input_ids: torch.LongTensor = None,
|
400 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
401 |
+
attention_mask: Optional[torch.Tensor] = None,
|
402 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
403 |
+
use_cache: bool = True,
|
404 |
+
num_logits_to_keep: Optional[int] = None,
|
405 |
+
logits_to_keep: Optional[int] = None,
|
406 |
+
**kwargs,
|
407 |
+
):
|
408 |
+
# only last token for `inputs_ids` if the `past_key_values` is passed along is not empty.
|
409 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
410 |
+
input_ids = input_ids[:, -1:]
|
411 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
412 |
+
if inputs_embeds is not None and past_key_values is None:
|
413 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
414 |
+
else:
|
415 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
416 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
417 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
418 |
+
# TODO: use `next_tokens` directly instead.
|
419 |
+
model_inputs = {"input_ids": input_ids.contiguous()}
|
420 |
+
|
421 |
+
if logits_to_keep is not None:
|
422 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
423 |
+
|
424 |
+
model_inputs.update(
|
425 |
+
{
|
426 |
+
"past_key_values": past_key_values,
|
427 |
+
"use_cache": use_cache,
|
428 |
+
"attention_mask": attention_mask,
|
429 |
+
"num_logits_to_keep": num_logits_to_keep,
|
430 |
+
}
|
431 |
+
)
|
432 |
+
return model_inputs
|
433 |
+
|
434 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
435 |
+
def forward(
|
436 |
+
self,
|
437 |
+
input_ids: torch.LongTensor = None,
|
438 |
+
attention_mask: Optional[torch.Tensor] = None,
|
439 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
440 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
441 |
+
labels: Optional[torch.LongTensor] = None,
|
442 |
+
use_cache: Optional[bool] = None,
|
443 |
+
output_attentions: Optional[bool] = None,
|
444 |
+
output_hidden_states: Optional[bool] = None,
|
445 |
+
return_dict: Optional[bool] = None,
|
446 |
+
num_logits_to_keep: Optional[int] = 0,
|
447 |
+
logits_to_keep: Optional[int] = 0,
|
448 |
+
**kwargs: Unpack[Dict],
|
449 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
450 |
+
num_logits_to_keep = 0 if num_logits_to_keep is None else num_logits_to_keep
|
451 |
+
output_attentions = (
|
452 |
+
output_attentions
|
453 |
+
if output_attentions is not None
|
454 |
+
else self.config.output_attentions
|
455 |
+
)
|
456 |
+
output_hidden_states = (
|
457 |
+
output_hidden_states
|
458 |
+
if output_hidden_states is not None
|
459 |
+
else self.config.output_hidden_states
|
460 |
+
)
|
461 |
+
return_dict = (
|
462 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
463 |
+
)
|
464 |
+
kwargs.pop("num_items_in_batch", None)
|
465 |
+
outputs = self.model(
|
466 |
+
input_ids=input_ids,
|
467 |
+
attention_mask=attention_mask,
|
468 |
+
inputs_embeds=inputs_embeds,
|
469 |
+
past_key_values=past_key_values,
|
470 |
+
use_cache=use_cache,
|
471 |
+
output_attentions=output_attentions,
|
472 |
+
output_hidden_states=output_hidden_states,
|
473 |
+
return_dict=return_dict,
|
474 |
+
**kwargs,
|
475 |
+
)
|
476 |
+
hidden_states = outputs[0]
|
477 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
478 |
+
|
479 |
+
loss, logits = None, None
|
480 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
481 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
482 |
+
if labels is not None:
|
483 |
+
if self.config.fuse_cross_entropy:
|
484 |
+
if fuse_linear_and_cross_entropy:
|
485 |
+
loss_fct = FusedLinearCrossEntropyLoss()
|
486 |
+
else:
|
487 |
+
loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
|
488 |
+
else:
|
489 |
+
loss_fct = nn.CrossEntropyLoss()
|
490 |
+
# Enable model parallelism
|
491 |
+
labels = labels.to(hidden_states.device)
|
492 |
+
labels = torch.cat(
|
493 |
+
(
|
494 |
+
labels[..., 1:],
|
495 |
+
torch.full_like(labels[:, :1], loss_fct.ignore_index),
|
496 |
+
),
|
497 |
+
1,
|
498 |
+
)
|
499 |
+
if fuse_linear_and_cross_entropy:
|
500 |
+
loss = loss_fct(
|
501 |
+
hidden_states.view(-1, self.config.hidden_size),
|
502 |
+
labels.view(-1),
|
503 |
+
self.lm_head.weight,
|
504 |
+
self.lm_head.bias,
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
loss = loss_fct(
|
508 |
+
logits.view(-1, self.config.vocab_size), labels.view(-1)
|
509 |
+
)
|
510 |
+
|
511 |
+
if not return_dict:
|
512 |
+
output = (logits,) + outputs[1:]
|
513 |
+
return (loss, *output) if loss is not None else output
|
514 |
+
return CausalLMOutputWithPast(
|
515 |
+
loss=loss,
|
516 |
+
logits=logits,
|
517 |
+
past_key_values=outputs.past_key_values,
|
518 |
+
hidden_states=outputs.hidden_states,
|
519 |
+
attentions=outputs.attentions,
|
520 |
+
)
|
fla/models/gla/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
4 |
+
|
5 |
+
from fla.models.gla.configuration_gla import GLAConfig
|
6 |
+
from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
|
7 |
+
|
8 |
+
AutoConfig.register(GLAConfig.model_type, GLAConfig)
|
9 |
+
AutoModel.register(GLAConfig, GLAModel)
|
10 |
+
AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
|
11 |
+
|
12 |
+
|
13 |
+
__all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']
|
fla/models/gla/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (714 Bytes). View file
|
|
fla/models/gsa/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (714 Bytes). View file
|
|
fla/models/gsa/__pycache__/configuration_gsa.cpython-311.pyc
ADDED
Binary file (4.27 kB). View file
|
|
fla/models/gsa/configuration_gsa.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from typing import Dict, Optional
|
4 |
+
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class GSAConfig(PretrainedConfig):
|
9 |
+
|
10 |
+
model_type = 'gsa'
|
11 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
hidden_size: int = 2048,
|
16 |
+
gate_logit_normalizer: Optional[int] = 8,
|
17 |
+
clamp_min: Optional[float] = None,
|
18 |
+
clamp_max: Optional[float] = None,
|
19 |
+
hidden_ratio: Optional[int] = 4,
|
20 |
+
intermediate_size: Optional[int] = None,
|
21 |
+
num_hidden_layers: int = 24,
|
22 |
+
num_heads: int = 4,
|
23 |
+
num_kv_heads: Optional[int] = None,
|
24 |
+
num_slots: Optional[int] = 64,
|
25 |
+
use_short_conv: bool = False,
|
26 |
+
conv_size: int = 4,
|
27 |
+
exapnd_k: float = 1,
|
28 |
+
exapnd_v: float = 1,
|
29 |
+
feature_map: str = 'swish',
|
30 |
+
use_output_gate: bool = False,
|
31 |
+
use_norm: bool = True,
|
32 |
+
max_position_embeddings: int = 2048,
|
33 |
+
hidden_act: str = "swish",
|
34 |
+
elementwise_affine: Optional[bool] = True,
|
35 |
+
norm_eps: float = 1e-6,
|
36 |
+
attn: Optional[Dict] = None,
|
37 |
+
use_cache: bool = True,
|
38 |
+
pad_token_id: int = None,
|
39 |
+
bos_token_id: int = 1,
|
40 |
+
eos_token_id: int = 2,
|
41 |
+
initializer_range: float = 0.006,
|
42 |
+
tie_word_embeddings: bool = False,
|
43 |
+
fuse_norm: bool = True,
|
44 |
+
fuse_swiglu: bool = True,
|
45 |
+
fuse_cross_entropy: bool = True,
|
46 |
+
vocab_size: int = 32000,
|
47 |
+
**kwargs
|
48 |
+
):
|
49 |
+
self.hidden_size = hidden_size
|
50 |
+
self.gate_logit_normalizer = gate_logit_normalizer
|
51 |
+
self.clamp_min = clamp_min
|
52 |
+
self.clamp_max = clamp_max
|
53 |
+
self.hidden_ratio = hidden_ratio
|
54 |
+
self.intermediate_size = intermediate_size
|
55 |
+
self.num_hidden_layers = num_hidden_layers
|
56 |
+
self.num_heads = num_heads
|
57 |
+
self.num_kv_heads = num_kv_heads
|
58 |
+
self.num_slots = num_slots
|
59 |
+
self.use_short_conv = use_short_conv
|
60 |
+
self.conv_size = conv_size
|
61 |
+
self.expand_k = exapnd_k
|
62 |
+
self.expand_v = exapnd_v
|
63 |
+
self.feature_map = feature_map
|
64 |
+
self.use_output_gate = use_output_gate
|
65 |
+
self.use_norm = use_norm
|
66 |
+
self.max_position_embeddings = max_position_embeddings
|
67 |
+
self.hidden_act = hidden_act
|
68 |
+
self.elementwise_affine = elementwise_affine
|
69 |
+
self.norm_eps = norm_eps
|
70 |
+
self.attn = attn
|
71 |
+
self.use_cache = use_cache
|
72 |
+
self.initializer_range = initializer_range
|
73 |
+
|
74 |
+
self.fuse_norm = fuse_norm
|
75 |
+
self.fuse_swiglu = fuse_swiglu
|
76 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
77 |
+
self.vocab_size = vocab_size
|
78 |
+
|
79 |
+
if attn is not None:
|
80 |
+
if not isinstance(attn, Dict):
|
81 |
+
raise ValueError("attn must be a dictionary")
|
82 |
+
if 'layers' not in attn:
|
83 |
+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
|
84 |
+
if 'num_heads' not in attn:
|
85 |
+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
|
86 |
+
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
|
87 |
+
attn['qkv_bias'] = attn.get('qkv_bias', False)
|
88 |
+
attn['window_size'] = attn.get('window_size', None)
|
89 |
+
attn['rope_theta'] = attn.get('rope_theta', 10000.)
|
90 |
+
|
91 |
+
super().__init__(
|
92 |
+
pad_token_id=pad_token_id,
|
93 |
+
bos_token_id=bos_token_id,
|
94 |
+
eos_token_id=eos_token_id,
|
95 |
+
tie_word_embeddings=tie_word_embeddings,
|
96 |
+
**kwargs,
|
97 |
+
)
|
fla/models/gsa/modeling_gsa.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers.generation import GenerationMixin
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
+
from transformers.modeling_utils import PreTrainedModel
|
15 |
+
from transformers.utils import logging
|
16 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
17 |
+
|
18 |
+
from fla.layers.attn import Attention
|
19 |
+
from fla.layers.gsa import GatedSlotAttention
|
20 |
+
from fla.models.gsa.configuration_gsa import GSAConfig
|
21 |
+
from fla.models.utils import Cache
|
22 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
|
23 |
+
from fla.modules import GatedMLP as GSAMLP
|
24 |
+
from fla.modules import RMSNorm
|
25 |
+
|
26 |
+
if TYPE_CHECKING:
|
27 |
+
from transformers.processing_utils import Unpack
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class GSABlock(nn.Module):
|
33 |
+
def __init__(self, config: GSAConfig, layer_idx: int):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.config = config
|
37 |
+
self.layer_idx = layer_idx
|
38 |
+
|
39 |
+
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
40 |
+
if config.attn is not None and layer_idx in config.attn['layers']:
|
41 |
+
self.attn = Attention(
|
42 |
+
hidden_size=config.hidden_size,
|
43 |
+
num_heads=config.attn['num_heads'],
|
44 |
+
num_kv_heads=config.attn['num_kv_heads'],
|
45 |
+
qkv_bias=config.attn['qkv_bias'],
|
46 |
+
window_size=config.attn['window_size'],
|
47 |
+
rope_theta=config.attn['rope_theta'],
|
48 |
+
max_position_embeddings=config.max_position_embeddings,
|
49 |
+
layer_idx=layer_idx
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
self.attn = GatedSlotAttention(
|
53 |
+
hidden_size=config.hidden_size,
|
54 |
+
expand_k=config.expand_k,
|
55 |
+
expand_v=config.expand_v,
|
56 |
+
num_heads=config.num_heads,
|
57 |
+
num_kv_heads=config.num_kv_heads,
|
58 |
+
num_slots=config.num_slots,
|
59 |
+
use_short_conv=config.use_short_conv,
|
60 |
+
conv_size=config.conv_size,
|
61 |
+
feature_map=config.feature_map,
|
62 |
+
use_output_gate=config.use_output_gate,
|
63 |
+
use_norm=config.use_norm,
|
64 |
+
gate_fn=config.hidden_act,
|
65 |
+
gate_logit_normalizer=config.gate_logit_normalizer,
|
66 |
+
elementwise_affine=config.elementwise_affine,
|
67 |
+
norm_eps=config.norm_eps,
|
68 |
+
fuse_norm=config.fuse_norm,
|
69 |
+
layer_idx=layer_idx
|
70 |
+
)
|
71 |
+
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
72 |
+
self.mlp = GSAMLP(
|
73 |
+
hidden_size=config.hidden_size,
|
74 |
+
hidden_ratio=config.hidden_ratio,
|
75 |
+
intermediate_size=config.intermediate_size,
|
76 |
+
hidden_act=config.hidden_act,
|
77 |
+
fuse_swiglu=config.fuse_swiglu
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(
|
81 |
+
self,
|
82 |
+
hidden_states: torch.Tensor,
|
83 |
+
attention_mask: Optional[torch.Tensor] = None,
|
84 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
85 |
+
use_cache: Optional[bool] = False,
|
86 |
+
output_attentions: Optional[bool] = False,
|
87 |
+
**kwargs: Unpack[Dict]
|
88 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
89 |
+
residual = hidden_states
|
90 |
+
hidden_states = self.attn_norm(hidden_states)
|
91 |
+
hidden_states, attentions, past_key_values = self.attn(
|
92 |
+
hidden_states=hidden_states,
|
93 |
+
attention_mask=attention_mask,
|
94 |
+
past_key_values=past_key_values,
|
95 |
+
use_cache=use_cache,
|
96 |
+
output_attentions=output_attentions,
|
97 |
+
**kwargs
|
98 |
+
)
|
99 |
+
if self.config.fuse_norm:
|
100 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
101 |
+
else:
|
102 |
+
hidden_states = residual + hidden_states
|
103 |
+
residual = hidden_states
|
104 |
+
hidden_states = self.mlp_norm(hidden_states)
|
105 |
+
hidden_states = self.mlp(hidden_states, **kwargs)
|
106 |
+
hidden_states = residual + hidden_states
|
107 |
+
|
108 |
+
outputs = (hidden_states, attentions, past_key_values)
|
109 |
+
|
110 |
+
return outputs
|
111 |
+
|
112 |
+
|
113 |
+
class GSAPreTrainedModel(PreTrainedModel):
|
114 |
+
|
115 |
+
config_class = GSAConfig
|
116 |
+
base_model_prefix = 'model'
|
117 |
+
supports_gradient_checkpointing = True
|
118 |
+
_no_split_modules = ['GSABlock']
|
119 |
+
_supports_cache_class = True
|
120 |
+
|
121 |
+
def __init__(self, *inputs, **kwargs):
|
122 |
+
super().__init__(*inputs, **kwargs)
|
123 |
+
|
124 |
+
def _init_weights(
|
125 |
+
self,
|
126 |
+
module: nn.Module,
|
127 |
+
prenorm_residual_strategy: Optional[str] = 'rescale',
|
128 |
+
num_residuals_per_layer: int = 2,
|
129 |
+
):
|
130 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
131 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
132 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
133 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
134 |
+
if module.bias is not None:
|
135 |
+
nn.init.zeros_(module.bias)
|
136 |
+
elif isinstance(module, nn.Embedding):
|
137 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
138 |
+
elif hasattr(module, 'reset_parameters'):
|
139 |
+
module.reset_parameters()
|
140 |
+
|
141 |
+
if prenorm_residual_strategy is not None:
|
142 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
143 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
144 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
145 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
146 |
+
#
|
147 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
148 |
+
p = None
|
149 |
+
if hasattr(module, 'o_proj'):
|
150 |
+
p = module.o_proj.weight
|
151 |
+
elif hasattr(module, 'down_proj'):
|
152 |
+
p = module.down_proj.weight
|
153 |
+
if p is not None:
|
154 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
155 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
156 |
+
# We need to reinit p since this code could be called multiple times
|
157 |
+
# Having just p *= scale would repeatedly scale it down
|
158 |
+
if prenorm_residual_strategy == 'rescale':
|
159 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
160 |
+
with torch.no_grad():
|
161 |
+
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
|
162 |
+
elif prenorm_residual_strategy == 'zero':
|
163 |
+
nn.init.zeros_(p)
|
164 |
+
else:
|
165 |
+
raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
|
166 |
+
|
167 |
+
|
168 |
+
class GSAModel(GSAPreTrainedModel):
|
169 |
+
|
170 |
+
def __init__(self, config: GSAConfig):
|
171 |
+
super().__init__(config)
|
172 |
+
self.padding_idx = config.pad_token_id
|
173 |
+
self.vocab_size = config.vocab_size
|
174 |
+
|
175 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
176 |
+
self.layers = nn.ModuleList([GSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
177 |
+
self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
178 |
+
|
179 |
+
self.gradient_checkpointing = False
|
180 |
+
|
181 |
+
self.post_init()
|
182 |
+
|
183 |
+
def get_input_embeddings(self):
|
184 |
+
return self.embeddings
|
185 |
+
|
186 |
+
def set_input_embeddings(self, value):
|
187 |
+
self.embeddings = value
|
188 |
+
|
189 |
+
def forward(
|
190 |
+
self,
|
191 |
+
input_ids: Optional[torch.LongTensor] = None,
|
192 |
+
attention_mask: Optional[torch.Tensor] = None, # noqa
|
193 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
194 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
195 |
+
use_cache: Optional[bool] = None,
|
196 |
+
output_attentions: Optional[bool] = None,
|
197 |
+
output_hidden_states: Optional[bool] = None,
|
198 |
+
return_dict: Optional[bool] = None,
|
199 |
+
**kwargs: Unpack[Dict]
|
200 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
201 |
+
if output_attentions:
|
202 |
+
warnings.warn("`GSAModel` does not `output_attentions` now, setting it to `False`.")
|
203 |
+
output_attentions = False
|
204 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
205 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
206 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
207 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
208 |
+
|
209 |
+
# retrieve input_ids and inputs_embeds
|
210 |
+
if input_ids is not None and inputs_embeds is not None:
|
211 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
212 |
+
if input_ids is None and inputs_embeds is None:
|
213 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
214 |
+
|
215 |
+
if inputs_embeds is None:
|
216 |
+
inputs_embeds = self.embeddings(input_ids)
|
217 |
+
hidden_states = inputs_embeds
|
218 |
+
|
219 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
220 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
221 |
+
|
222 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
223 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
224 |
+
use_cache = False
|
225 |
+
|
226 |
+
all_hidden_states = () if output_hidden_states else None
|
227 |
+
all_attns = () if output_attentions else None
|
228 |
+
for layer in self.layers:
|
229 |
+
if output_hidden_states:
|
230 |
+
all_hidden_states += (hidden_states,)
|
231 |
+
|
232 |
+
if self.gradient_checkpointing and self.training:
|
233 |
+
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
|
234 |
+
layer.__call__,
|
235 |
+
hidden_states,
|
236 |
+
attention_mask,
|
237 |
+
past_key_values,
|
238 |
+
use_cache,
|
239 |
+
output_attentions,
|
240 |
+
**kwargs
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
hidden_states, attentions, past_key_values = layer(
|
244 |
+
hidden_states,
|
245 |
+
attention_mask=attention_mask,
|
246 |
+
past_key_values=past_key_values,
|
247 |
+
use_cache=use_cache,
|
248 |
+
output_attentions=output_attentions,
|
249 |
+
**kwargs
|
250 |
+
)
|
251 |
+
|
252 |
+
if output_attentions:
|
253 |
+
all_attns += (attentions,)
|
254 |
+
|
255 |
+
hidden_states = self.norm(hidden_states)
|
256 |
+
|
257 |
+
# add hidden states from the last decoder layer
|
258 |
+
if output_hidden_states:
|
259 |
+
all_hidden_states += (hidden_states,)
|
260 |
+
|
261 |
+
if not return_dict:
|
262 |
+
return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
|
263 |
+
return BaseModelOutputWithPast(
|
264 |
+
last_hidden_state=hidden_states,
|
265 |
+
past_key_values=past_key_values,
|
266 |
+
hidden_states=all_hidden_states,
|
267 |
+
attentions=all_attns
|
268 |
+
)
|
269 |
+
|
270 |
+
|
271 |
+
class GSAForCausalLM(GSAPreTrainedModel, GenerationMixin):
|
272 |
+
|
273 |
+
_tied_weights_keys = ["lm_head.weight"]
|
274 |
+
|
275 |
+
def __init__(self, config):
|
276 |
+
|
277 |
+
super().__init__(config)
|
278 |
+
self.model = GSAModel(config)
|
279 |
+
self.vocab_size = config.vocab_size
|
280 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
281 |
+
self.criterion = None
|
282 |
+
|
283 |
+
# Initialize weights and apply final processing
|
284 |
+
self.post_init()
|
285 |
+
|
286 |
+
def get_input_embeddings(self):
|
287 |
+
return self.model.embeddings
|
288 |
+
|
289 |
+
def set_input_embeddings(self, value):
|
290 |
+
self.model.embeddings = value
|
291 |
+
|
292 |
+
def get_output_embeddings(self):
|
293 |
+
return self.lm_head
|
294 |
+
|
295 |
+
def set_output_embeddings(self, new_embeddings):
|
296 |
+
self.lm_head = new_embeddings
|
297 |
+
|
298 |
+
def set_decoder(self, decoder):
|
299 |
+
self.model = decoder
|
300 |
+
|
301 |
+
def get_decoder(self):
|
302 |
+
return self.model
|
303 |
+
|
304 |
+
def generate(self, *args, **kwargs):
|
305 |
+
try:
|
306 |
+
return super().generate(*args, **kwargs)
|
307 |
+
except AttributeError as exception:
|
308 |
+
if 'past_key_values' in str(exception):
|
309 |
+
raise AttributeError(
|
310 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
311 |
+
f"which is not supported for {self.__class__.__name__}. "
|
312 |
+
f"Try another generation strategy instead. "
|
313 |
+
f"For the available generation strategies, check this doc: "
|
314 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
315 |
+
)
|
316 |
+
else:
|
317 |
+
raise exception
|
318 |
+
|
319 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
320 |
+
def prepare_inputs_for_generation(
|
321 |
+
self,
|
322 |
+
input_ids: torch.LongTensor = None,
|
323 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
324 |
+
attention_mask: Optional[torch.Tensor] = None,
|
325 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
326 |
+
use_cache: bool = True,
|
327 |
+
logits_to_keep: Optional[int] = None,
|
328 |
+
**kwargs
|
329 |
+
):
|
330 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
331 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
332 |
+
input_ids = input_ids[:, -1:]
|
333 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
334 |
+
if inputs_embeds is not None and len(past_key_values) == 0:
|
335 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
336 |
+
else:
|
337 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
338 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
339 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
340 |
+
# TODO: use `next_tokens` directly instead.
|
341 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
342 |
+
|
343 |
+
if logits_to_keep is not None:
|
344 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
345 |
+
|
346 |
+
model_inputs.update({
|
347 |
+
'past_key_values': past_key_values,
|
348 |
+
'use_cache': use_cache,
|
349 |
+
'attention_mask': attention_mask,
|
350 |
+
})
|
351 |
+
return model_inputs
|
352 |
+
|
353 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
354 |
+
def forward(
|
355 |
+
self,
|
356 |
+
input_ids: torch.LongTensor = None,
|
357 |
+
attention_mask: Optional[torch.Tensor] = None,
|
358 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
359 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
360 |
+
labels: Optional[torch.LongTensor] = None,
|
361 |
+
use_cache: Optional[bool] = None,
|
362 |
+
output_attentions: Optional[bool] = None,
|
363 |
+
output_hidden_states: Optional[bool] = None,
|
364 |
+
return_dict: Optional[bool] = None,
|
365 |
+
logits_to_keep: Optional[int] = 0,
|
366 |
+
**kwargs: Unpack[Dict]
|
367 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
368 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
369 |
+
output_hidden_states = (
|
370 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
371 |
+
)
|
372 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
373 |
+
|
374 |
+
outputs = self.model(
|
375 |
+
input_ids=input_ids,
|
376 |
+
attention_mask=attention_mask,
|
377 |
+
inputs_embeds=inputs_embeds,
|
378 |
+
past_key_values=past_key_values,
|
379 |
+
use_cache=use_cache,
|
380 |
+
output_attentions=output_attentions,
|
381 |
+
output_hidden_states=output_hidden_states,
|
382 |
+
return_dict=return_dict,
|
383 |
+
**kwargs
|
384 |
+
)
|
385 |
+
|
386 |
+
hidden_states = outputs[0]
|
387 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
388 |
+
|
389 |
+
loss, logits = None, None
|
390 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
391 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
392 |
+
if labels is not None:
|
393 |
+
if getattr(self, 'criterion', None) is None:
|
394 |
+
if fuse_linear_and_cross_entropy:
|
395 |
+
criterion = FusedLinearCrossEntropyLoss()
|
396 |
+
elif self.config.fuse_cross_entropy:
|
397 |
+
criterion = FusedCrossEntropyLoss(inplace_backward=True)
|
398 |
+
else:
|
399 |
+
criterion = nn.CrossEntropyLoss()
|
400 |
+
else:
|
401 |
+
criterion = self.criterion
|
402 |
+
# Enable model parallelism
|
403 |
+
labels = labels.to(hidden_states.device)
|
404 |
+
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
|
405 |
+
if fuse_linear_and_cross_entropy:
|
406 |
+
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
|
407 |
+
else:
|
408 |
+
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
|
409 |
+
|
410 |
+
if not return_dict:
|
411 |
+
output = (logits,) + outputs[1:]
|
412 |
+
return (loss,) + output if loss is not None else output
|
413 |
+
|
414 |
+
return CausalLMOutputWithPast(
|
415 |
+
loss=loss,
|
416 |
+
logits=logits,
|
417 |
+
past_key_values=outputs.past_key_values,
|
418 |
+
hidden_states=outputs.hidden_states,
|
419 |
+
attentions=outputs.attentions,
|
420 |
+
)
|
fla/models/hgrn/modeling_hgrn.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers.generation import GenerationMixin
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
+
from transformers.modeling_utils import PreTrainedModel
|
15 |
+
from transformers.utils import logging
|
16 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
17 |
+
|
18 |
+
from fla.layers.attn import Attention
|
19 |
+
from fla.layers.hgrn import HGRNAttention
|
20 |
+
from fla.models.hgrn.configuration_hgrn import HGRNConfig
|
21 |
+
from fla.models.utils import Cache
|
22 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
|
23 |
+
from fla.modules import GatedMLP as HGRNMLP
|
24 |
+
from fla.modules import RMSNorm
|
25 |
+
|
26 |
+
if TYPE_CHECKING:
|
27 |
+
from transformers.processing_utils import Unpack
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class HGRNBlock(nn.Module):
|
33 |
+
def __init__(self, config: HGRNConfig, layer_idx: int):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.config = config
|
37 |
+
self.layer_idx = layer_idx
|
38 |
+
|
39 |
+
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
40 |
+
if config.attn is not None and layer_idx in config.attn['layers']:
|
41 |
+
self.attn = Attention(
|
42 |
+
hidden_size=config.hidden_size,
|
43 |
+
num_heads=config.attn['num_heads'],
|
44 |
+
num_kv_heads=config.attn['num_kv_heads'],
|
45 |
+
qkv_bias=config.attn['qkv_bias'],
|
46 |
+
window_size=config.attn['window_size'],
|
47 |
+
rope_theta=config.attn['rope_theta'],
|
48 |
+
max_position_embeddings=config.max_position_embeddings,
|
49 |
+
layer_idx=layer_idx
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
self.attn = HGRNAttention(
|
53 |
+
mode=config.attn_mode,
|
54 |
+
hidden_size=config.hidden_size,
|
55 |
+
expand_ratio=config.expand_ratio,
|
56 |
+
use_short_conv=config.use_short_conv,
|
57 |
+
conv_size=config.conv_size,
|
58 |
+
elementwise_affine=config.elementwise_affine,
|
59 |
+
norm_eps=config.norm_eps,
|
60 |
+
layer_idx=layer_idx
|
61 |
+
)
|
62 |
+
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
63 |
+
self.mlp = HGRNMLP(
|
64 |
+
hidden_size=config.hidden_size,
|
65 |
+
hidden_ratio=config.hidden_ratio,
|
66 |
+
intermediate_size=config.intermediate_size,
|
67 |
+
hidden_act=config.hidden_act,
|
68 |
+
fuse_swiglu=config.fuse_swiglu
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self,
|
73 |
+
hidden_states: torch.Tensor,
|
74 |
+
attention_mask: Optional[torch.Tensor] = None,
|
75 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
76 |
+
use_cache: Optional[bool] = False,
|
77 |
+
output_attentions: Optional[bool] = False,
|
78 |
+
lower_bound: Optional[torch.Tensor] = False,
|
79 |
+
**kwargs: Unpack[Dict]
|
80 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
81 |
+
residual = hidden_states
|
82 |
+
hidden_states = self.attn_norm(hidden_states)
|
83 |
+
hidden_states, attentions, past_key_values = self.attn(
|
84 |
+
hidden_states=hidden_states,
|
85 |
+
attention_mask=attention_mask,
|
86 |
+
past_key_values=past_key_values,
|
87 |
+
use_cache=use_cache,
|
88 |
+
output_attentions=output_attentions,
|
89 |
+
lower_bound=lower_bound,
|
90 |
+
**kwargs
|
91 |
+
)
|
92 |
+
if self.config.fuse_norm:
|
93 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
94 |
+
else:
|
95 |
+
hidden_states = residual + hidden_states
|
96 |
+
residual = hidden_states
|
97 |
+
hidden_states = self.mlp_norm(hidden_states)
|
98 |
+
hidden_states = self.mlp(hidden_states, **kwargs)
|
99 |
+
hidden_states = residual + hidden_states
|
100 |
+
|
101 |
+
outputs = (hidden_states, attentions, past_key_values)
|
102 |
+
|
103 |
+
return outputs
|
104 |
+
|
105 |
+
|
106 |
+
class HGRNPreTrainedModel(PreTrainedModel):
|
107 |
+
|
108 |
+
config_class = HGRNConfig
|
109 |
+
base_model_prefix = 'model'
|
110 |
+
supports_gradient_checkpointing = True
|
111 |
+
_no_split_modules = ['HGRNBlock']
|
112 |
+
_supports_cache_class = True
|
113 |
+
|
114 |
+
def __init__(self, *inputs, **kwargs):
|
115 |
+
super().__init__(*inputs, **kwargs)
|
116 |
+
|
117 |
+
def _init_weights(
|
118 |
+
self,
|
119 |
+
module: nn.Module,
|
120 |
+
prenorm_residual_strategy: Optional[str] = 'rescale',
|
121 |
+
num_residuals_per_layer: int = 2,
|
122 |
+
):
|
123 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
124 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
125 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
126 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
127 |
+
if module.bias is not None:
|
128 |
+
nn.init.zeros_(module.bias)
|
129 |
+
elif isinstance(module, nn.Embedding):
|
130 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
131 |
+
elif hasattr(module, 'reset_parameters'):
|
132 |
+
module.reset_parameters()
|
133 |
+
|
134 |
+
if prenorm_residual_strategy is not None:
|
135 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
136 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
137 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
138 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
139 |
+
#
|
140 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
141 |
+
p = None
|
142 |
+
if hasattr(module, 'o_proj'):
|
143 |
+
p = module.o_proj.weight
|
144 |
+
elif hasattr(module, 'down_proj'):
|
145 |
+
p = module.down_proj.weight
|
146 |
+
if p is not None:
|
147 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
148 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
149 |
+
# We need to reinit p since this code could be called multiple times
|
150 |
+
# Having just p *= scale would repeatedly scale it down
|
151 |
+
if prenorm_residual_strategy == 'rescale':
|
152 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
153 |
+
with torch.no_grad():
|
154 |
+
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
|
155 |
+
elif prenorm_residual_strategy == 'zero':
|
156 |
+
nn.init.zeros_(p)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
|
159 |
+
|
160 |
+
|
161 |
+
class HGRNModel(HGRNPreTrainedModel):
|
162 |
+
|
163 |
+
def __init__(self, config: HGRNConfig):
|
164 |
+
super().__init__(config)
|
165 |
+
self.padding_idx = config.pad_token_id
|
166 |
+
self.vocab_size = config.vocab_size
|
167 |
+
|
168 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
169 |
+
if config.use_lower_bound:
|
170 |
+
self.lower_bounds = nn.Parameter(torch.zeros(config.num_hidden_layers, config.hidden_size))
|
171 |
+
self.layers = nn.ModuleList([HGRNBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
172 |
+
self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
173 |
+
|
174 |
+
self.gradient_checkpointing = False
|
175 |
+
|
176 |
+
self.post_init()
|
177 |
+
|
178 |
+
def get_input_embeddings(self):
|
179 |
+
return self.embeddings
|
180 |
+
|
181 |
+
def set_input_embeddings(self, value):
|
182 |
+
self.embeddings = value
|
183 |
+
|
184 |
+
def forward(
|
185 |
+
self,
|
186 |
+
input_ids: Optional[torch.LongTensor] = None,
|
187 |
+
attention_mask: Optional[torch.Tensor] = None, # noqa
|
188 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
189 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
190 |
+
use_cache: Optional[bool] = None,
|
191 |
+
output_attentions: Optional[bool] = None,
|
192 |
+
output_hidden_states: Optional[bool] = None,
|
193 |
+
return_dict: Optional[bool] = None,
|
194 |
+
**kwargs: Unpack[Dict]
|
195 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
196 |
+
if output_attentions:
|
197 |
+
warnings.warn("`HGRNModel` does not `output_attentions` now, setting it to `False`.")
|
198 |
+
output_attentions = False
|
199 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
200 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
201 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
202 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
203 |
+
|
204 |
+
# retrieve input_ids and inputs_embeds
|
205 |
+
if input_ids is not None and inputs_embeds is not None:
|
206 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
207 |
+
if input_ids is None and inputs_embeds is None:
|
208 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
209 |
+
|
210 |
+
if inputs_embeds is None:
|
211 |
+
inputs_embeds = self.embeddings(input_ids)
|
212 |
+
hidden_states = inputs_embeds
|
213 |
+
|
214 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
215 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
216 |
+
|
217 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
218 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
219 |
+
use_cache = False
|
220 |
+
|
221 |
+
all_hidden_states = () if output_hidden_states else None
|
222 |
+
all_attns = () if output_attentions else None
|
223 |
+
|
224 |
+
if self.config.use_lower_bound:
|
225 |
+
lower_bounds = self.lower_bounds.softmax(0)
|
226 |
+
lower_bounds = lower_bounds.cumsum(0) - lower_bounds[0]
|
227 |
+
for i, layer in enumerate(self.layers):
|
228 |
+
if output_hidden_states:
|
229 |
+
all_hidden_states += (hidden_states,)
|
230 |
+
|
231 |
+
lower_bound = lower_bounds[i] if self.config.use_lower_bound else None
|
232 |
+
if self.gradient_checkpointing and self.training:
|
233 |
+
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
|
234 |
+
layer.__call__,
|
235 |
+
hidden_states,
|
236 |
+
attention_mask,
|
237 |
+
past_key_values,
|
238 |
+
use_cache,
|
239 |
+
output_attentions,
|
240 |
+
lower_bound,
|
241 |
+
**kwargs
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
hidden_states, attentions, past_key_values = layer(
|
245 |
+
hidden_states,
|
246 |
+
attention_mask=attention_mask,
|
247 |
+
past_key_values=past_key_values,
|
248 |
+
use_cache=use_cache,
|
249 |
+
output_attentions=output_attentions,
|
250 |
+
lower_bound=lower_bound,
|
251 |
+
**kwargs
|
252 |
+
)
|
253 |
+
|
254 |
+
if output_attentions:
|
255 |
+
all_attns += (attentions,)
|
256 |
+
|
257 |
+
hidden_states = self.norm(hidden_states)
|
258 |
+
|
259 |
+
# add hidden states from the last decoder layer
|
260 |
+
if output_hidden_states:
|
261 |
+
all_hidden_states += (hidden_states,)
|
262 |
+
|
263 |
+
if not return_dict:
|
264 |
+
return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
|
265 |
+
return BaseModelOutputWithPast(
|
266 |
+
last_hidden_state=hidden_states,
|
267 |
+
past_key_values=past_key_values,
|
268 |
+
hidden_states=all_hidden_states,
|
269 |
+
attentions=all_attns
|
270 |
+
)
|
271 |
+
|
272 |
+
|
273 |
+
class HGRNForCausalLM(HGRNPreTrainedModel, GenerationMixin):
|
274 |
+
|
275 |
+
_tied_weights_keys = ["lm_head.weight"]
|
276 |
+
|
277 |
+
def __init__(self, config):
|
278 |
+
super().__init__(config)
|
279 |
+
self.model = HGRNModel(config)
|
280 |
+
self.vocab_size = config.vocab_size
|
281 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
282 |
+
self.criterion = None
|
283 |
+
|
284 |
+
# Initialize weights and apply final processing
|
285 |
+
self.post_init()
|
286 |
+
|
287 |
+
def get_input_embeddings(self):
|
288 |
+
return self.model.embeddings
|
289 |
+
|
290 |
+
def set_input_embeddings(self, value):
|
291 |
+
self.model.embeddings = value
|
292 |
+
|
293 |
+
def get_output_embeddings(self):
|
294 |
+
return self.lm_head
|
295 |
+
|
296 |
+
def set_output_embeddings(self, new_embeddings):
|
297 |
+
self.lm_head = new_embeddings
|
298 |
+
|
299 |
+
def set_decoder(self, decoder):
|
300 |
+
self.model = decoder
|
301 |
+
|
302 |
+
def get_decoder(self):
|
303 |
+
return self.model
|
304 |
+
|
305 |
+
def generate(self, *args, **kwargs):
|
306 |
+
try:
|
307 |
+
return super().generate(*args, **kwargs)
|
308 |
+
except AttributeError as exception:
|
309 |
+
if 'past_key_values' in str(exception):
|
310 |
+
raise AttributeError(
|
311 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
312 |
+
f"which is not supported for {self.__class__.__name__}. "
|
313 |
+
f"Try another generation strategy instead. "
|
314 |
+
f"For the available generation strategies, check this doc: "
|
315 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
316 |
+
)
|
317 |
+
else:
|
318 |
+
raise exception
|
319 |
+
|
320 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
321 |
+
def prepare_inputs_for_generation(
|
322 |
+
self,
|
323 |
+
input_ids: torch.LongTensor = None,
|
324 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
325 |
+
attention_mask: Optional[torch.Tensor] = None,
|
326 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
327 |
+
use_cache: bool = True,
|
328 |
+
logits_to_keep: Optional[int] = None,
|
329 |
+
**kwargs: Unpack[Dict]
|
330 |
+
):
|
331 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
332 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
333 |
+
input_ids = input_ids[:, -1:]
|
334 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
335 |
+
if inputs_embeds is not None and len(past_key_values) == 0:
|
336 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
337 |
+
else:
|
338 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
339 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
340 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
341 |
+
# TODO: use `next_tokens` directly instead.
|
342 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
343 |
+
|
344 |
+
if logits_to_keep is not None:
|
345 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
346 |
+
|
347 |
+
model_inputs.update({
|
348 |
+
'past_key_values': past_key_values,
|
349 |
+
'use_cache': use_cache,
|
350 |
+
'attention_mask': attention_mask,
|
351 |
+
})
|
352 |
+
return model_inputs
|
353 |
+
|
354 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
355 |
+
def forward(
|
356 |
+
self,
|
357 |
+
input_ids: torch.LongTensor = None,
|
358 |
+
attention_mask: Optional[torch.Tensor] = None,
|
359 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
360 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
361 |
+
labels: Optional[torch.LongTensor] = None,
|
362 |
+
use_cache: Optional[bool] = None,
|
363 |
+
output_attentions: Optional[bool] = None,
|
364 |
+
output_hidden_states: Optional[bool] = None,
|
365 |
+
return_dict: Optional[bool] = None,
|
366 |
+
logits_to_keep: Optional[int] = 0,
|
367 |
+
**kwargs: Unpack[Dict]
|
368 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
369 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
370 |
+
output_hidden_states = (
|
371 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
372 |
+
)
|
373 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
374 |
+
|
375 |
+
outputs = self.model(
|
376 |
+
input_ids=input_ids,
|
377 |
+
attention_mask=attention_mask,
|
378 |
+
inputs_embeds=inputs_embeds,
|
379 |
+
past_key_values=past_key_values,
|
380 |
+
use_cache=use_cache,
|
381 |
+
output_attentions=output_attentions,
|
382 |
+
output_hidden_states=output_hidden_states,
|
383 |
+
return_dict=return_dict,
|
384 |
+
**kwargs
|
385 |
+
)
|
386 |
+
|
387 |
+
hidden_states = outputs[0]
|
388 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
389 |
+
|
390 |
+
loss, logits = None, None
|
391 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
392 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
393 |
+
if labels is not None:
|
394 |
+
if getattr(self, 'criterion', None) is None:
|
395 |
+
if fuse_linear_and_cross_entropy:
|
396 |
+
criterion = FusedLinearCrossEntropyLoss()
|
397 |
+
elif self.config.fuse_cross_entropy:
|
398 |
+
criterion = FusedCrossEntropyLoss(inplace_backward=True)
|
399 |
+
else:
|
400 |
+
criterion = nn.CrossEntropyLoss()
|
401 |
+
else:
|
402 |
+
criterion = self.criterion
|
403 |
+
labels = labels.to(hidden_states.device)
|
404 |
+
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
|
405 |
+
if fuse_linear_and_cross_entropy:
|
406 |
+
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
|
407 |
+
else:
|
408 |
+
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
|
409 |
+
|
410 |
+
if not return_dict:
|
411 |
+
output = (logits,) + outputs[1:]
|
412 |
+
return (loss,) + output if loss is not None else output
|
413 |
+
|
414 |
+
return CausalLMOutputWithPast(
|
415 |
+
loss=loss,
|
416 |
+
logits=logits,
|
417 |
+
past_key_values=outputs.past_key_values,
|
418 |
+
hidden_states=outputs.hidden_states,
|
419 |
+
attentions=outputs.attentions,
|
420 |
+
)
|
fla/models/linear_attn/configuration_linear_attn.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from typing import Dict, Optional
|
4 |
+
|
5 |
+
from transformers.configuration_utils import PretrainedConfig
|
6 |
+
|
7 |
+
|
8 |
+
class LinearAttentionConfig(PretrainedConfig):
|
9 |
+
|
10 |
+
model_type = 'linear_attn'
|
11 |
+
keys_to_ignore_at_inference = ['past_key_values']
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
attn_mode: str = "fused_chunk",
|
16 |
+
hidden_size: int = 2048,
|
17 |
+
expand_k: int = 1,
|
18 |
+
expand_v: int = 1,
|
19 |
+
hidden_ratio: Optional[int] = 4,
|
20 |
+
intermediate_size: Optional[int] = None,
|
21 |
+
num_hidden_layers: int = 24,
|
22 |
+
num_heads: int = 4,
|
23 |
+
num_kv_heads: Optional[int] = None,
|
24 |
+
feature_map: str = "elementwise_product",
|
25 |
+
tie_feature_map_qk: bool = False,
|
26 |
+
norm_q: bool = False,
|
27 |
+
norm_k: bool = False,
|
28 |
+
norm_feature_map: bool = False,
|
29 |
+
hidden_act: str = "swish",
|
30 |
+
max_position_embeddings: int = 2048,
|
31 |
+
elementwise_affine: Optional[bool] = True,
|
32 |
+
norm_eps: float = 1e-6,
|
33 |
+
attn: Optional[Dict] = None,
|
34 |
+
use_cache: bool = True,
|
35 |
+
pad_token_id: int = None,
|
36 |
+
bos_token_id: int = 1,
|
37 |
+
eos_token_id: int = 2,
|
38 |
+
tie_word_embeddings: bool = False,
|
39 |
+
initializer_range: float = 0.006,
|
40 |
+
fuse_norm: bool = True,
|
41 |
+
fuse_swiglu: bool = True,
|
42 |
+
fuse_cross_entropy: bool = True,
|
43 |
+
vocab_size: int = 32000,
|
44 |
+
**kwargs
|
45 |
+
):
|
46 |
+
self.attn_mode = attn_mode
|
47 |
+
self.hidden_size = hidden_size
|
48 |
+
self.expand_k = expand_k
|
49 |
+
self.expand_v = expand_v
|
50 |
+
self.hidden_ratio = hidden_ratio
|
51 |
+
self.intermediate_size = intermediate_size
|
52 |
+
self.num_hidden_layers = num_hidden_layers
|
53 |
+
self.num_heads = num_heads
|
54 |
+
self.num_kv_heads = num_kv_heads
|
55 |
+
self.feature_map = feature_map
|
56 |
+
self.tie_feature_map_qk = tie_feature_map_qk
|
57 |
+
self.norm_q = norm_q
|
58 |
+
self.norm_k = norm_k
|
59 |
+
self.norm_feature_map = norm_feature_map
|
60 |
+
self.hidden_act = hidden_act
|
61 |
+
self.max_position_embeddings = max_position_embeddings
|
62 |
+
self.elementwise_affine = elementwise_affine
|
63 |
+
self.norm_eps = norm_eps
|
64 |
+
self.attn = attn
|
65 |
+
self.use_cache = use_cache
|
66 |
+
self.initializer_range = initializer_range
|
67 |
+
|
68 |
+
self.fuse_norm = fuse_norm
|
69 |
+
self.fuse_swiglu = fuse_swiglu
|
70 |
+
self.fuse_cross_entropy = fuse_cross_entropy
|
71 |
+
self.vocab_size = vocab_size
|
72 |
+
|
73 |
+
if attn is not None:
|
74 |
+
if not isinstance(attn, Dict):
|
75 |
+
raise ValueError("attn must be a dictionary")
|
76 |
+
if 'layers' not in attn:
|
77 |
+
raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
|
78 |
+
if 'num_heads' not in attn:
|
79 |
+
raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
|
80 |
+
attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
|
81 |
+
attn['qkv_bias'] = attn.get('qkv_bias', False)
|
82 |
+
attn['window_size'] = attn.get('window_size', None)
|
83 |
+
attn['rope_theta'] = attn.get('rope_theta', 10000.)
|
84 |
+
|
85 |
+
super().__init__(
|
86 |
+
pad_token_id=pad_token_id,
|
87 |
+
bos_token_id=bos_token_id,
|
88 |
+
eos_token_id=eos_token_id,
|
89 |
+
tie_word_embeddings=tie_word_embeddings,
|
90 |
+
**kwargs,
|
91 |
+
)
|
fla/models/nsa/modeling_nsa.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import math
|
6 |
+
import warnings
|
7 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.checkpoint
|
12 |
+
from transformers.generation import GenerationMixin
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
14 |
+
from transformers.modeling_utils import PreTrainedModel
|
15 |
+
from transformers.utils import logging
|
16 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
17 |
+
|
18 |
+
from fla.layers.nsa import NativeSparseAttention
|
19 |
+
from fla.models.nsa.configuration_nsa import NSAConfig
|
20 |
+
from fla.models.utils import Cache
|
21 |
+
from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
|
22 |
+
from fla.modules import GatedMLP as NSAMLP
|
23 |
+
from fla.modules import RMSNorm
|
24 |
+
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from transformers.processing_utils import Unpack
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class NSABlock(nn.Module):
|
32 |
+
def __init__(self, config: NSAConfig, layer_idx: int):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.config = config
|
36 |
+
self.layer_idx = layer_idx
|
37 |
+
|
38 |
+
self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
39 |
+
self.attn = NativeSparseAttention(
|
40 |
+
hidden_size=config.hidden_size,
|
41 |
+
num_heads=config.num_heads,
|
42 |
+
num_kv_heads=config.num_kv_heads,
|
43 |
+
qkv_bias=config.qkv_bias,
|
44 |
+
block_size=config.block_size,
|
45 |
+
block_counts=config.block_counts,
|
46 |
+
window_size=config.window_size,
|
47 |
+
rope_theta=config.rope_theta,
|
48 |
+
max_position_embeddings=config.max_position_embeddings,
|
49 |
+
layer_idx=layer_idx
|
50 |
+
)
|
51 |
+
self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
52 |
+
self.mlp = NSAMLP(
|
53 |
+
hidden_size=config.hidden_size,
|
54 |
+
hidden_ratio=config.hidden_ratio,
|
55 |
+
intermediate_size=config.intermediate_size,
|
56 |
+
hidden_act=config.hidden_act,
|
57 |
+
fuse_swiglu=config.fuse_swiglu
|
58 |
+
)
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
hidden_states: torch.Tensor,
|
63 |
+
attention_mask: Optional[torch.Tensor] = None,
|
64 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
65 |
+
use_cache: Optional[bool] = False,
|
66 |
+
output_attentions: Optional[bool] = False,
|
67 |
+
**kwargs: Unpack[Dict]
|
68 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
69 |
+
residual = hidden_states
|
70 |
+
hidden_states = self.attn_norm(hidden_states)
|
71 |
+
hidden_states, attentions, past_key_values = self.attn(
|
72 |
+
hidden_states=hidden_states,
|
73 |
+
attention_mask=attention_mask,
|
74 |
+
past_key_values=past_key_values,
|
75 |
+
use_cache=use_cache,
|
76 |
+
output_attentions=output_attentions,
|
77 |
+
**kwargs
|
78 |
+
)
|
79 |
+
if self.config.fuse_norm:
|
80 |
+
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
|
81 |
+
else:
|
82 |
+
hidden_states = residual + hidden_states
|
83 |
+
residual = hidden_states
|
84 |
+
hidden_states = self.mlp_norm(hidden_states)
|
85 |
+
hidden_states = self.mlp(hidden_states, **kwargs)
|
86 |
+
hidden_states = residual + hidden_states
|
87 |
+
|
88 |
+
outputs = (hidden_states, attentions, past_key_values)
|
89 |
+
|
90 |
+
return outputs
|
91 |
+
|
92 |
+
|
93 |
+
class NSAPreTrainedModel(PreTrainedModel):
|
94 |
+
|
95 |
+
config_class = NSAConfig
|
96 |
+
base_model_prefix = 'model'
|
97 |
+
supports_gradient_checkpointing = True
|
98 |
+
_no_split_modules = ['NSABlock']
|
99 |
+
_supports_cache_class = True
|
100 |
+
|
101 |
+
def __init__(self, *inputs, **kwargs):
|
102 |
+
super().__init__(*inputs, **kwargs)
|
103 |
+
|
104 |
+
def _init_weights(
|
105 |
+
self,
|
106 |
+
module: nn.Module,
|
107 |
+
prenorm_residual_strategy: Optional[str] = 'rescale',
|
108 |
+
num_residuals_per_layer: int = 2,
|
109 |
+
):
|
110 |
+
if isinstance(module, (nn.Linear, nn.Conv1d)):
|
111 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
112 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
113 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
114 |
+
if module.bias is not None:
|
115 |
+
nn.init.zeros_(module.bias)
|
116 |
+
elif isinstance(module, nn.Embedding):
|
117 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
118 |
+
elif hasattr(module, 'reset_parameters'):
|
119 |
+
module.reset_parameters()
|
120 |
+
|
121 |
+
if prenorm_residual_strategy is not None:
|
122 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
123 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
124 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
125 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
126 |
+
#
|
127 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
128 |
+
p = None
|
129 |
+
if hasattr(module, 'o_proj'):
|
130 |
+
p = module.o_proj.weight
|
131 |
+
elif hasattr(module, 'down_proj'):
|
132 |
+
p = module.down_proj.weight
|
133 |
+
if p is not None:
|
134 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
135 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
136 |
+
# We need to reinit p since this code could be called multiple times
|
137 |
+
# Having just p *= scale would repeatedly scale it down
|
138 |
+
if prenorm_residual_strategy == 'rescale':
|
139 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
140 |
+
with torch.no_grad():
|
141 |
+
p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
|
142 |
+
elif prenorm_residual_strategy == 'zero':
|
143 |
+
nn.init.zeros_(p)
|
144 |
+
else:
|
145 |
+
raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
|
146 |
+
|
147 |
+
|
148 |
+
class NSAModel(NSAPreTrainedModel):
|
149 |
+
|
150 |
+
def __init__(self, config: NSAConfig):
|
151 |
+
super().__init__(config)
|
152 |
+
self.padding_idx = config.pad_token_id
|
153 |
+
self.vocab_size = config.vocab_size
|
154 |
+
|
155 |
+
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
156 |
+
self.layers = nn.ModuleList([NSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
157 |
+
self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
|
158 |
+
|
159 |
+
self.gradient_checkpointing = False
|
160 |
+
|
161 |
+
self.post_init()
|
162 |
+
|
163 |
+
def get_input_embeddings(self):
|
164 |
+
return self.embeddings
|
165 |
+
|
166 |
+
def set_input_embeddings(self, value):
|
167 |
+
self.embeddings = value
|
168 |
+
|
169 |
+
def forward(
|
170 |
+
self,
|
171 |
+
input_ids: Optional[torch.LongTensor] = None,
|
172 |
+
attention_mask: Optional[torch.Tensor] = None, # noqa
|
173 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
174 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
175 |
+
use_cache: Optional[bool] = None,
|
176 |
+
output_attentions: Optional[bool] = None,
|
177 |
+
output_hidden_states: Optional[bool] = None,
|
178 |
+
return_dict: Optional[bool] = None,
|
179 |
+
**kwargs: Unpack[Dict]
|
180 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
181 |
+
if output_attentions:
|
182 |
+
warnings.warn("`NSAModel` does not `output_attentions` now, setting it to `False`.")
|
183 |
+
output_attentions = False
|
184 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
185 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
186 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
187 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
188 |
+
|
189 |
+
# retrieve input_ids and inputs_embeds
|
190 |
+
if input_ids is not None and inputs_embeds is not None:
|
191 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
192 |
+
if input_ids is None and inputs_embeds is None:
|
193 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
194 |
+
|
195 |
+
if inputs_embeds is None:
|
196 |
+
inputs_embeds = self.embeddings(input_ids)
|
197 |
+
hidden_states = inputs_embeds
|
198 |
+
|
199 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
200 |
+
past_key_values = Cache.from_legacy_cache(past_key_values)
|
201 |
+
|
202 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
203 |
+
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
204 |
+
use_cache = False
|
205 |
+
|
206 |
+
all_hidden_states = () if output_hidden_states else None
|
207 |
+
all_attns = () if output_attentions else None
|
208 |
+
for layer in self.layers:
|
209 |
+
if output_hidden_states:
|
210 |
+
all_hidden_states += (hidden_states,)
|
211 |
+
|
212 |
+
if self.gradient_checkpointing and self.training:
|
213 |
+
hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
|
214 |
+
layer.__call__,
|
215 |
+
hidden_states,
|
216 |
+
attention_mask,
|
217 |
+
past_key_values,
|
218 |
+
use_cache,
|
219 |
+
output_attentions,
|
220 |
+
**kwargs
|
221 |
+
)
|
222 |
+
else:
|
223 |
+
hidden_states, attentions, past_key_values = layer(
|
224 |
+
hidden_states,
|
225 |
+
attention_mask=attention_mask,
|
226 |
+
past_key_values=past_key_values,
|
227 |
+
use_cache=use_cache,
|
228 |
+
output_attentions=output_attentions,
|
229 |
+
**kwargs
|
230 |
+
)
|
231 |
+
|
232 |
+
if output_attentions:
|
233 |
+
all_attns += (attentions,)
|
234 |
+
|
235 |
+
hidden_states = self.norm(hidden_states)
|
236 |
+
|
237 |
+
# add hidden states from the last decoder layer
|
238 |
+
if output_hidden_states:
|
239 |
+
all_hidden_states += (hidden_states,)
|
240 |
+
|
241 |
+
if not return_dict:
|
242 |
+
return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
|
243 |
+
return BaseModelOutputWithPast(
|
244 |
+
last_hidden_state=hidden_states,
|
245 |
+
past_key_values=past_key_values,
|
246 |
+
hidden_states=all_hidden_states,
|
247 |
+
attentions=all_attns
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
class NSAForCausalLM(NSAPreTrainedModel, GenerationMixin):
|
252 |
+
|
253 |
+
_tied_weights_keys = ["lm_head.weight"]
|
254 |
+
|
255 |
+
def __init__(self, config):
|
256 |
+
super().__init__(config)
|
257 |
+
self.model = NSAModel(config)
|
258 |
+
self.vocab_size = config.vocab_size
|
259 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
260 |
+
self.criterion = None
|
261 |
+
|
262 |
+
# Initialize weights and apply final processing
|
263 |
+
self.post_init()
|
264 |
+
|
265 |
+
def get_input_embeddings(self):
|
266 |
+
return self.model.embeddings
|
267 |
+
|
268 |
+
def set_input_embeddings(self, value):
|
269 |
+
self.model.embeddings = value
|
270 |
+
|
271 |
+
def get_output_embeddings(self):
|
272 |
+
return self.lm_head
|
273 |
+
|
274 |
+
def set_output_embeddings(self, new_embeddings):
|
275 |
+
self.lm_head = new_embeddings
|
276 |
+
|
277 |
+
def set_decoder(self, decoder):
|
278 |
+
self.model = decoder
|
279 |
+
|
280 |
+
def get_decoder(self):
|
281 |
+
return self.model
|
282 |
+
|
283 |
+
def generate(self, *args, **kwargs):
|
284 |
+
try:
|
285 |
+
return super().generate(*args, **kwargs)
|
286 |
+
except AttributeError as exception:
|
287 |
+
if 'past_key_values' in str(exception):
|
288 |
+
raise AttributeError(
|
289 |
+
f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
|
290 |
+
f"which is not supported for {self.__class__.__name__}. "
|
291 |
+
f"Try another generation strategy instead. "
|
292 |
+
f"For the available generation strategies, check this doc: "
|
293 |
+
f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
raise exception
|
297 |
+
|
298 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
299 |
+
def prepare_inputs_for_generation(
|
300 |
+
self,
|
301 |
+
input_ids: torch.LongTensor = None,
|
302 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
303 |
+
attention_mask: Optional[torch.Tensor] = None,
|
304 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
305 |
+
use_cache: bool = True,
|
306 |
+
logits_to_keep: Optional[int] = None,
|
307 |
+
**kwargs
|
308 |
+
):
|
309 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
310 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
311 |
+
input_ids = input_ids[:, -1:]
|
312 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
313 |
+
if inputs_embeds is not None and len(past_key_values) == 0:
|
314 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
315 |
+
else:
|
316 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
317 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
318 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
319 |
+
# TODO: use `next_tokens` directly instead.
|
320 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
321 |
+
|
322 |
+
if logits_to_keep is not None:
|
323 |
+
model_inputs['logits_to_keep'] = logits_to_keep
|
324 |
+
|
325 |
+
model_inputs.update({
|
326 |
+
'past_key_values': past_key_values,
|
327 |
+
'use_cache': use_cache,
|
328 |
+
'attention_mask': attention_mask,
|
329 |
+
})
|
330 |
+
return model_inputs
|
331 |
+
|
332 |
+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
333 |
+
def forward(
|
334 |
+
self,
|
335 |
+
input_ids: torch.LongTensor = None,
|
336 |
+
attention_mask: Optional[torch.Tensor] = None,
|
337 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
338 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
339 |
+
labels: Optional[torch.LongTensor] = None,
|
340 |
+
use_cache: Optional[bool] = None,
|
341 |
+
output_attentions: Optional[bool] = None,
|
342 |
+
output_hidden_states: Optional[bool] = None,
|
343 |
+
return_dict: Optional[bool] = None,
|
344 |
+
logits_to_keep: Optional[int] = 0,
|
345 |
+
**kwargs: Unpack[Dict]
|
346 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
347 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
348 |
+
output_hidden_states = (
|
349 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
350 |
+
)
|
351 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
352 |
+
|
353 |
+
outputs = self.model(
|
354 |
+
input_ids=input_ids,
|
355 |
+
attention_mask=attention_mask,
|
356 |
+
inputs_embeds=inputs_embeds,
|
357 |
+
past_key_values=past_key_values,
|
358 |
+
use_cache=use_cache,
|
359 |
+
output_attentions=output_attentions,
|
360 |
+
output_hidden_states=output_hidden_states,
|
361 |
+
return_dict=return_dict,
|
362 |
+
**kwargs
|
363 |
+
)
|
364 |
+
|
365 |
+
hidden_states = outputs[0]
|
366 |
+
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
367 |
+
|
368 |
+
loss, logits = None, None
|
369 |
+
if not fuse_linear_and_cross_entropy or labels is None:
|
370 |
+
logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
|
371 |
+
if labels is not None:
|
372 |
+
if getattr(self, 'criterion', None) is None:
|
373 |
+
if fuse_linear_and_cross_entropy:
|
374 |
+
criterion = FusedLinearCrossEntropyLoss()
|
375 |
+
elif self.config.fuse_cross_entropy:
|
376 |
+
criterion = FusedCrossEntropyLoss(inplace_backward=True)
|
377 |
+
else:
|
378 |
+
criterion = nn.CrossEntropyLoss()
|
379 |
+
else:
|
380 |
+
criterion = self.criterion
|
381 |
+
labels = labels.to(hidden_states.device)
|
382 |
+
labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
|
383 |
+
if fuse_linear_and_cross_entropy:
|
384 |
+
loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
|
385 |
+
else:
|
386 |
+
loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
|
387 |
+
|
388 |
+
if not return_dict:
|
389 |
+
output = (logits,) + outputs[1:]
|
390 |
+
return (loss,) + output if loss is not None else output
|
391 |
+
|
392 |
+
return CausalLMOutputWithPast(
|
393 |
+
loss=loss,
|
394 |
+
logits=logits,
|
395 |
+
past_key_values=outputs.past_key_values,
|
396 |
+
hidden_states=outputs.hidden_states,
|
397 |
+
attentions=outputs.attentions,
|
398 |
+
)
|
flame/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.1.0"
|
flame/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (163 Bytes). View file
|
|
flame/__pycache__/train.cpython-311.pyc
ADDED
Binary file (38.8 kB). View file
|
|
flame/models/parallelize_fla.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# This file applies the PT-D parallelisms (except pipeline parallelism) and various
|
8 |
+
# training techniques (e.g. activation checkpointing and compile) to the Llama model.
|
9 |
+
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.distributed import DeviceMesh
|
15 |
+
from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
|
16 |
+
from torch.distributed._composable.replicate import replicate
|
17 |
+
from torch.distributed._tensor import Replicate, Shard
|
18 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper
|
19 |
+
from torch.distributed.tensor.parallel import (
|
20 |
+
ColwiseParallel,
|
21 |
+
PrepareModuleInput,
|
22 |
+
PrepareModuleOutput,
|
23 |
+
RowwiseParallel,
|
24 |
+
SequenceParallel,
|
25 |
+
parallelize_module
|
26 |
+
)
|
27 |
+
|
28 |
+
from fla.modules.fused_linear_cross_entropy import LinearLossParallel
|
29 |
+
from fla.modules.mlp import SwiGLULinearParallel
|
30 |
+
from fla.modules.parallel import PrepareModuleWeight
|
31 |
+
from torchtitan.config_manager import TORCH_DTYPE_MAP, JobConfig
|
32 |
+
from torchtitan.distributed.parallel_dims import ParallelDims
|
33 |
+
from torchtitan.tools.logging import logger
|
34 |
+
|
35 |
+
|
36 |
+
def parallelize_fla(
|
37 |
+
model: nn.Module,
|
38 |
+
world_mesh: DeviceMesh,
|
39 |
+
parallel_dims: ParallelDims,
|
40 |
+
job_config: JobConfig,
|
41 |
+
):
|
42 |
+
"""
|
43 |
+
Apply tensor parallelism, activation checkpointing, torch.compile, and data
|
44 |
+
parallelism to the model.
|
45 |
+
|
46 |
+
NOTE: The passed-in model preferably should be on meta device. Otherwise,
|
47 |
+
the model must fit on GPU or CPU memory.
|
48 |
+
"""
|
49 |
+
|
50 |
+
if parallel_dims.tp_enabled:
|
51 |
+
if (
|
52 |
+
job_config.experimental.enable_async_tensor_parallel
|
53 |
+
and not job_config.training.compile
|
54 |
+
):
|
55 |
+
raise RuntimeError("Async TP requires --training.compile")
|
56 |
+
enable_float8_linear = "float8" in job_config.model.converters
|
57 |
+
apply_tp(
|
58 |
+
model,
|
59 |
+
world_mesh["tp"],
|
60 |
+
loss_parallel=parallel_dims.loss_parallel_enabled,
|
61 |
+
enable_float8=enable_float8_linear,
|
62 |
+
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
|
63 |
+
)
|
64 |
+
|
65 |
+
if job_config.activation_checkpoint.mode != "none":
|
66 |
+
apply_ac(model, job_config.activation_checkpoint)
|
67 |
+
|
68 |
+
# turn on per-block compile after AC wrapping and before FSDP
|
69 |
+
if job_config.training.compile:
|
70 |
+
apply_compile(model)
|
71 |
+
|
72 |
+
if (
|
73 |
+
parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
|
74 |
+
): # apply FSDP or HSDP, potentially with Context Parallel
|
75 |
+
if parallel_dims.dp_replicate_enabled:
|
76 |
+
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
|
77 |
+
else:
|
78 |
+
dp_mesh_dim_names = ("dp_shard_cp",)
|
79 |
+
|
80 |
+
apply_fsdp(
|
81 |
+
model,
|
82 |
+
world_mesh[tuple(dp_mesh_dim_names)],
|
83 |
+
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
|
84 |
+
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
|
85 |
+
pp_enabled=parallel_dims.pp_enabled,
|
86 |
+
cpu_offload=job_config.training.enable_cpu_offload,
|
87 |
+
reshard_after_forward_policy=job_config.training.fsdp_reshard_after_forward,
|
88 |
+
)
|
89 |
+
|
90 |
+
if parallel_dims.dp_replicate_enabled:
|
91 |
+
logger.info("Applied HSDP to the model")
|
92 |
+
else:
|
93 |
+
logger.info("Applied FSDP to the model")
|
94 |
+
|
95 |
+
if parallel_dims.cp_enabled:
|
96 |
+
logger.info("Applied Context Parallel to the model")
|
97 |
+
|
98 |
+
if job_config.training.enable_cpu_offload:
|
99 |
+
logger.info("Applied CPU Offloading to the model")
|
100 |
+
elif parallel_dims.dp_replicate_enabled:
|
101 |
+
if world_mesh.ndim > 1:
|
102 |
+
raise RuntimeError("DDP has not supported > 1D parallelism")
|
103 |
+
apply_ddp(
|
104 |
+
model,
|
105 |
+
world_mesh,
|
106 |
+
enable_compile=job_config.training.compile,
|
107 |
+
enable_compiled_autograd=job_config.experimental.enable_compiled_autograd,
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
class TPPlan:
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
model=None,
|
115 |
+
loss_parallel=False,
|
116 |
+
enable_float8=False,
|
117 |
+
):
|
118 |
+
self.model = model
|
119 |
+
self.loss_parallel = loss_parallel
|
120 |
+
self.enable_float8 = enable_float8
|
121 |
+
self.base_model_prefix = getattr(model, "base_model_prefix", "model")
|
122 |
+
|
123 |
+
# TODO(vkuzo): once float8 configuration supports delayed scaling,
|
124 |
+
# add a check here to enforce supported float8 all-gather configurations
|
125 |
+
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
|
126 |
+
try:
|
127 |
+
from torchao.float8.float8_tensor_parallel import (
|
128 |
+
Float8ColwiseParallel,
|
129 |
+
Float8RowwiseParallel,
|
130 |
+
PrepareFloat8ModuleInput
|
131 |
+
)
|
132 |
+
except ImportError:
|
133 |
+
Float8ColwiseParallel = None
|
134 |
+
Float8RowwiseParallel = None
|
135 |
+
PrepareFloat8ModuleInput = None
|
136 |
+
if self.enable_float8 and Float8ColwiseParallel is not None:
|
137 |
+
self.rowwise_parallel = Float8RowwiseParallel
|
138 |
+
self.colwise_parallel = Float8ColwiseParallel
|
139 |
+
self.prepare_module_input = PrepareFloat8ModuleInput
|
140 |
+
self.prepare_module_output = PrepareModuleOutput
|
141 |
+
else:
|
142 |
+
self.rowwise_parallel = RowwiseParallel
|
143 |
+
self.colwise_parallel = ColwiseParallel
|
144 |
+
self.prepare_module_input = PrepareModuleInput
|
145 |
+
self.prepare_module_output = PrepareModuleOutput
|
146 |
+
|
147 |
+
@property
|
148 |
+
def model_plan(self):
|
149 |
+
plans = {
|
150 |
+
f"{self.base_model_prefix}.embeddings": RowwiseParallel(
|
151 |
+
input_layouts=Replicate(),
|
152 |
+
output_layouts=Shard(1),
|
153 |
+
),
|
154 |
+
f"{self.base_model_prefix}.norm": SequenceParallel(),
|
155 |
+
}
|
156 |
+
if self.loss_parallel:
|
157 |
+
plans.update(
|
158 |
+
{
|
159 |
+
"lm_head": ColwiseParallel(
|
160 |
+
input_layouts=Shard(1),
|
161 |
+
output_layouts=Shard(-1) if self.loss_parallel else Replicate(),
|
162 |
+
use_local_output=not self.loss_parallel,
|
163 |
+
),
|
164 |
+
}
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
plans.update(
|
168 |
+
{
|
169 |
+
"lm_head": PrepareModuleWeight(layouts=Replicate()),
|
170 |
+
"criterion": LinearLossParallel(),
|
171 |
+
}
|
172 |
+
)
|
173 |
+
return plans
|
174 |
+
|
175 |
+
@property
|
176 |
+
def layer_plan(self):
|
177 |
+
return {
|
178 |
+
"attn_norm": SequenceParallel(),
|
179 |
+
**self.attn_plan,
|
180 |
+
"mlp_norm": SequenceParallel(),
|
181 |
+
**self.mlp_plan,
|
182 |
+
}
|
183 |
+
|
184 |
+
@property
|
185 |
+
def attn_plan(self):
|
186 |
+
raise NotImplementedError(
|
187 |
+
f"TP plans for token mixing layers of {self.model.config.model_type} not implemented"
|
188 |
+
)
|
189 |
+
|
190 |
+
@property
|
191 |
+
def mlp_plan(self):
|
192 |
+
return {
|
193 |
+
"mlp": self.prepare_module_input(
|
194 |
+
input_layouts=(Shard(1),),
|
195 |
+
desired_input_layouts=(Replicate(),),
|
196 |
+
),
|
197 |
+
"mlp.gate_proj": self.colwise_parallel(),
|
198 |
+
"mlp.up_proj": self.colwise_parallel(),
|
199 |
+
"mlp.down_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
200 |
+
"mlp.swiglu_linear": SwiGLULinearParallel(output_layouts=Shard(1)),
|
201 |
+
}
|
202 |
+
|
203 |
+
|
204 |
+
class TransformerTPPlan(TPPlan):
|
205 |
+
|
206 |
+
@property
|
207 |
+
def attn_plan(self):
|
208 |
+
return {
|
209 |
+
"attn": self.prepare_module_input(
|
210 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
211 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
212 |
+
),
|
213 |
+
"attn.q_proj": self.colwise_parallel(),
|
214 |
+
"attn.k_proj": self.colwise_parallel(),
|
215 |
+
"attn.v_proj": self.colwise_parallel(),
|
216 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
217 |
+
}
|
218 |
+
|
219 |
+
|
220 |
+
class GLATPPlan(TPPlan):
|
221 |
+
|
222 |
+
@property
|
223 |
+
def attn_plan(self):
|
224 |
+
return {
|
225 |
+
"attn": self.prepare_module_input(
|
226 |
+
input_kwarg_layouts={"hidden_states": Shard(1)},
|
227 |
+
desired_input_kwarg_layouts={"hidden_states": Replicate()},
|
228 |
+
),
|
229 |
+
"attn.q_proj": self.colwise_parallel(),
|
230 |
+
"attn.k_proj": self.colwise_parallel(),
|
231 |
+
"attn.v_proj": self.colwise_parallel(),
|
232 |
+
"attn.g_proj": self.colwise_parallel(),
|
233 |
+
"attn.gk_proj.0": PrepareModuleWeight(layouts=Replicate()),
|
234 |
+
"attn.gk_proj.1": self.colwise_parallel(),
|
235 |
+
"attn.g_norm": SequenceParallel(sequence_dim=-1),
|
236 |
+
"attn.o_proj": self.rowwise_parallel(output_layouts=Shard(1)),
|
237 |
+
}
|
238 |
+
|
239 |
+
|
240 |
+
TP_PLAN_MAP = {"transformer": TransformerTPPlan, "gla": GLATPPlan}
|
241 |
+
|
242 |
+
|
243 |
+
def apply_tp(
|
244 |
+
model: nn.Module,
|
245 |
+
tp_mesh: DeviceMesh,
|
246 |
+
loss_parallel: bool,
|
247 |
+
enable_float8: bool,
|
248 |
+
enable_async_tp: bool,
|
249 |
+
):
|
250 |
+
"""Apply tensor parallelism."""
|
251 |
+
# 1. Parallelize the embedding and shard its outputs (which are the first
|
252 |
+
# transformer block's inputs)
|
253 |
+
# 2. Parallelize the root norm layer over the sequence dim
|
254 |
+
# 3. Parallelize the final linear output layer
|
255 |
+
tp_plan = TP_PLAN_MAP[model.config.model_type](
|
256 |
+
model, loss_parallel=loss_parallel, enable_float8=enable_float8
|
257 |
+
)
|
258 |
+
parallelize_module(model, tp_mesh, tp_plan.model_plan)
|
259 |
+
|
260 |
+
blocks = get_blocks(model)
|
261 |
+
if blocks is None:
|
262 |
+
logger.warning("No block found for tensor parallelism")
|
263 |
+
else:
|
264 |
+
for _, block in enumerate(blocks):
|
265 |
+
parallelize_module(
|
266 |
+
module=block,
|
267 |
+
device_mesh=tp_mesh,
|
268 |
+
parallelize_plan=tp_plan.layer_plan,
|
269 |
+
)
|
270 |
+
|
271 |
+
if enable_async_tp:
|
272 |
+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
273 |
+
|
274 |
+
torch._inductor.config._micro_pipeline_tp = True
|
275 |
+
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
|
276 |
+
|
277 |
+
logger.info(
|
278 |
+
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
|
279 |
+
"Tensor Parallelism to the model"
|
280 |
+
)
|
281 |
+
|
282 |
+
|
283 |
+
# for selective op activation checkpointing
|
284 |
+
_save_list = {
|
285 |
+
torch.ops.aten.mm.default,
|
286 |
+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
287 |
+
torch.ops.aten._scaled_dot_product_flash_attention.default,
|
288 |
+
torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
289 |
+
# for low precision training, it's useful to always save
|
290 |
+
# the result of max, since the absolute maximum is
|
291 |
+
# used to compute the scaling factor for quantization.
|
292 |
+
torch.ops.aten.max.default,
|
293 |
+
}
|
294 |
+
|
295 |
+
|
296 |
+
def _apply_ac_to_block(module: nn.Module, ac_config):
|
297 |
+
valid_ac_modes = ("full", "selective")
|
298 |
+
if ac_config.mode not in valid_ac_modes:
|
299 |
+
raise ValueError(
|
300 |
+
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
|
301 |
+
)
|
302 |
+
|
303 |
+
if ac_config.mode == "full":
|
304 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
305 |
+
|
306 |
+
assert ac_config.mode == "selective", f"{ac_config.mode}"
|
307 |
+
use_op_sac = ac_config.selective_ac_option == "op"
|
308 |
+
use_layer_sac = ac_config.selective_ac_option.isdigit()
|
309 |
+
if not use_op_sac and not use_layer_sac:
|
310 |
+
raise ValueError(
|
311 |
+
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
|
312 |
+
f"Valid options: 'op' or a positive int representing layer frequency"
|
313 |
+
)
|
314 |
+
if use_op_sac:
|
315 |
+
from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
|
316 |
+
|
317 |
+
def _get_custom_policy(meta):
|
318 |
+
def _custom_policy(ctx, func, *args, **kwargs):
|
319 |
+
mode = "recompute" if ctx.is_recompute else "forward"
|
320 |
+
mm_count_key = f"{mode}_mm_count"
|
321 |
+
if func == torch.ops.aten.mm.default:
|
322 |
+
meta[mm_count_key] += 1
|
323 |
+
# Saves output of all compute ops, except every second mm
|
324 |
+
to_save = func in _save_list and not (
|
325 |
+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
|
326 |
+
)
|
327 |
+
return (
|
328 |
+
CheckpointPolicy.MUST_SAVE
|
329 |
+
if to_save
|
330 |
+
else CheckpointPolicy.PREFER_RECOMPUTE
|
331 |
+
)
|
332 |
+
|
333 |
+
return _custom_policy
|
334 |
+
|
335 |
+
def selective_checkpointing_context_fn():
|
336 |
+
meta = defaultdict(int)
|
337 |
+
return create_selective_checkpoint_contexts(_get_custom_policy(meta))
|
338 |
+
|
339 |
+
return ptd_checkpoint_wrapper(
|
340 |
+
module,
|
341 |
+
context_fn=selective_checkpointing_context_fn,
|
342 |
+
preserve_rng_state=False,
|
343 |
+
)
|
344 |
+
elif use_layer_sac:
|
345 |
+
# Checkpoint every `ac_freq` of the modules passed to this function
|
346 |
+
ac_freq = int(ac_config.selective_ac_option)
|
347 |
+
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
|
348 |
+
ptd_checkpoint_wrapper._count += 1
|
349 |
+
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
|
350 |
+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
|
351 |
+
else:
|
352 |
+
return module
|
353 |
+
|
354 |
+
|
355 |
+
def apply_ac(model: nn.Module, ac_config):
|
356 |
+
"""Apply activation checkpointing to the model."""
|
357 |
+
blocks = get_blocks(model)
|
358 |
+
if blocks is None:
|
359 |
+
logger.warning("No block found for activation checkpointing")
|
360 |
+
return
|
361 |
+
|
362 |
+
for layer_id, block in blocks.named_children():
|
363 |
+
block = _apply_ac_to_block(block, ac_config)
|
364 |
+
blocks.register_module(layer_id, block)
|
365 |
+
|
366 |
+
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
|
367 |
+
|
368 |
+
|
369 |
+
def apply_compile(model: nn.Module):
|
370 |
+
"""
|
371 |
+
Apply torch.compile to each block, which makes compilation efficient due to
|
372 |
+
repeated structure. Alternatively one can compile the whole model (after applying DP).
|
373 |
+
"""
|
374 |
+
|
375 |
+
blocks = get_blocks(model)
|
376 |
+
if blocks is None:
|
377 |
+
logger.warning("No block found for torch.compile")
|
378 |
+
else:
|
379 |
+
for layer_id, block in blocks.named_children():
|
380 |
+
block = torch.compile(block)
|
381 |
+
blocks.register_module(layer_id, block)
|
382 |
+
logger.info("Compiling each block with torch.compile")
|
383 |
+
|
384 |
+
real_model = get_model(model)
|
385 |
+
|
386 |
+
logger.info("Compiling the embedding, norm, and lm_head layers with torch.compile")
|
387 |
+
embeddings_key = get_components_name(real_model, "tok_embeddings")
|
388 |
+
if embeddings_key is not None:
|
389 |
+
embeddings = torch.compile(getattr(real_model, embeddings_key), fullgraph=True)
|
390 |
+
real_model.register_module(embeddings_key, embeddings)
|
391 |
+
|
392 |
+
norm_key = get_components_name(real_model, "norm")
|
393 |
+
if norm_key is not None:
|
394 |
+
norm = torch.compile(getattr(real_model, norm_key), fullgraph=True)
|
395 |
+
real_model.register_module(norm_key, norm)
|
396 |
+
|
397 |
+
lm_head_key = get_components_name(model, "lm_head")
|
398 |
+
if lm_head_key is not None:
|
399 |
+
lm_head = torch.compile(getattr(model, lm_head_key), fullgraph=True)
|
400 |
+
model.register_module(lm_head_key, lm_head)
|
401 |
+
|
402 |
+
logger.info("Compiling the entire model with torch.compile")
|
403 |
+
model = torch.compile(model)
|
404 |
+
|
405 |
+
|
406 |
+
def apply_fsdp(
|
407 |
+
model: nn.Module,
|
408 |
+
dp_mesh: DeviceMesh,
|
409 |
+
param_dtype: torch.dtype,
|
410 |
+
reduce_dtype: torch.dtype,
|
411 |
+
pp_enabled: bool,
|
412 |
+
cpu_offload: bool = False,
|
413 |
+
reshard_after_forward_policy: str = "default",
|
414 |
+
):
|
415 |
+
"""
|
416 |
+
Apply data parallelism (via FSDP2) to the model.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
model (nn.Module): The model to apply data parallelism to.
|
420 |
+
dp_mesh (DeviceMesh): The device mesh to use for data parallelism.
|
421 |
+
param_dtype (torch.dtype): The data type to use for model parameters.
|
422 |
+
reduce_dtype (torch.dtype): The data type to use for reduction operations.
|
423 |
+
pp_enabled (bool): Whether pipeline parallelism is enabled.
|
424 |
+
cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False.
|
425 |
+
reshard_after_forward_policy (str, optional):
|
426 |
+
The policy to use for resharding after forward pass. Defaults to "default".
|
427 |
+
Other options: "never", "always".
|
428 |
+
- "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios.
|
429 |
+
- "always" will enable `reshard_after_forward` for all forward passes.
|
430 |
+
- "never" will disable `reshard_after_forward` for all forward passes.
|
431 |
+
|
432 |
+
"""
|
433 |
+
mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype)
|
434 |
+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
435 |
+
if cpu_offload:
|
436 |
+
fsdp_config["offload_policy"] = CPUOffloadPolicy()
|
437 |
+
|
438 |
+
blocks = get_blocks(model)
|
439 |
+
if blocks is None:
|
440 |
+
logger.warning("No block found for FSDP")
|
441 |
+
else:
|
442 |
+
total_blocks = len(blocks)
|
443 |
+
for layer_id, block in enumerate(blocks):
|
444 |
+
if reshard_after_forward_policy == "always":
|
445 |
+
reshard_after_forward = True
|
446 |
+
elif reshard_after_forward_policy == "never":
|
447 |
+
reshard_after_forward = False
|
448 |
+
elif reshard_after_forward_policy == "default":
|
449 |
+
if pp_enabled:
|
450 |
+
# For PP, do not reshard after forward to avoid per-microbatch
|
451 |
+
# all-gathers, which can be expensive and non-overlapped
|
452 |
+
reshard_after_forward = False
|
453 |
+
else:
|
454 |
+
# As an optimization, do not reshard after forward for the last
|
455 |
+
# transformer block since FSDP would prefetch it immediately
|
456 |
+
reshard_after_forward = int(layer_id) < total_blocks - 1
|
457 |
+
else:
|
458 |
+
raise ValueError(
|
459 |
+
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
|
460 |
+
)
|
461 |
+
fully_shard(
|
462 |
+
block,
|
463 |
+
**fsdp_config,
|
464 |
+
reshard_after_forward=reshard_after_forward,
|
465 |
+
)
|
466 |
+
|
467 |
+
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
|
468 |
+
|
469 |
+
|
470 |
+
def apply_ddp(
|
471 |
+
model: nn.Module,
|
472 |
+
dp_mesh: DeviceMesh,
|
473 |
+
enable_compile: bool,
|
474 |
+
enable_compiled_autograd: bool,
|
475 |
+
):
|
476 |
+
if enable_compile:
|
477 |
+
if enable_compiled_autograd:
|
478 |
+
torch._dynamo.config.optimize_ddp = (
|
479 |
+
"python_reducer_without_compiled_forward"
|
480 |
+
)
|
481 |
+
else:
|
482 |
+
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
|
483 |
+
|
484 |
+
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
|
485 |
+
|
486 |
+
logger.info("Applied DDP to the model")
|
487 |
+
|
488 |
+
|
489 |
+
def get_model(model):
|
490 |
+
base_model_prefix = getattr(model, "base_model_prefix", "model")
|
491 |
+
if not hasattr(model, base_model_prefix):
|
492 |
+
return None
|
493 |
+
model = getattr(model, base_model_prefix)
|
494 |
+
return model
|
495 |
+
|
496 |
+
|
497 |
+
def get_blocks(model):
|
498 |
+
# TODO[flame]: adapt for network not using 'layers' attribute
|
499 |
+
model = get_model(model)
|
500 |
+
if not hasattr(model, "layers"):
|
501 |
+
logger.warning('no "layers" in model can be found')
|
502 |
+
return None
|
503 |
+
return model.layers
|
504 |
+
|
505 |
+
|
506 |
+
def get_components_name(model, component_name):
|
507 |
+
"""
|
508 |
+
We try to catch tok_embeddings, norm layers and lm_head layers
|
509 |
+
We do not catch the layer names in the blocks, for blocks see `get_blocks`
|
510 |
+
We assume the model has the following structure:
|
511 |
+
LlamaForCausalLM:
|
512 |
+
Model:
|
513 |
+
embed_tokens,
|
514 |
+
layers,
|
515 |
+
norm,
|
516 |
+
lm_head
|
517 |
+
***
|
518 |
+
so, to search 'tok_embeddings' and 'norm' we need to pass `get_model(model)`
|
519 |
+
and for 'lm_head' we need to pass `model`
|
520 |
+
***
|
521 |
+
"""
|
522 |
+
|
523 |
+
if component_name == "tok_embeddings":
|
524 |
+
if hasattr(model, "tok_embeddings"):
|
525 |
+
return "tok_embeddings"
|
526 |
+
elif hasattr(model, "embed_tokens"):
|
527 |
+
return "embed_tokens"
|
528 |
+
elif hasattr(model, "embeddings"):
|
529 |
+
return "embeddings"
|
530 |
+
else:
|
531 |
+
logger.warning("No tok_embeddings found in model")
|
532 |
+
return None
|
533 |
+
|
534 |
+
elif component_name == "norm":
|
535 |
+
if hasattr(model, "norm"):
|
536 |
+
return "norm"
|
537 |
+
elif hasattr(model, "norms"):
|
538 |
+
return "norms"
|
539 |
+
elif hasattr(model, "layernorm"):
|
540 |
+
return "layernorm"
|
541 |
+
else:
|
542 |
+
logger.warning("No norm found in model")
|
543 |
+
return None
|
544 |
+
|
545 |
+
elif component_name == "lm_head":
|
546 |
+
if hasattr(model, "lm_head"):
|
547 |
+
return "lm_head"
|
548 |
+
else:
|
549 |
+
logger.warning("No lm_head found in model")
|
550 |
+
return None
|
flame/train.py
ADDED
@@ -0,0 +1,897 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the BSD-style license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from datetime import timedelta
|
11 |
+
from collections import defaultdict
|
12 |
+
import dataclasses
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from datasets import interleave_datasets, load_dataset
|
16 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
17 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
18 |
+
|
19 |
+
import fla # noqa
|
20 |
+
from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss
|
21 |
+
from fla.ops.common.utils import prepare_position_ids
|
22 |
+
from flame.components.checkpoint import TrainState
|
23 |
+
from flame.config_manager import JobConfig
|
24 |
+
from flame.data import build_dataloader, shuffle
|
25 |
+
from flame.models.parallelize_fla import parallelize_fla
|
26 |
+
from flame.models.pipeline_fla import pipeline_fla
|
27 |
+
from flame.tools.utils import get_nparams_and_flops
|
28 |
+
from flame.utils.checkpoint import cleanup_local_checkpoints
|
29 |
+
from flame.utils.convert_dcp_to_hf import save_pretrained
|
30 |
+
from flame.utils.hf_utils import upload_checkpoint_to_hf
|
31 |
+
from datetime import datetime
|
32 |
+
from torchtitan.components.checkpoint import CheckpointManager
|
33 |
+
from torchtitan.components.ft import FTParallelDims, init_ft_manager
|
34 |
+
from torchtitan.components.loss import build_cross_entropy_loss
|
35 |
+
from torchtitan.components.lr_scheduler import build_lr_schedulers
|
36 |
+
from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible
|
37 |
+
from torchtitan.components.optimizer import build_optimizers
|
38 |
+
from torchtitan.distributed import ParallelDims
|
39 |
+
from torchtitan.distributed import utils as dist_utils
|
40 |
+
from torchtitan.protocols.model_converter import build_model_converters
|
41 |
+
from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec
|
42 |
+
from torchtitan.tools import utils
|
43 |
+
from torchtitan.tools.logging import init_logger, logger
|
44 |
+
from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
|
45 |
+
|
46 |
+
from dotenv import load_dotenv
|
47 |
+
load_dotenv()
|
48 |
+
|
49 |
+
import wandb
|
50 |
+
wandb.login(key=os.environ["WANDB_API_KEY"])
|
51 |
+
|
52 |
+
import huggingface_hub
|
53 |
+
huggingface_hub.login(token=os.environ["HF_TOKEN"])
|
54 |
+
|
55 |
+
|
56 |
+
def build_tokenizer(job_config: JobConfig) -> AutoTokenizer:
|
57 |
+
return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path)
|
58 |
+
|
59 |
+
|
60 |
+
register_train_spec(
|
61 |
+
TrainSpec(
|
62 |
+
name="fla",
|
63 |
+
cls=AutoModelForCausalLM,
|
64 |
+
config=AutoConfig,
|
65 |
+
parallelize_fn=parallelize_fla,
|
66 |
+
pipelining_fn=pipeline_fla,
|
67 |
+
build_optimizers_fn=build_optimizers,
|
68 |
+
build_lr_schedulers_fn=build_lr_schedulers,
|
69 |
+
build_dataloader_fn=build_dataloader,
|
70 |
+
build_tokenizer_fn=build_tokenizer,
|
71 |
+
build_loss_fn=build_cross_entropy_loss,
|
72 |
+
)
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
|
77 |
+
@record
|
78 |
+
def main(job_config: JobConfig):
|
79 |
+
logger.info(f"Starting job: {job_config.job.description}")
|
80 |
+
|
81 |
+
if job_config.experimental.custom_model_path:
|
82 |
+
utils.import_module_from_path(job_config.experimental.custom_model_path)
|
83 |
+
|
84 |
+
# used for colorful printing
|
85 |
+
color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color
|
86 |
+
|
87 |
+
if job_config.job.print_args:
|
88 |
+
logger.info(
|
89 |
+
f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}"
|
90 |
+
)
|
91 |
+
|
92 |
+
# take control of garbage collection to avoid stragglers
|
93 |
+
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
|
94 |
+
|
95 |
+
device_module, device_type = utils.device_module, utils.device_type
|
96 |
+
device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
|
97 |
+
# Device has to be set before creating TorchFT manager.
|
98 |
+
device_module.set_device(device)
|
99 |
+
ft_manager = init_ft_manager(job_config)
|
100 |
+
|
101 |
+
run_specific_repo_id = None
|
102 |
+
if getattr(job_config.checkpoint, "hf_upload_enabled", False):
|
103 |
+
hf_repo_base = getattr(job_config.checkpoint, "hf_repo_base_name", None)
|
104 |
+
if hf_repo_base:
|
105 |
+
# Generate timestamp (adjust format if desired)
|
106 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
107 |
+
run_specific_repo_id = f"{hf_repo_base}-{timestamp}"
|
108 |
+
logger.info(f"Target Hugging Face repository for this run: {run_specific_repo_id}")
|
109 |
+
else:
|
110 |
+
logger.warning("HF Hub upload enabled, but 'checkpoint.hf_repo_base_name' is not set.")
|
111 |
+
# Disable upload if base name is missing
|
112 |
+
job_config.checkpoint.hf_upload_enabled = False
|
113 |
+
|
114 |
+
# init distributed
|
115 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
116 |
+
if not ft_manager.enabled:
|
117 |
+
parallel_dims = ParallelDims(
|
118 |
+
dp_shard=job_config.training.data_parallel_shard_degree,
|
119 |
+
dp_replicate=job_config.training.data_parallel_replicate_degree,
|
120 |
+
cp=job_config.experimental.context_parallel_degree,
|
121 |
+
tp=job_config.training.tensor_parallel_degree,
|
122 |
+
pp=job_config.experimental.pipeline_parallel_degree,
|
123 |
+
world_size=world_size,
|
124 |
+
enable_loss_parallel=not job_config.training.disable_loss_parallel,
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
parallel_dims = FTParallelDims(
|
128 |
+
dp_shard=job_config.training.data_parallel_shard_degree,
|
129 |
+
dp_replicate=job_config.training.data_parallel_replicate_degree,
|
130 |
+
cp=job_config.experimental.context_parallel_degree,
|
131 |
+
tp=job_config.training.tensor_parallel_degree,
|
132 |
+
pp=job_config.experimental.pipeline_parallel_degree,
|
133 |
+
world_size=world_size,
|
134 |
+
enable_loss_parallel=not job_config.training.disable_loss_parallel,
|
135 |
+
ft_manager=ft_manager,
|
136 |
+
)
|
137 |
+
dist_utils.init_distributed(job_config)
|
138 |
+
# initialize device memory monitor and get peak flops for MFU calculation
|
139 |
+
device_memory_monitor = build_device_memory_monitor()
|
140 |
+
gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
|
141 |
+
logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}")
|
142 |
+
|
143 |
+
# build meshes
|
144 |
+
world_mesh = parallel_dims.build_mesh(device_type=device_type)
|
145 |
+
if parallel_dims.dp_enabled:
|
146 |
+
dp_mesh = world_mesh["dp"]
|
147 |
+
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank()
|
148 |
+
else:
|
149 |
+
dp_degree, dp_rank = 1, 0
|
150 |
+
|
151 |
+
if parallel_dims.pp_enabled:
|
152 |
+
raise NotImplementedError(
|
153 |
+
"Pipeline parallelism is not supported in this version"
|
154 |
+
)
|
155 |
+
"""
|
156 |
+
! TODO[flame]: We need to fix the pipeline parallelism for flame
|
157 |
+
[x] Match the key of models' components with the actual naming
|
158 |
+
[ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically
|
159 |
+
forces to tie if head is None, we need to handle this case
|
160 |
+
[ ]
|
161 |
+
"""
|
162 |
+
pp_mesh = world_mesh["pp"]
|
163 |
+
|
164 |
+
# Set random seed, and maybe enable deterministic mode (mainly for debugging, expect perf loss)
|
165 |
+
dist_utils.set_determinism(
|
166 |
+
world_mesh, device, job_config.training.seed, job_config.training.deterministic
|
167 |
+
)
|
168 |
+
train_spec = get_train_spec(job_config.model.name)
|
169 |
+
|
170 |
+
logger.info("Loading tokenizer...")
|
171 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
172 |
+
job_config.model.tokenizer_path,
|
173 |
+
trust_remote_code=True,
|
174 |
+
model_max_length=int(1e10),
|
175 |
+
)
|
176 |
+
logger.info(f"{tokenizer}")
|
177 |
+
logger.info(
|
178 |
+
f"Loading dataset {job_config.training.dataset}"
|
179 |
+
f":{job_config.training.dataset_name}"
|
180 |
+
if job_config.training.dataset_name is not None
|
181 |
+
else ""
|
182 |
+
)
|
183 |
+
|
184 |
+
min_num_shards = dp_degree * job_config.training.num_workers
|
185 |
+
if len(job_config.training.dataset.split(",")) == 1:
|
186 |
+
dataset = load_dataset(
|
187 |
+
path=job_config.training.dataset,
|
188 |
+
name=getattr(job_config.training, "dataset_name", None),
|
189 |
+
data_dir=getattr(job_config.training, "data_dir", None),
|
190 |
+
data_files=getattr(job_config.training, "data_files", None),
|
191 |
+
split=job_config.training.dataset_split or "train",
|
192 |
+
trust_remote_code=True,
|
193 |
+
streaming=job_config.training.streaming,
|
194 |
+
num_proc=(
|
195 |
+
job_config.training.num_workers
|
196 |
+
if not job_config.training.streaming
|
197 |
+
else None
|
198 |
+
),
|
199 |
+
)
|
200 |
+
logger.info(f"{dataset}")
|
201 |
+
|
202 |
+
logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
|
203 |
+
if not job_config.training.streaming:
|
204 |
+
# the states of map-style dataset is recoverable after shuffling
|
205 |
+
dataset = dataset.shuffle(
|
206 |
+
seed=job_config.training.seed
|
207 |
+
).to_iterable_dataset(num_shards=min_num_shards)
|
208 |
+
else:
|
209 |
+
if dataset.num_shards < min_num_shards:
|
210 |
+
logger.warning(
|
211 |
+
f"{color.red}"
|
212 |
+
f"Dataset {job_config.training.dataset} has insufficient shards ({dataset.num_shards}). "
|
213 |
+
f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
|
214 |
+
f"{job_config.training.num_workers} dataloader workers. "
|
215 |
+
f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards."
|
216 |
+
f"{color.reset}"
|
217 |
+
)
|
218 |
+
dataset = (
|
219 |
+
load_dataset(
|
220 |
+
path=job_config.training.dataset,
|
221 |
+
name=getattr(job_config.training, "dataset_name", None),
|
222 |
+
data_dir=getattr(job_config.training, "data_dir", None),
|
223 |
+
data_files=getattr(job_config.training, "data_files", None),
|
224 |
+
split=job_config.training.dataset_split or "train",
|
225 |
+
trust_remote_code=True,
|
226 |
+
streaming=False,
|
227 |
+
num_proc=job_config.training.num_workers,
|
228 |
+
)
|
229 |
+
.shuffle(seed=job_config.training.seed)
|
230 |
+
.to_iterable_dataset(num_shards=min_num_shards)
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
dataset = shuffle(dataset, seed=job_config.training.seed)
|
234 |
+
else:
|
235 |
+
datasets = job_config.training.dataset.split(",")
|
236 |
+
if job_config.training.dataset_name is not None:
|
237 |
+
dataset_names = [
|
238 |
+
name or None for name in job_config.training.dataset_name.split(",")
|
239 |
+
]
|
240 |
+
assert len(dataset_names) == len(datasets), (
|
241 |
+
"The number of dataset names must match the number of datasets"
|
242 |
+
)
|
243 |
+
else:
|
244 |
+
dataset_names = [None] * len(datasets)
|
245 |
+
if job_config.training.dataset_split is not None:
|
246 |
+
dataset_splits = [
|
247 |
+
split or "train"
|
248 |
+
for split in job_config.training.dataset_split.split(",")
|
249 |
+
]
|
250 |
+
assert len(dataset_splits) == len(datasets), (
|
251 |
+
"The number of dataset splits must match the number of datasets"
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
dataset_splits = ["train"] * len(datasets)
|
255 |
+
if job_config.training.data_dir is not None:
|
256 |
+
data_dirs = [
|
257 |
+
data_dir or None for data_dir in job_config.training.data_dir.split(",")
|
258 |
+
]
|
259 |
+
assert len(data_dirs) == len(datasets), (
|
260 |
+
"The number of data dirs must match the number of datasets"
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
data_dirs = [None] * len(datasets)
|
264 |
+
if job_config.training.data_files is not None:
|
265 |
+
data_files = job_config.training.data_files.split(",")
|
266 |
+
assert len(data_files) == len(datasets), (
|
267 |
+
"The number of data files must match the number of datasets"
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
data_files = [None] * len(datasets)
|
271 |
+
if job_config.training.data_probs is not None:
|
272 |
+
data_probs = [float(p) for p in job_config.training.data_probs.split(",")]
|
273 |
+
assert len(data_probs) == len(datasets), (
|
274 |
+
"The number of data probabilities must match the number of datasets"
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
raise ValueError(
|
278 |
+
"Data sampling probabilities are required if using multiple datasets"
|
279 |
+
)
|
280 |
+
|
281 |
+
subsets = []
|
282 |
+
for i, prob in enumerate(data_probs):
|
283 |
+
subset = load_dataset(
|
284 |
+
path=datasets[i],
|
285 |
+
name=dataset_names[i],
|
286 |
+
data_dir=data_dirs[i],
|
287 |
+
data_files=data_files[i],
|
288 |
+
split=dataset_splits[i],
|
289 |
+
trust_remote_code=True,
|
290 |
+
streaming=job_config.training.streaming,
|
291 |
+
num_proc=(
|
292 |
+
job_config.training.num_workers
|
293 |
+
if not job_config.training.streaming
|
294 |
+
else None
|
295 |
+
),
|
296 |
+
)
|
297 |
+
logger.info(
|
298 |
+
f"Subset {color.cyan}{datasets[i]}"
|
299 |
+
+ (f":{dataset_names[i]} " if dataset_names[i] else " ")
|
300 |
+
+ f"(p = {prob:.3f}){color.reset}:\n"
|
301 |
+
+ f"{subset}"
|
302 |
+
)
|
303 |
+
|
304 |
+
logger.info(f"Shuffling the dataset with seed {job_config.training.seed}")
|
305 |
+
if not job_config.training.streaming:
|
306 |
+
# the states of map-style dataset is recoverable after shuffling
|
307 |
+
subset = subset.shuffle(
|
308 |
+
seed=job_config.training.seed
|
309 |
+
).to_iterable_dataset(num_shards=min_num_shards)
|
310 |
+
else:
|
311 |
+
if subset.num_shards < min_num_shards:
|
312 |
+
logger.warning(
|
313 |
+
f"{color.red}"
|
314 |
+
f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). "
|
315 |
+
f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
|
316 |
+
f"{job_config.training.num_workers} dataloader workers. "
|
317 |
+
f"Resharding dataset to {min_num_shards} shards and disabling streaming mode."
|
318 |
+
f"{color.reset}"
|
319 |
+
)
|
320 |
+
# again, it's ok to directly shuffle the map-style dataset
|
321 |
+
# we expect an error raised if the map-style dataset still has not enough data shards
|
322 |
+
subset = (
|
323 |
+
load_dataset(
|
324 |
+
path=datasets[i],
|
325 |
+
name=dataset_names[i],
|
326 |
+
data_dir=data_dirs[i],
|
327 |
+
data_files=data_files[i],
|
328 |
+
split=dataset_splits[i],
|
329 |
+
trust_remote_code=True,
|
330 |
+
streaming=False,
|
331 |
+
num_proc=job_config.training.num_workers,
|
332 |
+
)
|
333 |
+
.shuffle(seed=job_config.training.seed)
|
334 |
+
.to_iterable_dataset(min_num_shards)
|
335 |
+
)
|
336 |
+
else:
|
337 |
+
# we set relatively small buffer size here as interleaving could provide some randomness
|
338 |
+
subset = shuffle(
|
339 |
+
subset,
|
340 |
+
seed=job_config.training.seed,
|
341 |
+
buffer_size=max(128, 1024 // len(datasets)),
|
342 |
+
)
|
343 |
+
|
344 |
+
if "text" in subset.column_names:
|
345 |
+
subset = subset.select_columns("text")
|
346 |
+
elif "content" in subset.column_names:
|
347 |
+
subset = subset.select_columns("content")
|
348 |
+
else:
|
349 |
+
raise ValueError(
|
350 |
+
f"Subset {datasets[i]} has no 'text' or 'content' column"
|
351 |
+
)
|
352 |
+
subsets.append(subset)
|
353 |
+
|
354 |
+
logger.info(
|
355 |
+
f"Interleaving {len(subsets)} datasets with probabilities {data_probs}"
|
356 |
+
)
|
357 |
+
dataset = interleave_datasets(
|
358 |
+
datasets=subsets,
|
359 |
+
probabilities=data_probs,
|
360 |
+
stopping_strategy="all_exhausted",
|
361 |
+
seed=job_config.training.seed,
|
362 |
+
)
|
363 |
+
logger.info(f"{dataset}")
|
364 |
+
|
365 |
+
|
366 |
+
logger.info(f"Loading model config from {job_config.model.config}")
|
367 |
+
model_config = AutoConfig.from_pretrained(job_config.model.config)
|
368 |
+
|
369 |
+
logger.info("Building dataloader...")
|
370 |
+
dataloader = build_dataloader(
|
371 |
+
dataset=dataset,
|
372 |
+
tokenizer=tokenizer,
|
373 |
+
rank=dp_rank,
|
374 |
+
world_size=dp_degree,
|
375 |
+
batch_size=job_config.training.batch_size,
|
376 |
+
# TODO: Make this more modular
|
377 |
+
# seq_len=job_config.training.seq_len if not model_config.use_myopic_loss else job_config.training.seq_len*2,
|
378 |
+
seq_len=job_config.training.seq_len * 2,
|
379 |
+
context_len=job_config.training.context_len,
|
380 |
+
varlen=job_config.training.varlen,
|
381 |
+
num_workers=job_config.training.num_workers,
|
382 |
+
pin_memory=job_config.training.pin_memory,
|
383 |
+
persistent_workers=job_config.training.persistent_workers,
|
384 |
+
snapshot_every_n_steps=job_config.checkpoint.interval,
|
385 |
+
)
|
386 |
+
|
387 |
+
# set the model configs from training inputs:
|
388 |
+
# 1. norm type to decide which norm layer to use
|
389 |
+
# 2. disable fused norm if TP is enabled
|
390 |
+
# 3. vocab size from tokenizer
|
391 |
+
# 4. context_len base on inputs
|
392 |
+
if parallel_dims.tp_enabled:
|
393 |
+
if model_config.fuse_norm:
|
394 |
+
logger.warning(
|
395 |
+
f"{color.red}"
|
396 |
+
f"Fused norm is not compatible with tensor parallelism. "
|
397 |
+
f"Disabling it for now."
|
398 |
+
f"{color.reset}"
|
399 |
+
)
|
400 |
+
model_config.fuse_norm = False
|
401 |
+
if parallel_dims.loss_parallel_enabled:
|
402 |
+
if model_config.fuse_cross_entropy:
|
403 |
+
logger.warning(
|
404 |
+
f"{color.red}"
|
405 |
+
f"Loss parallel enabled. Disabling fused cross entropy for now."
|
406 |
+
f"{color.reset}"
|
407 |
+
)
|
408 |
+
model_config.fuse_cross_entropy = False
|
409 |
+
model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size)
|
410 |
+
|
411 |
+
logger.info(
|
412 |
+
f"Building model from the config\n{color.green}{model_config}{color.reset}"
|
413 |
+
)
|
414 |
+
with torch.device("meta"):
|
415 |
+
model = AutoModelForCausalLM.from_config(model_config)
|
416 |
+
if (
|
417 |
+
getattr(model_config, "fuse_cross_entropy", False)
|
418 |
+
and FusedLinearCrossEntropyLoss is not None
|
419 |
+
):
|
420 |
+
model.criterion = FusedLinearCrossEntropyLoss(
|
421 |
+
num_chunks=8 // parallel_dims.tp
|
422 |
+
)
|
423 |
+
# defer weight initialization until after parallelisms are applied
|
424 |
+
model.apply(lambda m: setattr(m, "_is_hf_initialized", False))
|
425 |
+
logger.info(f"{color.blue}\n{model}{color.reset}\n")
|
426 |
+
|
427 |
+
# Build the collection of model converters. No-op if `model.converters` empty
|
428 |
+
model_converters = build_model_converters(job_config, parallel_dims)
|
429 |
+
model_converters.convert(model)
|
430 |
+
|
431 |
+
# calculate model size and flops per token
|
432 |
+
model_param_count, num_flops_per_token = get_nparams_and_flops(
|
433 |
+
model, model_config, job_config.training.context_len
|
434 |
+
)
|
435 |
+
|
436 |
+
# move sharded model to CPU/GPU and initialize weights via DTensor
|
437 |
+
if job_config.checkpoint.create_seed_checkpoint:
|
438 |
+
init_device = "cpu"
|
439 |
+
elif job_config.training.enable_cpu_offload:
|
440 |
+
init_device = "cpu"
|
441 |
+
else:
|
442 |
+
init_device = device_type
|
443 |
+
|
444 |
+
# apply parallelisms and initialization
|
445 |
+
if parallel_dims.pp_enabled:
|
446 |
+
# apply PT-D Pipeline Parallel
|
447 |
+
(
|
448 |
+
pp_schedule,
|
449 |
+
model_parts,
|
450 |
+
has_first_stage,
|
451 |
+
has_last_stage,
|
452 |
+
) = train_spec.pipelining_fn(
|
453 |
+
model,
|
454 |
+
pp_mesh,
|
455 |
+
parallel_dims,
|
456 |
+
job_config,
|
457 |
+
device,
|
458 |
+
model_config,
|
459 |
+
train_spec.loss_fn,
|
460 |
+
)
|
461 |
+
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
|
462 |
+
del model
|
463 |
+
|
464 |
+
# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
|
465 |
+
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
|
466 |
+
# optimizer, and checkpointing
|
467 |
+
for m in model_parts:
|
468 |
+
# apply SPMD-style PT-D techniques
|
469 |
+
train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
|
470 |
+
m.to_empty(device=init_device)
|
471 |
+
with torch.no_grad():
|
472 |
+
m.post_init()
|
473 |
+
m.train()
|
474 |
+
|
475 |
+
# confirm that user will be able to view loss metrics on the console
|
476 |
+
ensure_pp_loss_visible(parallel_dims, job_config, color)
|
477 |
+
else:
|
478 |
+
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
|
479 |
+
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
|
480 |
+
model.to_empty(device=init_device)
|
481 |
+
with torch.no_grad():
|
482 |
+
model.post_init()
|
483 |
+
model.train()
|
484 |
+
|
485 |
+
model_parts = [model]
|
486 |
+
|
487 |
+
device_mem_stats = device_memory_monitor.get_peak_stats()
|
488 |
+
logger.info(
|
489 |
+
f"{device_type.upper()} memory usage for model: "
|
490 |
+
f"{device_mem_stats.max_reserved_gib:.2f}GiB"
|
491 |
+
f"({device_mem_stats.max_reserved_pct:.2f}%)"
|
492 |
+
)
|
493 |
+
|
494 |
+
# build optimizer after applying parallelisms to the model
|
495 |
+
optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
|
496 |
+
lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
|
497 |
+
# Post optimizer step model converters hook.
|
498 |
+
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
|
499 |
+
# where it issues a single all-reduce for all parameters at once for better performance
|
500 |
+
optimizers.register_step_post_hook(
|
501 |
+
lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts)
|
502 |
+
)
|
503 |
+
|
504 |
+
train_state = TrainState()
|
505 |
+
|
506 |
+
# load initial checkpoint
|
507 |
+
checkpoint = CheckpointManager(
|
508 |
+
dataloader=dataloader,
|
509 |
+
model_parts=model_parts,
|
510 |
+
optimizers=optimizers,
|
511 |
+
lr_schedulers=lr_schedulers,
|
512 |
+
states={"train_state": train_state},
|
513 |
+
job_config=job_config,
|
514 |
+
ft_manager=ft_manager,
|
515 |
+
)
|
516 |
+
|
517 |
+
if job_config.checkpoint.create_seed_checkpoint:
|
518 |
+
assert world_size == 1, (
|
519 |
+
"Must create seed checkpoint using a single device, to disable sharding"
|
520 |
+
)
|
521 |
+
assert job_config.checkpoint.enable_checkpoint, (
|
522 |
+
"Must enable checkpointing when creating a seed checkpoint"
|
523 |
+
)
|
524 |
+
checkpoint.save(curr_step=0, force=True)
|
525 |
+
logger.info("Created seed checkpoint")
|
526 |
+
return
|
527 |
+
|
528 |
+
checkpoint.load(step=job_config.checkpoint.load_step)
|
529 |
+
metric_logger = build_metrics_processor(job_config, parallel_dims)
|
530 |
+
# Set dependent attributes for metric_logger
|
531 |
+
metric_logger.num_flops_per_token = num_flops_per_token
|
532 |
+
metric_logger.optimizers = optimizers # Pass optimizers if needed by logger logic
|
533 |
+
metric_logger.lr_schedulers = (
|
534 |
+
lr_schedulers # Pass schedulers if needed by logger logic
|
535 |
+
)
|
536 |
+
|
537 |
+
# plot losses loaded from checkpoint (if any) to TensorBoard
|
538 |
+
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
|
539 |
+
# This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
|
540 |
+
if train_state.step > 0 and len(metric_logger.data_loading_times) > 0:
|
541 |
+
for idx, step in enumerate(train_state.log_steps):
|
542 |
+
metric_logger.log(
|
543 |
+
step,
|
544 |
+
global_avg_loss=train_state.global_avg_losses[idx],
|
545 |
+
global_max_loss=train_state.global_max_losses[idx],
|
546 |
+
)
|
547 |
+
|
548 |
+
data_iterator = iter(dataloader)
|
549 |
+
|
550 |
+
train_context = dist_utils.get_train_context(
|
551 |
+
parallel_dims.loss_parallel_enabled,
|
552 |
+
job_config.experimental.enable_compiled_autograd,
|
553 |
+
)
|
554 |
+
|
555 |
+
# variables used to keep info for metrics logging
|
556 |
+
device_memory_monitor.reset_peak_stats()
|
557 |
+
|
558 |
+
global_batch_size = (
|
559 |
+
job_config.training.batch_size
|
560 |
+
* dp_degree
|
561 |
+
* job_config.training.gradient_accumulation_steps
|
562 |
+
)
|
563 |
+
num_tokens_per_step = global_batch_size * job_config.training.seq_len
|
564 |
+
# train loop
|
565 |
+
logger.info(f"{color.red}***** Running training *****{color.reset}")
|
566 |
+
logger.info(f"{color.green} Training starts at step {train_state.step + 1}")
|
567 |
+
logger.info(
|
568 |
+
f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}"
|
569 |
+
)
|
570 |
+
logger.info(
|
571 |
+
f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}"
|
572 |
+
)
|
573 |
+
logger.info(
|
574 |
+
f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}"
|
575 |
+
)
|
576 |
+
logger.info(
|
577 |
+
f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}"
|
578 |
+
f" ({num_tokens_per_step:,} tokens)"
|
579 |
+
)
|
580 |
+
logger.info(
|
581 |
+
f"{color.green} Total optimization steps = {job_config.training.steps:,} "
|
582 |
+
f"({job_config.training.steps * num_tokens_per_step:,} tokens)"
|
583 |
+
)
|
584 |
+
logger.info(
|
585 |
+
f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}"
|
586 |
+
f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)"
|
587 |
+
)
|
588 |
+
logger.info(
|
589 |
+
f"{color.green} Number of parameters = {model_param_count:,} {color.reset}"
|
590 |
+
)
|
591 |
+
|
592 |
+
with (
|
593 |
+
maybe_enable_profiling(
|
594 |
+
job_config, global_step=train_state.step
|
595 |
+
) as torch_profiler,
|
596 |
+
maybe_enable_memory_snapshot(
|
597 |
+
job_config, global_step=train_state.step
|
598 |
+
) as memory_profiler,
|
599 |
+
):
|
600 |
+
while train_state.step < job_config.training.steps:
|
601 |
+
train_state.step += 1
|
602 |
+
gc_handler.run(train_state.step)
|
603 |
+
|
604 |
+
optimizers.zero_grad()
|
605 |
+
|
606 |
+
losses = defaultdict(list)
|
607 |
+
actual_loss = []
|
608 |
+
# do gradient accumulation if enabled
|
609 |
+
for _ in range(job_config.training.gradient_accumulation_steps):
|
610 |
+
# get batch
|
611 |
+
data_load_start = time.perf_counter()
|
612 |
+
batch = next(data_iterator)
|
613 |
+
# Recall that this is, for myopic and MTP, it will be
|
614 |
+
# input_ids : (B, seq_len)
|
615 |
+
# labels : (B, seq_len * 2)
|
616 |
+
input_ids, labels = batch["input_ids"][:, :job_config.training.seq_len], batch["labels"]
|
617 |
+
|
618 |
+
# Update metrics processor state before forward/backward
|
619 |
+
metric_logger.ntokens_since_last_log += input_ids.numel()
|
620 |
+
metric_logger.data_loading_times.append(
|
621 |
+
time.perf_counter() - data_load_start
|
622 |
+
)
|
623 |
+
|
624 |
+
input_ids = input_ids.to(device_type)
|
625 |
+
|
626 |
+
"""
|
627 |
+
TODO[flame]: We need to carefully handle the position_ids for TP/CP
|
628 |
+
Depending on the Models'PE, the position_ids might be different.
|
629 |
+
|
630 |
+
e.g. for TP
|
631 |
+
For RoPE, all ranks have the same position_ids. [FOR HF model]
|
632 |
+
For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model]
|
633 |
+
|
634 |
+
e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids]
|
635 |
+
Each rank has the coresponding chunked position_ids. [FOR All model]
|
636 |
+
|
637 |
+
"""
|
638 |
+
labels = labels.to(device_type)
|
639 |
+
cu_seqlens = (
|
640 |
+
batch["cu_seqlens"].to(device_type)
|
641 |
+
if "cu_seqlens" in batch
|
642 |
+
else None
|
643 |
+
)
|
644 |
+
if cu_seqlens is not None:
|
645 |
+
position_ids = prepare_position_ids(cu_seqlens).to(torch.int32)
|
646 |
+
else:
|
647 |
+
position_ids = (
|
648 |
+
torch.arange(0, input_ids.shape[1], device=device_type)
|
649 |
+
.repeat(input_ids.shape[0], 1)
|
650 |
+
.to(torch.int32)
|
651 |
+
)
|
652 |
+
# apply context parallelism if cp is enabled
|
653 |
+
# ensure CP handles the separate freqs_cis buffer for each pp stage
|
654 |
+
optional_context_parallel_ctx = (
|
655 |
+
dist_utils.create_context_parallel_ctx(
|
656 |
+
cp_mesh=world_mesh["cp"],
|
657 |
+
cp_buffers=[input_ids, labels, position_ids],
|
658 |
+
cp_seq_dims=[1, 1, 1],
|
659 |
+
cp_no_restore_buffers={input_ids, labels, position_ids},
|
660 |
+
cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
|
661 |
+
)
|
662 |
+
if parallel_dims.cp_enabled
|
663 |
+
else None
|
664 |
+
)
|
665 |
+
|
666 |
+
# #! TODO[flame], we should distribute the position_ids as well with CP
|
667 |
+
if parallel_dims.pp_enabled:
|
668 |
+
raise NotImplementedError(
|
669 |
+
"Pipeline parallelism is not supported in this version"
|
670 |
+
)
|
671 |
+
# Pipeline Parallel forward / backward inside step() call
|
672 |
+
with train_context(optional_context_parallel_ctx):
|
673 |
+
targets, losses = (
|
674 |
+
(labels, []) if has_last_stage else (None, None)
|
675 |
+
)
|
676 |
+
|
677 |
+
if has_first_stage:
|
678 |
+
pp_schedule.step(input_ids, target=targets, losses=losses)
|
679 |
+
else:
|
680 |
+
pp_schedule.step(target=targets, losses=losses)
|
681 |
+
|
682 |
+
# accumulate losses across pipeline microbatches
|
683 |
+
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
|
684 |
+
loss = (
|
685 |
+
torch.mean(torch.stack(losses)).to(device)
|
686 |
+
if has_last_stage
|
687 |
+
else torch.tensor([-1.0], device=device)
|
688 |
+
)
|
689 |
+
else:
|
690 |
+
# Non-PP forward / backward
|
691 |
+
with train_context(optional_context_parallel_ctx):
|
692 |
+
output = model(
|
693 |
+
input_ids=input_ids,
|
694 |
+
labels=labels,
|
695 |
+
position_ids=position_ids,
|
696 |
+
cu_seqlens=cu_seqlens,
|
697 |
+
)
|
698 |
+
output_attributes = [field.name for field in dataclasses.fields(output)]
|
699 |
+
losses_atributes = [x for x in output_attributes if "loss" in x and x != "loss"]
|
700 |
+
loss = (
|
701 |
+
output.loss
|
702 |
+
/ job_config.training.gradient_accumulation_steps
|
703 |
+
)
|
704 |
+
loss.backward()
|
705 |
+
|
706 |
+
actual_loss.append(loss)
|
707 |
+
for loss_attr in losses_atributes:
|
708 |
+
custom_loss = getattr(output, loss_attr, None)
|
709 |
+
if custom_loss is not None:
|
710 |
+
custom_loss = custom_loss / job_config.training.gradient_accumulation_steps
|
711 |
+
custom_loss = custom_loss
|
712 |
+
losses[loss_attr].append(custom_loss)
|
713 |
+
|
714 |
+
loss = sum(actual_loss)
|
715 |
+
for loss_attr, loss_values in losses.items():
|
716 |
+
losses[loss_attr] = sum(loss_values)
|
717 |
+
|
718 |
+
# clip gradients
|
719 |
+
grad_norm = dist_utils.clip_grad_norm_(
|
720 |
+
[p for m in model_parts for p in m.parameters()],
|
721 |
+
job_config.training.max_norm,
|
722 |
+
foreach=True,
|
723 |
+
pp_mesh=pp_mesh if parallel_dims.pp_enabled else None,
|
724 |
+
)
|
725 |
+
|
726 |
+
# optimizer step
|
727 |
+
checkpoint.maybe_wait_for_staging()
|
728 |
+
if job_config.training.skip_nan_inf and (
|
729 |
+
grad_norm.isnan() or grad_norm.isinf()
|
730 |
+
):
|
731 |
+
logger.warning(
|
732 |
+
f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}"
|
733 |
+
)
|
734 |
+
optimizers.zero_grad()
|
735 |
+
train_state.skipped_step += 1
|
736 |
+
else:
|
737 |
+
optimizers.step()
|
738 |
+
lr_schedulers.step()
|
739 |
+
|
740 |
+
# log metrics - Use MetricsProcessor
|
741 |
+
global_avg_custom_loss = {}
|
742 |
+
global_max_custom_loss = {}
|
743 |
+
if metric_logger.should_log(train_state.step):
|
744 |
+
if (
|
745 |
+
parallel_dims.dp_replicate_enabled
|
746 |
+
or parallel_dims.dp_shard_enabled
|
747 |
+
or parallel_dims.cp_enabled
|
748 |
+
):
|
749 |
+
loss = loss.detach()
|
750 |
+
# Use dist_mean/max on the accumulated loss for the step
|
751 |
+
global_avg_loss, global_max_loss = (
|
752 |
+
dist_utils.dist_mean(
|
753 |
+
loss,
|
754 |
+
world_mesh["dp_cp"],
|
755 |
+
),
|
756 |
+
dist_utils.dist_max(
|
757 |
+
loss,
|
758 |
+
world_mesh["dp_cp"],
|
759 |
+
),
|
760 |
+
)
|
761 |
+
for loss_attr, loss_value in losses.items():
|
762 |
+
global_avg_custom_loss[loss_attr] = dist_utils.dist_mean(
|
763 |
+
loss_value, world_mesh["dp_cp"]
|
764 |
+
)
|
765 |
+
global_max_custom_loss[loss_attr] = dist_utils.dist_max(
|
766 |
+
loss_value, world_mesh["dp_cp"]
|
767 |
+
)
|
768 |
+
else:
|
769 |
+
# Scale back the loss before logging
|
770 |
+
global_avg_loss = global_max_loss = loss.item()
|
771 |
+
for loss_attr, loss_value in losses.items():
|
772 |
+
global_avg_custom_loss[loss_attr] = global_max_custom_loss[
|
773 |
+
loss_attr
|
774 |
+
] = loss_value.item()
|
775 |
+
|
776 |
+
# Update train state tokens and elapsed time
|
777 |
+
time_now = time.perf_counter()
|
778 |
+
time_delta = (
|
779 |
+
time_now - metric_logger.time_last_log
|
780 |
+
) # Use metric_logger's time
|
781 |
+
train_state.token += (
|
782 |
+
metric_logger.ntokens_since_last_log # Use tokens tracked by metric_logger
|
783 |
+
* parallel_dims.world_size
|
784 |
+
/ parallel_dims.non_data_parallel_size
|
785 |
+
)
|
786 |
+
train_state.elapsed += timedelta(seconds=time_delta)
|
787 |
+
train_state.log_steps.append(train_state.step)
|
788 |
+
train_state.global_avg_losses.append(global_avg_loss)
|
789 |
+
train_state.global_max_losses.append(global_max_loss)
|
790 |
+
|
791 |
+
# Log using the metric processor
|
792 |
+
last_lr = lr_schedulers.schedulers[0].get_last_lr()[0]
|
793 |
+
eta = (
|
794 |
+
train_state.elapsed
|
795 |
+
* (job_config.training.steps - train_state.step)
|
796 |
+
/ train_state.step
|
797 |
+
)
|
798 |
+
extra_metrics = {
|
799 |
+
"optimizer/lr": last_lr,
|
800 |
+
"optimizer/grad_norm": grad_norm.item(),
|
801 |
+
"optimizer/skipped_step": train_state.skipped_step,
|
802 |
+
}
|
803 |
+
for loss_attr, loss_value in global_avg_custom_loss.items():
|
804 |
+
extra_metrics[f"loss_metrics/global_avg_{loss_attr}"] = loss_value.item() if isinstance(loss_value, torch.Tensor) else loss_value
|
805 |
+
metric_logger.log(
|
806 |
+
train_state.step,
|
807 |
+
global_avg_loss,
|
808 |
+
global_max_loss,
|
809 |
+
extra_metrics=extra_metrics,
|
810 |
+
)
|
811 |
+
|
812 |
+
logger.info(
|
813 |
+
f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} "
|
814 |
+
f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}"
|
815 |
+
)
|
816 |
+
|
817 |
+
checkpoint.save(
|
818 |
+
train_state.step, force=(train_state.step == job_config.training.steps)
|
819 |
+
)
|
820 |
+
|
821 |
+
if torch.distributed.get_rank() == 0:
|
822 |
+
if job_config.checkpoint.enable_checkpoint:
|
823 |
+
hf_target_path = None
|
824 |
+
dcp_save_path = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder, f"step-{train_state.step}")
|
825 |
+
|
826 |
+
# TODO: Haven't tested this one yet
|
827 |
+
if getattr(job_config.checkpoint, "convert_to_hf_on_save", False):
|
828 |
+
try:
|
829 |
+
# Get the path where DCP was just saved
|
830 |
+
# Check CheckpointManager API for the best way, assuming get_save_path exists
|
831 |
+
hf_target_path = f"{dcp_save_path}" # e.g., .../checkpoint/step-1000-hf
|
832 |
+
|
833 |
+
logger.info(f"Converting step {train_state.step} DCP checkpoint to HF format at: {hf_target_path}")
|
834 |
+
save_pretrained( # Call the imported function
|
835 |
+
path=hf_target_path, # Pass target HF path as 'path'
|
836 |
+
step=train_state.step,
|
837 |
+
config=job_config.model.config, # Pass model config path/id
|
838 |
+
tokenizer=job_config.model.tokenizer_path # Pass tokenizer path/id
|
839 |
+
)
|
840 |
+
logger.info(f"Successfully converted step {train_state.step} to HF format.")
|
841 |
+
|
842 |
+
except Exception as e:
|
843 |
+
logger.error(f"Failed to convert checkpoint step {train_state.step} to HF format: {e}", exc_info=True)
|
844 |
+
|
845 |
+
base_checkpoint_dir = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder)
|
846 |
+
if getattr(job_config.checkpoint, "hf_upload_enabled", True):
|
847 |
+
upload_format = getattr(job_config.checkpoint, "hf_upload_format", "hf")
|
848 |
+
keep_k_hub = getattr(job_config.checkpoint, "hf_keep_latest_k", 5)
|
849 |
+
|
850 |
+
local_path_to_upload = None
|
851 |
+
if upload_format == "hf":
|
852 |
+
if hf_target_path and os.path.isdir(hf_target_path):
|
853 |
+
local_path_to_upload = hf_target_path
|
854 |
+
elif upload_format == "dcp":
|
855 |
+
if dcp_save_path and os.path.isdir(dcp_save_path):
|
856 |
+
local_path_to_upload = dcp_save_path
|
857 |
+
|
858 |
+
if local_path_to_upload:
|
859 |
+
try:
|
860 |
+
upload_checkpoint_to_hf(
|
861 |
+
local_path=local_path_to_upload,
|
862 |
+
step=train_state.step,
|
863 |
+
hf_repo_id_for_run=run_specific_repo_id,
|
864 |
+
upload_format=upload_format,
|
865 |
+
hf_keep_latest_k=job_config.checkpoint.keep_latest_k,
|
866 |
+
)
|
867 |
+
except Exception as e:
|
868 |
+
logger.error(f"Failed during HF Hub upload for step {train_state.step}: {e}", exc_info=True)
|
869 |
+
|
870 |
+
# signal the profiler that the next profiling step has started
|
871 |
+
if torch_profiler:
|
872 |
+
torch_profiler.step()
|
873 |
+
if memory_profiler:
|
874 |
+
memory_profiler.step()
|
875 |
+
|
876 |
+
# reduce timeout after first train step for faster signal
|
877 |
+
# (assuming lazy init and compilation are finished)
|
878 |
+
if train_state.step == 1:
|
879 |
+
dist_utils.set_pg_timeouts(
|
880 |
+
timeout=timedelta(seconds=job_config.comm.train_timeout_seconds),
|
881 |
+
world_mesh=world_mesh,
|
882 |
+
)
|
883 |
+
|
884 |
+
if torch.distributed.get_rank() == 0:
|
885 |
+
logger.info("Sleeping 2 seconds for other ranks to complete")
|
886 |
+
time.sleep(2)
|
887 |
+
|
888 |
+
metric_logger.close()
|
889 |
+
logger.info("Training completed")
|
890 |
+
|
891 |
+
|
892 |
+
if __name__ == "__main__":
|
893 |
+
init_logger()
|
894 |
+
config = JobConfig()
|
895 |
+
config.parse_args()
|
896 |
+
main(config)
|
897 |
+
torch.distributed.destroy_process_group()
|
flame/utils/convert_dcp_to_hf.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import tempfile
|
8 |
+
from datetime import timedelta
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.serialization
|
12 |
+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
13 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
14 |
+
|
15 |
+
import fla # noqa
|
16 |
+
from torchtitan.tools.logging import init_logger, logger
|
17 |
+
|
18 |
+
|
19 |
+
@torch.inference_mode()
|
20 |
+
def save_pretrained(
|
21 |
+
path: str,
|
22 |
+
step: int,
|
23 |
+
config: str,
|
24 |
+
tokenizer: str
|
25 |
+
):
|
26 |
+
logger.info(f"Loading the config from {config}")
|
27 |
+
config = AutoConfig.from_pretrained(config, trust_remote_code=True)
|
28 |
+
|
29 |
+
logger.info(f"Saving the config to {path}")
|
30 |
+
config.save_pretrained(path)
|
31 |
+
logger.info(f"Loading the tokenizer from {tokenizer}")
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True)
|
33 |
+
logger.info(f"Saving the tokenizer to {path}")
|
34 |
+
tokenizer.save_pretrained(path)
|
35 |
+
|
36 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
37 |
+
# base_checkpoint_dir = os.path.dirname(path)
|
38 |
+
base_checkpoint_dir = path
|
39 |
+
checkpoint = os.path.join(base_checkpoint_dir, f'checkpoint/step-{step}')
|
40 |
+
checkpoint_path = os.path.join(tmpdir, 'checkpoint.pt')
|
41 |
+
logger.info(f"Saving the distributed checkpoint to {checkpoint_path}")
|
42 |
+
dcp_to_torch_save(checkpoint, checkpoint_path)
|
43 |
+
|
44 |
+
logger.info(f"Initializing the model from config\n{config}")
|
45 |
+
model = AutoModelForCausalLM.from_config(config)
|
46 |
+
logger.info(model)
|
47 |
+
logger.info("Loading state dict from the checkpoint")
|
48 |
+
|
49 |
+
# Add datetime.timedelta and io.BytesIO to safe globals
|
50 |
+
torch.serialization.add_safe_globals([timedelta, io.BytesIO])
|
51 |
+
# torch.load now with default weights_only=True will work
|
52 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])
|
53 |
+
|
54 |
+
logger.info(f"Saving the model to {path}")
|
55 |
+
model.save_pretrained(path)
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
init_logger()
|
60 |
+
parser = argparse.ArgumentParser("Convert DCP format model weights to huggingface-style.")
|
61 |
+
parser.add_argument("--path", type=str, required=True)
|
62 |
+
parser.add_argument("--step", type=int, required=True)
|
63 |
+
parser.add_argument("--config", type=str, required=True)
|
64 |
+
parser.add_argument("--tokenizer", type=str, required=True)
|
65 |
+
args = parser.parse_args()
|
66 |
+
save_pretrained(args.path, args.step, args.config, args.tokenizer)
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"transformers_version": "4.51.3"
|
6 |
+
}
|