renll commited on
Commit
88bd2a6
·
verified ·
1 Parent(s): 1cbf1b7

Make `configuration_phi4flash.py` and `modeling_phi4flash.py` compatible with standard sliding window config (#7)

Browse files

- `sliding_window: list[Optional[int]]` -> `sliding_window: int` + `layer_types: list[str]` (8928fa33c0a966cf0822f7601ec722d38dd61285)

configuration_phi4flash.py CHANGED
@@ -112,6 +112,7 @@ class Phi4FlashConfig(PretrainedConfig):
112
  bos_token_id=1,
113
  eos_token_id=2,
114
  sliding_window=2047,
 
115
  mb_per_layer= 2,
116
  mamba_d_state=16,
117
  mamba_d_conv=4,
@@ -141,11 +142,16 @@ class Phi4FlashConfig(PretrainedConfig):
141
  self.use_cache = use_cache
142
  self.rope_theta = rope_theta
143
  self.mb_per_layer = mb_per_layer
144
- self.sliding_window = [
145
- sliding_window if layer_idx < num_hidden_layers // 2 and layer_idx % 2 == 1 else None
146
- for layer_idx in range(num_hidden_layers)
147
- ]
148
 
 
 
 
 
 
 
 
149
  self.mamba_d_state = mamba_d_state
150
  self.mamba_d_conv = mamba_d_conv
151
  self.mamba_expand = mamba_expand
 
112
  bos_token_id=1,
113
  eos_token_id=2,
114
  sliding_window=2047,
115
+ layer_types=None,
116
  mb_per_layer= 2,
117
  mamba_d_state=16,
118
  mamba_d_conv=4,
 
142
  self.use_cache = use_cache
143
  self.rope_theta = rope_theta
144
  self.mb_per_layer = mb_per_layer
145
+ self.sliding_window = sliding_window
146
+ self.layer_types = layer_types
 
 
147
 
148
+ if self.layer_types is None:
149
+ is_sliding = lambda i: i < num_hidden_layers // 2 and i % 2 == 1,
150
+ self.layer_types = [
151
+ "sliding_attention" if is_sliding(layer_idx) else "full_attention"
152
+ for layer_idx in range(num_hidden_layers)
153
+ ]
154
+
155
  self.mamba_d_state = mamba_d_state
156
  self.mamba_d_conv = mamba_d_conv
157
  self.mamba_expand = mamba_expand
modeling_phi4flash.py CHANGED
@@ -129,7 +129,7 @@ def _get_cache(
129
  cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
130
 
131
  if cache_implementation == "sliding_window":
132
- max_cache_len = min(self.config.sliding_window[1], max_cache_len)
133
 
134
  need_new_cache = (
135
  not hasattr(self, "_cache")
@@ -243,7 +243,7 @@ class SambaYCache(Cache):
243
  sliding_cache_shape = (
244
  self.max_batch_size,
245
  self.num_key_value_heads,
246
- min(config.sliding_window[1], max_cache_len),
247
  self.head_dim,
248
  )
249
  conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
@@ -573,7 +573,7 @@ class SambaYFlashAttention2(SambaYAttention):
573
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
575
 
576
- use_sliding_windows = self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None
577
 
578
  if past_key_value is not None:
579
 
@@ -710,8 +710,8 @@ class SambaYFlashAttention2(SambaYAttention):
710
  softmax_scale=softmax_scale,
711
  causal=causal,
712
  window_size=(
713
- self.config.sliding_window[self.layer_idx] -1,
714
- self.config.sliding_window[self.layer_idx] -1,
715
  ),
716
  )
717
 
@@ -735,8 +735,8 @@ class SambaYFlashAttention2(SambaYAttention):
735
  softmax_scale=softmax_scale,
736
  causal=causal,
737
  window_size=(
738
- self.config.sliding_window[self.layer_idx] -1,
739
- self.config.sliding_window[self.layer_idx] -1,
740
  ),
741
  )
742
 
@@ -1085,9 +1085,9 @@ class SambaYDecoderLayer(nn.Module):
1085
  residual = residual.to(torch.float32)
1086
  self_attn_weights = None
1087
  else:
1088
- if self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None and attention_mask is not None: # efficient SDPA and no padding
1089
  if past_key_value is not None and cache_position[0] > 0: # when decoding
1090
- attention_mask = attention_mask[:, -self.config.sliding_window[self.layer_idx]:]
1091
  #hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
1092
  # Self Attention
1093
  attn_outputs, self_attn_weights, yoco_key_values = self.attn(
 
129
  cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
130
 
131
  if cache_implementation == "sliding_window":
132
+ max_cache_len = min(self.config.sliding_window, max_cache_len)
133
 
134
  need_new_cache = (
135
  not hasattr(self, "_cache")
 
243
  sliding_cache_shape = (
244
  self.max_batch_size,
245
  self.num_key_value_heads,
246
+ min(config.sliding_window, max_cache_len),
247
  self.head_dim,
248
  )
249
  conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
 
573
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
575
 
576
+ use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] is not None
577
 
578
  if past_key_value is not None:
579
 
 
710
  softmax_scale=softmax_scale,
711
  causal=causal,
712
  window_size=(
713
+ self.config.layer_types[self.layer_idx] -1,
714
+ self.config.layer_types[self.layer_idx] -1,
715
  ),
716
  )
717
 
 
735
  softmax_scale=softmax_scale,
736
  causal=causal,
737
  window_size=(
738
+ self.config.layer_types[self.layer_idx] -1,
739
+ self.config.layer_types[self.layer_idx] -1,
740
  ),
741
  )
742
 
 
1085
  residual = residual.to(torch.float32)
1086
  self_attn_weights = None
1087
  else:
1088
+ if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] is not None and attention_mask is not None: # efficient SDPA and no padding
1089
  if past_key_value is not None and cache_position[0] > 0: # when decoding
1090
+ attention_mask = attention_mask[:, -self.config.layer_types[self.layer_idx]:]
1091
  #hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
1092
  # Self Attention
1093
  attn_outputs, self_attn_weights, yoco_key_values = self.attn(