diff --git "a/supp_trainer_after.py" "b/supp_trainer_after.py" --- "a/supp_trainer_after.py" +++ "b/supp_trainer_after.py" @@ -22,6 +22,7 @@ import functools import glob import importlib.metadata import inspect +import json import math import os import random @@ -33,7 +34,7 @@ import time import warnings from collections.abc import Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union # Integrations must be imported before ML frameworks: @@ -52,23 +53,35 @@ import torch.distributed as dist from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn -from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from . import __version__ from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .debug_utils import DebugOption, DebugUnderflowOverflow +from .feature_extraction_sequence_utils import SequenceFeatureExtractor +from .feature_extraction_utils import FeatureExtractionMixin from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend +from .image_processing_utils import BaseImageProcessor from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available +from .integrations.tpu import tpu_spmd_dataloader from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model -from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES +from .models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_MAPPING_NAMES, +) from .optimization import Adafactor, get_scheduler -from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from .processing_utils import ProcessorMixin +from .pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_2_3, +) from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( CallbackHandler, DefaultFlowCallback, + ExportableState, PrinterCallback, ProgressCallback, TrainerCallback, @@ -77,14 +90,15 @@ from .trainer_callback import ( ) from .trainer_pt_utils import ( DistributedTensorGatherer, + EvalLoopContainer, IterableDatasetShard, LabelSmoother, + LayerWiseDummyOptimizer, LengthGroupedSampler, SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, find_batch_size, - get_dataloader_sampler, get_model_param_count, get_module_class_from_name, get_parameter_names, @@ -102,11 +116,12 @@ from .trainer_utils import ( EvalPrediction, HPSearchBackend, HubStrategy, - IntervalStrategy, PredictionOutput, RemoveColumnsCollator, + SaveStrategy, TrainerMemoryTracker, TrainOutput, + check_target_module_exists, default_compute_objective, denumpify_detensorize, enable_full_determinism, @@ -129,26 +144,39 @@ from .utils import ( SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, + XLA_FSDPV2_MIN_VERSION, PushInProgress, + PushToHubMixin, can_return_loss, find_labels, is_accelerate_available, is_apex_available, is_bitsandbytes_available, is_datasets_available, + is_galore_torch_available, + is_grokadamw_available, is_in_notebook, is_ipex_available, + is_liger_kernel_available, + is_lomo_available, is_peft_available, is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, + is_schedulefree_available, is_torch_compile_available, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, is_torch_neuroncore_available, is_torch_npu_available, - is_torch_tpu_available, + is_torch_xla_available, + is_torch_xpu_available, + is_torchao_available, logging, strtobool, ) +from .utils.deprecation import deprecate_kwarg from .utils.quantization_config import QuantizationMethod @@ -166,9 +194,17 @@ if is_apex_available(): if is_datasets_available(): import datasets -if is_torch_tpu_available(check_device=False): +if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met + from torch_xla import __version__ as XLA_VERSION + + IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) + if IS_XLA_FSDPV2_POST_2_2: + import torch_xla.distributed.spmd as xs + import torch_xla.runtime as xr +else: + IS_XLA_FSDPV2_POST_2_2 = False if is_sagemaker_mp_enabled(): @@ -185,7 +221,6 @@ else: if is_safetensors_available(): import safetensors.torch - if is_peft_available(): from peft import PeftModel @@ -193,9 +228,10 @@ if is_peft_available(): if is_accelerate_available(): from accelerate import Accelerator, skip_first_batches from accelerate import __version__ as accelerate_version + from accelerate.state import AcceleratorState from accelerate.utils import ( DistributedDataParallelKwargs, - GradientAccumulationPlugin, + DistributedType, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, @@ -211,14 +247,54 @@ if is_accelerate_available(): if is_deepspeed_available(): from accelerate.utils import DeepSpeedSchedulerWrapper +if is_accelerate_available("0.28.0"): + from accelerate.utils import DataLoaderConfiguration + def _is_peft_model(model): - return is_peft_available() and isinstance(model, PeftModel) + if is_peft_available(): + classes_to_check = (PeftModel,) if is_peft_available() else () + # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 + if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): + from peft import PeftMixedModel + + classes_to_check = (*classes_to_check, PeftMixedModel) + return isinstance(model, classes_to_check) + return False + + +def _get_fsdp_ckpt_kwargs(): + # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release + if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): + return {"adapter_only": True} + else: + return {} + + +def safe_globals(): + # Starting from version 2.4 PyTorch introduces a check for the objects loaded + # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes + # a default and requires allowlisting of objects being loaded. + # See: https://github.com/pytorch/pytorch/pull/137602 + # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals + # See: https://github.com/huggingface/accelerate/pull/3036 + if version.parse(torch.__version__).release < version.parse("2.6").release: + return contextlib.nullcontext() + + np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core + allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype] + # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for + # all versions of numpy + allowlist += [type(np.dtype(np.uint32))] + + return torch.serialization.safe_globals(allowlist) if TYPE_CHECKING: import optuna + if is_datasets_available(): + import datasets logger = logging.get_logger(__name__) @@ -254,9 +330,9 @@ class Trainer: `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. data_collator (`DataCollator`, *optional*): The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will - default to [`default_data_collator`] if no `tokenizer` is provided, an instance of - [`DataCollatorWithPadding`] otherwise. - train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): + default to [`default_data_collator`] if no `processing_class` is provided, an instance of + [`DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer. + train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*): The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. @@ -265,14 +341,15 @@ class Trainer: `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally sets the seed of the RNGs used. - eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): + eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*): The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name. - tokenizer ([`PreTrainedTokenizerBase`], *optional*): - The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the - maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an - interrupted training or reuse the fine-tuned model. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + This supercedes the `tokenizer` argument, which is now deprecated. model_init (`Callable[[], PreTrainedModel]`, *optional*): A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start from a new instance of the model as given by this function. @@ -280,9 +357,15 @@ class Trainer: The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc). + compute_loss_func (`Callable`, *optional*): + A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated + batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) used by [`Trainer`]. compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return - a dictionary string to metric values. + a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to + `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered + after the last eval batch to signal that the function needs to calculate and return the global summary + statistics rather than accumulating the batch-level statistics callbacks (List of [`TrainerCallback`], *optional*): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](callback). @@ -291,6 +374,11 @@ class Trainer: optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*): + A tuple containing the optimizer class and keyword arguments to use. + Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): A function that preprocess the logits right before caching them at each evaluation step. Must take two tensors, the logits and the labels, and return the logits once processed as desired. The modifications made @@ -319,27 +407,51 @@ class Trainer: # Those are used as methods of the Trainer in examples. from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state + @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True) def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_loss_func: Optional[Callable] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ): if args is None: output_dir = "tmp_trainer" logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") args = TrainingArguments(output_dir=output_dir) + if args.batch_eval_metrics and compute_metrics is not None: + if "compute_result" not in inspect.signature(compute_metrics).parameters.keys(): + raise ValueError( + "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`" + " boolean argument which will be triggered after the last batch of the eval set to signal that the" + " summary statistics should be returned by the function." + ) + if args.eval_strategy is not None and args.eval_strategy != "no" and eval_dataset is None: + raise ValueError( + f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. " + ) + if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end: + if args.metric_for_best_model is None: + raise ValueError( + "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`." + ) + self.args = args + self.compute_loss_func = compute_loss_func # Seed must be set before instantiating the model when using model enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.hp_name = None self.deepspeed = None self.is_in_train = False @@ -381,7 +493,7 @@ class Trainer: "https://huggingface.co/docs/transformers/model_doc/auto" ) - if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: + if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False): self.is_model_parallel = True else: self.is_model_parallel = False @@ -402,21 +514,52 @@ class Trainer: " to `True` to avoid any unexpected behavior such as device placement mismatching." ) + if self.args.use_liger_kernel: + if is_liger_kernel_available(): + from liger_kernel.transformers import _apply_liger_kernel_to_instance + + if isinstance(model, PreTrainedModel): + # Patch the model with liger kernels. Use the default kernel configurations. + _apply_liger_kernel_to_instance(model=model) + else: + logger.warning( + "The model is not an instance of PreTrainedModel. No liger kernels will be applied." + ) + else: + raise ImportError( + "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. " + "Please install it with `pip install liger-kernel`" + ) + _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr( model, "_hf_peft_config_loaded", False ) + _quantization_method_supports_training = ( + getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable + ) + + _is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr( + model.hf_quantizer, "is_qat_trainable", False + ) + + # Filter out quantized + compiled models + if _is_quantized_and_base_model and hasattr(model, "_orig_mod"): + raise ValueError( + "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT" + ) # At this stage the model is already loaded - if _is_quantized_and_base_model and not _is_peft_model(model): + if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable: raise ValueError( "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of" " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft" " for more details" ) - elif _is_quantized_and_base_model and not getattr(model, "_is_quantized_training_enabled", False): + elif _is_quantized_and_base_model and not _quantization_method_supports_training: raise ValueError( - "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" - " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " + f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}" + " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers" + f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}" ) self.is_fsdp_xla_enabled = args.fsdp_config["xla"] @@ -445,11 +588,16 @@ class Trainer: ): self.place_model_on_device = False - default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) + default_collator = ( + DataCollatorWithPadding(processing_class) + if processing_class is not None + and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor)) + else default_data_collator + ) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset - self.tokenizer = tokenizer + self.processing_class = processing_class # Bnb Quantized models doesn't support `.to` operation. if ( @@ -466,17 +614,30 @@ class Trainer: self.model_wrapped = model self.model = model + # Just in case the model was wrapped outside of the `Trainer` + unwrapped_model = self.accelerator.unwrap_model(model) + model_forward = ( + unwrapped_model.forward + if not _is_peft_model(unwrapped_model) + else unwrapped_model.get_base_model().forward + ) + forward_params = inspect.signature(model_forward).parameters + self.model_accepts_loss_kwargs = any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()) + self.neftune_noise_alpha = args.neftune_noise_alpha self.compute_metrics = compute_metrics self.preprocess_logits_for_metrics = preprocess_logits_for_metrics self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs + if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None: + raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.") if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): raise RuntimeError( "Passing a `model_init` is incompatible with providing the `optimizers` argument. " "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) - if is_torch_tpu_available() and self.optimizer is not None: + if is_torch_xla_available() and self.optimizer is not None: for param in self.model.parameters(): model_device = param.device break @@ -491,17 +652,17 @@ class Trainer: " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) - if (self.is_deepspeed_enabled or self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and ( + if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and ( self.optimizer is not None or self.lr_scheduler is not None ): raise RuntimeError( - "Passing `optimizers` is not allowed if Deepspeed or PyTorch FSDP is enabled. " + "Passing `optimizers` is not allowed if PyTorch FSDP is enabled. " "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks self.callback_handler = CallbackHandler( - callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler + callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) @@ -518,7 +679,7 @@ class Trainer: if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") - if args.max_steps > 0: + if args.max_steps > 0 and args.num_train_epochs > 0: logger.info("max_steps is given, it will override any value given in num_train_epochs") if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: @@ -565,7 +726,8 @@ class Trainer: if (args.fp16 or args.bf16) and args.half_precision_backend == "auto": if args.device == torch.device("cpu"): if args.fp16: - raise ValueError("Tried to use `fp16` but it is not supported on cpu") + if not is_torch_greater_or_equal_than_2_3: + raise ValueError("Tried to use `fp16` but it is not supported on cpu") else: args.half_precision_backend = "cpu_amp" logger.info(f"Using {args.half_precision_backend} half precision backend") @@ -589,12 +751,15 @@ class Trainer: else: self.label_smoother = None + self.control = TrainerControl() + self.state = TrainerState( is_local_process_zero=self.is_local_process_zero(), is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], ) - - self.control = TrainerControl() # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then # returned to 0 every time flos need to be logged self.current_flos = 0 @@ -615,12 +780,34 @@ class Trainer: if args.torch_compile and not is_torch_compile_available(): raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") + self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False) + if self.is_fsdp_xla_v2_enabled: + if not IS_XLA_FSDPV2_POST_2_2: + raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.") + # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper. + # Tensor axis is just a placeholder where it will not be used in FSDPv2. + num_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor"))) + self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled + + @property + def tokenizer(self) -> Optional[PreTrainedTokenizerBase]: + logger.warning("Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.") + return self.processing_class + + @tokenizer.setter + def tokenizer(self, processing_class) -> None: + logger.warning( + "Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead." + ) + self.processing_class = processing_class + def _activate_neftune(self, model): r""" Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914 """ - unwrapped_model = unwrap_model(model) + unwrapped_model = self.accelerator.unwrap_model(model) if _is_peft_model(unwrapped_model): embeddings = unwrapped_model.base_model.model.get_input_embeddings() @@ -641,7 +828,7 @@ class Trainer: if not hasattr(self, "neftune_hook_handle"): raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first") - unwrapped_model = unwrap_model(model) + unwrapped_model = self.accelerator.unwrap_model(model) if _is_peft_model(unwrapped_model): embeddings = unwrapped_model.base_model.model.get_input_embeddings() @@ -656,7 +843,7 @@ class Trainer: Add a callback to the current list of [`~transformers.TrainerCallback`]. Args: - callback (`type` or [`~transformers.TrainerCallback`]): + callback (`type` or [`~transformers.TrainerCallback]`): A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the first case, will instantiate a member of that class. """ @@ -669,7 +856,7 @@ class Trainer: If the callback is not found, returns `None` (and no error is raised). Args: - callback (`type` or [`~transformers.TrainerCallback`]): + callback (`type` or [`~transformers.TrainerCallback]`): A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the first case, will pop the first member of that class found in the list of callbacks. @@ -683,7 +870,7 @@ class Trainer: Remove a callback from the current list of [`~transformers.TrainerCallback`]. Args: - callback (`type` or [`~transformers.TrainerCallback`]): + callback (`type` or [`~transformers.TrainerCallback]`): A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the first case, will remove the first member of that class found in the list of callbacks. """ @@ -700,7 +887,11 @@ class Trainer: # Inspect model forward signature to keep only the arguments it accepts. model_to_inspect = self.model if _is_peft_model(self.model): - model_to_inspect = self.model.get_base_model() + if hasattr(self.model, "get_base_model"): + model_to_inspect = self.model.get_base_model() + else: + # PeftMixedModel do not provide a `get_base_model` method + model_to_inspect = self.model.base_model.model signature = inspect.signature(model_to_inspect.forward) self._signature_columns = list(signature.parameters.keys()) # Labels may be named label or label_ids, the default data collator handles that. @@ -723,6 +914,12 @@ class Trainer: ) columns = [k for k in signature_columns if k in dataset.column_names] + if len(columns) == 0: + raise ValueError( + "No columns in the dataset match the model's forward method signature. " + f"The following columns have been ignored: [{', '.join(ignored_columns)}]. " + "Please check the dataset and model. You may need to set `remove_unused_columns=False` in `TrainingArguments`." + ) if version.parse(datasets.__version__) < version.parse("1.4.0"): dataset.set_format( @@ -764,7 +961,9 @@ class Trainer: ) else: lengths = None - model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None + model_input_name = ( + self.processing_class.model_input_names[0] if self.processing_class is not None else None + ) return LengthGroupedSampler( self.args.train_batch_size * self.args.gradient_accumulation_steps, dataset=self.train_dataset, @@ -806,13 +1005,18 @@ class Trainer: dataloader_params["sampler"] = self._get_train_sampler() dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: + if eval_dataset is None or not has_length(eval_dataset): + return None + # Build the sampler. + # Deprecated code if self.args.use_legacy_prediction_loop: - if is_torch_tpu_available(): + if is_torch_xla_available(): return SequentialDistributedSampler( eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() ) @@ -826,25 +1030,58 @@ class Trainer: else: return SequentialSampler(eval_dataset) + if self.args.group_by_length: + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + lengths = ( + eval_dataset[self.args.length_column_name] + if self.args.length_column_name in eval_dataset.column_names + else None + ) + else: + lengths = None + model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None + return LengthGroupedSampler( + self.args.eval_batch_size, + dataset=eval_dataset, + lengths=lengths, + model_input_name=model_input_name, + ) + if self.args.world_size <= 1: return SequentialSampler(eval_dataset) else: return None - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: """ Returns the evaluation [`~torch.utils.data.DataLoader`]. Subclass and override this method if you want to inject some custom behavior. Args: - eval_dataset (`torch.utils.data.Dataset`, *optional*): - If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted - by the `model.forward()` method are automatically removed. It must implement `__len__`. + eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*): + If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. """ if eval_dataset is None and self.eval_dataset is None: raise ValueError("Trainer: evaluation requires an eval_dataset.") - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" + if ( + hasattr(self, "_eval_dataloaders") + and dataloader_key in self._eval_dataloaders + and self.args.dataloader_persistent_workers + ): + return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) + + eval_dataset = ( + self.eval_dataset[eval_dataset] + if isinstance(eval_dataset, str) + else eval_dataset + if eval_dataset is not None + else self.eval_dataset + ) data_collator = self.data_collator if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): @@ -863,8 +1100,18 @@ class Trainer: if not isinstance(eval_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = eval_dataloader + else: + self._eval_dataloaders = {dataloader_key: eval_dataloader} - return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + return self.accelerator.prepare(eval_dataloader) def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: """ @@ -895,6 +1142,7 @@ class Trainer: if not isinstance(test_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_eval_sampler(test_dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor # We use the same batch_size as for eval. return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) @@ -952,9 +1200,28 @@ class Trainer: }, ] - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + if self.optimizer_cls_and_kwargs is not None: + optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs + else: + optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) + + # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("params") + + # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` + # e.g. for LOMO optimizer. + if "model" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("model") + + # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` + # to avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes @@ -974,8 +1241,40 @@ class Trainer: return self.optimizer + def get_num_trainable_parameters(self): + """ + Get the number of trainable parameters. + """ + return sum(p.numel() for p in self.model.parameters() if p.requires_grad) + + def get_learning_rates(self): + """ + Returns the learning rate of each parameter from self.optimizer. + """ + if self.optimizer is None: + raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.") + return [group["lr"] for group in self.optimizer.param_groups] + + def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None): + """ + Returns optimizer group for a parameter if given, else returns all optimizer groups for params. + + Args: + param (`str` or `torch.nn.parameter.Parameter`, *optional*): + The parameter for which optimizer group needs to be returned. + """ + if self.optimizer is None: + raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.") + if param is not None: + for group in self.optimizer.param_groups: + if param in group["params"]: + return group + return [group["params"] for group in self.optimizer.param_groups] + @staticmethod - def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: + def get_optimizer_cls_and_kwargs( + args: TrainingArguments, model: Optional[PreTrainedModel] = None + ) -> Tuple[Any, Any]: """ Returns the optimizer class and optimizer parameters based on the training arguments. @@ -1042,13 +1341,20 @@ class Trainer: OptimizerNames.ADAMW_8BIT, OptimizerNames.PAGED_ADAMW, OptimizerNames.PAGED_ADAMW_8BIT, + OptimizerNames.ADEMAMIX, + OptimizerNames.ADEMAMIX_8BIT, + OptimizerNames.PAGED_ADEMAMIX, + OptimizerNames.PAGED_ADEMAMIX_8BIT, OptimizerNames.LION, OptimizerNames.LION_8BIT, OptimizerNames.PAGED_LION, OptimizerNames.PAGED_LION_8BIT, + OptimizerNames.RMSPROP_BNB, + OptimizerNames.RMSPROP_8BIT, + OptimizerNames.RMSPROP_32BIT, ]: try: - from bitsandbytes.optim import AdamW, Lion + from bitsandbytes.optim import AdamW, Lion, RMSprop is_paged = False optim_bits = 32 @@ -1063,12 +1369,47 @@ class Trainer: elif "lion" in args.optim: optimizer_cls = Lion additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} + elif "rmsprop" in args.optim: + optimizer_cls = RMSprop + # Above we pass all `adam_kwargs` to the optimizer, here + # we only pass `optim_args` which can be passed by the user. + additional_optim_kwargs = optim_args + elif "ademamix" in args.optim: + if is_bitsandbytes_available() and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.44.0"): + raise ValueError( + "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. " + "Please install `bitsandbytes` >= 0.44.0." + ) + + from bitsandbytes.optim import AdEMAMix + + optimizer_cls = AdEMAMix + additional_optim_kwargs = { + "betas": ( + float(optim_args.get("beta1", args.adam_beta1)), + float(optim_args.get("beta2", args.adam_beta2)), + float(optim_args.get("beta3", 0.9999)), + ), + "alpha": float(optim_args.get("alpha", 5.0)), + "eps": float(optim_args.get("eps", args.adam_epsilon)), + } + + if "t_alpha" in optim_args: + additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"]) + + if "t_beta3" in optim_args: + additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"]) + + bnb_kwargs = {"optim_bits": optim_bits} + if "rmsprop" not in args.optim: + bnb_kwargs["is_paged"] = is_paged - bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits} optimizer_kwargs.update(additional_optim_kwargs) optimizer_kwargs.update(bnb_kwargs) except ImportError: - raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!") + raise ValueError("Trainer tried to instantiate bnb optimizer but `bitsandbytes` is not installed!") if is_bitsandbytes_available() and version.parse( importlib.metadata.version("bitsandbytes") ) < version.parse("0.41.1"): @@ -1102,6 +1443,216 @@ class Trainer: optimizer_cls = torch.optim.Adagrad elif args.optim == OptimizerNames.RMSPROP: optimizer_cls = torch.optim.RMSprop + elif args.optim in [ + OptimizerNames.GALORE_ADAMW, + OptimizerNames.GALORE_ADAMW_8BIT, + OptimizerNames.GALORE_ADAFACTOR, + OptimizerNames.GALORE_ADAMW_LAYERWISE, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE, + ]: + if not is_galore_torch_available(): + raise ImportError( + "You need to install `galore_torch` in order to use GaLore optimizers" + " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`" + ) + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + + is_layerwise = args.optim.lower().endswith("layerwise") + if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED: + raise NotImplementedError("Layer-wise GaLore does not support DDP at this time") + + optimizer_mapping = { + OptimizerNames.GALORE_ADAMW: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor, + OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW, + OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit, + OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor, + } + + optimizer_cls = optimizer_mapping[args.optim] + + if args.optim_target_modules is None: + raise ValueError( + "You need to define a `optim_target_modules` in order to properly use GaLore optimizers" + ) + + if not isinstance(args.optim_target_modules, (list, str)): + raise ValueError( + f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}" + ) + + if model is None: + raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") + + logger.warning( + "Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !" + ) + + all_linear = ( + isinstance(args.optim_target_modules, str) + and args.optim_target_modules.replace("_", "-") == "all-linear" + ) + + galore_params = [] + galore_params_names = [] + for module_name, module in model.named_modules(): + target_module_exists, is_regex = check_target_module_exists( + args.optim_target_modules, module_name, return_is_regex=True + ) + + if not isinstance(module, nn.Linear): + # Warn in case we match but it's not a linear layer + if target_module_exists and not is_regex: + logger.warning( + f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!" + ) + + continue + + if not target_module_exists and not all_linear: + continue + + galore_params.append(module.weight) + galore_params_names.append(module_name + ".weight") + + if len(galore_params) == 0: + raise ValueError( + f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`." + ) + + non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names] + + galore_optim_kwargs = { + "rank": int(optim_args.pop("rank", 128)), + "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), + "scale": float(optim_args.pop("scale", 0.25)), + "proj_type": optim_args.pop("proj_type", "std"), + } + + # The default args are from the official repository: https://github.com/jiaweizzhao/GaLore + param_groups = [ + {"params": non_galore_params}, + {"params": galore_params, **galore_optim_kwargs}, + ] + + if is_layerwise: + # For layer-wise optimizers, the optimization step is done through post accumulation + # gradient hooks. The trick is to first attach these hooks to the model parameters then + # create a dummy optimizer that will perform no-ops in the Trainer. + # See the original implementation or the nice implementation from @hiyouga + # here: https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba + if args.gradient_accumulation_steps != 1: + raise ValueError("Layerwise GaLoRE optimizer do not support gradient accumulation !") + + optimizer_dict = {} + for param in non_galore_params: + param_groups = [{"params": [param]}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + for param in galore_params: + param_groups = [{"params": [param], **galore_optim_kwargs}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + + def optimizer_hook(param): + if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + + for param in model.parameters(): + if param.requires_grad: + param.register_post_accumulate_grad_hook(optimizer_hook) + + optimizer_cls = LayerWiseDummyOptimizer + optimizer_kwargs.update({"optimizer_dict": optimizer_dict}) + + optimizer_kwargs.update({"params": param_groups}) + + if args.optim == OptimizerNames.GALORE_ADAFACTOR: + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) + elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + if not is_lomo_available(): + raise ImportError( + "You need to install `lomo_optim` in order to use LOMO optimizers" + " install it with `pip install lomo-optim`" + ) + if not is_accelerate_available("0.30.0"): + raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers") + + if model is None: + raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.") + + from lomo_optim import AdaLomo, Lomo + + if "ada" in args.optim: + optimizer_cls = AdaLomo + else: + optimizer_cls = Lomo + + optimizer_kwargs.update({"model": model}) + elif args.optim == OptimizerNames.GROKADAMW: + if not is_grokadamw_available(): + raise ValueError("Please install grokadamw with `pip install grokadamw`") + + from grokadamw import GrokAdamW + + optimizer_cls = GrokAdamW + optimizer_kwargs.update( + { + "alpha_init": float(optim_args.get("alpha_init", 0.98)), + "lamb": float(optim_args.get("lamb", 2.0)), + "gamma": float(optim_args.get("gamma", 0.1)), + "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)), + "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)), + } + ) + elif args.optim == OptimizerNames.ADAMW_TORCH_4BIT: + if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse( + "0.4.0" + ): + raise ImportError( + "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers." + "Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao" + ) + if version.parse(importlib.metadata.version("torch")) <= version.parse("2.4"): + raise ImportError( + "You need to have `torch>2.4` in order to use torch 4-bit optimizers. " + "Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly." + ) + from torchao.prototype.low_bit_optim import AdamW4bit + + optimizer_cls = AdamW4bit + optimizer_kwargs.update(adam_kwargs) + elif args.optim in [ + OptimizerNames.SCHEDULE_FREE_ADAMW, + OptimizerNames.SCHEDULE_FREE_SGD, + ]: + if not is_schedulefree_available(): + raise ImportError( + "You need to install `schedulefree` in order to use schedulefree optimizers" + " install it with `pip install schedulefree`" + ) + if not is_accelerate_available("0.30.0"): + raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers") + from schedulefree import AdamWScheduleFree, SGDScheduleFree + + additional_optim_kwargs = {} + if args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW: + optimizer_cls = AdamWScheduleFree + additional_optim_kwargs = adam_kwargs + elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD: + optimizer_cls = SGDScheduleFree + else: + raise ValueError("Invalid schedulefree optimizer") + additional_optim_kwargs["weight_decay"] = args.weight_decay + additional_optim_kwargs["warmup_steps"] = args.warmup_steps + additional_optim_kwargs.update( + { + "weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)), + "r": float(optim_args.get("r", 0.0)), + } + ) + optimizer_kwargs.update(additional_optim_kwargs) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs @@ -1139,21 +1690,21 @@ class Trainer: except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader return len(dataloader) * self.args.per_device_train_batch_size - def num_tokens(self, train_dl: DataLoader, max_steps: Optional[int] = None) -> int: + @staticmethod + def num_tokens(train_dl: DataLoader, max_steps: Optional[int] = None) -> int: """ Helper to get number of tokens in a [`~torch.utils.data.DataLoader`] by enumerating dataloader. """ train_tokens = 0 try: - for step, batch in enumerate(train_dl): + for batch in train_dl: tokens = batch["input_ids"].numel() if max_steps is not None: return tokens * max_steps train_tokens += tokens - return train_tokens except KeyError: logger.warning("Cannot get num_tokens from dataloader") - return train_tokens + return train_tokens def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): """HP search setup code""" @@ -1193,6 +1744,9 @@ class Trainer: if self.is_deepspeed_enabled: if self.args.deepspeed is None: raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") + + self.accelerator.free_memory() + # Rebuild the deepspeed config to reflect the updated training parameters from accelerate.utils import DeepSpeedPlugin @@ -1202,6 +1756,10 @@ class Trainer: self.args.hf_deepspeed_config.trainer_config_process(self.args) self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) + # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps. + # Simply calling `_reset_state` is enough and doesn't need a version pin. + AcceleratorState()._reset_state() + self.create_accelerator_and_postprocess() def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): @@ -1212,7 +1770,7 @@ class Trainer: if self.hp_search_backend == HPSearchBackend.OPTUNA: import optuna - if not trial.study._is_multi_objective(): + if hasattr(trial, "study") and not trial.study._is_multi_objective(): trial.report(self.objective, step) if trial.should_prune(): self.callback_handler.on_train_end(self.args, self.state, self.control) @@ -1232,6 +1790,8 @@ class Trainer: output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") self.save_model(output_dir, _internal_call=True) if self.args.should_save: + # Update the `TrainerControl` state to where we are currently + self.state.stateful_callbacks["TrainerControl"] = self.control.state() self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -1315,6 +1875,34 @@ class Trainer: return model + def compare_trainer_and_checkpoint_args(self, training_args, trainer_state): + attributes_map = { + "logging_steps": "logging_steps", + "eval_steps": "eval_steps", + "save_steps": "save_steps", + } + + has_warning = False + warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: " + for arg_attr, state_attr in attributes_map.items(): + arg_value = getattr(training_args, arg_attr, None) + state_value = getattr(trainer_state, state_attr, None) + + if arg_value is not None and state_value is not None and arg_value != state_value: + warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)" + has_warning = True + + # train bs is special as we need to account for multi-GPU + train_bs_args = training_args.per_device_train_batch_size + train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu) + + if train_bs_args != train_bs_state: + warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)" + has_warning = True + + if has_warning: + logger.warning_once(warning_str) + def _wrap_model(self, model, training=True, dataloader=None): if self.args.use_ipex: dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 @@ -1327,7 +1915,7 @@ class Trainer: return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again - if unwrap_model(model) is not model: + if self.accelerator.unwrap_model(model) is not model: return model # Mixed precision training with apex (torch < 1.6) @@ -1358,6 +1946,11 @@ class Trainer: size_based_auto_wrap_policy, transformer_auto_wrap_policy, ) + + if self.is_fsdp_xla_v2_enabled: + from torch_xla.experimental.spmd_fully_sharded_data_parallel import ( + SpmdFullyShardedDataParallel as FSDPv2, + ) except ImportError: raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") auto_wrap_policy = None @@ -1387,17 +1980,48 @@ class Trainer: ) fsdp_kwargs = self.args.xla_fsdp_config if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: + if model.config.use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + model.config.use_cache = False + # Apply gradient checkpointing to auto-wrapped sub-modules if specified def auto_wrapper_callable(m, *args, **kwargs): - return FSDP(checkpoint_module(m), *args, **kwargs) + target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2 + return target_cls(checkpoint_module(m), *args, **kwargs) # Wrap the base model with an outer FSDP wrapper - self.model = model = FSDP( - model, - auto_wrap_policy=auto_wrap_policy, - auto_wrapper_callable=auto_wrapper_callable, - **fsdp_kwargs, - ) + if self.is_fsdp_xla_v2_enabled: + + def shard_output(output, mesh): + from .modeling_outputs import CausalLMOutputWithPast + + real_output = None + if isinstance(output, torch.Tensor): + real_output = output + elif isinstance(output, tuple): + real_output = output[0] + elif isinstance(output, CausalLMOutputWithPast): + real_output = output.logits + + if real_output is None: + raise ValueError("Something went wrong, the output of the model shouldn't be `None`") + xs.mark_sharding(real_output, mesh, ("fsdp", None, None)) + + self.model = model = FSDPv2( + model, + shard_output=shard_output, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + ) + else: + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) # Patch `xm.optimizer_step` should not reduce gradients in this case, # as FSDP does not need gradient reduction over sharded parameters. @@ -1474,7 +2098,7 @@ class Trainer: # do_train is not a reliable argument, as it might not be set and .train() still called, so # the following is a workaround: - if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train and not self.is_model_parallel: self._move_model_to_device(self.model, args.device) if "model_path" in kwargs: @@ -1485,7 +2109,7 @@ class Trainer: FutureWarning, ) if len(kwargs) > 0: - raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") + raise TypeError(f"train() got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") # This might change the seed so needs to run first. self._hp_search_setup(trial) self._train_batch_size = self.args.train_batch_size @@ -1566,6 +2190,8 @@ class Trainer: logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() + if self.is_fsdp_xla_v2_enabled: + train_dataloader = tpu_spmd_dataloader(train_dataloader) # Setting up training control variables: # number of training epochs: num_train_epochs @@ -1637,7 +2263,11 @@ class Trainer: if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) - self.state = TrainerState() + self.state = TrainerState( + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ] + ) self.state.is_hyper_param_search = trial is not None self.state.train_batch_size = self._train_batch_size @@ -1660,12 +2290,7 @@ class Trainer: # Activate gradient checkpointing if needed if args.gradient_checkpointing: - if args.gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {} - else: - gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs - - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) model = self._wrap_model(self.model_wrapped) @@ -1674,7 +2299,17 @@ class Trainer: # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False + if use_accelerator_prepare and self.is_fsdp_enabled: + # In case of auto_find_batch_size=True + # Remove FSDP wrapping from sub-models. + self.model = unwrap_model(self.model, recursive=True) + if delay_optimizer_creation: + if use_accelerator_prepare: + # configure fsdp plugin for qlora if any + self._fsdp_qlora_plugin_updates() + if self.accelerator.mixed_precision != "fp8": + self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare @@ -1690,6 +2325,9 @@ class Trainer: model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) + elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + # In this case we are in DDP + LOMO, which should be supported + self.optimizer = self.accelerator.prepare(self.optimizer) if self.is_fsdp_enabled: self.model = self.model_wrapped = model @@ -1705,7 +2343,9 @@ class Trainer: # ckpt loading if resume_from_checkpoint is not None: if self.is_deepspeed_enabled: - deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + deepspeed_load_checkpoint( + self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) + ) elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) @@ -1740,7 +2380,9 @@ class Trainer: os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) - epochs_trained = self.state.global_step // num_update_steps_per_epoch + self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch *= args.gradient_accumulation_steps @@ -1783,39 +2425,23 @@ class Trainer: self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() - + grad_norm: Optional[float] = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. - if not args.ignore_data_skip: - for epoch in range(epochs_trained): - sampler = get_dataloader_sampler(train_dataloader) - sampler_kinds = [RandomSampler] - if version.parse(accelerate_version) > version.parse("0.23.0"): - sampler_kinds.append(SeedableRandomSampler) - is_random_sampler = isinstance(sampler, tuple(sampler_kinds)) - if not is_random_sampler: - # We just need to begin an iteration to create the randomization of the sampler. - for _ in train_dataloader: - break - else: - # Otherwise we need to call the whooooole sampler cause there is some random operation added - # AT THE VERY END! - sampler = sampler if sampler is not None else [] - _ = list(sampler) + if args.eval_on_start: + self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) - total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): - epoch_iterator = train_dataloader - if hasattr(epoch_iterator, "set_epoch"): - epoch_iterator.set_epoch(epoch) + epoch_dataloader = train_dataloader + if hasattr(epoch_dataloader, "set_epoch"): + epoch_dataloader.set_epoch(epoch) # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: self._past = None steps_in_epoch = ( - len(epoch_iterator) + len(epoch_dataloader) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps ) @@ -1827,124 +2453,171 @@ class Trainer: rng_to_sync = False steps_skipped = 0 if steps_trained_in_current_epoch > 0: - epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch) steps_skipped = steps_trained_in_current_epoch steps_trained_in_current_epoch = 0 rng_to_sync = True step = -1 - for step, inputs in enumerate(epoch_iterator): - total_batched_samples += 1 - - if self.args.include_num_input_tokens_seen: - main_input_name = getattr(self.model, "main_input_name", "input_ids") - if main_input_name not in inputs: - logger.warning( - "Tried to track the number of tokens seen, however the current model is " - "not configured properly to know what item is the input. To fix this, add " - "a `main_input_name` attribute to the model class you are using." - ) + epoch_iterator = iter(epoch_dataloader) + # We chunkify the epoch iterator into gradient accumulation steps `n` batches + remainder = num_examples % args.gradient_accumulation_steps + if remainder == 0: + remainder = args.gradient_accumulation_steps + update_step = -1 + total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 + for _ in range(total_updates): + update_step += 1 + num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) + for i, inputs in enumerate(batch_samples): + step += 1 + do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch + # Since we perform prefetching, we need to manually set sync_gradients + if not do_sync_step: + self.accelerator.gradient_state._set_sync_gradients(False) else: - self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel() - if rng_to_sync: - self._load_rng_state(resume_from_checkpoint) - rng_to_sync = False - - # Skip past any already trained steps if resuming training - if steps_trained_in_current_epoch > 0: - steps_trained_in_current_epoch -= 1 - if steps_trained_progress_bar is not None: - steps_trained_progress_bar.update(1) - if steps_trained_in_current_epoch == 0: - self._load_rng_state(resume_from_checkpoint) - continue - elif steps_trained_progress_bar is not None: - steps_trained_progress_bar.close() - steps_trained_progress_bar = None - - if step % args.gradient_accumulation_steps == 0: - self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - - with self.accelerator.accumulate(model): - tr_loss_step = self.training_step(model, inputs) - - if ( - args.logging_nan_inf_filter - and not is_torch_tpu_available() - and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) - ): - # if loss is nan or inf simply add the average of previous logged losses - tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) - else: - tr_loss += tr_loss_step - - self.current_flos += float(self.floating_point_ops(inputs)) - - is_last_step_and_steps_less_than_grad_acc = ( - steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch - ) - - if ( - total_batched_samples % args.gradient_accumulation_steps == 0 - or - # last step in epoch but step is always smaller than gradient_accumulation_steps - is_last_step_and_steps_less_than_grad_acc - ): - # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered - # in accelerate. So, explicitly enable sync gradients to True in that case. - if is_last_step_and_steps_less_than_grad_acc: self.accelerator.gradient_state._set_sync_gradients(True) - # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0: - # deepspeed does its own clipping - - if is_sagemaker_mp_enabled() and args.fp16: - self.optimizer.clip_master_grads(args.max_grad_norm) - elif self.use_apex: - # Revert to normal clipping otherwise, handling Apex or full precision - nn.utils.clip_grad_norm_( - amp.master_params(self.optimizer), - args.max_grad_norm, + if self.args.include_num_input_tokens_seen: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." ) else: - self.accelerator.clip_grad_norm_( - model.parameters(), - args.max_grad_norm, + input_tokens = inputs[main_input_name].numel() + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) + self.state.num_input_tokens_seen += ( + self.accelerator.gather(input_tokens).sum().cpu().item() + ) + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + # We explicitly want to avoid relying on `accelerator.accumulate` for generation training + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i != len(batch_samples) - 1 + and self.accelerator.distributed_type != DistributedType.DEEPSPEED + else contextlib.nullcontext + ) + with context(): + tr_loss_step = self.training_step(model, inputs, num_items_in_batch) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" ) + tr_loss = tr_loss + tr_loss_step - # Optimizer step - self.optimizer.step() - optimizer_was_run = not self.accelerator.optimizer_step_was_skipped - if optimizer_was_run: - # Delay optimizer scheduling until metrics are generated - if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - self.lr_scheduler.step() + self.current_flos += float(self.floating_point_ops(inputs)) - model.zero_grad() - self.state.global_step += 1 - self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch - self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if do_sync_step: + # Since we perform prefetching, we need to manually set sync_gradients to True + self.accelerator.gradient_state._set_sync_gradients(True) - self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) - else: - self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) + elif self.use_apex: + # Revert to normal clipping otherwise, handling Apex or full precision + _grad_norm = nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + if ( + is_accelerate_available() + and self.accelerator.distributed_type == DistributedType.DEEPSPEED + ): + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + else: + grad_norm = _grad_norm + + self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control) + + self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control) + + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + self._maybe_log_save_evaluate( + tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time + ) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() + break + # We also need to break out of the nested loop if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() break if step < 0: logger.warning( - "There seems to be not a single sample in your epoch_iterator, stopping training at step" + "There seems not to be a single sample in your epoch_iterator, stopping training at step" f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" f" num_steps ({max_steps}) higher than the number of available samples." ) self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: - if is_torch_tpu_available(): + if is_torch_xla_available(): # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) xm.master_print(met.metrics_report()) else: @@ -1962,7 +2635,7 @@ class Trainer: logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sure the model has been saved by process 0. - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.rendezvous("load_best_model_at_end") elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() @@ -1973,7 +2646,8 @@ class Trainer: # add remaining tr_loss self._total_loss_scalar += tr_loss.item() - train_loss = self._total_loss_scalar / self.state.global_step + effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step metrics = speed_metrics( "train", @@ -2000,7 +2674,7 @@ class Trainer: for checkpoint in checkpoints_sorted: if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") - shutil.rmtree(checkpoint) + shutil.rmtree(checkpoint, ignore_errors=True) self.control = self.callback_handler.on_train_end(args, self.state, self.control) @@ -2055,6 +2729,20 @@ class Trainer: # this checks the FSDP state dict when `FULL_STATE_DICT` is used or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin")) ) + # if multiple adapters exist, they get saved in sub directories + adapter_subdirs = ( + [ + folder_name + for folder_name in os.listdir(resume_from_checkpoint) + if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name)) + and ( + os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME)) + or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME)) + ) + ] + if os.path.isdir(resume_from_checkpoint) + else [] + ) if is_fsdp_ckpt and not self.is_fsdp_enabled: raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP") @@ -2072,6 +2760,7 @@ class Trainer: ] ) or is_fsdp_ckpt + or adapter_subdirs ): raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") @@ -2088,7 +2777,7 @@ class Trainer: ) if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt: - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} # If the model is on the GPU, it still works! if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): @@ -2115,7 +2804,13 @@ class Trainer: # release memory del state_dict elif self.is_fsdp_enabled: - load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint) + load_fsdp_model( + self.accelerator.state.fsdp_plugin, + self.accelerator, + model, + resume_from_checkpoint, + **_get_fsdp_ckpt_kwargs(), + ) else: # We load the model state dict on the CPU to avoid an OOM error. if self.args.save_safetensors and os.path.isfile(safe_weights_file): @@ -2137,9 +2832,27 @@ class Trainer: # Load adapters following PR # 24096 elif _is_peft_model(model): # If train a model using PEFT & LoRA, assume that adapter have been saved properly. - if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): + # TODO: in the future support only specific min PEFT versions + if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr( + model, "load_adapter" + ): if os.path.exists(resume_from_checkpoint): - model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True) + # For BC for older PEFT versions + if hasattr(model, "active_adapters"): + active_adapters = model.active_adapters + if len(active_adapters) > 1: + logger.warning("Multiple active adapters detected will only consider the first adapter") + active_adapter = active_adapters[0] + else: + active_adapter = model.active_adapter + + if adapter_subdirs: + for subdir_name in adapter_subdirs: + peft_id = os.path.join(resume_from_checkpoint, subdir_name) + model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter)) + model.set_adapter(active_adapter) + else: + model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True) else: logger.warning( "The intermediate checkpoints of PEFT may not be saved correctly, " @@ -2165,10 +2878,18 @@ class Trainer: model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.is_deepspeed_enabled: - deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) + deepspeed_load_checkpoint( + self.model_wrapped, + self.state.best_model_checkpoint, + load_module_strict=not _is_peft_model(self.model), + ) elif self.is_fsdp_enabled: load_result = load_fsdp_model( - self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint + self.accelerator.state.fsdp_plugin, + self.accelerator, + model, + self.state.best_model_checkpoint, + **_get_fsdp_ckpt_kwargs(), ) elif ( os.path.exists(best_model_path) @@ -2177,7 +2898,7 @@ class Trainer: or os.path.exists(best_safe_adapter_model_path) ): has_been_loaded = True - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): # If the 'user_content.pt' file exists, load with the new smp api. @@ -2205,9 +2926,35 @@ class Trainer: else: if _is_peft_model(model): # If train a model using PEFT & LoRA, assume that adapter have been saved properly. - if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): + # TODO: in the future support only specific min PEFT versions + if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr( + model, "load_adapter" + ): + # For BC for older PEFT versions + if hasattr(model, "active_adapters"): + active_adapter = model.active_adapters[0] + if len(model.active_adapters) > 1: + logger.warning("Detected multiple active adapters, will only consider the first one") + else: + active_adapter = model.active_adapter + if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): - model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) + try: + model.load_adapter(self.state.best_model_checkpoint, active_adapter) + except RuntimeError as exc: + if model.peft_config[active_adapter].is_prompt_learning: + # for context: https://github.com/huggingface/peft/issues/2256 + msg = ( + "When using prompt learning PEFT methods such as " + f"{model.peft_config[active_adapter].peft_type.value}, setting " + "load_best_model_at_end=True can lead to errors, it is recommended " + "to set this to False and to load the model manually from the checkpoint " + "directory using PeftModel.from_pretrained(base_model, ) after training " + "has finished." + ) + raise RuntimeError(msg) from exc + else: + raise # Load_adapter has no return value present, modify it when appropriate. from torch.nn.modules.module import _IncompatibleKeys @@ -2239,7 +2986,9 @@ class Trainer: load_result = model.load_state_dict(state_dict, False) if not is_sagemaker_mp_enabled() and has_been_loaded: self._issue_warnings_after_load(load_result) - elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists( + os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME) + ): load_result = load_sharded_checkpoint( model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() ) @@ -2264,9 +3013,30 @@ class Trainer: f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." ) - def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): + def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False): + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + # Run delayed LR scheduler now that metrics are populated + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + try: + self.lr_scheduler.step(metrics[metric_to_check]) + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', " + f"which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. " + f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or " + f"consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + return metrics + + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time): if self.control.should_log and self.state.global_step > self._globalstep_last_logged: - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.mark_step() logs: Dict[str, float] = {} @@ -2278,28 +3048,26 @@ class Trainer: tr_loss -= tr_loss logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + if grad_norm is not None: + logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm logs["learning_rate"] = self._get_learning_rate() self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step self.store_flos() - self.log(logs) + self.log(logs, start_time) metrics = None if self.control.should_evaluate: - metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) - self._report_to_hp_search(trial, self.state.global_step, metrics) - - # Run delayed LR scheduler now that metrics are populated - if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - metric_to_check = self.args.metric_for_best_model - if not metric_to_check.startswith("eval_"): - metric_to_check = f"eval_{metric_to_check}" - self.lr_scheduler.step(metrics[metric_to_check]) + metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == SaveStrategy.BEST: + self.control.should_save = is_new_best_metric if self.control.should_save: - self._save_checkpoint(model, trial, metrics=metrics) + self._save_checkpoint(model, trial) self.control = self.callback_handler.on_save(self.args, self.state, self.control) def _load_rng_state(self, checkpoint): @@ -2325,7 +3093,8 @@ class Trainer: ) return - checkpoint_rng_state = torch.load(rng_file) + with safe_globals(): + checkpoint_rng_state = torch.load(rng_file) random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) @@ -2340,7 +3109,7 @@ class Trainer: f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" "\nThis won't yield the same results as if the training had not been interrupted." ) - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.set_rng_state(checkpoint_rng_state["xla"]) if is_torch_npu_available(): if self.args.parallel_mode == ParallelMode.DISTRIBUTED: @@ -2353,8 +3122,71 @@ class Trainer: f"Didn't manage to set back the RNG states of the NPU because of the following error:\n {e}" "\nThis won't yield the same results as if the training had not been interrupted." ) + if is_torch_mlu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + torch.mlu.random.set_rng_state_all(checkpoint_rng_state["mlu"]) + else: + try: + torch.mlu.random.set_rng_state(checkpoint_rng_state["mlu"]) + except Exception as e: + logger.info( + f"Didn't manage to set back the RNG states of the MLU because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + if is_torch_musa_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + torch.musa.set_rng_state_all(checkpoint_rng_state["musa"]) + else: + try: + torch.musa.set_rng_state(checkpoint_rng_state["musa"]) + except Exception as e: + logger.info( + f"Didn't manage to set back the RNG states of the MUSA because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + + def _determine_best_metric(self, metrics, trial): + """ + Determine if the model should be saved based on the evaluation metrics. + If args.metric_for_best_model is not set, the loss is used. - def _save_checkpoint(self, model, trial, metrics=None): + Returns: + bool: True if a new best metric was found, else False + """ + is_new_best_metric = False + + if self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + + try: + metric_value = metrics[metric_to_check] + except KeyError as exc: + raise KeyError( + f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. " + f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments." + ) from exc + + operator = np.greater if self.args.greater_is_better else np.less + + if self.state.best_metric is None: + self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf") + + if operator(metric_value, self.state.best_metric): + run_dir = self._get_output_dir(trial=trial) + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + output_dir = os.path.join(run_dir, checkpoint_folder) + + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + is_new_best_metric = True + + return is_new_best_metric + + def _save_checkpoint(self, model, trial): # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # want to save except FullyShardedDDP. # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" @@ -2367,67 +3199,36 @@ class Trainer: run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) - if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0: - logger.warning( - f"Checkpoint destination directory {output_dir} already exists and is non-empty." - "Saving will proceed but saved results may be invalid." - ) - staging_output_dir = output_dir - else: - staging_output_dir = os.path.join(run_dir, f"tmp-{checkpoint_folder}") - self.save_model(staging_output_dir, _internal_call=True) + self.save_model(output_dir, _internal_call=True) if not self.args.save_only_model: # Save optimizer and scheduler - self._save_optimizer_and_scheduler(staging_output_dir) + self._save_optimizer_and_scheduler(output_dir) # Save RNG state - self._save_rng_state(staging_output_dir) - - # Determine the new best metric / best model checkpoint - if metrics is not None and self.args.metric_for_best_model is not None: - metric_to_check = self.args.metric_for_best_model - if not metric_to_check.startswith("eval_"): - metric_to_check = f"eval_{metric_to_check}" - metric_value = metrics[metric_to_check] - - operator = np.greater if self.args.greater_is_better else np.less - if ( - self.state.best_metric is None - or self.state.best_model_checkpoint is None - or operator(metric_value, self.state.best_metric) - ): - self.state.best_metric = metric_value - self.state.best_model_checkpoint = output_dir + self._save_rng_state(output_dir) # Save the Trainer state if self.args.should_save: - self.state.save_to_json(os.path.join(staging_output_dir, TRAINER_STATE_NAME)) + # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently + for cb in [ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ]: + cb_name = cb.__class__.__name__ + cb_state = cb.state() + if isinstance(self.state.stateful_callbacks[cb_name], list): + self.state.stateful_callbacks[cb_name].append(cb_state) + else: + self.state.stateful_callbacks[cb_name] = cb_state + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) if self.args.push_to_hub: - self._push_from_checkpoint(staging_output_dir) - - # Place checkpoint in final location after all saving is finished. - # First wait for everyone to finish writing - self.args.distributed_state.wait_for_everyone() - - # Then go through the rewriting process, only renaming and rotating from main process(es) - if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero(): - if staging_output_dir != output_dir: - if os.path.exists(staging_output_dir): - os.rename(staging_output_dir, output_dir) - - # Ensure rename completed in cases where os.rename is not atomic - # And can only happen on non-windows based systems - if os.name != "nt": - fd = os.open(output_dir, os.O_RDONLY) - os.fsync(fd) - os.close(fd) - - # Maybe delete some older checkpoints. - if self.args.should_save: - self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + self._push_from_checkpoint(output_dir) - self.args.distributed_state.wait_for_everyone() + # Maybe delete some older checkpoints. + if self.args.should_save: + # Solely rely on numerical checkpoint id for rotation. + # mtime is not reliable especially on some fuse fs in cloud environments. + self._rotate_checkpoints(use_mtime=False, output_dir=run_dir) def _save_rng_state(self, output_dir): # Save RNG state in non-distributed training @@ -2443,7 +3244,7 @@ class Trainer: else: rng_states["cuda"] = torch.cuda.random.get_rng_state() - if is_torch_tpu_available(): + if is_torch_xla_available(): rng_states["xla"] = xm.get_rng_state() if is_torch_npu_available(): @@ -2452,6 +3253,18 @@ class Trainer: else: rng_states["npu"] = torch.npu.random.get_rng_state() + if is_torch_mlu_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["mlu"] = torch.mlu.random.get_rng_state_all() + else: + rng_states["mlu"] = torch.mlu.random.get_rng_state() + + if is_torch_musa_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + rng_states["musa"] = torch.musa.get_rng_state_all() + else: + rng_states["musa"] = torch.musa.get_rng_state() + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may # not yet exist. os.makedirs(output_dir, exist_ok=True) @@ -2462,9 +3275,22 @@ class Trainer: torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) def _save_optimizer_and_scheduler(self, output_dir): - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.rendezvous("saving_optimizer_states") - xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + if self.is_fsdp_xla_v1_enabled: + optm = { + "optimizer": self.optimizer.state_dict(), + "shard_metadata": self.model.get_shard_metadata(), + } + xm.save( + optm, + os.path.join( + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), + master_only=False, + ) + else: + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) @@ -2490,7 +3316,9 @@ class Trainer: self.model_wrapped.save_checkpoint(output_dir) elif self.is_fsdp_enabled: # save fsdp specific ckpt for resuming from ckpt - save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) + save_fsdp_model( + self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs() + ) save_fsdp_optimizer( self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir ) @@ -2505,7 +3333,7 @@ class Trainer: if ( self.args.should_save and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler) - and not is_torch_tpu_available() + and not is_torch_xla_available() ): with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) @@ -2540,11 +3368,26 @@ class Trainer: ) ) ) + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}")) + if self.is_fsdp_xla_v1_enabled + else checkpoint_file_exists + ) if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states - if is_torch_tpu_available(): + if is_torch_xla_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. - optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") + if self.is_fsdp_xla_v1_enabled: + optimizer_state = torch.load( + os.path.join( + checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}" + ), + map_location="cpu", + ) + # We only need `optimizer` when resuming from checkpoint + optimizer_state = optimizer_state["optimizer"] + else: + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") with warnings.catch_warnings(record=True) as caught_warnings: lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") reissue_pt_warnings(caught_warnings) @@ -2584,6 +3427,7 @@ class Trainer: self.optimizer, self.model, checkpoint, + **_get_fsdp_ckpt_kwargs(), ) else: self.optimizer.load_state_dict( @@ -2593,6 +3437,45 @@ class Trainer: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) + def _load_callback_state(self): + """If callback states exist and were passed in, restore their states if enabled""" + if not self.args.restore_callback_states_from_checkpoint: + return + # Callback states are stored in stateful_callbacks + not_found = [] + new_callbacks = [] + original_callbacks = self.callback_handler.callbacks + [self.control] + for stored_callback, data in self.state.stateful_callbacks.items(): + if not isinstance(data, list): + data = [data] + if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks): + # We can load/restore from multiple callbacks of the same type. + duplicates = [ + callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback + ] + for callback, callback_data in zip(duplicates, data): + args = callback_data.get("args", {}) + attributes = callback_data.get("attributes", {}) + new_callback = type(callback)(**args) + for attribute, value in attributes.items(): + setattr(new_callback, attribute, value) + if isinstance(callback, TrainerControl): + # Specifically for restoring the `control` state + self.control = new_callback + else: + new_callbacks.append(new_callback) + # We remove the existing callback and add it to the list of new callbacks + self.callback_handler.remove_callback(type(new_callback)) + logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in") + else: + not_found.append(stored_callback) + if len(not_found) > 0: + logger.warning( + f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})" + ) + for callback in new_callbacks: + self.callback_handler.add_callback(callback) + def hyperparameter_search( self, hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, @@ -2639,13 +3522,18 @@ class Trainer: hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): A function that defines the trial/run name. Will default to None. kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more - information see: - - - the documentation of - [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) - - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run) - - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create) + Additional keyword arguments for each backend: + + - `optuna`: parameters from + [optuna.study.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) + and also the parameters `timeout`, `n_jobs` and `gc_after_trial` from + [optuna.study.Study.optimize](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.optimize) + - `ray`: parameters from [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run). + If `resources_per_trial` is not set in the `kwargs`, it defaults to 1 CPU core and 1 GPU (if available). + If `progress_reporter` is not set in the `kwargs`, + [ray.tune.CLIReporter](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html) is used. + - `sigopt`: the parameter `proxies` from + [sigopt.Connection.set_proxies](https://docs.sigopt.com/support/faq#how-do-i-use-sigopt-with-a-proxy). Returns: [`trainer_utils.BestRun` or `List[trainer_utils.BestRun]`]: All the information about the best run or best @@ -2672,7 +3560,7 @@ class Trainer: self.hp_search_backend = None return best_run - def log(self, logs: Dict[str, float]) -> None: + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: """ Log `logs` on the various objects watching training. @@ -2681,11 +3569,15 @@ class Trainer: Args: logs (`Dict[str, float]`): The values to log. + start_time (`Optional[float]`): + The start of training. """ if self.state.epoch is not None: - logs["epoch"] = round(self.state.epoch, 2) + logs["epoch"] = self.state.epoch if self.args.include_num_input_tokens_seen: logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen + if start_time is not None: + speed_metrics("train", start_time, num_tokens=self.state.num_input_tokens_seen) output = {**logs, **{"step": self.state.global_step}} self.state.log_history.append(output) @@ -2743,7 +3635,9 @@ class Trainer: return ctx_manager - def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + def training_step( + self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None + ) -> torch.Tensor: """ Perform a training step on a batch of inputs. @@ -2762,67 +3656,144 @@ class Trainer: `torch.Tensor`: The tensor with training loss on this batch. """ model.train() - inputs = self._prepare_inputs(inputs) + if hasattr(self.optimizer, "train") and callable(self.optimizer.train): + self.optimizer.train() + inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device) + # ################################################################################################################################################################################ - # ####################################################### - # # print(inputs['dataset_id']) - # data_info_temp = inputs['data_info'] - # del inputs['dataset_id'] - # del inputs['data_info'] - # ####################################################### - + ####################################################### + # print(inputs['dataset_id']) + data_info_temp = inputs['data_info'] + del inputs['dataset_id'] + del inputs['data_info'] + ####################################################### + # import pdb; pdb.set_trace() + + + + with self.compute_loss_context_manager(): - loss = self.compute_loss(model, inputs) - - # ####################################################### - # import json - # for i in range(len(data_info_temp)): - # data_info_temp[i]['loss'] = float(loss[0][i]) - - # file_path = '/code/NIPS_2024/Llava_Med/inference_demo/llava_cherry_loss_all.jsonl' - # with open(file_path, 'a', encoding='utf-8') as file: - # # json.dump(data_info_temp[0], file, ensure_ascii=False, indent=4) - # for content in data_info_temp: - # json_string = json.dumps(content, ensure_ascii=False) - # file.write(json_string + '\n') + # loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + (loss, outputs) = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch,return_outputs=True) + + # 没用了 不要了 + # ####################################################### + # import pprint + # # pprint.pprint(outputs) + # # import pdb; pdb.set_trace() + # last_token_logits_yes = outputs.logits[:, -1, :] + # yes_target_token_id = 4874 + # yes_target_logprob = torch.log_softmax(last_token_logits_yes, dim=-1)[0, yes_target_token_id].item() # ####################################################### + + + import json + for i in range(len(data_info_temp)): + + tensor = outputs.logits[i] + mask = (tensor != 0).any(dim=1) + last_token_logits_yes = tensor[mask][-1].unsqueeze(0) + # import pdb; pdb.set_trace() + yes_target_token_id = 4874 + yes_target_logprob = torch.log_softmax(last_token_logits_yes, dim=-1)[0, yes_target_token_id].item() + print(yes_target_logprob) + + # data_info_temp[i]['yes_target_logprob_7B_Img'] = yes_target_logprob + data_info_temp[i]['yes_target_logprob_7B_NImg'] = yes_target_logprob + data_info_temp[i]['logits_shape'] = last_token_logits_yes.shape + + from datetime import datetime + # 获取当前时间并将其格式化为字符串 + # current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + current_time = datetime.now().strftime('_%Y_%m_%d_') + + + file_path = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/NIPS_2025/LLaVA_Fliter/inference_demo/New_llava_Logits_NImg_' + current_time + '.jsonl' + # file_path = '/inspire/hdd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/albus/NIPS_2025/LLaVA_Fliter/inference_demo/New_llava_Logits_Img_' + current_time + '.jsonl' + with open(file_path, 'a', encoding='utf-8') as file: + # json.dump(data_info_temp[0], file, ensure_ascii=False, indent=4) + for content in data_info_temp: + json_string = json.dumps(content, ensure_ascii=False) + file.write(json_string + '\n') + ####################################################### + + + ####################################################### + loss = loss[0].sum() + ####################################################### + ################################################################################################################################################################################ + + + del inputs + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_musa_available(): + torch.musa.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(min_version="2.0"): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + kwargs = {} - # ####################################################### - # loss = loss[0].sum() - # ####################################################### + # For LOMO optimizers you need to explicitly use the learnign rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training - ########################################## - # # train loss 提取 / 屏蔽backward # ########################################## - if self.use_apex: - with amp.scale_loss(loss, self.optimizer) as scaled_loss: - scaled_loss.backward() - else: - self.accelerator.backward(loss) - ####################################################################################### - ############################################################################################################################################################################### - + # # # train loss 提取 / 屏蔽backward + # # ########################################## + + # if self.use_apex: + # with amp.scale_loss(loss, self.optimizer) as scaled_loss: + # scaled_loss.backward() + # else: + # # Finally we need to normalize the loss for reporting + # if num_items_in_batch is None: + # loss = loss / self.args.gradient_accumulation_steps + + # self.accelerator.backward(loss, **kwargs) + + # return loss.detach() + + # ####################################################################################### + # ############################################################################################################################################################################### return loss.detach() / self.args.gradient_accumulation_steps + ############################################################################################################################################################################### + - def compute_loss(self, model, inputs, return_outputs=False): + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ - if self.label_smoother is not None and "labels" in inputs: + if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs: labels = inputs.pop("labels") else: labels = None + if self.model_accepts_loss_kwargs: + loss_kwargs = {} + if num_items_in_batch is not None: + loss_kwargs["num_items_in_batch"] = num_items_in_batch + inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. @@ -2830,12 +3801,15 @@ class Trainer: self._past = outputs[self.args.past_index] if labels is not None: - unwrapped_model = unwrap_model(model) + unwrapped_model = self.accelerator.unwrap_model(model) if _is_peft_model(unwrapped_model): model_name = unwrapped_model.base_model.model._get_name() else: model_name = unwrapped_model._get_name() - if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + # User-defined compute_loss function + if self.compute_loss_func is not None: + loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) + elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) @@ -2848,6 +3822,9 @@ class Trainer: # We don't use .loss here since the model may return tuples instead of ModelOutput. loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes + return (loss, outputs) if return_outputs else loss def is_local_process_zero(self) -> bool: @@ -2879,7 +3856,7 @@ class Trainer: if output_dir is None: output_dir = self.args.output_dir - if is_torch_tpu_available(): + if is_torch_xla_available(): self._save_tpu(output_dir) elif is_sagemaker_mp_enabled(): # Calling the state_dict needs to be done on the wrapped model and on all processes. @@ -2922,29 +3899,64 @@ class Trainer: def _save_tpu(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir + logger.info(f"Saving model checkpoint to {output_dir}") model = self.model - model.to("cpu") + xm.mark_step() - if xm.is_master_ordinal(): + if xm.is_master_ordinal(local=False): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` + supported_classes = (PushToHubMixin,) xm.rendezvous("saving_checkpoint") - if not isinstance(model, PreTrainedModel): - if isinstance(unwrap_model(model), PreTrainedModel): - unwrap_model(model).save_pretrained( + if self.is_fsdp_xla_v1_enabled: + ckpt = { + "model": model.state_dict(), + "shard_metadata": model.get_shard_metadata(), + } + ckpt_path = os.path.join( + output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}" + ) + # All ranks save sharded checkpoint + xm.save(ckpt, ckpt_path, master_only=False) + # Make sure all ranks have saved checkpoints + xm.rendezvous("save_full_checkpoints") + # Master save full checkpoint + if self.args.should_save: + from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints + + full_state_dict, _ = consolidate_sharded_model_checkpoints( + ckpt_prefix=os.path.join(output_dir, ""), + ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}", + save_model=False, + ) + model = model.module.module + unwrapped_model = self.accelerator.unwrap_model(model) + if isinstance(unwrapped_model, supported_classes): + unwrapped_model.save_pretrained( + output_dir, + state_dict=full_state_dict, + save_function=xm.save, + safe_serialization=self.args.save_safetensors, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + elif not isinstance(model, supported_classes): + if isinstance(self.accelerator.unwrap_model(model), supported_classes): + self.accelerator.unwrap_model(model).save_pretrained( output_dir, is_main_process=self.args.should_save, - state_dict=model.state_dict(), + state_dict=xm._maybe_convert_to_cpu(model.state_dict()), save_function=xm.save, safe_serialization=self.args.save_safetensors, ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - state_dict = model.state_dict() + state_dict = xm._maybe_convert_to_cpu(model.state_dict()) xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: model.save_pretrained( @@ -2952,13 +3964,10 @@ class Trainer: is_main_process=self.args.should_save, save_function=xm.save, safe_serialization=self.args.save_safetensors, + state_dict=xm._maybe_convert_to_cpu(model.state_dict()), ) - if self.tokenizer is not None and self.args.should_save: - self.tokenizer.save_pretrained(output_dir) - - # We moved the model from TPU -> CPU for saving the weights. - # Now we should move it back to subsequent compute still works. - model.to(self.args.device) + if self.processing_class is not None and self.args.should_save: + self.processing_class.save_pretrained(output_dir) def _save(self, output_dir: Optional[str] = None, state_dict=None): # If we are executing this function, we are the process zero, so we don't check for that. @@ -2973,8 +3982,8 @@ class Trainer: if state_dict is None: state_dict = self.model.state_dict() - if isinstance(unwrap_model(self.model), supported_classes): - unwrap_model(self.model).save_pretrained( + if isinstance(self.accelerator.unwrap_model(self.model), supported_classes): + self.accelerator.unwrap_model(self.model).save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) else: @@ -2990,8 +3999,8 @@ class Trainer: output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) - if self.tokenizer is not None: - self.tokenizer.save_pretrained(output_dir) + if self.processing_class is not None: + self.processing_class.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) @@ -3088,7 +4097,7 @@ class Trainer: When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the - loss on `data1` and `metric_for_best_model="eval_data1_loss"` for the loss on `data2`. + loss on `data1` and `metric_for_best_model="eval_data2_loss"` for the loss on `data2`. @@ -3104,12 +4113,13 @@ class Trainer: dictionary also contains the epoch number which comes from the training state. """ # handle multipe eval datasets - eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + override = eval_dataset is not None + eval_dataset = eval_dataset if override else self.eval_dataset if isinstance(eval_dataset, dict): metrics = {} for eval_dataset_name, _eval_dataset in eval_dataset.items(): dataset_metrics = self.evaluate( - eval_dataset=_eval_dataset, + eval_dataset=_eval_dataset if override else eval_dataset_name, ignore_keys=ignore_keys, metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}", ) @@ -3120,6 +4130,9 @@ class Trainer: self._memory_tracker.start() eval_dataloader = self.get_eval_dataloader(eval_dataset) + if self.is_fsdp_xla_v2_enabled: + eval_dataloader = tpu_spmd_dataloader(eval_dataloader) + start_time = time.time() eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop @@ -3136,6 +4149,8 @@ class Trainer: total_batch_size = self.args.eval_batch_size * self.args.world_size if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + if f"{metric_key_prefix}_model_preparation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -3205,6 +4220,8 @@ class Trainer: total_batch_size = self.args.eval_batch_size * self.args.world_size if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + if f"{metric_key_prefix}_model_preparation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"] output.metrics.update( speed_metrics( metric_key_prefix, @@ -3243,11 +4260,13 @@ class Trainer: model = self._wrap_model(self.model, training=False, dataloader=dataloader) if len(self.accelerator._models) == 0 and model is self.model: + start_time = time.time() model = ( self.accelerator.prepare(model) - if self.is_deepspeed_enabled + if self.is_deepspeed_enabled or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8") else self.accelerator.prepare_model(model, evaluation_mode=True) ) + self.model_preparation_time = round(time.time() - start_time, 4) if self.is_fsdp_enabled: self.model = model @@ -3270,7 +4289,7 @@ class Trainer: batch_size = self.args.eval_batch_size - logger.info(f"***** Running {description} *****") + logger.info(f"\n***** Running {description} *****") if has_length(dataloader): logger.info(f" Num examples = {self.num_examples(dataloader)}") else: @@ -3278,6 +4297,8 @@ class Trainer: logger.info(f" Batch size = {batch_size}") model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() self.callback_handler.eval_dataloader = dataloader # Do this before wrapping. @@ -3287,20 +4308,17 @@ class Trainer: self._past = None # Initialize containers - # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) - losses_host = None - preds_host = None - labels_host = None - inputs_host = None - - # losses/preds/labels on CPU (final containers) - all_losses = None - all_preds = None - all_labels = None - all_inputs = None - # Will be useful when we have an iterable dataset so don't know its length. + all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + + metrics = None + eval_set_kwargs = {} + # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 + # Main evaluation loop for step, inputs in enumerate(dataloader): # Update the observed num examples @@ -3312,63 +4330,64 @@ class Trainer: batch_size = observed_batch_size # Prediction step - loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) main_input_name = getattr(self.model, "main_input_name", "input_ids") - inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None + inputs_decode = ( + self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None + ) - if is_torch_tpu_available(): + if is_torch_xla_available(): xm.mark_step() - # Update containers on host - if loss is not None: - losses = self.gather_function((loss.repeat(batch_size))) - losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) - if labels is not None: - labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) + # Update containers + if losses is not None: + losses = self.gather_function((losses.repeat(batch_size))) + all_losses.add(losses) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.gather_function((inputs_decode)) - inputs_host = ( - inputs_decode - if inputs_host is None - else nested_concat(inputs_host, inputs_decode, padding_index=-100) - ) + if not self.args.batch_eval_metrics or description == "Prediction": + all_inputs.add(inputs_decode) + if labels is not None: + # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block. + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.gather_function((logits)) - preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) - + if not self.args.batch_eval_metrics or description == "Prediction": + all_preds.add(logits) if labels is not None: labels = self.gather_function((labels)) - labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if not self.args.batch_eval_metrics or description == "Prediction": + all_labels.add(labels) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) - # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( - inputs_decode - if all_inputs is None - else nested_concat(all_inputs, inputs_decode, padding_index=-100) - ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = ( - labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and logits is not None and labels is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + batch_kwargs = {} + batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None + batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics( + EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs), + compute_result=is_last_step, ) - # Set back to None to begin a new accumulation - losses_host, preds_host, inputs_host, labels_host = None, None, None, None + del losses, logits, labels, inputs + torch.cuda.empty_cache() + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + all_losses.to_cpu_and_numpy() + all_preds.to_cpu_and_numpy() + all_labels.to_cpu_and_numpy() + all_inputs.to_cpu_and_numpy() + + del losses, logits, labels, inputs + torch.cuda.empty_cache() # After all calls to `.gather_function`, reset to `gather_for_metrics`: self.gather_function = self.accelerator.gather_for_metrics @@ -3377,20 +4396,10 @@ class Trainer: delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( - inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) - ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + all_losses = all_losses.get_arrays() + all_preds = all_preds.get_arrays() + all_labels = all_labels.get_arrays() + all_inputs = all_inputs.get_arrays() # Number of samples if has_length(eval_dataset): @@ -3408,23 +4417,31 @@ class Trainer: num_samples = observed_num_examples # Metrics! - if self.compute_metrics is not None and all_preds is not None and all_labels is not None: - if args.include_inputs_for_metrics: - metrics = self.compute_metrics( - EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) - ) - else: - metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) - else: + if ( + self.compute_metrics is not None + and all_preds is not None + and all_labels is not None + and not self.args.batch_eval_metrics + ): + eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None + eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics( + EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs) + ) + elif metrics is None: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) - if all_losses is not None: + if isinstance(all_losses, list) and all_losses: + metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item() + elif isinstance(all_losses, np.ndarray): metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() if hasattr(self, "jit_compilation_time"): metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time + if hasattr(self, "model_preparation_time"): + metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time # Prefix all keys with metric_key_prefix + '_' for key in list(metrics.keys()): @@ -3440,7 +4457,7 @@ class Trainer: """ if tensors is None: return - if is_torch_tpu_available(): + if is_torch_xla_available(): if name is None: name = "nested_gather" tensors = nested_xla_mesh_reduce(tensors, name) @@ -3575,7 +4592,7 @@ class Trainer: else: return 0 - def init_hf_repo(self): + def init_hf_repo(self, token: Optional[str] = None): """ Initializes a git repo in `self.args.hub_model_id`. """ @@ -3588,7 +4605,8 @@ class Trainer: else: repo_name = self.args.hub_model_id - repo_url = create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) + token = token if token is not None else self.args.hub_token + repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True) self.hub_model_id = repo_url.repo_id self.push_in_progress = None @@ -3664,7 +4682,7 @@ class Trainer: f.write(model_card) if is_peft_library: - unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) + self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir) def _push_from_checkpoint(self, checkpoint_folder): # Only push from one node. @@ -3677,18 +4695,27 @@ class Trainer: output_dir = self.args.output_dir # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] + # Add sharded checkpoints if we have an index + for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]: + index_path = os.path.join(checkpoint_folder, index_file) + if os.path.isfile(index_path): + modeling_files.append(index_file) + with open(index_path) as f: + index = json.loads(f.read()) + shard_files = list(set(index["weight_map"].values())) + modeling_files.extend(shard_files) if is_peft_available(): modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) for modeling_file in modeling_files: if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) - # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. - if self.tokenizer is not None: - self.tokenizer.save_pretrained(output_dir) + # Saving the processing class is fast and we don't know how many files it may have spawned, so we resave it to be sure. + if self.processing_class is not None: + self.processing_class.save_pretrained(output_dir) # Same for the training arguments torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) - if self.args.save_strategy == IntervalStrategy.STEPS: + if self.args.save_strategy == SaveStrategy.STEPS: commit_message = f"Training in progress, step {self.state.global_step}" else: commit_message = f"Training in progress, epoch {int(self.state.epoch)}" @@ -3730,15 +4757,26 @@ class Trainer: logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.") self.push_in_progress.wait_until_done() - def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + def push_to_hub( + self, + commit_message: Optional[str] = "End of training", + blocking: bool = True, + token: Optional[str] = None, + revision: Optional[str] = None, + **kwargs, + ) -> str: """ - Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`. + Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`. Parameters: commit_message (`str`, *optional*, defaults to `"End of training"`): Message to commit while pushing. blocking (`bool`, *optional*, defaults to `True`): Whether the function should return only when the `git push` has finished. + token (`str`, *optional*, defaults to `None`): + Token with write permission to overwrite Trainer's original args. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the "main" branch. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to [`~Trainer.create_model_card`]. @@ -3752,10 +4790,11 @@ class Trainer: model_name = Path(self.args.output_dir).name else: model_name = self.args.hub_model_id.split("/")[-1] + token = token if token is not None else self.args.hub_token # In case the user calls this method with args.push_to_hub = False if self.hub_model_id is None: - self.init_hf_repo() + self.init_hf_repo(token=token) # Needs to be executed on all processes for TPU training, but will only save on the processed determined by # self.args.should_save. @@ -3768,7 +4807,10 @@ class Trainer: # Add additional tags in the case the model has already some tags and users pass # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags # from all models since Trainer does not call `model.push_to_hub`. - if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None: + if getattr(self.model, "model_tags", None) is not None: + if "tags" not in kwargs: + kwargs["tags"] = [] + # If it is a string, convert it to a list if isinstance(kwargs["tags"], str): kwargs["tags"] = [kwargs["tags"]] @@ -3785,9 +4827,10 @@ class Trainer: repo_id=self.hub_model_id, folder_path=self.args.output_dir, commit_message=commit_message, - token=self.args.hub_token, + token=token, run_as_future=not blocking, ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"], + revision=revision, ) # @@ -3823,7 +4866,7 @@ class Trainer: if len(self.accelerator._models) == 0 and model is self.model: model = ( self.accelerator.prepare(model) - if self.is_deepspeed_enabled + if self.is_deepspeed_enabled or self.is_fsdp_enabled else self.accelerator.prepare_model(model, evaluation_mode=True) ) @@ -3846,15 +4889,28 @@ class Trainer: elif args.bf16_full_eval: model = model.to(dtype=torch.bfloat16, device=args.device) - batch_size = dataloader.batch_size + batch_size = ( + dataloader.total_batch_size + if getattr(dataloader, "_is_accelerate_prepared", False) + else dataloader.batch_size + ) + + if batch_size is None: + raise ValueError( + "Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size." + ) + num_examples = self.num_examples(dataloader) - logger.info(f"***** Running {description} *****") + logger.info(f"\n***** Running {description} *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Batch size = {batch_size}") + losses_host: torch.Tensor = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None + metrics: Optional[dict] = None + eval_set_kwargs: dict = {} world_size = max(1, args.world_size) @@ -3870,6 +4926,8 @@ class Trainer: inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) model.eval() + if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval): + self.optimizer.eval() if args.past_index >= 0: self._past = None @@ -3879,7 +4937,9 @@ class Trainer: for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) main_input_name = getattr(self.model, "main_input_name", "input_ids") - inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None + inputs_decode = ( + self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None + ) if loss is not None: losses = loss.repeat(batch_size) @@ -3896,8 +4956,21 @@ class Trainer: ) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) - # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if self.args.batch_eval_metrics: + if self.compute_metrics is not None and preds_host is not None and labels_host is not None: + is_last_step = self.accelerator.gradient_state.end_of_dataloader + batch_kwargs = {} + batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None + batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics( + EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs), + compute_result=is_last_step, + ) + + if self.args.batch_eval_metrics or ( + args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0 + ): + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) @@ -3905,6 +4978,8 @@ class Trainer: inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) # Set back to None to begin a new accumulation + del losses_host, preds_host, labels_host, inputs_host + torch.cuda.empty_cache() losses_host, preds_host, labels_host, inputs_host = None, None, None, None if args.past_index and hasattr(self, "_past"): @@ -3923,14 +4998,16 @@ class Trainer: label_ids = labels_gatherer.finalize() if not prediction_loss_only else None inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None - if self.compute_metrics is not None and preds is not None and label_ids is not None: - if args.include_inputs_for_metrics: - metrics = self.compute_metrics( - EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) - ) - else: - metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) - else: + if ( + self.compute_metrics is not None + and preds is not None + and label_ids is not None + and not self.args.batch_eval_metrics + ): + eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None + eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None + metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs)) + elif metrics is None: metrics = {} # To be JSON-serializable, we need to remove numpy types or zero-d tensors @@ -3953,7 +5030,7 @@ class Trainer: """ if tensors is None: return - if is_torch_tpu_available(): + if is_torch_xla_available(): tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) @@ -4002,20 +5079,67 @@ class Trainer: self.repo.git_push() def create_accelerator_and_postprocess(self): - grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} - grad_acc_kwargs["sync_with_dataloader"] = False - gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + # We explicitly don't rely on the `Accelerator` to do gradient accumulation + grad_acc_kwargs = {} + if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: + grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs + + # check if num_steps is attempted to be passed in gradient_accumulation_kwargs + if "num_steps" in grad_acc_kwargs: + if self.args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + else: + self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"] + + accelerator_config = self.args.accelerator_config.to_dict() + + if is_accelerate_available("0.28.0"): + dataloader_config = DataLoaderConfiguration( + split_batches=accelerator_config.pop("split_batches"), + dispatch_batches=accelerator_config.pop("dispatch_batches"), + even_batches=accelerator_config.pop("even_batches"), + use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), + ) + if is_accelerate_available("1.1.0"): + dataloader_config.data_seed = self.args.data_seed + + non_blocking = accelerator_config.pop("non_blocking") + if not is_accelerate_available("0.30.0"): + if non_blocking: + raise ImportError( + "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature." + ) + else: + if non_blocking and not self.args.dataloader_pin_memory: + logger.warning( + "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both." + ) + dataloader_config.non_blocking = non_blocking + # this would have been updated above, no need for it anymore + accelerator_config.pop("gradient_accumulation_kwargs") + + args = { + "deepspeed_plugin": self.args.deepspeed_plugin, + } + if is_accelerate_available("0.28.0"): + args["dataloader_config"] = dataloader_config + else: + args.update(accelerator_config) # create accelerator object - self.accelerator = Accelerator( - dispatch_batches=self.args.dispatch_batches, - split_batches=self.args.split_batches, - deepspeed_plugin=self.args.deepspeed_plugin, - gradient_accumulation_plugin=gradient_accumulation_plugin, - ) + self.accelerator = Accelerator(**args) # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics + if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys(): + self.gather_function = functools.partial( + self.gather_function, use_gather_object=self.args.eval_use_gather_object + ) + # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None @@ -4026,20 +5150,38 @@ class Trainer: fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( "limit_all_gathers", fsdp_plugin.limit_all_gathers ) - if is_accelerate_available("0.23.0"): - fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( - "activation_checkpointing", fsdp_plugin.activation_checkpointing + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( + "activation_checkpointing", fsdp_plugin.activation_checkpointing + ) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." ) - if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: - raise ValueError( - "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " - "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " - "when using FSDP." - ) if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: self.propagate_args_to_deepspeed() + # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end` + if ( + self.args.save_only_model + and (self.is_deepspeed_enabled or self.is_fsdp_enabled) + and self.args.load_best_model_at_end + ): + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.") + + # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3 + if ( + self.is_deepspeed_enabled + and self.accelerator.state.deepspeed_plugin.zero_stage == 3 + and self.args.auto_find_batch_size + ): + raise ValueError( + "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP" + ) + def propagate_args_to_deepspeed(self, auto_find_batch_size=False): """ Sets values in the deepspeed plugin based on the Trainer args @@ -4051,3 +5193,44 @@ class Trainer: ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size) + + def _fsdp_qlora_plugin_updates(self): + if self.is_fsdp_enabled and _is_peft_model(self.model): + from peft import LoraConfig + from peft.utils.other import fsdp_auto_wrap_policy + + if isinstance(self.model.active_peft_config, LoraConfig): + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model) + if ( + getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point + and version.parse(accelerate_version) > version.parse("0.27.0") + ): + fsdp_plugin.set_mixed_precision( + self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True + ) + + def get_batch_samples(self, epoch_iterator, num_batches): + batch_samples = [] + num_items_in_batch = None + for _ in range(num_batches): + try: + batch_samples += [next(epoch_iterator)] + except StopIteration: + break + + if len(batch_samples) > 0 and "labels" in batch_samples[0]: + # For now we don't support object detection + try: + num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + except (TypeError, AttributeError): + pass + + if self.args.average_tokens_across_devices: + num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() + + if torch.is_tensor(num_items_in_batch): + num_items_in_batch = num_items_in_batch.item() + + return batch_samples, num_items_in_batch