use flash_attn xentropy when available (#525)
Browse files* use flash_attn xentropy when available
* log when xentropy is not found
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
|
@@ -2,7 +2,9 @@
|
|
| 2 |
|
| 3 |
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
| 4 |
|
|
|
|
| 5 |
import warnings
|
|
|
|
| 6 |
from typing import List, Optional, Tuple, Union
|
| 7 |
|
| 8 |
import torch
|
|
@@ -33,6 +35,9 @@ except ImportError:
|
|
| 33 |
)
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
| 37 |
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
| 38 |
_prepare_decoder_attention_mask
|
|
@@ -44,6 +49,18 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
|
| 44 |
llama_model_forward
|
| 45 |
)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
| 49 |
# requires the attention mask to be the same as the key_padding_mask
|
|
|
|
| 2 |
|
| 3 |
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
| 4 |
|
| 5 |
+
import logging
|
| 6 |
import warnings
|
| 7 |
+
from functools import partial
|
| 8 |
from typing import List, Optional, Tuple, Union
|
| 9 |
|
| 10 |
import torch
|
|
|
|
| 35 |
)
|
| 36 |
|
| 37 |
|
| 38 |
+
LOG = logging.getLogger("axolotl")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
| 42 |
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
| 43 |
_prepare_decoder_attention_mask
|
|
|
|
| 49 |
llama_model_forward
|
| 50 |
)
|
| 51 |
|
| 52 |
+
try:
|
| 53 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
| 54 |
+
|
| 55 |
+
LOG.info("patching with flash_attn.losses.cross_entropy")
|
| 56 |
+
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
| 57 |
+
CrossEntropyLoss, inplace_backward=True
|
| 58 |
+
)
|
| 59 |
+
except ImportError:
|
| 60 |
+
LOG.info(
|
| 61 |
+
"optimized flash-attention CrossEntropyLoss not found (run `pip install git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy`)"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
|
| 65 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
| 66 |
# requires the attention mask to be the same as the key_padding_mask
|