File size: 9,556 Bytes
b21e4a2 dde02fc b21e4a2 da97285 b21e4a2 501958b 132eb74 b2430ce b21e4a2 da97285 85dd4d5 da97285 54d2ac1 b21e4a2 05bcc9e b21e4a2 00568c1 b21e4a2 b2430ce b21e4a2 15d3a65 da97285 b21e4a2 b2430ce b21e4a2 c67fb71 b21e4a2 4c834bf 132eb74 b21e4a2 ba944e6 f243c21 2ea70eb b432889 b21e4a2 5ea3aa3 05bcc9e 5ea3aa3 b21e4a2 7523d1f b21e4a2 a546ca2 40a6362 b21e4a2 dde02fc b21e4a2 dde02fc b21e4a2 dde02fc b21e4a2 501958b 85dd4d5 b21e4a2 827ec3d b21e4a2 00568c1 b21e4a2 827ec3d b21e4a2 15d3a65 be75668 b21e4a2 e4d1585 b21e4a2 00568c1 b21e4a2 796a085 b21e4a2 501958b ea00dd0 31d2350 501958b b21e4a2 827ec3d ef24342 827ec3d ef24342 827ec3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import os
import signal
import sys
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import Dataset
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from optimum.bettertransformer import BetterTransformer
except ImportError:
BetterTransformer = None
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
LOG = get_logger("axolotl.train")
class TrainDatasetMeta:
dataclass to capture the dataset specific options for training
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# load the tokenizer first
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
key=lambda path: int(path.split("-")[-1]),
cfg.resume_from_checkpoint = sorted_paths[-1]
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
resume_from_checkpoint = cfg.resume_from_checkpoint
# Load the model and tokenizer
msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
# we wait unitl the last possible moment to setup Accelerator
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True
model_ref = None
if cfg.rl and cfg.rl != "orpo":
if cfg.adapter and not cfg.rl_adapter_ref_model:
# use built-in trl autounwrap
LOG.debug("Passing model_ref: None to RL trainer")
model_ref = None # explicit setting to None
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
safe_serialization = cfg.save_safetensors is True
if cfg.unfrozen_parameters:
freeze_layers_except(model, cfg.unfrozen_parameters)
trainer = setup_trainer(
(model, model_ref, peft_config),
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:"Pre-saving adapter config to {cfg.output_dir}")
# additionally presave the tokenizer and model configs
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
if hasattr(model, "config"):
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
_model = model_weakref()
if cfg.flash_optimum and BetterTransformer:
_model = BetterTransformer.reverse(_model)
cfg.output_dir, safe_serialization=safe_serialization
_model_weakref = weakref.ref(model)
lambda signum, frame: terminate_handler(signum, frame, _model_weakref),
badge_markdown = """[<img src="" alt="Built with Axolotl" width="200" height="32"/>]("""
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
if getattr(cfg, "axolotl_config_path"):
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
version = get_distribution("axolotl").version
if raw_axolotl_cfg.is_file():
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n""Starting trainer...")
if cfg.group_by_length:"hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
post_train_hooks(cfg, trainer)"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# post training
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")"Set FSDP state dict type to FULL_STATE_DICT for saving.")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return model, tokenizer
# TODO do we need this fix?
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from:
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
cfg.output_dir, safe_serialization=safe_serialization
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
except AttributeError:
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
return model, tokenizer
def pretrain_hooks(_cfg, _trainer):
Run hooks right before kicking off the training
:param cfg:
:param trainer:
def post_train_hooks(_cfg, _trainer):
Run hooks right after training completes
:param cfg:
:param trainer: