bump transformers and update attention class map name (#1023)
Browse files* bump transformers and update attention class map name
* also run the tests in docker
* add mixtral e2e smoke test
* fix base name for docker image in test
* mixtral lora doesn't seem to work, at least check qlora
* add testcase for mixtral w sample packing
* check monkeypatch for flash attn multipack
* also run the e2e tests in docker
* use all gpus to run tests in docker ci
* use privileged mode too for docker w gpus
* rename the docker e2e actions for gh ci
* set privileged mode for docker and update mixtral model self attn check
* use fp16/bf16 for mixtral w fa2
* skip e2e tests on docker w gpus for now
* tests to validate mistral and mixtral patches
* fix rel import
- .github/workflows/tests-docker.yml +62 -0
- requirements.txt +1 -1
- src/axolotl/monkeypatch/mixtral/__init__.py +1 -1
- src/axolotl/monkeypatch/mixtral/modeling_mixtral.py +6 -2
- src/axolotl/utils/models.py +3 -0
- tests/e2e/test_mixtral.py +109 -0
- tests/e2e/test_mixtral_samplepack.py +123 -0
- tests/e2e/test_model_patches.py +99 -0
.github/workflows/tests-docker.yml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: e2e-docker-tests
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
pull_request:
|
| 5 |
+
paths:
|
| 6 |
+
- '**.py'
|
| 7 |
+
- 'requirements.txt'
|
| 8 |
+
workflow_dispatch:
|
| 9 |
+
|
| 10 |
+
jobs:
|
| 11 |
+
build-axolotl:
|
| 12 |
+
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
| 13 |
+
# this job needs to be run on self-hosted GPU runners...
|
| 14 |
+
strategy:
|
| 15 |
+
fail-fast: false
|
| 16 |
+
matrix:
|
| 17 |
+
include:
|
| 18 |
+
- cuda: 118
|
| 19 |
+
cuda_version: 11.8.0
|
| 20 |
+
python_version: "3.10"
|
| 21 |
+
pytorch: 2.0.1
|
| 22 |
+
axolotl_extras:
|
| 23 |
+
is_latest: true
|
| 24 |
+
- cuda: 121
|
| 25 |
+
cuda_version: 12.1.0
|
| 26 |
+
python_version: "3.10"
|
| 27 |
+
pytorch: 2.1.1
|
| 28 |
+
axolotl_extras:
|
| 29 |
+
runs-on: [self-hosted, gpu, docker]
|
| 30 |
+
steps:
|
| 31 |
+
- name: Checkout
|
| 32 |
+
uses: actions/checkout@v4
|
| 33 |
+
- name: Docker metadata
|
| 34 |
+
id: metadata
|
| 35 |
+
uses: docker/metadata-action@v5
|
| 36 |
+
with:
|
| 37 |
+
images: winglian/axolotl
|
| 38 |
+
- name: Set up Docker Buildx
|
| 39 |
+
uses: docker/setup-buildx-action@v3
|
| 40 |
+
- name: Login to Docker Hub
|
| 41 |
+
uses: docker/login-action@v3
|
| 42 |
+
with:
|
| 43 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
| 44 |
+
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
| 45 |
+
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
| 46 |
+
- name: Build and export to Docker
|
| 47 |
+
uses: docker/build-push-action@v5
|
| 48 |
+
with:
|
| 49 |
+
context: .
|
| 50 |
+
load: true
|
| 51 |
+
build-args: |
|
| 52 |
+
BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
| 53 |
+
CUDA=${{ matrix.cuda }}
|
| 54 |
+
PYTORCH_VERSION=${{ matrix.pytorch }}
|
| 55 |
+
file: ./docker/Dockerfile
|
| 56 |
+
tags: |
|
| 57 |
+
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
| 58 |
+
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
| 59 |
+
labels: ${{ steps.metadata.outputs.labels }}
|
| 60 |
+
- name: Unit Tests
|
| 61 |
+
run: |
|
| 62 |
+
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
auto-gptq==0.5.1
|
| 3 |
packaging
|
| 4 |
peft==0.6.0
|
| 5 |
-
transformers
|
| 6 |
tokenizers==0.15.0
|
| 7 |
bitsandbytes>=0.41.1
|
| 8 |
accelerate==0.24.1
|
|
|
|
| 2 |
auto-gptq==0.5.1
|
| 3 |
packaging
|
| 4 |
peft==0.6.0
|
| 5 |
+
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
| 6 |
tokenizers==0.15.0
|
| 7 |
bitsandbytes>=0.41.1
|
| 8 |
accelerate==0.24.1
|
src/axolotl/monkeypatch/mixtral/__init__.py
CHANGED
|
@@ -17,6 +17,6 @@ def replace_mixtral_attn_with_multipack_flash_attn():
|
|
| 17 |
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
| 18 |
mixtral_model_forward
|
| 19 |
)
|
| 20 |
-
transformers.models.mixtral.modeling_mixtral.
|
| 21 |
"flash_attention_2"
|
| 22 |
] = MixtralMultipackFlashAttention2
|
|
|
|
| 17 |
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
| 18 |
mixtral_model_forward
|
| 19 |
)
|
| 20 |
+
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
|
| 21 |
"flash_attention_2"
|
| 22 |
] = MixtralMultipackFlashAttention2
|
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
CHANGED
|
@@ -261,7 +261,11 @@ def mixtral_model_forward(
|
|
| 261 |
if inputs_embeds is None:
|
| 262 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 263 |
|
| 264 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
| 266 |
if is_padding_right:
|
| 267 |
raise ValueError(
|
|
@@ -270,7 +274,7 @@ def mixtral_model_forward(
|
|
| 270 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 271 |
)
|
| 272 |
|
| 273 |
-
if self.
|
| 274 |
# 2d mask is passed through the layers
|
| 275 |
attention_mask = (
|
| 276 |
attention_mask
|
|
|
|
| 261 |
if inputs_embeds is None:
|
| 262 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 263 |
|
| 264 |
+
if (
|
| 265 |
+
attention_mask is not None
|
| 266 |
+
and self._attn_implementation == "flash_attention_2"
|
| 267 |
+
and use_cache
|
| 268 |
+
):
|
| 269 |
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
| 270 |
if is_padding_right:
|
| 271 |
raise ValueError(
|
|
|
|
| 274 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
| 275 |
)
|
| 276 |
|
| 277 |
+
if self._attn_implementation == "flash_attention_2":
|
| 278 |
# 2d mask is passed through the layers
|
| 279 |
attention_mask = (
|
| 280 |
attention_mask
|
src/axolotl/utils/models.py
CHANGED
|
@@ -332,15 +332,18 @@ def load_model(
|
|
| 332 |
or cfg.is_mistral_derived_model
|
| 333 |
or model_config.model_type == "mixtral"
|
| 334 |
):
|
|
|
|
| 335 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 336 |
"flash_attention_2"
|
| 337 |
)
|
| 338 |
else:
|
| 339 |
if model_config.model_type == "mixtral":
|
|
|
|
| 340 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 341 |
"flash_attention_2"
|
| 342 |
)
|
| 343 |
else:
|
|
|
|
| 344 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 345 |
"eager"
|
| 346 |
)
|
|
|
|
| 332 |
or cfg.is_mistral_derived_model
|
| 333 |
or model_config.model_type == "mixtral"
|
| 334 |
):
|
| 335 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 336 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 337 |
"flash_attention_2"
|
| 338 |
)
|
| 339 |
else:
|
| 340 |
if model_config.model_type == "mixtral":
|
| 341 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
| 342 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 343 |
"flash_attention_2"
|
| 344 |
)
|
| 345 |
else:
|
| 346 |
+
model_kwargs["attn_implementation"] = "eager"
|
| 347 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
| 348 |
"eager"
|
| 349 |
)
|
tests/e2e/test_mixtral.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
E2E tests for mixtral
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import unittest
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
| 11 |
+
|
| 12 |
+
from axolotl.cli import load_datasets
|
| 13 |
+
from axolotl.common.cli import TrainerCliArgs
|
| 14 |
+
from axolotl.train import train
|
| 15 |
+
from axolotl.utils.config import normalize_config
|
| 16 |
+
from axolotl.utils.dict import DictDefault
|
| 17 |
+
|
| 18 |
+
from .utils import with_temp_dir
|
| 19 |
+
|
| 20 |
+
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 21 |
+
os.environ["WANDB_DISABLED"] = "true"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TestMixtral(unittest.TestCase):
|
| 25 |
+
"""
|
| 26 |
+
Test case for Llama models using LoRA
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
@with_temp_dir
|
| 30 |
+
def test_qlora(self, temp_dir):
|
| 31 |
+
# pylint: disable=duplicate-code
|
| 32 |
+
cfg = DictDefault(
|
| 33 |
+
{
|
| 34 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
| 35 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
| 36 |
+
"flash_attention": True,
|
| 37 |
+
"sequence_len": 1024,
|
| 38 |
+
"load_in_4bit": True,
|
| 39 |
+
"adapter": "qlora",
|
| 40 |
+
"lora_r": 16,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"lora_dropout": 0.1,
|
| 43 |
+
"lora_target_linear": True,
|
| 44 |
+
"val_set_size": 0.1,
|
| 45 |
+
"special_tokens": {},
|
| 46 |
+
"datasets": [
|
| 47 |
+
{
|
| 48 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
| 49 |
+
"type": "alpaca",
|
| 50 |
+
},
|
| 51 |
+
],
|
| 52 |
+
"num_epochs": 2,
|
| 53 |
+
"micro_batch_size": 2,
|
| 54 |
+
"gradient_accumulation_steps": 1,
|
| 55 |
+
"output_dir": temp_dir,
|
| 56 |
+
"learning_rate": 0.00001,
|
| 57 |
+
"optimizer": "adamw_bnb_8bit",
|
| 58 |
+
"lr_scheduler": "cosine",
|
| 59 |
+
"max_steps": 20,
|
| 60 |
+
"save_steps": 10,
|
| 61 |
+
"eval_steps": 10,
|
| 62 |
+
}
|
| 63 |
+
)
|
| 64 |
+
normalize_config(cfg)
|
| 65 |
+
cli_args = TrainerCliArgs()
|
| 66 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 67 |
+
|
| 68 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 69 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
| 70 |
+
|
| 71 |
+
@with_temp_dir
|
| 72 |
+
def test_ft(self, temp_dir):
|
| 73 |
+
# pylint: disable=duplicate-code
|
| 74 |
+
cfg = DictDefault(
|
| 75 |
+
{
|
| 76 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
| 77 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
| 78 |
+
"flash_attention": True,
|
| 79 |
+
"sequence_len": 1024,
|
| 80 |
+
"val_set_size": 0.1,
|
| 81 |
+
"special_tokens": {},
|
| 82 |
+
"datasets": [
|
| 83 |
+
{
|
| 84 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
| 85 |
+
"type": "alpaca",
|
| 86 |
+
},
|
| 87 |
+
],
|
| 88 |
+
"num_epochs": 2,
|
| 89 |
+
"micro_batch_size": 2,
|
| 90 |
+
"gradient_accumulation_steps": 1,
|
| 91 |
+
"output_dir": temp_dir,
|
| 92 |
+
"learning_rate": 0.00001,
|
| 93 |
+
"optimizer": "adamw_bnb_8bit",
|
| 94 |
+
"lr_scheduler": "cosine",
|
| 95 |
+
"max_steps": 20,
|
| 96 |
+
"save_steps": 10,
|
| 97 |
+
"eval_steps": 10,
|
| 98 |
+
}
|
| 99 |
+
)
|
| 100 |
+
if is_torch_bf16_gpu_available():
|
| 101 |
+
cfg.bf16 = True
|
| 102 |
+
else:
|
| 103 |
+
cfg.fp16 = True
|
| 104 |
+
normalize_config(cfg)
|
| 105 |
+
cli_args = TrainerCliArgs()
|
| 106 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 107 |
+
|
| 108 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 109 |
+
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
tests/e2e/test_mixtral_samplepack.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
E2E tests for mixtral
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import unittest
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
| 11 |
+
|
| 12 |
+
from axolotl.cli import load_datasets
|
| 13 |
+
from axolotl.common.cli import TrainerCliArgs
|
| 14 |
+
from axolotl.train import train
|
| 15 |
+
from axolotl.utils.config import normalize_config
|
| 16 |
+
from axolotl.utils.dict import DictDefault
|
| 17 |
+
|
| 18 |
+
from .utils import with_temp_dir
|
| 19 |
+
|
| 20 |
+
LOG = logging.getLogger("axolotl.tests.e2e")
|
| 21 |
+
os.environ["WANDB_DISABLED"] = "true"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TestMixtral(unittest.TestCase):
|
| 25 |
+
"""
|
| 26 |
+
Test case for Llama models using LoRA
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
@with_temp_dir
|
| 30 |
+
def test_qlora(self, temp_dir):
|
| 31 |
+
# pylint: disable=duplicate-code
|
| 32 |
+
cfg = DictDefault(
|
| 33 |
+
{
|
| 34 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
| 35 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
| 36 |
+
"flash_attention": True,
|
| 37 |
+
"sequence_len": 2048,
|
| 38 |
+
"load_in_4bit": True,
|
| 39 |
+
"adapter": "qlora",
|
| 40 |
+
"lora_r": 16,
|
| 41 |
+
"lora_alpha": 32,
|
| 42 |
+
"lora_dropout": 0.1,
|
| 43 |
+
"lora_target_linear": True,
|
| 44 |
+
"val_set_size": 0.1,
|
| 45 |
+
"special_tokens": {},
|
| 46 |
+
"datasets": [
|
| 47 |
+
{
|
| 48 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
| 49 |
+
"type": "alpaca",
|
| 50 |
+
},
|
| 51 |
+
],
|
| 52 |
+
"num_epochs": 2,
|
| 53 |
+
"micro_batch_size": 2,
|
| 54 |
+
"gradient_accumulation_steps": 1,
|
| 55 |
+
"output_dir": temp_dir,
|
| 56 |
+
"learning_rate": 0.00001,
|
| 57 |
+
"optimizer": "adamw_bnb_8bit",
|
| 58 |
+
"lr_scheduler": "cosine",
|
| 59 |
+
"max_steps": 20,
|
| 60 |
+
"save_steps": 10,
|
| 61 |
+
"eval_steps": 10,
|
| 62 |
+
"sample_packing": True,
|
| 63 |
+
}
|
| 64 |
+
)
|
| 65 |
+
if is_torch_bf16_gpu_available():
|
| 66 |
+
cfg.bf16 = True
|
| 67 |
+
else:
|
| 68 |
+
cfg.fp16 = True
|
| 69 |
+
normalize_config(cfg)
|
| 70 |
+
cli_args = TrainerCliArgs()
|
| 71 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 72 |
+
|
| 73 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 74 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
| 75 |
+
|
| 76 |
+
@with_temp_dir
|
| 77 |
+
def test_ft(self, temp_dir):
|
| 78 |
+
# pylint: disable=duplicate-code
|
| 79 |
+
cfg = DictDefault(
|
| 80 |
+
{
|
| 81 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
| 82 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
| 83 |
+
"flash_attention": True,
|
| 84 |
+
"sequence_len": 2048,
|
| 85 |
+
"val_set_size": 0.1,
|
| 86 |
+
"special_tokens": {},
|
| 87 |
+
"datasets": [
|
| 88 |
+
{
|
| 89 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
| 90 |
+
"type": "alpaca",
|
| 91 |
+
},
|
| 92 |
+
],
|
| 93 |
+
"num_epochs": 2,
|
| 94 |
+
"micro_batch_size": 2,
|
| 95 |
+
"gradient_accumulation_steps": 1,
|
| 96 |
+
"output_dir": temp_dir,
|
| 97 |
+
"learning_rate": 0.00001,
|
| 98 |
+
"optimizer": "adamw_bnb_8bit",
|
| 99 |
+
"lr_scheduler": "cosine",
|
| 100 |
+
"max_steps": 20,
|
| 101 |
+
"save_steps": 10,
|
| 102 |
+
"eval_steps": 10,
|
| 103 |
+
"sample_packing": True,
|
| 104 |
+
}
|
| 105 |
+
)
|
| 106 |
+
if is_torch_bf16_gpu_available():
|
| 107 |
+
cfg.bf16 = True
|
| 108 |
+
else:
|
| 109 |
+
cfg.fp16 = True
|
| 110 |
+
normalize_config(cfg)
|
| 111 |
+
cli_args = TrainerCliArgs()
|
| 112 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
| 113 |
+
|
| 114 |
+
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
| 115 |
+
assert (
|
| 116 |
+
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
| 117 |
+
in model.model.layers[0].self_attn.__class__.__module__
|
| 118 |
+
)
|
| 119 |
+
assert (
|
| 120 |
+
"MixtralMultipackFlashAttention2"
|
| 121 |
+
in model.model.layers[0].self_attn.__class__.__name__
|
| 122 |
+
)
|
| 123 |
+
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
tests/e2e/test_model_patches.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
E2E smoke tests to check that the monkeypatches are in place for certain configurations
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import unittest
|
| 6 |
+
|
| 7 |
+
from axolotl.common.cli import TrainerCliArgs
|
| 8 |
+
from axolotl.utils.config import normalize_config
|
| 9 |
+
from axolotl.utils.dict import DictDefault
|
| 10 |
+
from axolotl.utils.models import load_model, load_tokenizer
|
| 11 |
+
|
| 12 |
+
from .utils import with_temp_dir
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestModelPatches(unittest.TestCase):
|
| 16 |
+
"""
|
| 17 |
+
TestCases for the multipack monkey patches
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@with_temp_dir
|
| 21 |
+
def test_mixtral_multipack(self, temp_dir):
|
| 22 |
+
cfg = DictDefault(
|
| 23 |
+
{
|
| 24 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
| 25 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
| 26 |
+
"flash_attention": True,
|
| 27 |
+
"sample_packing": True,
|
| 28 |
+
"sequence_len": 2048,
|
| 29 |
+
"val_set_size": 0.1,
|
| 30 |
+
"special_tokens": {},
|
| 31 |
+
"datasets": [
|
| 32 |
+
{
|
| 33 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
| 34 |
+
"type": "alpaca",
|
| 35 |
+
},
|
| 36 |
+
],
|
| 37 |
+
"num_epochs": 2,
|
| 38 |
+
"micro_batch_size": 2,
|
| 39 |
+
"gradient_accumulation_steps": 1,
|
| 40 |
+
"output_dir": temp_dir,
|
| 41 |
+
"learning_rate": 0.00001,
|
| 42 |
+
"optimizer": "adamw_bnb_8bit",
|
| 43 |
+
"lr_scheduler": "cosine",
|
| 44 |
+
"max_steps": 20,
|
| 45 |
+
"save_steps": 10,
|
| 46 |
+
"eval_steps": 10,
|
| 47 |
+
}
|
| 48 |
+
)
|
| 49 |
+
normalize_config(cfg)
|
| 50 |
+
cli_args = TrainerCliArgs()
|
| 51 |
+
tokenizer = load_tokenizer(cfg)
|
| 52 |
+
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
| 53 |
+
|
| 54 |
+
assert (
|
| 55 |
+
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
| 56 |
+
in model.model.layers[0].self_attn.__class__.__module__
|
| 57 |
+
)
|
| 58 |
+
assert (
|
| 59 |
+
"MixtralMultipackFlashAttention2"
|
| 60 |
+
in model.model.layers[0].self_attn.__class__.__name__
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
@with_temp_dir
|
| 64 |
+
def test_mistral_multipack(self, temp_dir):
|
| 65 |
+
cfg = DictDefault(
|
| 66 |
+
{
|
| 67 |
+
"base_model": "openaccess-ai-collective/tiny-mistral",
|
| 68 |
+
"flash_attention": True,
|
| 69 |
+
"sample_packing": True,
|
| 70 |
+
"sequence_len": 2048,
|
| 71 |
+
"val_set_size": 0.1,
|
| 72 |
+
"special_tokens": {},
|
| 73 |
+
"datasets": [
|
| 74 |
+
{
|
| 75 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
| 76 |
+
"type": "alpaca",
|
| 77 |
+
},
|
| 78 |
+
],
|
| 79 |
+
"num_epochs": 2,
|
| 80 |
+
"micro_batch_size": 2,
|
| 81 |
+
"gradient_accumulation_steps": 1,
|
| 82 |
+
"output_dir": temp_dir,
|
| 83 |
+
"learning_rate": 0.00001,
|
| 84 |
+
"optimizer": "adamw_bnb_8bit",
|
| 85 |
+
"lr_scheduler": "cosine",
|
| 86 |
+
"max_steps": 20,
|
| 87 |
+
"save_steps": 10,
|
| 88 |
+
"eval_steps": 10,
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
normalize_config(cfg)
|
| 92 |
+
cli_args = TrainerCliArgs()
|
| 93 |
+
tokenizer = load_tokenizer(cfg)
|
| 94 |
+
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
| 95 |
+
|
| 96 |
+
assert (
|
| 97 |
+
"axolotl.monkeypatch.mistral_attn_hijack_flash"
|
| 98 |
+
in model.model.layers[0].self_attn.forward.__module__
|
| 99 |
+
)
|