speed up flash-attn inference
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
|
@@ -16,6 +16,7 @@ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
|
| 16 |
|
| 17 |
try:
|
| 18 |
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
|
|
|
| 19 |
flash_attn_varlen_kvpacked_func,
|
| 20 |
flash_attn_varlen_qkvpacked_func,
|
| 21 |
)
|
|
@@ -146,7 +147,7 @@ def flashattn_forward(
|
|
| 146 |
else:
|
| 147 |
# turn off FA causal mask after first inference autoregressive iteration
|
| 148 |
# only on first autoregressive step q,k,v have same seqlen
|
| 149 |
-
is_causal =
|
| 150 |
|
| 151 |
if self.training and attention_mask.shape[0] == 1:
|
| 152 |
# special handling using sample packing
|
|
@@ -163,14 +164,20 @@ def flashattn_forward(
|
|
| 163 |
)
|
| 164 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 165 |
elif query_states.shape == key_states.shape:
|
|
|
|
|
|
|
|
|
|
| 166 |
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
| 167 |
-
query_states
|
| 168 |
-
key_states
|
| 169 |
-
value_states
|
| 170 |
qkvpacked=True,
|
| 171 |
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
| 172 |
# the attention_mask should be the same as the key_padding_mask
|
| 173 |
key_padding_mask=attention_mask,
|
|
|
|
|
|
|
|
|
|
| 174 |
)
|
| 175 |
output_unpad = flash_attn_varlen_qkvpacked_func(
|
| 176 |
qkv_unpad,
|
|
@@ -182,35 +189,48 @@ def flashattn_forward(
|
|
| 182 |
)
|
| 183 |
output = output_pad_fn(output_unpad)
|
| 184 |
else:
|
| 185 |
-
(
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
attn_output = output
|
| 216 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
|
|
|
| 16 |
|
| 17 |
try:
|
| 18 |
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
| 19 |
+
flash_attn_kvpacked_func,
|
| 20 |
flash_attn_varlen_kvpacked_func,
|
| 21 |
flash_attn_varlen_qkvpacked_func,
|
| 22 |
)
|
|
|
|
| 147 |
else:
|
| 148 |
# turn off FA causal mask after first inference autoregressive iteration
|
| 149 |
# only on first autoregressive step q,k,v have same seqlen
|
| 150 |
+
is_causal = past_key_value is not None
|
| 151 |
|
| 152 |
if self.training and attention_mask.shape[0] == 1:
|
| 153 |
# special handling using sample packing
|
|
|
|
| 164 |
)
|
| 165 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
| 166 |
elif query_states.shape == key_states.shape:
|
| 167 |
+
query_states = query_states.transpose(1, 2)
|
| 168 |
+
key_states = key_states.transpose(1, 2)
|
| 169 |
+
value_states = value_states.transpose(1, 2)
|
| 170 |
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
| 171 |
+
query_states,
|
| 172 |
+
key_states,
|
| 173 |
+
value_states,
|
| 174 |
qkvpacked=True,
|
| 175 |
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
| 176 |
# the attention_mask should be the same as the key_padding_mask
|
| 177 |
key_padding_mask=attention_mask,
|
| 178 |
+
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
| 179 |
+
if attention_mask is not None
|
| 180 |
+
else None,
|
| 181 |
)
|
| 182 |
output_unpad = flash_attn_varlen_qkvpacked_func(
|
| 183 |
qkv_unpad,
|
|
|
|
| 189 |
)
|
| 190 |
output = output_pad_fn(output_unpad)
|
| 191 |
else:
|
| 192 |
+
query_states = query_states.transpose(1, 2)
|
| 193 |
+
key_states = key_states.transpose(1, 2)
|
| 194 |
+
value_states = value_states.transpose(1, 2)
|
| 195 |
+
if attention_mask is None or attention_mask.all().item():
|
| 196 |
+
output = flash_attn_kvpacked_func(
|
| 197 |
+
query_states,
|
| 198 |
+
torch.stack([key_states, value_states], 2),
|
| 199 |
+
causal=is_causal,
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
( # pylint: disable=unbalanced-tuple-unpacking
|
| 203 |
+
q_unpad,
|
| 204 |
+
kv_unpad,
|
| 205 |
+
cu_seqlens_q,
|
| 206 |
+
cu_seqlens_k,
|
| 207 |
+
max_seqlen_q,
|
| 208 |
+
max_seqlen_k,
|
| 209 |
+
_,
|
| 210 |
+
_,
|
| 211 |
+
output_pad_fn,
|
| 212 |
+
) = generate_qkv(
|
| 213 |
+
query_states,
|
| 214 |
+
key_states,
|
| 215 |
+
value_states,
|
| 216 |
+
kvpacked=True,
|
| 217 |
+
key_padding_mask=attention_mask,
|
| 218 |
+
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
| 219 |
+
if attention_mask is not None
|
| 220 |
+
else None,
|
| 221 |
+
)
|
| 222 |
+
output_unpad = flash_attn_varlen_kvpacked_func(
|
| 223 |
+
q_unpad,
|
| 224 |
+
kv_unpad,
|
| 225 |
+
cu_seqlens_q,
|
| 226 |
+
cu_seqlens_k,
|
| 227 |
+
max_seqlen_q,
|
| 228 |
+
max_seqlen_k,
|
| 229 |
+
0.0,
|
| 230 |
+
softmax_scale=None,
|
| 231 |
+
causal=is_causal,
|
| 232 |
+
)
|
| 233 |
+
output = output_pad_fn(output_unpad)
|
| 234 |
|
| 235 |
attn_output = output
|
| 236 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|