Make `configuration_phi4flash.py` and `modeling_phi4flash.py` compatible with standard sliding window config (#7)
Browse files- `sliding_window: list[Optional[int]]` -> `sliding_window: int` + `layer_types: list[str]` (8928fa33c0a966cf0822f7601ec722d38dd61285)
- configuration_phi4flash.py +10 -4
- modeling_phi4flash.py +9 -9
configuration_phi4flash.py
CHANGED
@@ -112,6 +112,7 @@ class Phi4FlashConfig(PretrainedConfig):
|
|
112 |
bos_token_id=1,
|
113 |
eos_token_id=2,
|
114 |
sliding_window=2047,
|
|
|
115 |
mb_per_layer= 2,
|
116 |
mamba_d_state=16,
|
117 |
mamba_d_conv=4,
|
@@ -141,11 +142,16 @@ class Phi4FlashConfig(PretrainedConfig):
|
|
141 |
self.use_cache = use_cache
|
142 |
self.rope_theta = rope_theta
|
143 |
self.mb_per_layer = mb_per_layer
|
144 |
-
self.sliding_window =
|
145 |
-
|
146 |
-
for layer_idx in range(num_hidden_layers)
|
147 |
-
]
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
self.mamba_d_state = mamba_d_state
|
150 |
self.mamba_d_conv = mamba_d_conv
|
151 |
self.mamba_expand = mamba_expand
|
|
|
112 |
bos_token_id=1,
|
113 |
eos_token_id=2,
|
114 |
sliding_window=2047,
|
115 |
+
layer_types=None,
|
116 |
mb_per_layer= 2,
|
117 |
mamba_d_state=16,
|
118 |
mamba_d_conv=4,
|
|
|
142 |
self.use_cache = use_cache
|
143 |
self.rope_theta = rope_theta
|
144 |
self.mb_per_layer = mb_per_layer
|
145 |
+
self.sliding_window = sliding_window
|
146 |
+
self.layer_types = layer_types
|
|
|
|
|
147 |
|
148 |
+
if self.layer_types is None:
|
149 |
+
is_sliding = lambda i: i < num_hidden_layers // 2 and i % 2 == 1,
|
150 |
+
self.layer_types = [
|
151 |
+
"sliding_attention" if is_sliding(layer_idx) else "full_attention"
|
152 |
+
for layer_idx in range(num_hidden_layers)
|
153 |
+
]
|
154 |
+
|
155 |
self.mamba_d_state = mamba_d_state
|
156 |
self.mamba_d_conv = mamba_d_conv
|
157 |
self.mamba_expand = mamba_expand
|
modeling_phi4flash.py
CHANGED
@@ -129,7 +129,7 @@ def _get_cache(
|
|
129 |
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
|
130 |
|
131 |
if cache_implementation == "sliding_window":
|
132 |
-
max_cache_len = min(self.config.sliding_window
|
133 |
|
134 |
need_new_cache = (
|
135 |
not hasattr(self, "_cache")
|
@@ -243,7 +243,7 @@ class SambaYCache(Cache):
|
|
243 |
sliding_cache_shape = (
|
244 |
self.max_batch_size,
|
245 |
self.num_key_value_heads,
|
246 |
-
min(config.sliding_window
|
247 |
self.head_dim,
|
248 |
)
|
249 |
conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
|
@@ -573,7 +573,7 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
573 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
574 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
575 |
|
576 |
-
use_sliding_windows = self.config.sliding_window is not None and self.config.
|
577 |
|
578 |
if past_key_value is not None:
|
579 |
|
@@ -710,8 +710,8 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
710 |
softmax_scale=softmax_scale,
|
711 |
causal=causal,
|
712 |
window_size=(
|
713 |
-
self.config.
|
714 |
-
self.config.
|
715 |
),
|
716 |
)
|
717 |
|
@@ -735,8 +735,8 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
735 |
softmax_scale=softmax_scale,
|
736 |
causal=causal,
|
737 |
window_size=(
|
738 |
-
self.config.
|
739 |
-
self.config.
|
740 |
),
|
741 |
)
|
742 |
|
@@ -1085,9 +1085,9 @@ class SambaYDecoderLayer(nn.Module):
|
|
1085 |
residual = residual.to(torch.float32)
|
1086 |
self_attn_weights = None
|
1087 |
else:
|
1088 |
-
if self.config.sliding_window is not None and self.config.
|
1089 |
if past_key_value is not None and cache_position[0] > 0: # when decoding
|
1090 |
-
attention_mask = attention_mask[:, -self.config.
|
1091 |
#hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
|
1092 |
# Self Attention
|
1093 |
attn_outputs, self_attn_weights, yoco_key_values = self.attn(
|
|
|
129 |
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
|
130 |
|
131 |
if cache_implementation == "sliding_window":
|
132 |
+
max_cache_len = min(self.config.sliding_window, max_cache_len)
|
133 |
|
134 |
need_new_cache = (
|
135 |
not hasattr(self, "_cache")
|
|
|
243 |
sliding_cache_shape = (
|
244 |
self.max_batch_size,
|
245 |
self.num_key_value_heads,
|
246 |
+
min(config.sliding_window, max_cache_len),
|
247 |
self.head_dim,
|
248 |
)
|
249 |
conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
|
|
|
573 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
574 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
575 |
|
576 |
+
use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] is not None
|
577 |
|
578 |
if past_key_value is not None:
|
579 |
|
|
|
710 |
softmax_scale=softmax_scale,
|
711 |
causal=causal,
|
712 |
window_size=(
|
713 |
+
self.config.layer_types[self.layer_idx] -1,
|
714 |
+
self.config.layer_types[self.layer_idx] -1,
|
715 |
),
|
716 |
)
|
717 |
|
|
|
735 |
softmax_scale=softmax_scale,
|
736 |
causal=causal,
|
737 |
window_size=(
|
738 |
+
self.config.layer_types[self.layer_idx] -1,
|
739 |
+
self.config.layer_types[self.layer_idx] -1,
|
740 |
),
|
741 |
)
|
742 |
|
|
|
1085 |
residual = residual.to(torch.float32)
|
1086 |
self_attn_weights = None
|
1087 |
else:
|
1088 |
+
if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] is not None and attention_mask is not None: # efficient SDPA and no padding
|
1089 |
if past_key_value is not None and cache_position[0] > 0: # when decoding
|
1090 |
+
attention_mask = attention_mask[:, -self.config.layer_types[self.layer_idx]:]
|
1091 |
#hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
|
1092 |
# Self Attention
|
1093 |
attn_outputs, self_attn_weights, yoco_key_values = self.attn(
|