diff --git a/accelerate_configs/uncompiled_4.yaml b/accelerate_configs/uncompiled_4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e15d4c6cd56145c4653e97b8cbdd823b154b6207 --- /dev/null +++ b/accelerate_configs/uncompiled_4.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: 0,1,2,3 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/finetrainers/__init__.py b/finetrainers/__init__.py index 412e298eb519f037e08dc755f92d136cfe2ef2e6..7da2391e864af71edf8b826d1f1263d5c8f1afe5 100644 --- a/finetrainers/__init__.py +++ b/finetrainers/__init__.py @@ -1,2 +1,5 @@ -from .args import Args, parse_arguments -from .trainer import Trainer +from .args import BaseArgs +from .config import ModelType, TrainingType +from .logging import get_logger +from .models import ModelSpecification +from .trainer import SFTTrainer diff --git a/finetrainers/args.py b/finetrainers/args.py index 46cd04cca1c0d7368be8395ce3382bc28fc6865b..199c7493ab0fcba1724b9e6eb7dbc70ca237ed64 100644 --- a/finetrainers/args.py +++ b/finetrainers/args.py @@ -1,14 +1,21 @@ import argparse +import os +import pathlib import sys -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import torch -from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS -from .models import SUPPORTED_MODEL_CONFIGS +from .config import SUPPORTED_MODEL_CONFIGS, ModelType, TrainingType +from .logging import get_logger +from .parallel import ParallelBackendEnum +from .utils import get_non_null_items -class Args: +logger = get_logger() + + +class BaseArgs: r""" The arguments for the finetrainers training script. @@ -19,6 +26,19 @@ class Args: TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show memory requirements per model, per training type with sensible training settings. + PARALLEL ARGUMENTS + ------------------ + parallel_backend (`str`, defaults to `accelerate`): + The parallel backend to use for training. Choose between ['accelerate', 'ptd']. + pp_degree (`int`, defaults to `1`): + The degree of pipeline parallelism. + dp_degree (`int`, defaults to `1`): + The degree of data parallelism (number of model replicas). + dp_shards (`int`, defaults to `-1`): + The number of data parallel shards (number of model partitions). + cp_degree (`int`, defaults to `1`): + The degree of context parallelism. + MODEL ARGUMENTS --------------- model_name (`str`): @@ -33,6 +53,22 @@ class Args: storage requirements. cache_dir (`str`, defaults to `None`): The directory where the downloaded models and datasets will be stored, or loaded from. + tokenizer_id (`str`, defaults to `None`): + Identifier for the tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. + tokenizer_2_id (`str`, defaults to `None`): + Identifier for the second tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. + tokenizer_3_id (`str`, defaults to `None`): + Identifier for the third tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`. + text_encoder_id (`str`, defaults to `None`): + Identifier for the text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. + text_encoder_2_id (`str`, defaults to `None`): + Identifier for the second text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. + text_encoder_3_id (`str`, defaults to `None`): + Identifier for the third text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`. + transformer_id (`str`, defaults to `None`): + Identifier for the transformer model. This is useful when using a different transformer model than the default from `pretrained_model_name_or_path`. + vae_id (`str`, defaults to `None`): + Identifier for the VAE model. This is useful when using a different VAE model than the default from `pretrained_model_name_or_path`. text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`): Data type for the text encoder when generating text embeddings. text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`): @@ -54,41 +90,47 @@ class Args: DATASET ARGUMENTS ----------------- - data_root (`str`): - A folder containing the training data. - dataset_file (`str`, defaults to `None`): - Path to a CSV/JSON/JSONL file containing metadata for training. This should be provided if you're not using - a directory dataset format containing a simple `prompts.txt` and `videos.txt`/`images.txt` for example. - video_column (`str`): - The column of the dataset containing videos. Or, the name of the file in `data_root` folder containing the - line-separated path to video data. - caption_column (`str`): - The column of the dataset containing the instance prompt for each video. Or, the name of the file in - `data_root` folder containing the line-separated instance prompts. - id_token (`str`, defaults to `None`): - Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training. - image_resolution_buckets (`List[Tuple[int, int]]`, defaults to `None`): - Resolution buckets for images. This should be a list of integer tuples, where each tuple represents the - resolution (height, width) of the image. All images will be resized to the nearest bucket resolution. - video_resolution_buckets (`List[Tuple[int, int, int]]`, defaults to `None`): - Resolution buckets for videos. This should be a list of integer tuples, where each tuple represents the - resolution (num_frames, height, width) of the video. All videos will be resized to the nearest bucket - resolution. - video_reshape_mode (`str`, defaults to `None`): - All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']. - TODO(aryan): We don't support this. - caption_dropout_p (`float`, defaults to `0.00`): - Probability of dropout for the caption tokens. This is useful to improve the unconditional generation - quality of the model. - caption_dropout_technique (`str`, defaults to `empty`): - Technique to use for caption dropout. Choose between ['empty', 'zero']. Some models apply caption dropout - by setting the prompt condition to an empty string, while others zero-out the text embedding tensors. - precompute_conditions (`bool`, defaults to `False`): - Whether or not to precompute the conditionings for the model. This is useful for faster training, and - reduces the memory requirements. - remove_common_llm_caption_prefixes (`bool`, defaults to `False`): - Whether or not to remove common LLM caption prefixes. This is useful for improving the quality of the - generated text. + dataset_config (`str`): + File to a dataset file containing information about training data. This file can contain information about one or + more datasets in JSON format. The file must have a key called "datasets", which is a list of dictionaries. Each + dictionary must contain the following keys: + - "data_root": (`str`) + The root directory containing the dataset. This parameter must be provided if `dataset_file` is not provided. + - "dataset_file": (`str`) + Path to a CSV/JSON/JSONL/PARQUET/ARROW/HF_HUB_DATASET file containing metadata for training. This parameter + must be provided if `data_root` is not provided. + - "dataset_type": (`str`) + Type of dataset. Choose between ['image', 'video']. + - "id_token": (`str`) + Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training + for single subject/concept/style training, but is not necessary. + - "image_resolution_buckets": (`List[Tuple[int, int]]`) + Resolution buckets for image. This should be a list of tuples containing 2 values, where each tuple + represents the resolution (height, width). All images will be resized to the nearest bucket resolution. + This parameter must be provided if `dataset_type` is 'image'. + - "video_resolution_buckets": (`List[Tuple[int, int, int]]`) + Resolution buckets for video. This should be a list of tuples containing 3 values, where each tuple + represents the resolution (num_frames, height, width). All videos will be resized to the nearest bucket + resolution. This parameter must be provided if `dataset_type` is 'video'. + - "reshape_mode": (`str`) + All input images/videos are reshaped using this mode. Choose between the following: + ["center_crop", "random_crop", "bicubic"]. + - "remove_common_llm_caption_prefixes": (`boolean`) + Whether or not to remove common LLM caption prefixes. See `~constants.py` for the list of common prefixes. + dataset_shuffle_buffer_size (`int`, defaults to `1`): + The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default + value of `1` means that the dataset will not be shuffled. + precomputation_items (`int`, defaults to `512`): + Number of data samples to precompute at once for memory-efficient training. The higher this value, + the more disk memory will be used to save the precomputed samples (conditions and latents). + precomputation_dir (`str`, defaults to `None`): + The directory where the precomputed samples will be stored. If not provided, the precomputed samples + will be stored in a temporary directory of the output directory. + precomputation_once (`bool`, defaults to `False`): + Precompute embeddings from all datasets at once before training. This is useful to save time during training + with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during + training when required. Make sure to set `precomputation_items` to a reasonable value in line with the size + of your dataset(s). DATALOADER_ARGUMENTS -------------------- @@ -136,16 +178,11 @@ class Args: A seed for reproducible training. batch_size (`int`, defaults to `1`): Per-device batch size. - train_epochs (`int`, defaults to `1`): - Number of training epochs. - train_steps (`int`, defaults to `None`): - Total number of training steps to perform. If provided, overrides `train_epochs`. - rank (`int`, defaults to `128`): - The rank for LoRA matrices. - lora_alpha (`float`, defaults to `64`): - The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices. - target_modules (`List[str]`, defaults to `["to_k", "to_q", "to_v", "to_out.0"]`): - The target modules for LoRA. Make sure to modify this based on the model. + train_steps (`int`, defaults to `1000`): + Total number of training steps to perform. + max_data_samples (`int`, defaults to `2**64`): + Maximum number of data samples observed during training training. If lesser than that required by `train_steps`, + the training will stop early. gradient_accumulation_steps (`int`, defaults to `1`): Number of gradients steps to accumulate before performing an optimizer step. gradient_checkpointing (`bool`, defaults to `False`): @@ -164,13 +201,11 @@ class Args: OPTIMIZER ARGUMENTS ------------------- optimizer (`str`, defaults to `adamw`): - The optimizer type to use. Choose between ['adam', 'adamw']. - use_8bit_bnb (`bool`, defaults to `False`): - Whether to use 8bit variant of the `optimizer` using `bitsandbytes`. + The optimizer type to use. Choose between the following: + - Torch optimizers: ["adam", "adamw"] + - Bitsandbytes optimizers: ["adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"] lr (`float`, defaults to `1e-4`): Initial learning rate (after the potential warmup period) to use. - scale_lr (`bool`, defaults to `False`): - Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size. lr_scheduler (`str`, defaults to `cosine_with_restarts`): The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup']. @@ -192,29 +227,21 @@ class Args: VALIDATION ARGUMENTS -------------------- - validation_prompts (`List[str]`, defaults to `None`): - List of prompts to use for validation. If not provided, a random prompt will be selected from the training - dataset. - validation_images (`List[str]`, defaults to `None`): - List of image paths to use for validation. - validation_videos (`List[str]`, defaults to `None`): - List of video paths to use for validation. - validation_heights (`List[int]`, defaults to `None`): - List of heights for the validation videos. - validation_widths (`List[int]`, defaults to `None`): - List of widths for the validation videos. - validation_num_frames (`List[int]`, defaults to `None`): - List of number of frames for the validation videos. - num_validation_videos_per_prompt (`int`, defaults to `1`): - Number of videos to use for validation per prompt. - validation_every_n_epochs (`int`, defaults to `None`): - Perform validation every `n` training epochs. - validation_every_n_steps (`int`, defaults to `None`): - Perform validation every `n` training steps. + validation_dataset_file (`str`, defaults to `None`): + Path to a CSV/JSON/PARQUET/ARROW file containing information for validation. The file must contain atleast the + "caption" column. Other columns such as "image_path" and "video_path" can be provided too. If provided, "image_path" + will be used to load a PIL.Image.Image and set the "image" key in the sample dictionary. Similarly, "video_path" + will be used to load a List[PIL.Image.Image] and set the "video" key in the sample dictionary. + The validation dataset file may contain other attributes specific to inference/validation such as: + - "height" and "width" and "num_frames": Resolution + - "num_inference_steps": Number of inference steps + - "guidance_scale": Classifier-free Guidance Scale + - ... (any number of additional attributes can be provided. The ModelSpecification::validate method will be + invoked with the sample dictionary to validate the sample.) + validation_steps (`int`, defaults to `500`): + Number of training steps after which a validation step is performed. enable_model_cpu_offload (`bool`, defaults to `False`): Whether or not to offload different modeling components to CPU during validation. - validation_frame_rate (`int`, defaults to `25`): - Frame rate to use for the validation videos. This value is defaulted to 25, as used in LTX Video pipeline. MISCELLANEOUS ARGUMENTS ----------------------- @@ -230,20 +257,44 @@ class Args: The directory where the model checkpoints and logs will be stored. logging_dir (`str`, defaults to `logs`): The directory where the logs will be stored. + logging_steps (`int`, defaults to `1`): + Training logs will be tracked every `logging_steps` steps. allow_tf32 (`bool`, defaults to `False`): Whether or not to allow the use of TF32 matmul on compatible hardware. nccl_timeout (`int`, defaults to `1800`): Timeout for the NCCL communication. report_to (`str`, defaults to `wandb`): The name of the logger to use for logging training metrics. Choose between ['wandb']. + verbose (`int`, defaults to `1`): + Whether or not to print verbose logs. + - 0: Diffusers/Transformers warning logging on local main process only + - 1: Diffusers/Transformers info logging on local main process only + - 2: Diffusers/Transformers debug logging on local main process only + - 3: Diffusers/Transformers debug logging on all processes """ + # Parallel arguments + parallel_backend = ParallelBackendEnum.ACCELERATE + pp_degree: int = 1 + dp_degree: int = 1 + dp_shards: int = 1 + cp_degree: int = 1 + tp_degree: int = 1 + # Model arguments model_name: str = None pretrained_model_name_or_path: str = None revision: Optional[str] = None variant: Optional[str] = None cache_dir: Optional[str] = None + tokenizer_id: Optional[str] = None + tokenizer_2_id: Optional[str] = None + tokenizer_3_id: Optional[str] = None + text_encoder_id: Optional[str] = None + text_encoder_2_id: Optional[str] = None + text_encoder_3_id: Optional[str] = None + transformer_id: Optional[str] = None + vae_id: Optional[str] = None text_encoder_dtype: torch.dtype = torch.bfloat16 text_encoder_2_dtype: torch.dtype = torch.bfloat16 text_encoder_3_dtype: torch.dtype = torch.bfloat16 @@ -263,18 +314,11 @@ class Args: ] # Dataset arguments - data_root: str = None - dataset_file: Optional[str] = None - video_column: str = None - caption_column: str = None - id_token: Optional[str] = None - image_resolution_buckets: List[Tuple[int, int]] = None - video_resolution_buckets: List[Tuple[int, int, int]] = None - video_reshape_mode: Optional[str] = None - caption_dropout_p: float = 0.00 - caption_dropout_technique: str = "empty" - precompute_conditions: bool = False - remove_common_llm_caption_prefixes: bool = False + dataset_config: str = None + dataset_shuffle_buffer_size: int = 1 + precomputation_items: int = 512 + precomputation_dir: Optional[str] = None + precomputation_once: bool = False # Dataloader arguments dataloader_num_workers: int = 0 @@ -296,11 +340,8 @@ class Args: training_type: str = None seed: int = 42 batch_size: int = 1 - train_epochs: int = 1 - train_steps: int = None - rank: int = 128 - lora_alpha: float = 64 - target_modules: List[str] = ["to_k", "to_q", "to_v", "to_out.0"] + train_steps: int = 1000 + max_data_samples: int = 2**64 gradient_accumulation_steps: int = 1 gradient_checkpointing: bool = False checkpointing_steps: int = 500 @@ -311,9 +352,7 @@ class Args: # Optimizer arguments optimizer: str = "adamw" - use_8bit_bnb: bool = False lr: float = 1e-4 - scale_lr: bool = False lr_scheduler: str = "cosine_with_restarts" lr_warmup_steps: int = 0 lr_num_cycles: int = 1 @@ -326,17 +365,9 @@ class Args: max_grad_norm: float = 1.0 # Validation arguments - validation_prompts: List[str] = None - validation_images: List[str] = None - validation_videos: List[str] = None - validation_heights: List[int] = None - validation_widths: List[int] = None - validation_num_frames: List[int] = None - num_validation_videos_per_prompt: int = 1 - validation_every_n_epochs: Optional[int] = None - validation_every_n_steps: Optional[int] = None + validation_dataset_file: Optional[str] = None + validation_steps: int = 500 enable_model_cpu_offload: bool = False - validation_frame_rate: int = 25 # Miscellaneous arguments tracker_name: str = "finetrainers" @@ -345,664 +376,343 @@ class Args: hub_model_id: Optional[str] = None output_dir: str = None logging_dir: Optional[str] = "logs" + logging_steps: int = 1 allow_tf32: bool = False - nccl_timeout: int = 1800 # 30 minutes + init_timeout: int = 300 # 5 minutes + nccl_timeout: int = 600 # 10 minutes, considering that validation may be performed report_to: str = "wandb" + verbose: int = 1 def to_dict(self) -> Dict[str, Any]: - return { - "model_arguments": { - "model_name": self.model_name, - "pretrained_model_name_or_path": self.pretrained_model_name_or_path, - "revision": self.revision, - "variant": self.variant, - "cache_dir": self.cache_dir, - "text_encoder_dtype": self.text_encoder_dtype, - "text_encoder_2_dtype": self.text_encoder_2_dtype, - "text_encoder_3_dtype": self.text_encoder_3_dtype, - "transformer_dtype": self.transformer_dtype, - "vae_dtype": self.vae_dtype, - "layerwise_upcasting_modules": self.layerwise_upcasting_modules, - "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype, - "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern, - }, - "dataset_arguments": { - "data_root": self.data_root, - "dataset_file": self.dataset_file, - "video_column": self.video_column, - "caption_column": self.caption_column, - "id_token": self.id_token, - "image_resolution_buckets": self.image_resolution_buckets, - "video_resolution_buckets": self.video_resolution_buckets, - "video_reshape_mode": self.video_reshape_mode, - "caption_dropout_p": self.caption_dropout_p, - "caption_dropout_technique": self.caption_dropout_technique, - "precompute_conditions": self.precompute_conditions, - "remove_common_llm_caption_prefixes": self.remove_common_llm_caption_prefixes, - }, - "dataloader_arguments": { - "dataloader_num_workers": self.dataloader_num_workers, - "pin_memory": self.pin_memory, - }, - "diffusion_arguments": { - "flow_resolution_shifting": self.flow_resolution_shifting, - "flow_base_seq_len": self.flow_base_seq_len, - "flow_max_seq_len": self.flow_max_seq_len, - "flow_base_shift": self.flow_base_shift, - "flow_max_shift": self.flow_max_shift, - "flow_shift": self.flow_shift, - "flow_weighting_scheme": self.flow_weighting_scheme, - "flow_logit_mean": self.flow_logit_mean, - "flow_logit_std": self.flow_logit_std, - "flow_mode_scale": self.flow_mode_scale, - }, - "training_arguments": { - "training_type": self.training_type, - "seed": self.seed, - "batch_size": self.batch_size, - "train_epochs": self.train_epochs, - "train_steps": self.train_steps, - "rank": self.rank, - "lora_alpha": self.lora_alpha, - "target_modules": self.target_modules, - "gradient_accumulation_steps": self.gradient_accumulation_steps, - "gradient_checkpointing": self.gradient_checkpointing, - "checkpointing_steps": self.checkpointing_steps, - "checkpointing_limit": self.checkpointing_limit, - "resume_from_checkpoint": self.resume_from_checkpoint, - "enable_slicing": self.enable_slicing, - "enable_tiling": self.enable_tiling, - }, - "optimizer_arguments": { - "optimizer": self.optimizer, - "use_8bit_bnb": self.use_8bit_bnb, - "lr": self.lr, - "scale_lr": self.scale_lr, - "lr_scheduler": self.lr_scheduler, - "lr_warmup_steps": self.lr_warmup_steps, - "lr_num_cycles": self.lr_num_cycles, - "lr_power": self.lr_power, - "beta1": self.beta1, - "beta2": self.beta2, - "beta3": self.beta3, - "weight_decay": self.weight_decay, - "epsilon": self.epsilon, - "max_grad_norm": self.max_grad_norm, - }, - "validation_arguments": { - "validation_prompts": self.validation_prompts, - "validation_images": self.validation_images, - "validation_videos": self.validation_videos, - "num_validation_videos_per_prompt": self.num_validation_videos_per_prompt, - "validation_every_n_epochs": self.validation_every_n_epochs, - "validation_every_n_steps": self.validation_every_n_steps, - "enable_model_cpu_offload": self.enable_model_cpu_offload, - "validation_frame_rate": self.validation_frame_rate, - }, - "miscellaneous_arguments": { - "tracker_name": self.tracker_name, - "push_to_hub": self.push_to_hub, - "hub_token": self.hub_token, - "hub_model_id": self.hub_model_id, - "output_dir": self.output_dir, - "logging_dir": self.logging_dir, - "allow_tf32": self.allow_tf32, - "nccl_timeout": self.nccl_timeout, - "report_to": self.report_to, - }, + parallel_arguments = { + "pp_degree": self.pp_degree, + "dp_degree": self.dp_degree, + "dp_shards": self.dp_shards, + "cp_degree": self.cp_degree, + "tp_degree": self.tp_degree, } + model_arguments = { + "model_name": self.model_name, + "pretrained_model_name_or_path": self.pretrained_model_name_or_path, + "revision": self.revision, + "variant": self.variant, + "cache_dir": self.cache_dir, + "tokenizer_id": self.tokenizer_id, + "tokenizer_2_id": self.tokenizer_2_id, + "tokenizer_3_id": self.tokenizer_3_id, + "text_encoder_id": self.text_encoder_id, + "text_encoder_2_id": self.text_encoder_2_id, + "text_encoder_3_id": self.text_encoder_3_id, + "transformer_id": self.transformer_id, + "vae_id": self.vae_id, + "text_encoder_dtype": self.text_encoder_dtype, + "text_encoder_2_dtype": self.text_encoder_2_dtype, + "text_encoder_3_dtype": self.text_encoder_3_dtype, + "transformer_dtype": self.transformer_dtype, + "vae_dtype": self.vae_dtype, + "layerwise_upcasting_modules": self.layerwise_upcasting_modules, + "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype, + "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern, + } + model_arguments = get_non_null_items(model_arguments) + + dataset_arguments = { + "dataset_config": self.dataset_config, + "dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size, + "precomputation_items": self.precomputation_items, + "precomputation_dir": self.precomputation_dir, + "precomputation_once": self.precomputation_once, + } + dataset_arguments = get_non_null_items(dataset_arguments) -# TODO(aryan): handle more informative messages -_IS_ARGUMENTS_REQUIRED = "--list_models" not in sys.argv - - -def parse_arguments() -> Args: - parser = argparse.ArgumentParser() + dataloader_arguments = { + "dataloader_num_workers": self.dataloader_num_workers, + "pin_memory": self.pin_memory, + } - if _IS_ARGUMENTS_REQUIRED: - _add_model_arguments(parser) - _add_dataset_arguments(parser) - _add_dataloader_arguments(parser) - _add_diffusion_arguments(parser) - _add_training_arguments(parser) - _add_optimizer_arguments(parser) - _add_validation_arguments(parser) - _add_miscellaneous_arguments(parser) + diffusion_arguments = { + "flow_resolution_shifting": self.flow_resolution_shifting, + "flow_base_seq_len": self.flow_base_seq_len, + "flow_max_seq_len": self.flow_max_seq_len, + "flow_base_shift": self.flow_base_shift, + "flow_max_shift": self.flow_max_shift, + "flow_shift": self.flow_shift, + "flow_weighting_scheme": self.flow_weighting_scheme, + "flow_logit_mean": self.flow_logit_mean, + "flow_logit_std": self.flow_logit_std, + "flow_mode_scale": self.flow_mode_scale, + } - args = parser.parse_args() - return _map_to_args_type(args) - else: - _add_helper_arguments(parser) + training_arguments = { + "training_type": self.training_type, + "seed": self.seed, + "batch_size": self.batch_size, + "train_steps": self.train_steps, + "max_data_samples": self.max_data_samples, + "gradient_accumulation_steps": self.gradient_accumulation_steps, + "gradient_checkpointing": self.gradient_checkpointing, + "checkpointing_steps": self.checkpointing_steps, + "checkpointing_limit": self.checkpointing_limit, + "resume_from_checkpoint": self.resume_from_checkpoint, + "enable_slicing": self.enable_slicing, + "enable_tiling": self.enable_tiling, + } + training_arguments = get_non_null_items(training_arguments) + + optimizer_arguments = { + "optimizer": self.optimizer, + "lr": self.lr, + "lr_scheduler": self.lr_scheduler, + "lr_warmup_steps": self.lr_warmup_steps, + "lr_num_cycles": self.lr_num_cycles, + "lr_power": self.lr_power, + "beta1": self.beta1, + "beta2": self.beta2, + "beta3": self.beta3, + "weight_decay": self.weight_decay, + "epsilon": self.epsilon, + "max_grad_norm": self.max_grad_norm, + } + optimizer_arguments = get_non_null_items(optimizer_arguments) - args = parser.parse_args() - _display_helper_messages(args) - sys.exit(0) + validation_arguments = { + "validation_dataset_file": self.validation_dataset_file, + "validation_steps": self.validation_steps, + "enable_model_cpu_offload": self.enable_model_cpu_offload, + } + validation_arguments = get_non_null_items(validation_arguments) + + miscellaneous_arguments = { + "tracker_name": self.tracker_name, + "push_to_hub": self.push_to_hub, + "hub_token": self.hub_token, + "hub_model_id": self.hub_model_id, + "output_dir": self.output_dir, + "logging_dir": self.logging_dir, + "logging_steps": self.logging_steps, + "allow_tf32": self.allow_tf32, + "init_timeout": self.init_timeout, + "nccl_timeout": self.nccl_timeout, + "report_to": self.report_to, + "verbose": self.verbose, + } + miscellaneous_arguments = get_non_null_items(miscellaneous_arguments) + return { + "parallel_arguments": parallel_arguments, + "model_arguments": model_arguments, + "dataset_arguments": dataset_arguments, + "dataloader_arguments": dataloader_arguments, + "diffusion_arguments": diffusion_arguments, + "training_arguments": training_arguments, + "optimizer_arguments": optimizer_arguments, + "validation_arguments": validation_arguments, + "miscellaneous_arguments": miscellaneous_arguments, + } -def validate_args(args: Args): - _validated_model_args(args) - _validate_training_args(args) + def extend_args( + self, + add_fn: Callable[[argparse.ArgumentParser], None], + map_fn: Callable[["BaseArgs"], None], + validate_fn: Callable[["BaseArgs"], None], + ) -> None: + if not hasattr(self, "_extended_add_arguments"): + self._extended_add_arguments = [] + self._extended_add_arguments.append((add_fn, validate_fn, map_fn)) + + def parse_args(self): + _LIST_MODELS = "--list_models" + + parser = argparse.ArgumentParser() + + special_args = [_LIST_MODELS] + if any(arg in sys.argv for arg in special_args): + _add_helper_arguments(parser) + args = parser.parse_args() + _display_helper_messages(args) + sys.exit(0) + else: + _add_args(parser) + for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): + add_fn, _, _ = extended_add_arg_fns + add_fn(parser) + + args, remaining_args = parser.parse_known_args() + logger.debug(f"Remaining unparsed arguments: {remaining_args}") + + mapped_args = _map_to_args_type(args) + for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): + _, _, map_fn = extended_add_arg_fns + map_fn(args, mapped_args) + + _validate_args(mapped_args) + for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []): + _, validate_fn, _ = extended_add_arg_fns + validate_fn(mapped_args) + + return mapped_args + + +def _add_args(parser: argparse.ArgumentParser) -> None: + _add_parallel_arguments(parser) + _add_model_arguments(parser) + _add_dataset_arguments(parser) + _add_dataloader_arguments(parser) + _add_diffusion_arguments(parser) + _add_training_arguments(parser) + _add_optimizer_arguments(parser) + _add_validation_arguments(parser) + _add_miscellaneous_arguments(parser) + + +def _validate_args(args: BaseArgs): + _validate_model_args(args) + _validate_dataset_args(args) _validate_validation_args(args) -def _add_model_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--model_name", - type=str, - required=True, - choices=list(SUPPORTED_MODEL_CONFIGS.keys()), - help="Name of model to train.", - ) - parser.add_argument( - "--pretrained_model_name_or_path", - type=str, - required=True, - help="Path to pretrained model or model identifier from huggingface.co/models.", - ) +def _add_parallel_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument( - "--revision", + "--parallel_backend", type=str, - default=None, - required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", + default=ParallelBackendEnum.ACCELERATE, + choices=[ParallelBackendEnum.ACCELERATE, ParallelBackendEnum.PTD], ) + parser.add_argument("--pp_degree", type=int, default=1) + parser.add_argument("--dp_degree", type=int, default=1) + parser.add_argument("--dp_shards", type=int, default=1) + parser.add_argument("--cp_degree", type=int, default=1) + parser.add_argument("--tp_degree", type=int, default=1) + + +def _add_model_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument( - "--variant", - type=str, - default=None, - help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", - ) - parser.add_argument( - "--cache_dir", - type=str, - default=None, - help="The directory where the downloaded models and datasets will be stored.", - ) - parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.") - parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.") - parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.") - parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.") - parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.") - parser.add_argument( - "--layerwise_upcasting_modules", - type=str, - default=[], - nargs="+", - choices=["transformer"], - help="Modules that should have fp8 storage weights but higher precision computation.", - ) + "--model_name", type=str, required=True, choices=[x.value for x in ModelType.__members__.values()] + ) + parser.add_argument("--pretrained_model_name_or_path", type=str, required=True) + parser.add_argument("--revision", type=str, default=None, required=False) + parser.add_argument("--variant", type=str, default=None) + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--tokenizer_id", type=str, default=None) + parser.add_argument("--tokenizer_2_id", type=str, default=None) + parser.add_argument("--tokenizer_3_id", type=str, default=None) + parser.add_argument("--text_encoder_id", type=str, default=None) + parser.add_argument("--text_encoder_2_id", type=str, default=None) + parser.add_argument("--text_encoder_3_id", type=str, default=None) + parser.add_argument("--transformer_id", type=str, default=None) + parser.add_argument("--vae_id", type=str, default=None) + parser.add_argument("--text_encoder_dtype", type=str, default="bf16") + parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16") + parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16") + parser.add_argument("--transformer_dtype", type=str, default="bf16") + parser.add_argument("--vae_dtype", type=str, default="bf16") + parser.add_argument("--layerwise_upcasting_modules", type=str, default=[], nargs="+", choices=["transformer"]) parser.add_argument( "--layerwise_upcasting_storage_dtype", type=str, default="float8_e4m3fn", choices=["float8_e4m3fn", "float8_e5m2"], - help="Data type for the layerwise upcasting storage.", ) parser.add_argument( "--layerwise_upcasting_skip_modules_pattern", type=str, default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"], nargs="+", - help="Modules to skip for layerwise upcasting.", ) def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None: - def parse_resolution_bucket(resolution_bucket: str) -> Tuple[int, ...]: - return tuple(map(int, resolution_bucket.split("x"))) - - def parse_image_resolution_bucket(resolution_bucket: str) -> Tuple[int, int]: - resolution_bucket = parse_resolution_bucket(resolution_bucket) - assert ( - len(resolution_bucket) == 2 - ), f"Expected 2D resolution bucket, got {len(resolution_bucket)}D resolution bucket" - return resolution_bucket - - def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int]: - resolution_bucket = parse_resolution_bucket(resolution_bucket) - assert ( - len(resolution_bucket) == 3 - ), f"Expected 3D resolution bucket, got {len(resolution_bucket)}D resolution bucket" - return resolution_bucket - - parser.add_argument( - "--data_root", - type=str, - required=True, - help=("A folder containing the training data."), - ) - parser.add_argument( - "--dataset_file", - type=str, - default=None, - help=("Path to a CSV file if loading prompts/video paths using this format."), - ) - parser.add_argument( - "--video_column", - type=str, - default="video", - help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.", - ) - parser.add_argument( - "--caption_column", - type=str, - default="text", - help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.", - ) - parser.add_argument( - "--id_token", - type=str, - default=None, - help="Identifier token appended to the start of each prompt if provided.", - ) - parser.add_argument( - "--image_resolution_buckets", - type=parse_image_resolution_bucket, - default=None, - nargs="+", - help="Resolution buckets for images.", - ) - parser.add_argument( - "--video_resolution_buckets", - type=parse_video_resolution_bucket, - default=None, - nargs="+", - help="Resolution buckets for videos.", - ) - parser.add_argument( - "--video_reshape_mode", - type=str, - default=None, - help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", - ) - parser.add_argument( - "--caption_dropout_p", - type=float, - default=0.00, - help="Probability of dropout for the caption tokens.", - ) - parser.add_argument( - "--caption_dropout_technique", - type=str, - default="empty", - choices=["empty", "zero"], - help="Technique to use for caption dropout.", - ) - parser.add_argument( - "--precompute_conditions", - action="store_true", - help="Whether or not to precompute the conditionings for the model.", - ) - parser.add_argument( - "--remove_common_llm_caption_prefixes", - action="store_true", - help="Whether or not to remove common LLM caption prefixes.", - ) + parser.add_argument("--dataset_config", type=str, required=True) + parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1) + parser.add_argument("--precomputation_items", type=int, default=512) + parser.add_argument("--precomputation_dir", type=str, default=None) + parser.add_argument("--precomputation_once", action="store_true") def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=0, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", - ) - parser.add_argument( - "--pin_memory", - action="store_true", - help="Whether or not to use the pinned memory setting in pytorch dataloader.", - ) + parser.add_argument("--dataloader_num_workers", type=int, default=0) + parser.add_argument("--pin_memory", action="store_true") def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--flow_resolution_shifting", - action="store_true", - help="Resolution-dependent shifting of timestep schedules.", - ) - parser.add_argument( - "--flow_base_seq_len", - type=int, - default=256, - help="Base image/video sequence length for the diffusion model.", - ) - parser.add_argument( - "--flow_max_seq_len", - type=int, - default=4096, - help="Maximum image/video sequence length for the diffusion model.", - ) - parser.add_argument( - "--flow_base_shift", - type=float, - default=0.5, - help="Base shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)", - ) - parser.add_argument( - "--flow_max_shift", - type=float, - default=1.15, - help="Maximum shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)", - ) - parser.add_argument( - "--flow_shift", - type=float, - default=1.0, - help="Shift value to use for the flow matching timestep schedule.", - ) + parser.add_argument("--flow_resolution_shifting", action="store_true") + parser.add_argument("--flow_base_seq_len", type=int, default=256) + parser.add_argument("--flow_max_seq_len", type=int, default=4096) + parser.add_argument("--flow_base_shift", type=float, default=0.5) + parser.add_argument("--flow_max_shift", type=float, default=1.15) + parser.add_argument("--flow_shift", type=float, default=1.0) parser.add_argument( "--flow_weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], - help='We default to the "none" weighting scheme for uniform sampling and uniform loss', - ) - parser.add_argument( - "--flow_logit_mean", - type=float, - default=0.0, - help="Mean to use when using the `'logit_normal'` weighting scheme.", - ) - parser.add_argument( - "--flow_logit_std", - type=float, - default=1.0, - help="Standard deviation to use when using the `'logit_normal'` weighting scheme.", - ) - parser.add_argument( - "--flow_mode_scale", - type=float, - default=1.29, - help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) + parser.add_argument("--flow_logit_mean", type=float, default=0.0) + parser.add_argument("--flow_logit_std", type=float, default=1.0) + parser.add_argument("--flow_mode_scale", type=float, default=1.29) def _add_training_arguments(parser: argparse.ArgumentParser) -> None: - # TODO: support full finetuning and other kinds - parser.add_argument( - "--training_type", - type=str, - choices=["lora", "full-finetune"], - required=True, - help="Type of training to perform. Choose between ['lora', 'full-finetune']", - ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") - parser.add_argument( - "--batch_size", - type=int, - default=1, - help="Batch size (per device) for the training dataloader.", - ) - parser.add_argument("--train_epochs", type=int, default=1, help="Number of training epochs.") - parser.add_argument( - "--train_steps", - type=int, - default=None, - help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.", - ) - parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.") - parser.add_argument( - "--lora_alpha", - type=int, - default=64, - help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.", - ) - parser.add_argument( - "--target_modules", - type=str, - default=["to_k", "to_q", "to_v", "to_out.0"], - nargs="+", - help="The target modules for LoRA.", - ) - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", - ) - parser.add_argument( - "--checkpointing_steps", - type=int, - default=500, - help=( - "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" - " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" - " training using `--resume_from_checkpoint`." - ), - ) - parser.add_argument( - "--checkpointing_limit", - type=int, - default=None, - help=("Max number of checkpoints to store."), - ) - parser.add_argument( - "--resume_from_checkpoint", - type=str, - default=None, - help=( - "Whether training should be resumed from a previous checkpoint. Use a path saved by" - ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' - ), - ) - parser.add_argument( - "--enable_slicing", - action="store_true", - help="Whether or not to use VAE slicing for saving memory.", - ) - parser.add_argument( - "--enable_tiling", - action="store_true", - help="Whether or not to use VAE tiling for saving memory.", + "--training_type", type=str, choices=[x.value for x in TrainingType.__members__.values()], required=True ) + parser.add_argument("--seed", type=int, default=None) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--train_steps", type=int, default=1000) + parser.add_argument("--max_data_samples", type=int, default=2**64) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--gradient_checkpointing", action="store_true") + parser.add_argument("--checkpointing_steps", type=int, default=500) + parser.add_argument("--checkpointing_limit", type=int, default=None) + parser.add_argument("--resume_from_checkpoint", type=str, default=None) + parser.add_argument("--enable_slicing", action="store_true") + parser.add_argument("--enable_tiling", action="store_true") def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--lr", - type=float, - default=1e-4, - help="Initial learning rate (after the potential warmup period) to use.", - ) - parser.add_argument( - "--scale_lr", - action="store_true", - help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", - ) - parser.add_argument( - "--lr_scheduler", - type=str, - default="constant", - help=( - 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' - ' "constant", "constant_with_warmup"]' - ), - ) - parser.add_argument( - "--lr_warmup_steps", - type=int, - default=500, - help="Number of steps for the warmup in the lr scheduler.", - ) - parser.add_argument( - "--lr_num_cycles", - type=int, - default=1, - help="Number of hard resets of the lr in cosine_with_restarts scheduler.", - ) - parser.add_argument( - "--lr_power", - type=float, - default=1.0, - help="Power factor of the polynomial scheduler.", - ) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--lr_scheduler", type=str, default="constant") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + parser.add_argument("--lr_num_cycles", type=int, default=1) + parser.add_argument("--lr_power", type=float, default=1.0) parser.add_argument( "--optimizer", type=lambda s: s.lower(), default="adam", - choices=["adam", "adamw"], - help=("The optimizer type to use."), + choices=["adam", "adamw", "adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"], ) - parser.add_argument( - "--use_8bit_bnb", - action="store_true", - help=("Whether to use 8bit variant of the `--optimizer` using `bitsandbytes`."), - ) - parser.add_argument( - "--beta1", - type=float, - default=0.9, - help="The beta1 parameter for the Adam and Prodigy optimizers.", - ) - parser.add_argument( - "--beta2", - type=float, - default=0.95, - help="The beta2 parameter for the Adam and Prodigy optimizers.", - ) - parser.add_argument( - "--beta3", - type=float, - default=None, - help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", - ) - parser.add_argument( - "--weight_decay", - type=float, - default=1e-04, - help="Weight decay to use for optimizer.", - ) - parser.add_argument( - "--epsilon", - type=float, - default=1e-8, - help="Epsilon value for the Adam optimizer and Prodigy optimizers.", - ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.95) + parser.add_argument("--beta3", type=float, default=None) + parser.add_argument("--weight_decay", type=float, default=1e-04) + parser.add_argument("--epsilon", type=float, default=1e-8) + parser.add_argument("--max_grad_norm", default=1.0, type=float) def _add_validation_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--validation_prompts", - type=str, - default=None, - help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.", - ) - parser.add_argument( - "--validation_images", - type=str, - default=None, - help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", - ) - parser.add_argument( - "--validation_videos", - type=str, - default=None, - help="One or more video path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.", - ) - parser.add_argument( - "--validation_separator", - type=str, - default=":::", - help="String that separates multiple validation prompts", - ) - parser.add_argument( - "--num_validation_videos", - type=int, - default=1, - help="Number of videos that should be generated during validation per `validation_prompt`.", - ) - parser.add_argument( - "--validation_epochs", - type=int, - default=None, - help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.", - ) - parser.add_argument( - "--validation_steps", - type=int, - default=None, - help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.", - ) - parser.add_argument( - "--validation_frame_rate", - type=int, - default=25, - help="Frame rate to use for the validation videos.", - ) - parser.add_argument( - "--enable_model_cpu_offload", - action="store_true", - help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.", - ) + parser.add_argument("--validation_dataset_file", type=str, default=None) + parser.add_argument("--validation_steps", type=int, default=500) + parser.add_argument("--enable_model_cpu_offload", action="store_true") def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument("--tracker_name", type=str, default="finetrainers", help="Project tracker name") - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Whether or not to push the model to the Hub.", - ) - parser.add_argument( - "--hub_token", - type=str, - default=None, - help="The token to use to push to the Model Hub.", - ) - parser.add_argument( - "--hub_model_id", - type=str, - default=None, - help="The name of the repository to keep in sync with the local `output_dir`.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="finetrainers-training", - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--logging_dir", - type=str, - default="logs", - help="Directory where logs are stored.", - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help=( - "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" - " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" - ), - ) - parser.add_argument( - "--nccl_timeout", - type=int, - default=600, - help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.", - ) - parser.add_argument( - "--report_to", - type=str, - default="none", - choices=["none", "wandb"], - help="The integration to report the results and logs to.", - ) + parser.add_argument("--tracker_name", type=str, default="finetrainers") + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--hub_token", type=str, default=None) + parser.add_argument("--hub_model_id", type=str, default=None) + parser.add_argument("--output_dir", type=str, default="finetrainers-training") + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--allow_tf32", action="store_true") + parser.add_argument("--init_timeout", type=int, default=300) + parser.add_argument("--nccl_timeout", type=int, default=600) + parser.add_argument("--report_to", type=str, default="none", choices=["none", "wandb"]) + parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3]) def _add_helper_arguments(parser: argparse.ArgumentParser) -> None: - parser.add_argument( - "--list_models", - action="store_true", - help="List all the supported models.", - ) + parser.add_argument("--list_models", action="store_true") _DTYPE_MAP = { @@ -1014,8 +724,16 @@ _DTYPE_MAP = { } -def _map_to_args_type(args: Dict[str, Any]) -> Args: - result_args = Args() +def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs: + result_args = BaseArgs() + + # Parallel arguments + result_args.parallel_backend = args.parallel_backend + result_args.pp_degree = args.pp_degree + result_args.dp_degree = args.dp_degree + result_args.dp_shards = args.dp_shards + result_args.cp_degree = args.cp_degree + result_args.tp_degree = args.tp_degree # Model arguments result_args.model_name = args.model_name @@ -1023,6 +741,14 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: result_args.revision = args.revision result_args.variant = args.variant result_args.cache_dir = args.cache_dir + result_args.tokenizer_id = args.tokenizer_id + result_args.tokenizer_2_id = args.tokenizer_2_id + result_args.tokenizer_3_id = args.tokenizer_3_id + result_args.text_encoder_id = args.text_encoder_id + result_args.text_encoder_2_id = args.text_encoder_2_id + result_args.text_encoder_3_id = args.text_encoder_3_id + result_args.transformer_id = args.transformer_id + result_args.vae_id = args.vae_id result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype] result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype] result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype] @@ -1033,21 +759,11 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern # Dataset arguments - if args.data_root is None and args.dataset_file is None: - raise ValueError("At least one of `data_root` or `dataset_file` should be provided.") - - result_args.data_root = args.data_root - result_args.dataset_file = args.dataset_file - result_args.video_column = args.video_column - result_args.caption_column = args.caption_column - result_args.id_token = args.id_token - result_args.image_resolution_buckets = args.image_resolution_buckets or DEFAULT_IMAGE_RESOLUTION_BUCKETS - result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS - result_args.video_reshape_mode = args.video_reshape_mode - result_args.caption_dropout_p = args.caption_dropout_p - result_args.caption_dropout_technique = args.caption_dropout_technique - result_args.precompute_conditions = args.precompute_conditions - result_args.remove_common_llm_caption_prefixes = args.remove_common_llm_caption_prefixes + result_args.dataset_config = args.dataset_config + result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size + result_args.precomputation_items = args.precomputation_items + result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed") + result_args.precomputation_once = args.precomputation_once # Dataloader arguments result_args.dataloader_num_workers = args.dataloader_num_workers @@ -1069,11 +785,8 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: result_args.training_type = args.training_type result_args.seed = args.seed result_args.batch_size = args.batch_size - result_args.train_epochs = args.train_epochs result_args.train_steps = args.train_steps - result_args.rank = args.rank - result_args.lora_alpha = args.lora_alpha - result_args.target_modules = args.target_modules + result_args.max_data_samples = args.max_data_samples result_args.gradient_accumulation_steps = args.gradient_accumulation_steps result_args.gradient_checkpointing = args.gradient_checkpointing result_args.checkpointing_steps = args.checkpointing_steps @@ -1084,9 +797,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: # Optimizer arguments result_args.optimizer = args.optimizer or "adamw" - result_args.use_8bit_bnb = args.use_8bit_bnb result_args.lr = args.lr or 1e-4 - result_args.scale_lr = args.scale_lr result_args.lr_scheduler = args.lr_scheduler result_args.lr_warmup_steps = args.lr_warmup_steps result_args.lr_num_cycles = args.lr_num_cycles @@ -1099,42 +810,9 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: result_args.max_grad_norm = args.max_grad_norm # Validation arguments - validation_prompts = args.validation_prompts.split(args.validation_separator) if args.validation_prompts else [] - validation_images = args.validation_images.split(args.validation_separator) if args.validation_images else None - validation_videos = args.validation_videos.split(args.validation_separator) if args.validation_videos else None - stripped_validation_prompts = [] - validation_heights = [] - validation_widths = [] - validation_num_frames = [] - for prompt in validation_prompts: - prompt: str - prompt = prompt.strip() - actual_prompt, separator, resolution = prompt.rpartition("@@@") - stripped_validation_prompts.append(actual_prompt) - num_frames, height, width = None, None, None - if len(resolution) > 0: - num_frames, height, width = map(int, resolution.split("x")) - validation_num_frames.append(num_frames) - validation_heights.append(height) - validation_widths.append(width) - - if validation_images is None: - validation_images = [None] * len(validation_prompts) - if validation_videos is None: - validation_videos = [None] * len(validation_prompts) - - result_args.validation_prompts = stripped_validation_prompts - result_args.validation_heights = validation_heights - result_args.validation_widths = validation_widths - result_args.validation_num_frames = validation_num_frames - result_args.validation_images = validation_images - result_args.validation_videos = validation_videos - - result_args.num_validation_videos_per_prompt = args.num_validation_videos - result_args.validation_every_n_epochs = args.validation_epochs - result_args.validation_every_n_steps = args.validation_steps + result_args.validation_dataset_file = args.validation_dataset_file + result_args.validation_steps = args.validation_steps result_args.enable_model_cpu_offload = args.enable_model_cpu_offload - result_args.validation_frame_rate = args.validation_frame_rate # Miscellaneous arguments result_args.tracker_name = args.tracker_name @@ -1143,45 +821,36 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args: result_args.hub_model_id = args.hub_model_id result_args.output_dir = args.output_dir result_args.logging_dir = args.logging_dir + result_args.logging_steps = args.logging_steps result_args.allow_tf32 = args.allow_tf32 + result_args.init_timeout = args.init_timeout result_args.nccl_timeout = args.nccl_timeout result_args.report_to = args.report_to + result_args.verbose = args.verbose return result_args -def _validated_model_args(args: Args): +def _validate_model_args(args: BaseArgs): if args.training_type == "full-finetune": assert ( "transformer" not in args.layerwise_upcasting_modules ), "Layerwise upcasting is not supported for full-finetune training" -def _validate_training_args(args: Args): - if args.training_type == "lora": - assert args.rank is not None, "Rank is required for LoRA training" - assert args.lora_alpha is not None, "LoRA alpha is required for LoRA training" - assert ( - args.target_modules is not None and len(args.target_modules) > 0 - ), "Target modules are required for LoRA training" - - -def _validate_validation_args(args: Args): - assert args.validation_prompts is not None, "Validation prompts are required for validation" - if args.validation_images is not None: - assert len(args.validation_images) == len( - args.validation_prompts - ), "Validation images and prompts should be of same length" - if args.validation_videos is not None: - assert len(args.validation_videos) == len( - args.validation_prompts - ), "Validation videos and prompts should be of same length" - assert len(args.validation_prompts) == len( - args.validation_heights - ), "Validation prompts and heights should be of same length" - assert len(args.validation_prompts) == len( - args.validation_widths - ), "Validation prompts and widths should be of same length" +def _validate_dataset_args(args: BaseArgs): + dataset_config = pathlib.Path(args.dataset_config) + if not dataset_config.exists(): + raise ValueError(f"Dataset config file {args.dataset_config} does not exist.") + if args.dataset_shuffle_buffer_size < 1: + raise ValueError("Dataset shuffle buffer size must be greater than 0.") + if args.precomputation_items < 1: + raise ValueError("Precomputation items must be greater than 0.") + + +def _validate_validation_args(args: BaseArgs): + if args.dp_shards > 1 and args.enable_model_cpu_offload: + raise ValueError("Model CPU offload is not supported with FSDP at the moment.") def _display_helper_messages(args: argparse.Namespace): diff --git a/finetrainers/config.py b/finetrainers/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e0cda5e831d89f2bce9555bbea44ac87260fe67d --- /dev/null +++ b/finetrainers/config.py @@ -0,0 +1,52 @@ +from enum import Enum +from typing import Type + +from .models import ModelSpecification +from .models.cogvideox import CogVideoXModelSpecification +from .models.hunyuan_video import HunyuanVideoModelSpecification +from .models.ltx_video import LTXVideoModelSpecification +from .models.wan import WanModelSpecification + + +class ModelType(str, Enum): + COGVIDEOX = "cogvideox" + HUNYUAN_VIDEO = "hunyuan_video" + LTX_VIDEO = "ltx_video" + WAN = "wan" + + +class TrainingType(str, Enum): + LORA = "lora" + FULL_FINETUNE = "full-finetune" + + +SUPPORTED_MODEL_CONFIGS = { + ModelType.HUNYUAN_VIDEO: { + TrainingType.LORA: HunyuanVideoModelSpecification, + TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification, + }, + ModelType.LTX_VIDEO: { + TrainingType.LORA: LTXVideoModelSpecification, + TrainingType.FULL_FINETUNE: LTXVideoModelSpecification, + }, + ModelType.COGVIDEOX: { + TrainingType.LORA: CogVideoXModelSpecification, + TrainingType.FULL_FINETUNE: CogVideoXModelSpecification, + }, + ModelType.WAN: { + TrainingType.LORA: WanModelSpecification, + TrainingType.FULL_FINETUNE: WanModelSpecification, + }, +} + + +def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]: + if model_name not in SUPPORTED_MODEL_CONFIGS: + raise ValueError( + f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}" + ) + if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]: + raise ValueError( + f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}" + ) + return SUPPORTED_MODEL_CONFIGS[model_name][training_type] diff --git a/finetrainers/constants.py b/finetrainers/constants.py index f6318f4cc43f96db9adf94b22ae92684f976f6fe..693495ca0e91617dce6c35583b2e1a18c9025708 100644 --- a/finetrainers/constants.py +++ b/finetrainers/constants.py @@ -78,3 +78,6 @@ COMMON_LLM_START_PHRASES = ( for continuation in _COMMON_CONTINUATION_WORDS ), ) + +SUPPORTED_IMAGE_FILE_EXTENSIONS = ("jpg", "jpeg", "png") +SUPPORTED_VIDEO_FILE_EXTENSIONS = ("mp4", "mov") diff --git a/finetrainers/data/__init__.py b/finetrainers/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7706fb999099fa324c2b5fe73f62ad57e89e0796 --- /dev/null +++ b/finetrainers/data/__init__.py @@ -0,0 +1,19 @@ +from ._artifact import ImageArtifact, VideoArtifact +from .dataloader import DPDataLoader +from .dataset import ( + ImageCaptionFilePairDataset, + ImageFileCaptionFileListDataset, + ImageFolderDataset, + ImageWebDataset, + ValidationDataset, + VideoCaptionFilePairDataset, + VideoFileCaptionFileListDataset, + VideoFolderDataset, + VideoWebDataset, + combine_datasets, + initialize_dataset, + wrap_iterable_dataset_for_preprocessing, +) +from .precomputation import DistributedDataPreprocessor, PreprocessedDataIterable +from .sampler import ResolutionSampler +from .utils import find_files diff --git a/finetrainers/data/_artifact.py b/finetrainers/data/_artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..400f25d143f5062d77ed6391ca9862654d295de7 --- /dev/null +++ b/finetrainers/data/_artifact.py @@ -0,0 +1,29 @@ +# ===== THIS FILE ONLY EXISTS FOR THE TIME BEING SINCE I DID NOT KNOW WHERE TO PUT IT ===== + +from dataclasses import dataclass +from typing import Any, List + +from PIL.Image import Image + + +@dataclass +class Artifact: + type: str + value: Any + file_extension: str + + +@dataclass +class ImageArtifact(Artifact): + value: Image + + def __init__(self, value: Image): + super().__init__(type="image", value=value, file_extension="png") + + +@dataclass +class VideoArtifact(Artifact): + value: List[Image] + + def __init__(self, value: List[Image]): + super().__init__(type="video", value=value, file_extension="mp4") diff --git a/finetrainers/data/dataloader.py b/finetrainers/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b0a4b1f6253bd943cf9fe7b9e31c06aa060b35 --- /dev/null +++ b/finetrainers/data/dataloader.py @@ -0,0 +1,40 @@ +import pickle +from typing import Any, Dict + +import torch.distributed.checkpoint.stateful +import torchdata.stateful_dataloader + +from ..logging import get_logger + + +logger = get_logger() + + +class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful): + def __init__( + self, + rank: int, + dataset: torch.utils.data.IterableDataset, + batch_size: int = 1, + num_workers: int = 0, + collate_fn=None, + ) -> None: + super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn) + + self._dp_rank = rank + self._rank_id = f"dp_rank_{rank}" + + def state_dict(self) -> Dict[str, Any]: + # Store state only for dp rank to avoid replicating the same state across other dimensions + return {self._rank_id: pickle.dumps(super().state_dict())} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # State being empty is valid + if not state_dict: + return + + if self._rank_id not in state_dict: + logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}") + return + + super().load_state_dict(pickle.loads(state_dict[self._rank_id])) diff --git a/finetrainers/data/dataset.py b/finetrainers/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..672bc1131b95900abfa6bdfcd50c6b644dc45556 --- /dev/null +++ b/finetrainers/data/dataset.py @@ -0,0 +1,844 @@ +import pathlib +import random +from typing import Any, Dict, List, Optional, Tuple, Union + +import datasets +import datasets.data_files +import datasets.distributed +import datasets.exceptions +import huggingface_hub +import huggingface_hub.errors +import numpy as np +import PIL.Image +import torch +import torch.distributed.checkpoint.stateful +from diffusers.utils import load_image, load_video +from huggingface_hub import list_repo_files, repo_exists, snapshot_download +from tqdm.auto import tqdm + +from .. import constants +from .. import functional as FF +from ..logging import get_logger +from . import utils + + +import decord # isort:skip + +decord.bridge.set_bridge("torch") + +logger = get_logger() + + +MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024 +COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"] +COMMON_VIDEO_FILES = ["video.txt", "videos.txt"] +COMMON_IMAGE_FILES = ["image.txt", "images.txt"] + + +class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = [] + caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0)) + for caption_file in caption_files: + data_file = self._find_data_file(caption_file) + if data_file: + data.append( + { + "caption": (self.root / caption_file).as_posix(), + "image": (self.root / data_file).as_posix(), + } + ) + + data = datasets.Dataset.from_list(data) + data = data.cast_column("image", datasets.Image(mode="RGB")) + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + sample["caption"] = _read_caption_from_file(sample["caption"]) + sample["image"] = _preprocess_image(sample["image"]) + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + def _find_data_file(self, caption_file: str) -> str: + caption_file = pathlib.Path(caption_file) + data_file = None + found_data = 0 + + for extension in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS: + image_filename = caption_file.with_suffix(f".{extension}") + if image_filename.exists(): + found_data += 1 + data_file = image_filename + + if found_data == 0: + return False + elif found_data > 1: + raise ValueError( + f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data " + f"file per caption file. The following extensions are supported:\n" + f" - Images: {constants.SUPPORTED_IMAGE_FILE_EXTENSIONS}\n" + ) + + return data_file.as_posix() + + +class VideoCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = [] + caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0)) + for caption_file in caption_files: + data_file = self._find_data_file(caption_file) + if data_file: + data.append( + { + "caption": (self.root / caption_file).as_posix(), + "video": (self.root / data_file).as_posix(), + } + ) + + data = datasets.Dataset.from_list(data) + data = data.cast_column("video", datasets.Video()) + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + sample["caption"] = _read_caption_from_file(sample["caption"]) + sample["video"] = _preprocess_video(sample["video"]) + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + def _find_data_file(self, caption_file: str) -> str: + caption_file = pathlib.Path(caption_file) + data_file = None + found_data = 0 + + for extension in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS: + video_filename = caption_file.with_suffix(f".{extension}") + if video_filename.exists(): + found_data += 1 + data_file = video_filename + + if found_data == 0: + return False + elif found_data > 1: + raise ValueError( + f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data " + f"file per caption file. The following extensions are supported:\n" + f" - Videos: {constants.SUPPORTED_VIDEO_FILE_EXTENSIONS}\n" + ) + + return data_file.as_posix() + + +class ImageFileCaptionFileListDataset( + torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful +): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"] + VALID_IMAGE_FILES = ["image.txt", "images.txt"] + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = [] + existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()] + existing_image_files = [file for file in VALID_IMAGE_FILES if (self.root / file).exists()] + + if len(existing_caption_files) == 0: + raise FileNotFoundError( + f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" + ) + if len(existing_image_files) == 0: + raise FileNotFoundError( + f"No image file found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}" + ) + if len(existing_caption_files) > 1: + raise ValueError( + f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" + ) + if len(existing_image_files) > 1: + raise ValueError( + f"Multiple image files found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}" + ) + + caption_file = existing_caption_files[0] + image_file = existing_image_files[0] + + with open((self.root / caption_file).as_posix(), "r") as f: + captions = f.read().splitlines() + with open((self.root / image_file).as_posix(), "r") as f: + images = f.read().splitlines() + images = [(self.root / image).as_posix() for image in images] + + if len(captions) != len(images): + raise ValueError(f"Number of captions ({len(captions)}) must match number of images ({len(images)})") + + for caption, image in zip(captions, images): + data.append({"caption": caption, "image": image}) + + data = datasets.Dataset.from_list(data) + data = data.cast_column("image", datasets.Image(mode="RGB")) + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + sample["image"] = _preprocess_image(sample["image"]) + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class VideoFileCaptionFileListDataset( + torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful +): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"] + VALID_VIDEO_FILES = ["video.txt", "videos.txt"] + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = [] + existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()] + existing_video_files = [file for file in VALID_VIDEO_FILES if (self.root / file).exists()] + + if len(existing_caption_files) == 0: + raise FileNotFoundError( + f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" + ) + if len(existing_video_files) == 0: + raise FileNotFoundError( + f"No video file found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}" + ) + if len(existing_caption_files) > 1: + raise ValueError( + f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" + ) + if len(existing_video_files) > 1: + raise ValueError( + f"Multiple video files found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}" + ) + + caption_file = existing_caption_files[0] + video_file = existing_video_files[0] + + with open((self.root / caption_file).as_posix(), "r") as f: + captions = f.read().splitlines() + with open((self.root / video_file).as_posix(), "r") as f: + videos = f.read().splitlines() + videos = [(self.root / video).as_posix() for video in videos] + + if len(captions) != len(videos): + raise ValueError(f"Number of captions ({len(captions)}) must match number of videos ({len(videos)})") + + for caption, video in zip(captions, videos): + data.append({"caption": caption, "video": video}) + + data = datasets.Dataset.from_list(data) + data = data.cast_column("video", datasets.Video()) + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + sample["video"] = _preprocess_video(sample["video"]) + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class ImageFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = datasets.load_dataset("imagefolder", data_dir=self.root.as_posix(), split="train") + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + sample["image"] = _preprocess_image(sample["image"]) + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, root: str, infinite: bool = False) -> None: + super().__init__() + + self.root = pathlib.Path(root) + self.infinite = infinite + + data = datasets.load_dataset("videofolder", data_dir=self.root.as_posix(), split="train") + + self._data = data.to_iterable_dataset() + self._sample_index = 0 + self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + sample["video"] = _preprocess_video(sample["video"]) + yield sample + + if not self.infinite: + logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") + break + else: + self._sample_index = 0 + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, dataset_name: str, infinite: bool = False) -> None: + super().__init__() + + self.dataset_name = dataset_name + self.infinite = infinite + + data = datasets.load_dataset(dataset_name, split="train", streaming=True) + data = data.rename_column("txt", "caption") + for column_name in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS: + if column_name in data.column_names: + data = data.cast_column(column_name, datasets.Image(mode="RGB")) + data = data.rename_column(column_name, "image") + + self._data = data + self._sample_index = 0 + self._precomputable_once = False + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + yield sample + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_index = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, dataset_name: str, infinite: bool = False) -> None: + super().__init__() + + self.dataset_name = dataset_name + self.infinite = infinite + + data = datasets.load_dataset(dataset_name, split="train", streaming=True) + data = data.rename_column("txt", "caption") + for column_name in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS: + if column_name in data.column_names: + data = data.cast_column(column_name, datasets.Video()) + data = data.rename_column(column_name, "video") + + self._data = data + self._sample_index = 0 + self._precomputable_once = False + + def _get_data_iter(self): + if self._sample_index == 0: + return iter(self._data) + return iter(self._data.skip(self._sample_index)) + + def __iter__(self): + while True: + for sample in self._get_data_iter(): + self._sample_index += 1 + yield sample + + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + else: + # Reset offset for the next iteration + self._sample_index = 0 + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + + def load_state_dict(self, state_dict): + self._sample_index = state_dict["sample_index"] + + def state_dict(self): + return {"sample_index": self._sample_index} + + +class ValidationDataset(torch.utils.data.IterableDataset): + def __init__(self, filename: str): + super().__init__() + + self.filename = pathlib.Path(filename) + + if not self.filename.exists(): + raise FileNotFoundError(f"File {self.filename.as_posix()} does not exist") + + if self.filename.suffix == ".csv": + data = datasets.load_dataset("csv", data_files=self.filename.as_posix(), split="train") + elif self.filename.suffix == ".json": + data = datasets.load_dataset("json", data_files=self.filename.as_posix(), split="train", field="data") + elif self.filename.suffix == ".parquet": + data = datasets.load_dataset("parquet", data_files=self.filename.as_posix(), split="train") + elif self.filename.suffix == ".arrow": + data = datasets.load_dataset("arrow", data_files=self.filename.as_posix(), split="train") + else: + _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"] + raise ValueError( + f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}" + ) + + self._data = data.to_iterable_dataset() + + def __iter__(self): + for sample in self._data: + # For consistency reasons, we mandate that "caption" is always present in the validation dataset. + # However, since the model specifications use "prompt", we create an alias here. + sample["prompt"] = sample["caption"] + + # Load image or video if the path is provided + # TODO(aryan): need to handle custom columns here for control conditions + sample["image"] = None + sample["video"] = None + + if sample.get("image_path", None) is not None: + image_path = pathlib.Path(sample["image_path"]) + if not image_path.is_file(): + logger.warning(f"Image file {image_path.as_posix()} does not exist.") + else: + sample["image"] = load_image(sample["image_path"]) + + if sample.get("video_path", None) is not None: + video_path = pathlib.Path(sample["video_path"]) + if not video_path.is_file(): + logger.warning(f"Video file {video_path.as_posix()} does not exist.") + else: + sample["video"] = load_video(sample["video_path"]) + + sample = {k: v for k, v in sample.items() if v is not None} + yield sample + + +class IterableDatasetPreprocessingWrapper( + torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful +): + def __init__( + self, + dataset: torch.utils.data.IterableDataset, + dataset_type: str, + id_token: Optional[str] = None, + image_resolution_buckets: List[Tuple[int, int]] = None, + video_resolution_buckets: List[Tuple[int, int, int]] = None, + reshape_mode: str = "bicubic", + remove_common_llm_caption_prefixes: bool = False, + **kwargs, + ): + super().__init__() + + self.dataset = dataset + self.dataset_type = dataset_type + self.id_token = id_token + self.image_resolution_buckets = image_resolution_buckets + self.video_resolution_buckets = video_resolution_buckets + self.reshape_mode = reshape_mode + self.remove_common_llm_caption_prefixes = remove_common_llm_caption_prefixes + + logger.info( + f"Initializing IterableDatasetPreprocessingWrapper for the dataset with the following configuration:\n" + f" - Dataset Type: {dataset_type}\n" + f" - ID Token: {id_token}\n" + f" - Image Resolution Buckets: {image_resolution_buckets}\n" + f" - Video Resolution Buckets: {video_resolution_buckets}\n" + f" - Reshape Mode: {reshape_mode}\n" + f" - Remove Common LLM Caption Prefixes: {remove_common_llm_caption_prefixes}\n" + ) + + def __iter__(self): + logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset") + for sample in iter(self.dataset): + if self.dataset_type == "image": + if self.image_resolution_buckets: + sample["image"] = FF.resize_to_nearest_bucket_image( + sample["image"], self.image_resolution_buckets, self.reshape_mode + ) + elif self.dataset_type == "video": + if self.video_resolution_buckets: + sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video( + sample["video"], self.video_resolution_buckets, self.reshape_mode + ) + if _first_frame_only: + msg = ( + "The number of frames in the video is less than the minimum bucket size " + "specified. The first frame is being used as a single frame video. This " + "message is logged at the first occurence and for every 128th occurence " + "after that." + ) + logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE", msg, frequency=128) + sample["video"] = sample["video"][0] + + if self.remove_common_llm_caption_prefixes: + sample["caption"] = FF.remove_prefix(sample["caption"], constants.COMMON_LLM_START_PHRASES) + + if self.id_token is not None: + sample["caption"] = f"{self.id_token} {sample['caption']}" + + yield sample + + def load_state_dict(self, state_dict): + self.dataset.load_state_dict(state_dict["dataset"]) + + def state_dict(self): + return {"dataset": self.dataset.state_dict()} + + +class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): + def __init__(self, datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False): + super().__init__() + + self.datasets = datasets + self.buffer_size = buffer_size + self.shuffle = shuffle + + logger.info( + f"Initializing IterableCombinedDataset with the following configuration:\n" + f" - Number of Datasets: {len(datasets)}\n" + f" - Buffer Size: {buffer_size}\n" + f" - Shuffle: {shuffle}\n" + ) + + def __iter__(self): + logger.info(f"Starting IterableCombinedDataset with {len(self.datasets)} datasets") + iterators = [iter(dataset) for dataset in self.datasets] + buffer = [] + per_iter = max(1, self.buffer_size // len(iterators)) + + for index, it in enumerate(iterators): + for _ in tqdm(range(per_iter), desc=f"Filling buffer from data iterator {index}"): + try: + buffer.append((it, next(it))) + except StopIteration: + continue + + while len(buffer) > 0: + idx = 0 + if self.shuffle: + idx = random.randint(0, len(buffer) - 1) + current_it, sample = buffer.pop(idx) + yield sample + try: + buffer.append((current_it, next(current_it))) + except StopIteration: + pass + + def load_state_dict(self, state_dict): + for dataset, dataset_state_dict in zip(self.datasets, state_dict["datasets"]): + dataset.load_state_dict(dataset_state_dict) + + def state_dict(self): + return {"datasets": [dataset.state_dict() for dataset in self.datasets]} + + +# TODO(aryan): maybe write a test for this +def initialize_dataset( + dataset_name_or_root: str, dataset_type: str = "video", streaming: bool = True, infinite: bool = False +) -> torch.utils.data.IterableDataset: + assert dataset_type in ["image", "video"] + + try: + does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset") + except huggingface_hub.errors.HFValidationError: + does_repo_exist_on_hub = False + + if does_repo_exist_on_hub: + return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite) + else: + return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite) + + +def combine_datasets( + datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False +) -> torch.utils.data.IterableDataset: + return IterableCombinedDataset(datasets=datasets, buffer_size=buffer_size, shuffle=shuffle) + + +def wrap_iterable_dataset_for_preprocessing( + dataset: torch.utils.data.IterableDataset, dataset_type: str, config: Dict[str, Any] +) -> torch.utils.data.IterableDataset: + return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config) + + +def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False): + root = pathlib.Path(dataset_name_or_root) + supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"] + metadata_files = [root / metadata_file for metadata_file in supported_metadata_files] + metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()] + + if len(metadata_files) > 1: + raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.") + + if len(metadata_files) == 1: + if dataset_type == "image": + dataset = ImageFolderDataset(root.as_posix(), infinite=infinite) + else: + dataset = VideoFolderDataset(root.as_posix(), infinite=infinite) + return dataset + + if _has_data_caption_file_pairs(root, remote=False): + if dataset_type == "image": + dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite) + else: + dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite) + elif _has_data_file_caption_file_lists(root, remote=False): + if dataset_type == "image": + dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite) + else: + dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite) + else: + raise ValueError( + f"Could not find any supported dataset structure in the directory {root}. Please open an issue at " + f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will " + f"help you set it up." + ) + + return dataset + + +def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool = False): + repo_file_list = list_repo_files(dataset_name, repo_type="dataset") + if _has_data_caption_file_pairs(repo_file_list, remote=True): + return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) + elif _has_data_file_caption_file_lists(repo_file_list, remote=True): + return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) + else: + return _initialize_webdataset(dataset_name, dataset_type, infinite) + + +def _initialize_data_caption_file_dataset_from_hub( + dataset_name: str, dataset_type: str, infinite: bool = False +) -> torch.utils.data.IterableDataset: + logger.info(f"Downloading dataset {dataset_name} from the HF Hub") + dataset_root = snapshot_download(dataset_name, repo_type="dataset") + if dataset_type == "image": + return ImageCaptionFilePairDataset(dataset_root, infinite=infinite) + else: + return VideoCaptionFilePairDataset(dataset_root, infinite=infinite) + + +def _initialize_data_file_caption_file_dataset_from_hub( + dataset_name: str, dataset_type: str, infinite: bool = False +) -> torch.utils.data.IterableDataset: + logger.info(f"Downloading dataset {dataset_name} from the HF Hub") + dataset_root = snapshot_download(dataset_name, repo_type="dataset") + if dataset_type == "image": + return ImageFileCaptionFileListDataset(dataset_root, infinite=infinite) + else: + return VideoFileCaptionFileListDataset(dataset_root, infinite=infinite) + + +def _initialize_webdataset( + dataset_name: str, dataset_type: str, infinite: bool = False +) -> torch.utils.data.IterableDataset: + logger.info(f"Streaming webdataset {dataset_name} from the HF Hub") + if dataset_type == "image": + return ImageWebDataset(dataset_name, infinite=infinite) + else: + return VideoWebDataset(dataset_name, infinite=infinite) + + +def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: + # TODO(aryan): this logic can be improved + if not remote: + caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0) + for caption_file in caption_files: + caption_file = pathlib.Path(caption_file) + for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: + data_filename = caption_file.with_suffix(f".{extension}") + if data_filename.exists(): + return True + return False + else: + caption_files = [file for file in root if file.endswith(".txt")] + for caption_file in caption_files: + caption_file = pathlib.Path(caption_file) + for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: + data_filename = caption_file.with_suffix(f".{extension}").name + if data_filename in root: + return True + return False + + +def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: + # TODO(aryan): this logic can be improved + if not remote: + file_list = {x.name for x in root.iterdir()} + has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES) + has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES) + has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES) + return has_caption_files and (has_video_files or has_image_files) + else: + has_caption_files = any(file in root for file in COMMON_CAPTION_FILES) + has_video_files = any(file in root for file in COMMON_VIDEO_FILES) + has_image_files = any(file in root for file in COMMON_IMAGE_FILES) + return has_caption_files and (has_video_files or has_image_files) + + +def _read_caption_from_file(filename: str) -> str: + with open(filename, "r") as f: + return f.read().strip() + + +def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor: + image = image.convert("RGB") + image = np.array(image).astype(np.float32) + image = torch.from_numpy(image) + image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0 + return image + + +def _preprocess_video(video: decord.VideoReader) -> torch.Tensor: + video = video.get_batch(list(range(len(video)))) + video = video.permute(0, 3, 1, 2).contiguous() + video = video.float() / 127.5 - 1.0 + return video diff --git a/finetrainers/data/precomputation.py b/finetrainers/data/precomputation.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1f020d9d3e715dcaef769604ab819f1ccfc5a4 --- /dev/null +++ b/finetrainers/data/precomputation.py @@ -0,0 +1,163 @@ +import pathlib +from typing import Any, Callable, Dict, Iterable, Optional + +import torch +from tqdm.auto import tqdm + +from .. import utils + + +class DistributedDataPreprocessor: + def __init__( + self, + rank: int, + num_items: int, + processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]], + save_dir: str, + ) -> None: + self._rank = rank + self._num_items = num_items + self._processor_fn = processor_fn + self._save_dir = pathlib.Path(save_dir) + + self._cached_samples = [] + self._preprocessed_iterator: "PreprocessedDataIterable" = None + + self._save_dir.mkdir(parents=True, exist_ok=True) + + subdirectories = [f for f in self._save_dir.iterdir() if f.is_dir()] + utils.delete_files(subdirectories) + + def consume( + self, + data_type: str, + components: Dict[str, Any], + data_iterator, + generator: Optional[torch.Generator] = None, + cache_samples: bool = False, + use_cached_samples: bool = False, + drop_samples: bool = False, + ) -> Iterable[Dict[str, Any]]: + if data_type not in self._processor_fn.keys(): + raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") + if cache_samples: + if use_cached_samples: + raise ValueError("Cannot cache and use cached samples at the same time.") + if drop_samples: + raise ValueError("Cannot cache and drop samples at the same time.") + + for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items): + if use_cached_samples: + item = self._cached_samples[i] + else: + item = next(data_iterator) + if cache_samples: + self._cached_samples.append(item) + item = self._processor_fn[data_type](**item, **components, generator=generator) + _save_item(self._rank, i, item, self._save_dir, data_type) + + if drop_samples: + del self._cached_samples + self._cached_samples = [] + utils.free_memory() + + self._preprocessed_iterator = PreprocessedDataIterable(self._rank, self._save_dir, data_type) + return iter(self._preprocessed_iterator) + + def consume_once( + self, + data_type: str, + components: Dict[str, Any], + data_iterator, + generator: Optional[torch.Generator] = None, + cache_samples: bool = False, + use_cached_samples: bool = False, + drop_samples: bool = False, + ) -> Iterable[Dict[str, Any]]: + if data_type not in self._processor_fn.keys(): + raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}") + if cache_samples: + if use_cached_samples: + raise ValueError("Cannot cache and use cached samples at the same time.") + if drop_samples: + raise ValueError("Cannot cache and drop samples at the same time.") + + for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items): + if use_cached_samples: + item = self._cached_samples[i] + else: + item = next(data_iterator) + if cache_samples: + self._cached_samples.append(item) + item = self._processor_fn[data_type](**item, **components, generator=generator) + _save_item(self._rank, i, item, self._save_dir, data_type) + + if drop_samples: + del self._cached_samples + self._cached_samples = [] + utils.free_memory() + + self._preprocessed_iterator = PreprocessedOnceDataIterable(self._rank, self._save_dir, data_type) + return iter(self._preprocessed_iterator) + + @property + def requires_data(self): + if self._preprocessed_iterator is None: + return True + return self._preprocessed_iterator.requires_data + + +class PreprocessedDataIterable: + def __init__(self, rank: int, save_dir: str, data_type: str) -> None: + self._rank = rank + self._save_dir = pathlib.Path(save_dir) + self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt"))) + self._data_type = data_type + + self._requires_data = False + + def __iter__(self) -> Iterable[Dict[str, Any]]: + for i in range(self._num_items): + if i == self._num_items - 1: + self._requires_data = True + yield _load_item(self._rank, i, self._save_dir, self._data_type) + + def __len__(self) -> int: + return self._num_items + + @property + def requires_data(self): + return self._requires_data + + +class PreprocessedOnceDataIterable: + def __init__(self, rank: int, save_dir: str, data_type: str) -> None: + self._rank = rank + self._save_dir = pathlib.Path(save_dir) + self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt"))) + self._data_type = data_type + + self._requires_data = False + + def __iter__(self) -> Iterable[Dict[str, Any]]: + index = 0 + while True: + yield _load_item(self._rank, index, self._save_dir, self._data_type) + index = (index + 1) % self._num_items + + def __len__(self) -> int: + return self._num_items + + @property + def requires_data(self): + return self._requires_data + + +def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None: + filename = directory / f"{data_type}-{rank}-{index}.pt" + torch.save(item, filename.as_posix()) + + +def _load_item(rank: int, index: int, directory: pathlib.Path, data_type: str) -> Dict[str, Any]: + filename = directory / f"{data_type}-{rank}-{index}.pt" + return torch.load(filename.as_posix(), weights_only=True) diff --git a/finetrainers/data/sampler.py b/finetrainers/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9d650e1d610e8ce91b4168a9960479cfcfe8f7 --- /dev/null +++ b/finetrainers/data/sampler.py @@ -0,0 +1,58 @@ +from typing import Any, Dict, List, Tuple + +import torch + + +class ResolutionSampler: + def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None: + self.batch_size = batch_size + self.dim_keys = dim_keys + assert dim_keys is not None, "dim_keys must be provided" + + self._chosen_leader_key = None + self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {} + self._satisfied_buckets: List[Dict[Any, Any]] = [] + + def consume(self, *dict_items: Dict[Any, Any]) -> None: + if self._chosen_leader_key is None: + self._determine_leader_item(*dict_items) + self._update_buckets(*dict_items) + + def get_batch(self) -> List[Dict[str, Any]]: + return list(zip(*self._satisfied_buckets.pop(-1))) + + @property + def is_ready(self) -> bool: + return len(self._satisfied_buckets) > 0 + + def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None: + num_observed = 0 + for dict_item in dict_items: + for key in self.dim_keys.keys(): + if key in dict_item.keys(): + self._chosen_leader_key = key + if not torch.is_tensor(dict_item[key]): + raise ValueError(f"Leader key {key} must be a tensor") + num_observed += 1 + if num_observed > 1: + raise ValueError( + f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys" + ) + if self._chosen_leader_key is None: + raise ValueError("No leader key found in provided list of data dictionaries") + + def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None: + chosen_value = [ + dict_item[self._chosen_leader_key] + for dict_item in dict_items + if self._chosen_leader_key in dict_item.keys() + ] + if len(chosen_value) == 0: + raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries") + chosen_value = chosen_value[0] + dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key]) + if dims not in self._unsatisfied_buckets: + self._unsatisfied_buckets[dims] = [] + self._unsatisfied_buckets[dims].append(dict_items) + if len(self._unsatisfied_buckets[dims]) == self.batch_size: + self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims)) diff --git a/finetrainers/data/utils.py b/finetrainers/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd507f348efc0e123532e8082502c7cfad956ed --- /dev/null +++ b/finetrainers/data/utils.py @@ -0,0 +1,20 @@ +import pathlib +from typing import List + + +def find_files(root: str, pattern: str, depth: int = 0) -> List[str]: + root_path = pathlib.Path(root) + result_files = [] + + def within_depth(path: pathlib.Path) -> bool: + return len(path.relative_to(root_path).parts) <= depth + + if depth == 0: + result_files.extend([str(file) for file in root_path.glob(pattern)]) + else: + # rglob matches all levels, but we filter by depth + for file in root_path.rglob(pattern): + if file.is_file() and within_depth(file.parent): + result_files.append(str(file)) + + return result_files diff --git a/finetrainers/dataset.py b/finetrainers/dataset.py deleted file mode 100644 index b40a9bffe0b05b50e37ee3f4cb773a6a4c10ede9..0000000000000000000000000000000000000000 --- a/finetrainers/dataset.py +++ /dev/null @@ -1,564 +0,0 @@ -import json -import os -import random -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -import pandas as pd -import torch -import torchvision.transforms as TT -import torchvision.transforms.functional as TTF -from accelerate.logging import get_logger -from torch.utils.data import Dataset, Sampler -from torchvision import transforms -from torchvision.transforms import InterpolationMode -from torchvision.transforms.functional import resize - -import gc -import time -import resource - -# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error -# Very few bug reports but it happens. Look in decord Github issues for more relevant information. -import decord # isort:skip - -decord.bridge.set_bridge("torch") - -from .constants import ( # noqa - COMMON_LLM_START_PHRASES, - PRECOMPUTED_CONDITIONS_DIR_NAME, - PRECOMPUTED_DIR_NAME, - PRECOMPUTED_LATENTS_DIR_NAME, -) - -# Decord is causing us some issues! -# Let's try to increase file descriptor limits to avoid this error: -# -# decord._ffi.base.DECORDError: Resource temporarily unavailable -try: - soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) - print(f"Current file descriptor limits: soft={soft}, hard={hard}") - - # Try to increase to hard limit if possible - if soft < hard: - resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) - new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE) - print(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}") -except Exception as e: - print(f"Could not check or update file descriptor limits: {e}") - -logger = get_logger(__name__) - -# TODO(aryan): This needs a refactor with separation of concerns. -# Images should be handled separately. Videos should be handled separately. -# Loading should be handled separately. -# Preprocessing (aspect ratio, resizing) should be handled separately. -# URL loading should be handled. -# Parquet format should be handled. -# Loading from ZIP should be handled. -class ImageOrVideoDataset(Dataset): - def __init__( - self, - data_root: str, - caption_column: str, - video_column: str, - resolution_buckets: List[Tuple[int, int, int]], - dataset_file: Optional[str] = None, - id_token: Optional[str] = None, - remove_llm_prefixes: bool = False, - ) -> None: - super().__init__() - - self.data_root = Path(data_root) - self.dataset_file = dataset_file - self.caption_column = caption_column - self.video_column = video_column - self.id_token = f"{id_token.strip()} " if id_token else "" - self.resolution_buckets = resolution_buckets - - # Four methods of loading data are supported. - # - Using a CSV: caption_column and video_column must be some column in the CSV. One could - # make use of other columns too, such as a motion score or aesthetic score, by modifying the - # logic in CSV processing. - # - Using two files containing line-separate captions and relative paths to videos. - # - Using a JSON file containing a list of dictionaries, where each dictionary has a `caption_column` and `video_column` key. - # - Using a JSONL file containing a list of line-separated dictionaries, where each dictionary has a `caption_column` and `video_column` key. - # For a more detailed explanation about preparing dataset format, checkout the README. - if dataset_file is None: - ( - self.prompts, - self.video_paths, - ) = self._load_dataset_from_local_path() - elif dataset_file.endswith(".csv"): - ( - self.prompts, - self.video_paths, - ) = self._load_dataset_from_csv() - elif dataset_file.endswith(".json"): - ( - self.prompts, - self.video_paths, - ) = self._load_dataset_from_json() - elif dataset_file.endswith(".jsonl"): - ( - self.prompts, - self.video_paths, - ) = self._load_dataset_from_jsonl() - else: - raise ValueError( - "Expected `--dataset_file` to be a path to a CSV file or a directory containing line-separated text prompts and video paths." - ) - - if len(self.video_paths) != len(self.prompts): - raise ValueError( - f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." - ) - - # Clean LLM start phrases - if remove_llm_prefixes: - for i in range(len(self.prompts)): - self.prompts[i] = self.prompts[i].strip() - for phrase in COMMON_LLM_START_PHRASES: - if self.prompts[i].startswith(phrase): - self.prompts[i] = self.prompts[i].removeprefix(phrase).strip() - - self.video_transforms = transforms.Compose( - [ - transforms.Lambda(self.scale_transform), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), - ] - ) - - @staticmethod - def scale_transform(x): - return x / 255.0 - - def __len__(self) -> int: - return len(self.video_paths) - - def __getitem__(self, index: int) -> Dict[str, Any]: - if isinstance(index, list): - # Here, index is actually a list of data objects that we need to return. - # The BucketSampler should ideally return indices. But, in the sampler, we'd like - # to have information about num_frames, height and width. Since this is not stored - # as metadata, we need to read the video to get this information. You could read this - # information without loading the full video in memory, but we do it anyway. In order - # to not load the video twice (once to get the metadata, and once to return the loaded video - # based on sampled indices), we cache it in the BucketSampler. When the sampler is - # to yield, we yield the cache data instead of indices. So, this special check ensures - # that data is not loaded a second time. PRs are welcome for improvements. - return index - - prompt = self.id_token + self.prompts[index] - - video_path: Path = self.video_paths[index] - if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: - video = self._preprocess_image(video_path) - else: - video = self._preprocess_video(video_path) - - return { - "prompt": prompt, - "video": video, - "video_metadata": { - "num_frames": video.shape[0], - "height": video.shape[2], - "width": video.shape[3], - }, - } - - def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: - if not self.data_root.exists(): - raise ValueError("Root folder for videos does not exist") - - prompt_path = self.data_root.joinpath(self.caption_column) - video_path = self.data_root.joinpath(self.video_column) - - if not prompt_path.exists() or not prompt_path.is_file(): - raise ValueError( - "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." - ) - if not video_path.exists() or not video_path.is_file(): - raise ValueError( - "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." - ) - - with open(prompt_path, "r", encoding="utf-8") as file: - prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] - with open(video_path, "r", encoding="utf-8") as file: - video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] - - if any(not path.is_file() for path in video_paths): - raise ValueError( - f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." - ) - - return prompts, video_paths - - def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]: - df = pd.read_csv(self.dataset_file) - prompts = df[self.caption_column].tolist() - video_paths = df[self.video_column].tolist() - video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] - - if any(not path.is_file() for path in video_paths): - raise ValueError( - f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." - ) - - return prompts, video_paths - - def _load_dataset_from_json(self) -> Tuple[List[str], List[str]]: - with open(self.dataset_file, "r", encoding="utf-8") as file: - data = json.load(file) - - prompts = [entry[self.caption_column] for entry in data] - video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data] - - if any(not path.is_file() for path in video_paths): - raise ValueError( - f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." - ) - - return prompts, video_paths - - def _load_dataset_from_jsonl(self) -> Tuple[List[str], List[str]]: - with open(self.dataset_file, "r", encoding="utf-8") as file: - data = [json.loads(line) for line in file] - - prompts = [entry[self.caption_column] for entry in data] - video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data] - - if any(not path.is_file() for path in video_paths): - raise ValueError( - f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." - ) - - return prompts, video_paths - - def _preprocess_image(self, path: Path) -> torch.Tensor: - # TODO(aryan): Support alpha channel in future by whitening background - image = TTF.Image.open(path.as_posix()).convert("RGB") - image = TTF.to_tensor(image) - image = image * 2.0 - 1.0 - image = image.unsqueeze(0).contiguous() # [C, H, W] -> [1, C, H, W] (1-frame video) - return image - - def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Loads a single video, or latent and prompt embedding, based on initialization parameters. - Returns a [F, C, H, W] video tensor. - """ - max_retries = 3 - retry_delay = 1.0 # seconds - - for attempt in range(max_retries): - try: - # Create video reader - video_reader = decord.VideoReader(uri=path.as_posix()) - video_num_frames = len(video_reader) - - # Process frames - indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames)) - frames = video_reader.get_batch(indices) - frames = frames[: self.max_num_frames].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) - - # Explicitly clean up resources - del video_reader - - # Force garbage collection occasionally - if random.random() < 0.05: # 5% chance - gc.collect() - - return frames - - except decord._ffi.base.DECORDError as e: - # Log the error - error_msg = str(e) - if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1: - logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}") - - # Clean up and wait before retrying - gc.collect() - time.sleep(retry_delay * (attempt + 1)) # Increasing backoff - else: - # Either not a resource error or we've run out of retries - logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}") - raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}") - - -class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0] - - def _preprocess_image(self, path: Path) -> torch.Tensor: - # TODO(aryan): Support alpha channel in future by whitening background - image = TTF.Image.open(path.as_posix()).convert("RGB") - image = TTF.to_tensor(image) - - nearest_res = self._find_nearest_resolution(image.shape[1], image.shape[2]) - image = resize(image, nearest_res) - - image = image * 2.0 - 1.0 - image = image.unsqueeze(0).contiguous() - return image - - def _preprocess_video(self, path: Path) -> torch.Tensor: - max_retries = 3 - retry_delay = 1.0 # seconds - - for attempt in range(max_retries): - try: - # Create video reader - video_reader = decord.VideoReader(uri=path.as_posix()) - video_num_frames = len(video_reader) - - # Find appropriate bucket for the video - video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames] - - if not video_buckets: - _, h, w = self.resolution_buckets[0] - video_buckets = [(1, h, w)] - - nearest_frame_bucket = min( - video_buckets, - key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)), - default=video_buckets[0], - )[0] - - # Extract and process frames - frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) - frames = video_reader.get_batch(frame_indices) - frames = frames[:nearest_frame_bucket].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - - nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) - frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) - frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) - - # Explicitly clean up resources - del video_reader - - # Force garbage collection occasionally - if random.random() < 0.05: # 5% chance - gc.collect() - - return frames - - except decord._ffi.base.DECORDError as e: - # Log the error - error_msg = str(e) - if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1: - logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}") - - # Clean up and wait before retrying - gc.collect() - time.sleep(retry_delay * (attempt + 1)) # Increasing backoff - else: - # Either not a resource error or we've run out of retries - logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}") - raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}") - - def _find_nearest_resolution(self, height, width): - nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) - return nearest_res[1], nearest_res[2] - - -class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset): - def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - self.video_reshape_mode = video_reshape_mode - self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0] - - def _resize_for_rectangle_crop(self, arr, image_size): - reshape_mode = self.video_reshape_mode - if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: - arr = resize( - arr, - size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], - interpolation=InterpolationMode.BICUBIC, - ) - else: - arr = resize( - arr, - size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], - interpolation=InterpolationMode.BICUBIC, - ) - - h, w = arr.shape[2], arr.shape[3] - arr = arr.squeeze(0) - - delta_h = h - image_size[0] - delta_w = w - image_size[1] - - if reshape_mode == "random" or reshape_mode == "none": - top = np.random.randint(0, delta_h + 1) - left = np.random.randint(0, delta_w + 1) - elif reshape_mode == "center": - top, left = delta_h // 2, delta_w // 2 - else: - raise NotImplementedError - arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) - return arr - - def _preprocess_video(self, path: Path) -> torch.Tensor: - max_retries = 3 - retry_delay = 1.0 # seconds - - for attempt in range(max_retries): - try: - # Create video reader - video_reader = decord.VideoReader(uri=path.as_posix()) - video_num_frames = len(video_reader) - - # Find appropriate bucket for the video - video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames] - - if not video_buckets: - _, h, w = self.resolution_buckets[0] - video_buckets = [(1, h, w)] - - nearest_frame_bucket = min( - video_buckets, - key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)), - default=video_buckets[0], - )[0] - - # Extract and process frames - frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) - frames = video_reader.get_batch(frame_indices) - frames = frames[:nearest_frame_bucket].float() - frames = frames.permute(0, 3, 1, 2).contiguous() - - # Fix: Change self.resolutions to self.resolution_buckets to match the class attribute - nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) - frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) - frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) - - # Explicitly clean up resources - del video_reader - - # Force garbage collection occasionally - if random.random() < 0.05: # 5% chance - gc.collect() - - return frames - - except decord._ffi.base.DECORDError as e: - # Log the error - error_msg = str(e) - if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1: - logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}") - - # Clean up and wait before retrying - gc.collect() - time.sleep(retry_delay * (attempt + 1)) # Increasing backoff - else: - # Either not a resource error or we've run out of retries - logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}") - raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}") - - def _find_nearest_resolution(self, height, width): - nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) - return nearest_res[1], nearest_res[2] - - -class PrecomputedDataset(Dataset): - def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None: - super().__init__() - - self.data_root = Path(data_root) - - if model_name and cleaned_model_id: - precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}" - self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME - self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME - else: - self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME - self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME - - self.latent_conditions = sorted(os.listdir(self.latents_path)) - self.text_conditions = sorted(os.listdir(self.conditions_path)) - - assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match" - - def __len__(self) -> int: - return len(self.latent_conditions) - - def __getitem__(self, index: int) -> Dict[str, Any]: - conditions = {} - latent_path = self.latents_path / self.latent_conditions[index] - condition_path = self.conditions_path / self.text_conditions[index] - conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True) - conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True) - return conditions - - -class BucketSampler(Sampler): - r""" - PyTorch Sampler that groups 3D data by height, width and frames. - - Args: - data_source (`ImageOrVideoDataset`): - A PyTorch dataset object that is an instance of `ImageOrVideoDataset`. - batch_size (`int`, defaults to `8`): - The batch size to use for training. - shuffle (`bool`, defaults to `True`): - Whether or not to shuffle the data in each batch before dispatching to dataloader. - drop_last (`bool`, defaults to `False`): - Whether or not to drop incomplete buckets of data after completely iterating over all data - in the dataset. If set to True, only batches that have `batch_size` number of entries will - be yielded. If set to False, it is guaranteed that all data in the dataset will be processed - and batches that do not have `batch_size` number of entries will also be yielded. - """ - - def __init__( - self, data_source: ImageOrVideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False - ) -> None: - self.data_source = data_source - self.batch_size = batch_size - self.shuffle = shuffle - self.drop_last = drop_last - - self.buckets = {resolution: [] for resolution in data_source.resolution_buckets} - - self._raised_warning_for_drop_last = False - - def __len__(self): - if self.drop_last and not self._raised_warning_for_drop_last: - self._raised_warning_for_drop_last = True - logger.warning( - "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." - ) - return (len(self.data_source) + self.batch_size - 1) // self.batch_size - - def __iter__(self): - for index, data in enumerate(self.data_source): - video_metadata = data["video_metadata"] - f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] - - self.buckets[(f, h, w)].append(data) - if len(self.buckets[(f, h, w)]) == self.batch_size: - if self.shuffle: - random.shuffle(self.buckets[(f, h, w)]) - yield self.buckets[(f, h, w)] - del self.buckets[(f, h, w)] - self.buckets[(f, h, w)] = [] - - if self.drop_last: - return - - for fhw, bucket in list(self.buckets.items()): - if len(bucket) == 0: - continue - if self.shuffle: - random.shuffle(bucket) - yield bucket - del self.buckets[fhw] - self.buckets[fhw] = [] diff --git a/finetrainers/functional/__init__.py b/finetrainers/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a62a87847ac0e61521a2331ec3b9ea08cbd49abb --- /dev/null +++ b/finetrainers/functional/__init__.py @@ -0,0 +1,16 @@ +from .diffusion import flow_match_target, flow_match_xt +from .image import ( + bicubic_resize_image, + center_crop_image, + find_nearest_resolution_image, + resize_crop_image, + resize_to_nearest_bucket_image, +) +from .text import dropout_caption, dropout_embeddings_to_zero, remove_prefix +from .video import ( + bicubic_resize_video, + center_crop_video, + find_nearest_video_resolution, + resize_crop_video, + resize_to_nearest_bucket_video, +) diff --git a/finetrainers/functional/diffusion.py b/finetrainers/functional/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d553895c2fb251abf80f01f284049acf84f87d --- /dev/null +++ b/finetrainers/functional/diffusion.py @@ -0,0 +1,11 @@ +import torch + + +def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + r"""Forward process of flow matching.""" + return (1.0 - t) * x0 + t * n + + +def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: + r"""Loss target for flow matching.""" + return n - x0 diff --git a/finetrainers/functional/image.py b/finetrainers/functional/image.py new file mode 100644 index 0000000000000000000000000000000000000000..8b644e4495dd38ec127ec748c03a64dafd783f00 --- /dev/null +++ b/finetrainers/functional/image.py @@ -0,0 +1,54 @@ +from typing import List, Literal, Tuple + +import torch +import torch.nn.functional as F + + +def center_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_channels, height, width = image.shape + crop_h, crop_w = size + top = (height - crop_h) // 2 + left = (width - crop_w) // 2 + return image[:, top : top + crop_h, left : left + crop_w] + + +def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_channels, height, width = image.shape + target_h, target_w = size + scale = max(target_h / height, target_w / width) + new_h, new_w = int(height * scale), int(width * scale) + image = F.interpolate(image, size=(new_h, new_w), mode="bilinear", align_corners=False) + return center_crop_image(image, size) + + +def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + return F.interpolate(image, size=size, mode="bicubic", align_corners=False) + + +def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]: + num_channels, height, width = image.shape + aspect_ratio = width / height + + def aspect_ratio_diff(bucket): + return abs((bucket[1] / bucket[0]) - aspect_ratio) + + return min(resolution_buckets, key=aspect_ratio_diff) + + +def resize_to_nearest_bucket_image( + image: torch.Tensor, + resolution_buckets: List[Tuple[int, int]], + resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic", +) -> torch.Tensor: + target_size = find_nearest_resolution_image(image, resolution_buckets) + + if resize_mode == "center_crop": + return center_crop_image(image, target_size) + elif resize_mode == "resize_crop": + return resize_crop_image(image, target_size) + elif resize_mode == "bicubic": + return bicubic_resize_image(image, target_size) + else: + raise ValueError( + f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'." + ) diff --git a/finetrainers/functional/text.py b/finetrainers/functional/text.py new file mode 100644 index 0000000000000000000000000000000000000000..6e823edfc2e3f4a93d2afddf6df71a4198f05219 --- /dev/null +++ b/finetrainers/functional/text.py @@ -0,0 +1,26 @@ +import random +from typing import List, Union + +import torch + + +def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]: + if random.random() >= dropout_p: + return caption + if isinstance(caption, str): + return "" + return [""] * len(caption) + + +def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor: + if random.random() >= dropout_p: + return embed + embed = torch.zeros_like(embed) + return embed + + +def remove_prefix(text: str, prefixes: List[str]) -> str: + for prefix in prefixes: + if text.startswith(prefix): + return text.removeprefix(prefix).strip() + return text diff --git a/finetrainers/functional/video.py b/finetrainers/functional/video.py new file mode 100644 index 0000000000000000000000000000000000000000..fcbc382b0615e53270f3b17746fec14f438ddd16 --- /dev/null +++ b/finetrainers/functional/video.py @@ -0,0 +1,94 @@ +from typing import List, Literal, Tuple + +import torch +import torch.nn.functional as F + + +def center_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_frames, num_channels, height, width = video.shape + crop_h, crop_w = size + top = (height - crop_h) // 2 + left = (width - crop_w) // 2 + return video[:, :, top : top + crop_h, left : left + crop_w] + + +def resize_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_frames, num_channels, height, width = video.shape + target_h, target_w = size + scale = max(target_h / height, target_w / width) + new_h, new_w = int(height * scale), int(width * scale) + video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False) + return center_crop_video(video, size) + + +def bicubic_resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: + num_frames, num_channels, height, width = video.shape + video = F.interpolate(video, size=size, mode="bicubic", align_corners=False) + return video + + +def find_nearest_video_resolution( + video: torch.Tensor, resolution_buckets: List[Tuple[int, int, int]] +) -> Tuple[int, int, int]: + num_frames, num_channels, height, width = video.shape + aspect_ratio = width / height + possible_buckets = [b for b in resolution_buckets if b[0] <= num_frames] + + if not possible_buckets: + best_frame_match = min(resolution_buckets, key=lambda b: abs(b[0] - num_frames)) + else: + best_frame_match = max(possible_buckets, key=lambda b: b[0]) + + frame_filtered_buckets = [b for b in resolution_buckets if b[0] == best_frame_match[0]] + + def aspect_ratio_diff(bucket): + return abs((bucket[2] / bucket[1]) - aspect_ratio) + + return min(frame_filtered_buckets, key=aspect_ratio_diff) + + +def resize_to_nearest_bucket_video( + video: torch.Tensor, + resolution_buckets: List[Tuple[int, int, int]], + resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic", +) -> torch.Tensor: + """ + Resizes a video tensor to the nearest resolution bucket using the specified mode. + - It first finds a frame match with <= T frames. + - Then, it selects the closest height/width bucket. + + Args: + video (`torch.Tensor`): + Input video tensor of shape `(B, T, C, H, W)`. + resolution_buckets (`List[Tuple[int, int, int]]`): + Available (num_frames, height, width) resolution buckets. + resize_mode (`str`): + One of ["center_crop", "resize_crop", "bicubic"]. + + Returns: + `torch.Tensor`: + Resized video tensor of the nearest bucket resolution. + """ + target_frames, target_h, target_w = find_nearest_video_resolution(video, resolution_buckets) + + # Adjust frame count: only interpolate frames if no lesser/equal frame count exists + num_frames, num_channels, height, width = video.shape + _first_frame_only = False + if num_frames > target_frames: + # Downsample: Select frames evenly + indices = torch.linspace(0, num_frames - 1, target_frames).long() + video = video[indices, :, :, :] + elif num_frames < target_frames: + _first_frame_only = False + + # Resize spatial resolution + if resize_mode == "center_crop": + return center_crop_video(video, (target_h, target_w)), _first_frame_only + elif resize_mode == "resize_crop": + return resize_crop_video(video, (target_h, target_w)), _first_frame_only + elif resize_mode == "bicubic": + return bicubic_resize_video(video, (target_h, target_w)), _first_frame_only + else: + raise ValueError( + f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'." + ) diff --git a/finetrainers/hooks/__init__.py b/finetrainers/hooks/__init__.py deleted file mode 100644 index f0c3a432f4021ec9b2666b48047c6fd40f3849b9..0000000000000000000000000000000000000000 --- a/finetrainers/hooks/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .layerwise_upcasting import apply_layerwise_upcasting diff --git a/finetrainers/hooks/hooks.py b/finetrainers/hooks/hooks.py deleted file mode 100644 index e779795279e2302de286096563538d2beb818bac..0000000000000000000000000000000000000000 --- a/finetrainers/hooks/hooks.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -from typing import Any, Dict, Optional, Tuple - -import torch -from accelerate.logging import get_logger - -from ..constants import FINETRAINERS_LOG_LEVEL - - -logger = get_logger("finetrainers") # pylint: disable=invalid-name -logger.setLevel(FINETRAINERS_LOG_LEVEL) - - -class ModelHook: - r""" - A hook that contains callbacks to be executed just before and after the forward method of a model. - """ - - _is_stateful = False - - def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when a model is initialized. - Args: - module (`torch.nn.Module`): - The module attached to this hook. - """ - return module - - def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when a model is deinitalized. - Args: - module (`torch.nn.Module`): - The module attached to this hook. - """ - module.forward = module._old_forward - del module._old_forward - return module - - def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: - r""" - Hook that is executed just before the forward method of the model. - Args: - module (`torch.nn.Module`): - The module whose forward pass will be executed just after this event. - args (`Tuple[Any]`): - The positional arguments passed to the module. - kwargs (`Dict[Str, Any]`): - The keyword arguments passed to the module. - Returns: - `Tuple[Tuple[Any], Dict[Str, Any]]`: - A tuple with the treated `args` and `kwargs`. - """ - return args, kwargs - - def post_forward(self, module: torch.nn.Module, output: Any) -> Any: - r""" - Hook that is executed just after the forward method of the model. - Args: - module (`torch.nn.Module`): - The module whose forward pass been executed just before this event. - output (`Any`): - The output of the module. - Returns: - `Any`: The processed `output`. - """ - return output - - def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when the hook is detached from a module. - Args: - module (`torch.nn.Module`): - The module detached from this hook. - """ - return module - - def reset_state(self, module: torch.nn.Module): - if self._is_stateful: - raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") - return module - - -class HookRegistry: - def __init__(self, module_ref: torch.nn.Module) -> None: - super().__init__() - - self.hooks: Dict[str, ModelHook] = {} - - self._module_ref = module_ref - self._hook_order = [] - - def register_hook(self, hook: ModelHook, name: str) -> None: - if name in self.hooks.keys(): - logger.warning(f"Hook with name {name} already exists, replacing it.") - - if hasattr(self._module_ref, "_old_forward"): - old_forward = self._module_ref._old_forward - else: - old_forward = self._module_ref.forward - self._module_ref._old_forward = self._module_ref.forward - - self._module_ref = hook.initialize_hook(self._module_ref) - - if hasattr(hook, "new_forward"): - rewritten_forward = hook.new_forward - - def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = rewritten_forward(module, *args, **kwargs) - return hook.post_forward(module, output) - else: - - def new_forward(module, *args, **kwargs): - args, kwargs = hook.pre_forward(module, *args, **kwargs) - output = old_forward(*args, **kwargs) - return hook.post_forward(module, output) - - self._module_ref.forward = functools.update_wrapper( - functools.partial(new_forward, self._module_ref), old_forward - ) - - self.hooks[name] = hook - self._hook_order.append(name) - - def get_hook(self, name: str) -> Optional[ModelHook]: - if name not in self.hooks.keys(): - return None - return self.hooks[name] - - def remove_hook(self, name: str) -> None: - if name not in self.hooks.keys(): - raise ValueError(f"Hook with name {name} not found.") - self.hooks[name].deinitalize_hook(self._module_ref) - del self.hooks[name] - self._hook_order.remove(name) - - def reset_stateful_hooks(self, recurse: bool = True) -> None: - for hook_name in self._hook_order: - hook = self.hooks[hook_name] - if hook._is_stateful: - hook.reset_state(self._module_ref) - - if recurse: - for module in self._module_ref.modules(): - if hasattr(module, "_diffusers_hook"): - module._diffusers_hook.reset_stateful_hooks(recurse=False) - - @classmethod - def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry": - if not hasattr(module, "_diffusers_hook"): - module._diffusers_hook = cls(module) - return module._diffusers_hook - - def __repr__(self) -> str: - hook_repr = "" - for i, hook_name in enumerate(self._hook_order): - hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})" - if i < len(self._hook_order) - 1: - hook_repr += "\n" - return f"HookRegistry(\n{hook_repr}\n)" diff --git a/finetrainers/hooks/layerwise_upcasting.py b/finetrainers/hooks/layerwise_upcasting.py deleted file mode 100644 index b7bdc38021c5145a2a6ac515270dc356385d03a7..0000000000000000000000000000000000000000 --- a/finetrainers/hooks/layerwise_upcasting.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -from typing import Optional, Tuple, Type - -import torch -from accelerate.logging import get_logger - -from ..constants import FINETRAINERS_LOG_LEVEL -from .hooks import HookRegistry, ModelHook - - -logger = get_logger("finetrainers") # pylint: disable=invalid-name -logger.setLevel(FINETRAINERS_LOG_LEVEL) - - -# fmt: off -_SUPPORTED_PYTORCH_LAYERS = ( - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, - torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, - torch.nn.Linear, -) - -_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm") -# fmt: on - - -class LayerwiseUpcastingHook(ModelHook): - r""" - A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype - for storage. This process may lead to quality loss in the output, but can significantly reduce the memory - footprint. - """ - - _is_stateful = False - - def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None: - self.storage_dtype = storage_dtype - self.compute_dtype = compute_dtype - self.non_blocking = non_blocking - - def initialize_hook(self, module: torch.nn.Module): - module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) - return module - - def pre_forward(self, module: torch.nn.Module, *args, **kwargs): - module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking) - return args, kwargs - - def post_forward(self, module: torch.nn.Module, output): - module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking) - return output - - -def apply_layerwise_upcasting( - module: torch.nn.Module, - storage_dtype: torch.dtype, - compute_dtype: torch.dtype, - skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN, - skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None, - non_blocking: bool = False, - _prefix: str = "", -) -> None: - r""" - Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any - nn.Module using diffusers layers or pytorch primitives. - Args: - module (`torch.nn.Module`): - The module whose leaf modules will be cast to a high precision dtype for computation, and to a low - precision dtype for storage. - storage_dtype (`torch.dtype`): - The dtype to cast the module to before/after the forward pass for storage. - compute_dtype (`torch.dtype`): - The dtype to cast the module to during the forward pass for computation. - skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`): - A list of patterns to match the names of the modules to skip during the layerwise upcasting process. - skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`): - A list of module classes to skip during the layerwise upcasting process. - non_blocking (`bool`, defaults to `False`): - If `True`, the weight casting operations are non-blocking. - """ - if skip_modules_classes is None and skip_modules_pattern is None: - apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) - return - - should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or ( - skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern) - ) - if should_skip: - logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"') - return - - if isinstance(module, _SUPPORTED_PYTORCH_LAYERS): - logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"') - apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking) - return - - for name, submodule in module.named_children(): - layer_name = f"{_prefix}.{name}" if _prefix else name - apply_layerwise_upcasting( - submodule, - storage_dtype, - compute_dtype, - skip_modules_pattern, - skip_modules_classes, - non_blocking, - _prefix=layer_name, - ) - - -def apply_layerwise_upcasting_hook( - module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool -) -> None: - r""" - Applies a `LayerwiseUpcastingHook` to a given module. - Args: - module (`torch.nn.Module`): - The module to attach the hook to. - storage_dtype (`torch.dtype`): - The dtype to cast the module to before the forward pass. - compute_dtype (`torch.dtype`): - The dtype to cast the module to during the forward pass. - non_blocking (`bool`): - If `True`, the weight casting operations are non-blocking. - """ - registry = HookRegistry.check_if_exists_or_initialize(module) - hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking) - registry.register_hook(hook, "layerwise_upcasting") diff --git a/finetrainers/logging.py b/finetrainers/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..29d3597db47a91c1f3ec353a677c7309ceffc19d --- /dev/null +++ b/finetrainers/logging.py @@ -0,0 +1,111 @@ +import logging +import os +from typing import TYPE_CHECKING, Union + +from .constants import FINETRAINERS_LOG_LEVEL + + +if TYPE_CHECKING: + from .parallel import ParallelBackendType + + +class FinetrainersLoggerAdapter(logging.LoggerAdapter): + def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None: + super().__init__(logger, {}) + self.parallel_backend = parallel_backend + self._log_freq = {} + self._log_freq_counter = {} + + def log( + self, + level, + msg, + *args, + main_process_only: bool = False, + local_main_process_only: bool = True, + in_order: bool = False, + **kwargs, + ): + # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice + kwargs.setdefault("stacklevel", 2) + + if not self.isEnabledFor(level): + return + + if self.parallel_backend is None: + if int(os.environ.get("RANK", 0)) == 0: + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + return + + if (main_process_only or local_main_process_only) and in_order: + raise ValueError( + "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True." + ) + + if (main_process_only and self.parallel_backend.is_main_process) or ( + local_main_process_only and self.parallel_backend.is_local_main_process + ): + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + return + + if in_order: + for i in range(self.parallel_backend.world_size): + if self.rank == i: + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + self.parallel_backend.wait_for_everyone() + return + + if not main_process_only and not local_main_process_only: + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + return + + def log_freq( + self, + level: str, + name: str, + msg: str, + frequency: int, + *, + main_process_only: bool = False, + local_main_process_only: bool = True, + in_order: bool = False, + **kwargs, + ) -> None: + if frequency <= 0: + return + if name not in self._log_freq_counter: + self._log_freq[name] = frequency + self._log_freq_counter[name] = 0 + if self._log_freq_counter[name] % self._log_freq[name] == 0: + self.log( + level, + msg, + main_process_only=main_process_only, + local_main_process_only=local_main_process_only, + in_order=in_order, + **kwargs, + ) + self._log_freq_counter[name] += 1 + + +def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]: + global _logger + return _logger + + +def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter: + _logger.parallel_backend = parallel_backend + + +_logger = logging.getLogger("finetrainers") +_logger.setLevel(FINETRAINERS_LOG_LEVEL) +_console_handler = logging.StreamHandler() +_console_handler.setLevel(FINETRAINERS_LOG_LEVEL) +_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +_console_handler.setFormatter(_formatter) +_logger.addHandler(_console_handler) +_logger = FinetrainersLoggerAdapter(_logger) diff --git a/finetrainers/models/__init__.py b/finetrainers/models/__init__.py index c24ab951d8b3cd8e52dfa0b3647c7d8183c1352c..fb7091a5e1650715591fdd7377e7c2850c0e3bb3 100644 --- a/finetrainers/models/__init__.py +++ b/finetrainers/models/__init__.py @@ -1,33 +1 @@ -from typing import Any, Dict - -from .cogvideox import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, COGVIDEOX_T2V_LORA_CONFIG -from .hunyuan_video import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, HUNYUAN_VIDEO_T2V_LORA_CONFIG -from .ltx_video import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, LTX_VIDEO_T2V_LORA_CONFIG - - -SUPPORTED_MODEL_CONFIGS = { - "hunyuan_video": { - "lora": HUNYUAN_VIDEO_T2V_LORA_CONFIG, - "full-finetune": HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, - }, - "ltx_video": { - "lora": LTX_VIDEO_T2V_LORA_CONFIG, - "full-finetune": LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, - }, - "cogvideox": { - "lora": COGVIDEOX_T2V_LORA_CONFIG, - "full-finetune": COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, - }, -} - - -def get_config_from_model_name(model_name: str, training_type: str) -> Dict[str, Any]: - if model_name not in SUPPORTED_MODEL_CONFIGS: - raise ValueError( - f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}" - ) - if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]: - raise ValueError( - f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}" - ) - return SUPPORTED_MODEL_CONFIGS[model_name][training_type] +from .modeling_utils import ModelSpecification diff --git a/finetrainers/models/cogvideox/__init__.py b/finetrainers/models/cogvideox/__init__.py index 7a72064347e083b0d277437e6a4e6e2e54164277..e1f9a84073541b0e764877bac0335637f03d32ca 100644 --- a/finetrainers/models/cogvideox/__init__.py +++ b/finetrainers/models/cogvideox/__init__.py @@ -1,2 +1 @@ -from .full_finetune import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG -from .lora import COGVIDEOX_T2V_LORA_CONFIG +from .base_specification import CogVideoXModelSpecification diff --git a/finetrainers/models/cogvideox/base_specification.py b/finetrainers/models/cogvideox/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..918c1a9a950b6b2404fc7362c608dc1ebf319f27 --- /dev/null +++ b/finetrainers/models/cogvideox/base_specification.py @@ -0,0 +1,424 @@ +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) +from PIL.Image import Image +from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer + +from ... import data +from ...logging import get_logger +from ...processors import ProcessorMixin, T5Processor +from ...typing import ArtifactType, SchedulerType +from ...utils import get_non_null_items +from ..modeling_utils import ModelSpecification +from ..utils import DiagonalGaussianDistribution +from .utils import prepare_rotary_positional_embeddings + + +logger = get_logger() + + +class CogVideoXLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the CogVideoX VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + self.output_names = output_names + assert len(self.output_names) == 1 + + def forward( + self, + vae: AutoencoderKLCogVideoX, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=vae.dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + if compute_posterior: + latents = vae.encode(video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + if vae.use_slicing and video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + moments = torch.cat(encoded_slices) + else: + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + latents = latents.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] -> [B, F, C, H, W] + return {self.output_names[0]: latents} + + +class CogVideoXModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])] + if latent_model_processors is None: + latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + return {"latents": (1, 3, 4)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir + ) + else: + tokenizer = T5Tokenizer.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=self.revision, + cache_dir=self.cache_dir, + ) + + if self.text_encoder_id is not None: + text_encoder = AutoModel.from_pretrained( + self.text_encoder_id, + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + text_encoder = T5EncoderModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + if self.vae_id is not None: + vae = AutoencoderKLCogVideoX.from_pretrained( + self.vae_id, + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + vae = AutoencoderKLCogVideoX.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + if self.transformer_id is not None: + transformer = CogVideoXTransformer3DModel.from_pretrained( + self.transformer_id, + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + transformer = CogVideoXTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + scheduler = CogVideoXDDIMScheduler.from_pretrained( + self.pretrained_model_name_or_path, subfolder="scheduler", revision=self.revision, cache_dir=self.cache_dir + ) + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[T5Tokenizer] = None, + text_encoder: Optional[T5EncoderModel] = None, + transformer: Optional[CogVideoXTransformer3DModel] = None, + vae: Optional[AutoencoderKLCogVideoX] = None, + scheduler: Optional[CogVideoXDDIMScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> CogVideoXPipeline: + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = CogVideoXPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.vae.to(self.vae_dtype) + + if not training: + pipe.transformer.to(self.transformer_dtype) + + if enable_slicing: + pipe.vae.enable_slicing() + if enable_tiling: + pipe.vae.enable_tiling() + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + caption: str, + max_sequence_length: int = 226, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + conditions.pop("prompt_attention_mask", None) + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLCogVideoX, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: CogVideoXTransformer3DModel, + scheduler: CogVideoXDDIMScheduler, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself. + VAE_SPATIAL_SCALE_FACTOR = 8 + rope_base_height = self.transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR + rope_base_width = self.transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR + patch_size = self.transformer_config.patch_size + patch_size_t = getattr(self.transformer_config, "patch_size_t", None) + + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"), _dim=2) + latents = posterior.sample(generator=generator) + del posterior + + if not self.vae_config.invert_scale_latents: + latents = latents * self.vae_config.scaling_factor + + if patch_size_t is not None: + latents = self._pad_frames(latents, patch_size_t) + + timesteps = (sigmas.flatten() * 1000.0).long() + + noise = torch.zeros_like(latents).normal_(generator=generator) + noisy_latents = scheduler.add_noise(latents, noise, timesteps) + + batch_size, num_frames, num_channels, height, width = latents.shape + ofs_emb = ( + None + if getattr(self.transformer_config, "ofs_embed_dim", None) is None + else latents.new_full((batch_size,), fill_value=2.0) + ) + + image_rotary_emb = ( + prepare_rotary_positional_embeddings( + height=height * VAE_SPATIAL_SCALE_FACTOR, + width=width * VAE_SPATIAL_SCALE_FACTOR, + num_frames=num_frames, + vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR, + patch_size=patch_size, + patch_size_t=patch_size_t, + attention_head_dim=self.transformer_config.attention_head_dim, + device=transformer.device, + base_height=rope_base_height, + base_width=rope_base_width, + ) + if self.transformer_config.use_rotary_positional_embeddings + else None + ) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + latent_model_conditions["image_rotary_emb"] = image_rotary_emb + latent_model_conditions["ofs"] = ofs_emb + condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds") + + velocity = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same + # code paths as scheduler.get_velocity(), which can be confusing to understand. + pred = scheduler.get_velocity(velocity, noisy_latents, timesteps) + target = latents + + return pred, target, sigmas + + def validation( + self, + pipeline: CogVideoXPipeline, + prompt: str, + image: Optional[Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + # TODO(aryan): add support for more parameters + if image is not None: + pipeline = CogVideoXImageToVideoPipeline.from_pipe(pipeline) + + generation_kwargs = { + "prompt": prompt, + "image": image, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + video = pipeline(**generation_kwargs).frames[0] + return [data.VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + CogVideoXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: CogVideoXTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = CogVideoXTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + @staticmethod + def _pad_frames(latents: torch.Tensor, patch_size_t: int) -> torch.Tensor: + num_frames = latents.size(1) + additional_frames = patch_size_t - (num_frames % patch_size_t) + if additional_frames > 0: + last_frame = latents[:, -1:] + padding_frames = last_frame.expand(-1, additional_frames, -1, -1, -1) + latents = torch.cat([latents, padding_frames], dim=1) + return latents diff --git a/finetrainers/models/cogvideox/full_finetune.py b/finetrainers/models/cogvideox/full_finetune.py deleted file mode 100644 index b7f2b4bbcff806c7c12b736c36bb9733b9980353..0000000000000000000000000000000000000000 --- a/finetrainers/models/cogvideox/full_finetune.py +++ /dev/null @@ -1,32 +0,0 @@ -from diffusers import CogVideoXPipeline - -from .lora import ( - calculate_noisy_latents, - collate_fn_t2v, - forward_pass, - initialize_pipeline, - load_condition_models, - load_diffusion_models, - load_latent_models, - post_latent_preparation, - prepare_conditions, - prepare_latents, - validation, -) - - -# TODO(aryan): refactor into model specs for better re-use -COGVIDEOX_T2V_FULL_FINETUNE_CONFIG = { - "pipeline_cls": CogVideoXPipeline, - "load_condition_models": load_condition_models, - "load_latent_models": load_latent_models, - "load_diffusion_models": load_diffusion_models, - "initialize_pipeline": initialize_pipeline, - "prepare_conditions": prepare_conditions, - "prepare_latents": prepare_latents, - "post_latent_preparation": post_latent_preparation, - "collate_fn": collate_fn_t2v, - "calculate_noisy_latents": calculate_noisy_latents, - "forward_pass": forward_pass, - "validation": validation, -} diff --git a/finetrainers/models/cogvideox/lora.py b/finetrainers/models/cogvideox/lora.py deleted file mode 100644 index 65d86ee901d73296c94c2abe20f21293cace45b3..0000000000000000000000000000000000000000 --- a/finetrainers/models/cogvideox/lora.py +++ /dev/null @@ -1,334 +0,0 @@ -from typing import Any, Dict, List, Optional, Union - -import torch -from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel -from PIL import Image -from transformers import T5EncoderModel, T5Tokenizer - -from .utils import prepare_rotary_positional_embeddings - - -def load_condition_models( - model_id: str = "THUDM/CogVideoX-5b", - text_encoder_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -): - tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) - text_encoder = T5EncoderModel.from_pretrained( - model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir - ) - return {"tokenizer": tokenizer, "text_encoder": text_encoder} - - -def load_latent_models( - model_id: str = "THUDM/CogVideoX-5b", - vae_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -): - vae = AutoencoderKLCogVideoX.from_pretrained( - model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir - ) - return {"vae": vae} - - -def load_diffusion_models( - model_id: str = "THUDM/CogVideoX-5b", - transformer_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -): - transformer = CogVideoXTransformer3DModel.from_pretrained( - model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir - ) - scheduler = CogVideoXDDIMScheduler.from_pretrained(model_id, subfolder="scheduler") - return {"transformer": transformer, "scheduler": scheduler} - - -def initialize_pipeline( - model_id: str = "THUDM/CogVideoX-5b", - text_encoder_dtype: torch.dtype = torch.bfloat16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.bfloat16, - tokenizer: Optional[T5Tokenizer] = None, - text_encoder: Optional[T5EncoderModel] = None, - transformer: Optional[CogVideoXTransformer3DModel] = None, - vae: Optional[AutoencoderKLCogVideoX] = None, - scheduler: Optional[CogVideoXDDIMScheduler] = None, - device: Optional[torch.device] = None, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - is_training: bool = False, - **kwargs, -) -> CogVideoXPipeline: - component_name_pairs = [ - ("tokenizer", tokenizer), - ("text_encoder", text_encoder), - ("transformer", transformer), - ("vae", vae), - ("scheduler", scheduler), - ] - components = {} - for name, component in component_name_pairs: - if component is not None: - components[name] = component - - pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) - pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) - pipe.vae = pipe.vae.to(dtype=vae_dtype) - - # The transformer should already be in the correct dtype when training, so we don't need to cast it here. - # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during - # DDP optimizer step. - if not is_training: - pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) - - if enable_slicing: - pipe.vae.enable_slicing() - if enable_tiling: - pipe.vae.enable_tiling() - - if enable_model_cpu_offload: - pipe.enable_model_cpu_offload(device=device) - else: - pipe.to(device=device) - - return pipe - - -def prepare_conditions( - tokenizer, - text_encoder, - prompt: Union[str, List[str]], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - max_sequence_length: int = 226, # TODO: this should be configurable - **kwargs, -): - device = device or text_encoder.device - dtype = dtype or text_encoder.dtype - return _get_t5_prompt_embeds( - tokenizer=tokenizer, - text_encoder=text_encoder, - prompt=prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - - -def prepare_latents( - vae: AutoencoderKLCogVideoX, - image_or_video: torch.Tensor, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - generator: Optional[torch.Generator] = None, - precompute: bool = False, - **kwargs, -) -> torch.Tensor: - device = device or vae.device - dtype = dtype or vae.dtype - - if image_or_video.ndim == 4: - image_or_video = image_or_video.unsqueeze(2) - assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" - - image_or_video = image_or_video.to(device=device, dtype=vae.dtype) - image_or_video = image_or_video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] - if not precompute: - latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) - if not vae.config.invert_scale_latents: - latents = latents * vae.config.scaling_factor - # For training Cog 1.5, we don't need to handle the scaling factor here. - # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents - # is probably only needed for image-to-video training. - # TODO(aryan): investigate this - # else: - # latents = 1 / vae.config.scaling_factor * latents - latents = latents.to(dtype=dtype) - return {"latents": latents} - else: - # handle vae scaling in the `train()` method directly. - if vae.use_slicing and image_or_video.shape[0] > 1: - encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] - h = torch.cat(encoded_slices) - else: - h = vae._encode(image_or_video) - return {"latents": h} - - -def post_latent_preparation( - vae_config: Dict[str, Any], latents: torch.Tensor, patch_size_t: Optional[int] = None, **kwargs -) -> torch.Tensor: - if not vae_config.invert_scale_latents: - latents = latents * vae_config.scaling_factor - # For training Cog 1.5, we don't need to handle the scaling factor here. - # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents - # is probably only needed for image-to-video training. - # TODO(aryan): investigate this - # else: - # latents = 1 / vae_config.scaling_factor * latents - latents = _pad_frames(latents, patch_size_t) - latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] - return {"latents": latents} - - -def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: - return { - "prompts": [x["prompt"] for x in batch[0]], - "videos": torch.stack([x["video"] for x in batch[0]]), - } - - -def calculate_noisy_latents( - scheduler: CogVideoXDDIMScheduler, - noise: torch.Tensor, - latents: torch.Tensor, - timesteps: torch.LongTensor, -) -> torch.Tensor: - noisy_latents = scheduler.add_noise(latents, noise, timesteps) - return noisy_latents - - -def forward_pass( - transformer: CogVideoXTransformer3DModel, - scheduler: CogVideoXDDIMScheduler, - prompt_embeds: torch.Tensor, - latents: torch.Tensor, - noisy_latents: torch.Tensor, - timesteps: torch.LongTensor, - ofs_emb: Optional[torch.Tensor] = None, - **kwargs, -) -> torch.Tensor: - # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself. - VAE_SPATIAL_SCALE_FACTOR = 8 - transformer_config = transformer.module.config if hasattr(transformer, "module") else transformer.config - batch_size, num_frames, num_channels, height, width = noisy_latents.shape - rope_base_height = transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR - rope_base_width = transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR - - image_rotary_emb = ( - prepare_rotary_positional_embeddings( - height=height * VAE_SPATIAL_SCALE_FACTOR, - width=width * VAE_SPATIAL_SCALE_FACTOR, - num_frames=num_frames, - vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR, - patch_size=transformer_config.patch_size, - patch_size_t=transformer_config.patch_size_t if hasattr(transformer_config, "patch_size_t") else None, - attention_head_dim=transformer_config.attention_head_dim, - device=transformer.device, - base_height=rope_base_height, - base_width=rope_base_width, - ) - if transformer_config.use_rotary_positional_embeddings - else None - ) - ofs_emb = None if transformer_config.ofs_embed_dim is None else latents.new_full((batch_size,), fill_value=2.0) - - velocity = transformer( - hidden_states=noisy_latents, - timestep=timesteps, - encoder_hidden_states=prompt_embeds, - ofs=ofs_emb, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] - # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same - # code paths as scheduler.get_velocity(), which can be confusing to understand. - denoised_latents = scheduler.get_velocity(velocity, noisy_latents, timesteps) - - return {"latents": denoised_latents} - - -def validation( - pipeline: CogVideoXPipeline, - prompt: str, - image: Optional[Image.Image] = None, - video: Optional[List[Image.Image]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_videos_per_prompt: int = 1, - generator: Optional[torch.Generator] = None, - **kwargs, -): - generation_kwargs = { - "prompt": prompt, - "height": height, - "width": width, - "num_frames": num_frames, - "num_videos_per_prompt": num_videos_per_prompt, - "generator": generator, - "return_dict": True, - "output_type": "pil", - } - generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} - output = pipeline(**generation_kwargs).frames[0] - return [("video", output)] - - -def _get_t5_prompt_embeds( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]] = None, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - return {"prompt_embeds": prompt_embeds} - - -def _pad_frames(latents: torch.Tensor, patch_size_t: int): - if patch_size_t is None or patch_size_t == 1: - return latents - - # `latents` should be of the following format: [B, C, F, H, W]. - # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t - latent_num_frames = latents.shape[2] - additional_frames = patch_size_t - latent_num_frames % patch_size_t - - if additional_frames > 0: - last_frame = latents[:, :, -1:, :, :] - padding_frames = last_frame.repeat(1, 1, additional_frames, 1, 1) - latents = torch.cat([latents, padding_frames], dim=2) - - return latents - - -# TODO(aryan): refactor into model specs for better re-use -COGVIDEOX_T2V_LORA_CONFIG = { - "pipeline_cls": CogVideoXPipeline, - "load_condition_models": load_condition_models, - "load_latent_models": load_latent_models, - "load_diffusion_models": load_diffusion_models, - "initialize_pipeline": initialize_pipeline, - "prepare_conditions": prepare_conditions, - "prepare_latents": prepare_latents, - "post_latent_preparation": post_latent_preparation, - "collate_fn": collate_fn_t2v, - "calculate_noisy_latents": calculate_noisy_latents, - "forward_pass": forward_pass, - "validation": validation, -} diff --git a/finetrainers/models/hunyuan_video/__init__.py b/finetrainers/models/hunyuan_video/__init__.py index 8ac729e91bb0d8af781ea51e856a43bfff1990df..518a42865f0cee30a534da458ec63b08c1a8d7e4 100644 --- a/finetrainers/models/hunyuan_video/__init__.py +++ b/finetrainers/models/hunyuan_video/__init__.py @@ -1,2 +1 @@ -from .full_finetune import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG -from .lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG +from .base_specification import HunyuanVideoModelSpecification diff --git a/finetrainers/models/hunyuan_video/base_specification.py b/finetrainers/models/hunyuan_video/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..1cd73fce85e8c804d1526ff0a464288d4da67740 --- /dev/null +++ b/finetrainers/models/hunyuan_video/base_specification.py @@ -0,0 +1,413 @@ +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel + +from ... import data +from ... import functional as FF +from ...logging import get_logger +from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin +from ...typing import ArtifactType, SchedulerType +from ...utils import get_non_null_items +from ..modeling_utils import ModelSpecification + + +logger = get_logger() + + +class HunyuanLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the HunyuanVideo VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + self.output_names = output_names + assert len(self.output_names) == 1 + + def forward( + self, + vae: AutoencoderKLHunyuanVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=vae.dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + if compute_posterior: + latents = vae.encode(video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + if vae.use_slicing and video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + moments = torch.cat(encoded_slices) + else: + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + return {self.output_names[0]: latents} + + +class HunyuanVideoModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [ + LlamaProcessor(["encoder_hidden_states", "encoder_attention_mask"]), + CLIPPooledProcessor( + ["pooled_projections"], + input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"}, + ), + ] + if latent_model_processors is None: + latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + # TODO + return { + "latents": (2, 3, 4), + } + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=self.revision, + cache_dir=self.cache_dir, + ) + + if self.tokenizer_2_id is not None: + tokenizer_2 = CLIPTokenizer.from_pretrained( + self.tokenizer_2_id, revision=self.revision, cache_dir=self.cache_dir + ) + else: + tokenizer_2 = CLIPTokenizer.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=self.revision, + cache_dir=self.cache_dir, + ) + + if self.text_encoder_id is not None: + text_encoder = LlamaModel.from_pretrained( + self.text_encoder_id, + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + text_encoder = LlamaModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + if self.text_encoder_2_id is not None: + text_encoder_2 = CLIPTextModel.from_pretrained( + self.text_encoder_2_id, + torch_dtype=self.text_encoder_2_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + text_encoder_2 = CLIPTextModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=self.text_encoder_2_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + } + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + if self.vae_id is not None: + vae = AutoencoderKLHunyuanVideo.from_pretrained( + self.vae_id, + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + vae = AutoencoderKLHunyuanVideo.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + if self.transformer_id is not None: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + self.transformer_id, + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + text_encoder: Optional[LlamaModel] = None, + text_encoder_2: Optional[CLIPTextModel] = None, + transformer: Optional[HunyuanVideoTransformer3DModel] = None, + vae: Optional[AutoencoderKLHunyuanVideo] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> HunyuanVideoPipeline: + components = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = HunyuanVideoPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.text_encoder_2.to(self.text_encoder_2_dtype) + pipe.vae.to(self.vae_dtype) + + if not training: + pipe.transformer.to(self.transformer_dtype) + + if enable_slicing: + pipe.vae.enable_slicing() + if enable_tiling: + pipe.vae.enable_tiling() + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + tokenizer_2: CLIPTokenizer, + text_encoder: LlamaModel, + text_encoder_2: CLIPTextModel, + caption: str, + max_sequence_length: int = 256, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLHunyuanVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: HunyuanVideoTransformer3DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + guidance: float = 1.0, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + latents = latents * self.vae_config.scaling_factor + noise = torch.zeros_like(latents).normal_(generator=generator) + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + + timesteps = (sigmas.flatten() * 1000.0).long() + guidance = latents.new_full((latents.size(0),), fill_value=guidance) * 1000.0 + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + latent_model_conditions["guidance"] = guidance + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: HunyuanVideoPipeline, + prompt: str, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + generation_kwargs = { + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + video = pipeline(**generation_kwargs).frames[0] + return [data.VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + HunyuanVideoPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: HunyuanVideoTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = HunyuanVideoTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) diff --git a/finetrainers/models/hunyuan_video/full_finetune.py b/finetrainers/models/hunyuan_video/full_finetune.py deleted file mode 100644 index 65e73f5451cacbc4ea4540b6bfcd3c3bd1b9e531..0000000000000000000000000000000000000000 --- a/finetrainers/models/hunyuan_video/full_finetune.py +++ /dev/null @@ -1,30 +0,0 @@ -from diffusers import HunyuanVideoPipeline - -from .lora import ( - collate_fn_t2v, - forward_pass, - initialize_pipeline, - load_condition_models, - load_diffusion_models, - load_latent_models, - post_latent_preparation, - prepare_conditions, - prepare_latents, - validation, -) - - -# TODO(aryan): refactor into model specs for better re-use -HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG = { - "pipeline_cls": HunyuanVideoPipeline, - "load_condition_models": load_condition_models, - "load_latent_models": load_latent_models, - "load_diffusion_models": load_diffusion_models, - "initialize_pipeline": initialize_pipeline, - "prepare_conditions": prepare_conditions, - "prepare_latents": prepare_latents, - "post_latent_preparation": post_latent_preparation, - "collate_fn": collate_fn_t2v, - "forward_pass": forward_pass, - "validation": validation, -} diff --git a/finetrainers/models/hunyuan_video/lora.py b/finetrainers/models/hunyuan_video/lora.py deleted file mode 100644 index 1d8ccd1f61f3131f9fb0c2ba1235070ce7439ba0..0000000000000000000000000000000000000000 --- a/finetrainers/models/hunyuan_video/lora.py +++ /dev/null @@ -1,368 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from accelerate.logging import get_logger -from diffusers import ( - AutoencoderKLHunyuanVideo, - FlowMatchEulerDiscreteScheduler, - HunyuanVideoPipeline, - HunyuanVideoTransformer3DModel, -) -from PIL import Image -from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer - - -logger = get_logger("finetrainers") # pylint: disable=invalid-name - - -def load_condition_models( - model_id: str = "hunyuanvideo-community/HunyuanVideo", - text_encoder_dtype: torch.dtype = torch.float16, - text_encoder_2_dtype: torch.dtype = torch.float16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -) -> Dict[str, nn.Module]: - tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) - text_encoder = LlamaModel.from_pretrained( - model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir - ) - tokenizer_2 = CLIPTokenizer.from_pretrained( - model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir - ) - text_encoder_2 = CLIPTextModel.from_pretrained( - model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir - ) - return { - "tokenizer": tokenizer, - "text_encoder": text_encoder, - "tokenizer_2": tokenizer_2, - "text_encoder_2": text_encoder_2, - } - - -def load_latent_models( - model_id: str = "hunyuanvideo-community/HunyuanVideo", - vae_dtype: torch.dtype = torch.float16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -) -> Dict[str, nn.Module]: - vae = AutoencoderKLHunyuanVideo.from_pretrained( - model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir - ) - return {"vae": vae} - - -def load_diffusion_models( - model_id: str = "hunyuanvideo-community/HunyuanVideo", - transformer_dtype: torch.dtype = torch.bfloat16, - shift: float = 1.0, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]: - transformer = HunyuanVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir - ) - scheduler = FlowMatchEulerDiscreteScheduler(shift=shift) - return {"transformer": transformer, "scheduler": scheduler} - - -def initialize_pipeline( - model_id: str = "hunyuanvideo-community/HunyuanVideo", - text_encoder_dtype: torch.dtype = torch.float16, - text_encoder_2_dtype: torch.dtype = torch.float16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.float16, - tokenizer: Optional[LlamaTokenizer] = None, - text_encoder: Optional[LlamaModel] = None, - tokenizer_2: Optional[CLIPTokenizer] = None, - text_encoder_2: Optional[CLIPTextModel] = None, - transformer: Optional[HunyuanVideoTransformer3DModel] = None, - vae: Optional[AutoencoderKLHunyuanVideo] = None, - scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, - device: Optional[torch.device] = None, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - is_training: bool = False, - **kwargs, -) -> HunyuanVideoPipeline: - component_name_pairs = [ - ("tokenizer", tokenizer), - ("text_encoder", text_encoder), - ("tokenizer_2", tokenizer_2), - ("text_encoder_2", text_encoder_2), - ("transformer", transformer), - ("vae", vae), - ("scheduler", scheduler), - ] - components = {} - for name, component in component_name_pairs: - if component is not None: - components[name] = component - - pipe = HunyuanVideoPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) - pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) - pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=text_encoder_2_dtype) - pipe.vae = pipe.vae.to(dtype=vae_dtype) - - # The transformer should already be in the correct dtype when training, so we don't need to cast it here. - # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during - # DDP optimizer step. - if not is_training: - pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) - - if enable_slicing: - pipe.vae.enable_slicing() - if enable_tiling: - pipe.vae.enable_tiling() - - if enable_model_cpu_offload: - pipe.enable_model_cpu_offload(device=device) - else: - pipe.to(device=device) - - return pipe - - -def prepare_conditions( - tokenizer: LlamaTokenizer, - text_encoder: LlamaModel, - tokenizer_2: CLIPTokenizer, - text_encoder_2: CLIPTextModel, - prompt: Union[str, List[str]], - guidance: float = 1.0, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - max_sequence_length: int = 256, - # TODO(aryan): make configurable - prompt_template: Dict[str, Any] = { - "template": ( - "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " - "1. The main content and theme of the video." - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." - "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." - "4. background environment, light, style and atmosphere." - "5. camera angles, movements, and transitions used in the video:<|eot_id|>" - "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" - ), - "crop_start": 95, - }, - **kwargs, -) -> torch.Tensor: - device = device or text_encoder.device - dtype = dtype or text_encoder.dtype - - if isinstance(prompt, str): - prompt = [prompt] - - conditions = {} - conditions.update( - _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length) - ) - conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype)) - - guidance = torch.tensor([guidance], device=device, dtype=dtype) * 1000.0 - conditions["guidance"] = guidance - - return conditions - - -def prepare_latents( - vae: AutoencoderKLHunyuanVideo, - image_or_video: torch.Tensor, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - generator: Optional[torch.Generator] = None, - precompute: bool = False, - **kwargs, -) -> torch.Tensor: - device = device or vae.device - dtype = dtype or vae.dtype - - if image_or_video.ndim == 4: - image_or_video = image_or_video.unsqueeze(2) - assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" - - image_or_video = image_or_video.to(device=device, dtype=vae.dtype) - image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W] - if not precompute: - latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) - latents = latents * vae.config.scaling_factor - latents = latents.to(dtype=dtype) - return {"latents": latents} - else: - if vae.use_slicing and image_or_video.shape[0] > 1: - encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] - h = torch.cat(encoded_slices) - else: - h = vae._encode(image_or_video) - return {"latents": h} - - -def post_latent_preparation( - vae_config: Dict[str, Any], - latents: torch.Tensor, - **kwargs, -) -> torch.Tensor: - latents = latents * vae_config.scaling_factor - return {"latents": latents} - - -def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: - return { - "prompts": [x["prompt"] for x in batch[0]], - "videos": torch.stack([x["video"] for x in batch[0]]), - } - - -def forward_pass( - transformer: HunyuanVideoTransformer3DModel, - prompt_embeds: torch.Tensor, - pooled_prompt_embeds: torch.Tensor, - prompt_attention_mask: torch.Tensor, - guidance: torch.Tensor, - latents: torch.Tensor, - noisy_latents: torch.Tensor, - timesteps: torch.LongTensor, - **kwargs, -) -> torch.Tensor: - denoised_latents = transformer( - hidden_states=noisy_latents, - timestep=timesteps, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - guidance=guidance, - return_dict=False, - )[0] - - return {"latents": denoised_latents} - - -def validation( - pipeline: HunyuanVideoPipeline, - prompt: str, - image: Optional[Image.Image] = None, - video: Optional[List[Image.Image]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - num_videos_per_prompt: int = 1, - generator: Optional[torch.Generator] = None, - **kwargs, -): - generation_kwargs = { - "prompt": prompt, - "height": height, - "width": width, - "num_frames": num_frames, - "num_inference_steps": 30, - "num_videos_per_prompt": num_videos_per_prompt, - "generator": generator, - "return_dict": True, - "output_type": "pil", - } - generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} - output = pipeline(**generation_kwargs).frames[0] - return [("video", output)] - - -def _get_llama_prompt_embeds( - tokenizer: LlamaTokenizer, - text_encoder: LlamaModel, - prompt: List[str], - prompt_template: Dict[str, Any], - device: torch.device, - dtype: torch.dtype, - max_sequence_length: int = 256, - num_hidden_layers_to_skip: int = 2, -) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size = len(prompt) - prompt = [prompt_template["template"].format(p) for p in prompt] - - crop_start = prompt_template.get("crop_start", None) - if crop_start is None: - prompt_template_input = tokenizer( - prompt_template["template"], - padding="max_length", - return_tensors="pt", - return_length=False, - return_overflowing_tokens=False, - return_attention_mask=False, - ) - crop_start = prompt_template_input["input_ids"].shape[-1] - # Remove <|eot_id|> token and placeholder {} - crop_start -= 2 - - max_sequence_length += crop_start - text_inputs = tokenizer( - prompt, - max_length=max_sequence_length, - padding="max_length", - truncation=True, - return_tensors="pt", - return_length=False, - return_overflowing_tokens=False, - return_attention_mask=True, - ) - text_input_ids = text_inputs.input_ids.to(device=device) - prompt_attention_mask = text_inputs.attention_mask.to(device=device) - - prompt_embeds = text_encoder( - input_ids=text_input_ids, - attention_mask=prompt_attention_mask, - output_hidden_states=True, - ).hidden_states[-(num_hidden_layers_to_skip + 1)] - prompt_embeds = prompt_embeds.to(dtype=dtype) - - if crop_start is not None and crop_start > 0: - prompt_embeds = prompt_embeds[:, crop_start:] - prompt_attention_mask = prompt_attention_mask[:, crop_start:] - - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) - - return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask} - - -def _get_clip_prompt_embeds( - tokenizer_2: CLIPTokenizer, - text_encoder_2: CLIPTextModel, - prompt: Union[str, List[str]], - device: torch.device, - dtype: torch.dtype, - max_sequence_length: int = 77, -) -> torch.Tensor: - text_inputs = tokenizer_2( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - - prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output - prompt_embeds = prompt_embeds.to(dtype=dtype) - - return {"pooled_prompt_embeds": prompt_embeds} - - -# TODO(aryan): refactor into model specs for better re-use -HUNYUAN_VIDEO_T2V_LORA_CONFIG = { - "pipeline_cls": HunyuanVideoPipeline, - "load_condition_models": load_condition_models, - "load_latent_models": load_latent_models, - "load_diffusion_models": load_diffusion_models, - "initialize_pipeline": initialize_pipeline, - "prepare_conditions": prepare_conditions, - "prepare_latents": prepare_latents, - "post_latent_preparation": post_latent_preparation, - "collate_fn": collate_fn_t2v, - "forward_pass": forward_pass, - "validation": validation, -} diff --git a/finetrainers/models/ltx_video/__init__.py b/finetrainers/models/ltx_video/__init__.py index 69391cdf9157831343a5fb73b237de618e8288bd..ff4e3550d54bb33fac80dd2d075ad2846eeeed46 100644 --- a/finetrainers/models/ltx_video/__init__.py +++ b/finetrainers/models/ltx_video/__init__.py @@ -1,2 +1 @@ -from .full_finetune import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG -from .lora import LTX_VIDEO_T2V_LORA_CONFIG +from .base_specification import LTXVideoModelSpecification diff --git a/finetrainers/models/ltx_video/base_specification.py b/finetrainers/models/ltx_video/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..37f6ceced22e53a5383c1fb42c7f6b4d2dc08ad2 --- /dev/null +++ b/finetrainers/models/ltx_video/base_specification.py @@ -0,0 +1,522 @@ +import os +import random +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXImageToVideoPipeline, + LTXPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from PIL.Image import Image +from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer + +from ... import data +from ... import functional as FF +from ...logging import get_logger +from ...parallel import ParallelBackendEnum +from ...processors import ProcessorMixin, T5Processor +from ...typing import ArtifactType, SchedulerType +from ...utils import get_non_null_items +from ..modeling_utils import ModelSpecification + + +logger = get_logger() + + +class LTXLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the LTX VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + - num_frames: The number of frames in the input video. + - height: The height of the input image/video. + - width: The width of the input image/video. + - latents_mean: The latent channel means from the VAE state dict. + - latents_std: The latent channel standard deviations from the VAE state dict. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + self.output_names = output_names + assert len(self.output_names) == 6 + + def forward( + self, + vae: AutoencoderKLLTXVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=vae.dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + if compute_posterior: + latents = vae.encode(video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + if vae.use_slicing and video.shape[0] > 1: + encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + moments = torch.cat(encoded_slices) + else: + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + _, _, num_frames, height, width = latents.shape + + return { + self.output_names[0]: latents, + self.output_names[1]: num_frames, + self.output_names[2]: height, + self.output_names[3]: width, + self.output_names[4]: vae.latents_mean, + self.output_names[5]: vae.latents_std, + } + + +class LTXVideoModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "Lightricks/LTX-Video", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])] + if latent_model_processors is None: + latent_model_processors = [ + LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"]) + ] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + return { + "latents": (2, 3, 4), + } + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir + ) + else: + tokenizer = T5Tokenizer.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=self.revision, + cache_dir=self.cache_dir, + ) + + if self.text_encoder_id is not None: + text_encoder = AutoModel.from_pretrained( + self.text_encoder_id, + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + text_encoder = T5EncoderModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + if self.vae_id is not None: + vae = AutoencoderKLLTXVideo.from_pretrained( + self.vae_id, + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + vae = AutoencoderKLLTXVideo.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + if self.transformer_id is not None: + transformer = LTXVideoTransformer3DModel.from_pretrained( + self.transformer_id, + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + transformer = LTXVideoTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[T5Tokenizer] = None, + text_encoder: Optional[T5EncoderModel] = None, + transformer: Optional[LTXVideoTransformer3DModel] = None, + vae: Optional[AutoencoderKLLTXVideo] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> LTXPipeline: + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = LTXPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.vae.to(self.vae_dtype) + + if not training: + pipe.transformer.to(self.transformer_dtype) + + if enable_slicing: + pipe.vae.enable_slicing() + if enable_tiling: + pipe.vae.enable_tiling() + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + caption: str, + max_sequence_length: int = 128, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLLTXVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: LTXVideoTransformer3DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + # TODO(aryan): make this configurable? Should it be? + first_frame_conditioning_p = 0.1 + min_first_frame_sigma = 0.25 + + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + latents_mean = latent_model_conditions.pop("latents_mean") + latents_std = latent_model_conditions.pop("latents_std") + + latents = self._normalize_latents(latents, latents_mean, latents_std) + noise = torch.zeros_like(latents).normal_(generator=generator) + + if random.random() < first_frame_conditioning_p: + # Based on Section 2.4 of the paper, it mentions that the first frame timesteps should be a small random value. + # Making as estimated guess, we limit the sigmas to be at least 0.2. + # torch.rand_like returns values in [0, 1). We want to make sure that the first frame sigma is <= actual sigmas + # for image conditioning. In order to do this, we rescale by multiplying with sigmas so the range is [0, sigmas). + first_frame_sigma = torch.rand_like(sigmas) * sigmas + first_frame_sigma = torch.min(first_frame_sigma, sigmas.new_full(sigmas.shape, min_first_frame_sigma)) + + latents_first_frame, latents_rest = latents[:, :, :1], latents[:, :, 1:] + noisy_latents_first_frame = FF.flow_match_xt(latents_first_frame, noise[:, :, :1], first_frame_sigma) + noisy_latents_remaining = FF.flow_match_xt(latents_rest, noise[:, :, 1:], sigmas) + noisy_latents = torch.cat([noisy_latents_first_frame, noisy_latents_remaining], dim=2) + else: + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + + patch_size = self.transformer_config.patch_size + patch_size_t = self.transformer_config.patch_size_t + + latents = self._pack_latents(latents, patch_size, patch_size_t) + noise = self._pack_latents(noise, patch_size, patch_size_t) + noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t) + + sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds") + condition_model_conditions["encoder_attention_mask"] = condition_model_conditions.pop("prompt_attention_mask") + + # TODO(aryan): make this configurable + frame_rate = 25 + temporal_compression_ratio = 8 + vae_spatial_compression_ratio = 32 + latent_frame_rate = frame_rate / temporal_compression_ratio + + rope_interpolation_scale = [ + 1 / latent_frame_rate, + vae_spatial_compression_ratio, + vae_spatial_compression_ratio, + ] + timesteps = (sigmas * 1000.0).long() + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + rope_interpolation_scale=rope_interpolation_scale, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: LTXPipeline, + prompt: str, + image: Optional[Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + frame_rate: int = 25, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + if image is not None: + pipeline = LTXImageToVideoPipeline.from_pipe(pipeline) + + generation_kwargs = { + "prompt": prompt, + "image": image, + "height": height, + "width": width, + "num_frames": num_frames, + "frame_rate": frame_rate, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + video = pipeline(**generation_kwargs).frames[0] + return [data.VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + LTXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: LTXVideoTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = LTXVideoTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def apply_tensor_parallel( + self, + backend: ParallelBackendEnum, + device_mesh: torch.distributed.DeviceMesh, + transformer: LTXVideoTransformer3DModel, + **kwargs, + ) -> None: + if backend == ParallelBackendEnum.PTD: + _apply_tensor_parallel_ptd(device_mesh, transformer) + else: + raise NotImplementedError(f"Parallel backend {backend} is not supported for LTXVideoModelSpecification") + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + +def _apply_tensor_parallel_ptd( + device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel +) -> None: + from torch.distributed.tensor.parallel import parallelize_module + from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel + + transformer_plan = { + # ===== Condition embeddings ===== + # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(), + # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)), + # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()), + # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())), + # "caption_projection.linear_1": ColwiseParallel(), + # "caption_projection.linear_2": RowwiseParallel(), + # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), + # ===== ===== + } + + for block in transformer.transformer_blocks: + block_plan = {} + + # ===== Attention ===== + # 8 all-to-all, 3 all-reduce + # block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False) + # block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False) + # block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False) + # block_plan["attn1.norm_q"] = SequenceParallel() + # block_plan["attn1.norm_k"] = SequenceParallel() + # block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) + # block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False) + # block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False) + # block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False) + # block_plan["attn2.norm_q"] = SequenceParallel() + # block_plan["attn2.norm_k"] = SequenceParallel() + # block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) + # ===== ===== + + block_plan["ff.net.0.proj"] = ColwiseParallel() + block_plan["ff.net.2"] = RowwiseParallel() + + parallelize_module(block, device_mesh, block_plan) + + parallelize_module(transformer, device_mesh, transformer_plan) diff --git a/finetrainers/models/ltx_video/full_finetune.py b/finetrainers/models/ltx_video/full_finetune.py deleted file mode 100644 index ca799ea6f1b4b075efa9f2c27bb69564832bcb7d..0000000000000000000000000000000000000000 --- a/finetrainers/models/ltx_video/full_finetune.py +++ /dev/null @@ -1,30 +0,0 @@ -from diffusers import LTXPipeline - -from .lora import ( - collate_fn_t2v, - forward_pass, - initialize_pipeline, - load_condition_models, - load_diffusion_models, - load_latent_models, - post_latent_preparation, - prepare_conditions, - prepare_latents, - validation, -) - - -# TODO(aryan): refactor into model specs for better re-use -LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG = { - "pipeline_cls": LTXPipeline, - "load_condition_models": load_condition_models, - "load_latent_models": load_latent_models, - "load_diffusion_models": load_diffusion_models, - "initialize_pipeline": initialize_pipeline, - "prepare_conditions": prepare_conditions, - "prepare_latents": prepare_latents, - "post_latent_preparation": post_latent_preparation, - "collate_fn": collate_fn_t2v, - "forward_pass": forward_pass, - "validation": validation, -} diff --git a/finetrainers/models/ltx_video/lora.py b/finetrainers/models/ltx_video/lora.py deleted file mode 100644 index bdd6ffa3e3b91564ff88222a2314f66c6e465116..0000000000000000000000000000000000000000 --- a/finetrainers/models/ltx_video/lora.py +++ /dev/null @@ -1,331 +0,0 @@ -from typing import Dict, List, Optional, Union - -import torch -import torch.nn as nn -from accelerate.logging import get_logger -from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel -from PIL import Image -from transformers import T5EncoderModel, T5Tokenizer - - -logger = get_logger("finetrainers") # pylint: disable=invalid-name - - -def load_condition_models( - model_id: str = "Lightricks/LTX-Video", - text_encoder_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -) -> Dict[str, nn.Module]: - tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir) - text_encoder = T5EncoderModel.from_pretrained( - model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir - ) - return {"tokenizer": tokenizer, "text_encoder": text_encoder} - - -def load_latent_models( - model_id: str = "Lightricks/LTX-Video", - vae_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -) -> Dict[str, nn.Module]: - vae = AutoencoderKLLTXVideo.from_pretrained( - model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir - ) - return {"vae": vae} - - -def load_diffusion_models( - model_id: str = "Lightricks/LTX-Video", - transformer_dtype: torch.dtype = torch.bfloat16, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - **kwargs, -) -> Dict[str, nn.Module]: - transformer = LTXVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir - ) - scheduler = FlowMatchEulerDiscreteScheduler() - return {"transformer": transformer, "scheduler": scheduler} - - -def initialize_pipeline( - model_id: str = "Lightricks/LTX-Video", - text_encoder_dtype: torch.dtype = torch.bfloat16, - transformer_dtype: torch.dtype = torch.bfloat16, - vae_dtype: torch.dtype = torch.bfloat16, - tokenizer: Optional[T5Tokenizer] = None, - text_encoder: Optional[T5EncoderModel] = None, - transformer: Optional[LTXVideoTransformer3DModel] = None, - vae: Optional[AutoencoderKLLTXVideo] = None, - scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, - device: Optional[torch.device] = None, - revision: Optional[str] = None, - cache_dir: Optional[str] = None, - enable_slicing: bool = False, - enable_tiling: bool = False, - enable_model_cpu_offload: bool = False, - is_training: bool = False, - **kwargs, -) -> LTXPipeline: - component_name_pairs = [ - ("tokenizer", tokenizer), - ("text_encoder", text_encoder), - ("transformer", transformer), - ("vae", vae), - ("scheduler", scheduler), - ] - components = {} - for name, component in component_name_pairs: - if component is not None: - components[name] = component - - pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir) - pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype) - pipe.vae = pipe.vae.to(dtype=vae_dtype) - # The transformer should already be in the correct dtype when training, so we don't need to cast it here. - # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during - # DDP optimizer step. - if not is_training: - pipe.transformer = pipe.transformer.to(dtype=transformer_dtype) - - if enable_slicing: - pipe.vae.enable_slicing() - if enable_tiling: - pipe.vae.enable_tiling() - - if enable_model_cpu_offload: - pipe.enable_model_cpu_offload(device=device) - else: - pipe.to(device=device) - - return pipe - - -def prepare_conditions( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: Union[str, List[str]], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - max_sequence_length: int = 128, - **kwargs, -) -> torch.Tensor: - device = device or text_encoder.device - dtype = dtype or text_encoder.dtype - - if isinstance(prompt, str): - prompt = [prompt] - - return _encode_prompt_t5(tokenizer, text_encoder, prompt, device, dtype, max_sequence_length) - - -def prepare_latents( - vae: AutoencoderKLLTXVideo, - image_or_video: torch.Tensor, - patch_size: int = 1, - patch_size_t: int = 1, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - generator: Optional[torch.Generator] = None, - precompute: bool = False, -) -> torch.Tensor: - device = device or vae.device - - if image_or_video.ndim == 4: - image_or_video = image_or_video.unsqueeze(2) - assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor" - - image_or_video = image_or_video.to(device=device, dtype=vae.dtype) - image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W] - if not precompute: - latents = vae.encode(image_or_video).latent_dist.sample(generator=generator) - latents = latents.to(dtype=dtype) - _, _, num_frames, height, width = latents.shape - latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std) - latents = _pack_latents(latents, patch_size, patch_size_t) - return {"latents": latents, "num_frames": num_frames, "height": height, "width": width} - else: - if vae.use_slicing and image_or_video.shape[0] > 1: - encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)] - h = torch.cat(encoded_slices) - else: - h = vae._encode(image_or_video) - _, _, num_frames, height, width = h.shape - - # TODO(aryan): This is very stupid that we might possibly be storing the latents_mean and latents_std in every file - # if precomputation is enabled. We should probably have a single file where re-usable properties like this are stored - # so as to reduce the disk memory requirements of the precomputed files. - return { - "latents": h, - "num_frames": num_frames, - "height": height, - "width": width, - "latents_mean": vae.latents_mean, - "latents_std": vae.latents_std, - } - - -def post_latent_preparation( - latents: torch.Tensor, - latents_mean: torch.Tensor, - latents_std: torch.Tensor, - num_frames: int, - height: int, - width: int, - patch_size: int = 1, - patch_size_t: int = 1, - **kwargs, -) -> torch.Tensor: - latents = _normalize_latents(latents, latents_mean, latents_std) - latents = _pack_latents(latents, patch_size, patch_size_t) - return {"latents": latents, "num_frames": num_frames, "height": height, "width": width} - - -def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: - return { - "prompts": [x["prompt"] for x in batch[0]], - "videos": torch.stack([x["video"] for x in batch[0]]), - } - - -def forward_pass( - transformer: LTXVideoTransformer3DModel, - prompt_embeds: torch.Tensor, - prompt_attention_mask: torch.Tensor, - latents: torch.Tensor, - noisy_latents: torch.Tensor, - timesteps: torch.LongTensor, - num_frames: int, - height: int, - width: int, - **kwargs, -) -> torch.Tensor: - # TODO(aryan): make configurable - frame_rate = 25 - latent_frame_rate = frame_rate / 8 - spatial_compression_ratio = 32 - rope_interpolation_scale = [1 / latent_frame_rate, spatial_compression_ratio, spatial_compression_ratio] - - denoised_latents = transformer( - hidden_states=noisy_latents, - encoder_hidden_states=prompt_embeds, - timestep=timesteps, - encoder_attention_mask=prompt_attention_mask, - num_frames=num_frames, - height=height, - width=width, - rope_interpolation_scale=rope_interpolation_scale, - return_dict=False, - )[0] - - return {"latents": denoised_latents} - - -def validation( - pipeline: LTXPipeline, - prompt: str, - image: Optional[Image.Image] = None, - video: Optional[List[Image.Image]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_frames: Optional[int] = None, - frame_rate: int = 24, - num_videos_per_prompt: int = 1, - generator: Optional[torch.Generator] = None, - **kwargs, -): - generation_kwargs = { - "prompt": prompt, - "height": height, - "width": width, - "num_frames": num_frames, - "frame_rate": frame_rate, - "num_videos_per_prompt": num_videos_per_prompt, - "generator": generator, - "return_dict": True, - "output_type": "pil", - } - generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} - video = pipeline(**generation_kwargs).frames[0] - return [("video", video)] - - -def _encode_prompt_t5( - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - prompt: List[str], - device: torch.device, - dtype: torch.dtype, - max_sequence_length, -) -> torch.Tensor: - batch_size = len(prompt) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.bool().to(device) - - prompt_embeds = text_encoder(text_input_ids.to(device))[0] - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) - - return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask} - - -def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 -) -> torch.Tensor: - # Normalize latents across the channel dimension [B, C, F, H, W] - latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - latents = (latents - latents_mean) * scaling_factor / latents_std - return latents - - -def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: - # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. - # The patch dimensions are then permuted and collapsed into the channel dimension of shape: - # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). - # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features - batch_size, num_channels, num_frames, height, width = latents.shape - post_patch_num_frames = num_frames // patch_size_t - post_patch_height = height // patch_size - post_patch_width = width // patch_size - latents = latents.reshape( - batch_size, - -1, - post_patch_num_frames, - patch_size_t, - post_patch_height, - patch_size, - post_patch_width, - patch_size, - ) - latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) - return latents - - -LTX_VIDEO_T2V_LORA_CONFIG = { - "pipeline_cls": LTXPipeline, - "load_condition_models": load_condition_models, - "load_latent_models": load_latent_models, - "load_diffusion_models": load_diffusion_models, - "initialize_pipeline": initialize_pipeline, - "prepare_conditions": prepare_conditions, - "prepare_latents": prepare_latents, - "post_latent_preparation": post_latent_preparation, - "collate_fn": collate_fn_t2v, - "forward_pass": forward_pass, - "validation": validation, -} diff --git a/finetrainers/models/modeling_utils.py b/finetrainers/models/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..805b0d94e30c9d56157634740672b27f642993c4 --- /dev/null +++ b/finetrainers/models/modeling_utils.py @@ -0,0 +1,292 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import FrozenDict +from PIL.Image import Image + +from ..logging import get_logger +from ..parallel import ParallelBackendEnum +from ..processors import ProcessorMixin +from ..typing import ArtifactType, SchedulerType, TokenizerType +from ..utils import resolve_component_cls + + +logger = get_logger() + +# TODO(aryan): we most likely don't need this. take a look after refactoring more +# fmt: off +IGNORE_KEYS_FOR_COLLATION = {"height", "width", "num_frames", "frame_rate", "rope_interpolation_scale", "return_dict", "attention_kwargs", "cross_attention_kwargs", "joint_attention_kwargs", "latents_mean", "latents_std"} +# fmt: on + + +class ModelSpecification: + r""" + The ModelSpecification class is an interface to be used for Diffusion training recipes. It provides + loose structure about how to organize the code for training. The trainer implementations will + make use of this interface to load models, prepare conditions, prepare latents, forward pass, etc. + """ + + def __init__( + self, + pretrained_model_name_or_path: Optional[str] = None, + tokenizer_id: Optional[str] = None, + tokenizer_2_id: Optional[str] = None, + tokenizer_3_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + text_encoder_2_id: Optional[str] = None, + text_encoder_3_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + text_encoder_2_dtype: torch.dtype = torch.bfloat16, + text_encoder_3_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: str = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + ) -> None: + self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.tokenizer_id = tokenizer_id + self.tokenizer_2_id = tokenizer_2_id + self.tokenizer_3_id = tokenizer_3_id + self.text_encoder_id = text_encoder_id + self.text_encoder_2_id = text_encoder_2_id + self.text_encoder_3_id = text_encoder_3_id + self.transformer_id = transformer_id + self.vae_id = vae_id + self.text_encoder_dtype = text_encoder_dtype + self.text_encoder_2_dtype = text_encoder_2_dtype + self.text_encoder_3_dtype = text_encoder_3_dtype + self.transformer_dtype = transformer_dtype + self.vae_dtype = vae_dtype + self.revision = revision + self.cache_dir = cache_dir + self.condition_model_processors = condition_model_processors or [] + self.latent_model_processors = latent_model_processors or [] + + self.transformer_config: Dict[str, Any] = None + self.vae_config: Dict[str, Any] = None + + self._load_configs() + + # TODO(aryan): revisit how to do this better without user having to worry about it + @property + def _resolution_dim_keys(self) -> Dict[str, Tuple[int, ...]]: + raise NotImplementedError( + f"ModelSpecification::_resolution_dim_keys is not implemented for {self.__class__.__name__}" + ) + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + raise NotImplementedError( + f"ModelSpecification::load_condition_models is not implemented for {self.__class__.__name__}" + ) + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + raise NotImplementedError( + f"ModelSpecification::load_latent_models is not implemented for {self.__class__.__name__}" + ) + + def load_diffusion_models(self) -> Dict[str, Union[torch.nn.Module]]: + raise NotImplementedError( + f"ModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}" + ) + + def load_pipeline( + self, + tokenizer: Optional[TokenizerType] = None, + tokenizer_2: Optional[TokenizerType] = None, + tokenizer_3: Optional[TokenizerType] = None, + text_encoder: Optional[torch.nn.Module] = None, + text_encoder_2: Optional[torch.nn.Module] = None, + text_encoder_3: Optional[torch.nn.Module] = None, + transformer: Optional[torch.nn.Module] = None, + vae: Optional[torch.nn.Module] = None, + scheduler: Optional[SchedulerType] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> DiffusionPipeline: + raise NotImplementedError( + f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}" + ) + + def collate_fn(self, batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]: + raise NotImplementedError(f"ModelSpecification::collate_fn is not implemented for {self.__class__.__name__}") + + def prepare_conditions(self, **kwargs) -> Dict[str, Any]: + for processor in self.condition_model_processors: + result = processor(**kwargs) + result_keys = set(result.keys()) + repeat_keys = result_keys.intersection(kwargs.keys()) + if repeat_keys: + logger.warning( + f"Processor {processor.__class__.__name__} returned keys that already exist in " + f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " + f"be intended. Please rename the keys in the processor to avoid conflicts." + ) + kwargs.update(result) + return kwargs + + def prepare_latents(self, **kwargs) -> Dict[str, Any]: + for processor in self.latent_model_processors: + result = processor(**kwargs) + result_keys = set(result.keys()) + repeat_keys = result_keys.intersection(kwargs.keys()) + if repeat_keys: + logger.warning( + f"Processor {processor.__class__.__name__} returned keys that already exist in " + f"conditions: {repeat_keys}. Overwriting the existing values, but this may not " + f"be intended. Please rename the keys in the processor to avoid conflicts." + ) + kwargs.update(result) + return kwargs + + def collate_conditions(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + keys = list(data[0].keys()) + collated_data = {} + for key in keys: + if key in IGNORE_KEYS_FOR_COLLATION: + collated_data[key] = data[0][key] + continue + collated_d = [d[key] for d in data] + if isinstance(collated_d[0], torch.Tensor): + collated_d = torch.cat(collated_d) + collated_data[key] = collated_d + return collated_data + + def collate_latents(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: + keys = list(data[0].keys()) + collated_data = {} + for key in keys: + if key in IGNORE_KEYS_FOR_COLLATION: + collated_data[key] = data[0][key] + continue + collated_d = [d[key] for d in data] + # TODO(aryan): Support multi-resolution collation + if isinstance(collated_d[0], torch.Tensor): + collated_d = torch.cat(collated_d) + collated_data[key] = collated_d + return collated_data + + def forward( + self, transformer: torch.nn.Module, generator: Optional[torch.Generator] = None, **kwargs + ) -> Dict[str, torch.Tensor]: + raise NotImplementedError(f"ModelSpecification::forward is not implemented for {self.__class__.__name__}") + + def validation( + self, + pipeline: DiffusionPipeline, + prompt: Optional[str] = None, + image: Optional[Image] = None, + video: Optional[List[Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + frame_rate: Optional[int] = None, + generator: Optional[torch.Generator] = None, + ) -> List[ArtifactType]: + raise NotImplementedError(f"ModelSpecification::validation is not implemented for {self.__class__.__name__}") + + def _save_lora_weights( + self, + directory: str, + transformer: torch.nn.Module, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + r""" + Save the lora state dicts of the model to the given directory. + + This API is not backwards compatible and will be changed in near future. + """ + raise NotImplementedError( + f"ModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}" + ) + + def _save_model( + self, + directory: str, + transformer: torch.nn.Module, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + r""" + Save the state dicts to the given directory. + + This API is not backwards compatible and will be changed in near future. + """ + raise NotImplementedError(f"ModelSpecification::save_model is not implemented for {self.__class__.__name__}") + + def apply_tensor_parallel( + self, + backend: ParallelBackendEnum, + device_mesh: torch.distributed.DeviceMesh, + text_encoder: torch.nn.Module, + text_encoder_2: torch.nn.Module, + text_encoder_3: torch.nn.Module, + transformer: torch.nn.Module, + vae: torch.nn.Module, + ) -> None: + raise NotImplementedError( + f"ModelSpecification::apply_tensor_parallel is not implemented for {self.__class__.__name__}" + ) + + def _load_configs(self) -> None: + self._load_transformer_config() + self._load_vae_config() + + def _load_transformer_config(self) -> None: + if self.transformer_id is not None: + transformer_cls = resolve_component_cls( + self.transformer_id, + component_name="_class_name", + filename="config.json", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.transformer_config = transformer_cls.load_config( + self.transformer_id, revision=self.revision, cache_dir=self.cache_dir + ) + else: + transformer_cls = resolve_component_cls( + self.pretrained_model_name_or_path, + component_name="transformer", + filename="model_index.json", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.transformer_config = transformer_cls.load_config( + self.pretrained_model_name_or_path, + subfolder="transformer", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.transformer_config = FrozenDict(**self.transformer_config) + + def _load_vae_config(self) -> None: + if self.vae_id is not None: + vae_cls = resolve_component_cls( + self.vae_id, + component_name="_class_name", + filename="config.json", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.vae_config = vae_cls.load_config(self.vae_id, revision=self.revision, cache_dir=self.cache_dir) + else: + vae_cls = resolve_component_cls( + self.pretrained_model_name_or_path, + component_name="vae", + filename="model_index.json", + revision=self.revision, + cache_dir=self.cache_dir, + ) + self.vae_config = vae_cls.load_config( + self.pretrained_model_name_or_path, subfolder="vae", revision=self.revision, cache_dir=self.cache_dir + ) + self.vae_config = FrozenDict(**self.vae_config) diff --git a/finetrainers/models/utils.py b/finetrainers/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aeda1e4379dfc6ab1d7baba1807f6e0ac71d779b --- /dev/null +++ b/finetrainers/models/utils.py @@ -0,0 +1,62 @@ +from typing import Optional, Tuple + +import numpy as np +import torch +from diffusers.utils.torch_utils import randn_tensor + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False, _dim: int = 1): + # Note: _dim is the new argument added here after copying from diffusers + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=_dim) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean diff --git a/finetrainers/models/wan/__init__.py b/finetrainers/models/wan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bfeae2994e6b83b8fcb1337602e5cb39c73fdc7 --- /dev/null +++ b/finetrainers/models/wan/__init__.py @@ -0,0 +1 @@ +from .base_specification import WanModelSpecification diff --git a/finetrainers/models/wan/base_specification.py b/finetrainers/models/wan/base_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..2d27f428dde35f9203b8fb889e209c63b0f1e0da --- /dev/null +++ b/finetrainers/models/wan/base_specification.py @@ -0,0 +1,378 @@ +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + WanImageToVideoPipeline, + WanPipeline, + WanTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from PIL.Image import Image +from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel + +from ... import data +from ... import functional as FF +from ...logging import get_logger +from ...processors import ProcessorMixin, T5Processor +from ...typing import ArtifactType, SchedulerType +from ...utils import get_non_null_items +from ..modeling_utils import ModelSpecification + + +logger = get_logger() + + +class WanLatentEncodeProcessor(ProcessorMixin): + r""" + Processor to encode image/video into latents using the Wan VAE. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor returns. The outputs are in the following order: + - latents: The latents of the input image/video. + - num_frames: The number of frames in the input video. + - height: The height of the input image/video. + - width: The width of the input image/video. + - latents_mean: The latent channel means from the VAE state dict. + - latents_std: The latent channel standard deviations from the VAE state dict. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + self.output_names = output_names + assert len(self.output_names) == 1 + + def forward( + self, + vae: AutoencoderKLWan, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + ) -> Dict[str, torch.Tensor]: + device = vae.device + dtype = vae.dtype + + if image is not None: + video = image.unsqueeze(1) + + assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" + video = video.to(device=device, dtype=vae.dtype) + video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + + if compute_posterior: + latents = vae.encode(video).latent_dist.sample(generator=generator) + latents = latents.to(dtype=dtype) + else: + # TODO(aryan): refactor in diffusers to have use_slicing attribute + # if vae.use_slicing and video.shape[0] > 1: + # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] + # moments = torch.cat(encoded_slices) + # else: + # moments = vae._encode(video) + moments = vae._encode(video) + latents = moments.to(dtype=dtype) + + return {self.output_names[0]: latents} + + +class WanModelSpecification(ModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])] + if latent_model_processors is None: + latent_model_processors = [WanLatentEncodeProcessor(["latents"])] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + + @property + def _resolution_dim_keys(self): + # TODO + return { + "latents": (2, 3, 4), + } + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=self.revision, + cache_dir=self.cache_dir, + ) + + if self.text_encoder_id is not None: + text_encoder = AutoModel.from_pretrained( + self.text_encoder_id, + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + text_encoder = UMT5EncoderModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return {"tokenizer": tokenizer, "text_encoder": text_encoder} + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + if self.vae_id is not None: + vae = AutoencoderKLWan.from_pretrained( + self.vae_id, + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + vae = AutoencoderKLWan.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="vae", + torch_dtype=self.vae_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + return {"vae": vae} + + def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: + if self.transformer_id is not None: + transformer = WanTransformer3DModel.from_pretrained( + self.transformer_id, + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + else: + transformer = WanTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + revision=self.revision, + cache_dir=self.cache_dir, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + text_encoder: Optional[UMT5EncoderModel] = None, + transformer: Optional[WanTransformer3DModel] = None, + vae: Optional[AutoencoderKLWan] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> WanPipeline: + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = WanPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.vae.to(self.vae_dtype) + + if not training: + pipe.transformer.to(self.transformer_dtype) + + # TODO(aryan): add support in diffusers + # if enable_slicing: + # pipe.vae.enable_slicing() + # if enable_tiling: + # pipe.vae.enable_tiling() + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + caption: str, + max_sequence_length: int = 512, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + conditions.pop("prompt_attention_mask", None) + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLWan, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + conditions = { + "vae": vae, + "image": image, + "video": video, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + def forward( + self, + transformer: WanTransformer3DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + noise = torch.zeros_like(latents).normal_(generator=generator) + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds") + + timesteps = (sigmas.flatten() * 1000.0).long() + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: WanPipeline, + prompt: str, + image: Optional[Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + **kwargs, + ) -> List[ArtifactType]: + if image is not None: + pipeline = WanImageToVideoPipeline.from_pipe(pipeline) + + generation_kwargs = { + "prompt": prompt, + "image": image, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + video = pipeline(**generation_kwargs).frames[0] + return [data.VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + WanPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: WanTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = WanTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) diff --git a/finetrainers/optimizer.py b/finetrainers/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..57da28e9377f2bf82b5307fae83338ad0b9ec385 --- /dev/null +++ b/finetrainers/optimizer.py @@ -0,0 +1,449 @@ +import functools +import math +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import torch +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_optimizer_state_dict, + set_optimizer_state_dict, +) +from torch.distributed.checkpoint.stateful import Stateful + +from .parallel import ParallelBackendEnum +from .utils.import_utils import is_bitsandbytes_available + + +class OptimizerWrapper(Stateful): + r""" + Optimizer wrapper that: + - allows step/zero_grad on multiple optimizers needed for virtual pipeline stages + - saves/loading optimizer state_dict at checkpoint + """ + + def __init__( + self, + model_parts: List[torch.nn.Module], + optimizer_cls: Type[torch.optim.Optimizer], + optimizer_kwargs: Dict[str, Any], + ) -> None: + self.optimizer_cls = optimizer_cls + self.optimizer_kwargs = optimizer_kwargs + + self.optimizers = [] + self.model_parts = model_parts + + for model in self.model_parts: + optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) + self.optimizers.append(optimizer) + + def step(self) -> None: + for optimizer in self.optimizers: + optimizer.step() + + def zero_grad(self) -> None: + for optimizer in self.optimizers: + optimizer.zero_grad() + + def state_dict(self) -> Dict[str, Any]: + func = functools.partial( + get_optimizer_state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + func = functools.partial( + set_optimizer_state_dict, + optim_state_dict=state_dict, + options=StateDictOptions(flatten_optimizer_state_dict=True), + ) + list(map(func, self.model_parts, self.optimizers)) + + +class SchedulerWrapper: + def __init__( + self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int + ) -> None: + self.schedulers = [] + for optimizer in optimizers: + self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)) + + def step(self) -> None: + for scheduler in self.schedulers: + scheduler.step() + + def get_last_lr(self) -> List[float]: + # TODO(aryan): look into this later. Currently calling it leads to NCCL hang????? + return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)} + + def get_lr_scheduler_state(self) -> Dict[str, Any]: + state_dict = {} + if len(self.schedulers) == 1: + state_dict["lr_scheduler"] = self.schedulers[0] + else: + # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. + # It should only support saving and loading a distributed checkpoint with the same number of pp ranks + for idx, lr_scheduler in enumerate(self.schedulers): + state_dict[f"lr_scheduler_{idx}"] = lr_scheduler + return state_dict + + +def get_optimizer( + parallel_backend: ParallelBackendEnum, + name: str, + model_parts: List[torch.nn.Module], + learning_rate: float = 1e-3, + beta1: float = 0.9, + beta2: float = 0.95, + beta3: float = 0.999, + epsilon: float = 1e-8, + weight_decay: float = 1e-4, + fused: bool = False, +) -> Union[torch.optim.Optimizer, OptimizerWrapper]: + name = name.lower() + + _raise_errors_if_packages_not_available(name) + + if name == "adam": + optimizer_cls = torch.optim.Adam + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + "fused": fused, + } + elif name == "adamw": + optimizer_cls = torch.optim.AdamW + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + "fused": fused, + } + elif name == "adam-bnb": + from bitsandbytes.optim import Adam + + optimizer_cls = Adam + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + elif name == "adamw-bnb": + from bitsandbytes.optim import AdamW + + optimizer_cls = AdamW + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + elif name == "adam-bnb-8bit": + from bitsandbytes.optim import Adam8bit + + optimizer_cls = Adam8bit + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + elif name == "adamw-bnb-8bit": + from bitsandbytes.optim import AdamW8bit + + optimizer_cls = AdamW8bit + optimizer_kwargs = { + "lr": learning_rate, + "betas": (beta1, beta2), + "eps": epsilon, + "weight_decay": weight_decay, + } + + # TODO(aryan): handle bitsandbytes and torchao + else: + raise ValueError(f"Unsupported optimizer: {name}") + + if parallel_backend == ParallelBackendEnum.ACCELERATE: + return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs) + elif parallel_backend == ParallelBackendEnum.PTD: + return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs) + + +def get_optimizer_accelerate( + model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] +) -> torch.optim.Optimizer: + params = [param for model in model_parts for param in model.parameters() if param.requires_grad] + optimizer = optimizer_cls(params, **optimizer_kwargs) + return optimizer + + +def get_optimizer_ptd( + model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] +) -> OptimizerWrapper: + return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs) + + +def get_lr_scheduler( + parallel_backend: ParallelBackendEnum, + name: str, + optimizer: Union[torch.optim.Optimizer, OptimizerWrapper], + step_rules: Optional[str] = None, + num_warmup_steps: Optional[int] = None, + num_training_steps: Optional[int] = None, + num_cycles: int = 1, + power: float = 1.0, + lr_init: float = 1e-3, + lr_end: float = 1e-7, + last_epoch: int = -1, +) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]: + name = name.lower() + if name == "constant": + scheduler_lambda_fn = get_constant_schedule() + elif name == "constant_with_warmup": + scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps) + elif name == "piecewise_constant": + scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules) + elif name == "linear": + scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps) + elif name == "cosine": + scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles) + elif name == "cosine_with_restarts": + scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup( + num_warmup_steps, num_training_steps, num_cycles + ) + elif name == "polynomial": + scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup( + num_warmup_steps, num_training_steps, lr_init, lr_end, power + ) + else: + raise ValueError(f"Unsupported scheduler: {name}") + + if parallel_backend == ParallelBackendEnum.ACCELERATE: + return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch) + elif parallel_backend == ParallelBackendEnum.PTD: + return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch) + + +def get_lr_scheduler_accelerate( + optimizer: torch.optim.Optimizer, + scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], + last_epoch: int = -1, +) -> torch.optim.lr_scheduler.LambdaLR: + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch) + return scheduler + + +def get_lr_scheduler_ptd( + optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1 +) -> SchedulerWrapper: + return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch) + + +# ============================== +# Adapted from https://github.com/huggingface/diffusers/blob/196aef5a6f76e1ad6ba889184860c3633d166910/src/diffusers/optimization.py +# ============================== + + +def get_constant_schedule() -> Callable[[int], float]: + r""" + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + """ + + def lr_lambda(current_step: int): + return 1.0 + + return lr_lambda + + +def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]: + r""" + Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate + increases linearly between 0 and the initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + return lr_lambda + + +def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]: + r""" + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + step_rules (`string`): + The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate + if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 + steps and multiple 0.005 for the other steps. + """ + + rules_dict = {} + rule_list = step_rules.split(",") + for rule_str in rule_list[:-1]: + value_str, steps_str = rule_str.split(":") + steps = int(steps_str) + value = float(value_str) + rules_dict[steps] = value + last_lr_multiple = float(rule_list[-1]) + + def create_rules_function(rules_dict, last_lr_multiple): + def rule_func(steps: int) -> float: + sorted_steps = sorted(rules_dict.keys()) + for i, sorted_step in enumerate(sorted_steps): + if steps < sorted_step: + return rules_dict[sorted_steps[i]] + return last_lr_multiple + + return rule_func + + rules_func = create_rules_function(rules_dict, last_lr_multiple) + return rules_func + + +def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]: + r""" + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + + return lr_lambda + + +def get_cosine_schedule_with_warmup( + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, +) -> Callable[[int], float]: + r""" + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_periods (`float`, *optional*, defaults to 0.5): + The number of periods of the cosine function in a schedule (the default is to just decrease from the max + value to 0 following a half-cosine). + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + return lr_lambda + + +def get_cosine_with_hard_restarts_schedule_with_warmup( + num_warmup_steps: int, + num_training_steps: int, + num_cycles: int = 1, +) -> Callable[[int], float]: + r""" + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases + linearly between 0 and the initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`int`, *optional*, defaults to 1): + The number of hard restarts to use. + """ + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + + return lr_lambda + + +def get_polynomial_decay_schedule_with_warmup( + num_warmup_steps: int, + num_training_steps: int, + lr_init: float, + lr_end: float = 1e-7, + power: float = 1.0, +) -> Callable[[int], float]: + r""" + Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the + optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the + initial lr set in the optimizer. + + Args: + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + lr_end (`float`, *optional*, defaults to 1e-7): + The end LR. + power (`float`, *optional*, defaults to 1.0): + Power factor. + + Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at + https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 + """ + + if not (lr_init > lr_end): + raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})") + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # as LambdaLR multiplies by lr_init + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # as LambdaLR multiplies by lr_init + + return lr_lambda + + +def _raise_errors_if_packages_not_available(name: str) -> None: + name_split = name.split("-") + if len(name_split) < 2: + return + package_name = name_split[1] + if package_name == "bnb": + if not is_bitsandbytes_available(): + raise ImportError( + f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer." + ) diff --git a/finetrainers/parallel/__init__.py b/finetrainers/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fbce75fadfc33312f6d655fc581ac90cfb6c577 --- /dev/null +++ b/finetrainers/parallel/__init__.py @@ -0,0 +1,22 @@ +from enum import Enum +from typing import Union + +from .accelerate import AccelerateParallelBackend +from .ptd import PytorchDTensorParallelBackend +from .utils import apply_ddp_ptd, apply_fsdp2_ptd, dist_max, dist_mean + + +ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend] + + +class ParallelBackendEnum(str, Enum): + ACCELERATE = "accelerate" + PTD = "ptd" + + +def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType: + if backend == ParallelBackendEnum.ACCELERATE: + return AccelerateParallelBackend + if backend == ParallelBackendEnum.PTD: + return PytorchDTensorParallelBackend + raise ValueError(f"Unknown parallel backend: {backend}") diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py new file mode 100644 index 0000000000000000000000000000000000000000..9a523321f0c333af54949cb2274fb2d60cf014ff --- /dev/null +++ b/finetrainers/parallel/accelerate.py @@ -0,0 +1,218 @@ +import datetime +import pathlib +from typing import Optional + +import torch +from diffusers.utils import is_accelerate_available + +from ..logging import get_logger +from ..utils import get_device_info +from .base import BaseParallelBackend +from .utils import apply_ddp_accelerate + + +if not is_accelerate_available(): + raise ImportError( + "Please install the accelerate package using `pip install accelerate` to use the AccelerateParallelBackend." + ) + +from accelerate import Accelerator +from accelerate.data_loader import DataLoader +from accelerate.utils import ( + DataLoaderConfiguration, + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, +) + + +logger = get_logger() +_device_type, _device_module = get_device_info() + + +class AccelerateParallelBackend(BaseParallelBackend): + def __init__( + self, + world_size: int, + pp_degree: int = 1, + dp_degree: int = 1, + dp_shards: int = -1, + cp_degree: int = 1, + tp_degree: int = 1, + backend: str = "nccl", + timeout: int = 180, + logging_dir: Optional[str] = None, + output_dir: Optional[str] = None, + gradient_accumulation_steps: Optional[int] = None, + ) -> None: + super().__init__() + + self._world_size = world_size + self._pp_degree = pp_degree + self._dp_degree = dp_degree + self._dp_shards = dp_shards + self._cp_degree = cp_degree + self._tp_degree = tp_degree + self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None + self._logging_dir = ( + self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None + ) + self._backend = backend + self._timeout = timeout + self._gradient_accumulation_steps = gradient_accumulation_steps + + if pp_degree > 1 or dp_shards > 1 or cp_degree > 1 or tp_degree > 1: + raise ValueError( + "AccelerateParallelBackend does not support anything but Distributed Data Parallelism at the moment." + ) + if dp_degree != world_size: + raise ValueError("Data parallel degree must be equal to world size.") + + self._accelerator: Accelerator = None + self._mesh: torch.distributed.DeviceMesh = None + + def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: + project_config = None + ddp_kwargs = None + init_process_group_kwargs = None + if self._accelerator is None: + project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir) + ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) + dataloader_config = DataLoaderConfiguration( + split_batches=False, dispatch_batches=False, use_stateful_dataloader=True + ) + init_process_group_kwargs = InitProcessGroupKwargs( + backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout) + ) + self._accelerator, model = apply_ddp_accelerate( + model, + project_config, + ddp_kwargs, + init_process_group_kwargs, + dataloader_config, + self._gradient_accumulation_steps, + accelerator=self._accelerator, + ) + logger.debug("Applied AccelerateParallel::apply_ddp to model.") + return model + + def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: + logger.debug("AccelerateParallelBackend::prepare_dataset completed!") + return dataset + + def prepare_dataloader( + self, + dataset: torch.utils.data.IterableDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + ) -> DataLoader: + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory + ) + dataloader = self._accelerator.prepare_data_loader(dataloader) + logger.debug("AccelerateParallelBackend::prepare_dataloader completed!") + return dataloader + + def prepare_optimizer(self, optimizer, lr_scheduler): + optimizer = self._accelerator.prepare_optimizer(optimizer) + lr_scheduler = self._accelerator.prepare_scheduler(lr_scheduler) + return optimizer, lr_scheduler + + def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: + def _get_mesh(): + if name is None: + return self._mesh + try: + return self._mesh[name] + except (KeyError, RuntimeError): + return self._mesh + + if self._mesh is not None: + return _get_mesh() + + mesh_list = [("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards)] + mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] + names = [x[0] for x in mesh_list] + degrees = [x[1] for x in mesh_list] + mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) + + dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] + + if self.data_replication_enabled: + dp_mesh_names.append("dp_replicate") + dp_cp_mesh_names.append("dp_replicate") + if self.data_sharding_enabled: + dp_mesh_names.append("dp_shard") + dp_cp_mesh_names.append("dp_shard") + dp_shard_cp_mesh_names.append("dp_shard") + if self.context_parallel_enabled: + dp_cp_mesh_names.append("cp") + dp_shard_cp_mesh_names.append("cp") + + if len(dp_mesh_names) > 0: + mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") + if len(dp_cp_mesh_names) > 0: + mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") + if len(dp_shard_cp_mesh_names) > 0: + mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") + + logger.debug(f"Device mesh: {mesh}") + self._mesh = mesh + return _get_mesh() + + @property + def world_size(self): + return self._accelerator.num_processes + + @property + def rank(self): + return self._accelerator.process_index + + @property + def local_rank(self): + return self._accelerator.local_process_index + + @property + def is_main_process(self): + r"""Returns `True` if the current process is the main process on the master node.""" + return self._accelerator.is_main_process + + @property + def is_local_main_process(self): + r"""Returns `True` if the current process is the main process on local node.""" + return self._accelerator.is_local_main_process + + @property + def device(self): + return self._accelerator.device + + def wait_for_everyone(self): + self._accelerator.wait_for_everyone() + + def destroy(self): + self._accelerator.end_training() + + @property + def pipeline_parallel_enabled(self): + return self._pp_degree > 1 + + @property + def data_parallel_enabled(self): + return self._dp_degree > 1 or self._dp_shards > 1 + + @property + def data_replication_enabled(self): + return self._dp_degree > 1 + + @property + def data_sharding_enabled(self): + return self._dp_shards > 1 + + @property + def context_parallel_enabled(self): + return self._cp_degree > 1 + + @property + def tensor_parallel_enabled(self): + return self._tp_degree > 1 diff --git a/finetrainers/parallel/base.py b/finetrainers/parallel/base.py new file mode 100644 index 0000000000000000000000000000000000000000..eb982ca252fdcd03b644bf5a3215857fada7119c --- /dev/null +++ b/finetrainers/parallel/base.py @@ -0,0 +1,96 @@ +from contextlib import contextmanager +from typing import Any, Dict, List, Optional + +import torch + +from ..trackers import TrackerType, initialize_trackers + + +class BaseParallelBackend: + r""" + Base class that contains properties and methods that should be implemented by different parallel backends. + """ + + def apply_ddp(self, *args, **kwargs) -> torch.nn.Module: + raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.") + + def prepare_dataset(self, *args, **kwargs) -> Any: + raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.") + + def prepare_dataloader(self, *args, **kwargs) -> Any: + raise NotImplementedError("Method `prepare_dataloader` must be implemented by subclass.") + + def prepare_optimizer(self, *args, **kwargs) -> Any: + raise NotImplementedError("Method `prepare_optimizer` must be implemented by subclass.") + + def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: + raise NotImplementedError("Method `get_mesh` must be implemented by subclass.") + + def initialize_trackers( + self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str + ) -> TrackerType: + self.tracker = None + if self.is_main_process: + self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir) + + def log(self, metrics: Dict[str, Any], step: int) -> None: + if self.is_main_process: + self.tracker.log(metrics, step) + + def wait_for_everyone(self): + raise NotImplementedError("Method `wait_for_everyone` must be implemented by subclass.") + + @contextmanager + def main_process_first(self): + raise NotImplementedError("Method `main_process_first` must be implemented by subclass.") + + def destroy(self): + raise NotImplementedError("Method `destroy` must be implemented by subclass.") + + @property + def world_size(self): + raise NotImplementedError("Method `world_size` must be implemented by subclass.") + + @property + def rank(self): + raise NotImplementedError("Method `rank` must be implemented by subclass.") + + @property + def local_rank(self): + raise NotImplementedError("Method `local_rank` must be implemented by subclass.") + + @property + def is_main_process(self): + raise NotImplementedError("Method `is_main_process` must be implemented by subclass.") + + @property + def is_local_main_process(self): + raise NotImplementedError("Method `is_local_main_process` must be implemented by subclass.") + + @property + def device(self): + raise NotImplementedError("Method `device` must be implemented by subclass.") + + @property + def pipeline_parallel_enabled(self): + raise NotImplementedError("Property `pipeline_parallel_enabled` must be implemented by subclass.") + + @property + def data_parallel_enabled(self): + raise NotImplementedError("Property `data_parallel_enabled` must be implemented by subclass.") + + @property + def data_replication_enabled(self): + raise NotImplementedError("Property `data_replication_enabled` must be implemented by subclass.") + + @property + def data_sharding_enabled(self): + raise NotImplementedError("Property `data_sharding_enabled` must be implemented by subclass.") + + @property + def context_parallel_enabled(self): + raise NotImplementedError("Property `context_parallel_enabled` must be implemented by subclass.") + + @property + def tensor_parallel_enabled(self): + raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.") diff --git a/finetrainers/parallel/deepspeed.py b/finetrainers/parallel/deepspeed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9f54d66ec1941ffc44d6239b305cc397ce61d4 --- /dev/null +++ b/finetrainers/parallel/deepspeed.py @@ -0,0 +1,7 @@ +from .base import BaseParallelBackend + + +class DeepspeedParallelBackend(BaseParallelBackend): + def __init__(self): + # TODO(aryan) + raise NotImplementedError("DeepspeedParallelBackend is not implemented yet.") diff --git a/finetrainers/parallel/ptd.py b/finetrainers/parallel/ptd.py new file mode 100644 index 0000000000000000000000000000000000000000..352273b4eff7f4cb21424962748a84ff29d96426 --- /dev/null +++ b/finetrainers/parallel/ptd.py @@ -0,0 +1,228 @@ +import datetime +import os +import pathlib +from typing import Optional + +import datasets.distributed +import torch + +from ..data import DPDataLoader +from ..logging import get_logger +from ..utils import get_device_info +from .base import BaseParallelBackend +from .utils import apply_ddp_ptd + + +_device_type, _device_module = get_device_info() +logger = get_logger() + + +class PytorchDTensorParallelBackend(BaseParallelBackend): + def __init__( + self, + world_size: int, + pp_degree: int = 1, + dp_degree: int = 1, + dp_shards: int = -1, + cp_degree: int = 1, + tp_degree: int = 1, + backend: str = "nccl", + timeout: int = 180, + logging_dir: Optional[str] = None, + output_dir: Optional[str] = None, + gradient_accumulation_steps: Optional[int] = None, + ) -> None: + super().__init__() + + self._world_size = world_size + self._pp_degree = pp_degree + self._dp_degree = dp_degree + self._dp_shards = dp_shards + self._cp_degree = cp_degree + self._tp_degree = tp_degree + self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None + self._logging_dir = ( + self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None + ) + self._backend = backend + self._timeout = timeout + + for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]: + if degree < 1: + raise ValueError(f"Parallel degree must be at least 1, got {degree}.") + + if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size: + raise ValueError( + f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards." + ) + + torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)) + _device_module.set_device(self.local_rank) + + logger.info( + f"Initialized parallel state with:\n" + f" - World size: {world_size}\n" + f" - Pipeline parallel degree: {pp_degree}\n" + f" - Data parallel degree: {dp_degree}\n" + f" - Context parallel degree: {cp_degree}\n" + f" - Tensor parallel degree: {tp_degree}\n" + f" - Data parallel shards: {dp_shards}\n" + ) + + self._mesh: torch.distributed.DeviceMesh = None + + def apply_ddp( + self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None + ) -> torch.nn.Module: + if device_mesh is None: + device_mesh = self.get_mesh() + apply_ddp_ptd(model, device_mesh) + logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.") + return model + + def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset: + dp_mesh = self.get_mesh("dp_replicate") + if dp_mesh is None: + dp_mesh = self.get_mesh() + if self.world_size > 1: + dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() + else: + dp_local_rank, dp_world_size = 0, 1 + dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size) + logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!") + return dataset + + def prepare_dataloader( + self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool + ) -> DPDataLoader: + dp_mesh = self.get_mesh("dp_replicate") + if dp_mesh is None: + dp_mesh = self.get_mesh() + if self.world_size > 1: + dp_local_rank = dp_mesh.get_local_rank() + else: + dp_local_rank = 0 + dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers) + logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!") + return dataloader + + def prepare_optimizer(self, optimizer, lr_scheduler): + logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!") + return optimizer, lr_scheduler + + def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh: + def _get_mesh(): + if name is None: + return self._mesh + try: + return self._mesh[name] + except (KeyError, RuntimeError): + if self._mesh.ndim == 0: + return None + return self._mesh + + if self._mesh is not None: + return _get_mesh() + + mesh_list = [ + ("pp", self._pp_degree), + ("dp_replicate", self._dp_degree), + ("dp_shard", self._dp_shards), + ("cp", self._cp_degree), + ("tp", self._tp_degree), + ] + mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1] + names = [x[0] for x in mesh_list] + degrees = [x[1] for x in mesh_list] + mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names) + + dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], [] + + if self.data_replication_enabled: + dp_mesh_names.append("dp_replicate") + dp_cp_mesh_names.append("dp_replicate") + if self.data_sharding_enabled: + dp_mesh_names.append("dp_shard") + dp_cp_mesh_names.append("dp_shard") + dp_shard_cp_mesh_names.append("dp_shard") + if self.context_parallel_enabled: + dp_cp_mesh_names.append("cp") + dp_shard_cp_mesh_names.append("cp") + + if len(dp_mesh_names) > 0: + mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp") + if len(dp_cp_mesh_names) > 0: + mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp") + if len(dp_shard_cp_mesh_names) > 0: + mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp") + + logger.debug(f"Device mesh: {mesh}") + self._mesh = mesh + return _get_mesh() + + @property + def world_size(self): + return torch.distributed.get_world_size() + + @property + def rank(self): + return torch.distributed.get_rank() + + @property + def local_rank(self): + return int(os.environ.get("LOCAL_RANK", 0)) + + @property + def is_main_process(self): + r"""Returns `True` if the current process is the main process on the master node.""" + return self.rank == 0 + + @property + def is_local_main_process(self): + r"""Returns `True` if the current process is the main process on local node.""" + return self.local_rank == 0 + + @property + def device(self): + return torch.device(_device_type, self.local_rank) + + def wait_for_everyone(self): + return torch.distributed.barrier() + + # @contextmanager + # def main_process_first(self): + # if self.is_main_process: + # yield + # self.wait_for_everyone() + # else: + # self.wait_for_everyone() + # yield + + def destroy(self): + if self.is_main_process: + self.tracker.finish() + return torch.distributed.destroy_process_group() + + @property + def pipeline_parallel_enabled(self): + return self._pp_degree > 1 + + @property + def data_parallel_enabled(self): + return self._dp_degree > 1 or self._dp_shards > 1 + + @property + def data_replication_enabled(self): + return self._dp_degree > 1 + + @property + def data_sharding_enabled(self): + return self._dp_shards > 1 + + @property + def context_parallel_enabled(self): + return self._cp_degree > 1 + + @property + def tensor_parallel_enabled(self): + return self._tp_degree > 1 diff --git a/finetrainers/parallel/utils.py b/finetrainers/parallel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe434e1a192504998260118e9bd4b615113b2cd8 --- /dev/null +++ b/finetrainers/parallel/utils.py @@ -0,0 +1,99 @@ +from typing import Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor +from diffusers.utils import is_accelerate_available +from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard +from torch.distributed._composable.replicate import replicate + +from ..utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES + + +if is_accelerate_available(): + from accelerate import Accelerator + from accelerate.utils import ( + DataLoaderConfiguration, + DistributedDataParallelKwargs, + InitProcessGroupKwargs, + ProjectConfiguration, + ) + + +def apply_fsdp2_ptd( + model: torch.nn.Module, + dp_mesh: torch.distributed.device_mesh.DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + output_dtype: torch.dtype, + pp_enabled: bool = False, + cpu_offload: bool = False, +) -> None: + r"""Apply FSDP2 on a model.""" + mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True) + + def apply_fully_shard(blocks): + for layer_index, block in enumerate(blocks): + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = layer_index < len(blocks) - 1 + fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward) + + for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: + blocks = getattr(model, transformer_block_name, None) + if blocks is not None: + apply_fully_shard(blocks) + + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + +def apply_ddp_accelerate( + model: torch.nn.Module, + project_config: Optional[ProjectConfiguration] = None, + ddp_kwargs: Optional[DistributedDataParallelKwargs] = None, + init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None, + dataloader_config: Optional[DataLoaderConfiguration] = None, + gradient_accumulation_steps: Optional[int] = None, + accelerator: Optional[Accelerator] = None, +) -> torch.nn.Module: + if accelerator is None: + accelerator = Accelerator( + project_config=project_config, + dataloader_config=dataloader_config, + gradient_accumulation_steps=gradient_accumulation_steps, + log_with=None, + kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], + ) + if torch.backends.mps.is_available(): + accelerator.native_amp = False + accelerator.prepare_model(model) + return accelerator, model + + +def apply_ddp_ptd(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None: + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + +def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: + if isinstance(x, torch.distributed.tensor.DTensor): + # functional collectives do not support DTensor inputs + x = x.full_tensor() + assert x.numel() == 1 # required by `.item()` + return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() + + +def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: + return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh) + + +def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float: + return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh) diff --git a/finetrainers/patches.py b/finetrainers/patches.py deleted file mode 100644 index 1faacbde58e2948a3aa83d8c848de2a9cc681583..0000000000000000000000000000000000000000 --- a/finetrainers/patches.py +++ /dev/null @@ -1,50 +0,0 @@ -import functools - -import torch -from accelerate.logging import get_logger -from peft.tuners.tuners_utils import BaseTunerLayer - -from .constants import FINETRAINERS_LOG_LEVEL - - -logger = get_logger("finetrainers") # pylint: disable=invalid-name -logger.setLevel(FINETRAINERS_LOG_LEVEL) - - -def perform_peft_patches() -> None: - _perform_patch_move_adapter_to_device_of_base_layer() - - -def _perform_patch_move_adapter_to_device_of_base_layer() -> None: - # We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights - # are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of - # LoRA weights from higher precision dtype. - BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer( - BaseTunerLayer._move_adapter_to_device_of_base_layer - ) - - -def _patched_move_adapter_to_device_of_base_layer(func) -> None: - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - with DisableTensorToDtype(): - return func(self, *args, **kwargs) - - return wrapper - - -class DisableTensorToDtype: - def __enter__(self): - self.original_to = torch.Tensor.to - - def modified_to(tensor, *args, **kwargs): - # remove dtype from args if present - args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] - if "dtype" in kwargs: - kwargs.pop("dtype") - return self.original_to(tensor, *args, **kwargs) - - torch.Tensor.to = modified_to - - def __exit__(self, exc_type, exc_val, exc_tb): - torch.Tensor.to = self.original_to diff --git a/finetrainers/patches/__init__.py b/finetrainers/patches/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d499f4d2136f8cdb4a9eb830289bb17757e7b37 --- /dev/null +++ b/finetrainers/patches/__init__.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from ..args import BaseArgs + from ..parallel import ParallelBackendType + + +def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBackendType") -> None: + # To avoid circular imports + from ..config import ModelType, TrainingType + + if args.model_name == ModelType.LTX_VIDEO: + from .models.ltx_video import patch + + patch.patch_transformer_forward() + if parallel_backend.tensor_parallel_enabled: + patch.patch_apply_rotary_emb_for_tp_compatibility() + + if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0: + from dependencies.peft import patch + + patch.patch_peft_move_adapter_to_device_of_base_layer() diff --git a/finetrainers/patches/dependencies/peft/patch.py b/finetrainers/patches/dependencies/peft/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0bf968cf4894b8ccd90c76631e57541e10642d --- /dev/null +++ b/finetrainers/patches/dependencies/peft/patch.py @@ -0,0 +1,25 @@ +import functools + +from peft.tuners.tuners_utils import BaseTunerLayer + +from ...utils import DisableTensorToDtype + + +def patch_peft_move_adapter_to_device_of_base_layer() -> None: + _perform_patch_move_adapter_to_device_of_base_layer() + + +def _perform_patch_move_adapter_to_device_of_base_layer() -> None: + BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer( + BaseTunerLayer._move_adapter_to_device_of_base_layer + ) + + +def _patched_move_adapter_to_device_of_base_layer(func) -> None: + # TODO(aryan): This is really unsafe probably and may break things. It works for now, but revisit and refactor. + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with DisableTensorToDtype(): + return func(self, *args, **kwargs) + + return wrapper diff --git a/finetrainers/patches/models/ltx_video/patch.py b/finetrainers/patches/models/ltx_video/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..851da6e795ee779c4a145d3baf89e50bb001adec --- /dev/null +++ b/finetrainers/patches/models/ltx_video/patch.py @@ -0,0 +1,127 @@ +from typing import Any, Dict, Optional, Tuple + +import diffusers +import torch +from diffusers import LTXVideoTransformer3DModel +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils.import_utils import is_torch_version + + +def patch_transformer_forward() -> None: + _perform_ltx_transformer_forward_patch() + + +def patch_apply_rotary_emb_for_tp_compatibility() -> None: + _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() + + +def _perform_ltx_transformer_forward_patch() -> None: + LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3Dforward + + +def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None: + def apply_rotary_emb(x, freqs): + cos, sin = freqs + # ======== THIS IS CHANGED FROM THE ORIGINAL IMPLEMENTATION ======== + # The change is made due to unsupported DTensor operation aten.ops.unbind + # FIXME: Once aten.ops.unbind support lands, this will no longer be required + # x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2] + x_real, x_imag = x.unflatten(2, (-1, 2)).chunk(2, dim=-1) # [B, S, H, D // 2] + # ================================================================== + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb + + +def _patched_LTXVideoTransformer3Dforward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + num_frames: int, + height: int, + width: int, + rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, + return_dict: bool = True, + *args, + **kwargs, +) -> torch.Tensor: + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # ===== This is modified compared to Diffusers ===== + # This is done because the Diffusers pipeline will pass in a 1D tensor for timestep + if timestep.ndim == 1: + timestep = timestep.view(-1, 1, 1).expand(-1, *hidden_states.shape[1:-1], -1) + # ================================================== + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + # ===== This is modified compared to Diffusers ===== + # temb = temb.view(batch_size, -1, temb.size(-1)) + # embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + # ================================================== + # This is done to make it possible to use per-token timestep embedding + temb = temb.view(batch_size, *hidden_states.shape[1:-1], temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, *hidden_states.shape[1:-1], embedded_timestep.size(-1)) + # ================================================== + + hidden_states = self.proj_in(hidden_states) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + encoder_attention_mask, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/finetrainers/patches/utils.py b/finetrainers/patches/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7f4726cc8183a461310570762ee95b5c4e6187 --- /dev/null +++ b/finetrainers/patches/utils.py @@ -0,0 +1,18 @@ +import torch + + +class DisableTensorToDtype: + def __enter__(self): + self.original_to = torch.Tensor.to + + def modified_to(tensor, *args, **kwargs): + # remove dtype from args if present + args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] + if "dtype" in kwargs: + kwargs.pop("dtype") + return self.original_to(tensor, *args, **kwargs) + + torch.Tensor.to = modified_to + + def __exit__(self, *args, **kwargs): + torch.Tensor.to = self.original_to diff --git a/finetrainers/processors/__init__.py b/finetrainers/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e55b3d14743fc01128d07ffb115fc84e98cc6eee --- /dev/null +++ b/finetrainers/processors/__init__.py @@ -0,0 +1,5 @@ +from .base import ProcessorMixin +from .clip import CLIPPooledProcessor +from .llama import LlamaProcessor +from .t5 import T5Processor +from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor diff --git a/finetrainers/processors/base.py b/finetrainers/processors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9853ead0ef49610bdfed05ff067918bf80558109 --- /dev/null +++ b/finetrainers/processors/base.py @@ -0,0 +1,20 @@ +import inspect +from typing import Any, Dict, List + + +class ProcessorMixin: + def __init__(self) -> None: + self._forward_parameter_names = inspect.signature(self.forward).parameters.keys() + self.output_names: List[str] = None + self.input_names: Dict[str, Any] = None + + def __call__(self, *args, **kwargs) -> Any: + shallow_copy_kwargs = dict(kwargs.items()) + if self.input_names is not None: + for k, v in self.input_names.items(): + shallow_copy_kwargs[v] = shallow_copy_kwargs.pop(k) + acceptable_kwargs = {k: v for k, v in shallow_copy_kwargs.items() if k in self._forward_parameter_names} + return self.forward(*args, **acceptable_kwargs) + + def forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("ProcessorMixin::forward method should be implemented by the subclass.") diff --git a/finetrainers/processors/clip.py b/finetrainers/processors/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..178addf8b2556e7c1ae952084790f2575d27f007 --- /dev/null +++ b/finetrainers/processors/clip.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast + +from .base import ProcessorMixin + + +class CLIPPooledProcessor(ProcessorMixin): + r""" + Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings + and attention masks for the input text. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the embeddings of the input + text and the second output is the attention mask for the input text. + """ + + def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None) -> None: + super().__init__() + + self.output_names = output_names + self.input_names = input_names + + assert len(output_names) == 1 + if input_names is not None: + assert len(input_names) <= 3 + + def forward( + self, + tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast], + text_encoder: CLIPTextModel, + caption: Union[str, List[str]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encode the input text and return the embeddings and attention mask for the input text. + + Args: + tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): + The tokenizer used to tokenize the input text. + text_encoder (`LlamaModel`): + The text encoder used to encode the input text. + caption (`Union[str, List[str]]`): + The input text to be encoded. + """ + if isinstance(caption, str): + caption = [caption] + + device = text_encoder.device + dtype = text_encoder.dtype + + text_inputs = tokenizer( + caption, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + + prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False).pooler_output + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return {self.output_names[0]: prompt_embeds} diff --git a/finetrainers/processors/llama.py b/finetrainers/processors/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..749e5f313541b92317279669faf915edeb9129c4 --- /dev/null +++ b/finetrainers/processors/llama.py @@ -0,0 +1,118 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast + +from .base import ProcessorMixin + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +class LlamaProcessor(ProcessorMixin): + r""" + Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings + and attention masks for the input text. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the embeddings of the input + text and the second output is the attention mask for the input text. + """ + + def __init__(self, output_names: List[str] = None): + super().__init__() + + self.output_names = output_names + + assert len(output_names) == 2 + + def forward( + self, + tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast], + text_encoder: LlamaModel, + caption: Union[str, List[str]], + max_sequence_length: int, + prompt_template: Optional[Dict[str, Any]] = None, + num_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encode the input text and return the embeddings and attention mask for the input text. + + Args: + tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`): + The tokenizer used to tokenize the input text. + text_encoder (`LlamaModel`): + The text encoder used to encode the input text. + caption (`Union[str, List[str]]`): + The input text to be encoded. + max_sequence_length (`int`): + The maximum sequence length of the input text. + prompt_template (`Optional[Dict[str, Any]]`): + The prompt template to be used to encode the input text. + """ + if prompt_template is None: + prompt_template = DEFAULT_PROMPT_TEMPLATE + if isinstance(caption, str): + caption = [caption] + + device = text_encoder.device + dtype = text_encoder.dtype + + batch_size = len(caption) + caption = [prompt_template["template"].format(c) for c in caption] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = tokenizer( + caption, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device) + prompt_attention_mask = text_inputs.attention_mask.bool().to(device) + + prompt_embeds = text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ).hidden_states[-(num_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + + return { + self.output_names[0]: prompt_embeds, + self.output_names[1]: prompt_attention_mask, + } diff --git a/finetrainers/processors/t5.py b/finetrainers/processors/t5.py new file mode 100644 index 0000000000000000000000000000000000000000..96c2c194ca01635de404d618afffa93b88cdf953 --- /dev/null +++ b/finetrainers/processors/t5.py @@ -0,0 +1,73 @@ +from typing import List, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer, T5TokenizerFast + +from .base import ProcessorMixin + + +class T5Processor(ProcessorMixin): + r""" + Processor for the T5 family of models. This processor is used to encode text inputs and return the embeddings + and attention masks for the input text. + + Args: + output_names (`List[str]`): + The names of the outputs that the processor should return. The first output is the embeddings of the input + text and the second output is the attention mask for the input text. + """ + + def __init__(self, output_names: List[str]): + super().__init__() + + self.output_names = output_names + + assert len(self.output_names) == 2 + + def forward( + self, + tokenizer: Union[T5Tokenizer, T5TokenizerFast], + text_encoder: T5EncoderModel, + caption: Union[str, List[str]], + max_sequence_length: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Encode the input text and return the embeddings and attention mask for the input text. + + Args: + tokenizer (`Union[T5Tokenizer, T5TokenizerFast]`): + The tokenizer used to tokenize the input text. + text_encoder (`T5EncoderModel`): + The text encoder used to encode the input text. + caption (`Union[str, List[str]]`): + The input text to be encoded. + max_sequence_length (`int`): + The maximum sequence length of the input text. + """ + if isinstance(caption, str): + caption = [caption] + + device = text_encoder.device + dtype = text_encoder.dtype + + batch_size = len(caption) + text_inputs = tokenizer( + caption, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + + return { + self.output_names[0]: prompt_embeds, + self.output_names[1]: prompt_attention_mask, + } diff --git a/finetrainers/processors/text.py b/finetrainers/processors/text.py new file mode 100644 index 0000000000000000000000000000000000000000..b51dca68214ba36c5813775f6dbc6d40592e9b3c --- /dev/null +++ b/finetrainers/processors/text.py @@ -0,0 +1,22 @@ +from typing import List, Union + +import torch + +from .. import functional as FF +from .base import ProcessorMixin + + +class CaptionTextDropoutProcessor(ProcessorMixin): + def __init__(self, dropout_p: float = 0.0) -> None: + self.dropout_p = dropout_p + + def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]: + return FF.dropout_caption(caption, self.dropout_p) + + +class CaptionEmbeddingDropoutProcessor(ProcessorMixin): + def __init__(self, dropout_p: float = 0.0) -> None: + self.dropout_p = dropout_p + + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + return FF.dropout_embeddings_to_zero(embedding, self.dropout_p) diff --git a/finetrainers/state.py b/finetrainers/state.py index 15a92e23da840af7e6920d20ea6cd4252feb47ed..5cda7d91b3f0e82b493d7e88b3565b9df985a228 100644 --- a/finetrainers/state.py +++ b/finetrainers/state.py @@ -1,21 +1,66 @@ +import io +from dataclasses import dataclass, field +from typing import Any, Dict, List + import torch -from accelerate import Accelerator +import torch.distributed.checkpoint.stateful + +from .parallel import ParallelBackendType +from .utils import get_device_info + + +_device_type, _ = get_device_info() + +@dataclass +class TrainState(torch.distributed.checkpoint.stateful.Stateful): + step: int = 0 + observed_data_samples: int = 0 + observed_num_tokens: int = 0 + global_avg_losses: List[float] = field(default_factory=list) + global_max_losses: List[float] = field(default_factory=list) + log_steps: List[int] = field(default_factory=list) + def state_dict(self) -> Dict[str, Any]: + # Only checkpoint global_avg_losses and global_max_losses per log frequency + # to avoid sync overhead in every iteration. + global_avg_losses_bytes = io.BytesIO() + torch.save(self.global_avg_losses, global_avg_losses_bytes) + global_max_losses_bytes = io.BytesIO() + torch.save(self.global_max_losses, global_max_losses_bytes) + log_steps_bytes = io.BytesIO() + torch.save(self.log_steps, log_steps_bytes) + return { + "step": torch.tensor(self.step, dtype=torch.int32), + "observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32), + "observed_num_tokens": torch.tensor(self.observed_num_tokens, dtype=torch.int32), + "global_avg_losses": global_avg_losses_bytes, + "global_max_losses": global_max_losses_bytes, + "log_steps": log_steps_bytes, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + state_dict["global_avg_losses"].seek(0) + state_dict["global_max_losses"].seek(0) + state_dict["log_steps"].seek(0) + + self.step = state_dict["step"].item() + self.observed_data_samples = state_dict["observed_data_samples"].item() + self.observed_num_tokens = state_dict["observed_num_tokens"].item() + self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False) + self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False) + self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) + + +@dataclass class State: + # Parallel state + parallel_backend: ParallelBackendType = None + # Training state - seed: int = None - model_name: str = None - accelerator: Accelerator = None - weight_dtype: torch.dtype = None - train_epochs: int = None - train_steps: int = None - overwrote_max_train_steps: bool = False + train_state: TrainState = None num_trainable_parameters: int = 0 - learning_rate: float = None - train_batch_size: int = None generator: torch.Generator = None - num_update_steps_per_epoch: int = None # Hub state repo_id: str = None diff --git a/finetrainers/trackers.py b/finetrainers/trackers.py new file mode 100644 index 0000000000000000000000000000000000000000..a48716605e1ed2e39ab5d86cd39f72467496fd52 --- /dev/null +++ b/finetrainers/trackers.py @@ -0,0 +1,92 @@ +import pathlib +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from .logging import get_logger + + +logger = get_logger() + + +class BaseTracker: + r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging.""" + + def log(self, metrics: Dict[str, Any], step: int) -> None: + pass + + def finish(self) -> None: + pass + + +class WandbTracker(BaseTracker): + r"""Logger implementation for Weights & Biases.""" + + def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None: + import wandb + + self.wandb = wandb + + # WandB does not create a directory if it does not exist and instead starts using the system temp directory. + pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) + + self.run = wandb.init(project=experiment_name, dir=log_dir, config=config) + logger.info("WandB logging enabled") + + def log(self, metrics: Dict[str, Any], step: int) -> None: + self.run.log(metrics, step=step) + + def finish(self) -> None: + self.run.finish() + + +class SequentialTracker(BaseTracker): + r"""Sequential tracker that logs to multiple trackers in sequence.""" + + def __init__(self, trackers: List[BaseTracker]) -> None: + self.trackers = trackers + + def log(self, metrics: Dict[str, Any], step: int) -> None: + for tracker in self.trackers: + tracker.log(metrics, step) + + def finish(self) -> None: + for tracker in self.trackers: + tracker.finish() + + +class Trackers(str, Enum): + r"""Enum for supported trackers.""" + + NONE = "none" + WANDB = "wandb" + + +_SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()] + + +def initialize_trackers( + trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str +) -> Union[BaseTracker, SequentialTracker]: + r"""Initialize loggers based on the provided configuration.""" + + logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}") + + if len(trackers) == 0: + return BaseTracker() + + if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)): + raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}") + + tracker_instances = [] + for tracker_name in set(trackers): + if tracker_name == Trackers.NONE: + tracker = BaseTracker() + elif tracker_name == Trackers.WANDB: + tracker = WandbTracker(experiment_name, log_dir, config) + tracker_instances.append(tracker) + + tracker = SequentialTracker(tracker_instances) + return tracker + + +TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker] diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py deleted file mode 100644 index a213c00a129a407bf304b51cb8ae0965afd54d98..0000000000000000000000000000000000000000 --- a/finetrainers/trainer.py +++ /dev/null @@ -1,1235 +0,0 @@ -import json -import logging -import math -import os -import gc -import random -from datetime import datetime, timedelta -from pathlib import Path -from typing import Any, Dict, List -import resource -import diffusers -import torch -import torch.backends -import transformers -import wandb -from accelerate import Accelerator, DistributedType -from accelerate.logging import get_logger -from accelerate.utils import ( - DistributedDataParallelKwargs, - InitProcessGroupKwargs, - ProjectConfiguration, - gather_object, - set_seed, -) -from diffusers import DiffusionPipeline -from diffusers.configuration_utils import FrozenDict -from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution -from diffusers.optimization import get_scheduler -from diffusers.training_utils import cast_training_params -from diffusers.utils import export_to_video, load_image, load_video -from huggingface_hub import create_repo, upload_folder -from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict -from tqdm import tqdm - -from .args import Args, validate_args -from .constants import ( - FINETRAINERS_LOG_LEVEL, - PRECOMPUTED_CONDITIONS_DIR_NAME, - PRECOMPUTED_DIR_NAME, - PRECOMPUTED_LATENTS_DIR_NAME, -) -from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset -from .hooks import apply_layerwise_upcasting -from .models import get_config_from_model_name -from .patches import perform_peft_patches -from .state import State -from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from -from .utils.data_utils import should_perform_precomputation -from .utils.diffusion_utils import ( - get_scheduler_alphas, - get_scheduler_sigmas, - prepare_loss_weights, - prepare_sigmas, - prepare_target, -) -from .utils.file_utils import string_to_filename -from .utils.hub_utils import save_model_card -from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous -from .utils.model_utils import resolve_vae_cls_from_ckpt_path -from .utils.optimizer_utils import get_optimizer -from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model - - -logger = get_logger("finetrainers") -logger.setLevel(FINETRAINERS_LOG_LEVEL) - - -class Trainer: - def __init__(self, args: Args) -> None: - validate_args(args) - - self.args = args - self.args.seed = self.args.seed or datetime.now().year - self.state = State() - - # Tokenizers - self.tokenizer = None - self.tokenizer_2 = None - self.tokenizer_3 = None - - # Text encoders - self.text_encoder = None - self.text_encoder_2 = None - self.text_encoder_3 = None - - # Denoisers - self.transformer = None - self.unet = None - - # Autoencoders - self.vae = None - - # Scheduler - self.scheduler = None - - self.transformer_config = None - self.vae_config = None - - self._init_distributed() - self._init_logging() - self._init_directories_and_repositories() - self._init_config_options() - - # Peform any patches needed for training - if len(self.args.layerwise_upcasting_modules) > 0: - perform_peft_patches() - # TODO(aryan): handle text encoders - # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]): - # perform_text_encoder_patches() - - self.state.model_name = self.args.model_name - self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type) - - def prepare_dataset(self) -> None: - # TODO(aryan): Make a background process for fetching - logger.info("Initializing dataset and dataloader") - - self.dataset = ImageOrVideoDatasetWithResizing( - data_root=self.args.data_root, - caption_column=self.args.caption_column, - video_column=self.args.video_column, - resolution_buckets=self.args.video_resolution_buckets, - dataset_file=self.args.dataset_file, - id_token=self.args.id_token, - remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes, - ) - self.dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=1, - sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True), - collate_fn=self.model_config.get("collate_fn"), - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.pin_memory, - ) - - def prepare_models(self) -> None: - logger.info("Initializing models") - - load_components_kwargs = self._get_load_components_kwargs() - condition_components, latent_components, diffusion_components = {}, {}, {} - if not self.args.precompute_conditions: - # To download the model files first on the main process (if not already present) - # and then load the cached files afterward from the other processes. - with self.state.accelerator.main_process_first(): - condition_components = self.model_config["load_condition_models"](**load_components_kwargs) - latent_components = self.model_config["load_latent_models"](**load_components_kwargs) - diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs) - - components = {} - components.update(condition_components) - components.update(latent_components) - components.update(diffusion_components) - self._set_components(components) - - if self.vae is not None: - if self.args.enable_slicing: - self.vae.enable_slicing() - if self.args.enable_tiling: - self.vae.enable_tiling() - - def prepare_precomputations(self) -> None: - if not self.args.precompute_conditions: - return - - logger.info("Initializing precomputations") - - if self.args.batch_size != 1: - raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.") - - def collate_fn(batch): - latent_conditions = [x["latent_conditions"] for x in batch] - text_conditions = [x["text_conditions"] for x in batch] - batched_latent_conditions = {} - batched_text_conditions = {} - for key in list(latent_conditions[0].keys()): - if torch.is_tensor(latent_conditions[0][key]): - batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0) - else: - # TODO(aryan): implement batch sampler for precomputed latents - batched_latent_conditions[key] = [x[key] for x in latent_conditions][0] - for key in list(text_conditions[0].keys()): - if torch.is_tensor(text_conditions[0][key]): - batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0) - else: - # TODO(aryan): implement batch sampler for precomputed latents - batched_text_conditions[key] = [x[key] for x in text_conditions][0] - return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions} - - cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path) - precomputation_dir = ( - Path(self.args.data_root) / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}" - ) - should_precompute = should_perform_precomputation(precomputation_dir) - if not should_precompute: - logger.info("Precomputed conditions and latents found. Loading precomputed data.") - self.dataloader = torch.utils.data.DataLoader( - PrecomputedDataset( - data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id - ), - batch_size=self.args.batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.pin_memory, - ) - return - - logger.info("Precomputed conditions and latents not found. Running precomputation.") - - # At this point, no models are loaded, so we need to load and precompute conditions and latents - with self.state.accelerator.main_process_first(): - condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs()) - self._set_components(condition_components) - self._move_components_to_device() - self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3]) - - if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty": - logger.warning( - "Caption dropout is not supported with precomputation yet. This will be supported in the future." - ) - - conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME - latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME - conditions_dir.mkdir(parents=True, exist_ok=True) - latents_dir.mkdir(parents=True, exist_ok=True) - - accelerator = self.state.accelerator - - # Precompute conditions - progress_bar = tqdm( - range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes), - desc="Precomputing conditions", - disable=not accelerator.is_local_main_process, - ) - index = 0 - for i, data in enumerate(self.dataset): - if i % accelerator.num_processes != accelerator.process_index: - continue - - logger.debug( - f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}" - ) - - text_conditions = self.model_config["prepare_conditions"]( - tokenizer=self.tokenizer, - tokenizer_2=self.tokenizer_2, - tokenizer_3=self.tokenizer_3, - text_encoder=self.text_encoder, - text_encoder_2=self.text_encoder_2, - text_encoder_3=self.text_encoder_3, - prompt=data["prompt"], - device=accelerator.device, - dtype=self.args.transformer_dtype, - ) - filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt" - torch.save(text_conditions, filename.as_posix()) - index += 1 - progress_bar.update(1) - self._delete_components() - - memory_statistics = get_memory_statistics() - logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}") - torch.cuda.reset_peak_memory_stats(accelerator.device) - - # Precompute latents - with self.state.accelerator.main_process_first(): - latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs()) - self._set_components(latent_components) - self._move_components_to_device() - self._disable_grad_for_components([self.vae]) - - if self.vae is not None: - if self.args.enable_slicing: - self.vae.enable_slicing() - if self.args.enable_tiling: - self.vae.enable_tiling() - - progress_bar = tqdm( - range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes), - desc="Precomputing latents", - disable=not accelerator.is_local_main_process, - ) - index = 0 - for i, data in enumerate(self.dataset): - if i % accelerator.num_processes != accelerator.process_index: - continue - - logger.debug( - f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}" - ) - - latent_conditions = self.model_config["prepare_latents"]( - vae=self.vae, - image_or_video=data["video"].unsqueeze(0), - device=accelerator.device, - dtype=self.args.transformer_dtype, - generator=self.state.generator, - precompute=True, - ) - filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt" - torch.save(latent_conditions, filename.as_posix()) - index += 1 - progress_bar.update(1) - self._delete_components() - - accelerator.wait_for_everyone() - logger.info("Precomputation complete") - - memory_statistics = get_memory_statistics() - logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}") - torch.cuda.reset_peak_memory_stats(accelerator.device) - - # Update dataloader to use precomputed conditions and latents - self.dataloader = torch.utils.data.DataLoader( - PrecomputedDataset( - data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id - ), - batch_size=self.args.batch_size, - shuffle=True, - collate_fn=collate_fn, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.pin_memory, - ) - - def prepare_trainable_parameters(self) -> None: - logger.info("Initializing trainable parameters") - - with self.state.accelerator.main_process_first(): - diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs()) - self._set_components(diffusion_components) - - components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae] - self._disable_grad_for_components(components) - - if self.args.training_type == "full-finetune": - logger.info("Finetuning transformer with no additional parameters") - self._enable_grad_for_components([self.transformer]) - else: - logger.info("Finetuning transformer with PEFT parameters") - self._disable_grad_for_components([self.transformer]) - - # Layerwise upcasting must be applied before adding the LoRA adapter. - # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on - # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. - if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules: - apply_layerwise_upcasting( - self.transformer, - storage_dtype=self.args.layerwise_upcasting_storage_dtype, - compute_dtype=self.args.transformer_dtype, - skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, - non_blocking=True, - ) - - self._move_components_to_device() - - if self.args.gradient_checkpointing: - self.transformer.enable_gradient_checkpointing() - - if self.args.training_type == "lora": - transformer_lora_config = LoraConfig( - r=self.args.rank, - lora_alpha=self.args.lora_alpha, - init_lora_weights=True, - target_modules=self.args.target_modules, - ) - self.transformer.add_adapter(transformer_lora_config) - else: - transformer_lora_config = None - - # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32 - # even if layerwise upcasting. Would be nice to have a test as well - - self.register_saving_loading_hooks(transformer_lora_config) - - def register_saving_loading_hooks(self, transformer_lora_config): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - if self.state.accelerator.is_main_process: - transformer_lora_layers_to_save = None - - for model in models: - if isinstance( - unwrap_model(self.state.accelerator, model), - type(unwrap_model(self.state.accelerator, self.transformer)), - ): - model = unwrap_model(self.state.accelerator, model) - if self.args.training_type == "lora": - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - else: - raise ValueError(f"Unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() - - if self.args.training_type == "lora": - self.model_config["pipeline_cls"].save_lora_weights( - output_dir, - transformer_lora_layers=transformer_lora_layers_to_save, - ) - else: - model.save_pretrained(os.path.join(output_dir, "transformer")) - - # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need - # to able to load all diffusion components from a specific checkpoint folder during validation, we need to - # ensure the scheduler config is serialized as well. - self.scheduler.save_pretrained(os.path.join(output_dir, "scheduler")) - - def load_model_hook(models, input_dir): - if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED: - while len(models) > 0: - model = models.pop() - if isinstance( - unwrap_model(self.state.accelerator, model), - type(unwrap_model(self.state.accelerator, self.transformer)), - ): - transformer_ = unwrap_model(self.state.accelerator, model) - else: - raise ValueError( - f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}" - ) - else: - transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__ - - if self.args.training_type == "lora": - transformer_ = transformer_cls_.from_pretrained( - self.args.pretrained_model_name_or_path, subfolder="transformer" - ) - transformer_.add_adapter(transformer_lora_config) - lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir) - transformer_state_dict = { - f'{k.replace("transformer.", "")}': v - for k, v in lora_state_dict.items() - if k.startswith("transformer.") - } - incompatible_keys = set_peft_model_state_dict( - transformer_, transformer_state_dict, adapter_name="default" - ) - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - else: - transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer")) - - self.state.accelerator.register_save_state_pre_hook(save_model_hook) - self.state.accelerator.register_load_state_pre_hook(load_model_hook) - - def prepare_optimizer(self) -> None: - logger.info("Initializing optimizer and lr scheduler") - - self.state.train_epochs = self.args.train_epochs - self.state.train_steps = self.args.train_steps - - # Make sure the trainable params are in float32 - if self.args.training_type == "lora": - cast_training_params([self.transformer], dtype=torch.float32) - - self.state.learning_rate = self.args.lr - if self.args.scale_lr: - self.state.learning_rate = ( - self.state.learning_rate - * self.args.gradient_accumulation_steps - * self.args.batch_size - * self.state.accelerator.num_processes - ) - - transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters())) - transformer_parameters_with_lr = { - "params": transformer_trainable_parameters, - "lr": self.state.learning_rate, - } - params_to_optimize = [transformer_parameters_with_lr] - self.state.num_trainable_parameters = sum(p.numel() for p in transformer_trainable_parameters) - - use_deepspeed_opt = ( - self.state.accelerator.state.deepspeed_plugin is not None - and "optimizer" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config - ) - optimizer = get_optimizer( - params_to_optimize=params_to_optimize, - optimizer_name=self.args.optimizer, - learning_rate=self.state.learning_rate, - beta1=self.args.beta1, - beta2=self.args.beta2, - beta3=self.args.beta3, - epsilon=self.args.epsilon, - weight_decay=self.args.weight_decay, - use_8bit=self.args.use_8bit_bnb, - use_deepspeed=use_deepspeed_opt, - ) - - num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps) - if self.state.train_steps is None: - self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch - self.state.overwrote_max_train_steps = True - - use_deepspeed_lr_scheduler = ( - self.state.accelerator.state.deepspeed_plugin is not None - and "scheduler" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config - ) - total_training_steps = self.state.train_steps * self.state.accelerator.num_processes - num_warmup_steps = self.args.lr_warmup_steps * self.state.accelerator.num_processes - - if use_deepspeed_lr_scheduler: - from accelerate.utils import DummyScheduler - - lr_scheduler = DummyScheduler( - name=self.args.lr_scheduler, - optimizer=optimizer, - total_num_steps=total_training_steps, - num_warmup_steps=num_warmup_steps, - ) - else: - lr_scheduler = get_scheduler( - name=self.args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=total_training_steps, - num_cycles=self.args.lr_num_cycles, - power=self.args.lr_power, - ) - - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - - def prepare_for_training(self) -> None: - self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare( - self.transformer, self.optimizer, self.dataloader, self.lr_scheduler - ) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps) - if self.state.overwrote_max_train_steps: - self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch - # Afterwards we recalculate our number of training epochs - self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch) - self.state.num_update_steps_per_epoch = num_update_steps_per_epoch - - def prepare_trackers(self) -> None: - logger.info("Initializing trackers") - - tracker_name = self.args.tracker_name or "finetrainers-experiment" - self.state.accelerator.init_trackers(tracker_name, config=self._get_training_info()) - - def train(self) -> None: - logger.info("Starting training") - - - # Add these lines at the beginning - if hasattr(resource, 'RLIMIT_NOFILE'): - try: - soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) - logger.info(f"Current file descriptor limits in trainer: soft={soft}, hard={hard}") - # Try to increase to hard limit if possible - if soft < hard: - resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) - new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE) - logger.info(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}") - except Exception as e: - logger.warning(f"Could not check or update file descriptor limits: {e}") - - memory_statistics = get_memory_statistics() - logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") - - if self.vae_config is None: - # If we've precomputed conditions and latents already, and are now re-using it, we will never load - # the VAE so self.vae_config will not be set. So, we need to load it here. - vae_cls = resolve_vae_cls_from_ckpt_path( - self.args.pretrained_model_name_or_path, revision=self.args.revision, cache_dir=self.args.cache_dir - ) - vae_config = vae_cls.load_config( - self.args.pretrained_model_name_or_path, - subfolder="vae", - revision=self.args.revision, - cache_dir=self.args.cache_dir, - ) - self.vae_config = FrozenDict(**vae_config) - - # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need - # to able to load all diffusion components from a specific checkpoint folder during validation, we need to - # ensure the scheduler config is serialized as well. - if self.args.training_type == "full-finetune": - self.scheduler.save_pretrained(os.path.join(self.args.output_dir, "scheduler")) - - self.state.train_batch_size = ( - self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps - ) - info = { - "trainable parameters": self.state.num_trainable_parameters, - "total samples": len(self.dataset), - "train epochs": self.state.train_epochs, - "train steps": self.state.train_steps, - "batches per device": self.args.batch_size, - "total batches observed per epoch": len(self.dataloader), - "train batch size": self.state.train_batch_size, - "gradient accumulation steps": self.args.gradient_accumulation_steps, - } - logger.info(f"Training configuration: {json.dumps(info, indent=4)}") - - global_step = 0 - first_epoch = 0 - initial_global_step = 0 - - # Potentially load in the weights and states from a previous save - ( - resume_from_checkpoint_path, - initial_global_step, - global_step, - first_epoch, - ) = get_latest_ckpt_path_to_resume_from( - resume_from_checkpoint=self.args.resume_from_checkpoint, - num_update_steps_per_epoch=self.state.num_update_steps_per_epoch, - output_dir=self.args.output_dir, - ) - if resume_from_checkpoint_path: - self.state.accelerator.load_state(resume_from_checkpoint_path) - - progress_bar = tqdm( - range(0, self.state.train_steps), - initial=initial_global_step, - desc="Training steps", - disable=not self.state.accelerator.is_local_main_process, - ) - - accelerator = self.state.accelerator - generator = torch.Generator(device=accelerator.device) - if self.args.seed is not None: - generator = generator.manual_seed(self.args.seed) - self.state.generator = generator - - scheduler_sigmas = get_scheduler_sigmas(self.scheduler) - scheduler_sigmas = ( - scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32) - if scheduler_sigmas is not None - else None - ) - scheduler_alphas = get_scheduler_alphas(self.scheduler) - scheduler_alphas = ( - scheduler_alphas.to(device=accelerator.device, dtype=torch.float32) - if scheduler_alphas is not None - else None - ) - - for epoch in range(first_epoch, self.state.train_epochs): - logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})") - - self.transformer.train() - models_to_accumulate = [self.transformer] - epoch_loss = 0.0 - num_loss_updates = 0 - - for step, batch in enumerate(self.dataloader): - logger.debug(f"Starting step {step + 1}") - logs = {} - - with accelerator.accumulate(models_to_accumulate): - if not self.args.precompute_conditions: - videos = batch["videos"] - prompts = batch["prompts"] - batch_size = len(prompts) - - if self.args.caption_dropout_technique == "empty": - if random.random() < self.args.caption_dropout_p: - prompts = [""] * batch_size - - latent_conditions = self.model_config["prepare_latents"]( - vae=self.vae, - image_or_video=videos, - patch_size=self.transformer_config.patch_size, - patch_size_t=self.transformer_config.patch_size_t, - device=accelerator.device, - dtype=self.args.transformer_dtype, - generator=self.state.generator, - ) - text_conditions = self.model_config["prepare_conditions"]( - tokenizer=self.tokenizer, - text_encoder=self.text_encoder, - tokenizer_2=self.tokenizer_2, - text_encoder_2=self.text_encoder_2, - prompt=prompts, - device=accelerator.device, - dtype=self.args.transformer_dtype, - ) - else: - latent_conditions = batch["latent_conditions"] - text_conditions = batch["text_conditions"] - latent_conditions["latents"] = DiagonalGaussianDistribution( - latent_conditions["latents"] - ).sample(self.state.generator) - - # This method should only be called for precomputed latents. - # TODO(aryan): rename this in separate PR - latent_conditions = self.model_config["post_latent_preparation"]( - vae_config=self.vae_config, - patch_size=self.transformer_config.patch_size, - patch_size_t=self.transformer_config.patch_size_t, - **latent_conditions, - ) - align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype) - align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype) - batch_size = latent_conditions["latents"].shape[0] - - latent_conditions = make_contiguous(latent_conditions) - text_conditions = make_contiguous(text_conditions) - - if self.args.caption_dropout_technique == "zero": - if random.random() < self.args.caption_dropout_p: - text_conditions["prompt_embeds"].fill_(0) - text_conditions["prompt_attention_mask"].fill_(False) - - # TODO(aryan): refactor later - if "pooled_prompt_embeds" in text_conditions: - text_conditions["pooled_prompt_embeds"].fill_(0) - - sigmas = prepare_sigmas( - scheduler=self.scheduler, - sigmas=scheduler_sigmas, - batch_size=batch_size, - num_train_timesteps=self.scheduler.config.num_train_timesteps, - flow_weighting_scheme=self.args.flow_weighting_scheme, - flow_logit_mean=self.args.flow_logit_mean, - flow_logit_std=self.args.flow_logit_std, - flow_mode_scale=self.args.flow_mode_scale, - device=accelerator.device, - generator=self.state.generator, - ) - timesteps = (sigmas * 1000.0).long() - - noise = torch.randn( - latent_conditions["latents"].shape, - generator=self.state.generator, - device=accelerator.device, - dtype=self.args.transformer_dtype, - ) - sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim) - - # TODO(aryan): We probably don't need calculate_noisy_latents because we can determine the type of - # scheduler and calculate the noisy latents accordingly. Look into this later. - if "calculate_noisy_latents" in self.model_config.keys(): - noisy_latents = self.model_config["calculate_noisy_latents"]( - scheduler=self.scheduler, - noise=noise, - latents=latent_conditions["latents"], - timesteps=timesteps, - ) - else: - # Default to flow-matching noise addition - noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise - noisy_latents = noisy_latents.to(latent_conditions["latents"].dtype) - - latent_conditions.update({"noisy_latents": noisy_latents}) - - weights = prepare_loss_weights( - scheduler=self.scheduler, - alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None, - sigmas=sigmas, - flow_weighting_scheme=self.args.flow_weighting_scheme, - ) - weights = expand_tensor_dims(weights, noise.ndim) - - pred = self.model_config["forward_pass"]( - transformer=self.transformer, - scheduler=self.scheduler, - timesteps=timesteps, - **latent_conditions, - **text_conditions, - ) - target = prepare_target( - scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"] - ) - - loss = weights.float() * (pred["latents"].float() - target.float()).pow(2) - # Average loss across all but batch dimension - loss = loss.mean(list(range(1, loss.ndim))) - # Average loss across batch dimension - loss = loss.mean() - accelerator.backward(loss) - - if accelerator.sync_gradients: - if accelerator.distributed_type == DistributedType.DEEPSPEED: - grad_norm = self.transformer.get_global_grad_norm() - # In some cases the grad norm may not return a float - if torch.is_tensor(grad_norm): - grad_norm = grad_norm.item() - else: - grad_norm = accelerator.clip_grad_norm_( - self.transformer.parameters(), self.args.max_grad_norm - ) - if torch.is_tensor(grad_norm): - grad_norm = grad_norm.item() - - logs["grad_norm"] = grad_norm - - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - # Checkpointing - if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process: - if global_step % self.args.checkpointing_steps == 0: - save_path = get_intermediate_ckpt_path( - checkpointing_limit=self.args.checkpointing_limit, - step=global_step, - output_dir=self.args.output_dir, - ) - accelerator.save_state(save_path) - - # Maybe run validation - should_run_validation = ( - self.args.validation_every_n_steps is not None - and global_step % self.args.validation_every_n_steps == 0 - ) - if should_run_validation: - self.validate(global_step) - - loss_item = loss.detach().item() - epoch_loss += loss_item - num_loss_updates += 1 - logs["step_loss"] = loss_item - logs["lr"] = self.lr_scheduler.get_last_lr()[0] - progress_bar.set_postfix(logs) - accelerator.log(logs, step=global_step) - - if global_step % 100 == 0: # Every 100 steps - # Force garbage collection to clean up any lingering resources - gc.collect() - - if global_step >= self.state.train_steps: - break - - - - if num_loss_updates > 0: - epoch_loss /= num_loss_updates - accelerator.log({"epoch_loss": epoch_loss}, step=global_step) - memory_statistics = get_memory_statistics() - logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") - - # Maybe run validation - should_run_validation = ( - self.args.validation_every_n_epochs is not None - and (epoch + 1) % self.args.validation_every_n_epochs == 0 - ) - if should_run_validation: - self.validate(global_step) - - if epoch % 3 == 0: # Every 3 epochs - logger.info("Performing periodic resource cleanup") - free_memory() - gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize(accelerator.device) - - accelerator.wait_for_everyone() - if accelerator.is_main_process: - transformer = unwrap_model(accelerator, self.transformer) - - if self.args.training_type == "lora": - transformer_lora_layers = get_peft_model_state_dict(transformer) - - self.model_config["pipeline_cls"].save_lora_weights( - save_directory=self.args.output_dir, - transformer_lora_layers=transformer_lora_layers, - ) - else: - transformer.save_pretrained(os.path.join(self.args.output_dir, "transformer")) - accelerator.wait_for_everyone() - self.validate(step=global_step, final_validation=True) - - if accelerator.is_main_process: - if self.args.push_to_hub: - upload_folder( - repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=["checkpoint-*"] - ) - - self._delete_components() - memory_statistics = get_memory_statistics() - logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") - - accelerator.end_training() - - def validate(self, step: int, final_validation: bool = False) -> None: - logger.info("Starting validation") - - accelerator = self.state.accelerator - num_validation_samples = len(self.args.validation_prompts) - - if num_validation_samples == 0: - logger.warning("No validation samples found. Skipping validation.") - if accelerator.is_main_process: - if self.args.push_to_hub: - save_model_card( - args=self.args, - repo_id=self.state.repo_id, - videos=None, - validation_prompts=None, - ) - return - - self.transformer.eval() - - memory_statistics = get_memory_statistics() - logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") - - pipeline = self._get_and_prepare_pipeline_for_validation(final_validation=final_validation) - - all_processes_artifacts = [] - prompts_to_filenames = {} - for i in range(num_validation_samples): - # Skip current validation on all processes but one - if i % accelerator.num_processes != accelerator.process_index: - continue - - prompt = self.args.validation_prompts[i] - image = self.args.validation_images[i] - video = self.args.validation_videos[i] - height = self.args.validation_heights[i] - width = self.args.validation_widths[i] - num_frames = self.args.validation_num_frames[i] - frame_rate = self.args.validation_frame_rate - if image is not None: - image = load_image(image) - if video is not None: - video = load_video(video) - - logger.debug( - f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}", - main_process_only=False, - ) - validation_artifacts = self.model_config["validation"]( - pipeline=pipeline, - prompt=prompt, - image=image, - video=video, - height=height, - width=width, - num_frames=num_frames, - frame_rate=frame_rate, - num_videos_per_prompt=self.args.num_validation_videos_per_prompt, - generator=torch.Generator(device=accelerator.device).manual_seed( - self.args.seed if self.args.seed is not None else 0 - ), - # todo support passing `fps` for supported pipelines. - ) - - prompt_filename = string_to_filename(prompt)[:25] - artifacts = { - "image": {"type": "image", "value": image}, - "video": {"type": "video", "value": video}, - } - for i, (artifact_type, artifact_value) in enumerate(validation_artifacts): - if artifact_value: - artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}}) - logger.debug( - f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}", - main_process_only=False, - ) - - for index, (key, value) in enumerate(list(artifacts.items())): - artifact_type = value["type"] - artifact_value = value["value"] - if artifact_type not in ["image", "video"] or artifact_value is None: - continue - - extension = "png" if artifact_type == "image" else "mp4" - filename = "validation-" if not final_validation else "final-" - filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}" - if accelerator.is_main_process and extension == "mp4": - prompts_to_filenames[prompt] = filename - filename = os.path.join(self.args.output_dir, filename) - - if artifact_type == "image" and artifact_value: - logger.debug(f"Saving image to {filename}") - artifact_value.save(filename) - artifact_value = wandb.Image(filename) - elif artifact_type == "video" and artifact_value: - logger.debug(f"Saving video to {filename}") - # TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`. - export_to_video(artifact_value, filename, fps=frame_rate) - artifact_value = wandb.Video(filename, caption=prompt) - - all_processes_artifacts.append(artifact_value) - - all_artifacts = gather_object(all_processes_artifacts) - - if accelerator.is_main_process: - tracker_key = "final" if final_validation else "validation" - for tracker in accelerator.trackers: - if tracker.name == "wandb": - artifact_log_dict = {} - - image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] - if len(image_artifacts) > 0: - artifact_log_dict["images"] = image_artifacts - video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] - if len(video_artifacts) > 0: - artifact_log_dict["videos"] = video_artifacts - tracker.log({tracker_key: artifact_log_dict}, step=step) - - if self.args.push_to_hub and final_validation: - video_filenames = list(prompts_to_filenames.values()) - prompts = list(prompts_to_filenames.keys()) - save_model_card( - args=self.args, - repo_id=self.state.repo_id, - videos=video_filenames, - validation_prompts=prompts, - ) - - # Remove all hooks that might have been added during pipeline initialization to the models - pipeline.remove_all_hooks() - del pipeline - - accelerator.wait_for_everyone() - - free_memory() - memory_statistics = get_memory_statistics() - logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") - torch.cuda.reset_peak_memory_stats(accelerator.device) - - if not final_validation: - self.transformer.train() - - def evaluate(self) -> None: - raise NotImplementedError("Evaluation has not been implemented yet.") - - def _init_distributed(self) -> None: - logging_dir = Path(self.args.output_dir, self.args.logging_dir) - project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir) - ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - init_process_group_kwargs = InitProcessGroupKwargs( - backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout) - ) - report_to = None if self.args.report_to.lower() == "none" else self.args.report_to - - accelerator = Accelerator( - project_config=project_config, - gradient_accumulation_steps=self.args.gradient_accumulation_steps, - log_with=report_to, - kwargs_handlers=[ddp_kwargs, init_process_group_kwargs], - ) - - # Disable AMP for MPS. - if torch.backends.mps.is_available(): - accelerator.native_amp = False - - self.state.accelerator = accelerator - - if self.args.seed is not None: - self.state.seed = self.args.seed - set_seed(self.args.seed) - - def _init_logging(self) -> None: - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=FINETRAINERS_LOG_LEVEL, - ) - if self.state.accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - logger.info("Initialized FineTrainers") - logger.info(self.state.accelerator.state, main_process_only=False) - - def _init_directories_and_repositories(self) -> None: - if self.state.accelerator.is_main_process: - self.args.output_dir = Path(self.args.output_dir) - self.args.output_dir.mkdir(parents=True, exist_ok=True) - self.state.output_dir = Path(self.args.output_dir) - - if self.args.push_to_hub: - repo_id = self.args.hub_model_id or Path(self.args.output_dir).name - self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id - - def _init_config_options(self) -> None: - # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices - if self.args.allow_tf32 and torch.cuda.is_available(): - torch.backends.cuda.matmul.allow_tf32 = True - - def _move_components_to_device(self): - if self.text_encoder is not None: - self.text_encoder = self.text_encoder.to(self.state.accelerator.device) - if self.text_encoder_2 is not None: - self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device) - if self.text_encoder_3 is not None: - self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device) - if self.transformer is not None: - self.transformer = self.transformer.to(self.state.accelerator.device) - if self.unet is not None: - self.unet = self.unet.to(self.state.accelerator.device) - if self.vae is not None: - self.vae = self.vae.to(self.state.accelerator.device) - - def _get_load_components_kwargs(self) -> Dict[str, Any]: - load_component_kwargs = { - "text_encoder_dtype": self.args.text_encoder_dtype, - "text_encoder_2_dtype": self.args.text_encoder_2_dtype, - "text_encoder_3_dtype": self.args.text_encoder_3_dtype, - "transformer_dtype": self.args.transformer_dtype, - "vae_dtype": self.args.vae_dtype, - "shift": self.args.flow_shift, - "revision": self.args.revision, - "cache_dir": self.args.cache_dir, - } - if self.args.pretrained_model_name_or_path is not None: - load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path - return load_component_kwargs - - def _set_components(self, components: Dict[str, Any]) -> None: - # Set models - self.tokenizer = components.get("tokenizer", self.tokenizer) - self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2) - self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3) - self.text_encoder = components.get("text_encoder", self.text_encoder) - self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2) - self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3) - self.transformer = components.get("transformer", self.transformer) - self.unet = components.get("unet", self.unet) - self.vae = components.get("vae", self.vae) - self.scheduler = components.get("scheduler", self.scheduler) - - # Set configs - self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config - self.vae_config = self.vae.config if self.vae is not None else self.vae_config - - def _delete_components(self) -> None: - self.tokenizer = None - self.tokenizer_2 = None - self.tokenizer_3 = None - self.text_encoder = None - self.text_encoder_2 = None - self.text_encoder_3 = None - self.transformer = None - self.unet = None - self.vae = None - self.scheduler = None - free_memory() - torch.cuda.synchronize(self.state.accelerator.device) - - def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline: - accelerator = self.state.accelerator - if not final_validation: - pipeline = self.model_config["initialize_pipeline"]( - model_id=self.args.pretrained_model_name_or_path, - tokenizer=self.tokenizer, - text_encoder=self.text_encoder, - tokenizer_2=self.tokenizer_2, - text_encoder_2=self.text_encoder_2, - transformer=unwrap_model(accelerator, self.transformer), - vae=self.vae, - device=accelerator.device, - revision=self.args.revision, - cache_dir=self.args.cache_dir, - enable_slicing=self.args.enable_slicing, - enable_tiling=self.args.enable_tiling, - enable_model_cpu_offload=self.args.enable_model_cpu_offload, - is_training=True, - ) - else: - self._delete_components() - - # Load the transformer weights from the final checkpoint if performing full-finetune - transformer = None - if self.args.training_type == "full-finetune": - transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"] - - pipeline = self.model_config["initialize_pipeline"]( - model_id=self.args.pretrained_model_name_or_path, - transformer=transformer, - device=accelerator.device, - revision=self.args.revision, - cache_dir=self.args.cache_dir, - enable_slicing=self.args.enable_slicing, - enable_tiling=self.args.enable_tiling, - enable_model_cpu_offload=self.args.enable_model_cpu_offload, - is_training=False, - ) - - # Load the LoRA weights if performing LoRA finetuning - if self.args.training_type == "lora": - pipeline.load_lora_weights(self.args.output_dir) - - return pipeline - - def _disable_grad_for_components(self, components: List[torch.nn.Module]): - for component in components: - if component is not None: - component.requires_grad_(False) - - def _enable_grad_for_components(self, components: List[torch.nn.Module]): - for component in components: - if component is not None: - component.requires_grad_(True) - - def _get_training_info(self) -> dict: - args = self.args.to_dict() - - training_args = args.get("training_arguments", {}) - training_type = training_args.get("training_type", "") - - # LoRA/non-LoRA stuff. - if training_type == "full-finetune": - filtered_training_args = { - k: v for k, v in training_args.items() if k not in {"rank", "lora_alpha", "target_modules"} - } - else: - filtered_training_args = training_args - - # Diffusion/flow stuff. - diffusion_args = args.get("diffusion_arguments", {}) - scheduler_name = self.scheduler.__class__.__name__ - if scheduler_name != "FlowMatchEulerDiscreteScheduler": - filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k} - else: - filtered_diffusion_args = diffusion_args - - # Rest of the stuff. - updated_training_info = args.copy() - updated_training_info["training_arguments"] = filtered_training_args - updated_training_info["diffusion_arguments"] = filtered_diffusion_args - return updated_training_info diff --git a/finetrainers/trainer/__init__.py b/finetrainers/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba65c53fa75115084ceec6dabc2667a8a5d6a29 --- /dev/null +++ b/finetrainers/trainer/__init__.py @@ -0,0 +1 @@ +from .sft_trainer.trainer import SFTTrainer diff --git a/finetrainers/trainer/config_utils.py b/finetrainers/trainer/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c354a30bac60265868c3ba8ef5313c12b7fc224 --- /dev/null +++ b/finetrainers/trainer/config_utils.py @@ -0,0 +1,17 @@ +import argparse +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from ..args import BaseArgs + + +class ConfigMixin: + def add_args(self, parser: argparse.ArgumentParser): + raise NotImplementedError("ConfigMixin::add_args should be implemented by subclasses.") + + def validate_args(self, args: "BaseArgs"): + raise NotImplementedError("ConfigMixin::map_args should be implemented by subclasses.") + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + raise NotImplementedError("ConfigMixin::validate_args should be implemented by subclasses.") diff --git a/finetrainers/trainer/sft_trainer/config.py b/finetrainers/trainer/sft_trainer/config.py new file mode 100644 index 0000000000000000000000000000000000000000..539426c7c9509c1b417d3471cc4391436f532cb7 --- /dev/null +++ b/finetrainers/trainer/sft_trainer/config.py @@ -0,0 +1,58 @@ +import argparse +from typing import TYPE_CHECKING, List, Union + +from ..config_utils import ConfigMixin + + +if TYPE_CHECKING: + from ...args import BaseArgs + + +class SFTLowRankConfig(ConfigMixin): + r""" + Configuration class for SFT low rank training. + + Args: + rank (int): + Rank of the low rank approximation. + lora_alpha (int): + The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices. + target_modules (`str` or `List[str]`): + Target modules for the low rank approximation. Can be a regex string or a list of regex strings. + """ + + rank: int = 64 + lora_alpha: int = 64 + target_modules: Union[str, List[str]] = "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)" + + def add_args(self, parser: argparse.ArgumentParser): + parser.add_argument("--rank", type=int, default=64) + parser.add_argument("--lora_alpha", type=int, default=64) + parser.add_argument( + "--target_modules", + type=str, + nargs="+", + default=["(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"], + ) + + def validate_args(self, args: "BaseArgs"): + assert self.rank > 0, "Rank must be a positive integer." + assert self.lora_alpha > 0, "lora_alpha must be a positive integer." + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + mapped_args.rank = argparse_args.rank + mapped_args.lora_alpha = argparse_args.lora_alpha + mapped_args.target_modules = ( + argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules + ) + + +class SFTFullRankConfig(ConfigMixin): + def add_args(self, parser: argparse.ArgumentParser): + pass + + def validate_args(self, args: "BaseArgs"): + pass + + def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"): + pass diff --git a/finetrainers/trainer/sft_trainer/trainer.py b/finetrainers/trainer/sft_trainer/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e9523730d7e7c9af666b618dbfd05858122ff96b --- /dev/null +++ b/finetrainers/trainer/sft_trainer/trainer.py @@ -0,0 +1,934 @@ +import functools +import json +import math +import os +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union + +import datasets.distributed +import diffusers +import torch +import torch.backends +import transformers +import wandb +from diffusers import DiffusionPipeline +from diffusers.hooks import apply_layerwise_casting +from diffusers.training_utils import cast_training_params +from diffusers.utils import export_to_video +from huggingface_hub import create_repo, upload_folder +from peft import LoraConfig, get_peft_model_state_dict +from tqdm import tqdm + +from ... import data, logging, optimizer, parallel, patches, utils +from ...config import TrainingType +from ...state import State, TrainState + + +if TYPE_CHECKING: + from ...args import BaseArgs + from ...models import ModelSpecification + + +logger = logging.get_logger() + + +class SFTTrainer: + def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None: + self.args = args + self.state = State() + self.state.train_state = TrainState() + + # Tokenizers + self.tokenizer = None + self.tokenizer_2 = None + self.tokenizer_3 = None + + # Text encoders + self.text_encoder = None + self.text_encoder_2 = None + self.text_encoder_3 = None + + # Denoisers + self.transformer = None + self.unet = None + + # Autoencoders + self.vae = None + + # Scheduler + self.scheduler = None + + # Optimizer & LR scheduler + self.optimizer = None + self.lr_scheduler = None + + # Checkpoint manager + self.checkpointer = None + + self._init_distributed() + self._init_config_options() + + # Perform any patches that might be necessary for training to work as expected + patches.perform_patches_for_training(self.args, self.state.parallel_backend) + + self.model_specification = model_specification + + def run(self) -> None: + try: + self._prepare_models() + self._prepare_trainable_parameters() + self._prepare_for_training() + self._prepare_dataset() + self._prepare_checkpointing() + self._train() + # trainer._evaluate() + except Exception as e: + logger.error(f"Error during training: {e}") + self.state.parallel_backend.destroy() + raise e + + def _prepare_models(self) -> None: + logger.info("Initializing models") + + diffusion_components = self.model_specification.load_diffusion_models() + self._set_components(diffusion_components) + + if self.state.parallel_backend.pipeline_parallel_enabled: + raise NotImplementedError( + "Pipeline parallelism is not supported yet. This will be supported in the future." + ) + + def _prepare_trainable_parameters(self) -> None: + logger.info("Initializing trainable parameters") + + parallel_backend = self.state.parallel_backend + + if self.args.training_type == TrainingType.FULL_FINETUNE: + logger.info("Finetuning transformer with no additional parameters") + utils.set_requires_grad([self.transformer], True) + else: + logger.info("Finetuning transformer with PEFT parameters") + utils.set_requires_grad([self.transformer], False) + + # Layerwise upcasting must be applied before adding the LoRA adapter. + # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on + # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. + if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules: + apply_layerwise_casting( + self.transformer, + storage_dtype=self.args.layerwise_upcasting_storage_dtype, + compute_dtype=self.args.transformer_dtype, + skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, + non_blocking=True, + ) + + transformer_lora_config = None + if self.args.training_type == TrainingType.LORA: + transformer_lora_config = LoraConfig( + r=self.args.rank, + lora_alpha=self.args.lora_alpha, + init_lora_weights=True, + target_modules=self.args.target_modules, + ) + self.transformer.add_adapter(transformer_lora_config) + + # # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32 + # # even if layerwise upcasting. Would be nice to have a test as well + # self.register_saving_loading_hooks(transformer_lora_config) + + # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all + # parameters to be of the same dtype. + if self.args.training_type == TrainingType.LORA and not parallel_backend.data_sharding_enabled: + cast_training_params([self.transformer], dtype=torch.float32) + + def _prepare_for_training(self) -> None: + # 1. Apply parallelism + parallel_backend = self.state.parallel_backend + world_mesh = parallel_backend.get_mesh() + model_specification = self.model_specification + + if parallel_backend.context_parallel_enabled: + raise NotImplementedError( + "Context parallelism is not supported yet. This will be supported in the future." + ) + + if parallel_backend.tensor_parallel_enabled: + # TODO(aryan): handle fp8 from TorchAO here + model_specification.apply_tensor_parallel( + backend=parallel.ParallelBackendEnum.PTD, + device_mesh=parallel_backend.get_mesh()["tp"], + transformer=self.transformer, + ) + + # Enable gradient checkpointing + if self.args.gradient_checkpointing: + # TODO(aryan): support other checkpointing types + utils.apply_activation_checkpointing(self.transformer, checkpointing_type="full") + + # Enable DDP, FSDP or HSDP + if parallel_backend.data_sharding_enabled: + # TODO(aryan): remove this when supported + if self.args.parallel_backend == "accelerate": + raise NotImplementedError("Data sharding is not supported with Accelerate yet.") + + if parallel_backend.data_replication_enabled: + logger.info("Applying HSDP to the model") + else: + logger.info("Applying FSDP to the model") + + # Apply FSDP or HSDP + if parallel_backend.data_replication_enabled or parallel_backend.context_parallel_enabled: + dp_mesh_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_names = ("dp_shard_cp",) + + parallel.apply_fsdp2_ptd( + model=self.transformer, + dp_mesh=world_mesh[dp_mesh_names], + param_dtype=self.args.transformer_dtype, + reduce_dtype=torch.float32, + output_dtype=None, + pp_enabled=parallel_backend.pipeline_parallel_enabled, + cpu_offload=False, # TODO(aryan): needs to be tested and allowed for enabling later + ) + elif parallel_backend.data_replication_enabled: + logger.info("Applying DDP to the model") + + if world_mesh.ndim > 1: + raise ValueError("DDP not supported for > 1D parallelism") + + parallel_backend.apply_ddp(self.transformer, world_mesh) + + self._move_components_to_device() + + # 2. Prepare optimizer and lr scheduler + # For training LoRAs, we can be a little more optimal. Currently, the OptimizerWrapper only accepts torch::nn::Module. + # This causes us to loop over all the parameters (even ones that don't require gradients, as in LoRA) at each optimizer + # step. This is OK (see https://github.com/pytorch/pytorch/blob/2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4/torch/optim/sgd.py#L85-L99) + # but can be optimized a bit by maybe creating a simple wrapper module encompassing the actual parameters that require + # gradients. TODO(aryan): look into it in the future. + model_parts = [self.transformer] + self.state.num_trainable_parameters = sum( + p.numel() for m in model_parts for p in m.parameters() if p.requires_grad + ) + + # Setup distributed optimizer and lr scheduler + logger.info("Initializing optimizer and lr scheduler") + self.state.train_state = TrainState() + self.optimizer = optimizer.get_optimizer( + parallel_backend=self.args.parallel_backend, + name=self.args.optimizer, + model_parts=model_parts, + learning_rate=self.args.lr, + beta1=self.args.beta1, + beta2=self.args.beta2, + beta3=self.args.beta3, + epsilon=self.args.epsilon, + weight_decay=self.args.weight_decay, + fused=False, + ) + self.lr_scheduler = optimizer.get_lr_scheduler( + parallel_backend=self.args.parallel_backend, + name=self.args.lr_scheduler, + optimizer=self.optimizer, + num_warmup_steps=self.args.lr_warmup_steps, + num_training_steps=self.args.train_steps, + # TODO(aryan): handle last_epoch + ) + self.optimizer, self.lr_scheduler = parallel_backend.prepare_optimizer(self.optimizer, self.lr_scheduler) + + # 3. Initialize trackers, directories and repositories + self._init_logging() + self._init_trackers() + self._init_directories_and_repositories() + + def _prepare_dataset(self) -> None: + logger.info("Initializing dataset and dataloader") + + with open(self.args.dataset_config, "r") as file: + dataset_configs = json.load(file)["datasets"] + logger.info(f"Training configured to use {len(dataset_configs)} datasets") + + datasets = [] + for config in dataset_configs: + data_root = config.pop("data_root", None) + dataset_file = config.pop("dataset_file", None) + dataset_type = config.pop("dataset_type") + + if data_root is not None and dataset_file is not None: + raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.") + + dataset_name_or_root = data_root or dataset_file + dataset = data.initialize_dataset(dataset_name_or_root, dataset_type, streaming=True, infinite=True) + + if not dataset._precomputable_once and self.args.precomputation_once: + raise ValueError( + f"Dataset {dataset_name_or_root} does not support precomputing all embeddings at once." + ) + + logger.info(f"Initialized dataset: {dataset_name_or_root}") + dataset = self.state.parallel_backend.prepare_dataset(dataset) + dataset = data.wrap_iterable_dataset_for_preprocessing(dataset, dataset_type, config) + datasets.append(dataset) + + dataset = data.combine_datasets(datasets, buffer_size=self.args.dataset_shuffle_buffer_size, shuffle=True) + dataloader = self.state.parallel_backend.prepare_dataloader( + dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory + ) + + self.dataset = dataset + self.dataloader = dataloader + + def _prepare_checkpointing(self) -> None: + parallel_backend = self.state.parallel_backend + + def save_model_hook(state_dict: Dict[str, Any]) -> None: + if parallel_backend.is_main_process: + if self.args.training_type == TrainingType.LORA: + state_dict = get_peft_model_state_dict(self.transformer, state_dict) + self.model_specification._save_lora_weights(self.args.output_dir, state_dict, self.scheduler) + elif self.args.training_type == TrainingType.FULL_FINETUNE: + self.model_specification._save_model( + self.args.output_dir, self.transformer, state_dict, self.scheduler + ) + parallel_backend.wait_for_everyone() + + enable_state_checkpointing = self.args.checkpointing_steps > 0 + self.checkpointer = utils.PTDCheckpointManager( + dataloader=self.dataloader, + model_parts=[self.transformer], + optimizers=self.optimizer, + schedulers=self.lr_scheduler, + states={"train_state": self.state.train_state}, + checkpointing_steps=self.args.checkpointing_steps, + checkpointing_limit=self.args.checkpointing_limit, + output_dir=self.args.output_dir, + enable=enable_state_checkpointing, + _callback_fn=save_model_hook, + ) + + resume_from_checkpoint = self.args.resume_from_checkpoint + if resume_from_checkpoint == "latest": + resume_from_checkpoint = -1 + if resume_from_checkpoint is not None: + self.checkpointer.load(resume_from_checkpoint) + + def _train(self) -> None: + logger.info("Starting training") + + parallel_backend = self.state.parallel_backend + train_state = self.state.train_state + device = parallel_backend.device + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") + + global_batch_size = self.args.batch_size * parallel_backend._dp_degree + info = { + "trainable parameters": self.state.num_trainable_parameters, + "train steps": self.args.train_steps, + "per-replica batch size": self.args.batch_size, + "global batch size": global_batch_size, + "gradient accumulation steps": self.args.gradient_accumulation_steps, + } + logger.info(f"Training configuration: {json.dumps(info, indent=4)}") + + progress_bar = tqdm( + range(0, self.args.train_steps), + initial=train_state.step, + desc="Training steps", + disable=not parallel_backend.is_local_main_process, + ) + + generator = torch.Generator(device=device) + if self.args.seed is not None: + generator = generator.manual_seed(self.args.seed) + self.state.generator = generator + + patch_size = 1 + if ( + getattr(self.transformer.config, "patch_size", None) is not None + and getattr(self.transformer.config, "patch_size_t", None) is not None + ): + patch_size = self.transformer.config.patch_size * self.transformer.config.patch_size_t + elif isinstance(getattr(self.transformer.config, "patch_size", None), int): + patch_size = self.transformer.config.patch_size + elif isinstance(getattr(self.transformer.config, "patch_size", None), (list, tuple)): + patch_size = math.prod(self.transformer.config.patch_size) + + scheduler_sigmas = utils.get_scheduler_sigmas(self.scheduler) + scheduler_sigmas = ( + scheduler_sigmas.to(device=device, dtype=torch.float32) if scheduler_sigmas is not None else None + ) + scheduler_alphas = utils.get_scheduler_alphas(self.scheduler) + scheduler_alphas = ( + scheduler_alphas.to(device=device, dtype=torch.float32) if scheduler_alphas is not None else None + ) + timesteps_buffer = [] + + self.transformer.train() + data_iterator = iter(self.dataloader) + + preprocessor = data.DistributedDataPreprocessor( + rank=parallel_backend.rank, + num_items=self.args.precomputation_items, + processor_fn={ + "condition": self.model_specification.prepare_conditions, + "latent": functools.partial( + self.model_specification.prepare_latents, compute_posterior=not self.args.precomputation_once + ), + }, + save_dir=self.args.precomputation_dir, + ) + precomputed_condition_iterator: Iterable[Dict[str, Any]] = None + precomputed_latent_iterator: Iterable[Dict[str, Any]] = None + sampler = data.ResolutionSampler( + batch_size=self.args.batch_size, dim_keys=self.model_specification._resolution_dim_keys + ) + requires_gradient_step = True + accumulated_loss = 0.0 + + while ( + train_state.step < self.args.train_steps and train_state.observed_data_samples < self.args.max_data_samples + ): + # 1. Load & preprocess data if required + if preprocessor.requires_data: + # TODO(aryan): We should do the following here: + # - Force checkpoint the trainable models, optimizers, schedulers and train state + # - Do the precomputation + # - Load the checkpointed models, optimizers, schedulers and train state back, and continue training + # This way we can be more memory efficient again, since the latest rewrite of precomputation removed + # this logic. + precomputed_condition_iterator, precomputed_latent_iterator = self._prepare_data( + preprocessor, data_iterator + ) + + # 2. Prepare batch + try: + condition_item = next(precomputed_condition_iterator) + latent_item = next(precomputed_latent_iterator) + sampler.consume(condition_item, latent_item) + except StopIteration: + if requires_gradient_step: + self.optimizer.step() + self.lr_scheduler.step() + requires_gradient_step = False + logger.info("Data exhausted. Exiting training loop.") + break + + if sampler.is_ready: + condition_batch, latent_batch = sampler.get_batch() + condition_model_conditions = self.model_specification.collate_conditions(condition_batch) + latent_model_conditions = self.model_specification.collate_latents(latent_batch) + else: + continue + + train_state.step += 1 + train_state.observed_data_samples += self.args.batch_size * parallel_backend._dp_degree + + lmc_latents = latent_model_conditions["latents"] + train_state.observed_num_tokens += math.prod(lmc_latents.shape[:-1]) // patch_size + + logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})") + + utils.align_device_and_dtype(latent_model_conditions, device, self.args.transformer_dtype) + utils.align_device_and_dtype(condition_model_conditions, device, self.args.transformer_dtype) + latent_model_conditions = utils.make_contiguous(latent_model_conditions) + condition_model_conditions = utils.make_contiguous(condition_model_conditions) + + # 3. Forward pass + sigmas = utils.prepare_sigmas( + scheduler=self.scheduler, + sigmas=scheduler_sigmas, + batch_size=self.args.batch_size, + num_train_timesteps=self.scheduler.config.num_train_timesteps, + flow_weighting_scheme=self.args.flow_weighting_scheme, + flow_logit_mean=self.args.flow_logit_mean, + flow_logit_std=self.args.flow_logit_std, + flow_mode_scale=self.args.flow_mode_scale, + device=device, + generator=self.state.generator, + ) + sigmas = utils.expand_tensor_dims(sigmas, latent_model_conditions["latents"].ndim) + + pred, target, sigmas = self.model_specification.forward( + transformer=self.transformer, + scheduler=self.scheduler, + condition_model_conditions=condition_model_conditions, + latent_model_conditions=latent_model_conditions, + sigmas=sigmas, + compute_posterior=not self.args.precomputation_once, + ) + + timesteps = (sigmas * 1000.0).long() + weights = utils.prepare_loss_weights( + scheduler=self.scheduler, + alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None, + sigmas=sigmas, + flow_weighting_scheme=self.args.flow_weighting_scheme, + ) + weights = utils.expand_tensor_dims(weights, pred.ndim) + + # 4. Compute loss & backward pass + loss = weights.float() * (pred.float() - target.float()).pow(2) + # Average loss across all but batch dimension + loss = loss.mean(list(range(1, loss.ndim))) + # Average loss across batch dimension + loss = loss.mean() + if self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + loss.backward() + accumulated_loss += loss.detach().item() + requires_gradient_step = True + + # 5. Clip gradients + model_parts = [self.transformer] + grad_norm = utils.torch._clip_grad_norm_while_handling_failing_dtensor_cases( + [p for m in model_parts for p in m.parameters()], + self.args.max_grad_norm, + foreach=True, + pp_mesh=parallel_backend.get_mesh("pp") if parallel_backend.pipeline_parallel_enabled else None, + ) + + # 6. Step optimizer & log metrics + logs = {} + + if train_state.step % self.args.gradient_accumulation_steps == 0: + # TODO(aryan): revisit no_sync() for FSDP + # TODO(aryan): average the gradients for accumulation? + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + if grad_norm is not None: + logs["grad_norm"] = grad_norm if isinstance(grad_norm, float) else grad_norm.detach().item() + if ( + parallel_backend.data_replication_enabled + or parallel_backend.data_sharding_enabled + or parallel_backend.context_parallel_enabled + ): + dp_cp_mesh = parallel_backend.get_mesh("dp_cp") + global_avg_loss, global_max_loss = ( + parallel.dist_mean(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), + parallel.dist_max(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), + ) + else: + global_avg_loss = global_max_loss = accumulated_loss + + logs["global_avg_loss"] = global_avg_loss + logs["global_max_loss"] = global_max_loss + train_state.global_avg_losses.append(global_avg_loss) + train_state.global_max_losses.append(global_max_loss) + accumulated_loss = 0.0 + requires_gradient_step = False + + progress_bar.update(1) + progress_bar.set_postfix(logs) + + timesteps_buffer.extend([(train_state.step, t) for t in timesteps.detach().cpu().numpy().tolist()]) + + if train_state.step % self.args.logging_steps == 0: + # TODO(aryan): handle non-SchedulerWrapper schedulers (probably not required eventually) since they might not be dicts + # TODO(aryan): causes NCCL hang for some reason. look into later + # logs.update(self.lr_scheduler.get_last_lr()) + + # timesteps_table = wandb.Table(data=timesteps_buffer, columns=["step", "timesteps"]) + # logs["timesteps"] = wandb.plot.scatter( + # timesteps_table, "step", "timesteps", title="Timesteps distribution" + # ) + timesteps_buffer = [] + + logs["observed_data_samples"] = train_state.observed_data_samples + logs["observed_num_tokens"] = train_state.observed_num_tokens + + parallel_backend.log(logs, step=train_state.step) + train_state.log_steps.append(train_state.step) + + # 7. Save checkpoint if required + self.checkpointer.save( + step=train_state.step, _device=device, _is_main_process=parallel_backend.is_main_process + ) + + # 8. Perform validation if required + if train_state.step % self.args.validation_steps == 0: + self._validate(step=train_state.step, final_validation=False) + + # 9. Final checkpoint, validation & cleanup + self.checkpointer.save( + train_state.step, force=True, _device=device, _is_main_process=parallel_backend.is_main_process + ) + parallel_backend.wait_for_everyone() + self._validate(step=train_state.step, final_validation=True) + + self._delete_components() + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") + + # 10. Upload artifacts to hub + if parallel_backend.is_main_process and self.args.push_to_hub: + upload_folder( + repo_id=self.state.repo_id, + folder_path=self.args.output_dir, + ignore_patterns=[f"{self.checkpointer._prefix}_*"], + ) + + parallel_backend.destroy() + + def _validate(self, step: int, final_validation: bool = False) -> None: + if self.args.validation_dataset_file is None: + return + + logger.info("Starting validation") + + # 1. Load validation dataset + parallel_backend = self.state.parallel_backend + dp_mesh = parallel_backend.get_mesh("dp_replicate") + + if dp_mesh is not None: + local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() + else: + local_rank, dp_world_size = 0, 1 + + dataset = data.ValidationDataset(self.args.validation_dataset_file) + dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, local_rank, dp_world_size) + validation_dataloader = data.DPDataLoader( + local_rank, + dataset, + batch_size=1, + num_workers=self.args.dataloader_num_workers, + collate_fn=lambda items: items, + ) + data_iterator = iter(validation_dataloader) + main_process_prompts_to_filenames = {} # Used to save model card + all_processes_artifacts = [] # Used to gather artifacts from all processes + + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") + + seed = self.args.seed if self.args.seed is not None else 0 + generator = torch.Generator(device=parallel_backend.device).manual_seed(seed) + pipeline = self._init_pipeline(final_validation=final_validation) + + # 2. Run validation + # TODO(aryan): when running validation with FSDP, if the number of data points is not divisible by dp_shards, we + # will hang indefinitely. Either pad the dataset or raise an error early on during initialization if the dataset + # size is not divisible by dp_shards. + self.transformer.eval() + while True: + validation_data = next(data_iterator, None) + if validation_data is None: + break + + logger.debug( + f"Validating {validation_data=} on rank={parallel_backend.rank}.", local_main_process_only=False + ) + + validation_data = validation_data[0] + validation_artifacts = self.model_specification.validation( + pipeline=pipeline, generator=generator, **validation_data + ) + + PROMPT = validation_data["prompt"] + IMAGE = validation_data.get("image", None) + VIDEO = validation_data.get("video", None) + EXPORT_FPS = validation_data.get("export_fps", 30) + + # 2.1. If there are any initial images or videos, they will be logged to keep track of them as + # conditioning for generation. + prompt_filename = utils.string_to_filename(PROMPT)[:25] + artifacts = { + "input_image": data.ImageArtifact(value=IMAGE), + "input_video": data.VideoArtifact(value=VIDEO), + } + + # 2.2. Track the artifacts generated from validation + for i, validation_artifact in enumerate(validation_artifacts): + if validation_artifact.value is None: + continue + artifacts.update({f"artifact_{i}": validation_artifact}) + + # 2.3. Save the artifacts to the output directory and create appropriate logging objects + # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited. + for index, (key, artifact) in enumerate(list(artifacts.items())): + assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact)) + filename = "validation-" if not final_validation else "final-" + filename += f"{step}-{parallel_backend.rank}-{index}-{prompt_filename}.{artifact.file_extension}" + output_filename = os.path.join(self.args.output_dir, filename) + + if parallel_backend.is_main_process and artifact.file_extension == "mp4": + main_process_prompts_to_filenames[PROMPT] = filename + + caption = f"{PROMPT} | (filename: {output_filename})" + if artifact.type == "image" and artifact.value is not None: + logger.debug( + f"Saving image from rank={parallel_backend.rank} to {output_filename}", + local_main_process_only=False, + ) + artifact.value.save(output_filename) + all_processes_artifacts.append(wandb.Image(output_filename, caption=caption)) + elif artifact.type == "video" and artifact.value is not None: + logger.debug( + f"Saving video from rank={parallel_backend.rank} to {output_filename}", + local_main_process_only=False, + ) + export_to_video(artifact.value, output_filename, fps=EXPORT_FPS) + all_processes_artifacts.append(wandb.Video(output_filename, caption=caption)) + + # 3. Cleanup & log artifacts + parallel_backend.wait_for_everyone() + + # Remove all hooks that might have been added during pipeline initialization to the models + pipeline.remove_all_hooks() + del pipeline + + utils.free_memory() + memory_statistics = utils.get_memory_statistics() + logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") + torch.cuda.reset_peak_memory_stats(parallel_backend.device) + + # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts. + # TODO(aryan): probably should only all gather from dp mesh process group + all_artifacts = [None] * parallel_backend.world_size + torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts) + all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts] + + if parallel_backend.is_main_process: + tracker_key = "final" if final_validation else "validation" + artifact_log_dict = {} + + image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] + if len(image_artifacts) > 0: + artifact_log_dict["images"] = image_artifacts + video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] + if len(video_artifacts) > 0: + artifact_log_dict["videos"] = video_artifacts + parallel_backend.log({tracker_key: artifact_log_dict}, step=step) + + if self.args.push_to_hub and final_validation: + video_filenames = list(main_process_prompts_to_filenames.values()) + prompts = list(main_process_prompts_to_filenames.keys()) + utils.save_model_card( + args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts + ) + + parallel_backend.wait_for_everyone() + if not final_validation: + self.transformer.train() + + def _evaluate(self) -> None: + raise NotImplementedError("Evaluation has not been implemented yet.") + + def _init_distributed(self) -> None: + # TODO: Accelerate disables native_amp for MPS. Probably need to do the same with implementation. + world_size = int(os.environ["WORLD_SIZE"]) + + # TODO(aryan): handle other backends + backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend) + self.state.parallel_backend = backend_cls( + world_size=world_size, + pp_degree=self.args.pp_degree, + dp_degree=self.args.dp_degree, + dp_shards=self.args.dp_shards, + cp_degree=self.args.cp_degree, + tp_degree=self.args.tp_degree, + backend="nccl", + timeout=self.args.init_timeout, + logging_dir=self.args.logging_dir, + output_dir=self.args.output_dir, + gradient_accumulation_steps=self.args.gradient_accumulation_steps, + ) + + if self.args.seed is not None: + world_mesh = self.state.parallel_backend.get_mesh() + utils.enable_determinism(self.args.seed, world_mesh) + + def _init_logging(self) -> None: + transformers_log_level = transformers.utils.logging.set_verbosity_error + diffusers_log_level = diffusers.utils.logging.set_verbosity_error + + if self.args.verbose == 0: + if self.state.parallel_backend.is_local_main_process: + transformers_log_level = transformers.utils.logging.set_verbosity_warning + diffusers_log_level = diffusers.utils.logging.set_verbosity_warning + elif self.args.verbose == 1: + if self.state.parallel_backend.is_local_main_process: + transformers_log_level = transformers.utils.logging.set_verbosity_info + diffusers_log_level = diffusers.utils.logging.set_verbosity_info + elif self.args.verbose == 2: + if self.state.parallel_backend.is_local_main_process: + transformers_log_level = transformers.utils.logging.set_verbosity_debug + diffusers_log_level = diffusers.utils.logging.set_verbosity_debug + else: + transformers_log_level = transformers.utils.logging.set_verbosity_debug + diffusers_log_level = diffusers.utils.logging.set_verbosity_debug + + transformers_log_level() + diffusers_log_level() + + logging._set_parallel_backend(self.state.parallel_backend) + logger.info("Initialized FineTrainers") + + def _init_trackers(self) -> None: + # TODO(aryan): handle multiple trackers + trackers = ["wandb"] + experiment_name = self.args.tracker_name or "finetrainers-experiment" + self.state.parallel_backend.initialize_trackers( + trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir + ) + + def _init_directories_and_repositories(self) -> None: + if self.state.parallel_backend.is_main_process: + self.args.output_dir = Path(self.args.output_dir) + self.args.output_dir.mkdir(parents=True, exist_ok=True) + self.state.output_dir = Path(self.args.output_dir) + + if self.args.push_to_hub: + repo_id = self.args.hub_model_id or Path(self.args.output_dir).name + self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id + + def _init_config_options(self) -> None: + # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + def _move_components_to_device( + self, components: Optional[List[torch.nn.Module]] = None, device: Optional[Union[str, torch.device]] = None + ) -> None: + if device is None: + device = self.state.parallel_backend.device + if components is None: + components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.transformer, self.vae] + components = utils.get_non_null_items(components) + components = list(filter(lambda x: hasattr(x, "to"), components)) + for component in components: + component.to(device) + + def _set_components(self, components: Dict[str, Any]) -> None: + # fmt: off + component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"] + # fmt: on + + for component_name in component_names: + existing_component = getattr(self, component_name, None) + new_component = components.get(component_name, existing_component) + setattr(self, component_name, new_component) + + def _delete_components(self, component_names: Optional[List[str]] = None) -> None: + if component_names is None: + # fmt: off + component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"] + # fmt: on + + for component_name in component_names: + setattr(self, component_name, None) + + utils.free_memory() + utils.synchronize_device() + + def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline: + parallel_backend = self.state.parallel_backend + module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"] + + if not final_validation: + module_names.remove("transformer") + pipeline = self.model_specification.load_pipeline( + tokenizer=self.tokenizer, + tokenizer_2=self.tokenizer_2, + tokenizer_3=self.tokenizer_3, + text_encoder=self.text_encoder, + text_encoder_2=self.text_encoder_2, + text_encoder_3=self.text_encoder_3, + # TODO(aryan): handle unwrapping for compiled modules + # transformer=utils.unwrap_model(accelerator, self.transformer), + transformer=self.transformer, + vae=self.vae, + enable_slicing=self.args.enable_slicing, + enable_tiling=self.args.enable_tiling, + enable_model_cpu_offload=self.args.enable_model_cpu_offload, + training=True, + ) + else: + # TODO(aryan): this branch does not work yet, needs to be implemented + self._delete_components() + + # Load the transformer weights from the final checkpoint if performing full-finetune + transformer = None + if self.args.training_type == TrainingType.FULL_FINETUNE: + transformer = self.model_specification.load_diffusion_models()["transformer"] + + pipeline = self.model_specification.load_pipeline( + transformer=transformer, + enable_slicing=self.args.enable_slicing, + enable_tiling=self.args.enable_tiling, + enable_model_cpu_offload=self.args.enable_model_cpu_offload, + training=False, + device=parallel_backend.device, + ) + + # Load the LoRA weights if performing LoRA finetuning + if self.args.training_type == TrainingType.LORA: + pipeline.load_lora_weights(self.args.output_dir) + + components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names} + self._set_components(components) + self._move_components_to_device(list(components.values())) + return pipeline + + def _prepare_data(self, preprocessor: data.DistributedDataPreprocessor, data_iterator): + logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.") + if self.args.precomputation_once: + consume_fn = preprocessor.consume_once + else: + consume_fn = preprocessor.consume + + condition_components = self.model_specification.load_condition_models() + component_names = list(condition_components.keys()) + component_modules = list(condition_components.values()) + self._set_components(condition_components) + self._move_components_to_device(component_modules) + precomputed_condition_iterator = consume_fn( + "condition", + components=condition_components, + data_iterator=data_iterator, + generator=self.state.generator, + cache_samples=True, + ) + self._delete_components(component_names) + del condition_components, component_names, component_modules + + latent_components = self.model_specification.load_latent_models() + if self.vae is not None: + if self.args.enable_slicing: + self.vae.enable_slicing() + if self.args.enable_tiling: + self.vae.enable_tiling() + component_names = list(latent_components.keys()) + component_modules = list(latent_components.values()) + self._set_components(latent_components) + self._move_components_to_device(component_modules) + precomputed_latent_iterator = consume_fn( + "latent", + components=latent_components, + data_iterator=data_iterator, + generator=self.state.generator, + use_cached_samples=True, + drop_samples=True, + ) + self._delete_components(component_names) + del latent_components, component_names, component_modules + + return precomputed_condition_iterator, precomputed_latent_iterator + + def _get_training_info(self) -> Dict[str, Any]: + info = self.args.to_dict() + + # Removing flow matching arguments when not using flow-matching objective + diffusion_args = info.get("diffusion_arguments", {}) + scheduler_name = self.scheduler.__class__.__name__ if self.scheduler is not None else "" + if scheduler_name != "FlowMatchEulerDiscreteScheduler": + filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k} + else: + filtered_diffusion_args = diffusion_args + + info.update({"diffusion_arguments": filtered_diffusion_args}) + return info diff --git a/finetrainers/typing.py b/finetrainers/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b3b339f252d8f47ef0ff67aa6c6733a2ccd7cf --- /dev/null +++ b/finetrainers/typing.py @@ -0,0 +1,11 @@ +from typing import Union + +from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler +from transformers import CLIPTokenizer, LlamaTokenizer, LlamaTokenizerFast, T5Tokenizer, T5TokenizerFast + +from .data import ImageArtifact, VideoArtifact + + +ArtifactType = Union[ImageArtifact, VideoArtifact] +SchedulerType = Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler] +TokenizerType = Union[CLIPTokenizer, T5Tokenizer, T5TokenizerFast, LlamaTokenizer, LlamaTokenizerFast] diff --git a/finetrainers/utils/__init__.py b/finetrainers/utils/__init__.py index d06bd9def1f5d4c242cafdc03cd2d31414c1f169..7339b296efedaf9c78f8491eaaf78a40f8fdaef1 100644 --- a/finetrainers/utils/__init__.py +++ b/finetrainers/utils/__init__.py @@ -1,4 +1,9 @@ -from .diffusion_utils import ( +import inspect +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +from .activation_checkpoint import apply_activation_checkpointing +from .data import determine_batch_size, should_perform_precomputation +from .diffusion import ( default_flow_shift, get_scheduler_alphas, get_scheduler_sigmas, @@ -7,7 +12,33 @@ from .diffusion_utils import ( prepare_target, resolution_dependent_timestep_flow_shift, ) -from .file_utils import delete_files, find_files -from .memory_utils import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous -from .optimizer_utils import get_optimizer, gradient_norm, max_gradient -from .torch_utils import unwrap_model +from .file import delete_files, find_files, string_to_filename +from .hub import save_model_card +from .memory import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous +from .model import resolve_component_cls +from .state_checkpoint import PTDCheckpointManager +from .torch import ( + align_device_and_dtype, + clip_grad_norm_, + enable_determinism, + expand_tensor_dims, + get_device_info, + set_requires_grad, + synchronize_device, + unwrap_model, +) + + +def get_parameter_names(obj: Any, method_name: Optional[str] = None) -> Set[str]: + if method_name is not None: + obj = getattr(obj, method_name) + return {name for name, _ in inspect.signature(obj).parameters.items()} + + +def get_non_null_items( + x: Union[List[Any], Tuple[Any], Dict[str, Any]] +) -> Union[List[Any], Tuple[Any], Dict[str, Any]]: + if isinstance(x, dict): + return {k: v for k, v in x.items() if v is not None} + if isinstance(x, (list, tuple)): + return type(x)(v for v in x if v is not None) diff --git a/finetrainers/utils/_common.py b/finetrainers/utils/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..c230e878d6fe715d696d8285c51f0ba073fd6b3e --- /dev/null +++ b/finetrainers/utils/_common.py @@ -0,0 +1,6 @@ +DIFFUSERS_TRANSFORMER_BLOCK_NAMES = [ + "transformer_blocks", + "single_transformer_blocks", + "temporal_transformer_blocks", + "blocks", +] diff --git a/finetrainers/utils/activation_checkpoint.py b/finetrainers/utils/activation_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4193a6cc027a771fe1fc2c3cb34595fbc336b2 --- /dev/null +++ b/finetrainers/utils/activation_checkpoint.py @@ -0,0 +1,71 @@ +import collections +from enum import Enum + +import torch +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper + +from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES + + +class CheckpointType(str, Enum): + FULL = "full" + OPS = "ops" + BLOCK_SKIP = "block_skip" + + +_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, +} + + +def apply_activation_checkpointing( + module: torch.nn.Module, checkpointing_type: str = CheckpointType.FULL, n_layer: int = 1 +) -> torch.nn.Module: + if checkpointing_type == CheckpointType.FULL: + module = _apply_activation_checkpointing_blocks(module) + elif checkpointing_type == CheckpointType.OPS: + module = _apply_activation_checkpointing_ops(module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS) + elif checkpointing_type == CheckpointType.BLOCK_SKIP: + module = _apply_activation_checkpointing_blocks(module, n_layer) + else: + raise ValueError( + f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}" + ) + return module + + +def _apply_activation_checkpointing_blocks(module: torch.nn.Module, n_layer: int = None) -> torch.nn.Module: + for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES: + blocks: torch.nn.Module = getattr(module, transformer_block_name, None) + if blocks is None: + continue + for index, (layer_id, block) in enumerate(blocks.named_children()): + if n_layer is None or index % n_layer == 0: + block = checkpoint_wrapper(block, preserve_rng_state=False) + blocks.register_module(layer_id, block) + return module + + +def _apply_activation_checkpointing_ops(module: torch.nn.Module, ops) -> torch.nn.Module: + from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in ops and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0) + return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = collections.defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return checkpoint_wrapper(module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False) diff --git a/finetrainers/utils/checkpointing.py b/finetrainers/utils/checkpointing.py deleted file mode 100644 index 01dbea0a029144f53ff8587b887204515a7e3250..0000000000000000000000000000000000000000 --- a/finetrainers/utils/checkpointing.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -from typing import Tuple - -from accelerate.logging import get_logger - -from ..constants import FINETRAINERS_LOG_LEVEL -from ..utils.file_utils import delete_files, find_files - - -logger = get_logger("finetrainers") -logger.setLevel(FINETRAINERS_LOG_LEVEL) - - -def get_latest_ckpt_path_to_resume_from( - resume_from_checkpoint: str, num_update_steps_per_epoch: int, output_dir: str -) -> Tuple[str, int, int, int]: - if not resume_from_checkpoint: - initial_global_step = 0 - global_step = 0 - first_epoch = 0 - resume_from_checkpoint_path = None - else: - if resume_from_checkpoint != "latest": - path = os.path.basename(resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] if len(dirs) > 0 else None - - if path is None: - logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") - resume_from_checkpoint = None - initial_global_step = 0 - global_step = 0 - first_epoch = 0 - resume_from_checkpoint_path = None - else: - logger.info(f"Resuming from checkpoint {path}") - resume_from_checkpoint_path = os.path.join(output_dir, path) - global_step = int(path.split("-")[1]) - - initial_global_step = global_step - first_epoch = global_step // num_update_steps_per_epoch - - return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch - - -def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str: - # before saving state, check if this save would set us over the `checkpointing_limit` - if checkpointing_limit is not None: - checkpoints = find_files(output_dir, prefix="checkpoint") - - # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints - if len(checkpoints) >= checkpointing_limit: - num_to_remove = len(checkpoints) - checkpointing_limit + 1 - checkpoints_to_remove = [os.path.join(output_dir, x) for x in checkpoints[0:num_to_remove]] - delete_files(checkpoints_to_remove) - - logger.info(f"Checkpointing at step {step}") - save_path = os.path.join(output_dir, f"checkpoint-{step}") - logger.info(f"Saving state to {save_path}") - return save_path diff --git a/finetrainers/utils/data_utils.py b/finetrainers/utils/data.py similarity index 76% rename from finetrainers/utils/data_utils.py rename to finetrainers/utils/data.py index 284dd1a9c9c1d91633dd0ae7e032a0cb507e8b2c..ae3fcc35262f7ec85d015c468983b033d61a154c 100644 --- a/finetrainers/utils/data_utils.py +++ b/finetrainers/utils/data.py @@ -1,6 +1,7 @@ from pathlib import Path -from typing import Union +from typing import Any, Union +import torch from accelerate.logging import get_logger from ..constants import PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME @@ -33,3 +34,18 @@ def should_perform_precomputation(precomputation_dir: Union[str, Path]) -> bool: return False logger.info("Precomputed data not found. Running precomputation.") return True + + +def determine_batch_size(x: Any) -> int: + if isinstance(x, list): + return len(x) + if isinstance(x, torch.Tensor): + return x.size(0) + if isinstance(x, dict): + for key in x: + try: + return determine_batch_size(x[key]) + except ValueError: + pass + return 1 + raise ValueError("Could not determine batch size from input.") diff --git a/finetrainers/utils/diffusion_utils.py b/finetrainers/utils/diffusion.py similarity index 100% rename from finetrainers/utils/diffusion_utils.py rename to finetrainers/utils/diffusion.py diff --git a/finetrainers/utils/file_utils.py b/finetrainers/utils/file.py similarity index 86% rename from finetrainers/utils/file_utils.py rename to finetrainers/utils/file.py index eb731771ba9bc7e07a273ca8947949dfd572b465..ba01213e758685aaa339f3ecad12c312a540dd9e 100644 --- a/finetrainers/utils/file_utils.py +++ b/finetrainers/utils/file.py @@ -1,12 +1,12 @@ -import logging import os import shutil from pathlib import Path from typing import List, Union +from ..logging import get_logger -logger = logging.getLogger("finetrainers") -logger.setLevel(os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO")) + +logger = get_logger() def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]: @@ -24,7 +24,7 @@ def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None: if not isinstance(dirs, list): dirs = [dirs] dirs = [Path(d) if isinstance(d, str) else d for d in dirs] - logger.info(f"Deleting files: {dirs}") + logger.debug(f"Deleting files: {dirs}") for dir in dirs: if not dir.exists(): continue diff --git a/finetrainers/utils/hub_utils.py b/finetrainers/utils/hub.py similarity index 82% rename from finetrainers/utils/hub_utils.py rename to finetrainers/utils/hub.py index ef865407ef9a1d41300c2e544d622a62b498989b..ea1a16eb42cbb1f2848376440817a3e1680ce61c 100644 --- a/finetrainers/utils/hub_utils.py +++ b/finetrainers/utils/hub.py @@ -28,20 +28,17 @@ def save_model_card( } ) - training_type = "Full" if args.training_type == "full-finetune" else "LoRA" model_description = f""" -# {training_type} Finetune +# LoRA Finetune <Gallery /> ## Model description -This is a {training_type.lower()} finetune of model: `{args.pretrained_model_name_or_path}`. +This is a lora finetune of model: `{args.pretrained_model_name_or_path}`. The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers). -`id_token` used: {args.id_token} (if it's not `None`, it should be used in the prompts.) - ## Download model [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab. @@ -56,7 +53,7 @@ TODO For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers. """ - if wandb.run and wandb.run.url: + if wandb.run.url: model_description += f""" Find out the wandb run URL and training configurations [here]({wandb.run.url}). """ @@ -72,13 +69,9 @@ Find out the wandb run URL and training configurations [here]({wandb.run.url}). "text-to-video", "diffusers-training", "diffusers", - "finetrainers", + "lora", "template:sd-lora", ] - if training_type == "Full": - tags.append("full-finetune") - else: - tags.append("lora") model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(args.output_dir, "README.md")) diff --git a/finetrainers/utils/import_utils.py b/finetrainers/utils/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56c19db6e4f032296bca103e38155da1ded1c074 --- /dev/null +++ b/finetrainers/utils/import_utils.py @@ -0,0 +1,20 @@ +import importlib + +import importlib_metadata + +from ..logging import get_logger + + +logger = get_logger() + + +_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None +try: + _bitsandbytes_version = importlib_metadata.version("bitsandbytes") + logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") +except importlib_metadata.PackageNotFoundError: + _bitsandbytes_available = False + + +def is_bitsandbytes_available(): + return _bitsandbytes_available diff --git a/finetrainers/utils/memory_utils.py b/finetrainers/utils/memory.py similarity index 100% rename from finetrainers/utils/memory_utils.py rename to finetrainers/utils/memory.py diff --git a/finetrainers/utils/model.py b/finetrainers/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4427f97d25ed44b2d9832cf456b082f65d66c2a8 --- /dev/null +++ b/finetrainers/utils/model.py @@ -0,0 +1,32 @@ +import importlib +import json +import os +from typing import Optional + +from huggingface_hub import hf_hub_download + + +def resolve_component_cls( + pretrained_model_name_or_path: str, + component_name: str, + filename: str = "model_index.json", + revision: Optional[str] = None, + cache_dir: Optional[str] = None, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.exists(str(pretrained_model_name_or_path)) and os.path.isdir(pretrained_model_name_or_path): + index_path = os.path.join(pretrained_model_name_or_path, filename) + else: + index_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=filename, revision=revision, cache_dir=cache_dir + ) + + with open(index_path, "r") as f: + model_index_dict = json.load(f) + + if component_name not in model_index_dict: + raise ValueError(f"No {component_name} found in the model index dict.") + + cls_config = model_index_dict[component_name] + library = importlib.import_module(cls_config[0]) + return getattr(library, cls_config[1]) diff --git a/finetrainers/utils/model_utils.py b/finetrainers/utils/model_utils.py deleted file mode 100644 index 1451ebff3d7c29e60d856cd41f56b358b044ffc9..0000000000000000000000000000000000000000 --- a/finetrainers/utils/model_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import importlib -import json -import os - -from huggingface_hub import hf_hub_download - - -def resolve_vae_cls_from_ckpt_path(ckpt_path, **kwargs): - ckpt_path = str(ckpt_path) - if os.path.exists(str(ckpt_path)) and os.path.isdir(ckpt_path): - index_path = os.path.join(ckpt_path, "model_index.json") - else: - revision = kwargs.get("revision", None) - cache_dir = kwargs.get("cache_dir", None) - index_path = hf_hub_download( - repo_id=ckpt_path, filename="model_index.json", revision=revision, cache_dir=cache_dir - ) - - with open(index_path, "r") as f: - model_index_dict = json.load(f) - assert "vae" in model_index_dict, "No VAE found in the modelx index dict." - - vae_cls_config = model_index_dict["vae"] - library = importlib.import_module(vae_cls_config[0]) - return getattr(library, vae_cls_config[1]) diff --git a/finetrainers/utils/optimizer_utils.py b/finetrainers/utils/optimizer_utils.py deleted file mode 100644 index 84c215b51e8e458ff3702fd65be223bce7a7aeb9..0000000000000000000000000000000000000000 --- a/finetrainers/utils/optimizer_utils.py +++ /dev/null @@ -1,178 +0,0 @@ -import inspect - -import torch -from accelerate.logging import get_logger - - -logger = get_logger("finetrainers") - - -def get_optimizer( - params_to_optimize, - optimizer_name: str = "adam", - learning_rate: float = 1e-3, - beta1: float = 0.9, - beta2: float = 0.95, - beta3: float = 0.98, - epsilon: float = 1e-8, - weight_decay: float = 1e-4, - prodigy_decouple: bool = False, - prodigy_use_bias_correction: bool = False, - prodigy_safeguard_warmup: bool = False, - use_8bit: bool = False, - use_4bit: bool = False, - use_torchao: bool = False, - use_deepspeed: bool = False, - use_cpu_offload_optimizer: bool = False, - offload_gradients: bool = False, -) -> torch.optim.Optimizer: - optimizer_name = optimizer_name.lower() - - # Use DeepSpeed optimzer - if use_deepspeed: - from accelerate.utils import DummyOptim - - return DummyOptim( - params_to_optimize, - lr=learning_rate, - betas=(beta1, beta2), - eps=epsilon, - weight_decay=weight_decay, - ) - - # TODO: consider moving the validation logic to `args.py` when we have torchao. - if use_8bit and use_4bit: - raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") - - if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: - try: - import torchao # noqa - - except ImportError: - raise ImportError( - "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." - ) - - if not use_torchao and use_4bit: - raise ValueError("4-bit Optimizers are only supported with torchao.") - - # Optimizer creation - supported_optimizers = ["adam", "adamw", "prodigy", "came"] - if optimizer_name not in supported_optimizers: - logger.warning( - f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." - ) - optimizer_name = "adamw" - - if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: - raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") - - if use_8bit: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - if optimizer_name == "adamw": - if use_torchao: - from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit - - optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW - else: - optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW - - init_kwargs = { - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - } - - elif optimizer_name == "adam": - if use_torchao: - from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit - - optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam - else: - optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam - - init_kwargs = { - "betas": (beta1, beta2), - "eps": epsilon, - "weight_decay": weight_decay, - } - - elif optimizer_name == "prodigy": - try: - import prodigyopt - except ImportError: - raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") - - optimizer_class = prodigyopt.Prodigy - - if learning_rate <= 0.1: - logger.warning( - "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" - ) - - init_kwargs = { - "lr": learning_rate, - "betas": (beta1, beta2), - "beta3": beta3, - "eps": epsilon, - "weight_decay": weight_decay, - "decouple": prodigy_decouple, - "use_bias_correction": prodigy_use_bias_correction, - "safeguard_warmup": prodigy_safeguard_warmup, - } - - elif optimizer_name == "came": - try: - import came_pytorch - except ImportError: - raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") - - optimizer_class = came_pytorch.CAME - - init_kwargs = { - "lr": learning_rate, - "eps": (1e-30, 1e-16), - "betas": (beta1, beta2, beta3), - "weight_decay": weight_decay, - } - - if use_cpu_offload_optimizer: - from torchao.prototype.low_bit_optim import CPUOffloadOptimizer - - if "fused" in inspect.signature(optimizer_class.__init__).parameters: - init_kwargs.update({"fused": True}) - - optimizer = CPUOffloadOptimizer( - params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs - ) - else: - optimizer = optimizer_class(params_to_optimize, **init_kwargs) - - return optimizer - - -def gradient_norm(parameters): - norm = 0 - for param in parameters: - if param.grad is None: - continue - local_norm = param.grad.detach().data.norm(2) - norm += local_norm.item() ** 2 - norm = norm**0.5 - return norm - - -def max_gradient(parameters): - max_grad_value = float("-inf") - for param in parameters: - if param.grad is None: - continue - local_max_grad = param.grad.detach().data.abs().max() - max_grad_value = max(max_grad_value, local_max_grad.item()) - return max_grad_value diff --git a/finetrainers/utils/state_checkpoint.py b/finetrainers/utils/state_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0e0b9b5f6214ba56d0c308714e29b9e11f4d8a --- /dev/null +++ b/finetrainers/utils/state_checkpoint.py @@ -0,0 +1,203 @@ +import functools +import pathlib +import shutil +import time +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +import torch +import torch.distributed.checkpoint +from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + set_model_state_dict, +) +from torch.distributed.checkpoint.stateful import Stateful + +from ..logging import get_logger + + +if TYPE_CHECKING: + from .. import optimizer + + +logger = get_logger() + + +class ModelWrapper(Stateful): + def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None: + self.model = [model] if isinstance(model, torch.nn.Module) else model + + def state_dict(self) -> Dict[str, Any]: + return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + func = functools.partial( + set_model_state_dict, + model_state_dict=state_dict, + options=StateDictOptions(strict=False), + ) + list(map(func, self.model)) + + +class PTDCheckpointManager: + def __init__( + self, + dataloader: torch.utils.data.DataLoader, + model_parts: List[torch.nn.Module], + optimizers: "optimizer.OptimizerWrapper", + schedulers: "optimizer.SchedulerWrapper", + states: Dict[str, Any], + checkpointing_steps: int, + checkpointing_limit: int, + output_dir: str, + enable: bool = True, + _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, + _prefix: str = "finetrainers_step", + ) -> None: + self.states = states + self.states.update( + { + "model": ModelWrapper(model_parts), + "optimizer": optimizers, + "dataloader": dataloader, + } + ) + self.states.update(schedulers.get_lr_scheduler_state()) + + self.checkpointing_steps = checkpointing_steps + self.checkpointing_limit = checkpointing_limit + self.output_dir = pathlib.Path(output_dir) + self.enable = enable + self._callback_fn = _callback_fn + self._prefix = _prefix + + logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'") + + def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str: + if not self._should_checkpoint(step, force): + return None + + checkpoint_dir = self._get_checkpoint_dir(step) + begin_time = time.monotonic() + torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) + end_time = time.monotonic() + logger.info( + f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}" + ) + self._purge_stale_checkpoints() + + state_dicts = [ + gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process) + for model in self.states["model"].model + ] + if self._callback_fn is not None: + list(map(self._callback_fn, state_dicts)) + + return checkpoint_dir.as_posix() + + def load(self, step: int = -1) -> bool: + if not self.enable: + return False + if not self.output_dir.exists(): + return False + if step != -1 and not self._get_checkpoint_dir(step).exists(): + return False + + if step == -1: + latest_checkpoint_dir = self._find_latest_checkpoint_dir() + if latest_checkpoint_dir is None: + return False + step = int(latest_checkpoint_dir.name.split("_")[-1]) + + checkpoint_dir = self._get_checkpoint_dir(step) + logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}") + + # For step 0, optimizers/schedulers are not available as they are created during training after first step + states = {"model": self.states["model"]} if step == 0 else self.states + + # See bug: https://github.com/pytorch/pytorch/pull/138575 + original_stateful_states = {k: v for k, v in states.items() if isinstance(v, Stateful)} + begin_time = time.monotonic() + torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix()) + end_time = time.monotonic() + logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.") + + # bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load() + states.update(original_stateful_states) + + return True + + def _should_checkpoint(self, step: int, force: bool) -> bool: + if not self.enable: + return False + + if not force: + if step % self.checkpointing_steps != 0: + return False + + return True + + def _get_checkpoint_dir(self, step: int) -> pathlib.Path: + return self.output_dir / f"{self._prefix}_{step}" + + def _find_latest_checkpoint_dir(self) -> Union[pathlib.Path, None]: + checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1])) + return checkpoints[-1] if len(checkpoints) > 0 else None + + def _purge_stale_checkpoints(self) -> None: + if self.checkpointing_limit is None or self.checkpointing_limit <= 0: + return + checkpoints = sorted( + self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True + ) + for checkpoint in checkpoints[self.checkpointing_limit :]: + logger.info(f"Deleting stale checkpoint: {checkpoint}") + shutil.rmtree(checkpoint, ignore_errors=True) + + +def gather_state_dict_on_cpu_rank0( + model, device: Optional[torch.device] = None, *, is_main_process: bool +) -> Dict[str, Any]: + cpu_state_dict = {} + sharded_sd = model.state_dict() + for param_name, param in sharded_sd.items(): + if param.is_cpu: + # Move back to device if offloaded to CPU + param = param.to(device) + if hasattr(param, "_local_tensor"): + # Gather DTensor + param = param.full_tensor() + if is_main_process: + cpu_state_dict[param_name] = param.cpu() + torch.distributed.barrier() + return cpu_state_dict + + +# # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict +# def dcp_to_torch_save( +# dcp_checkpoint_dir: Union[str, os.PathLike], +# torch_save_path: Union[str, os.PathLike], +# callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, +# ): +# """ +# Given a directory containing a DCP checkpoint, this function will convert it into a +# Torch save file. + +# Args: +# dcp_checkpoint_dir: Directory containing the DCP checkpoint. +# torch_save_path: Filename to store the converted Torch save file. +# callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict. + +# .. warning:: +# To avoid OOM, it's recommended to only run this function on a single rank. +# """ +# state_dict = {} +# _load_state_dict( +# state_dict, +# storage_reader=FileSystemReader(dcp_checkpoint_dir), +# planner=_EmptyStateDictLoadPlanner(), +# no_dist=True, +# ) +# if callback_fn is not None: +# state_dict = callback_fn(state_dict) +# torch.save(state_dict, torch_save_path) diff --git a/finetrainers/utils/torch.py b/finetrainers/utils/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..db434d3e361ea03e6c06811c98f08594c4ab773d --- /dev/null +++ b/finetrainers/utils/torch.py @@ -0,0 +1,338 @@ +import math +import os +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.backends +import torch.distributed as dist +import torch.distributed.tensor +from accelerate import Accelerator +from diffusers.utils.torch_utils import is_compiled_module + +from ..logging import get_logger + + +logger = get_logger() + +_STRING_TO_DTYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +_DTYPE_TO_STRING = {v: k for k, v in _STRING_TO_DTYPE.items()} + +_HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = False + + +def align_device_and_dtype( + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + if isinstance(x, torch.Tensor): + if device is not None: + x = x.to(device) + if dtype is not None: + x = x.to(dtype) + elif isinstance(x, dict): + if device is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + if dtype is not None: + x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} + return x + + +def _clip_grad_norm_while_handling_failing_dtensor_cases( + parameters: Union[torch.Tensor, List[torch.Tensor]], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, + pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, +) -> Optional[torch.Tensor]: + global _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES + + if not _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES: + try: + return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach, pp_mesh) + except NotImplementedError as e: + if "DTensor does not support cross-mesh operation" in str(e): + # https://github.com/pytorch/pytorch/issues/134212 + logger.warning( + "DTensor does not support cross-mesh operation. If you haven't fully tensor-parallelized your " + "model, while combining other parallelisms such as FSDP, it could be the reason for this error. " + "Gradient clipping will be skipped and gradient norm will not be logged." + ) + except Exception as e: + logger.warning( + f"An error occurred while clipping gradients: {e}. Gradient clipping will be skipped and gradient " + f"norm will not be logged." + ) + _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = True + return None + + +# Copied from https://github.com/pytorch/torchtitan/blob/4a169701555ab9bd6ca3769f9650ae3386b84c6e/torchtitan/utils.py#L362 +@torch.no_grad() +def clip_grad_norm_( + parameters: Union[torch.Tensor, List[torch.Tensor]], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, + pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, +) -> torch.Tensor: + r""" + Clip the gradient norm of parameters. + + Gradient norm clipping requires computing the gradient norm over the entire model. + `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions. + We need to manually reduce the gradient norm across PP stages. + See https://github.com/pytorch/torchtitan/issues/596 for details. + + Args: + parameters (`torch.Tensor` or `List[torch.Tensor]`): + Tensors that will have gradients normalized. + max_norm (`float`): + Maximum norm of the gradients after clipping. + norm_type (`float`, defaults to `2.0`): + Type of p-norm to use. Can be `inf` for infinity norm. + error_if_nonfinite (`bool`, defaults to `False`): + If `True`, an error is thrown if the total norm of the gradients from `parameters` is `nan`, `inf`, or `-inf`. + foreach (`bool`, defaults to `None`): + Use the faster foreach-based implementation. If `None`, use the foreach implementation for CUDA and CPU native tensors + and silently fall back to the slow implementation for other device types. + pp_mesh (`torch.distributed.device_mesh.DeviceMesh`, defaults to `None`): + Pipeline parallel device mesh. If not `None`, will reduce gradient norm across PP stages. + + Returns: + `torch.Tensor`: + Total norm of the gradients + """ + grads = [p.grad for p in parameters if p.grad is not None] + + # TODO(aryan): Wait for next Pytorch release to use `torch.nn.utils.get_total_norm` + # total_norm = torch.nn.utils.get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + + # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`. + # We can simply reduce the DTensor to get the total norm in this tensor's process group + # and then convert it to a local tensor. + # It has two purposes: + # 1. to make sure the total norm is computed correctly when PP is used (see below) + # 2. to return a reduced total_norm tensor whose .item() would return the correct value + if isinstance(total_norm, torch.distributed.tensor.DTensor): + # Will reach here if any non-PP parallelism is used. + # If only using PP, total_norm will be a local tensor. + total_norm = total_norm.full_tensor() + + if pp_mesh is not None: + if math.isinf(norm_type): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= norm_type + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / norm_type + + _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) + return total_norm + + +def enable_determinism( + seed: int, + world_mesh: Optional[torch.distributed.DeviceMesh] = None, + deterministic: bool = False, +) -> None: + r""" + For all ranks within the same DTensor SPMD group, the same seed will be set. + For PP groups, different seeds will be set. + """ + if deterministic: + logger.info("Deterministic algorithms are enabled (expect performance degradation).") + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + if not world_mesh: + if seed is not None: + torch.manual_seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed % 2**32) + logger.debug(f"Single-process job using seed: {seed}") + return + + # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, + # and choose a unique seed for each rank on the PP mesh. + if torch.distributed.distributed_c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names: + pp_mesh = world_mesh["pp"] + seed += pp_mesh.get_local_rank() + seed %= 2**64 + + info = { + "pp_rank": pp_mesh.get_local_rank(), + "global_rank": torch.distributed.distributed_c10d.get_rank(), + "seed": seed, + } + logger.debug(f"Enabling determinism: {info}") + spmd_mesh_dims = list(filter(lambda name: name != "pp", world_mesh.mesh_dim_names)) + spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None + else: + spmd_mesh = world_mesh + info = {"global_rank": torch.distributed.distributed_c10d.get_rank(), "seed": seed} + logger.debug(f"Enabling determinism: {info}") + + # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency + torch.manual_seed(seed) + # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1] + os.environ["PYTHONHASHSEED"] = str(seed % 2**32) + + # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. + # IF PP is also used, this seed is unique per PP rank. + if spmd_mesh and spmd_mesh.get_coordinate() is not None: + torch.distributed.tensor._random.manual_seed(seed, spmd_mesh) + + +def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor: + assert len(tensor.shape) <= ndim + return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape))) + + +def get_device_info(): + from torch._utils import _get_available_device_type, _get_device_module + + device_type = _get_available_device_type() + if device_type is None: + device_type = "cuda" + device_module = _get_device_module(device_type) + return device_type, device_module + + +def get_dtype_from_string(dtype: str): + return _STRING_TO_DTYPE[dtype] + + +def get_string_from_dtype(dtype: torch.dtype): + return _DTYPE_TO_STRING[dtype] + + +def set_requires_grad(models: Union[torch.nn.Module, List[torch.nn.Module]], value: bool) -> None: + if isinstance(models, torch.nn.Module): + models = [models] + for model in models: + if model is not None: + model.requires_grad_(value) + + +def synchronize_device() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + elif torch.backends.mps.is_available(): + torch.mps.synchronize() + + +def unwrap_model(accelerator: Accelerator, model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + +# TODO(aryan): remove everything below this after next torch release +def _get_total_norm( + tensors: Union[torch.Tensor, List[torch.Tensor]], + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + else: + tensors = list(tensors) + norm_type = float(norm_type) + if len(tensors) == 0: + return torch.tensor(0.0) + first_device = tensors[0].device + grouped_tensors: dict[ + tuple[torch.device, torch.dtype], tuple[list[list[torch.Tensor]], list[int]] + ] = _group_tensors_by_device_and_dtype( + [tensors] # type: ignore[list-item] + ) # type: ignore[assignment] + + norms: List[torch.Tensor] = [] + for (device, _), ([device_tensors], _) in grouped_tensors.items(): + if (foreach is None and _has_foreach_support(device_tensors, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_tensors, norm_type)) + elif foreach: + raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") + else: + norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_tensors]) + + total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + return total_norm + + +@torch.no_grad() +def _clip_grads_with_norm_( + parameters: Union[torch.Tensor, List[torch.Tensor]], + max_norm: float, + total_norm: torch.Tensor, + foreach: Optional[bool] = None, +) -> None: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + if len(grads) == 0: + return + grouped_grads: dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[torch.Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment] + + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors") + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + +def _get_foreach_kernels_supported_devices() -> list[str]: + r"""Return the device type list that supports foreach kernels.""" + return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()] + + +@torch.no_grad() +def _group_tensors_by_device_and_dtype( + tensorlistlist: List[List[Optional[torch.Tensor]]], + with_indices: bool = False, +) -> dict[tuple[torch.device, torch.dtype], tuple[List[List[Optional[torch.Tensor]]], List[int]]]: + return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) + + +def _device_has_foreach_support(device: torch.device) -> bool: + return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting() + + +def _has_foreach_support(tensors: List[torch.Tensor], device: torch.device) -> bool: + return _device_has_foreach_support(device) and all(t is None or type(t) in [torch.Tensor] for t in tensors) diff --git a/finetrainers/utils/torch_utils.py b/finetrainers/utils/torch_utils.py deleted file mode 100644 index 1c6ef5df9ea5a832e671cbaaff008cd6f48078b0..0000000000000000000000000000000000000000 --- a/finetrainers/utils/torch_utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Dict, Optional, Union - -import torch -from accelerate import Accelerator -from diffusers.utils.torch_utils import is_compiled_module - - -def unwrap_model(accelerator: Accelerator, model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - -def align_device_and_dtype( - x: Union[torch.Tensor, Dict[str, torch.Tensor]], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, -): - if isinstance(x, torch.Tensor): - if device is not None: - x = x.to(device) - if dtype is not None: - x = x.to(dtype) - elif isinstance(x, dict): - if device is not None: - x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} - if dtype is not None: - x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} - return x - - -def expand_tensor_dims(tensor, ndim): - while len(tensor.shape) < ndim: - tensor = tensor.unsqueeze(-1) - return tensor diff --git a/train.py b/train.py index 088c0617ae0c37fbc7ef5751fddc2d484f36d1a4..3260752aa563851527c867518c49fa7b1698d3e6 100644 --- a/train.py +++ b/train.py @@ -1,12 +1,12 @@ -import logging +import sys import traceback -from finetrainers import Trainer, parse_arguments -from finetrainers.constants import FINETRAINERS_LOG_LEVEL +from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger +from finetrainers.config import _get_model_specifiction_cls +from finetrainers.trainer.sft_trainer.config import SFTFullRankConfig, SFTLowRankConfig -logger = logging.getLogger("finetrainers") -logger.setLevel(FINETRAINERS_LOG_LEVEL) +logger = get_logger() def main(): @@ -22,18 +22,52 @@ def main(): ) try: - args = parse_arguments() - trainer = Trainer(args) - - trainer.prepare_dataset() - trainer.prepare_models() - trainer.prepare_precomputations() - trainer.prepare_trainable_parameters() - trainer.prepare_optimizer() - trainer.prepare_for_training() - trainer.prepare_trackers() - trainer.train() - # trainer.evaluate() + args = BaseArgs() + + argv = [y.strip() for x in sys.argv for y in x.split()] + training_type_index = argv.index("--training_type") + if training_type_index == -1: + raise ValueError("Training type not provided in command line arguments.") + + training_type = argv[training_type_index + 1] + training_cls = None + if training_type == TrainingType.LORA: + training_cls = SFTLowRankConfig + elif training_type == TrainingType.FULL_FINETUNE: + training_cls = SFTFullRankConfig + else: + raise ValueError(f"Training type {training_type} not supported.") + + training_config = training_cls() + args.extend_args(training_config.add_args, training_config.map_args, training_config.validate_args) + args = args.parse_args() + + model_specification_cls = _get_model_specifiction_cls(args.model_name, args.training_type) + model_specification = model_specification_cls( + pretrained_model_name_or_path=args.pretrained_model_name_or_path, + tokenizer_id=args.tokenizer_id, + tokenizer_2_id=args.tokenizer_2_id, + tokenizer_3_id=args.tokenizer_3_id, + text_encoder_id=args.text_encoder_id, + text_encoder_2_id=args.text_encoder_2_id, + text_encoder_3_id=args.text_encoder_3_id, + transformer_id=args.transformer_id, + vae_id=args.vae_id, + text_encoder_dtype=args.text_encoder_dtype, + text_encoder_2_dtype=args.text_encoder_2_dtype, + text_encoder_3_dtype=args.text_encoder_3_dtype, + transformer_dtype=args.transformer_dtype, + vae_dtype=args.vae_dtype, + revision=args.revision, + cache_dir=args.cache_dir, + ) + + if args.training_type in [TrainingType.LORA, TrainingType.FULL_FINETUNE]: + trainer = SFTTrainer(args, model_specification) + else: + raise ValueError(f"Training type {args.training_type} not supported.") + + trainer.run() except KeyboardInterrupt: logger.info("Received keyboard interrupt. Exiting...")