make mlflow optional (#1317)
Browse files* make mlflow optional
* fix xformers
don't patch swiglu if xformers not working
fix the check for xformers swiglu
* fix install of xformers with extra index url for docker builds
* fix docker build arg quoting
- .github/workflows/main.yml +2 -0
- .github/workflows/tests.yml +3 -0
- docker/Dockerfile +3 -2
- docker/Dockerfile-tests +3 -2
- requirements.txt +0 -1
- setup.py +3 -0
- src/axolotl/core/trainer_builder.py +10 -2
- src/axolotl/monkeypatch/llama_attn_hijack_flash.py +12 -0
- src/axolotl/utils/{callbacks.py → callbacks/__init__.py} +1 -30
- src/axolotl/utils/callbacks/mlflow_.py +44 -0
- src/axolotl/utils/models.py +2 -1
- tests/e2e/patched/test_fused_llama.py +3 -3
.github/workflows/main.yml
CHANGED
|
@@ -18,6 +18,7 @@ jobs:
|
|
| 18 |
python_version: "3.10"
|
| 19 |
pytorch: 2.1.2
|
| 20 |
axolotl_extras:
|
|
|
|
| 21 |
is_latest: true
|
| 22 |
- cuda: 121
|
| 23 |
cuda_version: 12.1.0
|
|
@@ -54,6 +55,7 @@ jobs:
|
|
| 54 |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
| 55 |
CUDA=${{ matrix.cuda }}
|
| 56 |
PYTORCH_VERSION=${{ matrix.pytorch }}
|
|
|
|
| 57 |
file: ./docker/Dockerfile
|
| 58 |
push: ${{ github.event_name != 'pull_request' }}
|
| 59 |
tags: |
|
|
|
|
| 18 |
python_version: "3.10"
|
| 19 |
pytorch: 2.1.2
|
| 20 |
axolotl_extras:
|
| 21 |
+
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
| 22 |
is_latest: true
|
| 23 |
- cuda: 121
|
| 24 |
cuda_version: 12.1.0
|
|
|
|
| 55 |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
| 56 |
CUDA=${{ matrix.cuda }}
|
| 57 |
PYTORCH_VERSION=${{ matrix.pytorch }}
|
| 58 |
+
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
| 59 |
file: ./docker/Dockerfile
|
| 60 |
push: ${{ github.event_name != 'pull_request' }}
|
| 61 |
tags: |
|
.github/workflows/tests.yml
CHANGED
|
@@ -70,6 +70,7 @@ jobs:
|
|
| 70 |
cuda_version: 11.8.0
|
| 71 |
python_version: "3.10"
|
| 72 |
pytorch: 2.1.2
|
|
|
|
| 73 |
- cuda: 121
|
| 74 |
cuda_version: 12.1.0
|
| 75 |
python_version: "3.10"
|
|
@@ -87,11 +88,13 @@ jobs:
|
|
| 87 |
# Set up build arguments
|
| 88 |
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
|
| 89 |
CUDA="${{ matrix.cuda }}"
|
|
|
|
| 90 |
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
| 91 |
# Build the Docker image
|
| 92 |
docker build . \
|
| 93 |
--file ./docker/Dockerfile-tests \
|
| 94 |
--build-arg BASE_TAG=$BASE_TAG \
|
|
|
|
| 95 |
--build-arg CUDA=$CUDA \
|
| 96 |
--build-arg GITHUB_REF=$GITHUB_REF \
|
| 97 |
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
|
|
|
| 70 |
cuda_version: 11.8.0
|
| 71 |
python_version: "3.10"
|
| 72 |
pytorch: 2.1.2
|
| 73 |
+
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
| 74 |
- cuda: 121
|
| 75 |
cuda_version: 12.1.0
|
| 76 |
python_version: "3.10"
|
|
|
|
| 88 |
# Set up build arguments
|
| 89 |
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
|
| 90 |
CUDA="${{ matrix.cuda }}"
|
| 91 |
+
AXOLOTL_ARGS="${{ matrix.axolotl_args }}"
|
| 92 |
PYTORCH_VERSION="${{ matrix.pytorch }}"
|
| 93 |
# Build the Docker image
|
| 94 |
docker build . \
|
| 95 |
--file ./docker/Dockerfile-tests \
|
| 96 |
--build-arg BASE_TAG=$BASE_TAG \
|
| 97 |
+
--build-arg AXOLOTL_ARGS="$AXOLOTL_ARGS" \
|
| 98 |
--build-arg CUDA=$CUDA \
|
| 99 |
--build-arg GITHUB_REF=$GITHUB_REF \
|
| 100 |
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
|
docker/Dockerfile
CHANGED
|
@@ -3,6 +3,7 @@ FROM winglian/axolotl-base:$BASE_TAG
|
|
| 3 |
|
| 4 |
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
| 5 |
ARG AXOLOTL_EXTRAS=""
|
|
|
|
| 6 |
ARG CUDA="118"
|
| 7 |
ENV BNB_CUDA_VERSION=$CUDA
|
| 8 |
ARG PYTORCH_VERSION="2.0.1"
|
|
@@ -20,9 +21,9 @@ WORKDIR /workspace/axolotl
|
|
| 20 |
|
| 21 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 22 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 23 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
| 24 |
else \
|
| 25 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
| 26 |
fi
|
| 27 |
|
| 28 |
# So we can test the Docker image
|
|
|
|
| 3 |
|
| 4 |
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
| 5 |
ARG AXOLOTL_EXTRAS=""
|
| 6 |
+
ARG AXOLOTL_ARGS=""
|
| 7 |
ARG CUDA="118"
|
| 8 |
ENV BNB_CUDA_VERSION=$CUDA
|
| 9 |
ARG PYTORCH_VERSION="2.0.1"
|
|
|
|
| 21 |
|
| 22 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 23 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 24 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
| 25 |
else \
|
| 26 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
| 27 |
fi
|
| 28 |
|
| 29 |
# So we can test the Docker image
|
docker/Dockerfile-tests
CHANGED
|
@@ -3,6 +3,7 @@ FROM winglian/axolotl-base:$BASE_TAG
|
|
| 3 |
|
| 4 |
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
| 5 |
ARG AXOLOTL_EXTRAS=""
|
|
|
|
| 6 |
ARG CUDA="118"
|
| 7 |
ENV BNB_CUDA_VERSION=$CUDA
|
| 8 |
ARG PYTORCH_VERSION="2.0.1"
|
|
@@ -24,9 +25,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|
| 24 |
|
| 25 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 26 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 27 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
|
| 28 |
else \
|
| 29 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
|
| 30 |
fi
|
| 31 |
|
| 32 |
# So we can test the Docker image
|
|
|
|
| 3 |
|
| 4 |
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
| 5 |
ARG AXOLOTL_EXTRAS=""
|
| 6 |
+
ARG AXOLOTL_ARGS=""
|
| 7 |
ARG CUDA="118"
|
| 8 |
ENV BNB_CUDA_VERSION=$CUDA
|
| 9 |
ARG PYTORCH_VERSION="2.0.1"
|
|
|
|
| 25 |
|
| 26 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
| 27 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
| 28 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
| 29 |
else \
|
| 30 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
| 31 |
fi
|
| 32 |
|
| 33 |
# So we can test the Docker image
|
requirements.txt
CHANGED
|
@@ -21,7 +21,6 @@ hf_transfer
|
|
| 21 |
colorama
|
| 22 |
numba
|
| 23 |
numpy>=1.24.4
|
| 24 |
-
mlflow
|
| 25 |
# qlora things
|
| 26 |
evaluate==0.4.1
|
| 27 |
scipy
|
|
|
|
| 21 |
colorama
|
| 22 |
numba
|
| 23 |
numpy>=1.24.4
|
|
|
|
| 24 |
# qlora things
|
| 25 |
evaluate==0.4.1
|
| 26 |
scipy
|
setup.py
CHANGED
|
@@ -82,5 +82,8 @@ setup(
|
|
| 82 |
"auto-gptq": [
|
| 83 |
"auto-gptq==0.5.1",
|
| 84 |
],
|
|
|
|
|
|
|
|
|
|
| 85 |
},
|
| 86 |
)
|
|
|
|
| 82 |
"auto-gptq": [
|
| 83 |
"auto-gptq==0.5.1",
|
| 84 |
],
|
| 85 |
+
"mlflow": [
|
| 86 |
+
"mlflow",
|
| 87 |
+
],
|
| 88 |
},
|
| 89 |
)
|
src/axolotl/core/trainer_builder.py
CHANGED
|
@@ -5,6 +5,7 @@ Builder for the training args and trainer
|
|
| 5 |
|
| 6 |
import abc
|
| 7 |
import importlib
|
|
|
|
| 8 |
import logging
|
| 9 |
import math
|
| 10 |
import sys
|
|
@@ -34,7 +35,6 @@ from axolotl.utils.callbacks import (
|
|
| 34 |
EvalFirstStepCallback,
|
| 35 |
GPUStatsCallback,
|
| 36 |
LossWatchDogCallback,
|
| 37 |
-
SaveAxolotlConfigtoMlflowCallback,
|
| 38 |
SaveAxolotlConfigtoWandBCallback,
|
| 39 |
SaveBetterTransformerModelCallback,
|
| 40 |
bench_eval_callback_factory,
|
|
@@ -62,6 +62,10 @@ except ImportError:
|
|
| 62 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
| 66 |
if isinstance(tag_names, str):
|
| 67 |
tag_names = [tag_names]
|
|
@@ -648,7 +652,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
| 648 |
callbacks.append(
|
| 649 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
| 650 |
)
|
| 651 |
-
if self.cfg.use_mlflow:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
callbacks.append(
|
| 653 |
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
| 654 |
)
|
|
|
|
| 5 |
|
| 6 |
import abc
|
| 7 |
import importlib
|
| 8 |
+
import importlib.util
|
| 9 |
import logging
|
| 10 |
import math
|
| 11 |
import sys
|
|
|
|
| 35 |
EvalFirstStepCallback,
|
| 36 |
GPUStatsCallback,
|
| 37 |
LossWatchDogCallback,
|
|
|
|
| 38 |
SaveAxolotlConfigtoWandBCallback,
|
| 39 |
SaveBetterTransformerModelCallback,
|
| 40 |
bench_eval_callback_factory,
|
|
|
|
| 62 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
| 63 |
|
| 64 |
|
| 65 |
+
def is_mlflow_available():
|
| 66 |
+
return importlib.util.find_spec("mlflow") is not None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
| 70 |
if isinstance(tag_names, str):
|
| 71 |
tag_names = [tag_names]
|
|
|
|
| 652 |
callbacks.append(
|
| 653 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
| 654 |
)
|
| 655 |
+
if self.cfg.use_mlflow and is_mlflow_available():
|
| 656 |
+
from axolotl.utils.callbacks.mlflow_ import (
|
| 657 |
+
SaveAxolotlConfigtoMlflowCallback,
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
callbacks.append(
|
| 661 |
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
| 662 |
)
|
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
|
@@ -44,6 +44,18 @@ except ImportError:
|
|
| 44 |
LOG = logging.getLogger("axolotl")
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def replace_llama_mlp_with_swiglu(model):
|
| 48 |
for name, module in model.named_modules():
|
| 49 |
if isinstance(module, LlamaMLP):
|
|
|
|
| 44 |
LOG = logging.getLogger("axolotl")
|
| 45 |
|
| 46 |
|
| 47 |
+
def is_xformers_swiglu_available() -> bool:
|
| 48 |
+
from xformers.ops.common import get_xformers_operator
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
get_xformers_operator("swiglu_packedw")()
|
| 52 |
+
return True
|
| 53 |
+
except RuntimeError as exc:
|
| 54 |
+
if "No such operator xformers::swiglu_packedw " in str(exc):
|
| 55 |
+
return False
|
| 56 |
+
return True
|
| 57 |
+
|
| 58 |
+
|
| 59 |
def replace_llama_mlp_with_swiglu(model):
|
| 60 |
for name, module in model.named_modules():
|
| 61 |
if isinstance(module, LlamaMLP):
|
src/axolotl/utils/{callbacks.py → callbacks/__init__.py}
RENAMED
|
@@ -9,7 +9,6 @@ from tempfile import NamedTemporaryFile
|
|
| 9 |
from typing import TYPE_CHECKING, Dict, List
|
| 10 |
|
| 11 |
import evaluate
|
| 12 |
-
import mlflow
|
| 13 |
import numpy as np
|
| 14 |
import pandas as pd
|
| 15 |
import torch
|
|
@@ -42,8 +41,8 @@ from axolotl.utils.distributed import (
|
|
| 42 |
if TYPE_CHECKING:
|
| 43 |
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
| 44 |
|
| 45 |
-
LOG = logging.getLogger("axolotl.callbacks")
|
| 46 |
IGNORE_INDEX = -100
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
class EvalFirstStepCallback(
|
|
@@ -756,31 +755,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
| 756 |
except (FileNotFoundError, ConnectionError) as err:
|
| 757 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
| 758 |
return control
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
| 762 |
-
"""Callback to save axolotl config to mlflow"""
|
| 763 |
-
|
| 764 |
-
def __init__(self, axolotl_config_path):
|
| 765 |
-
self.axolotl_config_path = axolotl_config_path
|
| 766 |
-
|
| 767 |
-
def on_train_begin(
|
| 768 |
-
self,
|
| 769 |
-
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
| 770 |
-
state: TrainerState, # pylint: disable=unused-argument
|
| 771 |
-
control: TrainerControl,
|
| 772 |
-
**kwargs, # pylint: disable=unused-argument
|
| 773 |
-
):
|
| 774 |
-
if is_main_process():
|
| 775 |
-
try:
|
| 776 |
-
with NamedTemporaryFile(
|
| 777 |
-
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
| 778 |
-
) as temp_file:
|
| 779 |
-
copyfile(self.axolotl_config_path, temp_file.name)
|
| 780 |
-
mlflow.log_artifact(temp_file.name, artifact_path="")
|
| 781 |
-
LOG.info(
|
| 782 |
-
"The Axolotl config has been saved to the MLflow artifacts."
|
| 783 |
-
)
|
| 784 |
-
except (FileNotFoundError, ConnectionError) as err:
|
| 785 |
-
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
| 786 |
-
return control
|
|
|
|
| 9 |
from typing import TYPE_CHECKING, Dict, List
|
| 10 |
|
| 11 |
import evaluate
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
import pandas as pd
|
| 14 |
import torch
|
|
|
|
| 41 |
if TYPE_CHECKING:
|
| 42 |
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
| 43 |
|
|
|
|
| 44 |
IGNORE_INDEX = -100
|
| 45 |
+
LOG = logging.getLogger("axolotl.callbacks")
|
| 46 |
|
| 47 |
|
| 48 |
class EvalFirstStepCallback(
|
|
|
|
| 755 |
except (FileNotFoundError, ConnectionError) as err:
|
| 756 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
| 757 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/utils/callbacks/mlflow_.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MLFlow module for trainer callbacks"""
|
| 2 |
+
import logging
|
| 3 |
+
from shutil import copyfile
|
| 4 |
+
from tempfile import NamedTemporaryFile
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
import mlflow
|
| 8 |
+
from transformers import TrainerCallback, TrainerControl, TrainerState
|
| 9 |
+
|
| 10 |
+
from axolotl.utils.distributed import is_main_process
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
| 14 |
+
|
| 15 |
+
LOG = logging.getLogger("axolotl.callbacks")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
| 19 |
+
# pylint: disable=duplicate-code
|
| 20 |
+
"""Callback to save axolotl config to mlflow"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, axolotl_config_path):
|
| 23 |
+
self.axolotl_config_path = axolotl_config_path
|
| 24 |
+
|
| 25 |
+
def on_train_begin(
|
| 26 |
+
self,
|
| 27 |
+
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
|
| 28 |
+
state: TrainerState, # pylint: disable=unused-argument
|
| 29 |
+
control: TrainerControl,
|
| 30 |
+
**kwargs, # pylint: disable=unused-argument
|
| 31 |
+
):
|
| 32 |
+
if is_main_process():
|
| 33 |
+
try:
|
| 34 |
+
with NamedTemporaryFile(
|
| 35 |
+
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
| 36 |
+
) as temp_file:
|
| 37 |
+
copyfile(self.axolotl_config_path, temp_file.name)
|
| 38 |
+
mlflow.log_artifact(temp_file.name, artifact_path="")
|
| 39 |
+
LOG.info(
|
| 40 |
+
"The Axolotl config has been saved to the MLflow artifacts."
|
| 41 |
+
)
|
| 42 |
+
except (FileNotFoundError, ConnectionError) as err:
|
| 43 |
+
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
| 44 |
+
return control
|
src/axolotl/utils/models.py
CHANGED
|
@@ -512,11 +512,12 @@ def load_model(
|
|
| 512 |
|
| 513 |
if cfg.flash_attention and not inference:
|
| 514 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
|
|
|
| 515 |
replace_llama_mlp_with_swiglu,
|
| 516 |
replace_llama_qkv_with_fused,
|
| 517 |
)
|
| 518 |
|
| 519 |
-
if cfg.flash_attn_fuse_mlp:
|
| 520 |
LOG.info("patching with SwiGLU")
|
| 521 |
replace_llama_mlp_with_swiglu(model)
|
| 522 |
|
|
|
|
| 512 |
|
| 513 |
if cfg.flash_attention and not inference:
|
| 514 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
| 515 |
+
is_xformers_swiglu_available,
|
| 516 |
replace_llama_mlp_with_swiglu,
|
| 517 |
replace_llama_qkv_with_fused,
|
| 518 |
)
|
| 519 |
|
| 520 |
+
if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
| 521 |
LOG.info("patching with SwiGLU")
|
| 522 |
replace_llama_mlp_with_swiglu(model)
|
| 523 |
|
tests/e2e/patched/test_fused_llama.py
CHANGED
|
@@ -57,9 +57,9 @@ class TestFusedLlama(unittest.TestCase):
|
|
| 57 |
"learning_rate": 0.00001,
|
| 58 |
"optimizer": "adamw_torch",
|
| 59 |
"lr_scheduler": "cosine",
|
| 60 |
-
"max_steps":
|
| 61 |
-
"save_steps":
|
| 62 |
-
"eval_steps":
|
| 63 |
}
|
| 64 |
)
|
| 65 |
if is_torch_bf16_gpu_available():
|
|
|
|
| 57 |
"learning_rate": 0.00001,
|
| 58 |
"optimizer": "adamw_torch",
|
| 59 |
"lr_scheduler": "cosine",
|
| 60 |
+
"max_steps": 10,
|
| 61 |
+
"save_steps": 5,
|
| 62 |
+
"eval_steps": 5,
|
| 63 |
}
|
| 64 |
)
|
| 65 |
if is_torch_bf16_gpu_available():
|