VideoModelStudio
/
docs
/finetrainers-src-codebase
/examples
/_legacy
/training
/cogvideox
/utils.py
| import gc | |
| import inspect | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from diffusers.models.embeddings import get_3d_rotary_pos_embed | |
| from diffusers.utils.torch_utils import is_compiled_module | |
| logger = get_logger(__name__) | |
| def get_optimizer( | |
| params_to_optimize, | |
| optimizer_name: str = "adam", | |
| learning_rate: float = 1e-3, | |
| beta1: float = 0.9, | |
| beta2: float = 0.95, | |
| beta3: float = 0.98, | |
| epsilon: float = 1e-8, | |
| weight_decay: float = 1e-4, | |
| prodigy_decouple: bool = False, | |
| prodigy_use_bias_correction: bool = False, | |
| prodigy_safeguard_warmup: bool = False, | |
| use_8bit: bool = False, | |
| use_4bit: bool = False, | |
| use_torchao: bool = False, | |
| use_deepspeed: bool = False, | |
| use_cpu_offload_optimizer: bool = False, | |
| offload_gradients: bool = False, | |
| ) -> torch.optim.Optimizer: | |
| optimizer_name = optimizer_name.lower() | |
| # Use DeepSpeed optimzer | |
| if use_deepspeed: | |
| from accelerate.utils import DummyOptim | |
| return DummyOptim( | |
| params_to_optimize, | |
| lr=learning_rate, | |
| betas=(beta1, beta2), | |
| eps=epsilon, | |
| weight_decay=weight_decay, | |
| ) | |
| if use_8bit and use_4bit: | |
| raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") | |
| if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: | |
| try: | |
| import torchao | |
| torchao.__version__ | |
| except ImportError: | |
| raise ImportError( | |
| "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." | |
| ) | |
| if not use_torchao and use_4bit: | |
| raise ValueError("4-bit Optimizers are only supported with torchao.") | |
| # Optimizer creation | |
| supported_optimizers = ["adam", "adamw", "prodigy", "came"] | |
| if optimizer_name not in supported_optimizers: | |
| logger.warning( | |
| f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." | |
| ) | |
| optimizer_name = "adamw" | |
| if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: | |
| raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") | |
| if use_8bit: | |
| try: | |
| import bitsandbytes as bnb | |
| except ImportError: | |
| raise ImportError( | |
| "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | |
| ) | |
| if optimizer_name == "adamw": | |
| if use_torchao: | |
| from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit | |
| optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW | |
| else: | |
| optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW | |
| init_kwargs = { | |
| "betas": (beta1, beta2), | |
| "eps": epsilon, | |
| "weight_decay": weight_decay, | |
| } | |
| elif optimizer_name == "adam": | |
| if use_torchao: | |
| from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit | |
| optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam | |
| else: | |
| optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam | |
| init_kwargs = { | |
| "betas": (beta1, beta2), | |
| "eps": epsilon, | |
| "weight_decay": weight_decay, | |
| } | |
| elif optimizer_name == "prodigy": | |
| try: | |
| import prodigyopt | |
| except ImportError: | |
| raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") | |
| optimizer_class = prodigyopt.Prodigy | |
| if learning_rate <= 0.1: | |
| logger.warning( | |
| "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" | |
| ) | |
| init_kwargs = { | |
| "lr": learning_rate, | |
| "betas": (beta1, beta2), | |
| "beta3": beta3, | |
| "eps": epsilon, | |
| "weight_decay": weight_decay, | |
| "decouple": prodigy_decouple, | |
| "use_bias_correction": prodigy_use_bias_correction, | |
| "safeguard_warmup": prodigy_safeguard_warmup, | |
| } | |
| elif optimizer_name == "came": | |
| try: | |
| import came_pytorch | |
| except ImportError: | |
| raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") | |
| optimizer_class = came_pytorch.CAME | |
| init_kwargs = { | |
| "lr": learning_rate, | |
| "eps": (1e-30, 1e-16), | |
| "betas": (beta1, beta2, beta3), | |
| "weight_decay": weight_decay, | |
| } | |
| if use_cpu_offload_optimizer: | |
| from torchao.prototype.low_bit_optim import CPUOffloadOptimizer | |
| if "fused" in inspect.signature(optimizer_class.__init__).parameters: | |
| init_kwargs.update({"fused": True}) | |
| optimizer = CPUOffloadOptimizer( | |
| params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs | |
| ) | |
| else: | |
| optimizer = optimizer_class(params_to_optimize, **init_kwargs) | |
| return optimizer | |
| def get_gradient_norm(parameters): | |
| norm = 0 | |
| for param in parameters: | |
| if param.grad is None: | |
| continue | |
| local_norm = param.grad.detach().data.norm(2) | |
| norm += local_norm.item() ** 2 | |
| norm = norm**0.5 | |
| return norm | |
| # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid | |
| def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): | |
| tw = tgt_width | |
| th = tgt_height | |
| h, w = src | |
| r = h / w | |
| if r > (th / tw): | |
| resize_height = th | |
| resize_width = int(round(th / h * w)) | |
| else: | |
| resize_width = tw | |
| resize_height = int(round(tw / w * h)) | |
| crop_top = int(round((th - resize_height) / 2.0)) | |
| crop_left = int(round((tw - resize_width) / 2.0)) | |
| return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) | |
| def prepare_rotary_positional_embeddings( | |
| height: int, | |
| width: int, | |
| num_frames: int, | |
| vae_scale_factor_spatial: int = 8, | |
| patch_size: int = 2, | |
| patch_size_t: int = None, | |
| attention_head_dim: int = 64, | |
| device: Optional[torch.device] = None, | |
| base_height: int = 480, | |
| base_width: int = 720, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| grid_height = height // (vae_scale_factor_spatial * patch_size) | |
| grid_width = width // (vae_scale_factor_spatial * patch_size) | |
| base_size_width = base_width // (vae_scale_factor_spatial * patch_size) | |
| base_size_height = base_height // (vae_scale_factor_spatial * patch_size) | |
| if patch_size_t is None: | |
| # CogVideoX 1.0 | |
| grid_crops_coords = get_resize_crop_region_for_grid( | |
| (grid_height, grid_width), base_size_width, base_size_height | |
| ) | |
| freqs_cos, freqs_sin = get_3d_rotary_pos_embed( | |
| embed_dim=attention_head_dim, | |
| crops_coords=grid_crops_coords, | |
| grid_size=(grid_height, grid_width), | |
| temporal_size=num_frames, | |
| ) | |
| else: | |
| # CogVideoX 1.5 | |
| base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t | |
| freqs_cos, freqs_sin = get_3d_rotary_pos_embed( | |
| embed_dim=attention_head_dim, | |
| crops_coords=None, | |
| grid_size=(grid_height, grid_width), | |
| temporal_size=base_num_frames, | |
| grid_type="slice", | |
| max_size=(base_size_height, base_size_width), | |
| ) | |
| freqs_cos = freqs_cos.to(device=device) | |
| freqs_sin = freqs_sin.to(device=device) | |
| return freqs_cos, freqs_sin | |
| def reset_memory(device: Union[str, torch.device]) -> None: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats(device) | |
| torch.cuda.reset_accumulated_memory_stats(device) | |
| def print_memory(device: Union[str, torch.device]) -> None: | |
| memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 | |
| max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 | |
| max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 | |
| print(f"{memory_allocated=:.3f} GB") | |
| print(f"{max_memory_allocated=:.3f} GB") | |
| print(f"{max_memory_reserved=:.3f} GB") | |
| def unwrap_model(accelerator: Accelerator, model): | |
| model = accelerator.unwrap_model(model) | |
| model = model._orig_mod if is_compiled_module(model) else model | |
| return model | |