adds llama and mistral dropout support (#858)
Browse files* adds llama and mistral dropout support
* gracefully handle attention dropout if not available yet
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
|
@@ -321,6 +321,8 @@ def flashattn_forward(
|
|
| 321 |
# only on first autoregressive step q,k,v have same seqlen
|
| 322 |
is_causal = key_states.shape == query_states.shape
|
| 323 |
|
|
|
|
|
|
|
| 324 |
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
| 325 |
# special handling using sample packing
|
| 326 |
qkv = torch.stack(
|
|
@@ -330,7 +332,12 @@ def flashattn_forward(
|
|
| 330 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 331 |
|
| 332 |
output = flash_attn_varlen_qkvpacked_func(
|
| 333 |
-
qkv,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
)
|
| 335 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 336 |
elif query_states.shape == key_states.shape:
|
|
@@ -353,7 +360,7 @@ def flashattn_forward(
|
|
| 353 |
qkv_unpad,
|
| 354 |
cu_seqlens_q,
|
| 355 |
max_seqlen_q,
|
| 356 |
-
|
| 357 |
softmax_scale=None,
|
| 358 |
causal=is_causal,
|
| 359 |
)
|
|
@@ -366,6 +373,7 @@ def flashattn_forward(
|
|
| 366 |
output = flash_attn_kvpacked_func(
|
| 367 |
query_states,
|
| 368 |
torch.stack([key_states, value_states], 2),
|
|
|
|
| 369 |
causal=is_causal,
|
| 370 |
)
|
| 371 |
else:
|
|
@@ -398,7 +406,7 @@ def flashattn_forward(
|
|
| 398 |
cu_seqlens_k,
|
| 399 |
max_seqlen_q,
|
| 400 |
max_seqlen_k,
|
| 401 |
-
|
| 402 |
softmax_scale=None,
|
| 403 |
causal=is_causal,
|
| 404 |
)
|
|
|
|
| 321 |
# only on first autoregressive step q,k,v have same seqlen
|
| 322 |
is_causal = key_states.shape == query_states.shape
|
| 323 |
|
| 324 |
+
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
| 325 |
+
|
| 326 |
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
| 327 |
# special handling using sample packing
|
| 328 |
qkv = torch.stack(
|
|
|
|
| 332 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
| 333 |
|
| 334 |
output = flash_attn_varlen_qkvpacked_func(
|
| 335 |
+
qkv,
|
| 336 |
+
cu_seqlens,
|
| 337 |
+
max_seqlen,
|
| 338 |
+
dropout_p=dropout_rate,
|
| 339 |
+
softmax_scale=None,
|
| 340 |
+
causal=True,
|
| 341 |
)
|
| 342 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 343 |
elif query_states.shape == key_states.shape:
|
|
|
|
| 360 |
qkv_unpad,
|
| 361 |
cu_seqlens_q,
|
| 362 |
max_seqlen_q,
|
| 363 |
+
dropout_p=dropout_rate,
|
| 364 |
softmax_scale=None,
|
| 365 |
causal=is_causal,
|
| 366 |
)
|
|
|
|
| 373 |
output = flash_attn_kvpacked_func(
|
| 374 |
query_states,
|
| 375 |
torch.stack([key_states, value_states], 2),
|
| 376 |
+
dropout_p=dropout_rate,
|
| 377 |
causal=is_causal,
|
| 378 |
)
|
| 379 |
else:
|
|
|
|
| 406 |
cu_seqlens_k,
|
| 407 |
max_seqlen_q,
|
| 408 |
max_seqlen_k,
|
| 409 |
+
dropout_p=dropout_rate,
|
| 410 |
softmax_scale=None,
|
| 411 |
causal=is_causal,
|
| 412 |
)
|
src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
CHANGED
|
@@ -201,6 +201,8 @@ def flashattn_forward(
|
|
| 201 |
# only on first autoregressive step q,k,v have same seqlen
|
| 202 |
is_causal = key_states.shape == query_states.shape
|
| 203 |
|
|
|
|
|
|
|
| 204 |
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
| 205 |
# special handling using sample packing
|
| 206 |
qkv = torch.stack(
|
|
@@ -213,7 +215,7 @@ def flashattn_forward(
|
|
| 213 |
qkv,
|
| 214 |
cu_seqlens,
|
| 215 |
max_seqlen,
|
| 216 |
-
|
| 217 |
softmax_scale=None,
|
| 218 |
causal=True,
|
| 219 |
window_size=window_size,
|
|
@@ -239,7 +241,7 @@ def flashattn_forward(
|
|
| 239 |
qkv_unpad,
|
| 240 |
cu_seqlens_q,
|
| 241 |
max_seqlen_q,
|
| 242 |
-
|
| 243 |
softmax_scale=None,
|
| 244 |
causal=is_causal,
|
| 245 |
window_size=window_size,
|
|
@@ -253,6 +255,7 @@ def flashattn_forward(
|
|
| 253 |
output = flash_attn_kvpacked_func(
|
| 254 |
query_states,
|
| 255 |
torch.stack([key_states, value_states], 2),
|
|
|
|
| 256 |
causal=is_causal,
|
| 257 |
window_size=window_size,
|
| 258 |
)
|
|
@@ -286,7 +289,7 @@ def flashattn_forward(
|
|
| 286 |
cu_seqlens_k,
|
| 287 |
max_seqlen_q,
|
| 288 |
max_seqlen_k,
|
| 289 |
-
|
| 290 |
softmax_scale=None,
|
| 291 |
causal=is_causal,
|
| 292 |
window_size=window_size,
|
|
|
|
| 201 |
# only on first autoregressive step q,k,v have same seqlen
|
| 202 |
is_causal = key_states.shape == query_states.shape
|
| 203 |
|
| 204 |
+
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
| 205 |
+
|
| 206 |
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
| 207 |
# special handling using sample packing
|
| 208 |
qkv = torch.stack(
|
|
|
|
| 215 |
qkv,
|
| 216 |
cu_seqlens,
|
| 217 |
max_seqlen,
|
| 218 |
+
dropout_p=dropout_rate,
|
| 219 |
softmax_scale=None,
|
| 220 |
causal=True,
|
| 221 |
window_size=window_size,
|
|
|
|
| 241 |
qkv_unpad,
|
| 242 |
cu_seqlens_q,
|
| 243 |
max_seqlen_q,
|
| 244 |
+
dropout_p=dropout_rate,
|
| 245 |
softmax_scale=None,
|
| 246 |
causal=is_causal,
|
| 247 |
window_size=window_size,
|
|
|
|
| 255 |
output = flash_attn_kvpacked_func(
|
| 256 |
query_states,
|
| 257 |
torch.stack([key_states, value_states], 2),
|
| 258 |
+
dropout_p=dropout_rate,
|
| 259 |
causal=is_causal,
|
| 260 |
window_size=window_size,
|
| 261 |
)
|
|
|
|
| 289 |
cu_seqlens_k,
|
| 290 |
max_seqlen_q,
|
| 291 |
max_seqlen_k,
|
| 292 |
+
dropout_p=dropout_rate,
|
| 293 |
softmax_scale=None,
|
| 294 |
causal=is_causal,
|
| 295 |
window_size=window_size,
|