DreamGenX
commited on
Respect sliding_window=None (#1214)
Browse files
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
CHANGED
|
@@ -94,7 +94,7 @@ def _prepare_decoder_attention_mask(
|
|
| 94 |
sliding_window,
|
| 95 |
): # pylint: disable=unused-argument
|
| 96 |
# [bsz, seq_len]
|
| 97 |
-
if attention_mask is None:
|
| 98 |
return attention_mask
|
| 99 |
|
| 100 |
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
|
@@ -151,7 +151,7 @@ def flashattn_forward(
|
|
| 151 |
)
|
| 152 |
|
| 153 |
use_sliding_windows = (
|
| 154 |
-
|
| 155 |
and kv_seq_len > self.config.sliding_window
|
| 156 |
)
|
| 157 |
|
|
|
|
| 94 |
sliding_window,
|
| 95 |
): # pylint: disable=unused-argument
|
| 96 |
# [bsz, seq_len]
|
| 97 |
+
if attention_mask is None or sliding_window is None:
|
| 98 |
return attention_mask
|
| 99 |
|
| 100 |
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
|
|
|
| 151 |
)
|
| 152 |
|
| 153 |
use_sliding_windows = (
|
| 154 |
+
getattr(self.config, "sliding_window") is not None
|
| 155 |
and kv_seq_len > self.config.sliding_window
|
| 156 |
)
|
| 157 |
|