diff --git a/accelerate_configs/uncompiled_4.yaml b/accelerate_configs/uncompiled_4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e15d4c6cd56145c4653e97b8cbdd823b154b6207
--- /dev/null
+++ b/accelerate_configs/uncompiled_4.yaml
@@ -0,0 +1,17 @@
+compute_environment: LOCAL_MACHINE
+debug: false
+distributed_type: MULTI_GPU
+downcast_bf16: 'no'
+enable_cpu_affinity: false
+gpu_ids: 0,1,2,3
+machine_rank: 0
+main_training_function: main
+mixed_precision: bf16
+num_machines: 1
+num_processes: 4
+rdzv_backend: static
+same_network: true
+tpu_env: []
+tpu_use_cluster: false
+tpu_use_sudo: false
+use_cpu: false
\ No newline at end of file
diff --git a/finetrainers/__init__.py b/finetrainers/__init__.py
index 412e298eb519f037e08dc755f92d136cfe2ef2e6..7da2391e864af71edf8b826d1f1263d5c8f1afe5 100644
--- a/finetrainers/__init__.py
+++ b/finetrainers/__init__.py
@@ -1,2 +1,5 @@
-from .args import Args, parse_arguments
-from .trainer import Trainer
+from .args import BaseArgs
+from .config import ModelType, TrainingType
+from .logging import get_logger
+from .models import ModelSpecification
+from .trainer import SFTTrainer
diff --git a/finetrainers/args.py b/finetrainers/args.py
index 46cd04cca1c0d7368be8395ce3382bc28fc6865b..199c7493ab0fcba1724b9e6eb7dbc70ca237ed64 100644
--- a/finetrainers/args.py
+++ b/finetrainers/args.py
@@ -1,14 +1,21 @@
 import argparse
+import os
+import pathlib
 import sys
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional
 
 import torch
 
-from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS
-from .models import SUPPORTED_MODEL_CONFIGS
+from .config import SUPPORTED_MODEL_CONFIGS, ModelType, TrainingType
+from .logging import get_logger
+from .parallel import ParallelBackendEnum
+from .utils import get_non_null_items
 
 
-class Args:
+logger = get_logger()
+
+
+class BaseArgs:
     r"""
     The arguments for the finetrainers training script.
 
@@ -19,6 +26,19 @@ class Args:
     TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show
     memory requirements per model, per training type with sensible training settings.
 
+    PARALLEL ARGUMENTS
+    ------------------
+    parallel_backend (`str`, defaults to `accelerate`):
+        The parallel backend to use for training. Choose between ['accelerate', 'ptd'].
+    pp_degree (`int`, defaults to `1`):
+        The degree of pipeline parallelism.
+    dp_degree (`int`, defaults to `1`):
+        The degree of data parallelism (number of model replicas).
+    dp_shards (`int`, defaults to `-1`):
+        The number of data parallel shards (number of model partitions).
+    cp_degree (`int`, defaults to `1`):
+        The degree of context parallelism.
+
     MODEL ARGUMENTS
     ---------------
     model_name (`str`):
@@ -33,6 +53,22 @@ class Args:
         storage requirements.
     cache_dir (`str`, defaults to `None`):
         The directory where the downloaded models and datasets will be stored, or loaded from.
+    tokenizer_id (`str`, defaults to `None`):
+        Identifier for the tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
+    tokenizer_2_id (`str`, defaults to `None`):
+        Identifier for the second tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
+    tokenizer_3_id (`str`, defaults to `None`):
+        Identifier for the third tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
+    text_encoder_id (`str`, defaults to `None`):
+        Identifier for the text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
+    text_encoder_2_id (`str`, defaults to `None`):
+        Identifier for the second text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
+    text_encoder_3_id (`str`, defaults to `None`):
+        Identifier for the third text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
+    transformer_id (`str`, defaults to `None`):
+        Identifier for the transformer model. This is useful when using a different transformer model than the default from `pretrained_model_name_or_path`.
+    vae_id (`str`, defaults to `None`):
+        Identifier for the VAE model. This is useful when using a different VAE model than the default from `pretrained_model_name_or_path`.
     text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
         Data type for the text encoder when generating text embeddings.
     text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
@@ -54,41 +90,47 @@ class Args:
 
     DATASET ARGUMENTS
     -----------------
-    data_root (`str`):
-        A folder containing the training data.
-    dataset_file (`str`, defaults to `None`):
-        Path to a CSV/JSON/JSONL file containing metadata for training. This should be provided if you're not using
-        a directory dataset format containing a simple `prompts.txt` and `videos.txt`/`images.txt` for example.
-    video_column (`str`):
-        The column of the dataset containing videos. Or, the name of the file in `data_root` folder containing the
-        line-separated path to video data.
-    caption_column (`str`):
-        The column of the dataset containing the instance prompt for each video. Or, the name of the file in
-        `data_root` folder containing the line-separated instance prompts.
-    id_token (`str`, defaults to `None`):
-        Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training.
-    image_resolution_buckets (`List[Tuple[int, int]]`, defaults to `None`):
-        Resolution buckets for images. This should be a list of integer tuples, where each tuple represents the
-        resolution (height, width) of the image. All images will be resized to the nearest bucket resolution.
-    video_resolution_buckets (`List[Tuple[int, int, int]]`, defaults to `None`):
-        Resolution buckets for videos. This should be a list of integer tuples, where each tuple represents the
-        resolution (num_frames, height, width) of the video. All videos will be resized to the nearest bucket
-        resolution.
-    video_reshape_mode (`str`, defaults to `None`):
-        All input videos are reshaped to this mode. Choose between ['center', 'random', 'none'].
-        TODO(aryan): We don't support this.
-    caption_dropout_p (`float`, defaults to `0.00`):
-        Probability of dropout for the caption tokens. This is useful to improve the unconditional generation
-        quality of the model.
-    caption_dropout_technique (`str`, defaults to `empty`):
-        Technique to use for caption dropout. Choose between ['empty', 'zero']. Some models apply caption dropout
-        by setting the prompt condition to an empty string, while others zero-out the text embedding tensors.
-    precompute_conditions (`bool`, defaults to `False`):
-        Whether or not to precompute the conditionings for the model. This is useful for faster training, and
-        reduces the memory requirements.
-    remove_common_llm_caption_prefixes (`bool`, defaults to `False`):
-        Whether or not to remove common LLM caption prefixes. This is useful for improving the quality of the
-        generated text.
+    dataset_config (`str`):
+        File to a dataset file containing information about training data. This file can contain information about one or
+        more datasets in JSON format. The file must have a key called "datasets", which is a list of dictionaries. Each
+        dictionary must contain the following keys:
+            - "data_root": (`str`)
+                The root directory containing the dataset. This parameter must be provided if `dataset_file` is not provided.
+            - "dataset_file": (`str`)
+                Path to a CSV/JSON/JSONL/PARQUET/ARROW/HF_HUB_DATASET file containing metadata for training. This parameter
+                must be provided if `data_root` is not provided.
+            - "dataset_type": (`str`)
+                Type of dataset. Choose between ['image', 'video'].
+            - "id_token": (`str`)
+                Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training
+                for single subject/concept/style training, but is not necessary.
+            - "image_resolution_buckets": (`List[Tuple[int, int]]`)
+                Resolution buckets for image. This should be a list of tuples containing 2 values, where each tuple
+                represents the resolution (height, width). All images will be resized to the nearest bucket resolution.
+                This parameter must be provided if `dataset_type` is 'image'.
+            - "video_resolution_buckets": (`List[Tuple[int, int, int]]`)
+                Resolution buckets for video. This should be a list of tuples containing 3 values, where each tuple
+                represents the resolution (num_frames, height, width). All videos will be resized to the nearest bucket
+                resolution. This parameter must be provided if `dataset_type` is 'video'.
+            - "reshape_mode": (`str`)
+                All input images/videos are reshaped using this mode. Choose between the following:
+                ["center_crop", "random_crop", "bicubic"].
+            - "remove_common_llm_caption_prefixes": (`boolean`)
+                Whether or not to remove common LLM caption prefixes. See `~constants.py` for the list of common prefixes.
+    dataset_shuffle_buffer_size (`int`, defaults to `1`):
+        The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default
+        value of `1` means that the dataset will not be shuffled.
+    precomputation_items (`int`, defaults to `512`):
+        Number of data samples to precompute at once for memory-efficient training. The higher this value,
+        the more disk memory will be used to save the precomputed samples (conditions and latents).
+    precomputation_dir (`str`, defaults to `None`):
+        The directory where the precomputed samples will be stored. If not provided, the precomputed samples
+        will be stored in a temporary directory of the output directory.
+    precomputation_once (`bool`, defaults to `False`):
+        Precompute embeddings from all datasets at once before training. This is useful to save time during training
+        with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during
+        training when required. Make sure to set `precomputation_items` to a reasonable value in line with the size
+        of your dataset(s).
 
     DATALOADER_ARGUMENTS
     --------------------
@@ -136,16 +178,11 @@ class Args:
         A seed for reproducible training.
     batch_size (`int`, defaults to `1`):
         Per-device batch size.
-    train_epochs (`int`, defaults to `1`):
-        Number of training epochs.
-    train_steps (`int`, defaults to `None`):
-        Total number of training steps to perform. If provided, overrides `train_epochs`.
-    rank (`int`, defaults to `128`):
-        The rank for LoRA matrices.
-    lora_alpha (`float`, defaults to `64`):
-        The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.
-    target_modules (`List[str]`, defaults to `["to_k", "to_q", "to_v", "to_out.0"]`):
-        The target modules for LoRA. Make sure to modify this based on the model.
+    train_steps (`int`, defaults to `1000`):
+        Total number of training steps to perform.
+    max_data_samples (`int`, defaults to `2**64`):
+        Maximum number of data samples observed during training training. If lesser than that required by `train_steps`,
+        the training will stop early.
     gradient_accumulation_steps (`int`, defaults to `1`):
         Number of gradients steps to accumulate before performing an optimizer step.
     gradient_checkpointing (`bool`, defaults to `False`):
@@ -164,13 +201,11 @@ class Args:
     OPTIMIZER ARGUMENTS
     -------------------
     optimizer (`str`, defaults to `adamw`):
-        The optimizer type to use. Choose between ['adam', 'adamw'].
-    use_8bit_bnb (`bool`, defaults to `False`):
-        Whether to use 8bit variant of the `optimizer` using `bitsandbytes`.
+        The optimizer type to use. Choose between the following:
+            - Torch optimizers: ["adam", "adamw"]
+            - Bitsandbytes optimizers: ["adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"]
     lr (`float`, defaults to `1e-4`):
         Initial learning rate (after the potential warmup period) to use.
-    scale_lr (`bool`, defaults to `False`):
-        Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.
     lr_scheduler (`str`, defaults to `cosine_with_restarts`):
         The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',
         'constant', 'constant_with_warmup'].
@@ -192,29 +227,21 @@ class Args:
 
     VALIDATION ARGUMENTS
     --------------------
-    validation_prompts (`List[str]`, defaults to `None`):
-        List of prompts to use for validation. If not provided, a random prompt will be selected from the training
-        dataset.
-    validation_images (`List[str]`, defaults to `None`):
-        List of image paths to use for validation.
-    validation_videos (`List[str]`, defaults to `None`):
-        List of video paths to use for validation.
-    validation_heights (`List[int]`, defaults to `None`):
-        List of heights for the validation videos.
-    validation_widths (`List[int]`, defaults to `None`):
-        List of widths for the validation videos.
-    validation_num_frames (`List[int]`, defaults to `None`):
-        List of number of frames for the validation videos.
-    num_validation_videos_per_prompt (`int`, defaults to `1`):
-        Number of videos to use for validation per prompt.
-    validation_every_n_epochs (`int`, defaults to `None`):
-        Perform validation every `n` training epochs.
-    validation_every_n_steps (`int`, defaults to `None`):
-        Perform validation every `n` training steps.
+    validation_dataset_file (`str`, defaults to `None`):
+        Path to a CSV/JSON/PARQUET/ARROW file containing information for validation. The file must contain atleast the
+        "caption" column. Other columns such as "image_path" and "video_path" can be provided too. If provided, "image_path"
+        will be used to load a PIL.Image.Image and set the "image" key in the sample dictionary. Similarly, "video_path"
+        will be used to load a List[PIL.Image.Image] and set the "video" key in the sample dictionary.
+        The validation dataset file may contain other attributes specific to inference/validation such as:
+            - "height" and "width" and "num_frames": Resolution
+            - "num_inference_steps": Number of inference steps
+            - "guidance_scale": Classifier-free Guidance Scale
+            - ... (any number of additional attributes can be provided. The ModelSpecification::validate method will be
+              invoked with the sample dictionary to validate the sample.)
+    validation_steps (`int`, defaults to `500`):
+        Number of training steps after which a validation step is performed.
     enable_model_cpu_offload (`bool`, defaults to `False`):
         Whether or not to offload different modeling components to CPU during validation.
-    validation_frame_rate (`int`, defaults to `25`):
-        Frame rate to use for the validation videos. This value is defaulted to 25, as used in LTX Video pipeline.
 
     MISCELLANEOUS ARGUMENTS
     -----------------------
@@ -230,20 +257,44 @@ class Args:
         The directory where the model checkpoints and logs will be stored.
     logging_dir (`str`, defaults to `logs`):
         The directory where the logs will be stored.
+    logging_steps (`int`, defaults to `1`):
+        Training logs will be tracked every `logging_steps` steps.
     allow_tf32 (`bool`, defaults to `False`):
         Whether or not to allow the use of TF32 matmul on compatible hardware.
     nccl_timeout (`int`, defaults to `1800`):
         Timeout for the NCCL communication.
     report_to (`str`, defaults to `wandb`):
         The name of the logger to use for logging training metrics. Choose between ['wandb'].
+    verbose (`int`, defaults to `1`):
+        Whether or not to print verbose logs.
+            - 0: Diffusers/Transformers warning logging on local main process only
+            - 1: Diffusers/Transformers info logging on local main process only
+            - 2: Diffusers/Transformers debug logging on local main process only
+            - 3: Diffusers/Transformers debug logging on all processes
     """
 
+    # Parallel arguments
+    parallel_backend = ParallelBackendEnum.ACCELERATE
+    pp_degree: int = 1
+    dp_degree: int = 1
+    dp_shards: int = 1
+    cp_degree: int = 1
+    tp_degree: int = 1
+
     # Model arguments
     model_name: str = None
     pretrained_model_name_or_path: str = None
     revision: Optional[str] = None
     variant: Optional[str] = None
     cache_dir: Optional[str] = None
+    tokenizer_id: Optional[str] = None
+    tokenizer_2_id: Optional[str] = None
+    tokenizer_3_id: Optional[str] = None
+    text_encoder_id: Optional[str] = None
+    text_encoder_2_id: Optional[str] = None
+    text_encoder_3_id: Optional[str] = None
+    transformer_id: Optional[str] = None
+    vae_id: Optional[str] = None
     text_encoder_dtype: torch.dtype = torch.bfloat16
     text_encoder_2_dtype: torch.dtype = torch.bfloat16
     text_encoder_3_dtype: torch.dtype = torch.bfloat16
@@ -263,18 +314,11 @@ class Args:
     ]
 
     # Dataset arguments
-    data_root: str = None
-    dataset_file: Optional[str] = None
-    video_column: str = None
-    caption_column: str = None
-    id_token: Optional[str] = None
-    image_resolution_buckets: List[Tuple[int, int]] = None
-    video_resolution_buckets: List[Tuple[int, int, int]] = None
-    video_reshape_mode: Optional[str] = None
-    caption_dropout_p: float = 0.00
-    caption_dropout_technique: str = "empty"
-    precompute_conditions: bool = False
-    remove_common_llm_caption_prefixes: bool = False
+    dataset_config: str = None
+    dataset_shuffle_buffer_size: int = 1
+    precomputation_items: int = 512
+    precomputation_dir: Optional[str] = None
+    precomputation_once: bool = False
 
     # Dataloader arguments
     dataloader_num_workers: int = 0
@@ -296,11 +340,8 @@ class Args:
     training_type: str = None
     seed: int = 42
     batch_size: int = 1
-    train_epochs: int = 1
-    train_steps: int = None
-    rank: int = 128
-    lora_alpha: float = 64
-    target_modules: List[str] = ["to_k", "to_q", "to_v", "to_out.0"]
+    train_steps: int = 1000
+    max_data_samples: int = 2**64
     gradient_accumulation_steps: int = 1
     gradient_checkpointing: bool = False
     checkpointing_steps: int = 500
@@ -311,9 +352,7 @@ class Args:
 
     # Optimizer arguments
     optimizer: str = "adamw"
-    use_8bit_bnb: bool = False
     lr: float = 1e-4
-    scale_lr: bool = False
     lr_scheduler: str = "cosine_with_restarts"
     lr_warmup_steps: int = 0
     lr_num_cycles: int = 1
@@ -326,17 +365,9 @@ class Args:
     max_grad_norm: float = 1.0
 
     # Validation arguments
-    validation_prompts: List[str] = None
-    validation_images: List[str] = None
-    validation_videos: List[str] = None
-    validation_heights: List[int] = None
-    validation_widths: List[int] = None
-    validation_num_frames: List[int] = None
-    num_validation_videos_per_prompt: int = 1
-    validation_every_n_epochs: Optional[int] = None
-    validation_every_n_steps: Optional[int] = None
+    validation_dataset_file: Optional[str] = None
+    validation_steps: int = 500
     enable_model_cpu_offload: bool = False
-    validation_frame_rate: int = 25
 
     # Miscellaneous arguments
     tracker_name: str = "finetrainers"
@@ -345,664 +376,343 @@ class Args:
     hub_model_id: Optional[str] = None
     output_dir: str = None
     logging_dir: Optional[str] = "logs"
+    logging_steps: int = 1
     allow_tf32: bool = False
-    nccl_timeout: int = 1800  # 30 minutes
+    init_timeout: int = 300  # 5 minutes
+    nccl_timeout: int = 600  # 10 minutes, considering that validation may be performed
     report_to: str = "wandb"
+    verbose: int = 1
 
     def to_dict(self) -> Dict[str, Any]:
-        return {
-            "model_arguments": {
-                "model_name": self.model_name,
-                "pretrained_model_name_or_path": self.pretrained_model_name_or_path,
-                "revision": self.revision,
-                "variant": self.variant,
-                "cache_dir": self.cache_dir,
-                "text_encoder_dtype": self.text_encoder_dtype,
-                "text_encoder_2_dtype": self.text_encoder_2_dtype,
-                "text_encoder_3_dtype": self.text_encoder_3_dtype,
-                "transformer_dtype": self.transformer_dtype,
-                "vae_dtype": self.vae_dtype,
-                "layerwise_upcasting_modules": self.layerwise_upcasting_modules,
-                "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
-                "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
-            },
-            "dataset_arguments": {
-                "data_root": self.data_root,
-                "dataset_file": self.dataset_file,
-                "video_column": self.video_column,
-                "caption_column": self.caption_column,
-                "id_token": self.id_token,
-                "image_resolution_buckets": self.image_resolution_buckets,
-                "video_resolution_buckets": self.video_resolution_buckets,
-                "video_reshape_mode": self.video_reshape_mode,
-                "caption_dropout_p": self.caption_dropout_p,
-                "caption_dropout_technique": self.caption_dropout_technique,
-                "precompute_conditions": self.precompute_conditions,
-                "remove_common_llm_caption_prefixes": self.remove_common_llm_caption_prefixes,
-            },
-            "dataloader_arguments": {
-                "dataloader_num_workers": self.dataloader_num_workers,
-                "pin_memory": self.pin_memory,
-            },
-            "diffusion_arguments": {
-                "flow_resolution_shifting": self.flow_resolution_shifting,
-                "flow_base_seq_len": self.flow_base_seq_len,
-                "flow_max_seq_len": self.flow_max_seq_len,
-                "flow_base_shift": self.flow_base_shift,
-                "flow_max_shift": self.flow_max_shift,
-                "flow_shift": self.flow_shift,
-                "flow_weighting_scheme": self.flow_weighting_scheme,
-                "flow_logit_mean": self.flow_logit_mean,
-                "flow_logit_std": self.flow_logit_std,
-                "flow_mode_scale": self.flow_mode_scale,
-            },
-            "training_arguments": {
-                "training_type": self.training_type,
-                "seed": self.seed,
-                "batch_size": self.batch_size,
-                "train_epochs": self.train_epochs,
-                "train_steps": self.train_steps,
-                "rank": self.rank,
-                "lora_alpha": self.lora_alpha,
-                "target_modules": self.target_modules,
-                "gradient_accumulation_steps": self.gradient_accumulation_steps,
-                "gradient_checkpointing": self.gradient_checkpointing,
-                "checkpointing_steps": self.checkpointing_steps,
-                "checkpointing_limit": self.checkpointing_limit,
-                "resume_from_checkpoint": self.resume_from_checkpoint,
-                "enable_slicing": self.enable_slicing,
-                "enable_tiling": self.enable_tiling,
-            },
-            "optimizer_arguments": {
-                "optimizer": self.optimizer,
-                "use_8bit_bnb": self.use_8bit_bnb,
-                "lr": self.lr,
-                "scale_lr": self.scale_lr,
-                "lr_scheduler": self.lr_scheduler,
-                "lr_warmup_steps": self.lr_warmup_steps,
-                "lr_num_cycles": self.lr_num_cycles,
-                "lr_power": self.lr_power,
-                "beta1": self.beta1,
-                "beta2": self.beta2,
-                "beta3": self.beta3,
-                "weight_decay": self.weight_decay,
-                "epsilon": self.epsilon,
-                "max_grad_norm": self.max_grad_norm,
-            },
-            "validation_arguments": {
-                "validation_prompts": self.validation_prompts,
-                "validation_images": self.validation_images,
-                "validation_videos": self.validation_videos,
-                "num_validation_videos_per_prompt": self.num_validation_videos_per_prompt,
-                "validation_every_n_epochs": self.validation_every_n_epochs,
-                "validation_every_n_steps": self.validation_every_n_steps,
-                "enable_model_cpu_offload": self.enable_model_cpu_offload,
-                "validation_frame_rate": self.validation_frame_rate,
-            },
-            "miscellaneous_arguments": {
-                "tracker_name": self.tracker_name,
-                "push_to_hub": self.push_to_hub,
-                "hub_token": self.hub_token,
-                "hub_model_id": self.hub_model_id,
-                "output_dir": self.output_dir,
-                "logging_dir": self.logging_dir,
-                "allow_tf32": self.allow_tf32,
-                "nccl_timeout": self.nccl_timeout,
-                "report_to": self.report_to,
-            },
+        parallel_arguments = {
+            "pp_degree": self.pp_degree,
+            "dp_degree": self.dp_degree,
+            "dp_shards": self.dp_shards,
+            "cp_degree": self.cp_degree,
+            "tp_degree": self.tp_degree,
         }
 
+        model_arguments = {
+            "model_name": self.model_name,
+            "pretrained_model_name_or_path": self.pretrained_model_name_or_path,
+            "revision": self.revision,
+            "variant": self.variant,
+            "cache_dir": self.cache_dir,
+            "tokenizer_id": self.tokenizer_id,
+            "tokenizer_2_id": self.tokenizer_2_id,
+            "tokenizer_3_id": self.tokenizer_3_id,
+            "text_encoder_id": self.text_encoder_id,
+            "text_encoder_2_id": self.text_encoder_2_id,
+            "text_encoder_3_id": self.text_encoder_3_id,
+            "transformer_id": self.transformer_id,
+            "vae_id": self.vae_id,
+            "text_encoder_dtype": self.text_encoder_dtype,
+            "text_encoder_2_dtype": self.text_encoder_2_dtype,
+            "text_encoder_3_dtype": self.text_encoder_3_dtype,
+            "transformer_dtype": self.transformer_dtype,
+            "vae_dtype": self.vae_dtype,
+            "layerwise_upcasting_modules": self.layerwise_upcasting_modules,
+            "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
+            "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
+        }
+        model_arguments = get_non_null_items(model_arguments)
+
+        dataset_arguments = {
+            "dataset_config": self.dataset_config,
+            "dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
+            "precomputation_items": self.precomputation_items,
+            "precomputation_dir": self.precomputation_dir,
+            "precomputation_once": self.precomputation_once,
+        }
+        dataset_arguments = get_non_null_items(dataset_arguments)
 
-# TODO(aryan): handle more informative messages
-_IS_ARGUMENTS_REQUIRED = "--list_models" not in sys.argv
-
-
-def parse_arguments() -> Args:
-    parser = argparse.ArgumentParser()
+        dataloader_arguments = {
+            "dataloader_num_workers": self.dataloader_num_workers,
+            "pin_memory": self.pin_memory,
+        }
 
-    if _IS_ARGUMENTS_REQUIRED:
-        _add_model_arguments(parser)
-        _add_dataset_arguments(parser)
-        _add_dataloader_arguments(parser)
-        _add_diffusion_arguments(parser)
-        _add_training_arguments(parser)
-        _add_optimizer_arguments(parser)
-        _add_validation_arguments(parser)
-        _add_miscellaneous_arguments(parser)
+        diffusion_arguments = {
+            "flow_resolution_shifting": self.flow_resolution_shifting,
+            "flow_base_seq_len": self.flow_base_seq_len,
+            "flow_max_seq_len": self.flow_max_seq_len,
+            "flow_base_shift": self.flow_base_shift,
+            "flow_max_shift": self.flow_max_shift,
+            "flow_shift": self.flow_shift,
+            "flow_weighting_scheme": self.flow_weighting_scheme,
+            "flow_logit_mean": self.flow_logit_mean,
+            "flow_logit_std": self.flow_logit_std,
+            "flow_mode_scale": self.flow_mode_scale,
+        }
 
-        args = parser.parse_args()
-        return _map_to_args_type(args)
-    else:
-        _add_helper_arguments(parser)
+        training_arguments = {
+            "training_type": self.training_type,
+            "seed": self.seed,
+            "batch_size": self.batch_size,
+            "train_steps": self.train_steps,
+            "max_data_samples": self.max_data_samples,
+            "gradient_accumulation_steps": self.gradient_accumulation_steps,
+            "gradient_checkpointing": self.gradient_checkpointing,
+            "checkpointing_steps": self.checkpointing_steps,
+            "checkpointing_limit": self.checkpointing_limit,
+            "resume_from_checkpoint": self.resume_from_checkpoint,
+            "enable_slicing": self.enable_slicing,
+            "enable_tiling": self.enable_tiling,
+        }
+        training_arguments = get_non_null_items(training_arguments)
+
+        optimizer_arguments = {
+            "optimizer": self.optimizer,
+            "lr": self.lr,
+            "lr_scheduler": self.lr_scheduler,
+            "lr_warmup_steps": self.lr_warmup_steps,
+            "lr_num_cycles": self.lr_num_cycles,
+            "lr_power": self.lr_power,
+            "beta1": self.beta1,
+            "beta2": self.beta2,
+            "beta3": self.beta3,
+            "weight_decay": self.weight_decay,
+            "epsilon": self.epsilon,
+            "max_grad_norm": self.max_grad_norm,
+        }
+        optimizer_arguments = get_non_null_items(optimizer_arguments)
 
-        args = parser.parse_args()
-        _display_helper_messages(args)
-        sys.exit(0)
+        validation_arguments = {
+            "validation_dataset_file": self.validation_dataset_file,
+            "validation_steps": self.validation_steps,
+            "enable_model_cpu_offload": self.enable_model_cpu_offload,
+        }
+        validation_arguments = get_non_null_items(validation_arguments)
+
+        miscellaneous_arguments = {
+            "tracker_name": self.tracker_name,
+            "push_to_hub": self.push_to_hub,
+            "hub_token": self.hub_token,
+            "hub_model_id": self.hub_model_id,
+            "output_dir": self.output_dir,
+            "logging_dir": self.logging_dir,
+            "logging_steps": self.logging_steps,
+            "allow_tf32": self.allow_tf32,
+            "init_timeout": self.init_timeout,
+            "nccl_timeout": self.nccl_timeout,
+            "report_to": self.report_to,
+            "verbose": self.verbose,
+        }
+        miscellaneous_arguments = get_non_null_items(miscellaneous_arguments)
 
+        return {
+            "parallel_arguments": parallel_arguments,
+            "model_arguments": model_arguments,
+            "dataset_arguments": dataset_arguments,
+            "dataloader_arguments": dataloader_arguments,
+            "diffusion_arguments": diffusion_arguments,
+            "training_arguments": training_arguments,
+            "optimizer_arguments": optimizer_arguments,
+            "validation_arguments": validation_arguments,
+            "miscellaneous_arguments": miscellaneous_arguments,
+        }
 
-def validate_args(args: Args):
-    _validated_model_args(args)
-    _validate_training_args(args)
+    def extend_args(
+        self,
+        add_fn: Callable[[argparse.ArgumentParser], None],
+        map_fn: Callable[["BaseArgs"], None],
+        validate_fn: Callable[["BaseArgs"], None],
+    ) -> None:
+        if not hasattr(self, "_extended_add_arguments"):
+            self._extended_add_arguments = []
+        self._extended_add_arguments.append((add_fn, validate_fn, map_fn))
+
+    def parse_args(self):
+        _LIST_MODELS = "--list_models"
+
+        parser = argparse.ArgumentParser()
+
+        special_args = [_LIST_MODELS]
+        if any(arg in sys.argv for arg in special_args):
+            _add_helper_arguments(parser)
+            args = parser.parse_args()
+            _display_helper_messages(args)
+            sys.exit(0)
+        else:
+            _add_args(parser)
+            for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
+                add_fn, _, _ = extended_add_arg_fns
+                add_fn(parser)
+
+            args, remaining_args = parser.parse_known_args()
+            logger.debug(f"Remaining unparsed arguments: {remaining_args}")
+
+            mapped_args = _map_to_args_type(args)
+            for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
+                _, _, map_fn = extended_add_arg_fns
+                map_fn(args, mapped_args)
+
+            _validate_args(mapped_args)
+            for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
+                _, validate_fn, _ = extended_add_arg_fns
+                validate_fn(mapped_args)
+
+            return mapped_args
+
+
+def _add_args(parser: argparse.ArgumentParser) -> None:
+    _add_parallel_arguments(parser)
+    _add_model_arguments(parser)
+    _add_dataset_arguments(parser)
+    _add_dataloader_arguments(parser)
+    _add_diffusion_arguments(parser)
+    _add_training_arguments(parser)
+    _add_optimizer_arguments(parser)
+    _add_validation_arguments(parser)
+    _add_miscellaneous_arguments(parser)
+
+
+def _validate_args(args: BaseArgs):
+    _validate_model_args(args)
+    _validate_dataset_args(args)
     _validate_validation_args(args)
 
 
-def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
-    parser.add_argument(
-        "--model_name",
-        type=str,
-        required=True,
-        choices=list(SUPPORTED_MODEL_CONFIGS.keys()),
-        help="Name of model to train.",
-    )
-    parser.add_argument(
-        "--pretrained_model_name_or_path",
-        type=str,
-        required=True,
-        help="Path to pretrained model or model identifier from huggingface.co/models.",
-    )
+def _add_parallel_arguments(parser: argparse.ArgumentParser) -> None:
     parser.add_argument(
-        "--revision",
+        "--parallel_backend",
         type=str,
-        default=None,
-        required=False,
-        help="Revision of pretrained model identifier from huggingface.co/models.",
+        default=ParallelBackendEnum.ACCELERATE,
+        choices=[ParallelBackendEnum.ACCELERATE, ParallelBackendEnum.PTD],
     )
+    parser.add_argument("--pp_degree", type=int, default=1)
+    parser.add_argument("--dp_degree", type=int, default=1)
+    parser.add_argument("--dp_shards", type=int, default=1)
+    parser.add_argument("--cp_degree", type=int, default=1)
+    parser.add_argument("--tp_degree", type=int, default=1)
+
+
+def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
     parser.add_argument(
-        "--variant",
-        type=str,
-        default=None,
-        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
-    )
-    parser.add_argument(
-        "--cache_dir",
-        type=str,
-        default=None,
-        help="The directory where the downloaded models and datasets will be stored.",
-    )
-    parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.")
-    parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.")
-    parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
-    parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
-    parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")
-    parser.add_argument(
-        "--layerwise_upcasting_modules",
-        type=str,
-        default=[],
-        nargs="+",
-        choices=["transformer"],
-        help="Modules that should have fp8 storage weights but higher precision computation.",
-    )
+        "--model_name", type=str, required=True, choices=[x.value for x in ModelType.__members__.values()]
+    )
+    parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
+    parser.add_argument("--revision", type=str, default=None, required=False)
+    parser.add_argument("--variant", type=str, default=None)
+    parser.add_argument("--cache_dir", type=str, default=None)
+    parser.add_argument("--tokenizer_id", type=str, default=None)
+    parser.add_argument("--tokenizer_2_id", type=str, default=None)
+    parser.add_argument("--tokenizer_3_id", type=str, default=None)
+    parser.add_argument("--text_encoder_id", type=str, default=None)
+    parser.add_argument("--text_encoder_2_id", type=str, default=None)
+    parser.add_argument("--text_encoder_3_id", type=str, default=None)
+    parser.add_argument("--transformer_id", type=str, default=None)
+    parser.add_argument("--vae_id", type=str, default=None)
+    parser.add_argument("--text_encoder_dtype", type=str, default="bf16")
+    parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16")
+    parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16")
+    parser.add_argument("--transformer_dtype", type=str, default="bf16")
+    parser.add_argument("--vae_dtype", type=str, default="bf16")
+    parser.add_argument("--layerwise_upcasting_modules", type=str, default=[], nargs="+", choices=["transformer"])
     parser.add_argument(
         "--layerwise_upcasting_storage_dtype",
         type=str,
         default="float8_e4m3fn",
         choices=["float8_e4m3fn", "float8_e5m2"],
-        help="Data type for the layerwise upcasting storage.",
     )
     parser.add_argument(
         "--layerwise_upcasting_skip_modules_pattern",
         type=str,
         default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
         nargs="+",
-        help="Modules to skip for layerwise upcasting.",
     )
 
 
 def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
-    def parse_resolution_bucket(resolution_bucket: str) -> Tuple[int, ...]:
-        return tuple(map(int, resolution_bucket.split("x")))
-
-    def parse_image_resolution_bucket(resolution_bucket: str) -> Tuple[int, int]:
-        resolution_bucket = parse_resolution_bucket(resolution_bucket)
-        assert (
-            len(resolution_bucket) == 2
-        ), f"Expected 2D resolution bucket, got {len(resolution_bucket)}D resolution bucket"
-        return resolution_bucket
-
-    def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int]:
-        resolution_bucket = parse_resolution_bucket(resolution_bucket)
-        assert (
-            len(resolution_bucket) == 3
-        ), f"Expected 3D resolution bucket, got {len(resolution_bucket)}D resolution bucket"
-        return resolution_bucket
-
-    parser.add_argument(
-        "--data_root",
-        type=str,
-        required=True,
-        help=("A folder containing the training data."),
-    )
-    parser.add_argument(
-        "--dataset_file",
-        type=str,
-        default=None,
-        help=("Path to a CSV file if loading prompts/video paths using this format."),
-    )
-    parser.add_argument(
-        "--video_column",
-        type=str,
-        default="video",
-        help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
-    )
-    parser.add_argument(
-        "--caption_column",
-        type=str,
-        default="text",
-        help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
-    )
-    parser.add_argument(
-        "--id_token",
-        type=str,
-        default=None,
-        help="Identifier token appended to the start of each prompt if provided.",
-    )
-    parser.add_argument(
-        "--image_resolution_buckets",
-        type=parse_image_resolution_bucket,
-        default=None,
-        nargs="+",
-        help="Resolution buckets for images.",
-    )
-    parser.add_argument(
-        "--video_resolution_buckets",
-        type=parse_video_resolution_bucket,
-        default=None,
-        nargs="+",
-        help="Resolution buckets for videos.",
-    )
-    parser.add_argument(
-        "--video_reshape_mode",
-        type=str,
-        default=None,
-        help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
-    )
-    parser.add_argument(
-        "--caption_dropout_p",
-        type=float,
-        default=0.00,
-        help="Probability of dropout for the caption tokens.",
-    )
-    parser.add_argument(
-        "--caption_dropout_technique",
-        type=str,
-        default="empty",
-        choices=["empty", "zero"],
-        help="Technique to use for caption dropout.",
-    )
-    parser.add_argument(
-        "--precompute_conditions",
-        action="store_true",
-        help="Whether or not to precompute the conditionings for the model.",
-    )
-    parser.add_argument(
-        "--remove_common_llm_caption_prefixes",
-        action="store_true",
-        help="Whether or not to remove common LLM caption prefixes.",
-    )
+    parser.add_argument("--dataset_config", type=str, required=True)
+    parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
+    parser.add_argument("--precomputation_items", type=int, default=512)
+    parser.add_argument("--precomputation_dir", type=str, default=None)
+    parser.add_argument("--precomputation_once", action="store_true")
 
 
 def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
-    parser.add_argument(
-        "--dataloader_num_workers",
-        type=int,
-        default=0,
-        help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
-    )
-    parser.add_argument(
-        "--pin_memory",
-        action="store_true",
-        help="Whether or not to use the pinned memory setting in pytorch dataloader.",
-    )
+    parser.add_argument("--dataloader_num_workers", type=int, default=0)
+    parser.add_argument("--pin_memory", action="store_true")
 
 
 def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
-    parser.add_argument(
-        "--flow_resolution_shifting",
-        action="store_true",
-        help="Resolution-dependent shifting of timestep schedules.",
-    )
-    parser.add_argument(
-        "--flow_base_seq_len",
-        type=int,
-        default=256,
-        help="Base image/video sequence length for the diffusion model.",
-    )
-    parser.add_argument(
-        "--flow_max_seq_len",
-        type=int,
-        default=4096,
-        help="Maximum image/video sequence length for the diffusion model.",
-    )
-    parser.add_argument(
-        "--flow_base_shift",
-        type=float,
-        default=0.5,
-        help="Base shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)",
-    )
-    parser.add_argument(
-        "--flow_max_shift",
-        type=float,
-        default=1.15,
-        help="Maximum shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)",
-    )
-    parser.add_argument(
-        "--flow_shift",
-        type=float,
-        default=1.0,
-        help="Shift value to use for the flow matching timestep schedule.",
-    )
+    parser.add_argument("--flow_resolution_shifting", action="store_true")
+    parser.add_argument("--flow_base_seq_len", type=int, default=256)
+    parser.add_argument("--flow_max_seq_len", type=int, default=4096)
+    parser.add_argument("--flow_base_shift", type=float, default=0.5)
+    parser.add_argument("--flow_max_shift", type=float, default=1.15)
+    parser.add_argument("--flow_shift", type=float, default=1.0)
     parser.add_argument(
         "--flow_weighting_scheme",
         type=str,
         default="none",
         choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
-        help='We default to the "none" weighting scheme for uniform sampling and uniform loss',
-    )
-    parser.add_argument(
-        "--flow_logit_mean",
-        type=float,
-        default=0.0,
-        help="Mean to use when using the `'logit_normal'` weighting scheme.",
-    )
-    parser.add_argument(
-        "--flow_logit_std",
-        type=float,
-        default=1.0,
-        help="Standard deviation to use when using the `'logit_normal'` weighting scheme.",
-    )
-    parser.add_argument(
-        "--flow_mode_scale",
-        type=float,
-        default=1.29,
-        help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
     )
+    parser.add_argument("--flow_logit_mean", type=float, default=0.0)
+    parser.add_argument("--flow_logit_std", type=float, default=1.0)
+    parser.add_argument("--flow_mode_scale", type=float, default=1.29)
 
 
 def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
-    # TODO: support full finetuning and other kinds
-    parser.add_argument(
-        "--training_type",
-        type=str,
-        choices=["lora", "full-finetune"],
-        required=True,
-        help="Type of training to perform. Choose between ['lora', 'full-finetune']",
-    )
-    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
-    parser.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="Batch size (per device) for the training dataloader.",
-    )
-    parser.add_argument("--train_epochs", type=int, default=1, help="Number of training epochs.")
-    parser.add_argument(
-        "--train_steps",
-        type=int,
-        default=None,
-        help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
-    )
-    parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
-    parser.add_argument(
-        "--lora_alpha",
-        type=int,
-        default=64,
-        help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
-    )
-    parser.add_argument(
-        "--target_modules",
-        type=str,
-        default=["to_k", "to_q", "to_v", "to_out.0"],
-        nargs="+",
-        help="The target modules for LoRA.",
-    )
-    parser.add_argument(
-        "--gradient_accumulation_steps",
-        type=int,
-        default=1,
-        help="Number of updates steps to accumulate before performing a backward/update pass.",
-    )
     parser.add_argument(
-        "--gradient_checkpointing",
-        action="store_true",
-        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
-    )
-    parser.add_argument(
-        "--checkpointing_steps",
-        type=int,
-        default=500,
-        help=(
-            "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
-            " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
-            " training using `--resume_from_checkpoint`."
-        ),
-    )
-    parser.add_argument(
-        "--checkpointing_limit",
-        type=int,
-        default=None,
-        help=("Max number of checkpoints to store."),
-    )
-    parser.add_argument(
-        "--resume_from_checkpoint",
-        type=str,
-        default=None,
-        help=(
-            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
-            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
-        ),
-    )
-    parser.add_argument(
-        "--enable_slicing",
-        action="store_true",
-        help="Whether or not to use VAE slicing for saving memory.",
-    )
-    parser.add_argument(
-        "--enable_tiling",
-        action="store_true",
-        help="Whether or not to use VAE tiling for saving memory.",
+        "--training_type", type=str, choices=[x.value for x in TrainingType.__members__.values()], required=True
     )
+    parser.add_argument("--seed", type=int, default=None)
+    parser.add_argument("--batch_size", type=int, default=1)
+    parser.add_argument("--train_steps", type=int, default=1000)
+    parser.add_argument("--max_data_samples", type=int, default=2**64)
+    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
+    parser.add_argument("--gradient_checkpointing", action="store_true")
+    parser.add_argument("--checkpointing_steps", type=int, default=500)
+    parser.add_argument("--checkpointing_limit", type=int, default=None)
+    parser.add_argument("--resume_from_checkpoint", type=str, default=None)
+    parser.add_argument("--enable_slicing", action="store_true")
+    parser.add_argument("--enable_tiling", action="store_true")
 
 
 def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None:
-    parser.add_argument(
-        "--lr",
-        type=float,
-        default=1e-4,
-        help="Initial learning rate (after the potential warmup period) to use.",
-    )
-    parser.add_argument(
-        "--scale_lr",
-        action="store_true",
-        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
-    )
-    parser.add_argument(
-        "--lr_scheduler",
-        type=str,
-        default="constant",
-        help=(
-            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
-            ' "constant", "constant_with_warmup"]'
-        ),
-    )
-    parser.add_argument(
-        "--lr_warmup_steps",
-        type=int,
-        default=500,
-        help="Number of steps for the warmup in the lr scheduler.",
-    )
-    parser.add_argument(
-        "--lr_num_cycles",
-        type=int,
-        default=1,
-        help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
-    )
-    parser.add_argument(
-        "--lr_power",
-        type=float,
-        default=1.0,
-        help="Power factor of the polynomial scheduler.",
-    )
+    parser.add_argument("--lr", type=float, default=1e-4)
+    parser.add_argument("--lr_scheduler", type=str, default="constant")
+    parser.add_argument("--lr_warmup_steps", type=int, default=500)
+    parser.add_argument("--lr_num_cycles", type=int, default=1)
+    parser.add_argument("--lr_power", type=float, default=1.0)
     parser.add_argument(
         "--optimizer",
         type=lambda s: s.lower(),
         default="adam",
-        choices=["adam", "adamw"],
-        help=("The optimizer type to use."),
+        choices=["adam", "adamw", "adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"],
     )
-    parser.add_argument(
-        "--use_8bit_bnb",
-        action="store_true",
-        help=("Whether to use 8bit variant of the `--optimizer` using `bitsandbytes`."),
-    )
-    parser.add_argument(
-        "--beta1",
-        type=float,
-        default=0.9,
-        help="The beta1 parameter for the Adam and Prodigy optimizers.",
-    )
-    parser.add_argument(
-        "--beta2",
-        type=float,
-        default=0.95,
-        help="The beta2 parameter for the Adam and Prodigy optimizers.",
-    )
-    parser.add_argument(
-        "--beta3",
-        type=float,
-        default=None,
-        help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
-    )
-    parser.add_argument(
-        "--weight_decay",
-        type=float,
-        default=1e-04,
-        help="Weight decay to use for optimizer.",
-    )
-    parser.add_argument(
-        "--epsilon",
-        type=float,
-        default=1e-8,
-        help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
-    )
-    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+    parser.add_argument("--beta1", type=float, default=0.9)
+    parser.add_argument("--beta2", type=float, default=0.95)
+    parser.add_argument("--beta3", type=float, default=None)
+    parser.add_argument("--weight_decay", type=float, default=1e-04)
+    parser.add_argument("--epsilon", type=float, default=1e-8)
+    parser.add_argument("--max_grad_norm", default=1.0, type=float)
 
 
 def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
-    parser.add_argument(
-        "--validation_prompts",
-        type=str,
-        default=None,
-        help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
-    )
-    parser.add_argument(
-        "--validation_images",
-        type=str,
-        default=None,
-        help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
-    )
-    parser.add_argument(
-        "--validation_videos",
-        type=str,
-        default=None,
-        help="One or more video path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
-    )
-    parser.add_argument(
-        "--validation_separator",
-        type=str,
-        default=":::",
-        help="String that separates multiple validation prompts",
-    )
-    parser.add_argument(
-        "--num_validation_videos",
-        type=int,
-        default=1,
-        help="Number of videos that should be generated during validation per `validation_prompt`.",
-    )
-    parser.add_argument(
-        "--validation_epochs",
-        type=int,
-        default=None,
-        help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
-    )
-    parser.add_argument(
-        "--validation_steps",
-        type=int,
-        default=None,
-        help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
-    )
-    parser.add_argument(
-        "--validation_frame_rate",
-        type=int,
-        default=25,
-        help="Frame rate to use for the validation videos.",
-    )
-    parser.add_argument(
-        "--enable_model_cpu_offload",
-        action="store_true",
-        help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
-    )
+    parser.add_argument("--validation_dataset_file", type=str, default=None)
+    parser.add_argument("--validation_steps", type=int, default=500)
+    parser.add_argument("--enable_model_cpu_offload", action="store_true")
 
 
 def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
-    parser.add_argument("--tracker_name", type=str, default="finetrainers", help="Project tracker name")
-    parser.add_argument(
-        "--push_to_hub",
-        action="store_true",
-        help="Whether or not to push the model to the Hub.",
-    )
-    parser.add_argument(
-        "--hub_token",
-        type=str,
-        default=None,
-        help="The token to use to push to the Model Hub.",
-    )
-    parser.add_argument(
-        "--hub_model_id",
-        type=str,
-        default=None,
-        help="The name of the repository to keep in sync with the local `output_dir`.",
-    )
-    parser.add_argument(
-        "--output_dir",
-        type=str,
-        default="finetrainers-training",
-        help="The output directory where the model predictions and checkpoints will be written.",
-    )
-    parser.add_argument(
-        "--logging_dir",
-        type=str,
-        default="logs",
-        help="Directory where logs are stored.",
-    )
-    parser.add_argument(
-        "--allow_tf32",
-        action="store_true",
-        help=(
-            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
-            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
-        ),
-    )
-    parser.add_argument(
-        "--nccl_timeout",
-        type=int,
-        default=600,
-        help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.",
-    )
-    parser.add_argument(
-        "--report_to",
-        type=str,
-        default="none",
-        choices=["none", "wandb"],
-        help="The integration to report the results and logs to.",
-    )
+    parser.add_argument("--tracker_name", type=str, default="finetrainers")
+    parser.add_argument("--push_to_hub", action="store_true")
+    parser.add_argument("--hub_token", type=str, default=None)
+    parser.add_argument("--hub_model_id", type=str, default=None)
+    parser.add_argument("--output_dir", type=str, default="finetrainers-training")
+    parser.add_argument("--logging_dir", type=str, default="logs")
+    parser.add_argument("--logging_steps", type=int, default=1)
+    parser.add_argument("--allow_tf32", action="store_true")
+    parser.add_argument("--init_timeout", type=int, default=300)
+    parser.add_argument("--nccl_timeout", type=int, default=600)
+    parser.add_argument("--report_to", type=str, default="none", choices=["none", "wandb"])
+    parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3])
 
 
 def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
-    parser.add_argument(
-        "--list_models",
-        action="store_true",
-        help="List all the supported models.",
-    )
+    parser.add_argument("--list_models", action="store_true")
 
 
 _DTYPE_MAP = {
@@ -1014,8 +724,16 @@ _DTYPE_MAP = {
 }
 
 
-def _map_to_args_type(args: Dict[str, Any]) -> Args:
-    result_args = Args()
+def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs:
+    result_args = BaseArgs()
+
+    # Parallel arguments
+    result_args.parallel_backend = args.parallel_backend
+    result_args.pp_degree = args.pp_degree
+    result_args.dp_degree = args.dp_degree
+    result_args.dp_shards = args.dp_shards
+    result_args.cp_degree = args.cp_degree
+    result_args.tp_degree = args.tp_degree
 
     # Model arguments
     result_args.model_name = args.model_name
@@ -1023,6 +741,14 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
     result_args.revision = args.revision
     result_args.variant = args.variant
     result_args.cache_dir = args.cache_dir
+    result_args.tokenizer_id = args.tokenizer_id
+    result_args.tokenizer_2_id = args.tokenizer_2_id
+    result_args.tokenizer_3_id = args.tokenizer_3_id
+    result_args.text_encoder_id = args.text_encoder_id
+    result_args.text_encoder_2_id = args.text_encoder_2_id
+    result_args.text_encoder_3_id = args.text_encoder_3_id
+    result_args.transformer_id = args.transformer_id
+    result_args.vae_id = args.vae_id
     result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
     result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
     result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
@@ -1033,21 +759,11 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
     result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern
 
     # Dataset arguments
-    if args.data_root is None and args.dataset_file is None:
-        raise ValueError("At least one of `data_root` or `dataset_file` should be provided.")
-
-    result_args.data_root = args.data_root
-    result_args.dataset_file = args.dataset_file
-    result_args.video_column = args.video_column
-    result_args.caption_column = args.caption_column
-    result_args.id_token = args.id_token
-    result_args.image_resolution_buckets = args.image_resolution_buckets or DEFAULT_IMAGE_RESOLUTION_BUCKETS
-    result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS
-    result_args.video_reshape_mode = args.video_reshape_mode
-    result_args.caption_dropout_p = args.caption_dropout_p
-    result_args.caption_dropout_technique = args.caption_dropout_technique
-    result_args.precompute_conditions = args.precompute_conditions
-    result_args.remove_common_llm_caption_prefixes = args.remove_common_llm_caption_prefixes
+    result_args.dataset_config = args.dataset_config
+    result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
+    result_args.precomputation_items = args.precomputation_items
+    result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
+    result_args.precomputation_once = args.precomputation_once
 
     # Dataloader arguments
     result_args.dataloader_num_workers = args.dataloader_num_workers
@@ -1069,11 +785,8 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
     result_args.training_type = args.training_type
     result_args.seed = args.seed
     result_args.batch_size = args.batch_size
-    result_args.train_epochs = args.train_epochs
     result_args.train_steps = args.train_steps
-    result_args.rank = args.rank
-    result_args.lora_alpha = args.lora_alpha
-    result_args.target_modules = args.target_modules
+    result_args.max_data_samples = args.max_data_samples
     result_args.gradient_accumulation_steps = args.gradient_accumulation_steps
     result_args.gradient_checkpointing = args.gradient_checkpointing
     result_args.checkpointing_steps = args.checkpointing_steps
@@ -1084,9 +797,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
 
     # Optimizer arguments
     result_args.optimizer = args.optimizer or "adamw"
-    result_args.use_8bit_bnb = args.use_8bit_bnb
     result_args.lr = args.lr or 1e-4
-    result_args.scale_lr = args.scale_lr
     result_args.lr_scheduler = args.lr_scheduler
     result_args.lr_warmup_steps = args.lr_warmup_steps
     result_args.lr_num_cycles = args.lr_num_cycles
@@ -1099,42 +810,9 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
     result_args.max_grad_norm = args.max_grad_norm
 
     # Validation arguments
-    validation_prompts = args.validation_prompts.split(args.validation_separator) if args.validation_prompts else []
-    validation_images = args.validation_images.split(args.validation_separator) if args.validation_images else None
-    validation_videos = args.validation_videos.split(args.validation_separator) if args.validation_videos else None
-    stripped_validation_prompts = []
-    validation_heights = []
-    validation_widths = []
-    validation_num_frames = []
-    for prompt in validation_prompts:
-        prompt: str
-        prompt = prompt.strip()
-        actual_prompt, separator, resolution = prompt.rpartition("@@@")
-        stripped_validation_prompts.append(actual_prompt)
-        num_frames, height, width = None, None, None
-        if len(resolution) > 0:
-            num_frames, height, width = map(int, resolution.split("x"))
-        validation_num_frames.append(num_frames)
-        validation_heights.append(height)
-        validation_widths.append(width)
-
-    if validation_images is None:
-        validation_images = [None] * len(validation_prompts)
-    if validation_videos is None:
-        validation_videos = [None] * len(validation_prompts)
-
-    result_args.validation_prompts = stripped_validation_prompts
-    result_args.validation_heights = validation_heights
-    result_args.validation_widths = validation_widths
-    result_args.validation_num_frames = validation_num_frames
-    result_args.validation_images = validation_images
-    result_args.validation_videos = validation_videos
-
-    result_args.num_validation_videos_per_prompt = args.num_validation_videos
-    result_args.validation_every_n_epochs = args.validation_epochs
-    result_args.validation_every_n_steps = args.validation_steps
+    result_args.validation_dataset_file = args.validation_dataset_file
+    result_args.validation_steps = args.validation_steps
     result_args.enable_model_cpu_offload = args.enable_model_cpu_offload
-    result_args.validation_frame_rate = args.validation_frame_rate
 
     # Miscellaneous arguments
     result_args.tracker_name = args.tracker_name
@@ -1143,45 +821,36 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
     result_args.hub_model_id = args.hub_model_id
     result_args.output_dir = args.output_dir
     result_args.logging_dir = args.logging_dir
+    result_args.logging_steps = args.logging_steps
     result_args.allow_tf32 = args.allow_tf32
+    result_args.init_timeout = args.init_timeout
     result_args.nccl_timeout = args.nccl_timeout
     result_args.report_to = args.report_to
+    result_args.verbose = args.verbose
 
     return result_args
 
 
-def _validated_model_args(args: Args):
+def _validate_model_args(args: BaseArgs):
     if args.training_type == "full-finetune":
         assert (
             "transformer" not in args.layerwise_upcasting_modules
         ), "Layerwise upcasting is not supported for full-finetune training"
 
 
-def _validate_training_args(args: Args):
-    if args.training_type == "lora":
-        assert args.rank is not None, "Rank is required for LoRA training"
-        assert args.lora_alpha is not None, "LoRA alpha is required for LoRA training"
-        assert (
-            args.target_modules is not None and len(args.target_modules) > 0
-        ), "Target modules are required for LoRA training"
-
-
-def _validate_validation_args(args: Args):
-    assert args.validation_prompts is not None, "Validation prompts are required for validation"
-    if args.validation_images is not None:
-        assert len(args.validation_images) == len(
-            args.validation_prompts
-        ), "Validation images and prompts should be of same length"
-    if args.validation_videos is not None:
-        assert len(args.validation_videos) == len(
-            args.validation_prompts
-        ), "Validation videos and prompts should be of same length"
-    assert len(args.validation_prompts) == len(
-        args.validation_heights
-    ), "Validation prompts and heights should be of same length"
-    assert len(args.validation_prompts) == len(
-        args.validation_widths
-    ), "Validation prompts and widths should be of same length"
+def _validate_dataset_args(args: BaseArgs):
+    dataset_config = pathlib.Path(args.dataset_config)
+    if not dataset_config.exists():
+        raise ValueError(f"Dataset config file {args.dataset_config} does not exist.")
+    if args.dataset_shuffle_buffer_size < 1:
+        raise ValueError("Dataset shuffle buffer size must be greater than 0.")
+    if args.precomputation_items < 1:
+        raise ValueError("Precomputation items must be greater than 0.")
+
+
+def _validate_validation_args(args: BaseArgs):
+    if args.dp_shards > 1 and args.enable_model_cpu_offload:
+        raise ValueError("Model CPU offload is not supported with FSDP at the moment.")
 
 
 def _display_helper_messages(args: argparse.Namespace):
diff --git a/finetrainers/config.py b/finetrainers/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0cda5e831d89f2bce9555bbea44ac87260fe67d
--- /dev/null
+++ b/finetrainers/config.py
@@ -0,0 +1,52 @@
+from enum import Enum
+from typing import Type
+
+from .models import ModelSpecification
+from .models.cogvideox import CogVideoXModelSpecification
+from .models.hunyuan_video import HunyuanVideoModelSpecification
+from .models.ltx_video import LTXVideoModelSpecification
+from .models.wan import WanModelSpecification
+
+
+class ModelType(str, Enum):
+    COGVIDEOX = "cogvideox"
+    HUNYUAN_VIDEO = "hunyuan_video"
+    LTX_VIDEO = "ltx_video"
+    WAN = "wan"
+
+
+class TrainingType(str, Enum):
+    LORA = "lora"
+    FULL_FINETUNE = "full-finetune"
+
+
+SUPPORTED_MODEL_CONFIGS = {
+    ModelType.HUNYUAN_VIDEO: {
+        TrainingType.LORA: HunyuanVideoModelSpecification,
+        TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
+    },
+    ModelType.LTX_VIDEO: {
+        TrainingType.LORA: LTXVideoModelSpecification,
+        TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
+    },
+    ModelType.COGVIDEOX: {
+        TrainingType.LORA: CogVideoXModelSpecification,
+        TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
+    },
+    ModelType.WAN: {
+        TrainingType.LORA: WanModelSpecification,
+        TrainingType.FULL_FINETUNE: WanModelSpecification,
+    },
+}
+
+
+def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]:
+    if model_name not in SUPPORTED_MODEL_CONFIGS:
+        raise ValueError(
+            f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
+        )
+    if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
+        raise ValueError(
+            f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
+        )
+    return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
diff --git a/finetrainers/constants.py b/finetrainers/constants.py
index f6318f4cc43f96db9adf94b22ae92684f976f6fe..693495ca0e91617dce6c35583b2e1a18c9025708 100644
--- a/finetrainers/constants.py
+++ b/finetrainers/constants.py
@@ -78,3 +78,6 @@ COMMON_LLM_START_PHRASES = (
         for continuation in _COMMON_CONTINUATION_WORDS
     ),
 )
+
+SUPPORTED_IMAGE_FILE_EXTENSIONS = ("jpg", "jpeg", "png")
+SUPPORTED_VIDEO_FILE_EXTENSIONS = ("mp4", "mov")
diff --git a/finetrainers/data/__init__.py b/finetrainers/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7706fb999099fa324c2b5fe73f62ad57e89e0796
--- /dev/null
+++ b/finetrainers/data/__init__.py
@@ -0,0 +1,19 @@
+from ._artifact import ImageArtifact, VideoArtifact
+from .dataloader import DPDataLoader
+from .dataset import (
+    ImageCaptionFilePairDataset,
+    ImageFileCaptionFileListDataset,
+    ImageFolderDataset,
+    ImageWebDataset,
+    ValidationDataset,
+    VideoCaptionFilePairDataset,
+    VideoFileCaptionFileListDataset,
+    VideoFolderDataset,
+    VideoWebDataset,
+    combine_datasets,
+    initialize_dataset,
+    wrap_iterable_dataset_for_preprocessing,
+)
+from .precomputation import DistributedDataPreprocessor, PreprocessedDataIterable
+from .sampler import ResolutionSampler
+from .utils import find_files
diff --git a/finetrainers/data/_artifact.py b/finetrainers/data/_artifact.py
new file mode 100644
index 0000000000000000000000000000000000000000..400f25d143f5062d77ed6391ca9862654d295de7
--- /dev/null
+++ b/finetrainers/data/_artifact.py
@@ -0,0 +1,29 @@
+# ===== THIS FILE ONLY EXISTS FOR THE TIME BEING SINCE I DID NOT KNOW WHERE TO PUT IT =====
+
+from dataclasses import dataclass
+from typing import Any, List
+
+from PIL.Image import Image
+
+
+@dataclass
+class Artifact:
+    type: str
+    value: Any
+    file_extension: str
+
+
+@dataclass
+class ImageArtifact(Artifact):
+    value: Image
+
+    def __init__(self, value: Image):
+        super().__init__(type="image", value=value, file_extension="png")
+
+
+@dataclass
+class VideoArtifact(Artifact):
+    value: List[Image]
+
+    def __init__(self, value: List[Image]):
+        super().__init__(type="video", value=value, file_extension="mp4")
diff --git a/finetrainers/data/dataloader.py b/finetrainers/data/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8b0a4b1f6253bd943cf9fe7b9e31c06aa060b35
--- /dev/null
+++ b/finetrainers/data/dataloader.py
@@ -0,0 +1,40 @@
+import pickle
+from typing import Any, Dict
+
+import torch.distributed.checkpoint.stateful
+import torchdata.stateful_dataloader
+
+from ..logging import get_logger
+
+
+logger = get_logger()
+
+
+class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful):
+    def __init__(
+        self,
+        rank: int,
+        dataset: torch.utils.data.IterableDataset,
+        batch_size: int = 1,
+        num_workers: int = 0,
+        collate_fn=None,
+    ) -> None:
+        super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn)
+
+        self._dp_rank = rank
+        self._rank_id = f"dp_rank_{rank}"
+
+    def state_dict(self) -> Dict[str, Any]:
+        # Store state only for dp rank to avoid replicating the same state across other dimensions
+        return {self._rank_id: pickle.dumps(super().state_dict())}
+
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+        # State being empty is valid
+        if not state_dict:
+            return
+
+        if self._rank_id not in state_dict:
+            logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
+            return
+
+        super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
diff --git a/finetrainers/data/dataset.py b/finetrainers/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..672bc1131b95900abfa6bdfcd50c6b644dc45556
--- /dev/null
+++ b/finetrainers/data/dataset.py
@@ -0,0 +1,844 @@
+import pathlib
+import random
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import datasets
+import datasets.data_files
+import datasets.distributed
+import datasets.exceptions
+import huggingface_hub
+import huggingface_hub.errors
+import numpy as np
+import PIL.Image
+import torch
+import torch.distributed.checkpoint.stateful
+from diffusers.utils import load_image, load_video
+from huggingface_hub import list_repo_files, repo_exists, snapshot_download
+from tqdm.auto import tqdm
+
+from .. import constants
+from .. import functional as FF
+from ..logging import get_logger
+from . import utils
+
+
+import decord  # isort:skip
+
+decord.bridge.set_bridge("torch")
+
+logger = get_logger()
+
+
+MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
+COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
+COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
+COMMON_IMAGE_FILES = ["image.txt", "images.txt"]
+
+
+class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+    def __init__(self, root: str, infinite: bool = False) -> None:
+        super().__init__()
+
+        self.root = pathlib.Path(root)
+        self.infinite = infinite
+
+        data = []
+        caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0))
+        for caption_file in caption_files:
+            data_file = self._find_data_file(caption_file)
+            if data_file:
+                data.append(
+                    {
+                        "caption": (self.root / caption_file).as_posix(),
+                        "image": (self.root / data_file).as_posix(),
+                    }
+                )
+
+        data = datasets.Dataset.from_list(data)
+        data = data.cast_column("image", datasets.Image(mode="RGB"))
+
+        self._data = data.to_iterable_dataset()
+        self._sample_index = 0
+        self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+    def _get_data_iter(self):
+        if self._sample_index == 0:
+            return iter(self._data)
+        return iter(self._data.skip(self._sample_index))
+
+    def __iter__(self):
+        while True:
+            for sample in self._get_data_iter():
+                self._sample_index += 1
+                sample["caption"] = _read_caption_from_file(sample["caption"])
+                sample["image"] = _preprocess_image(sample["image"])
+                yield sample
+
+            if not self.infinite:
+                logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+                break
+            else:
+                self._sample_index = 0
+
+    def load_state_dict(self, state_dict):
+        self._sample_index = state_dict["sample_index"]
+
+    def state_dict(self):
+        return {"sample_index": self._sample_index}
+
+    def _find_data_file(self, caption_file: str) -> str:
+        caption_file = pathlib.Path(caption_file)
+        data_file = None
+        found_data = 0
+
+        for extension in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
+            image_filename = caption_file.with_suffix(f".{extension}")
+            if image_filename.exists():
+                found_data += 1
+                data_file = image_filename
+
+        if found_data == 0:
+            return False
+        elif found_data > 1:
+            raise ValueError(
+                f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
+                f"file per caption file. The following extensions are supported:\n"
+                f"  - Images: {constants.SUPPORTED_IMAGE_FILE_EXTENSIONS}\n"
+            )
+
+        return data_file.as_posix()
+
+
+class VideoCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+    def __init__(self, root: str, infinite: bool = False) -> None:
+        super().__init__()
+
+        self.root = pathlib.Path(root)
+        self.infinite = infinite
+
+        data = []
+        caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0))
+        for caption_file in caption_files:
+            data_file = self._find_data_file(caption_file)
+            if data_file:
+                data.append(
+                    {
+                        "caption": (self.root / caption_file).as_posix(),
+                        "video": (self.root / data_file).as_posix(),
+                    }
+                )
+
+        data = datasets.Dataset.from_list(data)
+        data = data.cast_column("video", datasets.Video())
+
+        self._data = data.to_iterable_dataset()
+        self._sample_index = 0
+        self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+    def _get_data_iter(self):
+        if self._sample_index == 0:
+            return iter(self._data)
+        return iter(self._data.skip(self._sample_index))
+
+    def __iter__(self):
+        while True:
+            for sample in self._get_data_iter():
+                self._sample_index += 1
+                sample["caption"] = _read_caption_from_file(sample["caption"])
+                sample["video"] = _preprocess_video(sample["video"])
+                yield sample
+
+            if not self.infinite:
+                logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+                break
+            else:
+                self._sample_index = 0
+
+    def load_state_dict(self, state_dict):
+        self._sample_index = state_dict["sample_index"]
+
+    def state_dict(self):
+        return {"sample_index": self._sample_index}
+
+    def _find_data_file(self, caption_file: str) -> str:
+        caption_file = pathlib.Path(caption_file)
+        data_file = None
+        found_data = 0
+
+        for extension in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
+            video_filename = caption_file.with_suffix(f".{extension}")
+            if video_filename.exists():
+                found_data += 1
+                data_file = video_filename
+
+        if found_data == 0:
+            return False
+        elif found_data > 1:
+            raise ValueError(
+                f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
+                f"file per caption file. The following extensions are supported:\n"
+                f"  - Videos: {constants.SUPPORTED_VIDEO_FILE_EXTENSIONS}\n"
+            )
+
+        return data_file.as_posix()
+
+
+class ImageFileCaptionFileListDataset(
+    torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
+):
+    def __init__(self, root: str, infinite: bool = False) -> None:
+        super().__init__()
+
+        VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
+        VALID_IMAGE_FILES = ["image.txt", "images.txt"]
+
+        self.root = pathlib.Path(root)
+        self.infinite = infinite
+
+        data = []
+        existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
+        existing_image_files = [file for file in VALID_IMAGE_FILES if (self.root / file).exists()]
+
+        if len(existing_caption_files) == 0:
+            raise FileNotFoundError(
+                f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
+            )
+        if len(existing_image_files) == 0:
+            raise FileNotFoundError(
+                f"No image file found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
+            )
+        if len(existing_caption_files) > 1:
+            raise ValueError(
+                f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
+            )
+        if len(existing_image_files) > 1:
+            raise ValueError(
+                f"Multiple image files found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
+            )
+
+        caption_file = existing_caption_files[0]
+        image_file = existing_image_files[0]
+
+        with open((self.root / caption_file).as_posix(), "r") as f:
+            captions = f.read().splitlines()
+        with open((self.root / image_file).as_posix(), "r") as f:
+            images = f.read().splitlines()
+            images = [(self.root / image).as_posix() for image in images]
+
+        if len(captions) != len(images):
+            raise ValueError(f"Number of captions ({len(captions)}) must match number of images ({len(images)})")
+
+        for caption, image in zip(captions, images):
+            data.append({"caption": caption, "image": image})
+
+        data = datasets.Dataset.from_list(data)
+        data = data.cast_column("image", datasets.Image(mode="RGB"))
+
+        self._data = data.to_iterable_dataset()
+        self._sample_index = 0
+        self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+    def _get_data_iter(self):
+        if self._sample_index == 0:
+            return iter(self._data)
+        return iter(self._data.skip(self._sample_index))
+
+    def __iter__(self):
+        while True:
+            for sample in self._get_data_iter():
+                self._sample_index += 1
+                sample["image"] = _preprocess_image(sample["image"])
+                yield sample
+
+            if not self.infinite:
+                logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+                break
+            else:
+                self._sample_index = 0
+
+    def load_state_dict(self, state_dict):
+        self._sample_index = state_dict["sample_index"]
+
+    def state_dict(self):
+        return {"sample_index": self._sample_index}
+
+
+class VideoFileCaptionFileListDataset(
+    torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
+):
+    def __init__(self, root: str, infinite: bool = False) -> None:
+        super().__init__()
+
+        VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
+        VALID_VIDEO_FILES = ["video.txt", "videos.txt"]
+
+        self.root = pathlib.Path(root)
+        self.infinite = infinite
+
+        data = []
+        existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
+        existing_video_files = [file for file in VALID_VIDEO_FILES if (self.root / file).exists()]
+
+        if len(existing_caption_files) == 0:
+            raise FileNotFoundError(
+                f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
+            )
+        if len(existing_video_files) == 0:
+            raise FileNotFoundError(
+                f"No video file found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
+            )
+        if len(existing_caption_files) > 1:
+            raise ValueError(
+                f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
+            )
+        if len(existing_video_files) > 1:
+            raise ValueError(
+                f"Multiple video files found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
+            )
+
+        caption_file = existing_caption_files[0]
+        video_file = existing_video_files[0]
+
+        with open((self.root / caption_file).as_posix(), "r") as f:
+            captions = f.read().splitlines()
+        with open((self.root / video_file).as_posix(), "r") as f:
+            videos = f.read().splitlines()
+            videos = [(self.root / video).as_posix() for video in videos]
+
+        if len(captions) != len(videos):
+            raise ValueError(f"Number of captions ({len(captions)}) must match number of videos ({len(videos)})")
+
+        for caption, video in zip(captions, videos):
+            data.append({"caption": caption, "video": video})
+
+        data = datasets.Dataset.from_list(data)
+        data = data.cast_column("video", datasets.Video())
+
+        self._data = data.to_iterable_dataset()
+        self._sample_index = 0
+        self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+    def _get_data_iter(self):
+        if self._sample_index == 0:
+            return iter(self._data)
+        return iter(self._data.skip(self._sample_index))
+
+    def __iter__(self):
+        while True:
+            for sample in self._get_data_iter():
+                self._sample_index += 1
+                sample["video"] = _preprocess_video(sample["video"])
+                yield sample
+
+            if not self.infinite:
+                logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+                break
+            else:
+                self._sample_index = 0
+
+    def load_state_dict(self, state_dict):
+        self._sample_index = state_dict["sample_index"]
+
+    def state_dict(self):
+        return {"sample_index": self._sample_index}
+
+
+class ImageFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+    def __init__(self, root: str, infinite: bool = False) -> None:
+        super().__init__()
+
+        self.root = pathlib.Path(root)
+        self.infinite = infinite
+
+        data = datasets.load_dataset("imagefolder", data_dir=self.root.as_posix(), split="train")
+
+        self._data = data.to_iterable_dataset()
+        self._sample_index = 0
+        self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+    def _get_data_iter(self):
+        if self._sample_index == 0:
+            return iter(self._data)
+        return iter(self._data.skip(self._sample_index))
+
+    def __iter__(self):
+        while True:
+            for sample in self._get_data_iter():
+                self._sample_index += 1
+                sample["image"] = _preprocess_image(sample["image"])
+                yield sample
+
+            if not self.infinite:
+                logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+                break
+            else:
+                self._sample_index = 0
+
+    def load_state_dict(self, state_dict):
+        self._sample_index = state_dict["sample_index"]
+
+    def state_dict(self):
+        return {"sample_index": self._sample_index}
+
+
+class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+    def __init__(self, root: str, infinite: bool = False) -> None:
+        super().__init__()
+
+        self.root = pathlib.Path(root)
+        self.infinite = infinite
+
+        data = datasets.load_dataset("videofolder", data_dir=self.root.as_posix(), split="train")
+
+        self._data = data.to_iterable_dataset()
+        self._sample_index = 0
+        self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
+
+    def _get_data_iter(self):
+        if self._sample_index == 0:
+            return iter(self._data)
+        return iter(self._data.skip(self._sample_index))
+
+    def __iter__(self):
+        while True:
+            for sample in self._get_data_iter():
+                self._sample_index += 1
+                sample["video"] = _preprocess_video(sample["video"])
+                yield sample
+
+            if not self.infinite:
+                logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
+                break
+            else:
+                self._sample_index = 0
+
+    def load_state_dict(self, state_dict):
+        self._sample_index = state_dict["sample_index"]
+
+    def state_dict(self):
+        return {"sample_index": self._sample_index}
+
+
+class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+    def __init__(self, dataset_name: str, infinite: bool = False) -> None:
+        super().__init__()
+
+        self.dataset_name = dataset_name
+        self.infinite = infinite
+
+        data = datasets.load_dataset(dataset_name, split="train", streaming=True)
+        data = data.rename_column("txt", "caption")
+        for column_name in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
+            if column_name in data.column_names:
+                data = data.cast_column(column_name, datasets.Image(mode="RGB"))
+                data = data.rename_column(column_name, "image")
+
+        self._data = data
+        self._sample_index = 0
+        self._precomputable_once = False
+
+    def _get_data_iter(self):
+        if self._sample_index == 0:
+            return iter(self._data)
+        return iter(self._data.skip(self._sample_index))
+
+    def __iter__(self):
+        while True:
+            for sample in self._get_data_iter():
+                self._sample_index += 1
+                yield sample
+
+            if not self.infinite:
+                logger.warning(f"Dataset {self.dataset_name} has run out of data")
+                break
+            else:
+                # Reset offset for the next iteration
+                self._sample_index = 0
+                logger.warning(f"Dataset {self.dataset_name} is being re-looped")
+
+    def load_state_dict(self, state_dict):
+        self._sample_index = state_dict["sample_index"]
+
+    def state_dict(self):
+        return {"sample_index": self._sample_index}
+
+
+class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+    def __init__(self, dataset_name: str, infinite: bool = False) -> None:
+        super().__init__()
+
+        self.dataset_name = dataset_name
+        self.infinite = infinite
+
+        data = datasets.load_dataset(dataset_name, split="train", streaming=True)
+        data = data.rename_column("txt", "caption")
+        for column_name in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
+            if column_name in data.column_names:
+                data = data.cast_column(column_name, datasets.Video())
+                data = data.rename_column(column_name, "video")
+
+        self._data = data
+        self._sample_index = 0
+        self._precomputable_once = False
+
+    def _get_data_iter(self):
+        if self._sample_index == 0:
+            return iter(self._data)
+        return iter(self._data.skip(self._sample_index))
+
+    def __iter__(self):
+        while True:
+            for sample in self._get_data_iter():
+                self._sample_index += 1
+                yield sample
+
+            if not self.infinite:
+                logger.warning(f"Dataset {self.dataset_name} has run out of data")
+                break
+            else:
+                # Reset offset for the next iteration
+                self._sample_index = 0
+                logger.warning(f"Dataset {self.dataset_name} is being re-looped")
+
+    def load_state_dict(self, state_dict):
+        self._sample_index = state_dict["sample_index"]
+
+    def state_dict(self):
+        return {"sample_index": self._sample_index}
+
+
+class ValidationDataset(torch.utils.data.IterableDataset):
+    def __init__(self, filename: str):
+        super().__init__()
+
+        self.filename = pathlib.Path(filename)
+
+        if not self.filename.exists():
+            raise FileNotFoundError(f"File {self.filename.as_posix()} does not exist")
+
+        if self.filename.suffix == ".csv":
+            data = datasets.load_dataset("csv", data_files=self.filename.as_posix(), split="train")
+        elif self.filename.suffix == ".json":
+            data = datasets.load_dataset("json", data_files=self.filename.as_posix(), split="train", field="data")
+        elif self.filename.suffix == ".parquet":
+            data = datasets.load_dataset("parquet", data_files=self.filename.as_posix(), split="train")
+        elif self.filename.suffix == ".arrow":
+            data = datasets.load_dataset("arrow", data_files=self.filename.as_posix(), split="train")
+        else:
+            _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"]
+            raise ValueError(
+                f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}"
+            )
+
+        self._data = data.to_iterable_dataset()
+
+    def __iter__(self):
+        for sample in self._data:
+            # For consistency reasons, we mandate that "caption" is always present in the validation dataset.
+            # However, since the model specifications use "prompt", we create an alias here.
+            sample["prompt"] = sample["caption"]
+
+            # Load image or video if the path is provided
+            # TODO(aryan): need to handle custom columns here for control conditions
+            sample["image"] = None
+            sample["video"] = None
+
+            if sample.get("image_path", None) is not None:
+                image_path = pathlib.Path(sample["image_path"])
+                if not image_path.is_file():
+                    logger.warning(f"Image file {image_path.as_posix()} does not exist.")
+                else:
+                    sample["image"] = load_image(sample["image_path"])
+
+            if sample.get("video_path", None) is not None:
+                video_path = pathlib.Path(sample["video_path"])
+                if not video_path.is_file():
+                    logger.warning(f"Video file {video_path.as_posix()} does not exist.")
+                else:
+                    sample["video"] = load_video(sample["video_path"])
+
+            sample = {k: v for k, v in sample.items() if v is not None}
+            yield sample
+
+
+class IterableDatasetPreprocessingWrapper(
+    torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
+):
+    def __init__(
+        self,
+        dataset: torch.utils.data.IterableDataset,
+        dataset_type: str,
+        id_token: Optional[str] = None,
+        image_resolution_buckets: List[Tuple[int, int]] = None,
+        video_resolution_buckets: List[Tuple[int, int, int]] = None,
+        reshape_mode: str = "bicubic",
+        remove_common_llm_caption_prefixes: bool = False,
+        **kwargs,
+    ):
+        super().__init__()
+
+        self.dataset = dataset
+        self.dataset_type = dataset_type
+        self.id_token = id_token
+        self.image_resolution_buckets = image_resolution_buckets
+        self.video_resolution_buckets = video_resolution_buckets
+        self.reshape_mode = reshape_mode
+        self.remove_common_llm_caption_prefixes = remove_common_llm_caption_prefixes
+
+        logger.info(
+            f"Initializing IterableDatasetPreprocessingWrapper for the dataset with the following configuration:\n"
+            f"  - Dataset Type: {dataset_type}\n"
+            f"  - ID Token: {id_token}\n"
+            f"  - Image Resolution Buckets: {image_resolution_buckets}\n"
+            f"  - Video Resolution Buckets: {video_resolution_buckets}\n"
+            f"  - Reshape Mode: {reshape_mode}\n"
+            f"  - Remove Common LLM Caption Prefixes: {remove_common_llm_caption_prefixes}\n"
+        )
+
+    def __iter__(self):
+        logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset")
+        for sample in iter(self.dataset):
+            if self.dataset_type == "image":
+                if self.image_resolution_buckets:
+                    sample["image"] = FF.resize_to_nearest_bucket_image(
+                        sample["image"], self.image_resolution_buckets, self.reshape_mode
+                    )
+            elif self.dataset_type == "video":
+                if self.video_resolution_buckets:
+                    sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
+                        sample["video"], self.video_resolution_buckets, self.reshape_mode
+                    )
+                    if _first_frame_only:
+                        msg = (
+                            "The number of frames in the video is less than the minimum bucket size "
+                            "specified. The first frame is being used as a single frame video. This "
+                            "message is logged at the first occurence and for every 128th occurence "
+                            "after that."
+                        )
+                        logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE", msg, frequency=128)
+                        sample["video"] = sample["video"][0]
+
+            if self.remove_common_llm_caption_prefixes:
+                sample["caption"] = FF.remove_prefix(sample["caption"], constants.COMMON_LLM_START_PHRASES)
+
+            if self.id_token is not None:
+                sample["caption"] = f"{self.id_token} {sample['caption']}"
+
+            yield sample
+
+    def load_state_dict(self, state_dict):
+        self.dataset.load_state_dict(state_dict["dataset"])
+
+    def state_dict(self):
+        return {"dataset": self.dataset.state_dict()}
+
+
+class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
+    def __init__(self, datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False):
+        super().__init__()
+
+        self.datasets = datasets
+        self.buffer_size = buffer_size
+        self.shuffle = shuffle
+
+        logger.info(
+            f"Initializing IterableCombinedDataset with the following configuration:\n"
+            f"  - Number of Datasets: {len(datasets)}\n"
+            f"  - Buffer Size: {buffer_size}\n"
+            f"  - Shuffle: {shuffle}\n"
+        )
+
+    def __iter__(self):
+        logger.info(f"Starting IterableCombinedDataset with {len(self.datasets)} datasets")
+        iterators = [iter(dataset) for dataset in self.datasets]
+        buffer = []
+        per_iter = max(1, self.buffer_size // len(iterators))
+
+        for index, it in enumerate(iterators):
+            for _ in tqdm(range(per_iter), desc=f"Filling buffer from data iterator {index}"):
+                try:
+                    buffer.append((it, next(it)))
+                except StopIteration:
+                    continue
+
+        while len(buffer) > 0:
+            idx = 0
+            if self.shuffle:
+                idx = random.randint(0, len(buffer) - 1)
+            current_it, sample = buffer.pop(idx)
+            yield sample
+            try:
+                buffer.append((current_it, next(current_it)))
+            except StopIteration:
+                pass
+
+    def load_state_dict(self, state_dict):
+        for dataset, dataset_state_dict in zip(self.datasets, state_dict["datasets"]):
+            dataset.load_state_dict(dataset_state_dict)
+
+    def state_dict(self):
+        return {"datasets": [dataset.state_dict() for dataset in self.datasets]}
+
+
+# TODO(aryan): maybe write a test for this
+def initialize_dataset(
+    dataset_name_or_root: str, dataset_type: str = "video", streaming: bool = True, infinite: bool = False
+) -> torch.utils.data.IterableDataset:
+    assert dataset_type in ["image", "video"]
+
+    try:
+        does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset")
+    except huggingface_hub.errors.HFValidationError:
+        does_repo_exist_on_hub = False
+
+    if does_repo_exist_on_hub:
+        return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite)
+    else:
+        return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite)
+
+
+def combine_datasets(
+    datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False
+) -> torch.utils.data.IterableDataset:
+    return IterableCombinedDataset(datasets=datasets, buffer_size=buffer_size, shuffle=shuffle)
+
+
+def wrap_iterable_dataset_for_preprocessing(
+    dataset: torch.utils.data.IterableDataset, dataset_type: str, config: Dict[str, Any]
+) -> torch.utils.data.IterableDataset:
+    return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config)
+
+
+def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False):
+    root = pathlib.Path(dataset_name_or_root)
+    supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"]
+    metadata_files = [root / metadata_file for metadata_file in supported_metadata_files]
+    metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()]
+
+    if len(metadata_files) > 1:
+        raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.")
+
+    if len(metadata_files) == 1:
+        if dataset_type == "image":
+            dataset = ImageFolderDataset(root.as_posix(), infinite=infinite)
+        else:
+            dataset = VideoFolderDataset(root.as_posix(), infinite=infinite)
+        return dataset
+
+    if _has_data_caption_file_pairs(root, remote=False):
+        if dataset_type == "image":
+            dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite)
+        else:
+            dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite)
+    elif _has_data_file_caption_file_lists(root, remote=False):
+        if dataset_type == "image":
+            dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
+        else:
+            dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
+    else:
+        raise ValueError(
+            f"Could not find any supported dataset structure in the directory {root}. Please open an issue at "
+            f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will "
+            f"help you set it up."
+        )
+
+    return dataset
+
+
+def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool = False):
+    repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
+    if _has_data_caption_file_pairs(repo_file_list, remote=True):
+        return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
+    elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
+        return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
+    else:
+        return _initialize_webdataset(dataset_name, dataset_type, infinite)
+
+
+def _initialize_data_caption_file_dataset_from_hub(
+    dataset_name: str, dataset_type: str, infinite: bool = False
+) -> torch.utils.data.IterableDataset:
+    logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
+    dataset_root = snapshot_download(dataset_name, repo_type="dataset")
+    if dataset_type == "image":
+        return ImageCaptionFilePairDataset(dataset_root, infinite=infinite)
+    else:
+        return VideoCaptionFilePairDataset(dataset_root, infinite=infinite)
+
+
+def _initialize_data_file_caption_file_dataset_from_hub(
+    dataset_name: str, dataset_type: str, infinite: bool = False
+) -> torch.utils.data.IterableDataset:
+    logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
+    dataset_root = snapshot_download(dataset_name, repo_type="dataset")
+    if dataset_type == "image":
+        return ImageFileCaptionFileListDataset(dataset_root, infinite=infinite)
+    else:
+        return VideoFileCaptionFileListDataset(dataset_root, infinite=infinite)
+
+
+def _initialize_webdataset(
+    dataset_name: str, dataset_type: str, infinite: bool = False
+) -> torch.utils.data.IterableDataset:
+    logger.info(f"Streaming webdataset {dataset_name} from the HF Hub")
+    if dataset_type == "image":
+        return ImageWebDataset(dataset_name, infinite=infinite)
+    else:
+        return VideoWebDataset(dataset_name, infinite=infinite)
+
+
+def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
+    # TODO(aryan): this logic can be improved
+    if not remote:
+        caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0)
+        for caption_file in caption_files:
+            caption_file = pathlib.Path(caption_file)
+            for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
+                data_filename = caption_file.with_suffix(f".{extension}")
+                if data_filename.exists():
+                    return True
+        return False
+    else:
+        caption_files = [file for file in root if file.endswith(".txt")]
+        for caption_file in caption_files:
+            caption_file = pathlib.Path(caption_file)
+            for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
+                data_filename = caption_file.with_suffix(f".{extension}").name
+                if data_filename in root:
+                    return True
+        return False
+
+
+def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
+    # TODO(aryan): this logic can be improved
+    if not remote:
+        file_list = {x.name for x in root.iterdir()}
+        has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES)
+        has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES)
+        has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES)
+        return has_caption_files and (has_video_files or has_image_files)
+    else:
+        has_caption_files = any(file in root for file in COMMON_CAPTION_FILES)
+        has_video_files = any(file in root for file in COMMON_VIDEO_FILES)
+        has_image_files = any(file in root for file in COMMON_IMAGE_FILES)
+        return has_caption_files and (has_video_files or has_image_files)
+
+
+def _read_caption_from_file(filename: str) -> str:
+    with open(filename, "r") as f:
+        return f.read().strip()
+
+
+def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
+    image = image.convert("RGB")
+    image = np.array(image).astype(np.float32)
+    image = torch.from_numpy(image)
+    image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
+    return image
+
+
+def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
+    video = video.get_batch(list(range(len(video))))
+    video = video.permute(0, 3, 1, 2).contiguous()
+    video = video.float() / 127.5 - 1.0
+    return video
diff --git a/finetrainers/data/precomputation.py b/finetrainers/data/precomputation.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b1f020d9d3e715dcaef769604ab819f1ccfc5a4
--- /dev/null
+++ b/finetrainers/data/precomputation.py
@@ -0,0 +1,163 @@
+import pathlib
+from typing import Any, Callable, Dict, Iterable, Optional
+
+import torch
+from tqdm.auto import tqdm
+
+from .. import utils
+
+
+class DistributedDataPreprocessor:
+    def __init__(
+        self,
+        rank: int,
+        num_items: int,
+        processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
+        save_dir: str,
+    ) -> None:
+        self._rank = rank
+        self._num_items = num_items
+        self._processor_fn = processor_fn
+        self._save_dir = pathlib.Path(save_dir)
+
+        self._cached_samples = []
+        self._preprocessed_iterator: "PreprocessedDataIterable" = None
+
+        self._save_dir.mkdir(parents=True, exist_ok=True)
+
+        subdirectories = [f for f in self._save_dir.iterdir() if f.is_dir()]
+        utils.delete_files(subdirectories)
+
+    def consume(
+        self,
+        data_type: str,
+        components: Dict[str, Any],
+        data_iterator,
+        generator: Optional[torch.Generator] = None,
+        cache_samples: bool = False,
+        use_cached_samples: bool = False,
+        drop_samples: bool = False,
+    ) -> Iterable[Dict[str, Any]]:
+        if data_type not in self._processor_fn.keys():
+            raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
+        if cache_samples:
+            if use_cached_samples:
+                raise ValueError("Cannot cache and use cached samples at the same time.")
+            if drop_samples:
+                raise ValueError("Cannot cache and drop samples at the same time.")
+
+        for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items):
+            if use_cached_samples:
+                item = self._cached_samples[i]
+            else:
+                item = next(data_iterator)
+                if cache_samples:
+                    self._cached_samples.append(item)
+            item = self._processor_fn[data_type](**item, **components, generator=generator)
+            _save_item(self._rank, i, item, self._save_dir, data_type)
+
+        if drop_samples:
+            del self._cached_samples
+            self._cached_samples = []
+            utils.free_memory()
+
+        self._preprocessed_iterator = PreprocessedDataIterable(self._rank, self._save_dir, data_type)
+        return iter(self._preprocessed_iterator)
+
+    def consume_once(
+        self,
+        data_type: str,
+        components: Dict[str, Any],
+        data_iterator,
+        generator: Optional[torch.Generator] = None,
+        cache_samples: bool = False,
+        use_cached_samples: bool = False,
+        drop_samples: bool = False,
+    ) -> Iterable[Dict[str, Any]]:
+        if data_type not in self._processor_fn.keys():
+            raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
+        if cache_samples:
+            if use_cached_samples:
+                raise ValueError("Cannot cache and use cached samples at the same time.")
+            if drop_samples:
+                raise ValueError("Cannot cache and drop samples at the same time.")
+
+        for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items):
+            if use_cached_samples:
+                item = self._cached_samples[i]
+            else:
+                item = next(data_iterator)
+                if cache_samples:
+                    self._cached_samples.append(item)
+            item = self._processor_fn[data_type](**item, **components, generator=generator)
+            _save_item(self._rank, i, item, self._save_dir, data_type)
+
+        if drop_samples:
+            del self._cached_samples
+            self._cached_samples = []
+            utils.free_memory()
+
+        self._preprocessed_iterator = PreprocessedOnceDataIterable(self._rank, self._save_dir, data_type)
+        return iter(self._preprocessed_iterator)
+
+    @property
+    def requires_data(self):
+        if self._preprocessed_iterator is None:
+            return True
+        return self._preprocessed_iterator.requires_data
+
+
+class PreprocessedDataIterable:
+    def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
+        self._rank = rank
+        self._save_dir = pathlib.Path(save_dir)
+        self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt")))
+        self._data_type = data_type
+
+        self._requires_data = False
+
+    def __iter__(self) -> Iterable[Dict[str, Any]]:
+        for i in range(self._num_items):
+            if i == self._num_items - 1:
+                self._requires_data = True
+            yield _load_item(self._rank, i, self._save_dir, self._data_type)
+
+    def __len__(self) -> int:
+        return self._num_items
+
+    @property
+    def requires_data(self):
+        return self._requires_data
+
+
+class PreprocessedOnceDataIterable:
+    def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
+        self._rank = rank
+        self._save_dir = pathlib.Path(save_dir)
+        self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt")))
+        self._data_type = data_type
+
+        self._requires_data = False
+
+    def __iter__(self) -> Iterable[Dict[str, Any]]:
+        index = 0
+        while True:
+            yield _load_item(self._rank, index, self._save_dir, self._data_type)
+            index = (index + 1) % self._num_items
+
+    def __len__(self) -> int:
+        return self._num_items
+
+    @property
+    def requires_data(self):
+        return self._requires_data
+
+
+def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None:
+    filename = directory / f"{data_type}-{rank}-{index}.pt"
+    torch.save(item, filename.as_posix())
+
+
+def _load_item(rank: int, index: int, directory: pathlib.Path, data_type: str) -> Dict[str, Any]:
+    filename = directory / f"{data_type}-{rank}-{index}.pt"
+    return torch.load(filename.as_posix(), weights_only=True)
diff --git a/finetrainers/data/sampler.py b/finetrainers/data/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d9d650e1d610e8ce91b4168a9960479cfcfe8f7
--- /dev/null
+++ b/finetrainers/data/sampler.py
@@ -0,0 +1,58 @@
+from typing import Any, Dict, List, Tuple
+
+import torch
+
+
+class ResolutionSampler:
+    def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None:
+        self.batch_size = batch_size
+        self.dim_keys = dim_keys
+        assert dim_keys is not None, "dim_keys must be provided"
+
+        self._chosen_leader_key = None
+        self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {}
+        self._satisfied_buckets: List[Dict[Any, Any]] = []
+
+    def consume(self, *dict_items: Dict[Any, Any]) -> None:
+        if self._chosen_leader_key is None:
+            self._determine_leader_item(*dict_items)
+        self._update_buckets(*dict_items)
+
+    def get_batch(self) -> List[Dict[str, Any]]:
+        return list(zip(*self._satisfied_buckets.pop(-1)))
+
+    @property
+    def is_ready(self) -> bool:
+        return len(self._satisfied_buckets) > 0
+
+    def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None:
+        num_observed = 0
+        for dict_item in dict_items:
+            for key in self.dim_keys.keys():
+                if key in dict_item.keys():
+                    self._chosen_leader_key = key
+                    if not torch.is_tensor(dict_item[key]):
+                        raise ValueError(f"Leader key {key} must be a tensor")
+                    num_observed += 1
+        if num_observed > 1:
+            raise ValueError(
+                f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys"
+            )
+        if self._chosen_leader_key is None:
+            raise ValueError("No leader key found in provided list of data dictionaries")
+
+    def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None:
+        chosen_value = [
+            dict_item[self._chosen_leader_key]
+            for dict_item in dict_items
+            if self._chosen_leader_key in dict_item.keys()
+        ]
+        if len(chosen_value) == 0:
+            raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries")
+        chosen_value = chosen_value[0]
+        dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key])
+        if dims not in self._unsatisfied_buckets:
+            self._unsatisfied_buckets[dims] = []
+        self._unsatisfied_buckets[dims].append(dict_items)
+        if len(self._unsatisfied_buckets[dims]) == self.batch_size:
+            self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims))
diff --git a/finetrainers/data/utils.py b/finetrainers/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bd507f348efc0e123532e8082502c7cfad956ed
--- /dev/null
+++ b/finetrainers/data/utils.py
@@ -0,0 +1,20 @@
+import pathlib
+from typing import List
+
+
+def find_files(root: str, pattern: str, depth: int = 0) -> List[str]:
+    root_path = pathlib.Path(root)
+    result_files = []
+
+    def within_depth(path: pathlib.Path) -> bool:
+        return len(path.relative_to(root_path).parts) <= depth
+
+    if depth == 0:
+        result_files.extend([str(file) for file in root_path.glob(pattern)])
+    else:
+        # rglob matches all levels, but we filter by depth
+        for file in root_path.rglob(pattern):
+            if file.is_file() and within_depth(file.parent):
+                result_files.append(str(file))
+
+    return result_files
diff --git a/finetrainers/dataset.py b/finetrainers/dataset.py
deleted file mode 100644
index b40a9bffe0b05b50e37ee3f4cb773a6a4c10ede9..0000000000000000000000000000000000000000
--- a/finetrainers/dataset.py
+++ /dev/null
@@ -1,564 +0,0 @@
-import json
-import os
-import random
-from pathlib import Path
-from typing import Any, Dict, List, Optional, Tuple
-
-import numpy as np
-import pandas as pd
-import torch
-import torchvision.transforms as TT
-import torchvision.transforms.functional as TTF
-from accelerate.logging import get_logger
-from torch.utils.data import Dataset, Sampler
-from torchvision import transforms
-from torchvision.transforms import InterpolationMode
-from torchvision.transforms.functional import resize
-
-import gc
-import time
-import resource
-
-# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
-# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
-import decord  # isort:skip
-
-decord.bridge.set_bridge("torch")
-
-from .constants import (  # noqa
-    COMMON_LLM_START_PHRASES,
-    PRECOMPUTED_CONDITIONS_DIR_NAME,
-    PRECOMPUTED_DIR_NAME,
-    PRECOMPUTED_LATENTS_DIR_NAME,
-)
-
-# Decord is causing us some issues!
-# Let's try to increase file descriptor limits to avoid this error:
-#
-#     decord._ffi.base.DECORDError: Resource temporarily unavailable
-try:
-    soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
-    print(f"Current file descriptor limits: soft={soft}, hard={hard}")
-    
-    # Try to increase to hard limit if possible
-    if soft < hard:
-        resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
-        new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
-        print(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
-except Exception as e:
-    print(f"Could not check or update file descriptor limits: {e}")
-
-logger = get_logger(__name__)
-
-# TODO(aryan): This needs a refactor with separation of concerns.
-# Images should be handled separately. Videos should be handled separately.
-# Loading should be handled separately.
-# Preprocessing (aspect ratio, resizing) should be handled separately.
-# URL loading should be handled.
-# Parquet format should be handled.
-# Loading from ZIP should be handled.
-class ImageOrVideoDataset(Dataset):
-    def __init__(
-        self,
-        data_root: str,
-        caption_column: str,
-        video_column: str,
-        resolution_buckets: List[Tuple[int, int, int]],
-        dataset_file: Optional[str] = None,
-        id_token: Optional[str] = None,
-        remove_llm_prefixes: bool = False,
-    ) -> None:
-        super().__init__()
-
-        self.data_root = Path(data_root)
-        self.dataset_file = dataset_file
-        self.caption_column = caption_column
-        self.video_column = video_column
-        self.id_token = f"{id_token.strip()} " if id_token else ""
-        self.resolution_buckets = resolution_buckets
-
-        # Four methods of loading data are supported.
-        #   - Using a CSV: caption_column and video_column must be some column in the CSV. One could
-        #     make use of other columns too, such as a motion score or aesthetic score, by modifying the
-        #     logic in CSV processing.
-        #   - Using two files containing line-separate captions and relative paths to videos.
-        #   - Using a JSON file containing a list of dictionaries, where each dictionary has a `caption_column` and `video_column` key.
-        #   - Using a JSONL file containing a list of line-separated dictionaries, where each dictionary has a `caption_column` and `video_column` key.
-        # For a more detailed explanation about preparing dataset format, checkout the README.
-        if dataset_file is None:
-            (
-                self.prompts,
-                self.video_paths,
-            ) = self._load_dataset_from_local_path()
-        elif dataset_file.endswith(".csv"):
-            (
-                self.prompts,
-                self.video_paths,
-            ) = self._load_dataset_from_csv()
-        elif dataset_file.endswith(".json"):
-            (
-                self.prompts,
-                self.video_paths,
-            ) = self._load_dataset_from_json()
-        elif dataset_file.endswith(".jsonl"):
-            (
-                self.prompts,
-                self.video_paths,
-            ) = self._load_dataset_from_jsonl()
-        else:
-            raise ValueError(
-                "Expected `--dataset_file` to be a path to a CSV file or a directory containing line-separated text prompts and video paths."
-            )
-
-        if len(self.video_paths) != len(self.prompts):
-            raise ValueError(
-                f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
-            )
-
-        # Clean LLM start phrases
-        if remove_llm_prefixes:
-            for i in range(len(self.prompts)):
-                self.prompts[i] = self.prompts[i].strip()
-                for phrase in COMMON_LLM_START_PHRASES:
-                    if self.prompts[i].startswith(phrase):
-                        self.prompts[i] = self.prompts[i].removeprefix(phrase).strip()
-
-        self.video_transforms = transforms.Compose(
-            [
-                transforms.Lambda(self.scale_transform),
-                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
-            ]
-        )
-
-    @staticmethod
-    def scale_transform(x):
-        return x / 255.0
-
-    def __len__(self) -> int:
-        return len(self.video_paths)
-
-    def __getitem__(self, index: int) -> Dict[str, Any]:
-        if isinstance(index, list):
-            # Here, index is actually a list of data objects that we need to return.
-            # The BucketSampler should ideally return indices. But, in the sampler, we'd like
-            # to have information about num_frames, height and width. Since this is not stored
-            # as metadata, we need to read the video to get this information. You could read this
-            # information without loading the full video in memory, but we do it anyway. In order
-            # to not load the video twice (once to get the metadata, and once to return the loaded video
-            # based on sampled indices), we cache it in the BucketSampler. When the sampler is
-            # to yield, we yield the cache data instead of indices. So, this special check ensures
-            # that data is not loaded a second time. PRs are welcome for improvements.
-            return index
-
-        prompt = self.id_token + self.prompts[index]
-
-        video_path: Path = self.video_paths[index]
-        if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
-            video = self._preprocess_image(video_path)
-        else:
-            video = self._preprocess_video(video_path)
-
-        return {
-            "prompt": prompt,
-            "video": video,
-            "video_metadata": {
-                "num_frames": video.shape[0],
-                "height": video.shape[2],
-                "width": video.shape[3],
-            },
-        }
-
-    def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
-        if not self.data_root.exists():
-            raise ValueError("Root folder for videos does not exist")
-
-        prompt_path = self.data_root.joinpath(self.caption_column)
-        video_path = self.data_root.joinpath(self.video_column)
-
-        if not prompt_path.exists() or not prompt_path.is_file():
-            raise ValueError(
-                "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
-            )
-        if not video_path.exists() or not video_path.is_file():
-            raise ValueError(
-                "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
-            )
-
-        with open(prompt_path, "r", encoding="utf-8") as file:
-            prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
-        with open(video_path, "r", encoding="utf-8") as file:
-            video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
-
-        if any(not path.is_file() for path in video_paths):
-            raise ValueError(
-                f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
-            )
-
-        return prompts, video_paths
-
-    def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
-        df = pd.read_csv(self.dataset_file)
-        prompts = df[self.caption_column].tolist()
-        video_paths = df[self.video_column].tolist()
-        video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
-
-        if any(not path.is_file() for path in video_paths):
-            raise ValueError(
-                f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
-            )
-
-        return prompts, video_paths
-
-    def _load_dataset_from_json(self) -> Tuple[List[str], List[str]]:
-        with open(self.dataset_file, "r", encoding="utf-8") as file:
-            data = json.load(file)
-
-        prompts = [entry[self.caption_column] for entry in data]
-        video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data]
-
-        if any(not path.is_file() for path in video_paths):
-            raise ValueError(
-                f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
-            )
-
-        return prompts, video_paths
-
-    def _load_dataset_from_jsonl(self) -> Tuple[List[str], List[str]]:
-        with open(self.dataset_file, "r", encoding="utf-8") as file:
-            data = [json.loads(line) for line in file]
-
-        prompts = [entry[self.caption_column] for entry in data]
-        video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data]
-
-        if any(not path.is_file() for path in video_paths):
-            raise ValueError(
-                f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
-            )
-
-        return prompts, video_paths
-
-    def _preprocess_image(self, path: Path) -> torch.Tensor:
-        # TODO(aryan): Support alpha channel in future by whitening background
-        image = TTF.Image.open(path.as_posix()).convert("RGB")
-        image = TTF.to_tensor(image)
-        image = image * 2.0 - 1.0
-        image = image.unsqueeze(0).contiguous()  # [C, H, W] -> [1, C, H, W] (1-frame video)
-        return image
-
-    def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
-        """
-        Loads a single video, or latent and prompt embedding, based on initialization parameters.
-        Returns a [F, C, H, W] video tensor.
-        """
-        max_retries = 3
-        retry_delay = 1.0  # seconds
-        
-        for attempt in range(max_retries):
-            try:
-                # Create video reader
-                video_reader = decord.VideoReader(uri=path.as_posix())
-                video_num_frames = len(video_reader)
-
-                # Process frames
-                indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
-                frames = video_reader.get_batch(indices)
-                frames = frames[: self.max_num_frames].float()
-                frames = frames.permute(0, 3, 1, 2).contiguous()
-                frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
-                
-                # Explicitly clean up resources
-                del video_reader
-                
-                # Force garbage collection occasionally
-                if random.random() < 0.05:  # 5% chance
-                    gc.collect()
-                    
-                return frames
-                
-            except decord._ffi.base.DECORDError as e:
-                # Log the error
-                error_msg = str(e)
-                if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1:
-                    logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}")
-                    
-                    # Clean up and wait before retrying
-                    gc.collect()
-                    time.sleep(retry_delay * (attempt + 1))  # Increasing backoff
-                else:
-                    # Either not a resource error or we've run out of retries
-                    logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}")
-                    raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}")
-
-
-class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset):
-    def __init__(self, *args, **kwargs) -> None:
-        super().__init__(*args, **kwargs)
-
-        self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
-
-    def _preprocess_image(self, path: Path) -> torch.Tensor:
-        # TODO(aryan): Support alpha channel in future by whitening background
-        image = TTF.Image.open(path.as_posix()).convert("RGB")
-        image = TTF.to_tensor(image)
-
-        nearest_res = self._find_nearest_resolution(image.shape[1], image.shape[2])
-        image = resize(image, nearest_res)
-
-        image = image * 2.0 - 1.0
-        image = image.unsqueeze(0).contiguous()
-        return image
-
-    def _preprocess_video(self, path: Path) -> torch.Tensor:
-        max_retries = 3
-        retry_delay = 1.0  # seconds
-        
-        for attempt in range(max_retries):
-            try:
-                # Create video reader
-                video_reader = decord.VideoReader(uri=path.as_posix())
-                video_num_frames = len(video_reader)
-                
-                # Find appropriate bucket for the video
-                video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
-                
-                if not video_buckets:
-                    _, h, w = self.resolution_buckets[0]
-                    video_buckets = [(1, h, w)]
-
-                nearest_frame_bucket = min(
-                    video_buckets,
-                    key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
-                    default=video_buckets[0],
-                )[0]
-
-                # Extract and process frames
-                frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
-                frames = video_reader.get_batch(frame_indices)
-                frames = frames[:nearest_frame_bucket].float()
-                frames = frames.permute(0, 3, 1, 2).contiguous()
-
-                nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
-                frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
-                frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
-                
-                # Explicitly clean up resources
-                del video_reader
-                
-                # Force garbage collection occasionally
-                if random.random() < 0.05:  # 5% chance
-                    gc.collect()
-                    
-                return frames
-                
-            except decord._ffi.base.DECORDError as e:
-                # Log the error
-                error_msg = str(e)
-                if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1:
-                    logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}")
-                    
-                    # Clean up and wait before retrying
-                    gc.collect()
-                    time.sleep(retry_delay * (attempt + 1))  # Increasing backoff
-                else:
-                    # Either not a resource error or we've run out of retries
-                    logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}")
-                    raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}")
-
-    def _find_nearest_resolution(self, height, width):
-        nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
-        return nearest_res[1], nearest_res[2]
-
-
-class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset):
-    def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
-        super().__init__(*args, **kwargs)
-
-        self.video_reshape_mode = video_reshape_mode
-        self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
-
-    def _resize_for_rectangle_crop(self, arr, image_size):
-        reshape_mode = self.video_reshape_mode
-        if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
-            arr = resize(
-                arr,
-                size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
-                interpolation=InterpolationMode.BICUBIC,
-            )
-        else:
-            arr = resize(
-                arr,
-                size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
-                interpolation=InterpolationMode.BICUBIC,
-            )
-
-        h, w = arr.shape[2], arr.shape[3]
-        arr = arr.squeeze(0)
-
-        delta_h = h - image_size[0]
-        delta_w = w - image_size[1]
-
-        if reshape_mode == "random" or reshape_mode == "none":
-            top = np.random.randint(0, delta_h + 1)
-            left = np.random.randint(0, delta_w + 1)
-        elif reshape_mode == "center":
-            top, left = delta_h // 2, delta_w // 2
-        else:
-            raise NotImplementedError
-        arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
-        return arr
-
-    def _preprocess_video(self, path: Path) -> torch.Tensor:
-        max_retries = 3
-        retry_delay = 1.0  # seconds
-        
-        for attempt in range(max_retries):
-            try:
-                # Create video reader
-                video_reader = decord.VideoReader(uri=path.as_posix())
-                video_num_frames = len(video_reader)
-                
-                # Find appropriate bucket for the video
-                video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
-                
-                if not video_buckets:
-                    _, h, w = self.resolution_buckets[0]
-                    video_buckets = [(1, h, w)]
-
-                nearest_frame_bucket = min(
-                    video_buckets,
-                    key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
-                    default=video_buckets[0],
-                )[0]
-
-                # Extract and process frames
-                frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
-                frames = video_reader.get_batch(frame_indices)
-                frames = frames[:nearest_frame_bucket].float()
-                frames = frames.permute(0, 3, 1, 2).contiguous()
-
-                # Fix: Change self.resolutions to self.resolution_buckets to match the class attribute
-                nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
-                frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
-                frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
-                
-                # Explicitly clean up resources
-                del video_reader
-                
-                # Force garbage collection occasionally
-                if random.random() < 0.05:  # 5% chance
-                    gc.collect()
-                    
-                return frames
-                
-            except decord._ffi.base.DECORDError as e:
-                # Log the error
-                error_msg = str(e)
-                if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1:
-                    logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}")
-                    
-                    # Clean up and wait before retrying
-                    gc.collect()
-                    time.sleep(retry_delay * (attempt + 1))  # Increasing backoff
-                else:
-                    # Either not a resource error or we've run out of retries
-                    logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}")
-                    raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}")
-        
-    def _find_nearest_resolution(self, height, width):
-        nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
-        return nearest_res[1], nearest_res[2]
-
-
-class PrecomputedDataset(Dataset):
-    def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None:
-        super().__init__()
-
-        self.data_root = Path(data_root)
-
-        if model_name and cleaned_model_id:
-            precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
-            self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
-            self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
-        else:
-            self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
-            self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME
-
-        self.latent_conditions = sorted(os.listdir(self.latents_path))
-        self.text_conditions = sorted(os.listdir(self.conditions_path))
-
-        assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match"
-
-    def __len__(self) -> int:
-        return len(self.latent_conditions)
-
-    def __getitem__(self, index: int) -> Dict[str, Any]:
-        conditions = {}
-        latent_path = self.latents_path / self.latent_conditions[index]
-        condition_path = self.conditions_path / self.text_conditions[index]
-        conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True)
-        conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True)
-        return conditions
-
-
-class BucketSampler(Sampler):
-    r"""
-    PyTorch Sampler that groups 3D data by height, width and frames.
-
-    Args:
-        data_source (`ImageOrVideoDataset`):
-            A PyTorch dataset object that is an instance of `ImageOrVideoDataset`.
-        batch_size (`int`, defaults to `8`):
-            The batch size to use for training.
-        shuffle (`bool`, defaults to `True`):
-            Whether or not to shuffle the data in each batch before dispatching to dataloader.
-        drop_last (`bool`, defaults to `False`):
-            Whether or not to drop incomplete buckets of data after completely iterating over all data
-            in the dataset. If set to True, only batches that have `batch_size` number of entries will
-            be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
-            and batches that do not have `batch_size` number of entries will also be yielded.
-    """
-
-    def __init__(
-        self, data_source: ImageOrVideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
-    ) -> None:
-        self.data_source = data_source
-        self.batch_size = batch_size
-        self.shuffle = shuffle
-        self.drop_last = drop_last
-
-        self.buckets = {resolution: [] for resolution in data_source.resolution_buckets}
-
-        self._raised_warning_for_drop_last = False
-
-    def __len__(self):
-        if self.drop_last and not self._raised_warning_for_drop_last:
-            self._raised_warning_for_drop_last = True
-            logger.warning(
-                "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
-            )
-        return (len(self.data_source) + self.batch_size - 1) // self.batch_size
-
-    def __iter__(self):
-        for index, data in enumerate(self.data_source):
-            video_metadata = data["video_metadata"]
-            f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
-
-            self.buckets[(f, h, w)].append(data)
-            if len(self.buckets[(f, h, w)]) == self.batch_size:
-                if self.shuffle:
-                    random.shuffle(self.buckets[(f, h, w)])
-                yield self.buckets[(f, h, w)]
-                del self.buckets[(f, h, w)]
-                self.buckets[(f, h, w)] = []
-
-        if self.drop_last:
-            return
-
-        for fhw, bucket in list(self.buckets.items()):
-            if len(bucket) == 0:
-                continue
-            if self.shuffle:
-                random.shuffle(bucket)
-                yield bucket
-                del self.buckets[fhw]
-                self.buckets[fhw] = []
diff --git a/finetrainers/functional/__init__.py b/finetrainers/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a62a87847ac0e61521a2331ec3b9ea08cbd49abb
--- /dev/null
+++ b/finetrainers/functional/__init__.py
@@ -0,0 +1,16 @@
+from .diffusion import flow_match_target, flow_match_xt
+from .image import (
+    bicubic_resize_image,
+    center_crop_image,
+    find_nearest_resolution_image,
+    resize_crop_image,
+    resize_to_nearest_bucket_image,
+)
+from .text import dropout_caption, dropout_embeddings_to_zero, remove_prefix
+from .video import (
+    bicubic_resize_video,
+    center_crop_video,
+    find_nearest_video_resolution,
+    resize_crop_video,
+    resize_to_nearest_bucket_video,
+)
diff --git a/finetrainers/functional/diffusion.py b/finetrainers/functional/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9d553895c2fb251abf80f01f284049acf84f87d
--- /dev/null
+++ b/finetrainers/functional/diffusion.py
@@ -0,0 +1,11 @@
+import torch
+
+
+def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+    r"""Forward process of flow matching."""
+    return (1.0 - t) * x0 + t * n
+
+
+def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
+    r"""Loss target for flow matching."""
+    return n - x0
diff --git a/finetrainers/functional/image.py b/finetrainers/functional/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b644e4495dd38ec127ec748c03a64dafd783f00
--- /dev/null
+++ b/finetrainers/functional/image.py
@@ -0,0 +1,54 @@
+from typing import List, Literal, Tuple
+
+import torch
+import torch.nn.functional as F
+
+
+def center_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+    num_channels, height, width = image.shape
+    crop_h, crop_w = size
+    top = (height - crop_h) // 2
+    left = (width - crop_w) // 2
+    return image[:, top : top + crop_h, left : left + crop_w]
+
+
+def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+    num_channels, height, width = image.shape
+    target_h, target_w = size
+    scale = max(target_h / height, target_w / width)
+    new_h, new_w = int(height * scale), int(width * scale)
+    image = F.interpolate(image, size=(new_h, new_w), mode="bilinear", align_corners=False)
+    return center_crop_image(image, size)
+
+
+def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+    return F.interpolate(image, size=size, mode="bicubic", align_corners=False)
+
+
+def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
+    num_channels, height, width = image.shape
+    aspect_ratio = width / height
+
+    def aspect_ratio_diff(bucket):
+        return abs((bucket[1] / bucket[0]) - aspect_ratio)
+
+    return min(resolution_buckets, key=aspect_ratio_diff)
+
+
+def resize_to_nearest_bucket_image(
+    image: torch.Tensor,
+    resolution_buckets: List[Tuple[int, int]],
+    resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic",
+) -> torch.Tensor:
+    target_size = find_nearest_resolution_image(image, resolution_buckets)
+
+    if resize_mode == "center_crop":
+        return center_crop_image(image, target_size)
+    elif resize_mode == "resize_crop":
+        return resize_crop_image(image, target_size)
+    elif resize_mode == "bicubic":
+        return bicubic_resize_image(image, target_size)
+    else:
+        raise ValueError(
+            f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'."
+        )
diff --git a/finetrainers/functional/text.py b/finetrainers/functional/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e823edfc2e3f4a93d2afddf6df71a4198f05219
--- /dev/null
+++ b/finetrainers/functional/text.py
@@ -0,0 +1,26 @@
+import random
+from typing import List, Union
+
+import torch
+
+
+def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]:
+    if random.random() >= dropout_p:
+        return caption
+    if isinstance(caption, str):
+        return ""
+    return [""] * len(caption)
+
+
+def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor:
+    if random.random() >= dropout_p:
+        return embed
+    embed = torch.zeros_like(embed)
+    return embed
+
+
+def remove_prefix(text: str, prefixes: List[str]) -> str:
+    for prefix in prefixes:
+        if text.startswith(prefix):
+            return text.removeprefix(prefix).strip()
+    return text
diff --git a/finetrainers/functional/video.py b/finetrainers/functional/video.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcbc382b0615e53270f3b17746fec14f438ddd16
--- /dev/null
+++ b/finetrainers/functional/video.py
@@ -0,0 +1,94 @@
+from typing import List, Literal, Tuple
+
+import torch
+import torch.nn.functional as F
+
+
+def center_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+    num_frames, num_channels, height, width = video.shape
+    crop_h, crop_w = size
+    top = (height - crop_h) // 2
+    left = (width - crop_w) // 2
+    return video[:, :, top : top + crop_h, left : left + crop_w]
+
+
+def resize_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+    num_frames, num_channels, height, width = video.shape
+    target_h, target_w = size
+    scale = max(target_h / height, target_w / width)
+    new_h, new_w = int(height * scale), int(width * scale)
+    video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False)
+    return center_crop_video(video, size)
+
+
+def bicubic_resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
+    num_frames, num_channels, height, width = video.shape
+    video = F.interpolate(video, size=size, mode="bicubic", align_corners=False)
+    return video
+
+
+def find_nearest_video_resolution(
+    video: torch.Tensor, resolution_buckets: List[Tuple[int, int, int]]
+) -> Tuple[int, int, int]:
+    num_frames, num_channels, height, width = video.shape
+    aspect_ratio = width / height
+    possible_buckets = [b for b in resolution_buckets if b[0] <= num_frames]
+
+    if not possible_buckets:
+        best_frame_match = min(resolution_buckets, key=lambda b: abs(b[0] - num_frames))
+    else:
+        best_frame_match = max(possible_buckets, key=lambda b: b[0])
+
+    frame_filtered_buckets = [b for b in resolution_buckets if b[0] == best_frame_match[0]]
+
+    def aspect_ratio_diff(bucket):
+        return abs((bucket[2] / bucket[1]) - aspect_ratio)
+
+    return min(frame_filtered_buckets, key=aspect_ratio_diff)
+
+
+def resize_to_nearest_bucket_video(
+    video: torch.Tensor,
+    resolution_buckets: List[Tuple[int, int, int]],
+    resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic",
+) -> torch.Tensor:
+    """
+    Resizes a video tensor to the nearest resolution bucket using the specified mode.
+    - It first finds a frame match with <= T frames.
+    - Then, it selects the closest height/width bucket.
+
+    Args:
+        video (`torch.Tensor`):
+            Input video tensor of shape `(B, T, C, H, W)`.
+        resolution_buckets (`List[Tuple[int, int, int]]`):
+            Available (num_frames, height, width) resolution buckets.
+        resize_mode (`str`):
+            One of ["center_crop", "resize_crop", "bicubic"].
+
+    Returns:
+        `torch.Tensor`:
+            Resized video tensor of the nearest bucket resolution.
+    """
+    target_frames, target_h, target_w = find_nearest_video_resolution(video, resolution_buckets)
+
+    # Adjust frame count: only interpolate frames if no lesser/equal frame count exists
+    num_frames, num_channels, height, width = video.shape
+    _first_frame_only = False
+    if num_frames > target_frames:
+        # Downsample: Select frames evenly
+        indices = torch.linspace(0, num_frames - 1, target_frames).long()
+        video = video[indices, :, :, :]
+    elif num_frames < target_frames:
+        _first_frame_only = False
+
+    # Resize spatial resolution
+    if resize_mode == "center_crop":
+        return center_crop_video(video, (target_h, target_w)), _first_frame_only
+    elif resize_mode == "resize_crop":
+        return resize_crop_video(video, (target_h, target_w)), _first_frame_only
+    elif resize_mode == "bicubic":
+        return bicubic_resize_video(video, (target_h, target_w)), _first_frame_only
+    else:
+        raise ValueError(
+            f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'."
+        )
diff --git a/finetrainers/hooks/__init__.py b/finetrainers/hooks/__init__.py
deleted file mode 100644
index f0c3a432f4021ec9b2666b48047c6fd40f3849b9..0000000000000000000000000000000000000000
--- a/finetrainers/hooks/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .layerwise_upcasting import apply_layerwise_upcasting
diff --git a/finetrainers/hooks/hooks.py b/finetrainers/hooks/hooks.py
deleted file mode 100644
index e779795279e2302de286096563538d2beb818bac..0000000000000000000000000000000000000000
--- a/finetrainers/hooks/hooks.py
+++ /dev/null
@@ -1,176 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import functools
-from typing import Any, Dict, Optional, Tuple
-
-import torch
-from accelerate.logging import get_logger
-
-from ..constants import FINETRAINERS_LOG_LEVEL
-
-
-logger = get_logger("finetrainers")  # pylint: disable=invalid-name
-logger.setLevel(FINETRAINERS_LOG_LEVEL)
-
-
-class ModelHook:
-    r"""
-    A hook that contains callbacks to be executed just before and after the forward method of a model.
-    """
-
-    _is_stateful = False
-
-    def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
-        r"""
-        Hook that is executed when a model is initialized.
-        Args:
-            module (`torch.nn.Module`):
-                The module attached to this hook.
-        """
-        return module
-
-    def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
-        r"""
-        Hook that is executed when a model is deinitalized.
-        Args:
-            module (`torch.nn.Module`):
-                The module attached to this hook.
-        """
-        module.forward = module._old_forward
-        del module._old_forward
-        return module
-
-    def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
-        r"""
-        Hook that is executed just before the forward method of the model.
-        Args:
-            module (`torch.nn.Module`):
-                The module whose forward pass will be executed just after this event.
-            args (`Tuple[Any]`):
-                The positional arguments passed to the module.
-            kwargs (`Dict[Str, Any]`):
-                The keyword arguments passed to the module.
-        Returns:
-            `Tuple[Tuple[Any], Dict[Str, Any]]`:
-                A tuple with the treated `args` and `kwargs`.
-        """
-        return args, kwargs
-
-    def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
-        r"""
-        Hook that is executed just after the forward method of the model.
-        Args:
-            module (`torch.nn.Module`):
-                The module whose forward pass been executed just before this event.
-            output (`Any`):
-                The output of the module.
-        Returns:
-            `Any`: The processed `output`.
-        """
-        return output
-
-    def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
-        r"""
-        Hook that is executed when the hook is detached from a module.
-        Args:
-            module (`torch.nn.Module`):
-                The module detached from this hook.
-        """
-        return module
-
-    def reset_state(self, module: torch.nn.Module):
-        if self._is_stateful:
-            raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
-        return module
-
-
-class HookRegistry:
-    def __init__(self, module_ref: torch.nn.Module) -> None:
-        super().__init__()
-
-        self.hooks: Dict[str, ModelHook] = {}
-
-        self._module_ref = module_ref
-        self._hook_order = []
-
-    def register_hook(self, hook: ModelHook, name: str) -> None:
-        if name in self.hooks.keys():
-            logger.warning(f"Hook with name {name} already exists, replacing it.")
-
-        if hasattr(self._module_ref, "_old_forward"):
-            old_forward = self._module_ref._old_forward
-        else:
-            old_forward = self._module_ref.forward
-            self._module_ref._old_forward = self._module_ref.forward
-
-        self._module_ref = hook.initialize_hook(self._module_ref)
-
-        if hasattr(hook, "new_forward"):
-            rewritten_forward = hook.new_forward
-
-            def new_forward(module, *args, **kwargs):
-                args, kwargs = hook.pre_forward(module, *args, **kwargs)
-                output = rewritten_forward(module, *args, **kwargs)
-                return hook.post_forward(module, output)
-        else:
-
-            def new_forward(module, *args, **kwargs):
-                args, kwargs = hook.pre_forward(module, *args, **kwargs)
-                output = old_forward(*args, **kwargs)
-                return hook.post_forward(module, output)
-
-        self._module_ref.forward = functools.update_wrapper(
-            functools.partial(new_forward, self._module_ref), old_forward
-        )
-
-        self.hooks[name] = hook
-        self._hook_order.append(name)
-
-    def get_hook(self, name: str) -> Optional[ModelHook]:
-        if name not in self.hooks.keys():
-            return None
-        return self.hooks[name]
-
-    def remove_hook(self, name: str) -> None:
-        if name not in self.hooks.keys():
-            raise ValueError(f"Hook with name {name} not found.")
-        self.hooks[name].deinitalize_hook(self._module_ref)
-        del self.hooks[name]
-        self._hook_order.remove(name)
-
-    def reset_stateful_hooks(self, recurse: bool = True) -> None:
-        for hook_name in self._hook_order:
-            hook = self.hooks[hook_name]
-            if hook._is_stateful:
-                hook.reset_state(self._module_ref)
-
-        if recurse:
-            for module in self._module_ref.modules():
-                if hasattr(module, "_diffusers_hook"):
-                    module._diffusers_hook.reset_stateful_hooks(recurse=False)
-
-    @classmethod
-    def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
-        if not hasattr(module, "_diffusers_hook"):
-            module._diffusers_hook = cls(module)
-        return module._diffusers_hook
-
-    def __repr__(self) -> str:
-        hook_repr = ""
-        for i, hook_name in enumerate(self._hook_order):
-            hook_repr += f"  ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
-            if i < len(self._hook_order) - 1:
-                hook_repr += "\n"
-        return f"HookRegistry(\n{hook_repr}\n)"
diff --git a/finetrainers/hooks/layerwise_upcasting.py b/finetrainers/hooks/layerwise_upcasting.py
deleted file mode 100644
index b7bdc38021c5145a2a6ac515270dc356385d03a7..0000000000000000000000000000000000000000
--- a/finetrainers/hooks/layerwise_upcasting.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import re
-from typing import Optional, Tuple, Type
-
-import torch
-from accelerate.logging import get_logger
-
-from ..constants import FINETRAINERS_LOG_LEVEL
-from .hooks import HookRegistry, ModelHook
-
-
-logger = get_logger("finetrainers")  # pylint: disable=invalid-name
-logger.setLevel(FINETRAINERS_LOG_LEVEL)
-
-
-# fmt: off
-_SUPPORTED_PYTORCH_LAYERS = (
-    torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
-    torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
-    torch.nn.Linear,
-)
-
-_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm")
-# fmt: on
-
-
-class LayerwiseUpcastingHook(ModelHook):
-    r"""
-    A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
-    for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
-    footprint.
-    """
-
-    _is_stateful = False
-
-    def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
-        self.storage_dtype = storage_dtype
-        self.compute_dtype = compute_dtype
-        self.non_blocking = non_blocking
-
-    def initialize_hook(self, module: torch.nn.Module):
-        module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
-        return module
-
-    def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
-        module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
-        return args, kwargs
-
-    def post_forward(self, module: torch.nn.Module, output):
-        module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
-        return output
-
-
-def apply_layerwise_upcasting(
-    module: torch.nn.Module,
-    storage_dtype: torch.dtype,
-    compute_dtype: torch.dtype,
-    skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN,
-    skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None,
-    non_blocking: bool = False,
-    _prefix: str = "",
-) -> None:
-    r"""
-    Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
-    nn.Module using diffusers layers or pytorch primitives.
-    Args:
-        module (`torch.nn.Module`):
-            The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
-            precision dtype for storage.
-        storage_dtype (`torch.dtype`):
-            The dtype to cast the module to before/after the forward pass for storage.
-        compute_dtype (`torch.dtype`):
-            The dtype to cast the module to during the forward pass for computation.
-        skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
-            A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
-        skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`):
-            A list of module classes to skip during the layerwise upcasting process.
-        non_blocking (`bool`, defaults to `False`):
-            If `True`, the weight casting operations are non-blocking.
-    """
-    if skip_modules_classes is None and skip_modules_pattern is None:
-        apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
-        return
-
-    should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
-        skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
-    )
-    if should_skip:
-        logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"')
-        return
-
-    if isinstance(module, _SUPPORTED_PYTORCH_LAYERS):
-        logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"')
-        apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
-        return
-
-    for name, submodule in module.named_children():
-        layer_name = f"{_prefix}.{name}" if _prefix else name
-        apply_layerwise_upcasting(
-            submodule,
-            storage_dtype,
-            compute_dtype,
-            skip_modules_pattern,
-            skip_modules_classes,
-            non_blocking,
-            _prefix=layer_name,
-        )
-
-
-def apply_layerwise_upcasting_hook(
-    module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
-) -> None:
-    r"""
-    Applies a `LayerwiseUpcastingHook` to a given module.
-    Args:
-        module (`torch.nn.Module`):
-            The module to attach the hook to.
-        storage_dtype (`torch.dtype`):
-            The dtype to cast the module to before the forward pass.
-        compute_dtype (`torch.dtype`):
-            The dtype to cast the module to during the forward pass.
-        non_blocking (`bool`):
-            If `True`, the weight casting operations are non-blocking.
-    """
-    registry = HookRegistry.check_if_exists_or_initialize(module)
-    hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking)
-    registry.register_hook(hook, "layerwise_upcasting")
diff --git a/finetrainers/logging.py b/finetrainers/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..29d3597db47a91c1f3ec353a677c7309ceffc19d
--- /dev/null
+++ b/finetrainers/logging.py
@@ -0,0 +1,111 @@
+import logging
+import os
+from typing import TYPE_CHECKING, Union
+
+from .constants import FINETRAINERS_LOG_LEVEL
+
+
+if TYPE_CHECKING:
+    from .parallel import ParallelBackendType
+
+
+class FinetrainersLoggerAdapter(logging.LoggerAdapter):
+    def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None:
+        super().__init__(logger, {})
+        self.parallel_backend = parallel_backend
+        self._log_freq = {}
+        self._log_freq_counter = {}
+
+    def log(
+        self,
+        level,
+        msg,
+        *args,
+        main_process_only: bool = False,
+        local_main_process_only: bool = True,
+        in_order: bool = False,
+        **kwargs,
+    ):
+        # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
+        kwargs.setdefault("stacklevel", 2)
+
+        if not self.isEnabledFor(level):
+            return
+
+        if self.parallel_backend is None:
+            if int(os.environ.get("RANK", 0)) == 0:
+                msg, kwargs = self.process(msg, kwargs)
+                self.logger.log(level, msg, *args, **kwargs)
+            return
+
+        if (main_process_only or local_main_process_only) and in_order:
+            raise ValueError(
+                "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True."
+            )
+
+        if (main_process_only and self.parallel_backend.is_main_process) or (
+            local_main_process_only and self.parallel_backend.is_local_main_process
+        ):
+            msg, kwargs = self.process(msg, kwargs)
+            self.logger.log(level, msg, *args, **kwargs)
+            return
+
+        if in_order:
+            for i in range(self.parallel_backend.world_size):
+                if self.rank == i:
+                    msg, kwargs = self.process(msg, kwargs)
+                    self.logger.log(level, msg, *args, **kwargs)
+                self.parallel_backend.wait_for_everyone()
+            return
+
+        if not main_process_only and not local_main_process_only:
+            msg, kwargs = self.process(msg, kwargs)
+            self.logger.log(level, msg, *args, **kwargs)
+            return
+
+    def log_freq(
+        self,
+        level: str,
+        name: str,
+        msg: str,
+        frequency: int,
+        *,
+        main_process_only: bool = False,
+        local_main_process_only: bool = True,
+        in_order: bool = False,
+        **kwargs,
+    ) -> None:
+        if frequency <= 0:
+            return
+        if name not in self._log_freq_counter:
+            self._log_freq[name] = frequency
+            self._log_freq_counter[name] = 0
+        if self._log_freq_counter[name] % self._log_freq[name] == 0:
+            self.log(
+                level,
+                msg,
+                main_process_only=main_process_only,
+                local_main_process_only=local_main_process_only,
+                in_order=in_order,
+                **kwargs,
+            )
+        self._log_freq_counter[name] += 1
+
+
+def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]:
+    global _logger
+    return _logger
+
+
+def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter:
+    _logger.parallel_backend = parallel_backend
+
+
+_logger = logging.getLogger("finetrainers")
+_logger.setLevel(FINETRAINERS_LOG_LEVEL)
+_console_handler = logging.StreamHandler()
+_console_handler.setLevel(FINETRAINERS_LOG_LEVEL)
+_formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
+_console_handler.setFormatter(_formatter)
+_logger.addHandler(_console_handler)
+_logger = FinetrainersLoggerAdapter(_logger)
diff --git a/finetrainers/models/__init__.py b/finetrainers/models/__init__.py
index c24ab951d8b3cd8e52dfa0b3647c7d8183c1352c..fb7091a5e1650715591fdd7377e7c2850c0e3bb3 100644
--- a/finetrainers/models/__init__.py
+++ b/finetrainers/models/__init__.py
@@ -1,33 +1 @@
-from typing import Any, Dict
-
-from .cogvideox import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, COGVIDEOX_T2V_LORA_CONFIG
-from .hunyuan_video import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, HUNYUAN_VIDEO_T2V_LORA_CONFIG
-from .ltx_video import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, LTX_VIDEO_T2V_LORA_CONFIG
-
-
-SUPPORTED_MODEL_CONFIGS = {
-    "hunyuan_video": {
-        "lora": HUNYUAN_VIDEO_T2V_LORA_CONFIG,
-        "full-finetune": HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG,
-    },
-    "ltx_video": {
-        "lora": LTX_VIDEO_T2V_LORA_CONFIG,
-        "full-finetune": LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG,
-    },
-    "cogvideox": {
-        "lora": COGVIDEOX_T2V_LORA_CONFIG,
-        "full-finetune": COGVIDEOX_T2V_FULL_FINETUNE_CONFIG,
-    },
-}
-
-
-def get_config_from_model_name(model_name: str, training_type: str) -> Dict[str, Any]:
-    if model_name not in SUPPORTED_MODEL_CONFIGS:
-        raise ValueError(
-            f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
-        )
-    if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
-        raise ValueError(
-            f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
-        )
-    return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
+from .modeling_utils import ModelSpecification
diff --git a/finetrainers/models/cogvideox/__init__.py b/finetrainers/models/cogvideox/__init__.py
index 7a72064347e083b0d277437e6a4e6e2e54164277..e1f9a84073541b0e764877bac0335637f03d32ca 100644
--- a/finetrainers/models/cogvideox/__init__.py
+++ b/finetrainers/models/cogvideox/__init__.py
@@ -1,2 +1 @@
-from .full_finetune import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG
-from .lora import COGVIDEOX_T2V_LORA_CONFIG
+from .base_specification import CogVideoXModelSpecification
diff --git a/finetrainers/models/cogvideox/base_specification.py b/finetrainers/models/cogvideox/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..918c1a9a950b6b2404fc7362c608dc1ebf319f27
--- /dev/null
+++ b/finetrainers/models/cogvideox/base_specification.py
@@ -0,0 +1,424 @@
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+    AutoencoderKLCogVideoX,
+    CogVideoXDDIMScheduler,
+    CogVideoXImageToVideoPipeline,
+    CogVideoXPipeline,
+    CogVideoXTransformer3DModel,
+)
+from PIL.Image import Image
+from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
+
+from ... import data
+from ...logging import get_logger
+from ...processors import ProcessorMixin, T5Processor
+from ...typing import ArtifactType, SchedulerType
+from ...utils import get_non_null_items
+from ..modeling_utils import ModelSpecification
+from ..utils import DiagonalGaussianDistribution
+from .utils import prepare_rotary_positional_embeddings
+
+
+logger = get_logger()
+
+
+class CogVideoXLatentEncodeProcessor(ProcessorMixin):
+    r"""
+    Processor to encode image/video into latents using the CogVideoX VAE.
+
+    Args:
+        output_names (`List[str]`):
+            The names of the outputs that the processor returns. The outputs are in the following order:
+            - latents: The latents of the input image/video.
+    """
+
+    def __init__(self, output_names: List[str]):
+        super().__init__()
+        self.output_names = output_names
+        assert len(self.output_names) == 1
+
+    def forward(
+        self,
+        vae: AutoencoderKLCogVideoX,
+        image: Optional[torch.Tensor] = None,
+        video: Optional[torch.Tensor] = None,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+    ) -> Dict[str, torch.Tensor]:
+        device = vae.device
+        dtype = vae.dtype
+
+        if image is not None:
+            video = image.unsqueeze(1)
+
+        assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
+        video = video.to(device=device, dtype=vae.dtype)
+        video = video.permute(0, 2, 1, 3, 4).contiguous()  # [B, F, C, H, W] -> [B, C, F, H, W]
+
+        if compute_posterior:
+            latents = vae.encode(video).latent_dist.sample(generator=generator)
+            latents = latents.to(dtype=dtype)
+        else:
+            if vae.use_slicing and video.shape[0] > 1:
+                encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
+                moments = torch.cat(encoded_slices)
+            else:
+                moments = vae._encode(video)
+            latents = moments.to(dtype=dtype)
+
+        latents = latents.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W] -> [B, F, C, H, W]
+        return {self.output_names[0]: latents}
+
+
+class CogVideoXModelSpecification(ModelSpecification):
+    def __init__(
+        self,
+        pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b",
+        tokenizer_id: Optional[str] = None,
+        text_encoder_id: Optional[str] = None,
+        transformer_id: Optional[str] = None,
+        vae_id: Optional[str] = None,
+        text_encoder_dtype: torch.dtype = torch.bfloat16,
+        transformer_dtype: torch.dtype = torch.bfloat16,
+        vae_dtype: torch.dtype = torch.bfloat16,
+        revision: Optional[str] = None,
+        cache_dir: Optional[str] = None,
+        condition_model_processors: List[ProcessorMixin] = None,
+        latent_model_processors: List[ProcessorMixin] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            pretrained_model_name_or_path=pretrained_model_name_or_path,
+            tokenizer_id=tokenizer_id,
+            text_encoder_id=text_encoder_id,
+            transformer_id=transformer_id,
+            vae_id=vae_id,
+            text_encoder_dtype=text_encoder_dtype,
+            transformer_dtype=transformer_dtype,
+            vae_dtype=vae_dtype,
+            revision=revision,
+            cache_dir=cache_dir,
+        )
+
+        if condition_model_processors is None:
+            condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
+        if latent_model_processors is None:
+            latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])]
+
+        self.condition_model_processors = condition_model_processors
+        self.latent_model_processors = latent_model_processors
+
+    @property
+    def _resolution_dim_keys(self):
+        return {"latents": (1, 3, 4)}
+
+    def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+        if self.tokenizer_id is not None:
+            tokenizer = AutoTokenizer.from_pretrained(
+                self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
+            )
+        else:
+            tokenizer = T5Tokenizer.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="tokenizer",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        if self.text_encoder_id is not None:
+            text_encoder = AutoModel.from_pretrained(
+                self.text_encoder_id,
+                torch_dtype=self.text_encoder_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            text_encoder = T5EncoderModel.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="text_encoder",
+                torch_dtype=self.text_encoder_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        return {"tokenizer": tokenizer, "text_encoder": text_encoder}
+
+    def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+        if self.vae_id is not None:
+            vae = AutoencoderKLCogVideoX.from_pretrained(
+                self.vae_id,
+                torch_dtype=self.vae_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            vae = AutoencoderKLCogVideoX.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="vae",
+                torch_dtype=self.vae_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        return {"vae": vae}
+
+    def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+        if self.transformer_id is not None:
+            transformer = CogVideoXTransformer3DModel.from_pretrained(
+                self.transformer_id,
+                torch_dtype=self.transformer_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            transformer = CogVideoXTransformer3DModel.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="transformer",
+                torch_dtype=self.transformer_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        scheduler = CogVideoXDDIMScheduler.from_pretrained(
+            self.pretrained_model_name_or_path, subfolder="scheduler", revision=self.revision, cache_dir=self.cache_dir
+        )
+
+        return {"transformer": transformer, "scheduler": scheduler}
+
+    def load_pipeline(
+        self,
+        tokenizer: Optional[T5Tokenizer] = None,
+        text_encoder: Optional[T5EncoderModel] = None,
+        transformer: Optional[CogVideoXTransformer3DModel] = None,
+        vae: Optional[AutoencoderKLCogVideoX] = None,
+        scheduler: Optional[CogVideoXDDIMScheduler] = None,
+        enable_slicing: bool = False,
+        enable_tiling: bool = False,
+        enable_model_cpu_offload: bool = False,
+        training: bool = False,
+        **kwargs,
+    ) -> CogVideoXPipeline:
+        components = {
+            "tokenizer": tokenizer,
+            "text_encoder": text_encoder,
+            "transformer": transformer,
+            "vae": vae,
+            "scheduler": scheduler,
+        }
+        components = get_non_null_items(components)
+
+        pipe = CogVideoXPipeline.from_pretrained(
+            self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+        )
+        pipe.text_encoder.to(self.text_encoder_dtype)
+        pipe.vae.to(self.vae_dtype)
+
+        if not training:
+            pipe.transformer.to(self.transformer_dtype)
+
+        if enable_slicing:
+            pipe.vae.enable_slicing()
+        if enable_tiling:
+            pipe.vae.enable_tiling()
+        if enable_model_cpu_offload:
+            pipe.enable_model_cpu_offload()
+
+        return pipe
+
+    @torch.no_grad()
+    def prepare_conditions(
+        self,
+        tokenizer: T5Tokenizer,
+        text_encoder: T5EncoderModel,
+        caption: str,
+        max_sequence_length: int = 226,
+        **kwargs,
+    ) -> Dict[str, Any]:
+        conditions = {
+            "tokenizer": tokenizer,
+            "text_encoder": text_encoder,
+            "caption": caption,
+            "max_sequence_length": max_sequence_length,
+            **kwargs,
+        }
+        input_keys = set(conditions.keys())
+        conditions = super().prepare_conditions(**conditions)
+        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+        conditions.pop("prompt_attention_mask", None)
+        return conditions
+
+    @torch.no_grad()
+    def prepare_latents(
+        self,
+        vae: AutoencoderKLCogVideoX,
+        image: Optional[torch.Tensor] = None,
+        video: Optional[torch.Tensor] = None,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+        **kwargs,
+    ) -> Dict[str, torch.Tensor]:
+        conditions = {
+            "vae": vae,
+            "image": image,
+            "video": video,
+            "generator": generator,
+            "compute_posterior": compute_posterior,
+            **kwargs,
+        }
+        input_keys = set(conditions.keys())
+        conditions = super().prepare_latents(**conditions)
+        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+        return conditions
+
+    def forward(
+        self,
+        transformer: CogVideoXTransformer3DModel,
+        scheduler: CogVideoXDDIMScheduler,
+        condition_model_conditions: Dict[str, torch.Tensor],
+        latent_model_conditions: Dict[str, torch.Tensor],
+        sigmas: torch.Tensor,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, ...]:
+        # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
+        VAE_SPATIAL_SCALE_FACTOR = 8
+        rope_base_height = self.transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
+        rope_base_width = self.transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR
+        patch_size = self.transformer_config.patch_size
+        patch_size_t = getattr(self.transformer_config, "patch_size_t", None)
+
+        if compute_posterior:
+            latents = latent_model_conditions.pop("latents")
+        else:
+            posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"), _dim=2)
+            latents = posterior.sample(generator=generator)
+            del posterior
+
+        if not self.vae_config.invert_scale_latents:
+            latents = latents * self.vae_config.scaling_factor
+
+        if patch_size_t is not None:
+            latents = self._pad_frames(latents, patch_size_t)
+
+        timesteps = (sigmas.flatten() * 1000.0).long()
+
+        noise = torch.zeros_like(latents).normal_(generator=generator)
+        noisy_latents = scheduler.add_noise(latents, noise, timesteps)
+
+        batch_size, num_frames, num_channels, height, width = latents.shape
+        ofs_emb = (
+            None
+            if getattr(self.transformer_config, "ofs_embed_dim", None) is None
+            else latents.new_full((batch_size,), fill_value=2.0)
+        )
+
+        image_rotary_emb = (
+            prepare_rotary_positional_embeddings(
+                height=height * VAE_SPATIAL_SCALE_FACTOR,
+                width=width * VAE_SPATIAL_SCALE_FACTOR,
+                num_frames=num_frames,
+                vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
+                patch_size=patch_size,
+                patch_size_t=patch_size_t,
+                attention_head_dim=self.transformer_config.attention_head_dim,
+                device=transformer.device,
+                base_height=rope_base_height,
+                base_width=rope_base_width,
+            )
+            if self.transformer_config.use_rotary_positional_embeddings
+            else None
+        )
+
+        latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+        latent_model_conditions["image_rotary_emb"] = image_rotary_emb
+        latent_model_conditions["ofs"] = ofs_emb
+        condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
+
+        velocity = transformer(
+            **latent_model_conditions,
+            **condition_model_conditions,
+            timestep=timesteps,
+            return_dict=False,
+        )[0]
+        # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
+        # code paths as scheduler.get_velocity(), which can be confusing to understand.
+        pred = scheduler.get_velocity(velocity, noisy_latents, timesteps)
+        target = latents
+
+        return pred, target, sigmas
+
+    def validation(
+        self,
+        pipeline: CogVideoXPipeline,
+        prompt: str,
+        image: Optional[Image] = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_frames: Optional[int] = None,
+        num_inference_steps: int = 50,
+        generator: Optional[torch.Generator] = None,
+        **kwargs,
+    ) -> List[ArtifactType]:
+        # TODO(aryan): add support for more parameters
+        if image is not None:
+            pipeline = CogVideoXImageToVideoPipeline.from_pipe(pipeline)
+
+        generation_kwargs = {
+            "prompt": prompt,
+            "image": image,
+            "height": height,
+            "width": width,
+            "num_frames": num_frames,
+            "num_inference_steps": num_inference_steps,
+            "generator": generator,
+            "return_dict": True,
+            "output_type": "pil",
+        }
+        generation_kwargs = get_non_null_items(generation_kwargs)
+        video = pipeline(**generation_kwargs).frames[0]
+        return [data.VideoArtifact(value=video)]
+
+    def _save_lora_weights(
+        self,
+        directory: str,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        # TODO(aryan): this needs refactoring
+        if transformer_state_dict is not None:
+            CogVideoXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
+        if scheduler is not None:
+            scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+    def _save_model(
+        self,
+        directory: str,
+        transformer: CogVideoXTransformer3DModel,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+    ) -> None:
+        # TODO(aryan): this needs refactoring
+        if transformer_state_dict is not None:
+            with init_empty_weights():
+                transformer_copy = CogVideoXTransformer3DModel.from_config(transformer.config)
+            transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+            transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+        if scheduler is not None:
+            scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+    @staticmethod
+    def _pad_frames(latents: torch.Tensor, patch_size_t: int) -> torch.Tensor:
+        num_frames = latents.size(1)
+        additional_frames = patch_size_t - (num_frames % patch_size_t)
+        if additional_frames > 0:
+            last_frame = latents[:, -1:]
+            padding_frames = last_frame.expand(-1, additional_frames, -1, -1, -1)
+            latents = torch.cat([latents, padding_frames], dim=1)
+        return latents
diff --git a/finetrainers/models/cogvideox/full_finetune.py b/finetrainers/models/cogvideox/full_finetune.py
deleted file mode 100644
index b7f2b4bbcff806c7c12b736c36bb9733b9980353..0000000000000000000000000000000000000000
--- a/finetrainers/models/cogvideox/full_finetune.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from diffusers import CogVideoXPipeline
-
-from .lora import (
-    calculate_noisy_latents,
-    collate_fn_t2v,
-    forward_pass,
-    initialize_pipeline,
-    load_condition_models,
-    load_diffusion_models,
-    load_latent_models,
-    post_latent_preparation,
-    prepare_conditions,
-    prepare_latents,
-    validation,
-)
-
-
-# TODO(aryan): refactor into model specs for better re-use
-COGVIDEOX_T2V_FULL_FINETUNE_CONFIG = {
-    "pipeline_cls": CogVideoXPipeline,
-    "load_condition_models": load_condition_models,
-    "load_latent_models": load_latent_models,
-    "load_diffusion_models": load_diffusion_models,
-    "initialize_pipeline": initialize_pipeline,
-    "prepare_conditions": prepare_conditions,
-    "prepare_latents": prepare_latents,
-    "post_latent_preparation": post_latent_preparation,
-    "collate_fn": collate_fn_t2v,
-    "calculate_noisy_latents": calculate_noisy_latents,
-    "forward_pass": forward_pass,
-    "validation": validation,
-}
diff --git a/finetrainers/models/cogvideox/lora.py b/finetrainers/models/cogvideox/lora.py
deleted file mode 100644
index 65d86ee901d73296c94c2abe20f21293cace45b3..0000000000000000000000000000000000000000
--- a/finetrainers/models/cogvideox/lora.py
+++ /dev/null
@@ -1,334 +0,0 @@
-from typing import Any, Dict, List, Optional, Union
-
-import torch
-from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
-from PIL import Image
-from transformers import T5EncoderModel, T5Tokenizer
-
-from .utils import prepare_rotary_positional_embeddings
-
-
-def load_condition_models(
-    model_id: str = "THUDM/CogVideoX-5b",
-    text_encoder_dtype: torch.dtype = torch.bfloat16,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    **kwargs,
-):
-    tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
-    text_encoder = T5EncoderModel.from_pretrained(
-        model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
-    )
-    return {"tokenizer": tokenizer, "text_encoder": text_encoder}
-
-
-def load_latent_models(
-    model_id: str = "THUDM/CogVideoX-5b",
-    vae_dtype: torch.dtype = torch.bfloat16,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    **kwargs,
-):
-    vae = AutoencoderKLCogVideoX.from_pretrained(
-        model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
-    )
-    return {"vae": vae}
-
-
-def load_diffusion_models(
-    model_id: str = "THUDM/CogVideoX-5b",
-    transformer_dtype: torch.dtype = torch.bfloat16,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    **kwargs,
-):
-    transformer = CogVideoXTransformer3DModel.from_pretrained(
-        model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
-    )
-    scheduler = CogVideoXDDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
-    return {"transformer": transformer, "scheduler": scheduler}
-
-
-def initialize_pipeline(
-    model_id: str = "THUDM/CogVideoX-5b",
-    text_encoder_dtype: torch.dtype = torch.bfloat16,
-    transformer_dtype: torch.dtype = torch.bfloat16,
-    vae_dtype: torch.dtype = torch.bfloat16,
-    tokenizer: Optional[T5Tokenizer] = None,
-    text_encoder: Optional[T5EncoderModel] = None,
-    transformer: Optional[CogVideoXTransformer3DModel] = None,
-    vae: Optional[AutoencoderKLCogVideoX] = None,
-    scheduler: Optional[CogVideoXDDIMScheduler] = None,
-    device: Optional[torch.device] = None,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    enable_slicing: bool = False,
-    enable_tiling: bool = False,
-    enable_model_cpu_offload: bool = False,
-    is_training: bool = False,
-    **kwargs,
-) -> CogVideoXPipeline:
-    component_name_pairs = [
-        ("tokenizer", tokenizer),
-        ("text_encoder", text_encoder),
-        ("transformer", transformer),
-        ("vae", vae),
-        ("scheduler", scheduler),
-    ]
-    components = {}
-    for name, component in component_name_pairs:
-        if component is not None:
-            components[name] = component
-
-    pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
-    pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
-    pipe.vae = pipe.vae.to(dtype=vae_dtype)
-
-    # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
-    # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
-    # DDP optimizer step.
-    if not is_training:
-        pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
-
-    if enable_slicing:
-        pipe.vae.enable_slicing()
-    if enable_tiling:
-        pipe.vae.enable_tiling()
-
-    if enable_model_cpu_offload:
-        pipe.enable_model_cpu_offload(device=device)
-    else:
-        pipe.to(device=device)
-
-    return pipe
-
-
-def prepare_conditions(
-    tokenizer,
-    text_encoder,
-    prompt: Union[str, List[str]],
-    device: Optional[torch.device] = None,
-    dtype: Optional[torch.dtype] = None,
-    max_sequence_length: int = 226,  # TODO: this should be configurable
-    **kwargs,
-):
-    device = device or text_encoder.device
-    dtype = dtype or text_encoder.dtype
-    return _get_t5_prompt_embeds(
-        tokenizer=tokenizer,
-        text_encoder=text_encoder,
-        prompt=prompt,
-        max_sequence_length=max_sequence_length,
-        device=device,
-        dtype=dtype,
-    )
-
-
-def prepare_latents(
-    vae: AutoencoderKLCogVideoX,
-    image_or_video: torch.Tensor,
-    device: Optional[torch.device] = None,
-    dtype: Optional[torch.dtype] = None,
-    generator: Optional[torch.Generator] = None,
-    precompute: bool = False,
-    **kwargs,
-) -> torch.Tensor:
-    device = device or vae.device
-    dtype = dtype or vae.dtype
-
-    if image_or_video.ndim == 4:
-        image_or_video = image_or_video.unsqueeze(2)
-    assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
-
-    image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
-    image_or_video = image_or_video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]
-    if not precompute:
-        latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
-        if not vae.config.invert_scale_latents:
-            latents = latents * vae.config.scaling_factor
-        # For training Cog 1.5, we don't need to handle the scaling factor here.
-        # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
-        # is probably only needed for image-to-video training.
-        # TODO(aryan): investigate this
-        # else:
-        #     latents = 1 / vae.config.scaling_factor * latents
-        latents = latents.to(dtype=dtype)
-        return {"latents": latents}
-    else:
-        # handle vae scaling in the `train()` method directly.
-        if vae.use_slicing and image_or_video.shape[0] > 1:
-            encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
-            h = torch.cat(encoded_slices)
-        else:
-            h = vae._encode(image_or_video)
-        return {"latents": h}
-
-
-def post_latent_preparation(
-    vae_config: Dict[str, Any], latents: torch.Tensor, patch_size_t: Optional[int] = None, **kwargs
-) -> torch.Tensor:
-    if not vae_config.invert_scale_latents:
-        latents = latents * vae_config.scaling_factor
-    # For training Cog 1.5, we don't need to handle the scaling factor here.
-    # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
-    # is probably only needed for image-to-video training.
-    # TODO(aryan): investigate this
-    # else:
-    #     latents = 1 / vae_config.scaling_factor * latents
-    latents = _pad_frames(latents, patch_size_t)
-    latents = latents.permute(0, 2, 1, 3, 4)  # [B, F, C, H, W]
-    return {"latents": latents}
-
-
-def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
-    return {
-        "prompts": [x["prompt"] for x in batch[0]],
-        "videos": torch.stack([x["video"] for x in batch[0]]),
-    }
-
-
-def calculate_noisy_latents(
-    scheduler: CogVideoXDDIMScheduler,
-    noise: torch.Tensor,
-    latents: torch.Tensor,
-    timesteps: torch.LongTensor,
-) -> torch.Tensor:
-    noisy_latents = scheduler.add_noise(latents, noise, timesteps)
-    return noisy_latents
-
-
-def forward_pass(
-    transformer: CogVideoXTransformer3DModel,
-    scheduler: CogVideoXDDIMScheduler,
-    prompt_embeds: torch.Tensor,
-    latents: torch.Tensor,
-    noisy_latents: torch.Tensor,
-    timesteps: torch.LongTensor,
-    ofs_emb: Optional[torch.Tensor] = None,
-    **kwargs,
-) -> torch.Tensor:
-    # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
-    VAE_SPATIAL_SCALE_FACTOR = 8
-    transformer_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
-    batch_size, num_frames, num_channels, height, width = noisy_latents.shape
-    rope_base_height = transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
-    rope_base_width = transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR
-
-    image_rotary_emb = (
-        prepare_rotary_positional_embeddings(
-            height=height * VAE_SPATIAL_SCALE_FACTOR,
-            width=width * VAE_SPATIAL_SCALE_FACTOR,
-            num_frames=num_frames,
-            vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
-            patch_size=transformer_config.patch_size,
-            patch_size_t=transformer_config.patch_size_t if hasattr(transformer_config, "patch_size_t") else None,
-            attention_head_dim=transformer_config.attention_head_dim,
-            device=transformer.device,
-            base_height=rope_base_height,
-            base_width=rope_base_width,
-        )
-        if transformer_config.use_rotary_positional_embeddings
-        else None
-    )
-    ofs_emb = None if transformer_config.ofs_embed_dim is None else latents.new_full((batch_size,), fill_value=2.0)
-
-    velocity = transformer(
-        hidden_states=noisy_latents,
-        timestep=timesteps,
-        encoder_hidden_states=prompt_embeds,
-        ofs=ofs_emb,
-        image_rotary_emb=image_rotary_emb,
-        return_dict=False,
-    )[0]
-    # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
-    # code paths as scheduler.get_velocity(), which can be confusing to understand.
-    denoised_latents = scheduler.get_velocity(velocity, noisy_latents, timesteps)
-
-    return {"latents": denoised_latents}
-
-
-def validation(
-    pipeline: CogVideoXPipeline,
-    prompt: str,
-    image: Optional[Image.Image] = None,
-    video: Optional[List[Image.Image]] = None,
-    height: Optional[int] = None,
-    width: Optional[int] = None,
-    num_frames: Optional[int] = None,
-    num_videos_per_prompt: int = 1,
-    generator: Optional[torch.Generator] = None,
-    **kwargs,
-):
-    generation_kwargs = {
-        "prompt": prompt,
-        "height": height,
-        "width": width,
-        "num_frames": num_frames,
-        "num_videos_per_prompt": num_videos_per_prompt,
-        "generator": generator,
-        "return_dict": True,
-        "output_type": "pil",
-    }
-    generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
-    output = pipeline(**generation_kwargs).frames[0]
-    return [("video", output)]
-
-
-def _get_t5_prompt_embeds(
-    tokenizer: T5Tokenizer,
-    text_encoder: T5EncoderModel,
-    prompt: Union[str, List[str]] = None,
-    max_sequence_length: int = 226,
-    device: Optional[torch.device] = None,
-    dtype: Optional[torch.dtype] = None,
-):
-    prompt = [prompt] if isinstance(prompt, str) else prompt
-
-    text_inputs = tokenizer(
-        prompt,
-        padding="max_length",
-        max_length=max_sequence_length,
-        truncation=True,
-        add_special_tokens=True,
-        return_tensors="pt",
-    )
-    text_input_ids = text_inputs.input_ids
-
-    prompt_embeds = text_encoder(text_input_ids.to(device))[0]
-    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
-
-    return {"prompt_embeds": prompt_embeds}
-
-
-def _pad_frames(latents: torch.Tensor, patch_size_t: int):
-    if patch_size_t is None or patch_size_t == 1:
-        return latents
-
-    # `latents` should be of the following format: [B, C, F, H, W].
-    # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
-    latent_num_frames = latents.shape[2]
-    additional_frames = patch_size_t - latent_num_frames % patch_size_t
-
-    if additional_frames > 0:
-        last_frame = latents[:, :, -1:, :, :]
-        padding_frames = last_frame.repeat(1, 1, additional_frames, 1, 1)
-        latents = torch.cat([latents, padding_frames], dim=2)
-
-    return latents
-
-
-# TODO(aryan): refactor into model specs for better re-use
-COGVIDEOX_T2V_LORA_CONFIG = {
-    "pipeline_cls": CogVideoXPipeline,
-    "load_condition_models": load_condition_models,
-    "load_latent_models": load_latent_models,
-    "load_diffusion_models": load_diffusion_models,
-    "initialize_pipeline": initialize_pipeline,
-    "prepare_conditions": prepare_conditions,
-    "prepare_latents": prepare_latents,
-    "post_latent_preparation": post_latent_preparation,
-    "collate_fn": collate_fn_t2v,
-    "calculate_noisy_latents": calculate_noisy_latents,
-    "forward_pass": forward_pass,
-    "validation": validation,
-}
diff --git a/finetrainers/models/hunyuan_video/__init__.py b/finetrainers/models/hunyuan_video/__init__.py
index 8ac729e91bb0d8af781ea51e856a43bfff1990df..518a42865f0cee30a534da458ec63b08c1a8d7e4 100644
--- a/finetrainers/models/hunyuan_video/__init__.py
+++ b/finetrainers/models/hunyuan_video/__init__.py
@@ -1,2 +1 @@
-from .full_finetune import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG
-from .lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG
+from .base_specification import HunyuanVideoModelSpecification
diff --git a/finetrainers/models/hunyuan_video/base_specification.py b/finetrainers/models/hunyuan_video/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd73fce85e8c804d1526ff0a464288d4da67740
--- /dev/null
+++ b/finetrainers/models/hunyuan_video/base_specification.py
@@ -0,0 +1,413 @@
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+    AutoencoderKLHunyuanVideo,
+    FlowMatchEulerDiscreteScheduler,
+    HunyuanVideoPipeline,
+    HunyuanVideoTransformer3DModel,
+)
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel
+
+from ... import data
+from ... import functional as FF
+from ...logging import get_logger
+from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin
+from ...typing import ArtifactType, SchedulerType
+from ...utils import get_non_null_items
+from ..modeling_utils import ModelSpecification
+
+
+logger = get_logger()
+
+
+class HunyuanLatentEncodeProcessor(ProcessorMixin):
+    r"""
+    Processor to encode image/video into latents using the HunyuanVideo VAE.
+
+    Args:
+        output_names (`List[str]`):
+            The names of the outputs that the processor returns. The outputs are in the following order:
+            - latents: The latents of the input image/video.
+    """
+
+    def __init__(self, output_names: List[str]):
+        super().__init__()
+        self.output_names = output_names
+        assert len(self.output_names) == 1
+
+    def forward(
+        self,
+        vae: AutoencoderKLHunyuanVideo,
+        image: Optional[torch.Tensor] = None,
+        video: Optional[torch.Tensor] = None,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+    ) -> Dict[str, torch.Tensor]:
+        device = vae.device
+        dtype = vae.dtype
+
+        if image is not None:
+            video = image.unsqueeze(1)
+
+        assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
+        video = video.to(device=device, dtype=vae.dtype)
+        video = video.permute(0, 2, 1, 3, 4).contiguous()  # [B, F, C, H, W] -> [B, C, F, H, W]
+
+        if compute_posterior:
+            latents = vae.encode(video).latent_dist.sample(generator=generator)
+            latents = latents.to(dtype=dtype)
+        else:
+            if vae.use_slicing and video.shape[0] > 1:
+                encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
+                moments = torch.cat(encoded_slices)
+            else:
+                moments = vae._encode(video)
+            latents = moments.to(dtype=dtype)
+
+        return {self.output_names[0]: latents}
+
+
+class HunyuanVideoModelSpecification(ModelSpecification):
+    def __init__(
+        self,
+        pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo",
+        tokenizer_id: Optional[str] = None,
+        text_encoder_id: Optional[str] = None,
+        transformer_id: Optional[str] = None,
+        vae_id: Optional[str] = None,
+        text_encoder_dtype: torch.dtype = torch.bfloat16,
+        transformer_dtype: torch.dtype = torch.bfloat16,
+        vae_dtype: torch.dtype = torch.bfloat16,
+        revision: Optional[str] = None,
+        cache_dir: Optional[str] = None,
+        condition_model_processors: List[ProcessorMixin] = None,
+        latent_model_processors: List[ProcessorMixin] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            pretrained_model_name_or_path=pretrained_model_name_or_path,
+            tokenizer_id=tokenizer_id,
+            text_encoder_id=text_encoder_id,
+            transformer_id=transformer_id,
+            vae_id=vae_id,
+            text_encoder_dtype=text_encoder_dtype,
+            transformer_dtype=transformer_dtype,
+            vae_dtype=vae_dtype,
+            revision=revision,
+            cache_dir=cache_dir,
+        )
+
+        if condition_model_processors is None:
+            condition_model_processors = [
+                LlamaProcessor(["encoder_hidden_states", "encoder_attention_mask"]),
+                CLIPPooledProcessor(
+                    ["pooled_projections"],
+                    input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"},
+                ),
+            ]
+        if latent_model_processors is None:
+            latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])]
+
+        self.condition_model_processors = condition_model_processors
+        self.latent_model_processors = latent_model_processors
+
+    @property
+    def _resolution_dim_keys(self):
+        # TODO
+        return {
+            "latents": (2, 3, 4),
+        }
+
+    def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+        if self.tokenizer_id is not None:
+            tokenizer = AutoTokenizer.from_pretrained(
+                self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
+            )
+        else:
+            tokenizer = AutoTokenizer.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="tokenizer",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        if self.tokenizer_2_id is not None:
+            tokenizer_2 = CLIPTokenizer.from_pretrained(
+                self.tokenizer_2_id, revision=self.revision, cache_dir=self.cache_dir
+            )
+        else:
+            tokenizer_2 = CLIPTokenizer.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="tokenizer_2",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        if self.text_encoder_id is not None:
+            text_encoder = LlamaModel.from_pretrained(
+                self.text_encoder_id,
+                torch_dtype=self.text_encoder_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            text_encoder = LlamaModel.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="text_encoder",
+                torch_dtype=self.text_encoder_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        if self.text_encoder_2_id is not None:
+            text_encoder_2 = CLIPTextModel.from_pretrained(
+                self.text_encoder_2_id,
+                torch_dtype=self.text_encoder_2_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            text_encoder_2 = CLIPTextModel.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="text_encoder_2",
+                torch_dtype=self.text_encoder_2_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        return {
+            "tokenizer": tokenizer,
+            "tokenizer_2": tokenizer_2,
+            "text_encoder": text_encoder,
+            "text_encoder_2": text_encoder_2,
+        }
+
+    def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+        if self.vae_id is not None:
+            vae = AutoencoderKLHunyuanVideo.from_pretrained(
+                self.vae_id,
+                torch_dtype=self.vae_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            vae = AutoencoderKLHunyuanVideo.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="vae",
+                torch_dtype=self.vae_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        return {"vae": vae}
+
+    def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+        if self.transformer_id is not None:
+            transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+                self.transformer_id,
+                torch_dtype=self.transformer_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="transformer",
+                torch_dtype=self.transformer_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        scheduler = FlowMatchEulerDiscreteScheduler()
+
+        return {"transformer": transformer, "scheduler": scheduler}
+
+    def load_pipeline(
+        self,
+        tokenizer: Optional[AutoTokenizer] = None,
+        tokenizer_2: Optional[CLIPTokenizer] = None,
+        text_encoder: Optional[LlamaModel] = None,
+        text_encoder_2: Optional[CLIPTextModel] = None,
+        transformer: Optional[HunyuanVideoTransformer3DModel] = None,
+        vae: Optional[AutoencoderKLHunyuanVideo] = None,
+        scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+        enable_slicing: bool = False,
+        enable_tiling: bool = False,
+        enable_model_cpu_offload: bool = False,
+        training: bool = False,
+        **kwargs,
+    ) -> HunyuanVideoPipeline:
+        components = {
+            "tokenizer": tokenizer,
+            "tokenizer_2": tokenizer_2,
+            "text_encoder": text_encoder,
+            "text_encoder_2": text_encoder_2,
+            "transformer": transformer,
+            "vae": vae,
+            "scheduler": scheduler,
+        }
+        components = get_non_null_items(components)
+
+        pipe = HunyuanVideoPipeline.from_pretrained(
+            self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+        )
+        pipe.text_encoder.to(self.text_encoder_dtype)
+        pipe.text_encoder_2.to(self.text_encoder_2_dtype)
+        pipe.vae.to(self.vae_dtype)
+
+        if not training:
+            pipe.transformer.to(self.transformer_dtype)
+
+        if enable_slicing:
+            pipe.vae.enable_slicing()
+        if enable_tiling:
+            pipe.vae.enable_tiling()
+        if enable_model_cpu_offload:
+            pipe.enable_model_cpu_offload()
+
+        return pipe
+
+    @torch.no_grad()
+    def prepare_conditions(
+        self,
+        tokenizer: AutoTokenizer,
+        tokenizer_2: CLIPTokenizer,
+        text_encoder: LlamaModel,
+        text_encoder_2: CLIPTextModel,
+        caption: str,
+        max_sequence_length: int = 256,
+        **kwargs,
+    ) -> Dict[str, Any]:
+        conditions = {
+            "tokenizer": tokenizer,
+            "tokenizer_2": tokenizer_2,
+            "text_encoder": text_encoder,
+            "text_encoder_2": text_encoder_2,
+            "caption": caption,
+            "max_sequence_length": max_sequence_length,
+            **kwargs,
+        }
+        input_keys = set(conditions.keys())
+        conditions = super().prepare_conditions(**conditions)
+        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+        return conditions
+
+    @torch.no_grad()
+    def prepare_latents(
+        self,
+        vae: AutoencoderKLHunyuanVideo,
+        image: Optional[torch.Tensor] = None,
+        video: Optional[torch.Tensor] = None,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+        **kwargs,
+    ) -> Dict[str, torch.Tensor]:
+        conditions = {
+            "vae": vae,
+            "image": image,
+            "video": video,
+            "generator": generator,
+            "compute_posterior": compute_posterior,
+            **kwargs,
+        }
+        input_keys = set(conditions.keys())
+        conditions = super().prepare_latents(**conditions)
+        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+        return conditions
+
+    def forward(
+        self,
+        transformer: HunyuanVideoTransformer3DModel,
+        condition_model_conditions: Dict[str, torch.Tensor],
+        latent_model_conditions: Dict[str, torch.Tensor],
+        sigmas: torch.Tensor,
+        guidance: float = 1.0,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, ...]:
+        if compute_posterior:
+            latents = latent_model_conditions.pop("latents")
+        else:
+            posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
+            latents = posterior.sample(generator=generator)
+            del posterior
+
+        latents = latents * self.vae_config.scaling_factor
+        noise = torch.zeros_like(latents).normal_(generator=generator)
+        noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
+
+        timesteps = (sigmas.flatten() * 1000.0).long()
+        guidance = latents.new_full((latents.size(0),), fill_value=guidance) * 1000.0
+
+        latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+        latent_model_conditions["guidance"] = guidance
+
+        pred = transformer(
+            **latent_model_conditions,
+            **condition_model_conditions,
+            timestep=timesteps,
+            return_dict=False,
+        )[0]
+        target = FF.flow_match_target(noise, latents)
+
+        return pred, target, sigmas
+
+    def validation(
+        self,
+        pipeline: HunyuanVideoPipeline,
+        prompt: str,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_frames: Optional[int] = None,
+        num_inference_steps: int = 50,
+        generator: Optional[torch.Generator] = None,
+        **kwargs,
+    ) -> List[ArtifactType]:
+        generation_kwargs = {
+            "prompt": prompt,
+            "height": height,
+            "width": width,
+            "num_frames": num_frames,
+            "num_inference_steps": num_inference_steps,
+            "generator": generator,
+            "return_dict": True,
+            "output_type": "pil",
+        }
+        generation_kwargs = get_non_null_items(generation_kwargs)
+        video = pipeline(**generation_kwargs).frames[0]
+        return [data.VideoArtifact(value=video)]
+
+    def _save_lora_weights(
+        self,
+        directory: str,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        # TODO(aryan): this needs refactoring
+        if transformer_state_dict is not None:
+            HunyuanVideoPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
+        if scheduler is not None:
+            scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+    def _save_model(
+        self,
+        directory: str,
+        transformer: HunyuanVideoTransformer3DModel,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+    ) -> None:
+        # TODO(aryan): this needs refactoring
+        if transformer_state_dict is not None:
+            with init_empty_weights():
+                transformer_copy = HunyuanVideoTransformer3DModel.from_config(transformer.config)
+            transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+            transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+        if scheduler is not None:
+            scheduler.save_pretrained(os.path.join(directory, "scheduler"))
diff --git a/finetrainers/models/hunyuan_video/full_finetune.py b/finetrainers/models/hunyuan_video/full_finetune.py
deleted file mode 100644
index 65e73f5451cacbc4ea4540b6bfcd3c3bd1b9e531..0000000000000000000000000000000000000000
--- a/finetrainers/models/hunyuan_video/full_finetune.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from diffusers import HunyuanVideoPipeline
-
-from .lora import (
-    collate_fn_t2v,
-    forward_pass,
-    initialize_pipeline,
-    load_condition_models,
-    load_diffusion_models,
-    load_latent_models,
-    post_latent_preparation,
-    prepare_conditions,
-    prepare_latents,
-    validation,
-)
-
-
-# TODO(aryan): refactor into model specs for better re-use
-HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG = {
-    "pipeline_cls": HunyuanVideoPipeline,
-    "load_condition_models": load_condition_models,
-    "load_latent_models": load_latent_models,
-    "load_diffusion_models": load_diffusion_models,
-    "initialize_pipeline": initialize_pipeline,
-    "prepare_conditions": prepare_conditions,
-    "prepare_latents": prepare_latents,
-    "post_latent_preparation": post_latent_preparation,
-    "collate_fn": collate_fn_t2v,
-    "forward_pass": forward_pass,
-    "validation": validation,
-}
diff --git a/finetrainers/models/hunyuan_video/lora.py b/finetrainers/models/hunyuan_video/lora.py
deleted file mode 100644
index 1d8ccd1f61f3131f9fb0c2ba1235070ce7439ba0..0000000000000000000000000000000000000000
--- a/finetrainers/models/hunyuan_video/lora.py
+++ /dev/null
@@ -1,368 +0,0 @@
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-from accelerate.logging import get_logger
-from diffusers import (
-    AutoencoderKLHunyuanVideo,
-    FlowMatchEulerDiscreteScheduler,
-    HunyuanVideoPipeline,
-    HunyuanVideoTransformer3DModel,
-)
-from PIL import Image
-from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer
-
-
-logger = get_logger("finetrainers")  # pylint: disable=invalid-name
-
-
-def load_condition_models(
-    model_id: str = "hunyuanvideo-community/HunyuanVideo",
-    text_encoder_dtype: torch.dtype = torch.float16,
-    text_encoder_2_dtype: torch.dtype = torch.float16,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    **kwargs,
-) -> Dict[str, nn.Module]:
-    tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
-    text_encoder = LlamaModel.from_pretrained(
-        model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
-    )
-    tokenizer_2 = CLIPTokenizer.from_pretrained(
-        model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir
-    )
-    text_encoder_2 = CLIPTextModel.from_pretrained(
-        model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir
-    )
-    return {
-        "tokenizer": tokenizer,
-        "text_encoder": text_encoder,
-        "tokenizer_2": tokenizer_2,
-        "text_encoder_2": text_encoder_2,
-    }
-
-
-def load_latent_models(
-    model_id: str = "hunyuanvideo-community/HunyuanVideo",
-    vae_dtype: torch.dtype = torch.float16,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    **kwargs,
-) -> Dict[str, nn.Module]:
-    vae = AutoencoderKLHunyuanVideo.from_pretrained(
-        model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
-    )
-    return {"vae": vae}
-
-
-def load_diffusion_models(
-    model_id: str = "hunyuanvideo-community/HunyuanVideo",
-    transformer_dtype: torch.dtype = torch.bfloat16,
-    shift: float = 1.0,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    **kwargs,
-) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]:
-    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
-        model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
-    )
-    scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
-    return {"transformer": transformer, "scheduler": scheduler}
-
-
-def initialize_pipeline(
-    model_id: str = "hunyuanvideo-community/HunyuanVideo",
-    text_encoder_dtype: torch.dtype = torch.float16,
-    text_encoder_2_dtype: torch.dtype = torch.float16,
-    transformer_dtype: torch.dtype = torch.bfloat16,
-    vae_dtype: torch.dtype = torch.float16,
-    tokenizer: Optional[LlamaTokenizer] = None,
-    text_encoder: Optional[LlamaModel] = None,
-    tokenizer_2: Optional[CLIPTokenizer] = None,
-    text_encoder_2: Optional[CLIPTextModel] = None,
-    transformer: Optional[HunyuanVideoTransformer3DModel] = None,
-    vae: Optional[AutoencoderKLHunyuanVideo] = None,
-    scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
-    device: Optional[torch.device] = None,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    enable_slicing: bool = False,
-    enable_tiling: bool = False,
-    enable_model_cpu_offload: bool = False,
-    is_training: bool = False,
-    **kwargs,
-) -> HunyuanVideoPipeline:
-    component_name_pairs = [
-        ("tokenizer", tokenizer),
-        ("text_encoder", text_encoder),
-        ("tokenizer_2", tokenizer_2),
-        ("text_encoder_2", text_encoder_2),
-        ("transformer", transformer),
-        ("vae", vae),
-        ("scheduler", scheduler),
-    ]
-    components = {}
-    for name, component in component_name_pairs:
-        if component is not None:
-            components[name] = component
-
-    pipe = HunyuanVideoPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
-    pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
-    pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=text_encoder_2_dtype)
-    pipe.vae = pipe.vae.to(dtype=vae_dtype)
-
-    # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
-    # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
-    # DDP optimizer step.
-    if not is_training:
-        pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
-
-    if enable_slicing:
-        pipe.vae.enable_slicing()
-    if enable_tiling:
-        pipe.vae.enable_tiling()
-
-    if enable_model_cpu_offload:
-        pipe.enable_model_cpu_offload(device=device)
-    else:
-        pipe.to(device=device)
-
-    return pipe
-
-
-def prepare_conditions(
-    tokenizer: LlamaTokenizer,
-    text_encoder: LlamaModel,
-    tokenizer_2: CLIPTokenizer,
-    text_encoder_2: CLIPTextModel,
-    prompt: Union[str, List[str]],
-    guidance: float = 1.0,
-    device: Optional[torch.device] = None,
-    dtype: Optional[torch.dtype] = None,
-    max_sequence_length: int = 256,
-    # TODO(aryan): make configurable
-    prompt_template: Dict[str, Any] = {
-        "template": (
-            "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
-            "1. The main content and theme of the video."
-            "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
-            "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
-            "4. background environment, light, style and atmosphere."
-            "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
-            "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
-        ),
-        "crop_start": 95,
-    },
-    **kwargs,
-) -> torch.Tensor:
-    device = device or text_encoder.device
-    dtype = dtype or text_encoder.dtype
-
-    if isinstance(prompt, str):
-        prompt = [prompt]
-
-    conditions = {}
-    conditions.update(
-        _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length)
-    )
-    conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype))
-
-    guidance = torch.tensor([guidance], device=device, dtype=dtype) * 1000.0
-    conditions["guidance"] = guidance
-
-    return conditions
-
-
-def prepare_latents(
-    vae: AutoencoderKLHunyuanVideo,
-    image_or_video: torch.Tensor,
-    device: Optional[torch.device] = None,
-    dtype: Optional[torch.dtype] = None,
-    generator: Optional[torch.Generator] = None,
-    precompute: bool = False,
-    **kwargs,
-) -> torch.Tensor:
-    device = device or vae.device
-    dtype = dtype or vae.dtype
-
-    if image_or_video.ndim == 4:
-        image_or_video = image_or_video.unsqueeze(2)
-    assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
-
-    image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
-    image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous()  # [B, C, F, H, W] -> [B, F, C, H, W]
-    if not precompute:
-        latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
-        latents = latents * vae.config.scaling_factor
-        latents = latents.to(dtype=dtype)
-        return {"latents": latents}
-    else:
-        if vae.use_slicing and image_or_video.shape[0] > 1:
-            encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
-            h = torch.cat(encoded_slices)
-        else:
-            h = vae._encode(image_or_video)
-        return {"latents": h}
-
-
-def post_latent_preparation(
-    vae_config: Dict[str, Any],
-    latents: torch.Tensor,
-    **kwargs,
-) -> torch.Tensor:
-    latents = latents * vae_config.scaling_factor
-    return {"latents": latents}
-
-
-def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
-    return {
-        "prompts": [x["prompt"] for x in batch[0]],
-        "videos": torch.stack([x["video"] for x in batch[0]]),
-    }
-
-
-def forward_pass(
-    transformer: HunyuanVideoTransformer3DModel,
-    prompt_embeds: torch.Tensor,
-    pooled_prompt_embeds: torch.Tensor,
-    prompt_attention_mask: torch.Tensor,
-    guidance: torch.Tensor,
-    latents: torch.Tensor,
-    noisy_latents: torch.Tensor,
-    timesteps: torch.LongTensor,
-    **kwargs,
-) -> torch.Tensor:
-    denoised_latents = transformer(
-        hidden_states=noisy_latents,
-        timestep=timesteps,
-        encoder_hidden_states=prompt_embeds,
-        pooled_projections=pooled_prompt_embeds,
-        encoder_attention_mask=prompt_attention_mask,
-        guidance=guidance,
-        return_dict=False,
-    )[0]
-
-    return {"latents": denoised_latents}
-
-
-def validation(
-    pipeline: HunyuanVideoPipeline,
-    prompt: str,
-    image: Optional[Image.Image] = None,
-    video: Optional[List[Image.Image]] = None,
-    height: Optional[int] = None,
-    width: Optional[int] = None,
-    num_frames: Optional[int] = None,
-    num_videos_per_prompt: int = 1,
-    generator: Optional[torch.Generator] = None,
-    **kwargs,
-):
-    generation_kwargs = {
-        "prompt": prompt,
-        "height": height,
-        "width": width,
-        "num_frames": num_frames,
-        "num_inference_steps": 30,
-        "num_videos_per_prompt": num_videos_per_prompt,
-        "generator": generator,
-        "return_dict": True,
-        "output_type": "pil",
-    }
-    generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
-    output = pipeline(**generation_kwargs).frames[0]
-    return [("video", output)]
-
-
-def _get_llama_prompt_embeds(
-    tokenizer: LlamaTokenizer,
-    text_encoder: LlamaModel,
-    prompt: List[str],
-    prompt_template: Dict[str, Any],
-    device: torch.device,
-    dtype: torch.dtype,
-    max_sequence_length: int = 256,
-    num_hidden_layers_to_skip: int = 2,
-) -> Tuple[torch.Tensor, torch.Tensor]:
-    batch_size = len(prompt)
-    prompt = [prompt_template["template"].format(p) for p in prompt]
-
-    crop_start = prompt_template.get("crop_start", None)
-    if crop_start is None:
-        prompt_template_input = tokenizer(
-            prompt_template["template"],
-            padding="max_length",
-            return_tensors="pt",
-            return_length=False,
-            return_overflowing_tokens=False,
-            return_attention_mask=False,
-        )
-        crop_start = prompt_template_input["input_ids"].shape[-1]
-        # Remove <|eot_id|> token and placeholder {}
-        crop_start -= 2
-
-    max_sequence_length += crop_start
-    text_inputs = tokenizer(
-        prompt,
-        max_length=max_sequence_length,
-        padding="max_length",
-        truncation=True,
-        return_tensors="pt",
-        return_length=False,
-        return_overflowing_tokens=False,
-        return_attention_mask=True,
-    )
-    text_input_ids = text_inputs.input_ids.to(device=device)
-    prompt_attention_mask = text_inputs.attention_mask.to(device=device)
-
-    prompt_embeds = text_encoder(
-        input_ids=text_input_ids,
-        attention_mask=prompt_attention_mask,
-        output_hidden_states=True,
-    ).hidden_states[-(num_hidden_layers_to_skip + 1)]
-    prompt_embeds = prompt_embeds.to(dtype=dtype)
-
-    if crop_start is not None and crop_start > 0:
-        prompt_embeds = prompt_embeds[:, crop_start:]
-        prompt_attention_mask = prompt_attention_mask[:, crop_start:]
-
-    prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
-
-    return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
-
-
-def _get_clip_prompt_embeds(
-    tokenizer_2: CLIPTokenizer,
-    text_encoder_2: CLIPTextModel,
-    prompt: Union[str, List[str]],
-    device: torch.device,
-    dtype: torch.dtype,
-    max_sequence_length: int = 77,
-) -> torch.Tensor:
-    text_inputs = tokenizer_2(
-        prompt,
-        padding="max_length",
-        max_length=max_sequence_length,
-        truncation=True,
-        return_tensors="pt",
-    )
-
-    prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output
-    prompt_embeds = prompt_embeds.to(dtype=dtype)
-
-    return {"pooled_prompt_embeds": prompt_embeds}
-
-
-# TODO(aryan): refactor into model specs for better re-use
-HUNYUAN_VIDEO_T2V_LORA_CONFIG = {
-    "pipeline_cls": HunyuanVideoPipeline,
-    "load_condition_models": load_condition_models,
-    "load_latent_models": load_latent_models,
-    "load_diffusion_models": load_diffusion_models,
-    "initialize_pipeline": initialize_pipeline,
-    "prepare_conditions": prepare_conditions,
-    "prepare_latents": prepare_latents,
-    "post_latent_preparation": post_latent_preparation,
-    "collate_fn": collate_fn_t2v,
-    "forward_pass": forward_pass,
-    "validation": validation,
-}
diff --git a/finetrainers/models/ltx_video/__init__.py b/finetrainers/models/ltx_video/__init__.py
index 69391cdf9157831343a5fb73b237de618e8288bd..ff4e3550d54bb33fac80dd2d075ad2846eeeed46 100644
--- a/finetrainers/models/ltx_video/__init__.py
+++ b/finetrainers/models/ltx_video/__init__.py
@@ -1,2 +1 @@
-from .full_finetune import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG
-from .lora import LTX_VIDEO_T2V_LORA_CONFIG
+from .base_specification import LTXVideoModelSpecification
diff --git a/finetrainers/models/ltx_video/base_specification.py b/finetrainers/models/ltx_video/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..37f6ceced22e53a5383c1fb42c7f6b4d2dc08ad2
--- /dev/null
+++ b/finetrainers/models/ltx_video/base_specification.py
@@ -0,0 +1,522 @@
+import os
+import random
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+    AutoencoderKLLTXVideo,
+    FlowMatchEulerDiscreteScheduler,
+    LTXImageToVideoPipeline,
+    LTXPipeline,
+    LTXVideoTransformer3DModel,
+)
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from PIL.Image import Image
+from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
+
+from ... import data
+from ... import functional as FF
+from ...logging import get_logger
+from ...parallel import ParallelBackendEnum
+from ...processors import ProcessorMixin, T5Processor
+from ...typing import ArtifactType, SchedulerType
+from ...utils import get_non_null_items
+from ..modeling_utils import ModelSpecification
+
+
+logger = get_logger()
+
+
+class LTXLatentEncodeProcessor(ProcessorMixin):
+    r"""
+    Processor to encode image/video into latents using the LTX VAE.
+
+    Args:
+        output_names (`List[str]`):
+            The names of the outputs that the processor returns. The outputs are in the following order:
+            - latents: The latents of the input image/video.
+            - num_frames: The number of frames in the input video.
+            - height: The height of the input image/video.
+            - width: The width of the input image/video.
+            - latents_mean: The latent channel means from the VAE state dict.
+            - latents_std: The latent channel standard deviations from the VAE state dict.
+    """
+
+    def __init__(self, output_names: List[str]):
+        super().__init__()
+        self.output_names = output_names
+        assert len(self.output_names) == 6
+
+    def forward(
+        self,
+        vae: AutoencoderKLLTXVideo,
+        image: Optional[torch.Tensor] = None,
+        video: Optional[torch.Tensor] = None,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+    ) -> Dict[str, torch.Tensor]:
+        device = vae.device
+        dtype = vae.dtype
+
+        if image is not None:
+            video = image.unsqueeze(1)
+
+        assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
+        video = video.to(device=device, dtype=vae.dtype)
+        video = video.permute(0, 2, 1, 3, 4).contiguous()  # [B, F, C, H, W] -> [B, C, F, H, W]
+
+        if compute_posterior:
+            latents = vae.encode(video).latent_dist.sample(generator=generator)
+            latents = latents.to(dtype=dtype)
+        else:
+            if vae.use_slicing and video.shape[0] > 1:
+                encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
+                moments = torch.cat(encoded_slices)
+            else:
+                moments = vae._encode(video)
+            latents = moments.to(dtype=dtype)
+
+        _, _, num_frames, height, width = latents.shape
+
+        return {
+            self.output_names[0]: latents,
+            self.output_names[1]: num_frames,
+            self.output_names[2]: height,
+            self.output_names[3]: width,
+            self.output_names[4]: vae.latents_mean,
+            self.output_names[5]: vae.latents_std,
+        }
+
+
+class LTXVideoModelSpecification(ModelSpecification):
+    def __init__(
+        self,
+        pretrained_model_name_or_path: str = "Lightricks/LTX-Video",
+        tokenizer_id: Optional[str] = None,
+        text_encoder_id: Optional[str] = None,
+        transformer_id: Optional[str] = None,
+        vae_id: Optional[str] = None,
+        text_encoder_dtype: torch.dtype = torch.bfloat16,
+        transformer_dtype: torch.dtype = torch.bfloat16,
+        vae_dtype: torch.dtype = torch.bfloat16,
+        revision: Optional[str] = None,
+        cache_dir: Optional[str] = None,
+        condition_model_processors: List[ProcessorMixin] = None,
+        latent_model_processors: List[ProcessorMixin] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            pretrained_model_name_or_path=pretrained_model_name_or_path,
+            tokenizer_id=tokenizer_id,
+            text_encoder_id=text_encoder_id,
+            transformer_id=transformer_id,
+            vae_id=vae_id,
+            text_encoder_dtype=text_encoder_dtype,
+            transformer_dtype=transformer_dtype,
+            vae_dtype=vae_dtype,
+            revision=revision,
+            cache_dir=cache_dir,
+        )
+
+        if condition_model_processors is None:
+            condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
+        if latent_model_processors is None:
+            latent_model_processors = [
+                LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"])
+            ]
+
+        self.condition_model_processors = condition_model_processors
+        self.latent_model_processors = latent_model_processors
+
+    @property
+    def _resolution_dim_keys(self):
+        return {
+            "latents": (2, 3, 4),
+        }
+
+    def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+        if self.tokenizer_id is not None:
+            tokenizer = AutoTokenizer.from_pretrained(
+                self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
+            )
+        else:
+            tokenizer = T5Tokenizer.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="tokenizer",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        if self.text_encoder_id is not None:
+            text_encoder = AutoModel.from_pretrained(
+                self.text_encoder_id,
+                torch_dtype=self.text_encoder_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            text_encoder = T5EncoderModel.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="text_encoder",
+                torch_dtype=self.text_encoder_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        return {"tokenizer": tokenizer, "text_encoder": text_encoder}
+
+    def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+        if self.vae_id is not None:
+            vae = AutoencoderKLLTXVideo.from_pretrained(
+                self.vae_id,
+                torch_dtype=self.vae_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            vae = AutoencoderKLLTXVideo.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="vae",
+                torch_dtype=self.vae_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        return {"vae": vae}
+
+    def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+        if self.transformer_id is not None:
+            transformer = LTXVideoTransformer3DModel.from_pretrained(
+                self.transformer_id,
+                torch_dtype=self.transformer_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            transformer = LTXVideoTransformer3DModel.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="transformer",
+                torch_dtype=self.transformer_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        scheduler = FlowMatchEulerDiscreteScheduler()
+
+        return {"transformer": transformer, "scheduler": scheduler}
+
+    def load_pipeline(
+        self,
+        tokenizer: Optional[T5Tokenizer] = None,
+        text_encoder: Optional[T5EncoderModel] = None,
+        transformer: Optional[LTXVideoTransformer3DModel] = None,
+        vae: Optional[AutoencoderKLLTXVideo] = None,
+        scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+        enable_slicing: bool = False,
+        enable_tiling: bool = False,
+        enable_model_cpu_offload: bool = False,
+        training: bool = False,
+        **kwargs,
+    ) -> LTXPipeline:
+        components = {
+            "tokenizer": tokenizer,
+            "text_encoder": text_encoder,
+            "transformer": transformer,
+            "vae": vae,
+            "scheduler": scheduler,
+        }
+        components = get_non_null_items(components)
+
+        pipe = LTXPipeline.from_pretrained(
+            self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+        )
+        pipe.text_encoder.to(self.text_encoder_dtype)
+        pipe.vae.to(self.vae_dtype)
+
+        if not training:
+            pipe.transformer.to(self.transformer_dtype)
+
+        if enable_slicing:
+            pipe.vae.enable_slicing()
+        if enable_tiling:
+            pipe.vae.enable_tiling()
+        if enable_model_cpu_offload:
+            pipe.enable_model_cpu_offload()
+
+        return pipe
+
+    @torch.no_grad()
+    def prepare_conditions(
+        self,
+        tokenizer: T5Tokenizer,
+        text_encoder: T5EncoderModel,
+        caption: str,
+        max_sequence_length: int = 128,
+        **kwargs,
+    ) -> Dict[str, Any]:
+        conditions = {
+            "tokenizer": tokenizer,
+            "text_encoder": text_encoder,
+            "caption": caption,
+            "max_sequence_length": max_sequence_length,
+            **kwargs,
+        }
+        input_keys = set(conditions.keys())
+        conditions = super().prepare_conditions(**conditions)
+        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+        return conditions
+
+    @torch.no_grad()
+    def prepare_latents(
+        self,
+        vae: AutoencoderKLLTXVideo,
+        image: Optional[torch.Tensor] = None,
+        video: Optional[torch.Tensor] = None,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+        **kwargs,
+    ) -> Dict[str, torch.Tensor]:
+        conditions = {
+            "vae": vae,
+            "image": image,
+            "video": video,
+            "generator": generator,
+            "compute_posterior": compute_posterior,
+            **kwargs,
+        }
+        input_keys = set(conditions.keys())
+        conditions = super().prepare_latents(**conditions)
+        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+        return conditions
+
+    def forward(
+        self,
+        transformer: LTXVideoTransformer3DModel,
+        condition_model_conditions: Dict[str, torch.Tensor],
+        latent_model_conditions: Dict[str, torch.Tensor],
+        sigmas: torch.Tensor,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, ...]:
+        # TODO(aryan): make this configurable? Should it be?
+        first_frame_conditioning_p = 0.1
+        min_first_frame_sigma = 0.25
+
+        if compute_posterior:
+            latents = latent_model_conditions.pop("latents")
+        else:
+            posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
+            latents = posterior.sample(generator=generator)
+            del posterior
+
+        latents_mean = latent_model_conditions.pop("latents_mean")
+        latents_std = latent_model_conditions.pop("latents_std")
+
+        latents = self._normalize_latents(latents, latents_mean, latents_std)
+        noise = torch.zeros_like(latents).normal_(generator=generator)
+
+        if random.random() < first_frame_conditioning_p:
+            # Based on Section 2.4 of the paper, it mentions that the first frame timesteps should be a small random value.
+            # Making as estimated guess, we limit the sigmas to be at least 0.2.
+            # torch.rand_like returns values in [0, 1). We want to make sure that the first frame sigma is <= actual sigmas
+            # for image conditioning. In order to do this, we rescale by multiplying with sigmas so the range is [0, sigmas).
+            first_frame_sigma = torch.rand_like(sigmas) * sigmas
+            first_frame_sigma = torch.min(first_frame_sigma, sigmas.new_full(sigmas.shape, min_first_frame_sigma))
+
+            latents_first_frame, latents_rest = latents[:, :, :1], latents[:, :, 1:]
+            noisy_latents_first_frame = FF.flow_match_xt(latents_first_frame, noise[:, :, :1], first_frame_sigma)
+            noisy_latents_remaining = FF.flow_match_xt(latents_rest, noise[:, :, 1:], sigmas)
+            noisy_latents = torch.cat([noisy_latents_first_frame, noisy_latents_remaining], dim=2)
+        else:
+            noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
+
+        patch_size = self.transformer_config.patch_size
+        patch_size_t = self.transformer_config.patch_size_t
+
+        latents = self._pack_latents(latents, patch_size, patch_size_t)
+        noise = self._pack_latents(noise, patch_size, patch_size_t)
+        noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
+
+        sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
+
+        latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+        condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
+        condition_model_conditions["encoder_attention_mask"] = condition_model_conditions.pop("prompt_attention_mask")
+
+        # TODO(aryan): make this configurable
+        frame_rate = 25
+        temporal_compression_ratio = 8
+        vae_spatial_compression_ratio = 32
+        latent_frame_rate = frame_rate / temporal_compression_ratio
+
+        rope_interpolation_scale = [
+            1 / latent_frame_rate,
+            vae_spatial_compression_ratio,
+            vae_spatial_compression_ratio,
+        ]
+        timesteps = (sigmas * 1000.0).long()
+
+        pred = transformer(
+            **latent_model_conditions,
+            **condition_model_conditions,
+            timestep=timesteps,
+            rope_interpolation_scale=rope_interpolation_scale,
+            return_dict=False,
+        )[0]
+        target = FF.flow_match_target(noise, latents)
+
+        return pred, target, sigmas
+
+    def validation(
+        self,
+        pipeline: LTXPipeline,
+        prompt: str,
+        image: Optional[Image] = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_frames: Optional[int] = None,
+        frame_rate: int = 25,
+        num_inference_steps: int = 50,
+        generator: Optional[torch.Generator] = None,
+        **kwargs,
+    ) -> List[ArtifactType]:
+        if image is not None:
+            pipeline = LTXImageToVideoPipeline.from_pipe(pipeline)
+
+        generation_kwargs = {
+            "prompt": prompt,
+            "image": image,
+            "height": height,
+            "width": width,
+            "num_frames": num_frames,
+            "frame_rate": frame_rate,
+            "num_inference_steps": num_inference_steps,
+            "generator": generator,
+            "return_dict": True,
+            "output_type": "pil",
+        }
+        generation_kwargs = get_non_null_items(generation_kwargs)
+        video = pipeline(**generation_kwargs).frames[0]
+        return [data.VideoArtifact(value=video)]
+
+    def _save_lora_weights(
+        self,
+        directory: str,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        # TODO(aryan): this needs refactoring
+        if transformer_state_dict is not None:
+            LTXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
+        if scheduler is not None:
+            scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+    def _save_model(
+        self,
+        directory: str,
+        transformer: LTXVideoTransformer3DModel,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+    ) -> None:
+        # TODO(aryan): this needs refactoring
+        if transformer_state_dict is not None:
+            with init_empty_weights():
+                transformer_copy = LTXVideoTransformer3DModel.from_config(transformer.config)
+            transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+            transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+        if scheduler is not None:
+            scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+    def apply_tensor_parallel(
+        self,
+        backend: ParallelBackendEnum,
+        device_mesh: torch.distributed.DeviceMesh,
+        transformer: LTXVideoTransformer3DModel,
+        **kwargs,
+    ) -> None:
+        if backend == ParallelBackendEnum.PTD:
+            _apply_tensor_parallel_ptd(device_mesh, transformer)
+        else:
+            raise NotImplementedError(f"Parallel backend {backend} is not supported for LTXVideoModelSpecification")
+
+    @staticmethod
+    def _normalize_latents(
+        latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+    ) -> torch.Tensor:
+        # Normalize latents across the channel dimension [B, C, F, H, W]
+        latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+        latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+        latents = (latents - latents_mean) * scaling_factor / latents_std
+        return latents
+
+    @staticmethod
+    def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
+        # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
+        # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
+        # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
+        # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
+        batch_size, num_channels, num_frames, height, width = latents.shape
+        post_patch_num_frames = num_frames // patch_size_t
+        post_patch_height = height // patch_size
+        post_patch_width = width // patch_size
+        latents = latents.reshape(
+            batch_size,
+            -1,
+            post_patch_num_frames,
+            patch_size_t,
+            post_patch_height,
+            patch_size,
+            post_patch_width,
+            patch_size,
+        )
+        latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
+        return latents
+
+
+def _apply_tensor_parallel_ptd(
+    device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel
+) -> None:
+    from torch.distributed.tensor.parallel import parallelize_module
+    from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel
+
+    transformer_plan = {
+        # ===== Condition embeddings =====
+        # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(),
+        # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)),
+        # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()),
+        # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())),
+        # "caption_projection.linear_1": ColwiseParallel(),
+        # "caption_projection.linear_2": RowwiseParallel(),
+        # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False),
+        # ===== =====
+    }
+
+    for block in transformer.transformer_blocks:
+        block_plan = {}
+
+        # ===== Attention =====
+        # 8 all-to-all, 3 all-reduce
+        # block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False)
+        # block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False)
+        # block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False)
+        # block_plan["attn1.norm_q"] = SequenceParallel()
+        # block_plan["attn1.norm_k"] = SequenceParallel()
+        # block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
+        # block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False)
+        # block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False)
+        # block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False)
+        # block_plan["attn2.norm_q"] = SequenceParallel()
+        # block_plan["attn2.norm_k"] = SequenceParallel()
+        # block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
+        # ===== =====
+
+        block_plan["ff.net.0.proj"] = ColwiseParallel()
+        block_plan["ff.net.2"] = RowwiseParallel()
+
+        parallelize_module(block, device_mesh, block_plan)
+
+    parallelize_module(transformer, device_mesh, transformer_plan)
diff --git a/finetrainers/models/ltx_video/full_finetune.py b/finetrainers/models/ltx_video/full_finetune.py
deleted file mode 100644
index ca799ea6f1b4b075efa9f2c27bb69564832bcb7d..0000000000000000000000000000000000000000
--- a/finetrainers/models/ltx_video/full_finetune.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from diffusers import LTXPipeline
-
-from .lora import (
-    collate_fn_t2v,
-    forward_pass,
-    initialize_pipeline,
-    load_condition_models,
-    load_diffusion_models,
-    load_latent_models,
-    post_latent_preparation,
-    prepare_conditions,
-    prepare_latents,
-    validation,
-)
-
-
-# TODO(aryan): refactor into model specs for better re-use
-LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG = {
-    "pipeline_cls": LTXPipeline,
-    "load_condition_models": load_condition_models,
-    "load_latent_models": load_latent_models,
-    "load_diffusion_models": load_diffusion_models,
-    "initialize_pipeline": initialize_pipeline,
-    "prepare_conditions": prepare_conditions,
-    "prepare_latents": prepare_latents,
-    "post_latent_preparation": post_latent_preparation,
-    "collate_fn": collate_fn_t2v,
-    "forward_pass": forward_pass,
-    "validation": validation,
-}
diff --git a/finetrainers/models/ltx_video/lora.py b/finetrainers/models/ltx_video/lora.py
deleted file mode 100644
index bdd6ffa3e3b91564ff88222a2314f66c6e465116..0000000000000000000000000000000000000000
--- a/finetrainers/models/ltx_video/lora.py
+++ /dev/null
@@ -1,331 +0,0 @@
-from typing import Dict, List, Optional, Union
-
-import torch
-import torch.nn as nn
-from accelerate.logging import get_logger
-from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
-from PIL import Image
-from transformers import T5EncoderModel, T5Tokenizer
-
-
-logger = get_logger("finetrainers")  # pylint: disable=invalid-name
-
-
-def load_condition_models(
-    model_id: str = "Lightricks/LTX-Video",
-    text_encoder_dtype: torch.dtype = torch.bfloat16,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    **kwargs,
-) -> Dict[str, nn.Module]:
-    tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
-    text_encoder = T5EncoderModel.from_pretrained(
-        model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
-    )
-    return {"tokenizer": tokenizer, "text_encoder": text_encoder}
-
-
-def load_latent_models(
-    model_id: str = "Lightricks/LTX-Video",
-    vae_dtype: torch.dtype = torch.bfloat16,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    **kwargs,
-) -> Dict[str, nn.Module]:
-    vae = AutoencoderKLLTXVideo.from_pretrained(
-        model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
-    )
-    return {"vae": vae}
-
-
-def load_diffusion_models(
-    model_id: str = "Lightricks/LTX-Video",
-    transformer_dtype: torch.dtype = torch.bfloat16,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    **kwargs,
-) -> Dict[str, nn.Module]:
-    transformer = LTXVideoTransformer3DModel.from_pretrained(
-        model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
-    )
-    scheduler = FlowMatchEulerDiscreteScheduler()
-    return {"transformer": transformer, "scheduler": scheduler}
-
-
-def initialize_pipeline(
-    model_id: str = "Lightricks/LTX-Video",
-    text_encoder_dtype: torch.dtype = torch.bfloat16,
-    transformer_dtype: torch.dtype = torch.bfloat16,
-    vae_dtype: torch.dtype = torch.bfloat16,
-    tokenizer: Optional[T5Tokenizer] = None,
-    text_encoder: Optional[T5EncoderModel] = None,
-    transformer: Optional[LTXVideoTransformer3DModel] = None,
-    vae: Optional[AutoencoderKLLTXVideo] = None,
-    scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
-    device: Optional[torch.device] = None,
-    revision: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    enable_slicing: bool = False,
-    enable_tiling: bool = False,
-    enable_model_cpu_offload: bool = False,
-    is_training: bool = False,
-    **kwargs,
-) -> LTXPipeline:
-    component_name_pairs = [
-        ("tokenizer", tokenizer),
-        ("text_encoder", text_encoder),
-        ("transformer", transformer),
-        ("vae", vae),
-        ("scheduler", scheduler),
-    ]
-    components = {}
-    for name, component in component_name_pairs:
-        if component is not None:
-            components[name] = component
-
-    pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
-    pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
-    pipe.vae = pipe.vae.to(dtype=vae_dtype)
-    # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
-    # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
-    # DDP optimizer step.
-    if not is_training:
-        pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
-
-    if enable_slicing:
-        pipe.vae.enable_slicing()
-    if enable_tiling:
-        pipe.vae.enable_tiling()
-
-    if enable_model_cpu_offload:
-        pipe.enable_model_cpu_offload(device=device)
-    else:
-        pipe.to(device=device)
-
-    return pipe
-
-
-def prepare_conditions(
-    tokenizer: T5Tokenizer,
-    text_encoder: T5EncoderModel,
-    prompt: Union[str, List[str]],
-    device: Optional[torch.device] = None,
-    dtype: Optional[torch.dtype] = None,
-    max_sequence_length: int = 128,
-    **kwargs,
-) -> torch.Tensor:
-    device = device or text_encoder.device
-    dtype = dtype or text_encoder.dtype
-
-    if isinstance(prompt, str):
-        prompt = [prompt]
-
-    return _encode_prompt_t5(tokenizer, text_encoder, prompt, device, dtype, max_sequence_length)
-
-
-def prepare_latents(
-    vae: AutoencoderKLLTXVideo,
-    image_or_video: torch.Tensor,
-    patch_size: int = 1,
-    patch_size_t: int = 1,
-    device: Optional[torch.device] = None,
-    dtype: Optional[torch.dtype] = None,
-    generator: Optional[torch.Generator] = None,
-    precompute: bool = False,
-) -> torch.Tensor:
-    device = device or vae.device
-
-    if image_or_video.ndim == 4:
-        image_or_video = image_or_video.unsqueeze(2)
-    assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
-
-    image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
-    image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous()  # [B, C, F, H, W] -> [B, F, C, H, W]
-    if not precompute:
-        latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
-        latents = latents.to(dtype=dtype)
-        _, _, num_frames, height, width = latents.shape
-        latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std)
-        latents = _pack_latents(latents, patch_size, patch_size_t)
-        return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
-    else:
-        if vae.use_slicing and image_or_video.shape[0] > 1:
-            encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
-            h = torch.cat(encoded_slices)
-        else:
-            h = vae._encode(image_or_video)
-        _, _, num_frames, height, width = h.shape
-
-        # TODO(aryan): This is very stupid that we might possibly be storing the latents_mean and latents_std in every file
-        # if precomputation is enabled. We should probably have a single file where re-usable properties like this are stored
-        # so as to reduce the disk memory requirements of the precomputed files.
-        return {
-            "latents": h,
-            "num_frames": num_frames,
-            "height": height,
-            "width": width,
-            "latents_mean": vae.latents_mean,
-            "latents_std": vae.latents_std,
-        }
-
-
-def post_latent_preparation(
-    latents: torch.Tensor,
-    latents_mean: torch.Tensor,
-    latents_std: torch.Tensor,
-    num_frames: int,
-    height: int,
-    width: int,
-    patch_size: int = 1,
-    patch_size_t: int = 1,
-    **kwargs,
-) -> torch.Tensor:
-    latents = _normalize_latents(latents, latents_mean, latents_std)
-    latents = _pack_latents(latents, patch_size, patch_size_t)
-    return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
-
-
-def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
-    return {
-        "prompts": [x["prompt"] for x in batch[0]],
-        "videos": torch.stack([x["video"] for x in batch[0]]),
-    }
-
-
-def forward_pass(
-    transformer: LTXVideoTransformer3DModel,
-    prompt_embeds: torch.Tensor,
-    prompt_attention_mask: torch.Tensor,
-    latents: torch.Tensor,
-    noisy_latents: torch.Tensor,
-    timesteps: torch.LongTensor,
-    num_frames: int,
-    height: int,
-    width: int,
-    **kwargs,
-) -> torch.Tensor:
-    # TODO(aryan): make configurable
-    frame_rate = 25
-    latent_frame_rate = frame_rate / 8
-    spatial_compression_ratio = 32
-    rope_interpolation_scale = [1 / latent_frame_rate, spatial_compression_ratio, spatial_compression_ratio]
-
-    denoised_latents = transformer(
-        hidden_states=noisy_latents,
-        encoder_hidden_states=prompt_embeds,
-        timestep=timesteps,
-        encoder_attention_mask=prompt_attention_mask,
-        num_frames=num_frames,
-        height=height,
-        width=width,
-        rope_interpolation_scale=rope_interpolation_scale,
-        return_dict=False,
-    )[0]
-
-    return {"latents": denoised_latents}
-
-
-def validation(
-    pipeline: LTXPipeline,
-    prompt: str,
-    image: Optional[Image.Image] = None,
-    video: Optional[List[Image.Image]] = None,
-    height: Optional[int] = None,
-    width: Optional[int] = None,
-    num_frames: Optional[int] = None,
-    frame_rate: int = 24,
-    num_videos_per_prompt: int = 1,
-    generator: Optional[torch.Generator] = None,
-    **kwargs,
-):
-    generation_kwargs = {
-        "prompt": prompt,
-        "height": height,
-        "width": width,
-        "num_frames": num_frames,
-        "frame_rate": frame_rate,
-        "num_videos_per_prompt": num_videos_per_prompt,
-        "generator": generator,
-        "return_dict": True,
-        "output_type": "pil",
-    }
-    generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
-    video = pipeline(**generation_kwargs).frames[0]
-    return [("video", video)]
-
-
-def _encode_prompt_t5(
-    tokenizer: T5Tokenizer,
-    text_encoder: T5EncoderModel,
-    prompt: List[str],
-    device: torch.device,
-    dtype: torch.dtype,
-    max_sequence_length,
-) -> torch.Tensor:
-    batch_size = len(prompt)
-
-    text_inputs = tokenizer(
-        prompt,
-        padding="max_length",
-        max_length=max_sequence_length,
-        truncation=True,
-        add_special_tokens=True,
-        return_tensors="pt",
-    )
-    text_input_ids = text_inputs.input_ids
-    prompt_attention_mask = text_inputs.attention_mask
-    prompt_attention_mask = prompt_attention_mask.bool().to(device)
-
-    prompt_embeds = text_encoder(text_input_ids.to(device))[0]
-    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
-    prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
-
-    return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
-
-
-def _normalize_latents(
-    latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
-) -> torch.Tensor:
-    # Normalize latents across the channel dimension [B, C, F, H, W]
-    latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
-    latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
-    latents = (latents - latents_mean) * scaling_factor / latents_std
-    return latents
-
-
-def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
-    # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
-    # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
-    # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
-    # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
-    batch_size, num_channels, num_frames, height, width = latents.shape
-    post_patch_num_frames = num_frames // patch_size_t
-    post_patch_height = height // patch_size
-    post_patch_width = width // patch_size
-    latents = latents.reshape(
-        batch_size,
-        -1,
-        post_patch_num_frames,
-        patch_size_t,
-        post_patch_height,
-        patch_size,
-        post_patch_width,
-        patch_size,
-    )
-    latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
-    return latents
-
-
-LTX_VIDEO_T2V_LORA_CONFIG = {
-    "pipeline_cls": LTXPipeline,
-    "load_condition_models": load_condition_models,
-    "load_latent_models": load_latent_models,
-    "load_diffusion_models": load_diffusion_models,
-    "initialize_pipeline": initialize_pipeline,
-    "prepare_conditions": prepare_conditions,
-    "prepare_latents": prepare_latents,
-    "post_latent_preparation": post_latent_preparation,
-    "collate_fn": collate_fn_t2v,
-    "forward_pass": forward_pass,
-    "validation": validation,
-}
diff --git a/finetrainers/models/modeling_utils.py b/finetrainers/models/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..805b0d94e30c9d56157634740672b27f642993c4
--- /dev/null
+++ b/finetrainers/models/modeling_utils.py
@@ -0,0 +1,292 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.configuration_utils import FrozenDict
+from PIL.Image import Image
+
+from ..logging import get_logger
+from ..parallel import ParallelBackendEnum
+from ..processors import ProcessorMixin
+from ..typing import ArtifactType, SchedulerType, TokenizerType
+from ..utils import resolve_component_cls
+
+
+logger = get_logger()
+
+# TODO(aryan): we most likely don't need this. take a look after refactoring more
+# fmt: off
+IGNORE_KEYS_FOR_COLLATION = {"height", "width", "num_frames", "frame_rate", "rope_interpolation_scale", "return_dict", "attention_kwargs", "cross_attention_kwargs", "joint_attention_kwargs", "latents_mean", "latents_std"}
+# fmt: on
+
+
+class ModelSpecification:
+    r"""
+    The ModelSpecification class is an interface to be used for Diffusion training recipes. It provides
+    loose structure about how to organize the code for training. The trainer implementations will
+    make use of this interface to load models, prepare conditions, prepare latents, forward pass, etc.
+    """
+
+    def __init__(
+        self,
+        pretrained_model_name_or_path: Optional[str] = None,
+        tokenizer_id: Optional[str] = None,
+        tokenizer_2_id: Optional[str] = None,
+        tokenizer_3_id: Optional[str] = None,
+        text_encoder_id: Optional[str] = None,
+        text_encoder_2_id: Optional[str] = None,
+        text_encoder_3_id: Optional[str] = None,
+        transformer_id: Optional[str] = None,
+        vae_id: Optional[str] = None,
+        text_encoder_dtype: torch.dtype = torch.bfloat16,
+        text_encoder_2_dtype: torch.dtype = torch.bfloat16,
+        text_encoder_3_dtype: torch.dtype = torch.bfloat16,
+        transformer_dtype: torch.dtype = torch.bfloat16,
+        vae_dtype: str = torch.bfloat16,
+        revision: Optional[str] = None,
+        cache_dir: Optional[str] = None,
+        condition_model_processors: List[ProcessorMixin] = None,
+        latent_model_processors: List[ProcessorMixin] = None,
+    ) -> None:
+        self.pretrained_model_name_or_path = pretrained_model_name_or_path
+        self.tokenizer_id = tokenizer_id
+        self.tokenizer_2_id = tokenizer_2_id
+        self.tokenizer_3_id = tokenizer_3_id
+        self.text_encoder_id = text_encoder_id
+        self.text_encoder_2_id = text_encoder_2_id
+        self.text_encoder_3_id = text_encoder_3_id
+        self.transformer_id = transformer_id
+        self.vae_id = vae_id
+        self.text_encoder_dtype = text_encoder_dtype
+        self.text_encoder_2_dtype = text_encoder_2_dtype
+        self.text_encoder_3_dtype = text_encoder_3_dtype
+        self.transformer_dtype = transformer_dtype
+        self.vae_dtype = vae_dtype
+        self.revision = revision
+        self.cache_dir = cache_dir
+        self.condition_model_processors = condition_model_processors or []
+        self.latent_model_processors = latent_model_processors or []
+
+        self.transformer_config: Dict[str, Any] = None
+        self.vae_config: Dict[str, Any] = None
+
+        self._load_configs()
+
+    # TODO(aryan): revisit how to do this better without user having to worry about it
+    @property
+    def _resolution_dim_keys(self) -> Dict[str, Tuple[int, ...]]:
+        raise NotImplementedError(
+            f"ModelSpecification::_resolution_dim_keys is not implemented for {self.__class__.__name__}"
+        )
+
+    def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+        raise NotImplementedError(
+            f"ModelSpecification::load_condition_models is not implemented for {self.__class__.__name__}"
+        )
+
+    def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+        raise NotImplementedError(
+            f"ModelSpecification::load_latent_models is not implemented for {self.__class__.__name__}"
+        )
+
+    def load_diffusion_models(self) -> Dict[str, Union[torch.nn.Module]]:
+        raise NotImplementedError(
+            f"ModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}"
+        )
+
+    def load_pipeline(
+        self,
+        tokenizer: Optional[TokenizerType] = None,
+        tokenizer_2: Optional[TokenizerType] = None,
+        tokenizer_3: Optional[TokenizerType] = None,
+        text_encoder: Optional[torch.nn.Module] = None,
+        text_encoder_2: Optional[torch.nn.Module] = None,
+        text_encoder_3: Optional[torch.nn.Module] = None,
+        transformer: Optional[torch.nn.Module] = None,
+        vae: Optional[torch.nn.Module] = None,
+        scheduler: Optional[SchedulerType] = None,
+        enable_slicing: bool = False,
+        enable_tiling: bool = False,
+        enable_model_cpu_offload: bool = False,
+        training: bool = False,
+        **kwargs,
+    ) -> DiffusionPipeline:
+        raise NotImplementedError(
+            f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}"
+        )
+
+    def collate_fn(self, batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
+        raise NotImplementedError(f"ModelSpecification::collate_fn is not implemented for {self.__class__.__name__}")
+
+    def prepare_conditions(self, **kwargs) -> Dict[str, Any]:
+        for processor in self.condition_model_processors:
+            result = processor(**kwargs)
+            result_keys = set(result.keys())
+            repeat_keys = result_keys.intersection(kwargs.keys())
+            if repeat_keys:
+                logger.warning(
+                    f"Processor {processor.__class__.__name__} returned keys that already exist in "
+                    f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
+                    f"be intended. Please rename the keys in the processor to avoid conflicts."
+                )
+            kwargs.update(result)
+        return kwargs
+
+    def prepare_latents(self, **kwargs) -> Dict[str, Any]:
+        for processor in self.latent_model_processors:
+            result = processor(**kwargs)
+            result_keys = set(result.keys())
+            repeat_keys = result_keys.intersection(kwargs.keys())
+            if repeat_keys:
+                logger.warning(
+                    f"Processor {processor.__class__.__name__} returned keys that already exist in "
+                    f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
+                    f"be intended. Please rename the keys in the processor to avoid conflicts."
+                )
+            kwargs.update(result)
+        return kwargs
+
+    def collate_conditions(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+        keys = list(data[0].keys())
+        collated_data = {}
+        for key in keys:
+            if key in IGNORE_KEYS_FOR_COLLATION:
+                collated_data[key] = data[0][key]
+                continue
+            collated_d = [d[key] for d in data]
+            if isinstance(collated_d[0], torch.Tensor):
+                collated_d = torch.cat(collated_d)
+            collated_data[key] = collated_d
+        return collated_data
+
+    def collate_latents(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+        keys = list(data[0].keys())
+        collated_data = {}
+        for key in keys:
+            if key in IGNORE_KEYS_FOR_COLLATION:
+                collated_data[key] = data[0][key]
+                continue
+            collated_d = [d[key] for d in data]
+            # TODO(aryan): Support multi-resolution collation
+            if isinstance(collated_d[0], torch.Tensor):
+                collated_d = torch.cat(collated_d)
+            collated_data[key] = collated_d
+        return collated_data
+
+    def forward(
+        self, transformer: torch.nn.Module, generator: Optional[torch.Generator] = None, **kwargs
+    ) -> Dict[str, torch.Tensor]:
+        raise NotImplementedError(f"ModelSpecification::forward is not implemented for {self.__class__.__name__}")
+
+    def validation(
+        self,
+        pipeline: DiffusionPipeline,
+        prompt: Optional[str] = None,
+        image: Optional[Image] = None,
+        video: Optional[List[Image]] = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_frames: Optional[int] = None,
+        frame_rate: Optional[int] = None,
+        generator: Optional[torch.Generator] = None,
+    ) -> List[ArtifactType]:
+        raise NotImplementedError(f"ModelSpecification::validation is not implemented for {self.__class__.__name__}")
+
+    def _save_lora_weights(
+        self,
+        directory: str,
+        transformer: torch.nn.Module,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+    ) -> None:
+        r"""
+        Save the lora state dicts of the model to the given directory.
+
+        This API is not backwards compatible and will be changed in near future.
+        """
+        raise NotImplementedError(
+            f"ModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}"
+        )
+
+    def _save_model(
+        self,
+        directory: str,
+        transformer: torch.nn.Module,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+    ) -> None:
+        r"""
+        Save the state dicts to the given directory.
+
+        This API is not backwards compatible and will be changed in near future.
+        """
+        raise NotImplementedError(f"ModelSpecification::save_model is not implemented for {self.__class__.__name__}")
+
+    def apply_tensor_parallel(
+        self,
+        backend: ParallelBackendEnum,
+        device_mesh: torch.distributed.DeviceMesh,
+        text_encoder: torch.nn.Module,
+        text_encoder_2: torch.nn.Module,
+        text_encoder_3: torch.nn.Module,
+        transformer: torch.nn.Module,
+        vae: torch.nn.Module,
+    ) -> None:
+        raise NotImplementedError(
+            f"ModelSpecification::apply_tensor_parallel is not implemented for {self.__class__.__name__}"
+        )
+
+    def _load_configs(self) -> None:
+        self._load_transformer_config()
+        self._load_vae_config()
+
+    def _load_transformer_config(self) -> None:
+        if self.transformer_id is not None:
+            transformer_cls = resolve_component_cls(
+                self.transformer_id,
+                component_name="_class_name",
+                filename="config.json",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+            self.transformer_config = transformer_cls.load_config(
+                self.transformer_id, revision=self.revision, cache_dir=self.cache_dir
+            )
+        else:
+            transformer_cls = resolve_component_cls(
+                self.pretrained_model_name_or_path,
+                component_name="transformer",
+                filename="model_index.json",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+            self.transformer_config = transformer_cls.load_config(
+                self.pretrained_model_name_or_path,
+                subfolder="transformer",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        self.transformer_config = FrozenDict(**self.transformer_config)
+
+    def _load_vae_config(self) -> None:
+        if self.vae_id is not None:
+            vae_cls = resolve_component_cls(
+                self.vae_id,
+                component_name="_class_name",
+                filename="config.json",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+            self.vae_config = vae_cls.load_config(self.vae_id, revision=self.revision, cache_dir=self.cache_dir)
+        else:
+            vae_cls = resolve_component_cls(
+                self.pretrained_model_name_or_path,
+                component_name="vae",
+                filename="model_index.json",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+            self.vae_config = vae_cls.load_config(
+                self.pretrained_model_name_or_path, subfolder="vae", revision=self.revision, cache_dir=self.cache_dir
+            )
+        self.vae_config = FrozenDict(**self.vae_config)
diff --git a/finetrainers/models/utils.py b/finetrainers/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeda1e4379dfc6ab1d7baba1807f6e0ac71d779b
--- /dev/null
+++ b/finetrainers/models/utils.py
@@ -0,0 +1,62 @@
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from diffusers.utils.torch_utils import randn_tensor
+
+
+class DiagonalGaussianDistribution(object):
+    def __init__(self, parameters: torch.Tensor, deterministic: bool = False, _dim: int = 1):
+        # Note: _dim is the new argument added here after copying from diffusers
+        self.parameters = parameters
+        self.mean, self.logvar = torch.chunk(parameters, 2, dim=_dim)
+        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+        self.deterministic = deterministic
+        self.std = torch.exp(0.5 * self.logvar)
+        self.var = torch.exp(self.logvar)
+        if self.deterministic:
+            self.var = self.std = torch.zeros_like(
+                self.mean, device=self.parameters.device, dtype=self.parameters.dtype
+            )
+
+    def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
+        # make sure sample is on the same device as the parameters and has same dtype
+        sample = randn_tensor(
+            self.mean.shape,
+            generator=generator,
+            device=self.parameters.device,
+            dtype=self.parameters.dtype,
+        )
+        x = self.mean + self.std * sample
+        return x
+
+    def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
+        if self.deterministic:
+            return torch.Tensor([0.0])
+        else:
+            if other is None:
+                return 0.5 * torch.sum(
+                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+                    dim=[1, 2, 3],
+                )
+            else:
+                return 0.5 * torch.sum(
+                    torch.pow(self.mean - other.mean, 2) / other.var
+                    + self.var / other.var
+                    - 1.0
+                    - self.logvar
+                    + other.logvar,
+                    dim=[1, 2, 3],
+                )
+
+    def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
+        if self.deterministic:
+            return torch.Tensor([0.0])
+        logtwopi = np.log(2.0 * np.pi)
+        return 0.5 * torch.sum(
+            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+            dim=dims,
+        )
+
+    def mode(self) -> torch.Tensor:
+        return self.mean
diff --git a/finetrainers/models/wan/__init__.py b/finetrainers/models/wan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bfeae2994e6b83b8fcb1337602e5cb39c73fdc7
--- /dev/null
+++ b/finetrainers/models/wan/__init__.py
@@ -0,0 +1 @@
+from .base_specification import WanModelSpecification
diff --git a/finetrainers/models/wan/base_specification.py b/finetrainers/models/wan/base_specification.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d27f428dde35f9203b8fb889e209c63b0f1e0da
--- /dev/null
+++ b/finetrainers/models/wan/base_specification.py
@@ -0,0 +1,378 @@
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+from accelerate import init_empty_weights
+from diffusers import (
+    AutoencoderKLWan,
+    FlowMatchEulerDiscreteScheduler,
+    WanImageToVideoPipeline,
+    WanPipeline,
+    WanTransformer3DModel,
+)
+from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
+from PIL.Image import Image
+from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel
+
+from ... import data
+from ... import functional as FF
+from ...logging import get_logger
+from ...processors import ProcessorMixin, T5Processor
+from ...typing import ArtifactType, SchedulerType
+from ...utils import get_non_null_items
+from ..modeling_utils import ModelSpecification
+
+
+logger = get_logger()
+
+
+class WanLatentEncodeProcessor(ProcessorMixin):
+    r"""
+    Processor to encode image/video into latents using the Wan VAE.
+
+    Args:
+        output_names (`List[str]`):
+            The names of the outputs that the processor returns. The outputs are in the following order:
+            - latents: The latents of the input image/video.
+            - num_frames: The number of frames in the input video.
+            - height: The height of the input image/video.
+            - width: The width of the input image/video.
+            - latents_mean: The latent channel means from the VAE state dict.
+            - latents_std: The latent channel standard deviations from the VAE state dict.
+    """
+
+    def __init__(self, output_names: List[str]):
+        super().__init__()
+        self.output_names = output_names
+        assert len(self.output_names) == 1
+
+    def forward(
+        self,
+        vae: AutoencoderKLWan,
+        image: Optional[torch.Tensor] = None,
+        video: Optional[torch.Tensor] = None,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+    ) -> Dict[str, torch.Tensor]:
+        device = vae.device
+        dtype = vae.dtype
+
+        if image is not None:
+            video = image.unsqueeze(1)
+
+        assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
+        video = video.to(device=device, dtype=vae.dtype)
+        video = video.permute(0, 2, 1, 3, 4).contiguous()  # [B, F, C, H, W] -> [B, C, F, H, W]
+
+        if compute_posterior:
+            latents = vae.encode(video).latent_dist.sample(generator=generator)
+            latents = latents.to(dtype=dtype)
+        else:
+            # TODO(aryan): refactor in diffusers to have use_slicing attribute
+            # if vae.use_slicing and video.shape[0] > 1:
+            #     encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
+            #     moments = torch.cat(encoded_slices)
+            # else:
+            #     moments = vae._encode(video)
+            moments = vae._encode(video)
+            latents = moments.to(dtype=dtype)
+
+        return {self.output_names[0]: latents}
+
+
+class WanModelSpecification(ModelSpecification):
+    def __init__(
+        self,
+        pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
+        tokenizer_id: Optional[str] = None,
+        text_encoder_id: Optional[str] = None,
+        transformer_id: Optional[str] = None,
+        vae_id: Optional[str] = None,
+        text_encoder_dtype: torch.dtype = torch.bfloat16,
+        transformer_dtype: torch.dtype = torch.bfloat16,
+        vae_dtype: torch.dtype = torch.bfloat16,
+        revision: Optional[str] = None,
+        cache_dir: Optional[str] = None,
+        condition_model_processors: List[ProcessorMixin] = None,
+        latent_model_processors: List[ProcessorMixin] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            pretrained_model_name_or_path=pretrained_model_name_or_path,
+            tokenizer_id=tokenizer_id,
+            text_encoder_id=text_encoder_id,
+            transformer_id=transformer_id,
+            vae_id=vae_id,
+            text_encoder_dtype=text_encoder_dtype,
+            transformer_dtype=transformer_dtype,
+            vae_dtype=vae_dtype,
+            revision=revision,
+            cache_dir=cache_dir,
+        )
+
+        if condition_model_processors is None:
+            condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
+        if latent_model_processors is None:
+            latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
+
+        self.condition_model_processors = condition_model_processors
+        self.latent_model_processors = latent_model_processors
+
+    @property
+    def _resolution_dim_keys(self):
+        # TODO
+        return {
+            "latents": (2, 3, 4),
+        }
+
+    def load_condition_models(self) -> Dict[str, torch.nn.Module]:
+        if self.tokenizer_id is not None:
+            tokenizer = AutoTokenizer.from_pretrained(
+                self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
+            )
+        else:
+            tokenizer = AutoTokenizer.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="tokenizer",
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        if self.text_encoder_id is not None:
+            text_encoder = AutoModel.from_pretrained(
+                self.text_encoder_id,
+                torch_dtype=self.text_encoder_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            text_encoder = UMT5EncoderModel.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="text_encoder",
+                torch_dtype=self.text_encoder_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        return {"tokenizer": tokenizer, "text_encoder": text_encoder}
+
+    def load_latent_models(self) -> Dict[str, torch.nn.Module]:
+        if self.vae_id is not None:
+            vae = AutoencoderKLWan.from_pretrained(
+                self.vae_id,
+                torch_dtype=self.vae_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            vae = AutoencoderKLWan.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="vae",
+                torch_dtype=self.vae_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        return {"vae": vae}
+
+    def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
+        if self.transformer_id is not None:
+            transformer = WanTransformer3DModel.from_pretrained(
+                self.transformer_id,
+                torch_dtype=self.transformer_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+        else:
+            transformer = WanTransformer3DModel.from_pretrained(
+                self.pretrained_model_name_or_path,
+                subfolder="transformer",
+                torch_dtype=self.transformer_dtype,
+                revision=self.revision,
+                cache_dir=self.cache_dir,
+            )
+
+        scheduler = FlowMatchEulerDiscreteScheduler()
+
+        return {"transformer": transformer, "scheduler": scheduler}
+
+    def load_pipeline(
+        self,
+        tokenizer: Optional[AutoTokenizer] = None,
+        text_encoder: Optional[UMT5EncoderModel] = None,
+        transformer: Optional[WanTransformer3DModel] = None,
+        vae: Optional[AutoencoderKLWan] = None,
+        scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
+        enable_slicing: bool = False,
+        enable_tiling: bool = False,
+        enable_model_cpu_offload: bool = False,
+        training: bool = False,
+        **kwargs,
+    ) -> WanPipeline:
+        components = {
+            "tokenizer": tokenizer,
+            "text_encoder": text_encoder,
+            "transformer": transformer,
+            "vae": vae,
+            "scheduler": scheduler,
+        }
+        components = get_non_null_items(components)
+
+        pipe = WanPipeline.from_pretrained(
+            self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
+        )
+        pipe.text_encoder.to(self.text_encoder_dtype)
+        pipe.vae.to(self.vae_dtype)
+
+        if not training:
+            pipe.transformer.to(self.transformer_dtype)
+
+        # TODO(aryan): add support in diffusers
+        # if enable_slicing:
+        #     pipe.vae.enable_slicing()
+        # if enable_tiling:
+        #     pipe.vae.enable_tiling()
+        if enable_model_cpu_offload:
+            pipe.enable_model_cpu_offload()
+
+        return pipe
+
+    @torch.no_grad()
+    def prepare_conditions(
+        self,
+        tokenizer: AutoTokenizer,
+        text_encoder: UMT5EncoderModel,
+        caption: str,
+        max_sequence_length: int = 512,
+        **kwargs,
+    ) -> Dict[str, Any]:
+        conditions = {
+            "tokenizer": tokenizer,
+            "text_encoder": text_encoder,
+            "caption": caption,
+            "max_sequence_length": max_sequence_length,
+            **kwargs,
+        }
+        input_keys = set(conditions.keys())
+        conditions = super().prepare_conditions(**conditions)
+        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+        conditions.pop("prompt_attention_mask", None)
+        return conditions
+
+    @torch.no_grad()
+    def prepare_latents(
+        self,
+        vae: AutoencoderKLWan,
+        image: Optional[torch.Tensor] = None,
+        video: Optional[torch.Tensor] = None,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+        **kwargs,
+    ) -> Dict[str, torch.Tensor]:
+        conditions = {
+            "vae": vae,
+            "image": image,
+            "video": video,
+            "generator": generator,
+            "compute_posterior": compute_posterior,
+            **kwargs,
+        }
+        input_keys = set(conditions.keys())
+        conditions = super().prepare_latents(**conditions)
+        conditions = {k: v for k, v in conditions.items() if k not in input_keys}
+        return conditions
+
+    def forward(
+        self,
+        transformer: WanTransformer3DModel,
+        condition_model_conditions: Dict[str, torch.Tensor],
+        latent_model_conditions: Dict[str, torch.Tensor],
+        sigmas: torch.Tensor,
+        generator: Optional[torch.Generator] = None,
+        compute_posterior: bool = True,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, ...]:
+        if compute_posterior:
+            latents = latent_model_conditions.pop("latents")
+        else:
+            posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
+            latents = posterior.sample(generator=generator)
+            del posterior
+
+        noise = torch.zeros_like(latents).normal_(generator=generator)
+        noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
+
+        latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
+        condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
+
+        timesteps = (sigmas.flatten() * 1000.0).long()
+
+        pred = transformer(
+            **latent_model_conditions,
+            **condition_model_conditions,
+            timestep=timesteps,
+            return_dict=False,
+        )[0]
+        target = FF.flow_match_target(noise, latents)
+
+        return pred, target, sigmas
+
+    def validation(
+        self,
+        pipeline: WanPipeline,
+        prompt: str,
+        image: Optional[Image] = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_frames: Optional[int] = None,
+        num_inference_steps: int = 50,
+        generator: Optional[torch.Generator] = None,
+        **kwargs,
+    ) -> List[ArtifactType]:
+        if image is not None:
+            pipeline = WanImageToVideoPipeline.from_pipe(pipeline)
+
+        generation_kwargs = {
+            "prompt": prompt,
+            "image": image,
+            "height": height,
+            "width": width,
+            "num_frames": num_frames,
+            "num_inference_steps": num_inference_steps,
+            "generator": generator,
+            "return_dict": True,
+            "output_type": "pil",
+        }
+        generation_kwargs = get_non_null_items(generation_kwargs)
+        video = pipeline(**generation_kwargs).frames[0]
+        return [data.VideoArtifact(value=video)]
+
+    def _save_lora_weights(
+        self,
+        directory: str,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+        *args,
+        **kwargs,
+    ) -> None:
+        # TODO(aryan): this needs refactoring
+        if transformer_state_dict is not None:
+            WanPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
+        if scheduler is not None:
+            scheduler.save_pretrained(os.path.join(directory, "scheduler"))
+
+    def _save_model(
+        self,
+        directory: str,
+        transformer: WanTransformer3DModel,
+        transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
+        scheduler: Optional[SchedulerType] = None,
+    ) -> None:
+        # TODO(aryan): this needs refactoring
+        if transformer_state_dict is not None:
+            with init_empty_weights():
+                transformer_copy = WanTransformer3DModel.from_config(transformer.config)
+            transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
+            transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
+        if scheduler is not None:
+            scheduler.save_pretrained(os.path.join(directory, "scheduler"))
diff --git a/finetrainers/optimizer.py b/finetrainers/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..57da28e9377f2bf82b5307fae83338ad0b9ec385
--- /dev/null
+++ b/finetrainers/optimizer.py
@@ -0,0 +1,449 @@
+import functools
+import math
+from typing import Any, Callable, Dict, List, Optional, Type, Union
+
+import torch
+from torch.distributed.checkpoint.state_dict import (
+    StateDictOptions,
+    get_optimizer_state_dict,
+    set_optimizer_state_dict,
+)
+from torch.distributed.checkpoint.stateful import Stateful
+
+from .parallel import ParallelBackendEnum
+from .utils.import_utils import is_bitsandbytes_available
+
+
+class OptimizerWrapper(Stateful):
+    r"""
+    Optimizer wrapper that:
+        - allows step/zero_grad on multiple optimizers needed for virtual pipeline stages
+        - saves/loading optimizer state_dict at checkpoint
+    """
+
+    def __init__(
+        self,
+        model_parts: List[torch.nn.Module],
+        optimizer_cls: Type[torch.optim.Optimizer],
+        optimizer_kwargs: Dict[str, Any],
+    ) -> None:
+        self.optimizer_cls = optimizer_cls
+        self.optimizer_kwargs = optimizer_kwargs
+
+        self.optimizers = []
+        self.model_parts = model_parts
+
+        for model in self.model_parts:
+            optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)
+            self.optimizers.append(optimizer)
+
+    def step(self) -> None:
+        for optimizer in self.optimizers:
+            optimizer.step()
+
+    def zero_grad(self) -> None:
+        for optimizer in self.optimizers:
+            optimizer.zero_grad()
+
+    def state_dict(self) -> Dict[str, Any]:
+        func = functools.partial(
+            get_optimizer_state_dict,
+            options=StateDictOptions(flatten_optimizer_state_dict=True),
+        )
+        return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()}
+
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+        func = functools.partial(
+            set_optimizer_state_dict,
+            optim_state_dict=state_dict,
+            options=StateDictOptions(flatten_optimizer_state_dict=True),
+        )
+        list(map(func, self.model_parts, self.optimizers))
+
+
+class SchedulerWrapper:
+    def __init__(
+        self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int
+    ) -> None:
+        self.schedulers = []
+        for optimizer in optimizers:
+            self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch))
+
+    def step(self) -> None:
+        for scheduler in self.schedulers:
+            scheduler.step()
+
+    def get_last_lr(self) -> List[float]:
+        # TODO(aryan): look into this later. Currently calling it leads to NCCL hang?????
+        return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)}
+
+    def get_lr_scheduler_state(self) -> Dict[str, Any]:
+        state_dict = {}
+        if len(self.schedulers) == 1:
+            state_dict["lr_scheduler"] = self.schedulers[0]
+        else:
+            # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
+            # It should only support saving and loading a distributed checkpoint with the same number of pp ranks
+            for idx, lr_scheduler in enumerate(self.schedulers):
+                state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
+        return state_dict
+
+
+def get_optimizer(
+    parallel_backend: ParallelBackendEnum,
+    name: str,
+    model_parts: List[torch.nn.Module],
+    learning_rate: float = 1e-3,
+    beta1: float = 0.9,
+    beta2: float = 0.95,
+    beta3: float = 0.999,
+    epsilon: float = 1e-8,
+    weight_decay: float = 1e-4,
+    fused: bool = False,
+) -> Union[torch.optim.Optimizer, OptimizerWrapper]:
+    name = name.lower()
+
+    _raise_errors_if_packages_not_available(name)
+
+    if name == "adam":
+        optimizer_cls = torch.optim.Adam
+        optimizer_kwargs = {
+            "lr": learning_rate,
+            "betas": (beta1, beta2),
+            "eps": epsilon,
+            "weight_decay": weight_decay,
+            "fused": fused,
+        }
+    elif name == "adamw":
+        optimizer_cls = torch.optim.AdamW
+        optimizer_kwargs = {
+            "lr": learning_rate,
+            "betas": (beta1, beta2),
+            "eps": epsilon,
+            "weight_decay": weight_decay,
+            "fused": fused,
+        }
+    elif name == "adam-bnb":
+        from bitsandbytes.optim import Adam
+
+        optimizer_cls = Adam
+        optimizer_kwargs = {
+            "lr": learning_rate,
+            "betas": (beta1, beta2),
+            "eps": epsilon,
+            "weight_decay": weight_decay,
+        }
+    elif name == "adamw-bnb":
+        from bitsandbytes.optim import AdamW
+
+        optimizer_cls = AdamW
+        optimizer_kwargs = {
+            "lr": learning_rate,
+            "betas": (beta1, beta2),
+            "eps": epsilon,
+            "weight_decay": weight_decay,
+        }
+    elif name == "adam-bnb-8bit":
+        from bitsandbytes.optim import Adam8bit
+
+        optimizer_cls = Adam8bit
+        optimizer_kwargs = {
+            "lr": learning_rate,
+            "betas": (beta1, beta2),
+            "eps": epsilon,
+            "weight_decay": weight_decay,
+        }
+    elif name == "adamw-bnb-8bit":
+        from bitsandbytes.optim import AdamW8bit
+
+        optimizer_cls = AdamW8bit
+        optimizer_kwargs = {
+            "lr": learning_rate,
+            "betas": (beta1, beta2),
+            "eps": epsilon,
+            "weight_decay": weight_decay,
+        }
+
+    # TODO(aryan): handle bitsandbytes and torchao
+    else:
+        raise ValueError(f"Unsupported optimizer: {name}")
+
+    if parallel_backend == ParallelBackendEnum.ACCELERATE:
+        return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs)
+    elif parallel_backend == ParallelBackendEnum.PTD:
+        return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs)
+
+
+def get_optimizer_accelerate(
+    model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
+) -> torch.optim.Optimizer:
+    params = [param for model in model_parts for param in model.parameters() if param.requires_grad]
+    optimizer = optimizer_cls(params, **optimizer_kwargs)
+    return optimizer
+
+
+def get_optimizer_ptd(
+    model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
+) -> OptimizerWrapper:
+    return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs)
+
+
+def get_lr_scheduler(
+    parallel_backend: ParallelBackendEnum,
+    name: str,
+    optimizer: Union[torch.optim.Optimizer, OptimizerWrapper],
+    step_rules: Optional[str] = None,
+    num_warmup_steps: Optional[int] = None,
+    num_training_steps: Optional[int] = None,
+    num_cycles: int = 1,
+    power: float = 1.0,
+    lr_init: float = 1e-3,
+    lr_end: float = 1e-7,
+    last_epoch: int = -1,
+) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]:
+    name = name.lower()
+    if name == "constant":
+        scheduler_lambda_fn = get_constant_schedule()
+    elif name == "constant_with_warmup":
+        scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps)
+    elif name == "piecewise_constant":
+        scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules)
+    elif name == "linear":
+        scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps)
+    elif name == "cosine":
+        scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles)
+    elif name == "cosine_with_restarts":
+        scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup(
+            num_warmup_steps, num_training_steps, num_cycles
+        )
+    elif name == "polynomial":
+        scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup(
+            num_warmup_steps, num_training_steps, lr_init, lr_end, power
+        )
+    else:
+        raise ValueError(f"Unsupported scheduler: {name}")
+
+    if parallel_backend == ParallelBackendEnum.ACCELERATE:
+        return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch)
+    elif parallel_backend == ParallelBackendEnum.PTD:
+        return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch)
+
+
+def get_lr_scheduler_accelerate(
+    optimizer: torch.optim.Optimizer,
+    scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler],
+    last_epoch: int = -1,
+) -> torch.optim.lr_scheduler.LambdaLR:
+    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)
+    return scheduler
+
+
+def get_lr_scheduler_ptd(
+    optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1
+) -> SchedulerWrapper:
+    return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch)
+
+
+# ==============================
+# Adapted from https://github.com/huggingface/diffusers/blob/196aef5a6f76e1ad6ba889184860c3633d166910/src/diffusers/optimization.py
+# ==============================
+
+
+def get_constant_schedule() -> Callable[[int], float]:
+    r"""
+    Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+    """
+
+    def lr_lambda(current_step: int):
+        return 1.0
+
+    return lr_lambda
+
+
+def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]:
+    r"""
+    Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+    increases linearly between 0 and the initial lr set in the optimizer.
+
+    Args:
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+    """
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1.0, num_warmup_steps))
+        return 1.0
+
+    return lr_lambda
+
+
+def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]:
+    r"""
+    Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+    Args:
+        step_rules (`string`):
+            The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
+            if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
+            steps and multiple 0.005 for the other steps.
+    """
+
+    rules_dict = {}
+    rule_list = step_rules.split(",")
+    for rule_str in rule_list[:-1]:
+        value_str, steps_str = rule_str.split(":")
+        steps = int(steps_str)
+        value = float(value_str)
+        rules_dict[steps] = value
+    last_lr_multiple = float(rule_list[-1])
+
+    def create_rules_function(rules_dict, last_lr_multiple):
+        def rule_func(steps: int) -> float:
+            sorted_steps = sorted(rules_dict.keys())
+            for i, sorted_step in enumerate(sorted_steps):
+                if steps < sorted_step:
+                    return rules_dict[sorted_steps[i]]
+            return last_lr_multiple
+
+        return rule_func
+
+    rules_func = create_rules_function(rules_dict, last_lr_multiple)
+    return rules_func
+
+
+def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]:
+    r"""
+    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+    Args:
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+    """
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        return max(
+            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+        )
+
+    return lr_lambda
+
+
+def get_cosine_schedule_with_warmup(
+    num_warmup_steps: int,
+    num_training_steps: int,
+    num_cycles: float = 0.5,
+) -> Callable[[int], float]:
+    r"""
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+    initial lr set in the optimizer.
+
+    Args:
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_periods (`float`, *optional*, defaults to 0.5):
+            The number of periods of the cosine function in a schedule (the default is to just decrease from the max
+            value to 0 following a half-cosine).
+    """
+
+    def lr_lambda(current_step):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+    return lr_lambda
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+    num_warmup_steps: int,
+    num_training_steps: int,
+    num_cycles: int = 1,
+) -> Callable[[int], float]:
+    r"""
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+    linearly between 0 and the initial lr set in the optimizer.
+
+    Args:
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_cycles (`int`, *optional*, defaults to 1):
+            The number of hard restarts to use.
+    """
+
+    def lr_lambda(current_step):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+        if progress >= 1.0:
+            return 0.0
+        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+    return lr_lambda
+
+
+def get_polynomial_decay_schedule_with_warmup(
+    num_warmup_steps: int,
+    num_training_steps: int,
+    lr_init: float,
+    lr_end: float = 1e-7,
+    power: float = 1.0,
+) -> Callable[[int], float]:
+    r"""
+    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+    initial lr set in the optimizer.
+
+    Args:
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        lr_end (`float`, *optional*, defaults to 1e-7):
+            The end LR.
+        power (`float`, *optional*, defaults to 1.0):
+            Power factor.
+
+    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at
+    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+    """
+
+    if not (lr_init > lr_end):
+        raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
+
+    def lr_lambda(current_step: int):
+        if current_step < num_warmup_steps:
+            return float(current_step) / float(max(1, num_warmup_steps))
+        elif current_step > num_training_steps:
+            return lr_end / lr_init  # as LambdaLR multiplies by lr_init
+        else:
+            lr_range = lr_init - lr_end
+            decay_steps = num_training_steps - num_warmup_steps
+            pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+            decay = lr_range * pct_remaining**power + lr_end
+            return decay / lr_init  # as LambdaLR multiplies by lr_init
+
+    return lr_lambda
+
+
+def _raise_errors_if_packages_not_available(name: str) -> None:
+    name_split = name.split("-")
+    if len(name_split) < 2:
+        return
+    package_name = name_split[1]
+    if package_name == "bnb":
+        if not is_bitsandbytes_available():
+            raise ImportError(
+                f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer."
+            )
diff --git a/finetrainers/parallel/__init__.py b/finetrainers/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fbce75fadfc33312f6d655fc581ac90cfb6c577
--- /dev/null
+++ b/finetrainers/parallel/__init__.py
@@ -0,0 +1,22 @@
+from enum import Enum
+from typing import Union
+
+from .accelerate import AccelerateParallelBackend
+from .ptd import PytorchDTensorParallelBackend
+from .utils import apply_ddp_ptd, apply_fsdp2_ptd, dist_max, dist_mean
+
+
+ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend]
+
+
+class ParallelBackendEnum(str, Enum):
+    ACCELERATE = "accelerate"
+    PTD = "ptd"
+
+
+def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType:
+    if backend == ParallelBackendEnum.ACCELERATE:
+        return AccelerateParallelBackend
+    if backend == ParallelBackendEnum.PTD:
+        return PytorchDTensorParallelBackend
+    raise ValueError(f"Unknown parallel backend: {backend}")
diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a523321f0c333af54949cb2274fb2d60cf014ff
--- /dev/null
+++ b/finetrainers/parallel/accelerate.py
@@ -0,0 +1,218 @@
+import datetime
+import pathlib
+from typing import Optional
+
+import torch
+from diffusers.utils import is_accelerate_available
+
+from ..logging import get_logger
+from ..utils import get_device_info
+from .base import BaseParallelBackend
+from .utils import apply_ddp_accelerate
+
+
+if not is_accelerate_available():
+    raise ImportError(
+        "Please install the accelerate package using `pip install accelerate` to use the AccelerateParallelBackend."
+    )
+
+from accelerate import Accelerator
+from accelerate.data_loader import DataLoader
+from accelerate.utils import (
+    DataLoaderConfiguration,
+    DistributedDataParallelKwargs,
+    InitProcessGroupKwargs,
+    ProjectConfiguration,
+)
+
+
+logger = get_logger()
+_device_type, _device_module = get_device_info()
+
+
+class AccelerateParallelBackend(BaseParallelBackend):
+    def __init__(
+        self,
+        world_size: int,
+        pp_degree: int = 1,
+        dp_degree: int = 1,
+        dp_shards: int = -1,
+        cp_degree: int = 1,
+        tp_degree: int = 1,
+        backend: str = "nccl",
+        timeout: int = 180,
+        logging_dir: Optional[str] = None,
+        output_dir: Optional[str] = None,
+        gradient_accumulation_steps: Optional[int] = None,
+    ) -> None:
+        super().__init__()
+
+        self._world_size = world_size
+        self._pp_degree = pp_degree
+        self._dp_degree = dp_degree
+        self._dp_shards = dp_shards
+        self._cp_degree = cp_degree
+        self._tp_degree = tp_degree
+        self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
+        self._logging_dir = (
+            self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
+        )
+        self._backend = backend
+        self._timeout = timeout
+        self._gradient_accumulation_steps = gradient_accumulation_steps
+
+        if pp_degree > 1 or dp_shards > 1 or cp_degree > 1 or tp_degree > 1:
+            raise ValueError(
+                "AccelerateParallelBackend does not support anything but Distributed Data Parallelism at the moment."
+            )
+        if dp_degree != world_size:
+            raise ValueError("Data parallel degree must be equal to world size.")
+
+        self._accelerator: Accelerator = None
+        self._mesh: torch.distributed.DeviceMesh = None
+
+    def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
+        project_config = None
+        ddp_kwargs = None
+        init_process_group_kwargs = None
+        if self._accelerator is None:
+            project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir)
+            ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
+            dataloader_config = DataLoaderConfiguration(
+                split_batches=False, dispatch_batches=False, use_stateful_dataloader=True
+            )
+            init_process_group_kwargs = InitProcessGroupKwargs(
+                backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)
+            )
+        self._accelerator, model = apply_ddp_accelerate(
+            model,
+            project_config,
+            ddp_kwargs,
+            init_process_group_kwargs,
+            dataloader_config,
+            self._gradient_accumulation_steps,
+            accelerator=self._accelerator,
+        )
+        logger.debug("Applied AccelerateParallel::apply_ddp to model.")
+        return model
+
+    def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
+        logger.debug("AccelerateParallelBackend::prepare_dataset completed!")
+        return dataset
+
+    def prepare_dataloader(
+        self,
+        dataset: torch.utils.data.IterableDataset,
+        batch_size: int = 1,
+        num_workers: int = 0,
+        pin_memory: bool = False,
+    ) -> DataLoader:
+        dataloader = torch.utils.data.DataLoader(
+            dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory
+        )
+        dataloader = self._accelerator.prepare_data_loader(dataloader)
+        logger.debug("AccelerateParallelBackend::prepare_dataloader completed!")
+        return dataloader
+
+    def prepare_optimizer(self, optimizer, lr_scheduler):
+        optimizer = self._accelerator.prepare_optimizer(optimizer)
+        lr_scheduler = self._accelerator.prepare_scheduler(lr_scheduler)
+        return optimizer, lr_scheduler
+
+    def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
+        def _get_mesh():
+            if name is None:
+                return self._mesh
+            try:
+                return self._mesh[name]
+            except (KeyError, RuntimeError):
+                return self._mesh
+
+        if self._mesh is not None:
+            return _get_mesh()
+
+        mesh_list = [("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards)]
+        mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
+        names = [x[0] for x in mesh_list]
+        degrees = [x[1] for x in mesh_list]
+        mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)
+
+        dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []
+
+        if self.data_replication_enabled:
+            dp_mesh_names.append("dp_replicate")
+            dp_cp_mesh_names.append("dp_replicate")
+        if self.data_sharding_enabled:
+            dp_mesh_names.append("dp_shard")
+            dp_cp_mesh_names.append("dp_shard")
+            dp_shard_cp_mesh_names.append("dp_shard")
+        if self.context_parallel_enabled:
+            dp_cp_mesh_names.append("cp")
+            dp_shard_cp_mesh_names.append("cp")
+
+        if len(dp_mesh_names) > 0:
+            mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
+        if len(dp_cp_mesh_names) > 0:
+            mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
+        if len(dp_shard_cp_mesh_names) > 0:
+            mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")
+
+        logger.debug(f"Device mesh: {mesh}")
+        self._mesh = mesh
+        return _get_mesh()
+
+    @property
+    def world_size(self):
+        return self._accelerator.num_processes
+
+    @property
+    def rank(self):
+        return self._accelerator.process_index
+
+    @property
+    def local_rank(self):
+        return self._accelerator.local_process_index
+
+    @property
+    def is_main_process(self):
+        r"""Returns `True` if the current process is the main process on the master node."""
+        return self._accelerator.is_main_process
+
+    @property
+    def is_local_main_process(self):
+        r"""Returns `True` if the current process is the main process on local node."""
+        return self._accelerator.is_local_main_process
+
+    @property
+    def device(self):
+        return self._accelerator.device
+
+    def wait_for_everyone(self):
+        self._accelerator.wait_for_everyone()
+
+    def destroy(self):
+        self._accelerator.end_training()
+
+    @property
+    def pipeline_parallel_enabled(self):
+        return self._pp_degree > 1
+
+    @property
+    def data_parallel_enabled(self):
+        return self._dp_degree > 1 or self._dp_shards > 1
+
+    @property
+    def data_replication_enabled(self):
+        return self._dp_degree > 1
+
+    @property
+    def data_sharding_enabled(self):
+        return self._dp_shards > 1
+
+    @property
+    def context_parallel_enabled(self):
+        return self._cp_degree > 1
+
+    @property
+    def tensor_parallel_enabled(self):
+        return self._tp_degree > 1
diff --git a/finetrainers/parallel/base.py b/finetrainers/parallel/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb982ca252fdcd03b644bf5a3215857fada7119c
--- /dev/null
+++ b/finetrainers/parallel/base.py
@@ -0,0 +1,96 @@
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional
+
+import torch
+
+from ..trackers import TrackerType, initialize_trackers
+
+
+class BaseParallelBackend:
+    r"""
+    Base class that contains properties and methods that should be implemented by different parallel backends.
+    """
+
+    def apply_ddp(self, *args, **kwargs) -> torch.nn.Module:
+        raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.")
+
+    def prepare_dataset(self, *args, **kwargs) -> Any:
+        raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.")
+
+    def prepare_dataloader(self, *args, **kwargs) -> Any:
+        raise NotImplementedError("Method `prepare_dataloader` must be implemented by subclass.")
+
+    def prepare_optimizer(self, *args, **kwargs) -> Any:
+        raise NotImplementedError("Method `prepare_optimizer` must be implemented by subclass.")
+
+    def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
+        raise NotImplementedError("Method `get_mesh` must be implemented by subclass.")
+
+    def initialize_trackers(
+        self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
+    ) -> TrackerType:
+        self.tracker = None
+        if self.is_main_process:
+            self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir)
+
+    def log(self, metrics: Dict[str, Any], step: int) -> None:
+        if self.is_main_process:
+            self.tracker.log(metrics, step)
+
+    def wait_for_everyone(self):
+        raise NotImplementedError("Method `wait_for_everyone` must be implemented by subclass.")
+
+    @contextmanager
+    def main_process_first(self):
+        raise NotImplementedError("Method `main_process_first` must be implemented by subclass.")
+
+    def destroy(self):
+        raise NotImplementedError("Method `destroy` must be implemented by subclass.")
+
+    @property
+    def world_size(self):
+        raise NotImplementedError("Method `world_size` must be implemented by subclass.")
+
+    @property
+    def rank(self):
+        raise NotImplementedError("Method `rank` must be implemented by subclass.")
+
+    @property
+    def local_rank(self):
+        raise NotImplementedError("Method `local_rank` must be implemented by subclass.")
+
+    @property
+    def is_main_process(self):
+        raise NotImplementedError("Method `is_main_process` must be implemented by subclass.")
+
+    @property
+    def is_local_main_process(self):
+        raise NotImplementedError("Method `is_local_main_process` must be implemented by subclass.")
+
+    @property
+    def device(self):
+        raise NotImplementedError("Method `device` must be implemented by subclass.")
+
+    @property
+    def pipeline_parallel_enabled(self):
+        raise NotImplementedError("Property `pipeline_parallel_enabled` must be implemented by subclass.")
+
+    @property
+    def data_parallel_enabled(self):
+        raise NotImplementedError("Property `data_parallel_enabled` must be implemented by subclass.")
+
+    @property
+    def data_replication_enabled(self):
+        raise NotImplementedError("Property `data_replication_enabled` must be implemented by subclass.")
+
+    @property
+    def data_sharding_enabled(self):
+        raise NotImplementedError("Property `data_sharding_enabled` must be implemented by subclass.")
+
+    @property
+    def context_parallel_enabled(self):
+        raise NotImplementedError("Property `context_parallel_enabled` must be implemented by subclass.")
+
+    @property
+    def tensor_parallel_enabled(self):
+        raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.")
diff --git a/finetrainers/parallel/deepspeed.py b/finetrainers/parallel/deepspeed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9f54d66ec1941ffc44d6239b305cc397ce61d4
--- /dev/null
+++ b/finetrainers/parallel/deepspeed.py
@@ -0,0 +1,7 @@
+from .base import BaseParallelBackend
+
+
+class DeepspeedParallelBackend(BaseParallelBackend):
+    def __init__(self):
+        # TODO(aryan)
+        raise NotImplementedError("DeepspeedParallelBackend is not implemented yet.")
diff --git a/finetrainers/parallel/ptd.py b/finetrainers/parallel/ptd.py
new file mode 100644
index 0000000000000000000000000000000000000000..352273b4eff7f4cb21424962748a84ff29d96426
--- /dev/null
+++ b/finetrainers/parallel/ptd.py
@@ -0,0 +1,228 @@
+import datetime
+import os
+import pathlib
+from typing import Optional
+
+import datasets.distributed
+import torch
+
+from ..data import DPDataLoader
+from ..logging import get_logger
+from ..utils import get_device_info
+from .base import BaseParallelBackend
+from .utils import apply_ddp_ptd
+
+
+_device_type, _device_module = get_device_info()
+logger = get_logger()
+
+
+class PytorchDTensorParallelBackend(BaseParallelBackend):
+    def __init__(
+        self,
+        world_size: int,
+        pp_degree: int = 1,
+        dp_degree: int = 1,
+        dp_shards: int = -1,
+        cp_degree: int = 1,
+        tp_degree: int = 1,
+        backend: str = "nccl",
+        timeout: int = 180,
+        logging_dir: Optional[str] = None,
+        output_dir: Optional[str] = None,
+        gradient_accumulation_steps: Optional[int] = None,
+    ) -> None:
+        super().__init__()
+
+        self._world_size = world_size
+        self._pp_degree = pp_degree
+        self._dp_degree = dp_degree
+        self._dp_shards = dp_shards
+        self._cp_degree = cp_degree
+        self._tp_degree = tp_degree
+        self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
+        self._logging_dir = (
+            self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
+        )
+        self._backend = backend
+        self._timeout = timeout
+
+        for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]:
+            if degree < 1:
+                raise ValueError(f"Parallel degree must be at least 1, got {degree}.")
+
+        if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size:
+            raise ValueError(
+                f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards."
+            )
+
+        torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout))
+        _device_module.set_device(self.local_rank)
+
+        logger.info(
+            f"Initialized parallel state with:\n"
+            f"  - World size: {world_size}\n"
+            f"  - Pipeline parallel degree: {pp_degree}\n"
+            f"  - Data parallel degree: {dp_degree}\n"
+            f"  - Context parallel degree: {cp_degree}\n"
+            f"  - Tensor parallel degree: {tp_degree}\n"
+            f"  - Data parallel shards: {dp_shards}\n"
+        )
+
+        self._mesh: torch.distributed.DeviceMesh = None
+
+    def apply_ddp(
+        self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None
+    ) -> torch.nn.Module:
+        if device_mesh is None:
+            device_mesh = self.get_mesh()
+        apply_ddp_ptd(model, device_mesh)
+        logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.")
+        return model
+
+    def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
+        dp_mesh = self.get_mesh("dp_replicate")
+        if dp_mesh is None:
+            dp_mesh = self.get_mesh()
+        if self.world_size > 1:
+            dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size()
+        else:
+            dp_local_rank, dp_world_size = 0, 1
+        dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size)
+        logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!")
+        return dataset
+
+    def prepare_dataloader(
+        self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool
+    ) -> DPDataLoader:
+        dp_mesh = self.get_mesh("dp_replicate")
+        if dp_mesh is None:
+            dp_mesh = self.get_mesh()
+        if self.world_size > 1:
+            dp_local_rank = dp_mesh.get_local_rank()
+        else:
+            dp_local_rank = 0
+        dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers)
+        logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!")
+        return dataloader
+
+    def prepare_optimizer(self, optimizer, lr_scheduler):
+        logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!")
+        return optimizer, lr_scheduler
+
+    def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
+        def _get_mesh():
+            if name is None:
+                return self._mesh
+            try:
+                return self._mesh[name]
+            except (KeyError, RuntimeError):
+                if self._mesh.ndim == 0:
+                    return None
+                return self._mesh
+
+        if self._mesh is not None:
+            return _get_mesh()
+
+        mesh_list = [
+            ("pp", self._pp_degree),
+            ("dp_replicate", self._dp_degree),
+            ("dp_shard", self._dp_shards),
+            ("cp", self._cp_degree),
+            ("tp", self._tp_degree),
+        ]
+        mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
+        names = [x[0] for x in mesh_list]
+        degrees = [x[1] for x in mesh_list]
+        mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)
+
+        dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []
+
+        if self.data_replication_enabled:
+            dp_mesh_names.append("dp_replicate")
+            dp_cp_mesh_names.append("dp_replicate")
+        if self.data_sharding_enabled:
+            dp_mesh_names.append("dp_shard")
+            dp_cp_mesh_names.append("dp_shard")
+            dp_shard_cp_mesh_names.append("dp_shard")
+        if self.context_parallel_enabled:
+            dp_cp_mesh_names.append("cp")
+            dp_shard_cp_mesh_names.append("cp")
+
+        if len(dp_mesh_names) > 0:
+            mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
+        if len(dp_cp_mesh_names) > 0:
+            mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
+        if len(dp_shard_cp_mesh_names) > 0:
+            mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")
+
+        logger.debug(f"Device mesh: {mesh}")
+        self._mesh = mesh
+        return _get_mesh()
+
+    @property
+    def world_size(self):
+        return torch.distributed.get_world_size()
+
+    @property
+    def rank(self):
+        return torch.distributed.get_rank()
+
+    @property
+    def local_rank(self):
+        return int(os.environ.get("LOCAL_RANK", 0))
+
+    @property
+    def is_main_process(self):
+        r"""Returns `True` if the current process is the main process on the master node."""
+        return self.rank == 0
+
+    @property
+    def is_local_main_process(self):
+        r"""Returns `True` if the current process is the main process on local node."""
+        return self.local_rank == 0
+
+    @property
+    def device(self):
+        return torch.device(_device_type, self.local_rank)
+
+    def wait_for_everyone(self):
+        return torch.distributed.barrier()
+
+    # @contextmanager
+    # def main_process_first(self):
+    #     if self.is_main_process:
+    #         yield
+    #         self.wait_for_everyone()
+    #     else:
+    #         self.wait_for_everyone()
+    #         yield
+
+    def destroy(self):
+        if self.is_main_process:
+            self.tracker.finish()
+        return torch.distributed.destroy_process_group()
+
+    @property
+    def pipeline_parallel_enabled(self):
+        return self._pp_degree > 1
+
+    @property
+    def data_parallel_enabled(self):
+        return self._dp_degree > 1 or self._dp_shards > 1
+
+    @property
+    def data_replication_enabled(self):
+        return self._dp_degree > 1
+
+    @property
+    def data_sharding_enabled(self):
+        return self._dp_shards > 1
+
+    @property
+    def context_parallel_enabled(self):
+        return self._cp_degree > 1
+
+    @property
+    def tensor_parallel_enabled(self):
+        return self._tp_degree > 1
diff --git a/finetrainers/parallel/utils.py b/finetrainers/parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe434e1a192504998260118e9bd4b615113b2cd8
--- /dev/null
+++ b/finetrainers/parallel/utils.py
@@ -0,0 +1,99 @@
+from typing import Optional
+
+import torch
+import torch.distributed._functional_collectives as funcol
+import torch.distributed.tensor
+from diffusers.utils import is_accelerate_available
+from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
+from torch.distributed._composable.replicate import replicate
+
+from ..utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES
+
+
+if is_accelerate_available():
+    from accelerate import Accelerator
+    from accelerate.utils import (
+        DataLoaderConfiguration,
+        DistributedDataParallelKwargs,
+        InitProcessGroupKwargs,
+        ProjectConfiguration,
+    )
+
+
+def apply_fsdp2_ptd(
+    model: torch.nn.Module,
+    dp_mesh: torch.distributed.device_mesh.DeviceMesh,
+    param_dtype: torch.dtype,
+    reduce_dtype: torch.dtype,
+    output_dtype: torch.dtype,
+    pp_enabled: bool = False,
+    cpu_offload: bool = False,
+) -> None:
+    r"""Apply FSDP2 on a model."""
+    mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True)
+    fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
+
+    if cpu_offload:
+        fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True)
+
+    def apply_fully_shard(blocks):
+        for layer_index, block in enumerate(blocks):
+            if pp_enabled:
+                # For PP, do not reshard after forward to avoid per-microbatch
+                # all-gathers, which can be expensive and non-overlapped
+                reshard_after_forward = False
+            else:
+                # As an optimization, do not reshard after forward for the last
+                # transformer block since FSDP would prefetch it immediately
+                reshard_after_forward = layer_index < len(blocks) - 1
+            fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward)
+
+    for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
+        blocks = getattr(model, transformer_block_name, None)
+        if blocks is not None:
+            apply_fully_shard(blocks)
+
+    fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
+
+
+def apply_ddp_accelerate(
+    model: torch.nn.Module,
+    project_config: Optional[ProjectConfiguration] = None,
+    ddp_kwargs: Optional[DistributedDataParallelKwargs] = None,
+    init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None,
+    dataloader_config: Optional[DataLoaderConfiguration] = None,
+    gradient_accumulation_steps: Optional[int] = None,
+    accelerator: Optional[Accelerator] = None,
+) -> torch.nn.Module:
+    if accelerator is None:
+        accelerator = Accelerator(
+            project_config=project_config,
+            dataloader_config=dataloader_config,
+            gradient_accumulation_steps=gradient_accumulation_steps,
+            log_with=None,
+            kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
+        )
+        if torch.backends.mps.is_available():
+            accelerator.native_amp = False
+    accelerator.prepare_model(model)
+    return accelerator, model
+
+
+def apply_ddp_ptd(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
+    replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
+
+
+def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
+    if isinstance(x, torch.distributed.tensor.DTensor):
+        # functional collectives do not support DTensor inputs
+        x = x.full_tensor()
+    assert x.numel() == 1  # required by `.item()`
+    return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
+
+
+def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
+    return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh)
+
+
+def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
+    return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh)
diff --git a/finetrainers/patches.py b/finetrainers/patches.py
deleted file mode 100644
index 1faacbde58e2948a3aa83d8c848de2a9cc681583..0000000000000000000000000000000000000000
--- a/finetrainers/patches.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import functools
-
-import torch
-from accelerate.logging import get_logger
-from peft.tuners.tuners_utils import BaseTunerLayer
-
-from .constants import FINETRAINERS_LOG_LEVEL
-
-
-logger = get_logger("finetrainers")  # pylint: disable=invalid-name
-logger.setLevel(FINETRAINERS_LOG_LEVEL)
-
-
-def perform_peft_patches() -> None:
-    _perform_patch_move_adapter_to_device_of_base_layer()
-
-
-def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
-    # We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights
-    # are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of
-    # LoRA weights from higher precision dtype.
-    BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
-        BaseTunerLayer._move_adapter_to_device_of_base_layer
-    )
-
-
-def _patched_move_adapter_to_device_of_base_layer(func) -> None:
-    @functools.wraps(func)
-    def wrapper(self, *args, **kwargs):
-        with DisableTensorToDtype():
-            return func(self, *args, **kwargs)
-
-    return wrapper
-
-
-class DisableTensorToDtype:
-    def __enter__(self):
-        self.original_to = torch.Tensor.to
-
-        def modified_to(tensor, *args, **kwargs):
-            # remove dtype from args if present
-            args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
-            if "dtype" in kwargs:
-                kwargs.pop("dtype")
-            return self.original_to(tensor, *args, **kwargs)
-
-        torch.Tensor.to = modified_to
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        torch.Tensor.to = self.original_to
diff --git a/finetrainers/patches/__init__.py b/finetrainers/patches/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d499f4d2136f8cdb4a9eb830289bb17757e7b37
--- /dev/null
+++ b/finetrainers/patches/__init__.py
@@ -0,0 +1,23 @@
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+    from ..args import BaseArgs
+    from ..parallel import ParallelBackendType
+
+
+def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBackendType") -> None:
+    # To avoid circular imports
+    from ..config import ModelType, TrainingType
+
+    if args.model_name == ModelType.LTX_VIDEO:
+        from .models.ltx_video import patch
+
+        patch.patch_transformer_forward()
+        if parallel_backend.tensor_parallel_enabled:
+            patch.patch_apply_rotary_emb_for_tp_compatibility()
+
+    if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
+        from dependencies.peft import patch
+
+        patch.patch_peft_move_adapter_to_device_of_base_layer()
diff --git a/finetrainers/patches/dependencies/peft/patch.py b/finetrainers/patches/dependencies/peft/patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c0bf968cf4894b8ccd90c76631e57541e10642d
--- /dev/null
+++ b/finetrainers/patches/dependencies/peft/patch.py
@@ -0,0 +1,25 @@
+import functools
+
+from peft.tuners.tuners_utils import BaseTunerLayer
+
+from ...utils import DisableTensorToDtype
+
+
+def patch_peft_move_adapter_to_device_of_base_layer() -> None:
+    _perform_patch_move_adapter_to_device_of_base_layer()
+
+
+def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
+    BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
+        BaseTunerLayer._move_adapter_to_device_of_base_layer
+    )
+
+
+def _patched_move_adapter_to_device_of_base_layer(func) -> None:
+    # TODO(aryan): This is really unsafe probably and may break things. It works for now, but revisit and refactor.
+    @functools.wraps(func)
+    def wrapper(self, *args, **kwargs):
+        with DisableTensorToDtype():
+            return func(self, *args, **kwargs)
+
+    return wrapper
diff --git a/finetrainers/patches/models/ltx_video/patch.py b/finetrainers/patches/models/ltx_video/patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..851da6e795ee779c4a145d3baf89e50bb001adec
--- /dev/null
+++ b/finetrainers/patches/models/ltx_video/patch.py
@@ -0,0 +1,127 @@
+from typing import Any, Dict, Optional, Tuple
+
+import diffusers
+import torch
+from diffusers import LTXVideoTransformer3DModel
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.utils.import_utils import is_torch_version
+
+
+def patch_transformer_forward() -> None:
+    _perform_ltx_transformer_forward_patch()
+
+
+def patch_apply_rotary_emb_for_tp_compatibility() -> None:
+    _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch()
+
+
+def _perform_ltx_transformer_forward_patch() -> None:
+    LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3Dforward
+
+
+def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
+    def apply_rotary_emb(x, freqs):
+        cos, sin = freqs
+        # ======== THIS IS CHANGED FROM THE ORIGINAL IMPLEMENTATION ========
+        # The change is made due to unsupported DTensor operation aten.ops.unbind
+        # FIXME: Once aten.ops.unbind support lands, this will no longer be required
+        # x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1)  # [B, S, H, D // 2]
+        x_real, x_imag = x.unflatten(2, (-1, 2)).chunk(2, dim=-1)  # [B, S, H, D // 2]
+        # ==================================================================
+        x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
+        out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+        return out
+
+    diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
+
+
+def _patched_LTXVideoTransformer3Dforward(
+    self,
+    hidden_states: torch.Tensor,
+    encoder_hidden_states: torch.Tensor,
+    timestep: torch.LongTensor,
+    encoder_attention_mask: torch.Tensor,
+    num_frames: int,
+    height: int,
+    width: int,
+    rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
+    return_dict: bool = True,
+    *args,
+    **kwargs,
+) -> torch.Tensor:
+    image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
+
+    # convert encoder_attention_mask to a bias the same way we do for attention_mask
+    if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+        encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+    batch_size = hidden_states.size(0)
+
+    # ===== This is modified compared to Diffusers =====
+    # This is done because the Diffusers pipeline will pass in a 1D tensor for timestep
+    if timestep.ndim == 1:
+        timestep = timestep.view(-1, 1, 1).expand(-1, *hidden_states.shape[1:-1], -1)
+    # ==================================================
+
+    temb, embedded_timestep = self.time_embed(
+        timestep.flatten(),
+        batch_size=batch_size,
+        hidden_dtype=hidden_states.dtype,
+    )
+
+    # ===== This is modified compared to Diffusers =====
+    # temb = temb.view(batch_size, -1, temb.size(-1))
+    # embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
+    # ==================================================
+    # This is done to make it possible to use per-token timestep embedding
+    temb = temb.view(batch_size, *hidden_states.shape[1:-1], temb.size(-1))
+    embedded_timestep = embedded_timestep.view(batch_size, *hidden_states.shape[1:-1], embedded_timestep.size(-1))
+    # ==================================================
+
+    hidden_states = self.proj_in(hidden_states)
+
+    encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+    encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
+
+    for block in self.transformer_blocks:
+        if torch.is_grad_enabled() and self.gradient_checkpointing:
+
+            def create_custom_forward(module, return_dict=None):
+                def custom_forward(*inputs):
+                    if return_dict is not None:
+                        return module(*inputs, return_dict=return_dict)
+                    else:
+                        return module(*inputs)
+
+                return custom_forward
+
+            ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+            hidden_states = torch.utils.checkpoint.checkpoint(
+                create_custom_forward(block),
+                hidden_states,
+                encoder_hidden_states,
+                temb,
+                image_rotary_emb,
+                encoder_attention_mask,
+                **ckpt_kwargs,
+            )
+        else:
+            hidden_states = block(
+                hidden_states=hidden_states,
+                encoder_hidden_states=encoder_hidden_states,
+                temb=temb,
+                image_rotary_emb=image_rotary_emb,
+                encoder_attention_mask=encoder_attention_mask,
+            )
+
+    scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
+    shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
+
+    hidden_states = self.norm_out(hidden_states)
+    hidden_states = hidden_states * (1 + scale) + shift
+    output = self.proj_out(hidden_states)
+
+    if not return_dict:
+        return (output,)
+    return Transformer2DModelOutput(sample=output)
diff --git a/finetrainers/patches/utils.py b/finetrainers/patches/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d7f4726cc8183a461310570762ee95b5c4e6187
--- /dev/null
+++ b/finetrainers/patches/utils.py
@@ -0,0 +1,18 @@
+import torch
+
+
+class DisableTensorToDtype:
+    def __enter__(self):
+        self.original_to = torch.Tensor.to
+
+        def modified_to(tensor, *args, **kwargs):
+            # remove dtype from args if present
+            args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
+            if "dtype" in kwargs:
+                kwargs.pop("dtype")
+            return self.original_to(tensor, *args, **kwargs)
+
+        torch.Tensor.to = modified_to
+
+    def __exit__(self, *args, **kwargs):
+        torch.Tensor.to = self.original_to
diff --git a/finetrainers/processors/__init__.py b/finetrainers/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e55b3d14743fc01128d07ffb115fc84e98cc6eee
--- /dev/null
+++ b/finetrainers/processors/__init__.py
@@ -0,0 +1,5 @@
+from .base import ProcessorMixin
+from .clip import CLIPPooledProcessor
+from .llama import LlamaProcessor
+from .t5 import T5Processor
+from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor
diff --git a/finetrainers/processors/base.py b/finetrainers/processors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..9853ead0ef49610bdfed05ff067918bf80558109
--- /dev/null
+++ b/finetrainers/processors/base.py
@@ -0,0 +1,20 @@
+import inspect
+from typing import Any, Dict, List
+
+
+class ProcessorMixin:
+    def __init__(self) -> None:
+        self._forward_parameter_names = inspect.signature(self.forward).parameters.keys()
+        self.output_names: List[str] = None
+        self.input_names: Dict[str, Any] = None
+
+    def __call__(self, *args, **kwargs) -> Any:
+        shallow_copy_kwargs = dict(kwargs.items())
+        if self.input_names is not None:
+            for k, v in self.input_names.items():
+                shallow_copy_kwargs[v] = shallow_copy_kwargs.pop(k)
+        acceptable_kwargs = {k: v for k, v in shallow_copy_kwargs.items() if k in self._forward_parameter_names}
+        return self.forward(*args, **acceptable_kwargs)
+
+    def forward(self, *args, **kwargs) -> Any:
+        raise NotImplementedError("ProcessorMixin::forward method should be implemented by the subclass.")
diff --git a/finetrainers/processors/clip.py b/finetrainers/processors/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..178addf8b2556e7c1ae952084790f2575d27f007
--- /dev/null
+++ b/finetrainers/processors/clip.py
@@ -0,0 +1,65 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPTokenizerFast
+
+from .base import ProcessorMixin
+
+
+class CLIPPooledProcessor(ProcessorMixin):
+    r"""
+    Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings
+    and attention masks for the input text.
+
+    Args:
+        output_names (`List[str]`):
+            The names of the outputs that the processor should return. The first output is the embeddings of the input
+            text and the second output is the attention mask for the input text.
+    """
+
+    def __init__(self, output_names: List[str] = None, input_names: Optional[Dict[str, Any]] = None) -> None:
+        super().__init__()
+
+        self.output_names = output_names
+        self.input_names = input_names
+
+        assert len(output_names) == 1
+        if input_names is not None:
+            assert len(input_names) <= 3
+
+    def forward(
+        self,
+        tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
+        text_encoder: CLIPTextModel,
+        caption: Union[str, List[str]],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""
+        Encode the input text and return the embeddings and attention mask for the input text.
+
+        Args:
+            tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`):
+                The tokenizer used to tokenize the input text.
+            text_encoder (`LlamaModel`):
+                The text encoder used to encode the input text.
+            caption (`Union[str, List[str]]`):
+                The input text to be encoded.
+        """
+        if isinstance(caption, str):
+            caption = [caption]
+
+        device = text_encoder.device
+        dtype = text_encoder.dtype
+
+        text_inputs = tokenizer(
+            caption,
+            padding="max_length",
+            max_length=77,
+            truncation=True,
+            return_tensors="pt",
+        )
+        text_input_ids = text_inputs.input_ids.to(device)
+
+        prompt_embeds = text_encoder(text_input_ids, output_hidden_states=False).pooler_output
+        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+        return {self.output_names[0]: prompt_embeds}
diff --git a/finetrainers/processors/llama.py b/finetrainers/processors/llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..749e5f313541b92317279669faf915edeb9129c4
--- /dev/null
+++ b/finetrainers/processors/llama.py
@@ -0,0 +1,118 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import LlamaModel, LlamaTokenizer, LlamaTokenizerFast
+
+from .base import ProcessorMixin
+
+
+DEFAULT_PROMPT_TEMPLATE = {
+    "template": (
+        "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+        "1. The main content and theme of the video."
+        "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+        "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+        "4. background environment, light, style and atmosphere."
+        "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+        "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+    ),
+    "crop_start": 95,
+}
+
+
+class LlamaProcessor(ProcessorMixin):
+    r"""
+    Processor for the Llama family of models. This processor is used to encode text inputs and return the embeddings
+    and attention masks for the input text.
+
+    Args:
+        output_names (`List[str]`):
+            The names of the outputs that the processor should return. The first output is the embeddings of the input
+            text and the second output is the attention mask for the input text.
+    """
+
+    def __init__(self, output_names: List[str] = None):
+        super().__init__()
+
+        self.output_names = output_names
+
+        assert len(output_names) == 2
+
+    def forward(
+        self,
+        tokenizer: Union[LlamaTokenizer, LlamaTokenizerFast],
+        text_encoder: LlamaModel,
+        caption: Union[str, List[str]],
+        max_sequence_length: int,
+        prompt_template: Optional[Dict[str, Any]] = None,
+        num_layers_to_skip: int = 2,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""
+        Encode the input text and return the embeddings and attention mask for the input text.
+
+        Args:
+            tokenizer (`Union[LlamaTokenizer, LlamaTokenizerFast]`):
+                The tokenizer used to tokenize the input text.
+            text_encoder (`LlamaModel`):
+                The text encoder used to encode the input text.
+            caption (`Union[str, List[str]]`):
+                The input text to be encoded.
+            max_sequence_length (`int`):
+                The maximum sequence length of the input text.
+            prompt_template (`Optional[Dict[str, Any]]`):
+                The prompt template to be used to encode the input text.
+        """
+        if prompt_template is None:
+            prompt_template = DEFAULT_PROMPT_TEMPLATE
+        if isinstance(caption, str):
+            caption = [caption]
+
+        device = text_encoder.device
+        dtype = text_encoder.dtype
+
+        batch_size = len(caption)
+        caption = [prompt_template["template"].format(c) for c in caption]
+
+        crop_start = prompt_template.get("crop_start", None)
+        if crop_start is None:
+            prompt_template_input = tokenizer(
+                prompt_template["template"],
+                padding="max_length",
+                return_tensors="pt",
+                return_length=False,
+                return_overflowing_tokens=False,
+                return_attention_mask=False,
+            )
+            crop_start = prompt_template_input["input_ids"].shape[-1]
+            # Remove <|eot_id|> token and placeholder {}
+            crop_start -= 2
+
+        max_sequence_length += crop_start
+        text_inputs = tokenizer(
+            caption,
+            max_length=max_sequence_length,
+            padding="max_length",
+            truncation=True,
+            return_tensors="pt",
+            return_length=False,
+            return_overflowing_tokens=False,
+            return_attention_mask=True,
+        )
+        text_input_ids = text_inputs.input_ids.to(device)
+        prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
+
+        prompt_embeds = text_encoder(
+            text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
+        ).hidden_states[-(num_layers_to_skip + 1)]
+        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+        if crop_start is not None and crop_start > 0:
+            prompt_embeds = prompt_embeds[:, crop_start:]
+            prompt_attention_mask = prompt_attention_mask[:, crop_start:]
+
+        prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+
+        return {
+            self.output_names[0]: prompt_embeds,
+            self.output_names[1]: prompt_attention_mask,
+        }
diff --git a/finetrainers/processors/t5.py b/finetrainers/processors/t5.py
new file mode 100644
index 0000000000000000000000000000000000000000..96c2c194ca01635de404d618afffa93b88cdf953
--- /dev/null
+++ b/finetrainers/processors/t5.py
@@ -0,0 +1,73 @@
+from typing import List, Tuple, Union
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer, T5TokenizerFast
+
+from .base import ProcessorMixin
+
+
+class T5Processor(ProcessorMixin):
+    r"""
+    Processor for the T5 family of models. This processor is used to encode text inputs and return the embeddings
+    and attention masks for the input text.
+
+    Args:
+        output_names (`List[str]`):
+            The names of the outputs that the processor should return. The first output is the embeddings of the input
+            text and the second output is the attention mask for the input text.
+    """
+
+    def __init__(self, output_names: List[str]):
+        super().__init__()
+
+        self.output_names = output_names
+
+        assert len(self.output_names) == 2
+
+    def forward(
+        self,
+        tokenizer: Union[T5Tokenizer, T5TokenizerFast],
+        text_encoder: T5EncoderModel,
+        caption: Union[str, List[str]],
+        max_sequence_length: int,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        r"""
+        Encode the input text and return the embeddings and attention mask for the input text.
+
+        Args:
+            tokenizer (`Union[T5Tokenizer, T5TokenizerFast]`):
+                The tokenizer used to tokenize the input text.
+            text_encoder (`T5EncoderModel`):
+                The text encoder used to encode the input text.
+            caption (`Union[str, List[str]]`):
+                The input text to be encoded.
+            max_sequence_length (`int`):
+                The maximum sequence length of the input text.
+        """
+        if isinstance(caption, str):
+            caption = [caption]
+
+        device = text_encoder.device
+        dtype = text_encoder.dtype
+
+        batch_size = len(caption)
+        text_inputs = tokenizer(
+            caption,
+            padding="max_length",
+            max_length=max_sequence_length,
+            truncation=True,
+            add_special_tokens=True,
+            return_tensors="pt",
+        )
+        text_input_ids = text_inputs.input_ids
+        prompt_attention_mask = text_inputs.attention_mask
+        prompt_attention_mask = prompt_attention_mask.bool().to(device)
+
+        prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+        prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
+
+        return {
+            self.output_names[0]: prompt_embeds,
+            self.output_names[1]: prompt_attention_mask,
+        }
diff --git a/finetrainers/processors/text.py b/finetrainers/processors/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..b51dca68214ba36c5813775f6dbc6d40592e9b3c
--- /dev/null
+++ b/finetrainers/processors/text.py
@@ -0,0 +1,22 @@
+from typing import List, Union
+
+import torch
+
+from .. import functional as FF
+from .base import ProcessorMixin
+
+
+class CaptionTextDropoutProcessor(ProcessorMixin):
+    def __init__(self, dropout_p: float = 0.0) -> None:
+        self.dropout_p = dropout_p
+
+    def forward(self, caption: Union[str, List[str]]) -> Union[str, List[str]]:
+        return FF.dropout_caption(caption, self.dropout_p)
+
+
+class CaptionEmbeddingDropoutProcessor(ProcessorMixin):
+    def __init__(self, dropout_p: float = 0.0) -> None:
+        self.dropout_p = dropout_p
+
+    def forward(self, embedding: torch.Tensor) -> torch.Tensor:
+        return FF.dropout_embeddings_to_zero(embedding, self.dropout_p)
diff --git a/finetrainers/state.py b/finetrainers/state.py
index 15a92e23da840af7e6920d20ea6cd4252feb47ed..5cda7d91b3f0e82b493d7e88b3565b9df985a228 100644
--- a/finetrainers/state.py
+++ b/finetrainers/state.py
@@ -1,21 +1,66 @@
+import io
+from dataclasses import dataclass, field
+from typing import Any, Dict, List
+
 import torch
-from accelerate import Accelerator
+import torch.distributed.checkpoint.stateful
+
+from .parallel import ParallelBackendType
+from .utils import get_device_info
+
+
+_device_type, _ = get_device_info()
+
 
+@dataclass
+class TrainState(torch.distributed.checkpoint.stateful.Stateful):
+    step: int = 0
+    observed_data_samples: int = 0
+    observed_num_tokens: int = 0
+    global_avg_losses: List[float] = field(default_factory=list)
+    global_max_losses: List[float] = field(default_factory=list)
+    log_steps: List[int] = field(default_factory=list)
 
+    def state_dict(self) -> Dict[str, Any]:
+        # Only checkpoint global_avg_losses and global_max_losses per log frequency
+        # to avoid sync overhead in every iteration.
+        global_avg_losses_bytes = io.BytesIO()
+        torch.save(self.global_avg_losses, global_avg_losses_bytes)
+        global_max_losses_bytes = io.BytesIO()
+        torch.save(self.global_max_losses, global_max_losses_bytes)
+        log_steps_bytes = io.BytesIO()
+        torch.save(self.log_steps, log_steps_bytes)
+        return {
+            "step": torch.tensor(self.step, dtype=torch.int32),
+            "observed_data_samples": torch.tensor(self.observed_data_samples, dtype=torch.int32),
+            "observed_num_tokens": torch.tensor(self.observed_num_tokens, dtype=torch.int32),
+            "global_avg_losses": global_avg_losses_bytes,
+            "global_max_losses": global_max_losses_bytes,
+            "log_steps": log_steps_bytes,
+        }
+
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+        state_dict["global_avg_losses"].seek(0)
+        state_dict["global_max_losses"].seek(0)
+        state_dict["log_steps"].seek(0)
+
+        self.step = state_dict["step"].item()
+        self.observed_data_samples = state_dict["observed_data_samples"].item()
+        self.observed_num_tokens = state_dict["observed_num_tokens"].item()
+        self.global_avg_losses = torch.load(state_dict["global_avg_losses"], weights_only=False)
+        self.global_max_losses = torch.load(state_dict["global_max_losses"], weights_only=False)
+        self.log_steps = torch.load(state_dict["log_steps"], weights_only=False)
+
+
+@dataclass
 class State:
+    # Parallel state
+    parallel_backend: ParallelBackendType = None
+
     # Training state
-    seed: int = None
-    model_name: str = None
-    accelerator: Accelerator = None
-    weight_dtype: torch.dtype = None
-    train_epochs: int = None
-    train_steps: int = None
-    overwrote_max_train_steps: bool = False
+    train_state: TrainState = None
     num_trainable_parameters: int = 0
-    learning_rate: float = None
-    train_batch_size: int = None
     generator: torch.Generator = None
-    num_update_steps_per_epoch: int = None
 
     # Hub state
     repo_id: str = None
diff --git a/finetrainers/trackers.py b/finetrainers/trackers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a48716605e1ed2e39ab5d86cd39f72467496fd52
--- /dev/null
+++ b/finetrainers/trackers.py
@@ -0,0 +1,92 @@
+import pathlib
+from enum import Enum
+from typing import Any, Dict, List, Optional, Union
+
+from .logging import get_logger
+
+
+logger = get_logger()
+
+
+class BaseTracker:
+    r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging."""
+
+    def log(self, metrics: Dict[str, Any], step: int) -> None:
+        pass
+
+    def finish(self) -> None:
+        pass
+
+
+class WandbTracker(BaseTracker):
+    r"""Logger implementation for Weights & Biases."""
+
+    def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None:
+        import wandb
+
+        self.wandb = wandb
+
+        # WandB does not create a directory if it does not exist and instead starts using the system temp directory.
+        pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)
+
+        self.run = wandb.init(project=experiment_name, dir=log_dir, config=config)
+        logger.info("WandB logging enabled")
+
+    def log(self, metrics: Dict[str, Any], step: int) -> None:
+        self.run.log(metrics, step=step)
+
+    def finish(self) -> None:
+        self.run.finish()
+
+
+class SequentialTracker(BaseTracker):
+    r"""Sequential tracker that logs to multiple trackers in sequence."""
+
+    def __init__(self, trackers: List[BaseTracker]) -> None:
+        self.trackers = trackers
+
+    def log(self, metrics: Dict[str, Any], step: int) -> None:
+        for tracker in self.trackers:
+            tracker.log(metrics, step)
+
+    def finish(self) -> None:
+        for tracker in self.trackers:
+            tracker.finish()
+
+
+class Trackers(str, Enum):
+    r"""Enum for supported trackers."""
+
+    NONE = "none"
+    WANDB = "wandb"
+
+
+_SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()]
+
+
+def initialize_trackers(
+    trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
+) -> Union[BaseTracker, SequentialTracker]:
+    r"""Initialize loggers based on the provided configuration."""
+
+    logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}")
+
+    if len(trackers) == 0:
+        return BaseTracker()
+
+    if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)):
+        raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}")
+
+    tracker_instances = []
+    for tracker_name in set(trackers):
+        if tracker_name == Trackers.NONE:
+            tracker = BaseTracker()
+        elif tracker_name == Trackers.WANDB:
+            tracker = WandbTracker(experiment_name, log_dir, config)
+        tracker_instances.append(tracker)
+
+    tracker = SequentialTracker(tracker_instances)
+    return tracker
+
+
+TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker]
diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py
deleted file mode 100644
index a213c00a129a407bf304b51cb8ae0965afd54d98..0000000000000000000000000000000000000000
--- a/finetrainers/trainer.py
+++ /dev/null
@@ -1,1235 +0,0 @@
-import json
-import logging
-import math
-import os
-import gc
-import random
-from datetime import datetime, timedelta
-from pathlib import Path
-from typing import Any, Dict, List
-import resource
-import diffusers
-import torch
-import torch.backends
-import transformers
-import wandb
-from accelerate import Accelerator, DistributedType
-from accelerate.logging import get_logger
-from accelerate.utils import (
-    DistributedDataParallelKwargs,
-    InitProcessGroupKwargs,
-    ProjectConfiguration,
-    gather_object,
-    set_seed,
-)
-from diffusers import DiffusionPipeline
-from diffusers.configuration_utils import FrozenDict
-from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
-from diffusers.optimization import get_scheduler
-from diffusers.training_utils import cast_training_params
-from diffusers.utils import export_to_video, load_image, load_video
-from huggingface_hub import create_repo, upload_folder
-from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
-from tqdm import tqdm
-
-from .args import Args, validate_args
-from .constants import (
-    FINETRAINERS_LOG_LEVEL,
-    PRECOMPUTED_CONDITIONS_DIR_NAME,
-    PRECOMPUTED_DIR_NAME,
-    PRECOMPUTED_LATENTS_DIR_NAME,
-)
-from .dataset import BucketSampler, ImageOrVideoDatasetWithResizing, PrecomputedDataset
-from .hooks import apply_layerwise_upcasting
-from .models import get_config_from_model_name
-from .patches import perform_peft_patches
-from .state import State
-from .utils.checkpointing import get_intermediate_ckpt_path, get_latest_ckpt_path_to_resume_from
-from .utils.data_utils import should_perform_precomputation
-from .utils.diffusion_utils import (
-    get_scheduler_alphas,
-    get_scheduler_sigmas,
-    prepare_loss_weights,
-    prepare_sigmas,
-    prepare_target,
-)
-from .utils.file_utils import string_to_filename
-from .utils.hub_utils import save_model_card
-from .utils.memory_utils import free_memory, get_memory_statistics, make_contiguous
-from .utils.model_utils import resolve_vae_cls_from_ckpt_path
-from .utils.optimizer_utils import get_optimizer
-from .utils.torch_utils import align_device_and_dtype, expand_tensor_dims, unwrap_model
-
-
-logger = get_logger("finetrainers")
-logger.setLevel(FINETRAINERS_LOG_LEVEL)
-
-
-class Trainer:
-    def __init__(self, args: Args) -> None:
-        validate_args(args)
-
-        self.args = args
-        self.args.seed = self.args.seed or datetime.now().year
-        self.state = State()
-
-        # Tokenizers
-        self.tokenizer = None
-        self.tokenizer_2 = None
-        self.tokenizer_3 = None
-
-        # Text encoders
-        self.text_encoder = None
-        self.text_encoder_2 = None
-        self.text_encoder_3 = None
-
-        # Denoisers
-        self.transformer = None
-        self.unet = None
-
-        # Autoencoders
-        self.vae = None
-
-        # Scheduler
-        self.scheduler = None
-
-        self.transformer_config = None
-        self.vae_config = None
-
-        self._init_distributed()
-        self._init_logging()
-        self._init_directories_and_repositories()
-        self._init_config_options()
-
-        # Peform any patches needed for training
-        if len(self.args.layerwise_upcasting_modules) > 0:
-            perform_peft_patches()
-        # TODO(aryan): handle text encoders
-        # if any(["text_encoder" in component_name for component_name in self.args.layerwise_upcasting_modules]):
-        #     perform_text_encoder_patches()
-
-        self.state.model_name = self.args.model_name
-        self.model_config = get_config_from_model_name(self.args.model_name, self.args.training_type)
-
-    def prepare_dataset(self) -> None:
-        # TODO(aryan): Make a background process for fetching
-        logger.info("Initializing dataset and dataloader")
-
-        self.dataset = ImageOrVideoDatasetWithResizing(
-            data_root=self.args.data_root,
-            caption_column=self.args.caption_column,
-            video_column=self.args.video_column,
-            resolution_buckets=self.args.video_resolution_buckets,
-            dataset_file=self.args.dataset_file,
-            id_token=self.args.id_token,
-            remove_llm_prefixes=self.args.remove_common_llm_caption_prefixes,
-        )
-        self.dataloader = torch.utils.data.DataLoader(
-            self.dataset,
-            batch_size=1,
-            sampler=BucketSampler(self.dataset, batch_size=self.args.batch_size, shuffle=True),
-            collate_fn=self.model_config.get("collate_fn"),
-            num_workers=self.args.dataloader_num_workers,
-            pin_memory=self.args.pin_memory,
-        )
-
-    def prepare_models(self) -> None:
-        logger.info("Initializing models")
-
-        load_components_kwargs = self._get_load_components_kwargs()
-        condition_components, latent_components, diffusion_components = {}, {}, {}
-        if not self.args.precompute_conditions:
-            # To download the model files first on the main process (if not already present)
-            # and then load the cached files afterward from the other processes.
-            with self.state.accelerator.main_process_first():
-                condition_components = self.model_config["load_condition_models"](**load_components_kwargs)
-                latent_components = self.model_config["load_latent_models"](**load_components_kwargs)
-                diffusion_components = self.model_config["load_diffusion_models"](**load_components_kwargs)
-
-        components = {}
-        components.update(condition_components)
-        components.update(latent_components)
-        components.update(diffusion_components)
-        self._set_components(components)
-
-        if self.vae is not None:
-            if self.args.enable_slicing:
-                self.vae.enable_slicing()
-            if self.args.enable_tiling:
-                self.vae.enable_tiling()
-
-    def prepare_precomputations(self) -> None:
-        if not self.args.precompute_conditions:
-            return
-
-        logger.info("Initializing precomputations")
-
-        if self.args.batch_size != 1:
-            raise ValueError("Precomputation is only supported with batch size 1. This will be supported in future.")
-
-        def collate_fn(batch):
-            latent_conditions = [x["latent_conditions"] for x in batch]
-            text_conditions = [x["text_conditions"] for x in batch]
-            batched_latent_conditions = {}
-            batched_text_conditions = {}
-            for key in list(latent_conditions[0].keys()):
-                if torch.is_tensor(latent_conditions[0][key]):
-                    batched_latent_conditions[key] = torch.cat([x[key] for x in latent_conditions], dim=0)
-                else:
-                    # TODO(aryan): implement batch sampler for precomputed latents
-                    batched_latent_conditions[key] = [x[key] for x in latent_conditions][0]
-            for key in list(text_conditions[0].keys()):
-                if torch.is_tensor(text_conditions[0][key]):
-                    batched_text_conditions[key] = torch.cat([x[key] for x in text_conditions], dim=0)
-                else:
-                    # TODO(aryan): implement batch sampler for precomputed latents
-                    batched_text_conditions[key] = [x[key] for x in text_conditions][0]
-            return {"latent_conditions": batched_latent_conditions, "text_conditions": batched_text_conditions}
-
-        cleaned_model_id = string_to_filename(self.args.pretrained_model_name_or_path)
-        precomputation_dir = (
-            Path(self.args.data_root) / f"{self.args.model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
-        )
-        should_precompute = should_perform_precomputation(precomputation_dir)
-        if not should_precompute:
-            logger.info("Precomputed conditions and latents found. Loading precomputed data.")
-            self.dataloader = torch.utils.data.DataLoader(
-                PrecomputedDataset(
-                    data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
-                ),
-                batch_size=self.args.batch_size,
-                shuffle=True,
-                collate_fn=collate_fn,
-                num_workers=self.args.dataloader_num_workers,
-                pin_memory=self.args.pin_memory,
-            )
-            return
-
-        logger.info("Precomputed conditions and latents not found. Running precomputation.")
-
-        # At this point, no models are loaded, so we need to load and precompute conditions and latents
-        with self.state.accelerator.main_process_first():
-            condition_components = self.model_config["load_condition_models"](**self._get_load_components_kwargs())
-        self._set_components(condition_components)
-        self._move_components_to_device()
-        self._disable_grad_for_components([self.text_encoder, self.text_encoder_2, self.text_encoder_3])
-
-        if self.args.caption_dropout_p > 0 and self.args.caption_dropout_technique == "empty":
-            logger.warning(
-                "Caption dropout is not supported with precomputation yet. This will be supported in the future."
-            )
-
-        conditions_dir = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
-        latents_dir = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
-        conditions_dir.mkdir(parents=True, exist_ok=True)
-        latents_dir.mkdir(parents=True, exist_ok=True)
-
-        accelerator = self.state.accelerator
-
-        # Precompute conditions
-        progress_bar = tqdm(
-            range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
-            desc="Precomputing conditions",
-            disable=not accelerator.is_local_main_process,
-        )
-        index = 0
-        for i, data in enumerate(self.dataset):
-            if i % accelerator.num_processes != accelerator.process_index:
-                continue
-
-            logger.debug(
-                f"Precomputing conditions for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
-            )
-
-            text_conditions = self.model_config["prepare_conditions"](
-                tokenizer=self.tokenizer,
-                tokenizer_2=self.tokenizer_2,
-                tokenizer_3=self.tokenizer_3,
-                text_encoder=self.text_encoder,
-                text_encoder_2=self.text_encoder_2,
-                text_encoder_3=self.text_encoder_3,
-                prompt=data["prompt"],
-                device=accelerator.device,
-                dtype=self.args.transformer_dtype,
-            )
-            filename = conditions_dir / f"conditions-{accelerator.process_index}-{index}.pt"
-            torch.save(text_conditions, filename.as_posix())
-            index += 1
-            progress_bar.update(1)
-        self._delete_components()
-
-        memory_statistics = get_memory_statistics()
-        logger.info(f"Memory after precomputing conditions: {json.dumps(memory_statistics, indent=4)}")
-        torch.cuda.reset_peak_memory_stats(accelerator.device)
-
-        # Precompute latents
-        with self.state.accelerator.main_process_first():
-            latent_components = self.model_config["load_latent_models"](**self._get_load_components_kwargs())
-        self._set_components(latent_components)
-        self._move_components_to_device()
-        self._disable_grad_for_components([self.vae])
-
-        if self.vae is not None:
-            if self.args.enable_slicing:
-                self.vae.enable_slicing()
-            if self.args.enable_tiling:
-                self.vae.enable_tiling()
-
-        progress_bar = tqdm(
-            range(0, (len(self.dataset) + accelerator.num_processes - 1) // accelerator.num_processes),
-            desc="Precomputing latents",
-            disable=not accelerator.is_local_main_process,
-        )
-        index = 0
-        for i, data in enumerate(self.dataset):
-            if i % accelerator.num_processes != accelerator.process_index:
-                continue
-
-            logger.debug(
-                f"Precomputing latents for batch {i + 1}/{len(self.dataset)} on process {accelerator.process_index}"
-            )
-
-            latent_conditions = self.model_config["prepare_latents"](
-                vae=self.vae,
-                image_or_video=data["video"].unsqueeze(0),
-                device=accelerator.device,
-                dtype=self.args.transformer_dtype,
-                generator=self.state.generator,
-                precompute=True,
-            )
-            filename = latents_dir / f"latents-{accelerator.process_index}-{index}.pt"
-            torch.save(latent_conditions, filename.as_posix())
-            index += 1
-            progress_bar.update(1)
-        self._delete_components()
-
-        accelerator.wait_for_everyone()
-        logger.info("Precomputation complete")
-
-        memory_statistics = get_memory_statistics()
-        logger.info(f"Memory after precomputing latents: {json.dumps(memory_statistics, indent=4)}")
-        torch.cuda.reset_peak_memory_stats(accelerator.device)
-
-        # Update dataloader to use precomputed conditions and latents
-        self.dataloader = torch.utils.data.DataLoader(
-            PrecomputedDataset(
-                data_root=self.args.data_root, model_name=self.args.model_name, cleaned_model_id=cleaned_model_id
-            ),
-            batch_size=self.args.batch_size,
-            shuffle=True,
-            collate_fn=collate_fn,
-            num_workers=self.args.dataloader_num_workers,
-            pin_memory=self.args.pin_memory,
-        )
-
-    def prepare_trainable_parameters(self) -> None:
-        logger.info("Initializing trainable parameters")
-
-        with self.state.accelerator.main_process_first():
-            diffusion_components = self.model_config["load_diffusion_models"](**self._get_load_components_kwargs())
-        self._set_components(diffusion_components)
-
-        components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.vae]
-        self._disable_grad_for_components(components)
-
-        if self.args.training_type == "full-finetune":
-            logger.info("Finetuning transformer with no additional parameters")
-            self._enable_grad_for_components([self.transformer])
-        else:
-            logger.info("Finetuning transformer with PEFT parameters")
-            self._disable_grad_for_components([self.transformer])
-
-        # Layerwise upcasting must be applied before adding the LoRA adapter.
-        # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
-        # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
-        if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules:
-            apply_layerwise_upcasting(
-                self.transformer,
-                storage_dtype=self.args.layerwise_upcasting_storage_dtype,
-                compute_dtype=self.args.transformer_dtype,
-                skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
-                non_blocking=True,
-            )
-
-        self._move_components_to_device()
-
-        if self.args.gradient_checkpointing:
-            self.transformer.enable_gradient_checkpointing()
-
-        if self.args.training_type == "lora":
-            transformer_lora_config = LoraConfig(
-                r=self.args.rank,
-                lora_alpha=self.args.lora_alpha,
-                init_lora_weights=True,
-                target_modules=self.args.target_modules,
-            )
-            self.transformer.add_adapter(transformer_lora_config)
-        else:
-            transformer_lora_config = None
-
-        # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32
-        # even if layerwise upcasting. Would be nice to have a test as well
-
-        self.register_saving_loading_hooks(transformer_lora_config)
-
-    def register_saving_loading_hooks(self, transformer_lora_config):
-        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
-        def save_model_hook(models, weights, output_dir):
-            if self.state.accelerator.is_main_process:
-                transformer_lora_layers_to_save = None
-
-                for model in models:
-                    if isinstance(
-                        unwrap_model(self.state.accelerator, model),
-                        type(unwrap_model(self.state.accelerator, self.transformer)),
-                    ):
-                        model = unwrap_model(self.state.accelerator, model)
-                        if self.args.training_type == "lora":
-                            transformer_lora_layers_to_save = get_peft_model_state_dict(model)
-                    else:
-                        raise ValueError(f"Unexpected save model: {model.__class__}")
-
-                    # make sure to pop weight so that corresponding model is not saved again
-                    if weights:
-                        weights.pop()
-
-                if self.args.training_type == "lora":
-                    self.model_config["pipeline_cls"].save_lora_weights(
-                        output_dir,
-                        transformer_lora_layers=transformer_lora_layers_to_save,
-                    )
-                else:
-                    model.save_pretrained(os.path.join(output_dir, "transformer"))
-
-                    # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
-                    # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
-                    # ensure the scheduler config is serialized as well.
-                    self.scheduler.save_pretrained(os.path.join(output_dir, "scheduler"))
-
-        def load_model_hook(models, input_dir):
-            if not self.state.accelerator.distributed_type == DistributedType.DEEPSPEED:
-                while len(models) > 0:
-                    model = models.pop()
-                    if isinstance(
-                        unwrap_model(self.state.accelerator, model),
-                        type(unwrap_model(self.state.accelerator, self.transformer)),
-                    ):
-                        transformer_ = unwrap_model(self.state.accelerator, model)
-                    else:
-                        raise ValueError(
-                            f"Unexpected save model: {unwrap_model(self.state.accelerator, model).__class__}"
-                        )
-            else:
-                transformer_cls_ = unwrap_model(self.state.accelerator, self.transformer).__class__
-
-                if self.args.training_type == "lora":
-                    transformer_ = transformer_cls_.from_pretrained(
-                        self.args.pretrained_model_name_or_path, subfolder="transformer"
-                    )
-                    transformer_.add_adapter(transformer_lora_config)
-                    lora_state_dict = self.model_config["pipeline_cls"].lora_state_dict(input_dir)
-                    transformer_state_dict = {
-                        f'{k.replace("transformer.", "")}': v
-                        for k, v in lora_state_dict.items()
-                        if k.startswith("transformer.")
-                    }
-                    incompatible_keys = set_peft_model_state_dict(
-                        transformer_, transformer_state_dict, adapter_name="default"
-                    )
-                    if incompatible_keys is not None:
-                        # check only for unexpected keys
-                        unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
-                        if unexpected_keys:
-                            logger.warning(
-                                f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
-                                f" {unexpected_keys}. "
-                            )
-                else:
-                    transformer_ = transformer_cls_.from_pretrained(os.path.join(input_dir, "transformer"))
-
-        self.state.accelerator.register_save_state_pre_hook(save_model_hook)
-        self.state.accelerator.register_load_state_pre_hook(load_model_hook)
-
-    def prepare_optimizer(self) -> None:
-        logger.info("Initializing optimizer and lr scheduler")
-
-        self.state.train_epochs = self.args.train_epochs
-        self.state.train_steps = self.args.train_steps
-
-        # Make sure the trainable params are in float32
-        if self.args.training_type == "lora":
-            cast_training_params([self.transformer], dtype=torch.float32)
-
-        self.state.learning_rate = self.args.lr
-        if self.args.scale_lr:
-            self.state.learning_rate = (
-                self.state.learning_rate
-                * self.args.gradient_accumulation_steps
-                * self.args.batch_size
-                * self.state.accelerator.num_processes
-            )
-
-        transformer_trainable_parameters = list(filter(lambda p: p.requires_grad, self.transformer.parameters()))
-        transformer_parameters_with_lr = {
-            "params": transformer_trainable_parameters,
-            "lr": self.state.learning_rate,
-        }
-        params_to_optimize = [transformer_parameters_with_lr]
-        self.state.num_trainable_parameters = sum(p.numel() for p in transformer_trainable_parameters)
-
-        use_deepspeed_opt = (
-            self.state.accelerator.state.deepspeed_plugin is not None
-            and "optimizer" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
-        )
-        optimizer = get_optimizer(
-            params_to_optimize=params_to_optimize,
-            optimizer_name=self.args.optimizer,
-            learning_rate=self.state.learning_rate,
-            beta1=self.args.beta1,
-            beta2=self.args.beta2,
-            beta3=self.args.beta3,
-            epsilon=self.args.epsilon,
-            weight_decay=self.args.weight_decay,
-            use_8bit=self.args.use_8bit_bnb,
-            use_deepspeed=use_deepspeed_opt,
-        )
-
-        num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
-        if self.state.train_steps is None:
-            self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
-            self.state.overwrote_max_train_steps = True
-
-        use_deepspeed_lr_scheduler = (
-            self.state.accelerator.state.deepspeed_plugin is not None
-            and "scheduler" in self.state.accelerator.state.deepspeed_plugin.deepspeed_config
-        )
-        total_training_steps = self.state.train_steps * self.state.accelerator.num_processes
-        num_warmup_steps = self.args.lr_warmup_steps * self.state.accelerator.num_processes
-
-        if use_deepspeed_lr_scheduler:
-            from accelerate.utils import DummyScheduler
-
-            lr_scheduler = DummyScheduler(
-                name=self.args.lr_scheduler,
-                optimizer=optimizer,
-                total_num_steps=total_training_steps,
-                num_warmup_steps=num_warmup_steps,
-            )
-        else:
-            lr_scheduler = get_scheduler(
-                name=self.args.lr_scheduler,
-                optimizer=optimizer,
-                num_warmup_steps=num_warmup_steps,
-                num_training_steps=total_training_steps,
-                num_cycles=self.args.lr_num_cycles,
-                power=self.args.lr_power,
-            )
-
-        self.optimizer = optimizer
-        self.lr_scheduler = lr_scheduler
-
-    def prepare_for_training(self) -> None:
-        self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare(
-            self.transformer, self.optimizer, self.dataloader, self.lr_scheduler
-        )
-
-        # We need to recalculate our total training steps as the size of the training dataloader may have changed.
-        num_update_steps_per_epoch = math.ceil(len(self.dataloader) / self.args.gradient_accumulation_steps)
-        if self.state.overwrote_max_train_steps:
-            self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
-        # Afterwards we recalculate our number of training epochs
-        self.state.train_epochs = math.ceil(self.state.train_steps / num_update_steps_per_epoch)
-        self.state.num_update_steps_per_epoch = num_update_steps_per_epoch
-
-    def prepare_trackers(self) -> None:
-        logger.info("Initializing trackers")
-
-        tracker_name = self.args.tracker_name or "finetrainers-experiment"
-        self.state.accelerator.init_trackers(tracker_name, config=self._get_training_info())
-
-    def train(self) -> None:
-        logger.info("Starting training")
-
-
-        # Add these lines at the beginning
-        if hasattr(resource, 'RLIMIT_NOFILE'):
-            try:
-                soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
-                logger.info(f"Current file descriptor limits in trainer: soft={soft}, hard={hard}")
-                # Try to increase to hard limit if possible
-                if soft < hard:
-                    resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
-                    new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
-                    logger.info(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
-            except Exception as e:
-                logger.warning(f"Could not check or update file descriptor limits: {e}")
-        
-        memory_statistics = get_memory_statistics()
-        logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
-
-        if self.vae_config is None:
-            # If we've precomputed conditions and latents already, and are now re-using it, we will never load
-            # the VAE so self.vae_config will not be set. So, we need to load it here.
-            vae_cls = resolve_vae_cls_from_ckpt_path(
-                self.args.pretrained_model_name_or_path, revision=self.args.revision, cache_dir=self.args.cache_dir
-            )
-            vae_config = vae_cls.load_config(
-                self.args.pretrained_model_name_or_path,
-                subfolder="vae",
-                revision=self.args.revision,
-                cache_dir=self.args.cache_dir,
-            )
-            self.vae_config = FrozenDict(**vae_config)
-
-        # In some cases, the scheduler needs to be loaded with specific config (e.g. in CogVideoX). Since we need
-        # to able to load all diffusion components from a specific checkpoint folder during validation, we need to
-        # ensure the scheduler config is serialized as well.
-        if self.args.training_type == "full-finetune":
-            self.scheduler.save_pretrained(os.path.join(self.args.output_dir, "scheduler"))
-
-        self.state.train_batch_size = (
-            self.args.batch_size * self.state.accelerator.num_processes * self.args.gradient_accumulation_steps
-        )
-        info = {
-            "trainable parameters": self.state.num_trainable_parameters,
-            "total samples": len(self.dataset),
-            "train epochs": self.state.train_epochs,
-            "train steps": self.state.train_steps,
-            "batches per device": self.args.batch_size,
-            "total batches observed per epoch": len(self.dataloader),
-            "train batch size": self.state.train_batch_size,
-            "gradient accumulation steps": self.args.gradient_accumulation_steps,
-        }
-        logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
-
-        global_step = 0
-        first_epoch = 0
-        initial_global_step = 0
-
-        # Potentially load in the weights and states from a previous save
-        (
-            resume_from_checkpoint_path,
-            initial_global_step,
-            global_step,
-            first_epoch,
-        ) = get_latest_ckpt_path_to_resume_from(
-            resume_from_checkpoint=self.args.resume_from_checkpoint,
-            num_update_steps_per_epoch=self.state.num_update_steps_per_epoch,
-            output_dir=self.args.output_dir,
-        )
-        if resume_from_checkpoint_path:
-            self.state.accelerator.load_state(resume_from_checkpoint_path)
-
-        progress_bar = tqdm(
-            range(0, self.state.train_steps),
-            initial=initial_global_step,
-            desc="Training steps",
-            disable=not self.state.accelerator.is_local_main_process,
-        )
-
-        accelerator = self.state.accelerator
-        generator = torch.Generator(device=accelerator.device)
-        if self.args.seed is not None:
-            generator = generator.manual_seed(self.args.seed)
-        self.state.generator = generator
-
-        scheduler_sigmas = get_scheduler_sigmas(self.scheduler)
-        scheduler_sigmas = (
-            scheduler_sigmas.to(device=accelerator.device, dtype=torch.float32)
-            if scheduler_sigmas is not None
-            else None
-        )
-        scheduler_alphas = get_scheduler_alphas(self.scheduler)
-        scheduler_alphas = (
-            scheduler_alphas.to(device=accelerator.device, dtype=torch.float32)
-            if scheduler_alphas is not None
-            else None
-        )
-
-        for epoch in range(first_epoch, self.state.train_epochs):
-            logger.debug(f"Starting epoch ({epoch + 1}/{self.state.train_epochs})")
-
-            self.transformer.train()
-            models_to_accumulate = [self.transformer]
-            epoch_loss = 0.0
-            num_loss_updates = 0
-
-            for step, batch in enumerate(self.dataloader):
-                logger.debug(f"Starting step {step + 1}")
-                logs = {}
-
-                with accelerator.accumulate(models_to_accumulate):
-                    if not self.args.precompute_conditions:
-                        videos = batch["videos"]
-                        prompts = batch["prompts"]
-                        batch_size = len(prompts)
-
-                        if self.args.caption_dropout_technique == "empty":
-                            if random.random() < self.args.caption_dropout_p:
-                                prompts = [""] * batch_size
-
-                        latent_conditions = self.model_config["prepare_latents"](
-                            vae=self.vae,
-                            image_or_video=videos,
-                            patch_size=self.transformer_config.patch_size,
-                            patch_size_t=self.transformer_config.patch_size_t,
-                            device=accelerator.device,
-                            dtype=self.args.transformer_dtype,
-                            generator=self.state.generator,
-                        )
-                        text_conditions = self.model_config["prepare_conditions"](
-                            tokenizer=self.tokenizer,
-                            text_encoder=self.text_encoder,
-                            tokenizer_2=self.tokenizer_2,
-                            text_encoder_2=self.text_encoder_2,
-                            prompt=prompts,
-                            device=accelerator.device,
-                            dtype=self.args.transformer_dtype,
-                        )
-                    else:
-                        latent_conditions = batch["latent_conditions"]
-                        text_conditions = batch["text_conditions"]
-                        latent_conditions["latents"] = DiagonalGaussianDistribution(
-                            latent_conditions["latents"]
-                        ).sample(self.state.generator)
-
-                        # This method should only be called for precomputed latents.
-                        # TODO(aryan): rename this in separate PR
-                        latent_conditions = self.model_config["post_latent_preparation"](
-                            vae_config=self.vae_config,
-                            patch_size=self.transformer_config.patch_size,
-                            patch_size_t=self.transformer_config.patch_size_t,
-                            **latent_conditions,
-                        )
-                        align_device_and_dtype(latent_conditions, accelerator.device, self.args.transformer_dtype)
-                        align_device_and_dtype(text_conditions, accelerator.device, self.args.transformer_dtype)
-                        batch_size = latent_conditions["latents"].shape[0]
-
-                    latent_conditions = make_contiguous(latent_conditions)
-                    text_conditions = make_contiguous(text_conditions)
-
-                    if self.args.caption_dropout_technique == "zero":
-                        if random.random() < self.args.caption_dropout_p:
-                            text_conditions["prompt_embeds"].fill_(0)
-                            text_conditions["prompt_attention_mask"].fill_(False)
-
-                            # TODO(aryan): refactor later
-                            if "pooled_prompt_embeds" in text_conditions:
-                                text_conditions["pooled_prompt_embeds"].fill_(0)
-
-                    sigmas = prepare_sigmas(
-                        scheduler=self.scheduler,
-                        sigmas=scheduler_sigmas,
-                        batch_size=batch_size,
-                        num_train_timesteps=self.scheduler.config.num_train_timesteps,
-                        flow_weighting_scheme=self.args.flow_weighting_scheme,
-                        flow_logit_mean=self.args.flow_logit_mean,
-                        flow_logit_std=self.args.flow_logit_std,
-                        flow_mode_scale=self.args.flow_mode_scale,
-                        device=accelerator.device,
-                        generator=self.state.generator,
-                    )
-                    timesteps = (sigmas * 1000.0).long()
-
-                    noise = torch.randn(
-                        latent_conditions["latents"].shape,
-                        generator=self.state.generator,
-                        device=accelerator.device,
-                        dtype=self.args.transformer_dtype,
-                    )
-                    sigmas = expand_tensor_dims(sigmas, ndim=noise.ndim)
-
-                    # TODO(aryan): We probably don't need calculate_noisy_latents because we can determine the type of
-                    # scheduler and calculate the noisy latents accordingly. Look into this later.
-                    if "calculate_noisy_latents" in self.model_config.keys():
-                        noisy_latents = self.model_config["calculate_noisy_latents"](
-                            scheduler=self.scheduler,
-                            noise=noise,
-                            latents=latent_conditions["latents"],
-                            timesteps=timesteps,
-                        )
-                    else:
-                        # Default to flow-matching noise addition
-                        noisy_latents = (1.0 - sigmas) * latent_conditions["latents"] + sigmas * noise
-                        noisy_latents = noisy_latents.to(latent_conditions["latents"].dtype)
-
-                    latent_conditions.update({"noisy_latents": noisy_latents})
-
-                    weights = prepare_loss_weights(
-                        scheduler=self.scheduler,
-                        alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None,
-                        sigmas=sigmas,
-                        flow_weighting_scheme=self.args.flow_weighting_scheme,
-                    )
-                    weights = expand_tensor_dims(weights, noise.ndim)
-
-                    pred = self.model_config["forward_pass"](
-                        transformer=self.transformer,
-                        scheduler=self.scheduler,
-                        timesteps=timesteps,
-                        **latent_conditions,
-                        **text_conditions,
-                    )
-                    target = prepare_target(
-                        scheduler=self.scheduler, noise=noise, latents=latent_conditions["latents"]
-                    )
-
-                    loss = weights.float() * (pred["latents"].float() - target.float()).pow(2)
-                    # Average loss across all but batch dimension
-                    loss = loss.mean(list(range(1, loss.ndim)))
-                    # Average loss across batch dimension
-                    loss = loss.mean()
-                    accelerator.backward(loss)
-
-                    if accelerator.sync_gradients:
-                        if accelerator.distributed_type == DistributedType.DEEPSPEED:
-                            grad_norm = self.transformer.get_global_grad_norm()
-                            # In some cases the grad norm may not return a float
-                            if torch.is_tensor(grad_norm):
-                                grad_norm = grad_norm.item()
-                        else:
-                            grad_norm = accelerator.clip_grad_norm_(
-                                self.transformer.parameters(), self.args.max_grad_norm
-                            )
-                            if torch.is_tensor(grad_norm):
-                                grad_norm = grad_norm.item()
-
-                        logs["grad_norm"] = grad_norm
-
-                    self.optimizer.step()
-                    self.lr_scheduler.step()
-                    self.optimizer.zero_grad()
-
-                # Checks if the accelerator has performed an optimization step behind the scenes
-                if accelerator.sync_gradients:
-                    progress_bar.update(1)
-                    global_step += 1
-
-                    # Checkpointing
-                    if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
-                        if global_step % self.args.checkpointing_steps == 0:
-                            save_path = get_intermediate_ckpt_path(
-                                checkpointing_limit=self.args.checkpointing_limit,
-                                step=global_step,
-                                output_dir=self.args.output_dir,
-                            )
-                            accelerator.save_state(save_path)
-
-                    # Maybe run validation
-                    should_run_validation = (
-                        self.args.validation_every_n_steps is not None
-                        and global_step % self.args.validation_every_n_steps == 0
-                    )
-                    if should_run_validation:
-                        self.validate(global_step)
-
-                loss_item = loss.detach().item()
-                epoch_loss += loss_item
-                num_loss_updates += 1
-                logs["step_loss"] = loss_item
-                logs["lr"] = self.lr_scheduler.get_last_lr()[0]
-                progress_bar.set_postfix(logs)
-                accelerator.log(logs, step=global_step)
-
-                if global_step % 100 == 0:  # Every 100 steps
-                    # Force garbage collection to clean up any lingering resources
-                    gc.collect()
-
-                if global_step >= self.state.train_steps:
-                    break
-
-                
-
-            if num_loss_updates > 0:
-                epoch_loss /= num_loss_updates
-            accelerator.log({"epoch_loss": epoch_loss}, step=global_step)
-            memory_statistics = get_memory_statistics()
-            logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}")
-
-            # Maybe run validation
-            should_run_validation = (
-                self.args.validation_every_n_epochs is not None
-                and (epoch + 1) % self.args.validation_every_n_epochs == 0
-            )
-            if should_run_validation:
-                self.validate(global_step)
-
-            if epoch % 3 == 0:  # Every 3 epochs
-                logger.info("Performing periodic resource cleanup")
-                free_memory()
-                gc.collect()
-                torch.cuda.empty_cache()
-                torch.cuda.synchronize(accelerator.device)
-
-        accelerator.wait_for_everyone()
-        if accelerator.is_main_process:
-            transformer = unwrap_model(accelerator, self.transformer)
-
-            if self.args.training_type == "lora":
-                transformer_lora_layers = get_peft_model_state_dict(transformer)
-
-                self.model_config["pipeline_cls"].save_lora_weights(
-                    save_directory=self.args.output_dir,
-                    transformer_lora_layers=transformer_lora_layers,
-                )
-            else:
-                transformer.save_pretrained(os.path.join(self.args.output_dir, "transformer"))
-        accelerator.wait_for_everyone()
-        self.validate(step=global_step, final_validation=True)
-
-        if accelerator.is_main_process:
-            if self.args.push_to_hub:
-                upload_folder(
-                    repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=["checkpoint-*"]
-                )
-
-        self._delete_components()
-        memory_statistics = get_memory_statistics()
-        logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
-
-        accelerator.end_training()
-
-    def validate(self, step: int, final_validation: bool = False) -> None:
-        logger.info("Starting validation")
-
-        accelerator = self.state.accelerator
-        num_validation_samples = len(self.args.validation_prompts)
-
-        if num_validation_samples == 0:
-            logger.warning("No validation samples found. Skipping validation.")
-            if accelerator.is_main_process:
-                if self.args.push_to_hub:
-                    save_model_card(
-                        args=self.args,
-                        repo_id=self.state.repo_id,
-                        videos=None,
-                        validation_prompts=None,
-                    )
-            return
-
-        self.transformer.eval()
-
-        memory_statistics = get_memory_statistics()
-        logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
-
-        pipeline = self._get_and_prepare_pipeline_for_validation(final_validation=final_validation)
-
-        all_processes_artifacts = []
-        prompts_to_filenames = {}
-        for i in range(num_validation_samples):
-            # Skip current validation on all processes but one
-            if i % accelerator.num_processes != accelerator.process_index:
-                continue
-
-            prompt = self.args.validation_prompts[i]
-            image = self.args.validation_images[i]
-            video = self.args.validation_videos[i]
-            height = self.args.validation_heights[i]
-            width = self.args.validation_widths[i]
-            num_frames = self.args.validation_num_frames[i]
-            frame_rate = self.args.validation_frame_rate
-            if image is not None:
-                image = load_image(image)
-            if video is not None:
-                video = load_video(video)
-
-            logger.debug(
-                f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
-                main_process_only=False,
-            )
-            validation_artifacts = self.model_config["validation"](
-                pipeline=pipeline,
-                prompt=prompt,
-                image=image,
-                video=video,
-                height=height,
-                width=width,
-                num_frames=num_frames,
-                frame_rate=frame_rate,
-                num_videos_per_prompt=self.args.num_validation_videos_per_prompt,
-                generator=torch.Generator(device=accelerator.device).manual_seed(
-                    self.args.seed if self.args.seed is not None else 0
-                ),
-                # todo support passing `fps` for supported pipelines.
-            )
-
-            prompt_filename = string_to_filename(prompt)[:25]
-            artifacts = {
-                "image": {"type": "image", "value": image},
-                "video": {"type": "video", "value": video},
-            }
-            for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
-                if artifact_value:
-                    artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
-            logger.debug(
-                f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
-                main_process_only=False,
-            )
-
-            for index, (key, value) in enumerate(list(artifacts.items())):
-                artifact_type = value["type"]
-                artifact_value = value["value"]
-                if artifact_type not in ["image", "video"] or artifact_value is None:
-                    continue
-
-                extension = "png" if artifact_type == "image" else "mp4"
-                filename = "validation-" if not final_validation else "final-"
-                filename += f"{step}-{accelerator.process_index}-{index}-{prompt_filename}.{extension}"
-                if accelerator.is_main_process and extension == "mp4":
-                    prompts_to_filenames[prompt] = filename
-                filename = os.path.join(self.args.output_dir, filename)
-
-                if artifact_type == "image" and artifact_value:
-                    logger.debug(f"Saving image to {filename}")
-                    artifact_value.save(filename)
-                    artifact_value = wandb.Image(filename)
-                elif artifact_type == "video" and artifact_value:
-                    logger.debug(f"Saving video to {filename}")
-                    # TODO: this should be configurable here as well as in validation runs where we call the pipeline that has `fps`.
-                    export_to_video(artifact_value, filename, fps=frame_rate)
-                    artifact_value = wandb.Video(filename, caption=prompt)
-
-                all_processes_artifacts.append(artifact_value)
-
-        all_artifacts = gather_object(all_processes_artifacts)
-
-        if accelerator.is_main_process:
-            tracker_key = "final" if final_validation else "validation"
-            for tracker in accelerator.trackers:
-                if tracker.name == "wandb":
-                    artifact_log_dict = {}
-
-                    image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
-                    if len(image_artifacts) > 0:
-                        artifact_log_dict["images"] = image_artifacts
-                    video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
-                    if len(video_artifacts) > 0:
-                        artifact_log_dict["videos"] = video_artifacts
-                    tracker.log({tracker_key: artifact_log_dict}, step=step)
-
-            if self.args.push_to_hub and final_validation:
-                video_filenames = list(prompts_to_filenames.values())
-                prompts = list(prompts_to_filenames.keys())
-                save_model_card(
-                    args=self.args,
-                    repo_id=self.state.repo_id,
-                    videos=video_filenames,
-                    validation_prompts=prompts,
-                )
-
-        # Remove all hooks that might have been added during pipeline initialization to the models
-        pipeline.remove_all_hooks()
-        del pipeline
-
-        accelerator.wait_for_everyone()
-
-        free_memory()
-        memory_statistics = get_memory_statistics()
-        logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
-        torch.cuda.reset_peak_memory_stats(accelerator.device)
-
-        if not final_validation:
-            self.transformer.train()
-
-    def evaluate(self) -> None:
-        raise NotImplementedError("Evaluation has not been implemented yet.")
-
-    def _init_distributed(self) -> None:
-        logging_dir = Path(self.args.output_dir, self.args.logging_dir)
-        project_config = ProjectConfiguration(project_dir=self.args.output_dir, logging_dir=logging_dir)
-        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
-        init_process_group_kwargs = InitProcessGroupKwargs(
-            backend="nccl", timeout=timedelta(seconds=self.args.nccl_timeout)
-        )
-        report_to = None if self.args.report_to.lower() == "none" else self.args.report_to
-
-        accelerator = Accelerator(
-            project_config=project_config,
-            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
-            log_with=report_to,
-            kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
-        )
-
-        # Disable AMP for MPS.
-        if torch.backends.mps.is_available():
-            accelerator.native_amp = False
-
-        self.state.accelerator = accelerator
-
-        if self.args.seed is not None:
-            self.state.seed = self.args.seed
-            set_seed(self.args.seed)
-
-    def _init_logging(self) -> None:
-        logging.basicConfig(
-            format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
-            datefmt="%m/%d/%Y %H:%M:%S",
-            level=FINETRAINERS_LOG_LEVEL,
-        )
-        if self.state.accelerator.is_local_main_process:
-            transformers.utils.logging.set_verbosity_warning()
-            diffusers.utils.logging.set_verbosity_info()
-        else:
-            transformers.utils.logging.set_verbosity_error()
-            diffusers.utils.logging.set_verbosity_error()
-
-        logger.info("Initialized FineTrainers")
-        logger.info(self.state.accelerator.state, main_process_only=False)
-
-    def _init_directories_and_repositories(self) -> None:
-        if self.state.accelerator.is_main_process:
-            self.args.output_dir = Path(self.args.output_dir)
-            self.args.output_dir.mkdir(parents=True, exist_ok=True)
-            self.state.output_dir = Path(self.args.output_dir)
-
-            if self.args.push_to_hub:
-                repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
-                self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
-
-    def _init_config_options(self) -> None:
-        # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
-        if self.args.allow_tf32 and torch.cuda.is_available():
-            torch.backends.cuda.matmul.allow_tf32 = True
-
-    def _move_components_to_device(self):
-        if self.text_encoder is not None:
-            self.text_encoder = self.text_encoder.to(self.state.accelerator.device)
-        if self.text_encoder_2 is not None:
-            self.text_encoder_2 = self.text_encoder_2.to(self.state.accelerator.device)
-        if self.text_encoder_3 is not None:
-            self.text_encoder_3 = self.text_encoder_3.to(self.state.accelerator.device)
-        if self.transformer is not None:
-            self.transformer = self.transformer.to(self.state.accelerator.device)
-        if self.unet is not None:
-            self.unet = self.unet.to(self.state.accelerator.device)
-        if self.vae is not None:
-            self.vae = self.vae.to(self.state.accelerator.device)
-
-    def _get_load_components_kwargs(self) -> Dict[str, Any]:
-        load_component_kwargs = {
-            "text_encoder_dtype": self.args.text_encoder_dtype,
-            "text_encoder_2_dtype": self.args.text_encoder_2_dtype,
-            "text_encoder_3_dtype": self.args.text_encoder_3_dtype,
-            "transformer_dtype": self.args.transformer_dtype,
-            "vae_dtype": self.args.vae_dtype,
-            "shift": self.args.flow_shift,
-            "revision": self.args.revision,
-            "cache_dir": self.args.cache_dir,
-        }
-        if self.args.pretrained_model_name_or_path is not None:
-            load_component_kwargs["model_id"] = self.args.pretrained_model_name_or_path
-        return load_component_kwargs
-
-    def _set_components(self, components: Dict[str, Any]) -> None:
-        # Set models
-        self.tokenizer = components.get("tokenizer", self.tokenizer)
-        self.tokenizer_2 = components.get("tokenizer_2", self.tokenizer_2)
-        self.tokenizer_3 = components.get("tokenizer_3", self.tokenizer_3)
-        self.text_encoder = components.get("text_encoder", self.text_encoder)
-        self.text_encoder_2 = components.get("text_encoder_2", self.text_encoder_2)
-        self.text_encoder_3 = components.get("text_encoder_3", self.text_encoder_3)
-        self.transformer = components.get("transformer", self.transformer)
-        self.unet = components.get("unet", self.unet)
-        self.vae = components.get("vae", self.vae)
-        self.scheduler = components.get("scheduler", self.scheduler)
-
-        # Set configs
-        self.transformer_config = self.transformer.config if self.transformer is not None else self.transformer_config
-        self.vae_config = self.vae.config if self.vae is not None else self.vae_config
-
-    def _delete_components(self) -> None:
-        self.tokenizer = None
-        self.tokenizer_2 = None
-        self.tokenizer_3 = None
-        self.text_encoder = None
-        self.text_encoder_2 = None
-        self.text_encoder_3 = None
-        self.transformer = None
-        self.unet = None
-        self.vae = None
-        self.scheduler = None
-        free_memory()
-        torch.cuda.synchronize(self.state.accelerator.device)
-
-    def _get_and_prepare_pipeline_for_validation(self, final_validation: bool = False) -> DiffusionPipeline:
-        accelerator = self.state.accelerator
-        if not final_validation:
-            pipeline = self.model_config["initialize_pipeline"](
-                model_id=self.args.pretrained_model_name_or_path,
-                tokenizer=self.tokenizer,
-                text_encoder=self.text_encoder,
-                tokenizer_2=self.tokenizer_2,
-                text_encoder_2=self.text_encoder_2,
-                transformer=unwrap_model(accelerator, self.transformer),
-                vae=self.vae,
-                device=accelerator.device,
-                revision=self.args.revision,
-                cache_dir=self.args.cache_dir,
-                enable_slicing=self.args.enable_slicing,
-                enable_tiling=self.args.enable_tiling,
-                enable_model_cpu_offload=self.args.enable_model_cpu_offload,
-                is_training=True,
-            )
-        else:
-            self._delete_components()
-
-            # Load the transformer weights from the final checkpoint if performing full-finetune
-            transformer = None
-            if self.args.training_type == "full-finetune":
-                transformer = self.model_config["load_diffusion_models"](model_id=self.args.output_dir)["transformer"]
-
-            pipeline = self.model_config["initialize_pipeline"](
-                model_id=self.args.pretrained_model_name_or_path,
-                transformer=transformer,
-                device=accelerator.device,
-                revision=self.args.revision,
-                cache_dir=self.args.cache_dir,
-                enable_slicing=self.args.enable_slicing,
-                enable_tiling=self.args.enable_tiling,
-                enable_model_cpu_offload=self.args.enable_model_cpu_offload,
-                is_training=False,
-            )
-
-            # Load the LoRA weights if performing LoRA finetuning
-            if self.args.training_type == "lora":
-                pipeline.load_lora_weights(self.args.output_dir)
-
-        return pipeline
-
-    def _disable_grad_for_components(self, components: List[torch.nn.Module]):
-        for component in components:
-            if component is not None:
-                component.requires_grad_(False)
-
-    def _enable_grad_for_components(self, components: List[torch.nn.Module]):
-        for component in components:
-            if component is not None:
-                component.requires_grad_(True)
-
-    def _get_training_info(self) -> dict:
-        args = self.args.to_dict()
-
-        training_args = args.get("training_arguments", {})
-        training_type = training_args.get("training_type", "")
-
-        # LoRA/non-LoRA stuff.
-        if training_type == "full-finetune":
-            filtered_training_args = {
-                k: v for k, v in training_args.items() if k not in {"rank", "lora_alpha", "target_modules"}
-            }
-        else:
-            filtered_training_args = training_args
-
-        # Diffusion/flow stuff.
-        diffusion_args = args.get("diffusion_arguments", {})
-        scheduler_name = self.scheduler.__class__.__name__
-        if scheduler_name != "FlowMatchEulerDiscreteScheduler":
-            filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k}
-        else:
-            filtered_diffusion_args = diffusion_args
-
-        # Rest of the stuff.
-        updated_training_info = args.copy()
-        updated_training_info["training_arguments"] = filtered_training_args
-        updated_training_info["diffusion_arguments"] = filtered_diffusion_args
-        return updated_training_info
diff --git a/finetrainers/trainer/__init__.py b/finetrainers/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ba65c53fa75115084ceec6dabc2667a8a5d6a29
--- /dev/null
+++ b/finetrainers/trainer/__init__.py
@@ -0,0 +1 @@
+from .sft_trainer.trainer import SFTTrainer
diff --git a/finetrainers/trainer/config_utils.py b/finetrainers/trainer/config_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c354a30bac60265868c3ba8ef5313c12b7fc224
--- /dev/null
+++ b/finetrainers/trainer/config_utils.py
@@ -0,0 +1,17 @@
+import argparse
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+    from ..args import BaseArgs
+
+
+class ConfigMixin:
+    def add_args(self, parser: argparse.ArgumentParser):
+        raise NotImplementedError("ConfigMixin::add_args should be implemented by subclasses.")
+
+    def validate_args(self, args: "BaseArgs"):
+        raise NotImplementedError("ConfigMixin::map_args should be implemented by subclasses.")
+
+    def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
+        raise NotImplementedError("ConfigMixin::validate_args should be implemented by subclasses.")
diff --git a/finetrainers/trainer/sft_trainer/config.py b/finetrainers/trainer/sft_trainer/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..539426c7c9509c1b417d3471cc4391436f532cb7
--- /dev/null
+++ b/finetrainers/trainer/sft_trainer/config.py
@@ -0,0 +1,58 @@
+import argparse
+from typing import TYPE_CHECKING, List, Union
+
+from ..config_utils import ConfigMixin
+
+
+if TYPE_CHECKING:
+    from ...args import BaseArgs
+
+
+class SFTLowRankConfig(ConfigMixin):
+    r"""
+    Configuration class for SFT low rank training.
+
+    Args:
+        rank (int):
+            Rank of the low rank approximation.
+        lora_alpha (int):
+            The lora_alpha parameter to compute scaling factor (lora_alpha / rank) for low-rank matrices.
+        target_modules (`str` or `List[str]`):
+            Target modules for the low rank approximation. Can be a regex string or a list of regex strings.
+    """
+
+    rank: int = 64
+    lora_alpha: int = 64
+    target_modules: Union[str, List[str]] = "(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"
+
+    def add_args(self, parser: argparse.ArgumentParser):
+        parser.add_argument("--rank", type=int, default=64)
+        parser.add_argument("--lora_alpha", type=int, default=64)
+        parser.add_argument(
+            "--target_modules",
+            type=str,
+            nargs="+",
+            default=["(transformer_blocks|single_transformer_blocks).*(to_q|to_k|to_v|to_out.0)"],
+        )
+
+    def validate_args(self, args: "BaseArgs"):
+        assert self.rank > 0, "Rank must be a positive integer."
+        assert self.lora_alpha > 0, "lora_alpha must be a positive integer."
+
+    def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
+        mapped_args.rank = argparse_args.rank
+        mapped_args.lora_alpha = argparse_args.lora_alpha
+        mapped_args.target_modules = (
+            argparse_args.target_modules[0] if len(argparse_args.target_modules) == 1 else argparse_args.target_modules
+        )
+
+
+class SFTFullRankConfig(ConfigMixin):
+    def add_args(self, parser: argparse.ArgumentParser):
+        pass
+
+    def validate_args(self, args: "BaseArgs"):
+        pass
+
+    def map_args(self, argparse_args: argparse.Namespace, mapped_args: "BaseArgs"):
+        pass
diff --git a/finetrainers/trainer/sft_trainer/trainer.py b/finetrainers/trainer/sft_trainer/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9523730d7e7c9af666b618dbfd05858122ff96b
--- /dev/null
+++ b/finetrainers/trainer/sft_trainer/trainer.py
@@ -0,0 +1,934 @@
+import functools
+import json
+import math
+import os
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
+
+import datasets.distributed
+import diffusers
+import torch
+import torch.backends
+import transformers
+import wandb
+from diffusers import DiffusionPipeline
+from diffusers.hooks import apply_layerwise_casting
+from diffusers.training_utils import cast_training_params
+from diffusers.utils import export_to_video
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, get_peft_model_state_dict
+from tqdm import tqdm
+
+from ... import data, logging, optimizer, parallel, patches, utils
+from ...config import TrainingType
+from ...state import State, TrainState
+
+
+if TYPE_CHECKING:
+    from ...args import BaseArgs
+    from ...models import ModelSpecification
+
+
+logger = logging.get_logger()
+
+
+class SFTTrainer:
+    def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None:
+        self.args = args
+        self.state = State()
+        self.state.train_state = TrainState()
+
+        # Tokenizers
+        self.tokenizer = None
+        self.tokenizer_2 = None
+        self.tokenizer_3 = None
+
+        # Text encoders
+        self.text_encoder = None
+        self.text_encoder_2 = None
+        self.text_encoder_3 = None
+
+        # Denoisers
+        self.transformer = None
+        self.unet = None
+
+        # Autoencoders
+        self.vae = None
+
+        # Scheduler
+        self.scheduler = None
+
+        # Optimizer & LR scheduler
+        self.optimizer = None
+        self.lr_scheduler = None
+
+        # Checkpoint manager
+        self.checkpointer = None
+
+        self._init_distributed()
+        self._init_config_options()
+
+        # Perform any patches that might be necessary for training to work as expected
+        patches.perform_patches_for_training(self.args, self.state.parallel_backend)
+
+        self.model_specification = model_specification
+
+    def run(self) -> None:
+        try:
+            self._prepare_models()
+            self._prepare_trainable_parameters()
+            self._prepare_for_training()
+            self._prepare_dataset()
+            self._prepare_checkpointing()
+            self._train()
+            # trainer._evaluate()
+        except Exception as e:
+            logger.error(f"Error during training: {e}")
+            self.state.parallel_backend.destroy()
+            raise e
+
+    def _prepare_models(self) -> None:
+        logger.info("Initializing models")
+
+        diffusion_components = self.model_specification.load_diffusion_models()
+        self._set_components(diffusion_components)
+
+        if self.state.parallel_backend.pipeline_parallel_enabled:
+            raise NotImplementedError(
+                "Pipeline parallelism is not supported yet. This will be supported in the future."
+            )
+
+    def _prepare_trainable_parameters(self) -> None:
+        logger.info("Initializing trainable parameters")
+
+        parallel_backend = self.state.parallel_backend
+
+        if self.args.training_type == TrainingType.FULL_FINETUNE:
+            logger.info("Finetuning transformer with no additional parameters")
+            utils.set_requires_grad([self.transformer], True)
+        else:
+            logger.info("Finetuning transformer with PEFT parameters")
+            utils.set_requires_grad([self.transformer], False)
+
+        # Layerwise upcasting must be applied before adding the LoRA adapter.
+        # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on
+        # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly.
+        if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules:
+            apply_layerwise_casting(
+                self.transformer,
+                storage_dtype=self.args.layerwise_upcasting_storage_dtype,
+                compute_dtype=self.args.transformer_dtype,
+                skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern,
+                non_blocking=True,
+            )
+
+        transformer_lora_config = None
+        if self.args.training_type == TrainingType.LORA:
+            transformer_lora_config = LoraConfig(
+                r=self.args.rank,
+                lora_alpha=self.args.lora_alpha,
+                init_lora_weights=True,
+                target_modules=self.args.target_modules,
+            )
+            self.transformer.add_adapter(transformer_lora_config)
+
+        # # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32
+        # # even if layerwise upcasting. Would be nice to have a test as well
+        # self.register_saving_loading_hooks(transformer_lora_config)
+
+        # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all
+        # parameters to be of the same dtype.
+        if self.args.training_type == TrainingType.LORA and not parallel_backend.data_sharding_enabled:
+            cast_training_params([self.transformer], dtype=torch.float32)
+
+    def _prepare_for_training(self) -> None:
+        # 1. Apply parallelism
+        parallel_backend = self.state.parallel_backend
+        world_mesh = parallel_backend.get_mesh()
+        model_specification = self.model_specification
+
+        if parallel_backend.context_parallel_enabled:
+            raise NotImplementedError(
+                "Context parallelism is not supported yet. This will be supported in the future."
+            )
+
+        if parallel_backend.tensor_parallel_enabled:
+            # TODO(aryan): handle fp8 from TorchAO here
+            model_specification.apply_tensor_parallel(
+                backend=parallel.ParallelBackendEnum.PTD,
+                device_mesh=parallel_backend.get_mesh()["tp"],
+                transformer=self.transformer,
+            )
+
+        # Enable gradient checkpointing
+        if self.args.gradient_checkpointing:
+            # TODO(aryan): support other checkpointing types
+            utils.apply_activation_checkpointing(self.transformer, checkpointing_type="full")
+
+        # Enable DDP, FSDP or HSDP
+        if parallel_backend.data_sharding_enabled:
+            # TODO(aryan): remove this when supported
+            if self.args.parallel_backend == "accelerate":
+                raise NotImplementedError("Data sharding is not supported with Accelerate yet.")
+
+            if parallel_backend.data_replication_enabled:
+                logger.info("Applying HSDP to the model")
+            else:
+                logger.info("Applying FSDP to the model")
+
+            # Apply FSDP or HSDP
+            if parallel_backend.data_replication_enabled or parallel_backend.context_parallel_enabled:
+                dp_mesh_names = ("dp_replicate", "dp_shard_cp")
+            else:
+                dp_mesh_names = ("dp_shard_cp",)
+
+            parallel.apply_fsdp2_ptd(
+                model=self.transformer,
+                dp_mesh=world_mesh[dp_mesh_names],
+                param_dtype=self.args.transformer_dtype,
+                reduce_dtype=torch.float32,
+                output_dtype=None,
+                pp_enabled=parallel_backend.pipeline_parallel_enabled,
+                cpu_offload=False,  # TODO(aryan): needs to be tested and allowed for enabling later
+            )
+        elif parallel_backend.data_replication_enabled:
+            logger.info("Applying DDP to the model")
+
+            if world_mesh.ndim > 1:
+                raise ValueError("DDP not supported for > 1D parallelism")
+
+            parallel_backend.apply_ddp(self.transformer, world_mesh)
+
+        self._move_components_to_device()
+
+        # 2. Prepare optimizer and lr scheduler
+        # For training LoRAs, we can be a little more optimal. Currently, the OptimizerWrapper only accepts torch::nn::Module.
+        # This causes us to loop over all the parameters (even ones that don't require gradients, as in LoRA) at each optimizer
+        # step. This is OK (see https://github.com/pytorch/pytorch/blob/2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4/torch/optim/sgd.py#L85-L99)
+        # but can be optimized a bit by maybe creating a simple wrapper module encompassing the actual parameters that require
+        # gradients. TODO(aryan): look into it in the future.
+        model_parts = [self.transformer]
+        self.state.num_trainable_parameters = sum(
+            p.numel() for m in model_parts for p in m.parameters() if p.requires_grad
+        )
+
+        # Setup distributed optimizer and lr scheduler
+        logger.info("Initializing optimizer and lr scheduler")
+        self.state.train_state = TrainState()
+        self.optimizer = optimizer.get_optimizer(
+            parallel_backend=self.args.parallel_backend,
+            name=self.args.optimizer,
+            model_parts=model_parts,
+            learning_rate=self.args.lr,
+            beta1=self.args.beta1,
+            beta2=self.args.beta2,
+            beta3=self.args.beta3,
+            epsilon=self.args.epsilon,
+            weight_decay=self.args.weight_decay,
+            fused=False,
+        )
+        self.lr_scheduler = optimizer.get_lr_scheduler(
+            parallel_backend=self.args.parallel_backend,
+            name=self.args.lr_scheduler,
+            optimizer=self.optimizer,
+            num_warmup_steps=self.args.lr_warmup_steps,
+            num_training_steps=self.args.train_steps,
+            # TODO(aryan): handle last_epoch
+        )
+        self.optimizer, self.lr_scheduler = parallel_backend.prepare_optimizer(self.optimizer, self.lr_scheduler)
+
+        # 3. Initialize trackers, directories and repositories
+        self._init_logging()
+        self._init_trackers()
+        self._init_directories_and_repositories()
+
+    def _prepare_dataset(self) -> None:
+        logger.info("Initializing dataset and dataloader")
+
+        with open(self.args.dataset_config, "r") as file:
+            dataset_configs = json.load(file)["datasets"]
+        logger.info(f"Training configured to use {len(dataset_configs)} datasets")
+
+        datasets = []
+        for config in dataset_configs:
+            data_root = config.pop("data_root", None)
+            dataset_file = config.pop("dataset_file", None)
+            dataset_type = config.pop("dataset_type")
+
+            if data_root is not None and dataset_file is not None:
+                raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.")
+
+            dataset_name_or_root = data_root or dataset_file
+            dataset = data.initialize_dataset(dataset_name_or_root, dataset_type, streaming=True, infinite=True)
+
+            if not dataset._precomputable_once and self.args.precomputation_once:
+                raise ValueError(
+                    f"Dataset {dataset_name_or_root} does not support precomputing all embeddings at once."
+                )
+
+            logger.info(f"Initialized dataset: {dataset_name_or_root}")
+            dataset = self.state.parallel_backend.prepare_dataset(dataset)
+            dataset = data.wrap_iterable_dataset_for_preprocessing(dataset, dataset_type, config)
+            datasets.append(dataset)
+
+        dataset = data.combine_datasets(datasets, buffer_size=self.args.dataset_shuffle_buffer_size, shuffle=True)
+        dataloader = self.state.parallel_backend.prepare_dataloader(
+            dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory
+        )
+
+        self.dataset = dataset
+        self.dataloader = dataloader
+
+    def _prepare_checkpointing(self) -> None:
+        parallel_backend = self.state.parallel_backend
+
+        def save_model_hook(state_dict: Dict[str, Any]) -> None:
+            if parallel_backend.is_main_process:
+                if self.args.training_type == TrainingType.LORA:
+                    state_dict = get_peft_model_state_dict(self.transformer, state_dict)
+                    self.model_specification._save_lora_weights(self.args.output_dir, state_dict, self.scheduler)
+                elif self.args.training_type == TrainingType.FULL_FINETUNE:
+                    self.model_specification._save_model(
+                        self.args.output_dir, self.transformer, state_dict, self.scheduler
+                    )
+            parallel_backend.wait_for_everyone()
+
+        enable_state_checkpointing = self.args.checkpointing_steps > 0
+        self.checkpointer = utils.PTDCheckpointManager(
+            dataloader=self.dataloader,
+            model_parts=[self.transformer],
+            optimizers=self.optimizer,
+            schedulers=self.lr_scheduler,
+            states={"train_state": self.state.train_state},
+            checkpointing_steps=self.args.checkpointing_steps,
+            checkpointing_limit=self.args.checkpointing_limit,
+            output_dir=self.args.output_dir,
+            enable=enable_state_checkpointing,
+            _callback_fn=save_model_hook,
+        )
+
+        resume_from_checkpoint = self.args.resume_from_checkpoint
+        if resume_from_checkpoint == "latest":
+            resume_from_checkpoint = -1
+        if resume_from_checkpoint is not None:
+            self.checkpointer.load(resume_from_checkpoint)
+
+    def _train(self) -> None:
+        logger.info("Starting training")
+
+        parallel_backend = self.state.parallel_backend
+        train_state = self.state.train_state
+        device = parallel_backend.device
+
+        memory_statistics = utils.get_memory_statistics()
+        logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
+
+        global_batch_size = self.args.batch_size * parallel_backend._dp_degree
+        info = {
+            "trainable parameters": self.state.num_trainable_parameters,
+            "train steps": self.args.train_steps,
+            "per-replica batch size": self.args.batch_size,
+            "global batch size": global_batch_size,
+            "gradient accumulation steps": self.args.gradient_accumulation_steps,
+        }
+        logger.info(f"Training configuration: {json.dumps(info, indent=4)}")
+
+        progress_bar = tqdm(
+            range(0, self.args.train_steps),
+            initial=train_state.step,
+            desc="Training steps",
+            disable=not parallel_backend.is_local_main_process,
+        )
+
+        generator = torch.Generator(device=device)
+        if self.args.seed is not None:
+            generator = generator.manual_seed(self.args.seed)
+        self.state.generator = generator
+
+        patch_size = 1
+        if (
+            getattr(self.transformer.config, "patch_size", None) is not None
+            and getattr(self.transformer.config, "patch_size_t", None) is not None
+        ):
+            patch_size = self.transformer.config.patch_size * self.transformer.config.patch_size_t
+        elif isinstance(getattr(self.transformer.config, "patch_size", None), int):
+            patch_size = self.transformer.config.patch_size
+        elif isinstance(getattr(self.transformer.config, "patch_size", None), (list, tuple)):
+            patch_size = math.prod(self.transformer.config.patch_size)
+
+        scheduler_sigmas = utils.get_scheduler_sigmas(self.scheduler)
+        scheduler_sigmas = (
+            scheduler_sigmas.to(device=device, dtype=torch.float32) if scheduler_sigmas is not None else None
+        )
+        scheduler_alphas = utils.get_scheduler_alphas(self.scheduler)
+        scheduler_alphas = (
+            scheduler_alphas.to(device=device, dtype=torch.float32) if scheduler_alphas is not None else None
+        )
+        timesteps_buffer = []
+
+        self.transformer.train()
+        data_iterator = iter(self.dataloader)
+
+        preprocessor = data.DistributedDataPreprocessor(
+            rank=parallel_backend.rank,
+            num_items=self.args.precomputation_items,
+            processor_fn={
+                "condition": self.model_specification.prepare_conditions,
+                "latent": functools.partial(
+                    self.model_specification.prepare_latents, compute_posterior=not self.args.precomputation_once
+                ),
+            },
+            save_dir=self.args.precomputation_dir,
+        )
+        precomputed_condition_iterator: Iterable[Dict[str, Any]] = None
+        precomputed_latent_iterator: Iterable[Dict[str, Any]] = None
+        sampler = data.ResolutionSampler(
+            batch_size=self.args.batch_size, dim_keys=self.model_specification._resolution_dim_keys
+        )
+        requires_gradient_step = True
+        accumulated_loss = 0.0
+
+        while (
+            train_state.step < self.args.train_steps and train_state.observed_data_samples < self.args.max_data_samples
+        ):
+            # 1. Load & preprocess data if required
+            if preprocessor.requires_data:
+                # TODO(aryan): We should do the following here:
+                # - Force checkpoint the trainable models, optimizers, schedulers and train state
+                # - Do the precomputation
+                # - Load the checkpointed models, optimizers, schedulers and train state back, and continue training
+                # This way we can be more memory efficient again, since the latest rewrite of precomputation removed
+                # this logic.
+                precomputed_condition_iterator, precomputed_latent_iterator = self._prepare_data(
+                    preprocessor, data_iterator
+                )
+
+            # 2. Prepare batch
+            try:
+                condition_item = next(precomputed_condition_iterator)
+                latent_item = next(precomputed_latent_iterator)
+                sampler.consume(condition_item, latent_item)
+            except StopIteration:
+                if requires_gradient_step:
+                    self.optimizer.step()
+                    self.lr_scheduler.step()
+                    requires_gradient_step = False
+                logger.info("Data exhausted. Exiting training loop.")
+                break
+
+            if sampler.is_ready:
+                condition_batch, latent_batch = sampler.get_batch()
+                condition_model_conditions = self.model_specification.collate_conditions(condition_batch)
+                latent_model_conditions = self.model_specification.collate_latents(latent_batch)
+            else:
+                continue
+
+            train_state.step += 1
+            train_state.observed_data_samples += self.args.batch_size * parallel_backend._dp_degree
+
+            lmc_latents = latent_model_conditions["latents"]
+            train_state.observed_num_tokens += math.prod(lmc_latents.shape[:-1]) // patch_size
+
+            logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
+
+            utils.align_device_and_dtype(latent_model_conditions, device, self.args.transformer_dtype)
+            utils.align_device_and_dtype(condition_model_conditions, device, self.args.transformer_dtype)
+            latent_model_conditions = utils.make_contiguous(latent_model_conditions)
+            condition_model_conditions = utils.make_contiguous(condition_model_conditions)
+
+            # 3. Forward pass
+            sigmas = utils.prepare_sigmas(
+                scheduler=self.scheduler,
+                sigmas=scheduler_sigmas,
+                batch_size=self.args.batch_size,
+                num_train_timesteps=self.scheduler.config.num_train_timesteps,
+                flow_weighting_scheme=self.args.flow_weighting_scheme,
+                flow_logit_mean=self.args.flow_logit_mean,
+                flow_logit_std=self.args.flow_logit_std,
+                flow_mode_scale=self.args.flow_mode_scale,
+                device=device,
+                generator=self.state.generator,
+            )
+            sigmas = utils.expand_tensor_dims(sigmas, latent_model_conditions["latents"].ndim)
+
+            pred, target, sigmas = self.model_specification.forward(
+                transformer=self.transformer,
+                scheduler=self.scheduler,
+                condition_model_conditions=condition_model_conditions,
+                latent_model_conditions=latent_model_conditions,
+                sigmas=sigmas,
+                compute_posterior=not self.args.precomputation_once,
+            )
+
+            timesteps = (sigmas * 1000.0).long()
+            weights = utils.prepare_loss_weights(
+                scheduler=self.scheduler,
+                alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None,
+                sigmas=sigmas,
+                flow_weighting_scheme=self.args.flow_weighting_scheme,
+            )
+            weights = utils.expand_tensor_dims(weights, pred.ndim)
+
+            # 4. Compute loss & backward pass
+            loss = weights.float() * (pred.float() - target.float()).pow(2)
+            # Average loss across all but batch dimension
+            loss = loss.mean(list(range(1, loss.ndim)))
+            # Average loss across batch dimension
+            loss = loss.mean()
+            if self.args.gradient_accumulation_steps > 1:
+                loss = loss / self.args.gradient_accumulation_steps
+            loss.backward()
+            accumulated_loss += loss.detach().item()
+            requires_gradient_step = True
+
+            # 5. Clip gradients
+            model_parts = [self.transformer]
+            grad_norm = utils.torch._clip_grad_norm_while_handling_failing_dtensor_cases(
+                [p for m in model_parts for p in m.parameters()],
+                self.args.max_grad_norm,
+                foreach=True,
+                pp_mesh=parallel_backend.get_mesh("pp") if parallel_backend.pipeline_parallel_enabled else None,
+            )
+
+            # 6. Step optimizer & log metrics
+            logs = {}
+
+            if train_state.step % self.args.gradient_accumulation_steps == 0:
+                # TODO(aryan): revisit no_sync() for FSDP
+                # TODO(aryan): average the gradients for accumulation?
+                self.optimizer.step()
+                self.lr_scheduler.step()
+                self.optimizer.zero_grad()
+
+                if grad_norm is not None:
+                    logs["grad_norm"] = grad_norm if isinstance(grad_norm, float) else grad_norm.detach().item()
+                if (
+                    parallel_backend.data_replication_enabled
+                    or parallel_backend.data_sharding_enabled
+                    or parallel_backend.context_parallel_enabled
+                ):
+                    dp_cp_mesh = parallel_backend.get_mesh("dp_cp")
+                    global_avg_loss, global_max_loss = (
+                        parallel.dist_mean(torch.tensor([accumulated_loss], device=device), dp_cp_mesh),
+                        parallel.dist_max(torch.tensor([accumulated_loss], device=device), dp_cp_mesh),
+                    )
+                else:
+                    global_avg_loss = global_max_loss = accumulated_loss
+
+                logs["global_avg_loss"] = global_avg_loss
+                logs["global_max_loss"] = global_max_loss
+                train_state.global_avg_losses.append(global_avg_loss)
+                train_state.global_max_losses.append(global_max_loss)
+                accumulated_loss = 0.0
+                requires_gradient_step = False
+
+            progress_bar.update(1)
+            progress_bar.set_postfix(logs)
+
+            timesteps_buffer.extend([(train_state.step, t) for t in timesteps.detach().cpu().numpy().tolist()])
+
+            if train_state.step % self.args.logging_steps == 0:
+                # TODO(aryan): handle non-SchedulerWrapper schedulers (probably not required eventually) since they might not be dicts
+                # TODO(aryan): causes NCCL hang for some reason. look into later
+                # logs.update(self.lr_scheduler.get_last_lr())
+
+                # timesteps_table = wandb.Table(data=timesteps_buffer, columns=["step", "timesteps"])
+                # logs["timesteps"] = wandb.plot.scatter(
+                #     timesteps_table, "step", "timesteps", title="Timesteps distribution"
+                # )
+                timesteps_buffer = []
+
+                logs["observed_data_samples"] = train_state.observed_data_samples
+                logs["observed_num_tokens"] = train_state.observed_num_tokens
+
+                parallel_backend.log(logs, step=train_state.step)
+                train_state.log_steps.append(train_state.step)
+
+            # 7. Save checkpoint if required
+            self.checkpointer.save(
+                step=train_state.step, _device=device, _is_main_process=parallel_backend.is_main_process
+            )
+
+            # 8. Perform validation if required
+            if train_state.step % self.args.validation_steps == 0:
+                self._validate(step=train_state.step, final_validation=False)
+
+        # 9. Final checkpoint, validation & cleanup
+        self.checkpointer.save(
+            train_state.step, force=True, _device=device, _is_main_process=parallel_backend.is_main_process
+        )
+        parallel_backend.wait_for_everyone()
+        self._validate(step=train_state.step, final_validation=True)
+
+        self._delete_components()
+        memory_statistics = utils.get_memory_statistics()
+        logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}")
+
+        # 10. Upload artifacts to hub
+        if parallel_backend.is_main_process and self.args.push_to_hub:
+            upload_folder(
+                repo_id=self.state.repo_id,
+                folder_path=self.args.output_dir,
+                ignore_patterns=[f"{self.checkpointer._prefix}_*"],
+            )
+
+        parallel_backend.destroy()
+
+    def _validate(self, step: int, final_validation: bool = False) -> None:
+        if self.args.validation_dataset_file is None:
+            return
+
+        logger.info("Starting validation")
+
+        # 1. Load validation dataset
+        parallel_backend = self.state.parallel_backend
+        dp_mesh = parallel_backend.get_mesh("dp_replicate")
+
+        if dp_mesh is not None:
+            local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size()
+        else:
+            local_rank, dp_world_size = 0, 1
+
+        dataset = data.ValidationDataset(self.args.validation_dataset_file)
+        dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, local_rank, dp_world_size)
+        validation_dataloader = data.DPDataLoader(
+            local_rank,
+            dataset,
+            batch_size=1,
+            num_workers=self.args.dataloader_num_workers,
+            collate_fn=lambda items: items,
+        )
+        data_iterator = iter(validation_dataloader)
+        main_process_prompts_to_filenames = {}  # Used to save model card
+        all_processes_artifacts = []  # Used to gather artifacts from all processes
+
+        memory_statistics = utils.get_memory_statistics()
+        logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
+
+        seed = self.args.seed if self.args.seed is not None else 0
+        generator = torch.Generator(device=parallel_backend.device).manual_seed(seed)
+        pipeline = self._init_pipeline(final_validation=final_validation)
+
+        # 2. Run validation
+        # TODO(aryan): when running validation with FSDP, if the number of data points is not divisible by dp_shards, we
+        # will hang indefinitely. Either pad the dataset or raise an error early on during initialization if the dataset
+        # size is not divisible by dp_shards.
+        self.transformer.eval()
+        while True:
+            validation_data = next(data_iterator, None)
+            if validation_data is None:
+                break
+
+            logger.debug(
+                f"Validating {validation_data=} on rank={parallel_backend.rank}.", local_main_process_only=False
+            )
+
+            validation_data = validation_data[0]
+            validation_artifacts = self.model_specification.validation(
+                pipeline=pipeline, generator=generator, **validation_data
+            )
+
+            PROMPT = validation_data["prompt"]
+            IMAGE = validation_data.get("image", None)
+            VIDEO = validation_data.get("video", None)
+            EXPORT_FPS = validation_data.get("export_fps", 30)
+
+            # 2.1. If there are any initial images or videos, they will be logged to keep track of them as
+            # conditioning for generation.
+            prompt_filename = utils.string_to_filename(PROMPT)[:25]
+            artifacts = {
+                "input_image": data.ImageArtifact(value=IMAGE),
+                "input_video": data.VideoArtifact(value=VIDEO),
+            }
+
+            # 2.2. Track the artifacts generated from validation
+            for i, validation_artifact in enumerate(validation_artifacts):
+                if validation_artifact.value is None:
+                    continue
+                artifacts.update({f"artifact_{i}": validation_artifact})
+
+            # 2.3. Save the artifacts to the output directory and create appropriate logging objects
+            # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited.
+            for index, (key, artifact) in enumerate(list(artifacts.items())):
+                assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact))
+                filename = "validation-" if not final_validation else "final-"
+                filename += f"{step}-{parallel_backend.rank}-{index}-{prompt_filename}.{artifact.file_extension}"
+                output_filename = os.path.join(self.args.output_dir, filename)
+
+                if parallel_backend.is_main_process and artifact.file_extension == "mp4":
+                    main_process_prompts_to_filenames[PROMPT] = filename
+
+                caption = f"{PROMPT} | (filename: {output_filename})"
+                if artifact.type == "image" and artifact.value is not None:
+                    logger.debug(
+                        f"Saving image from rank={parallel_backend.rank} to {output_filename}",
+                        local_main_process_only=False,
+                    )
+                    artifact.value.save(output_filename)
+                    all_processes_artifacts.append(wandb.Image(output_filename, caption=caption))
+                elif artifact.type == "video" and artifact.value is not None:
+                    logger.debug(
+                        f"Saving video from rank={parallel_backend.rank} to {output_filename}",
+                        local_main_process_only=False,
+                    )
+                    export_to_video(artifact.value, output_filename, fps=EXPORT_FPS)
+                    all_processes_artifacts.append(wandb.Video(output_filename, caption=caption))
+
+        # 3. Cleanup & log artifacts
+        parallel_backend.wait_for_everyone()
+
+        # Remove all hooks that might have been added during pipeline initialization to the models
+        pipeline.remove_all_hooks()
+        del pipeline
+
+        utils.free_memory()
+        memory_statistics = utils.get_memory_statistics()
+        logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
+        torch.cuda.reset_peak_memory_stats(parallel_backend.device)
+
+        # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts.
+        # TODO(aryan): probably should only all gather from dp mesh process group
+        all_artifacts = [None] * parallel_backend.world_size
+        torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts)
+        all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts]
+
+        if parallel_backend.is_main_process:
+            tracker_key = "final" if final_validation else "validation"
+            artifact_log_dict = {}
+
+            image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
+            if len(image_artifacts) > 0:
+                artifact_log_dict["images"] = image_artifacts
+            video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
+            if len(video_artifacts) > 0:
+                artifact_log_dict["videos"] = video_artifacts
+            parallel_backend.log({tracker_key: artifact_log_dict}, step=step)
+
+            if self.args.push_to_hub and final_validation:
+                video_filenames = list(main_process_prompts_to_filenames.values())
+                prompts = list(main_process_prompts_to_filenames.keys())
+                utils.save_model_card(
+                    args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts
+                )
+
+        parallel_backend.wait_for_everyone()
+        if not final_validation:
+            self.transformer.train()
+
+    def _evaluate(self) -> None:
+        raise NotImplementedError("Evaluation has not been implemented yet.")
+
+    def _init_distributed(self) -> None:
+        # TODO: Accelerate disables native_amp for MPS. Probably need to do the same with implementation.
+        world_size = int(os.environ["WORLD_SIZE"])
+
+        # TODO(aryan): handle other backends
+        backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend)
+        self.state.parallel_backend = backend_cls(
+            world_size=world_size,
+            pp_degree=self.args.pp_degree,
+            dp_degree=self.args.dp_degree,
+            dp_shards=self.args.dp_shards,
+            cp_degree=self.args.cp_degree,
+            tp_degree=self.args.tp_degree,
+            backend="nccl",
+            timeout=self.args.init_timeout,
+            logging_dir=self.args.logging_dir,
+            output_dir=self.args.output_dir,
+            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+        )
+
+        if self.args.seed is not None:
+            world_mesh = self.state.parallel_backend.get_mesh()
+            utils.enable_determinism(self.args.seed, world_mesh)
+
+    def _init_logging(self) -> None:
+        transformers_log_level = transformers.utils.logging.set_verbosity_error
+        diffusers_log_level = diffusers.utils.logging.set_verbosity_error
+
+        if self.args.verbose == 0:
+            if self.state.parallel_backend.is_local_main_process:
+                transformers_log_level = transformers.utils.logging.set_verbosity_warning
+                diffusers_log_level = diffusers.utils.logging.set_verbosity_warning
+        elif self.args.verbose == 1:
+            if self.state.parallel_backend.is_local_main_process:
+                transformers_log_level = transformers.utils.logging.set_verbosity_info
+                diffusers_log_level = diffusers.utils.logging.set_verbosity_info
+        elif self.args.verbose == 2:
+            if self.state.parallel_backend.is_local_main_process:
+                transformers_log_level = transformers.utils.logging.set_verbosity_debug
+                diffusers_log_level = diffusers.utils.logging.set_verbosity_debug
+        else:
+            transformers_log_level = transformers.utils.logging.set_verbosity_debug
+            diffusers_log_level = diffusers.utils.logging.set_verbosity_debug
+
+        transformers_log_level()
+        diffusers_log_level()
+
+        logging._set_parallel_backend(self.state.parallel_backend)
+        logger.info("Initialized FineTrainers")
+
+    def _init_trackers(self) -> None:
+        # TODO(aryan): handle multiple trackers
+        trackers = ["wandb"]
+        experiment_name = self.args.tracker_name or "finetrainers-experiment"
+        self.state.parallel_backend.initialize_trackers(
+            trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
+        )
+
+    def _init_directories_and_repositories(self) -> None:
+        if self.state.parallel_backend.is_main_process:
+            self.args.output_dir = Path(self.args.output_dir)
+            self.args.output_dir.mkdir(parents=True, exist_ok=True)
+            self.state.output_dir = Path(self.args.output_dir)
+
+            if self.args.push_to_hub:
+                repo_id = self.args.hub_model_id or Path(self.args.output_dir).name
+                self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id
+
+    def _init_config_options(self) -> None:
+        # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+        if self.args.allow_tf32 and torch.cuda.is_available():
+            torch.backends.cuda.matmul.allow_tf32 = True
+
+    def _move_components_to_device(
+        self, components: Optional[List[torch.nn.Module]] = None, device: Optional[Union[str, torch.device]] = None
+    ) -> None:
+        if device is None:
+            device = self.state.parallel_backend.device
+        if components is None:
+            components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.transformer, self.vae]
+        components = utils.get_non_null_items(components)
+        components = list(filter(lambda x: hasattr(x, "to"), components))
+        for component in components:
+            component.to(device)
+
+    def _set_components(self, components: Dict[str, Any]) -> None:
+        # fmt: off
+        component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
+        # fmt: on
+
+        for component_name in component_names:
+            existing_component = getattr(self, component_name, None)
+            new_component = components.get(component_name, existing_component)
+            setattr(self, component_name, new_component)
+
+    def _delete_components(self, component_names: Optional[List[str]] = None) -> None:
+        if component_names is None:
+            # fmt: off
+            component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
+            # fmt: on
+
+        for component_name in component_names:
+            setattr(self, component_name, None)
+
+        utils.free_memory()
+        utils.synchronize_device()
+
+    def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline:
+        parallel_backend = self.state.parallel_backend
+        module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"]
+
+        if not final_validation:
+            module_names.remove("transformer")
+            pipeline = self.model_specification.load_pipeline(
+                tokenizer=self.tokenizer,
+                tokenizer_2=self.tokenizer_2,
+                tokenizer_3=self.tokenizer_3,
+                text_encoder=self.text_encoder,
+                text_encoder_2=self.text_encoder_2,
+                text_encoder_3=self.text_encoder_3,
+                # TODO(aryan): handle unwrapping for compiled modules
+                # transformer=utils.unwrap_model(accelerator, self.transformer),
+                transformer=self.transformer,
+                vae=self.vae,
+                enable_slicing=self.args.enable_slicing,
+                enable_tiling=self.args.enable_tiling,
+                enable_model_cpu_offload=self.args.enable_model_cpu_offload,
+                training=True,
+            )
+        else:
+            # TODO(aryan): this branch does not work yet, needs to be implemented
+            self._delete_components()
+
+            # Load the transformer weights from the final checkpoint if performing full-finetune
+            transformer = None
+            if self.args.training_type == TrainingType.FULL_FINETUNE:
+                transformer = self.model_specification.load_diffusion_models()["transformer"]
+
+            pipeline = self.model_specification.load_pipeline(
+                transformer=transformer,
+                enable_slicing=self.args.enable_slicing,
+                enable_tiling=self.args.enable_tiling,
+                enable_model_cpu_offload=self.args.enable_model_cpu_offload,
+                training=False,
+                device=parallel_backend.device,
+            )
+
+            # Load the LoRA weights if performing LoRA finetuning
+            if self.args.training_type == TrainingType.LORA:
+                pipeline.load_lora_weights(self.args.output_dir)
+
+        components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names}
+        self._set_components(components)
+        self._move_components_to_device(list(components.values()))
+        return pipeline
+
+    def _prepare_data(self, preprocessor: data.DistributedDataPreprocessor, data_iterator):
+        logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
+        if self.args.precomputation_once:
+            consume_fn = preprocessor.consume_once
+        else:
+            consume_fn = preprocessor.consume
+
+        condition_components = self.model_specification.load_condition_models()
+        component_names = list(condition_components.keys())
+        component_modules = list(condition_components.values())
+        self._set_components(condition_components)
+        self._move_components_to_device(component_modules)
+        precomputed_condition_iterator = consume_fn(
+            "condition",
+            components=condition_components,
+            data_iterator=data_iterator,
+            generator=self.state.generator,
+            cache_samples=True,
+        )
+        self._delete_components(component_names)
+        del condition_components, component_names, component_modules
+
+        latent_components = self.model_specification.load_latent_models()
+        if self.vae is not None:
+            if self.args.enable_slicing:
+                self.vae.enable_slicing()
+            if self.args.enable_tiling:
+                self.vae.enable_tiling()
+        component_names = list(latent_components.keys())
+        component_modules = list(latent_components.values())
+        self._set_components(latent_components)
+        self._move_components_to_device(component_modules)
+        precomputed_latent_iterator = consume_fn(
+            "latent",
+            components=latent_components,
+            data_iterator=data_iterator,
+            generator=self.state.generator,
+            use_cached_samples=True,
+            drop_samples=True,
+        )
+        self._delete_components(component_names)
+        del latent_components, component_names, component_modules
+
+        return precomputed_condition_iterator, precomputed_latent_iterator
+
+    def _get_training_info(self) -> Dict[str, Any]:
+        info = self.args.to_dict()
+
+        # Removing flow matching arguments when not using flow-matching objective
+        diffusion_args = info.get("diffusion_arguments", {})
+        scheduler_name = self.scheduler.__class__.__name__ if self.scheduler is not None else ""
+        if scheduler_name != "FlowMatchEulerDiscreteScheduler":
+            filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k}
+        else:
+            filtered_diffusion_args = diffusion_args
+
+        info.update({"diffusion_arguments": filtered_diffusion_args})
+        return info
diff --git a/finetrainers/typing.py b/finetrainers/typing.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7b3b339f252d8f47ef0ff67aa6c6733a2ccd7cf
--- /dev/null
+++ b/finetrainers/typing.py
@@ -0,0 +1,11 @@
+from typing import Union
+
+from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler
+from transformers import CLIPTokenizer, LlamaTokenizer, LlamaTokenizerFast, T5Tokenizer, T5TokenizerFast
+
+from .data import ImageArtifact, VideoArtifact
+
+
+ArtifactType = Union[ImageArtifact, VideoArtifact]
+SchedulerType = Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]
+TokenizerType = Union[CLIPTokenizer, T5Tokenizer, T5TokenizerFast, LlamaTokenizer, LlamaTokenizerFast]
diff --git a/finetrainers/utils/__init__.py b/finetrainers/utils/__init__.py
index d06bd9def1f5d4c242cafdc03cd2d31414c1f169..7339b296efedaf9c78f8491eaaf78a40f8fdaef1 100644
--- a/finetrainers/utils/__init__.py
+++ b/finetrainers/utils/__init__.py
@@ -1,4 +1,9 @@
-from .diffusion_utils import (
+import inspect
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
+
+from .activation_checkpoint import apply_activation_checkpointing
+from .data import determine_batch_size, should_perform_precomputation
+from .diffusion import (
     default_flow_shift,
     get_scheduler_alphas,
     get_scheduler_sigmas,
@@ -7,7 +12,33 @@ from .diffusion_utils import (
     prepare_target,
     resolution_dependent_timestep_flow_shift,
 )
-from .file_utils import delete_files, find_files
-from .memory_utils import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous
-from .optimizer_utils import get_optimizer, gradient_norm, max_gradient
-from .torch_utils import unwrap_model
+from .file import delete_files, find_files, string_to_filename
+from .hub import save_model_card
+from .memory import bytes_to_gigabytes, free_memory, get_memory_statistics, make_contiguous
+from .model import resolve_component_cls
+from .state_checkpoint import PTDCheckpointManager
+from .torch import (
+    align_device_and_dtype,
+    clip_grad_norm_,
+    enable_determinism,
+    expand_tensor_dims,
+    get_device_info,
+    set_requires_grad,
+    synchronize_device,
+    unwrap_model,
+)
+
+
+def get_parameter_names(obj: Any, method_name: Optional[str] = None) -> Set[str]:
+    if method_name is not None:
+        obj = getattr(obj, method_name)
+    return {name for name, _ in inspect.signature(obj).parameters.items()}
+
+
+def get_non_null_items(
+    x: Union[List[Any], Tuple[Any], Dict[str, Any]]
+) -> Union[List[Any], Tuple[Any], Dict[str, Any]]:
+    if isinstance(x, dict):
+        return {k: v for k, v in x.items() if v is not None}
+    if isinstance(x, (list, tuple)):
+        return type(x)(v for v in x if v is not None)
diff --git a/finetrainers/utils/_common.py b/finetrainers/utils/_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..c230e878d6fe715d696d8285c51f0ba073fd6b3e
--- /dev/null
+++ b/finetrainers/utils/_common.py
@@ -0,0 +1,6 @@
+DIFFUSERS_TRANSFORMER_BLOCK_NAMES = [
+    "transformer_blocks",
+    "single_transformer_blocks",
+    "temporal_transformer_blocks",
+    "blocks",
+]
diff --git a/finetrainers/utils/activation_checkpoint.py b/finetrainers/utils/activation_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc4193a6cc027a771fe1fc2c3cb34595fbc336b2
--- /dev/null
+++ b/finetrainers/utils/activation_checkpoint.py
@@ -0,0 +1,71 @@
+import collections
+from enum import Enum
+
+import torch
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper
+
+from ._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES
+
+
+class CheckpointType(str, Enum):
+    FULL = "full"
+    OPS = "ops"
+    BLOCK_SKIP = "block_skip"
+
+
+_SELECTIVE_ACTIVATION_CHECKPOINTING_OPS = {
+    torch.ops.aten.mm.default,
+    torch.ops.aten._scaled_dot_product_efficient_attention.default,
+    torch.ops.aten._scaled_dot_product_flash_attention.default,
+    torch.ops._c10d_functional.reduce_scatter_tensor.default,
+}
+
+
+def apply_activation_checkpointing(
+    module: torch.nn.Module, checkpointing_type: str = CheckpointType.FULL, n_layer: int = 1
+) -> torch.nn.Module:
+    if checkpointing_type == CheckpointType.FULL:
+        module = _apply_activation_checkpointing_blocks(module)
+    elif checkpointing_type == CheckpointType.OPS:
+        module = _apply_activation_checkpointing_ops(module, _SELECTIVE_ACTIVATION_CHECKPOINTING_OPS)
+    elif checkpointing_type == CheckpointType.BLOCK_SKIP:
+        module = _apply_activation_checkpointing_blocks(module, n_layer)
+    else:
+        raise ValueError(
+            f"Checkpointing type '{checkpointing_type}' not supported. Supported types are {CheckpointType.__members__.keys()}"
+        )
+    return module
+
+
+def _apply_activation_checkpointing_blocks(module: torch.nn.Module, n_layer: int = None) -> torch.nn.Module:
+    for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
+        blocks: torch.nn.Module = getattr(module, transformer_block_name, None)
+        if blocks is None:
+            continue
+        for index, (layer_id, block) in enumerate(blocks.named_children()):
+            if n_layer is None or index % n_layer == 0:
+                block = checkpoint_wrapper(block, preserve_rng_state=False)
+                blocks.register_module(layer_id, block)
+    return module
+
+
+def _apply_activation_checkpointing_ops(module: torch.nn.Module, ops) -> torch.nn.Module:
+    from torch.utils.checkpoint import CheckpointPolicy, create_selective_checkpoint_contexts
+
+    def _get_custom_policy(meta):
+        def _custom_policy(ctx, func, *args, **kwargs):
+            mode = "recompute" if ctx.is_recompute else "forward"
+            mm_count_key = f"{mode}_mm_count"
+            if func == torch.ops.aten.mm.default:
+                meta[mm_count_key] += 1
+            # Saves output of all compute ops, except every second mm
+            to_save = func in ops and not (func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0)
+            return CheckpointPolicy.MUST_SAVE if to_save else CheckpointPolicy.PREFER_RECOMPUTE
+
+        return _custom_policy
+
+    def selective_checkpointing_context_fn():
+        meta = collections.defaultdict(int)
+        return create_selective_checkpoint_contexts(_get_custom_policy(meta))
+
+    return checkpoint_wrapper(module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False)
diff --git a/finetrainers/utils/checkpointing.py b/finetrainers/utils/checkpointing.py
deleted file mode 100644
index 01dbea0a029144f53ff8587b887204515a7e3250..0000000000000000000000000000000000000000
--- a/finetrainers/utils/checkpointing.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import os
-from typing import Tuple
-
-from accelerate.logging import get_logger
-
-from ..constants import FINETRAINERS_LOG_LEVEL
-from ..utils.file_utils import delete_files, find_files
-
-
-logger = get_logger("finetrainers")
-logger.setLevel(FINETRAINERS_LOG_LEVEL)
-
-
-def get_latest_ckpt_path_to_resume_from(
-    resume_from_checkpoint: str, num_update_steps_per_epoch: int, output_dir: str
-) -> Tuple[str, int, int, int]:
-    if not resume_from_checkpoint:
-        initial_global_step = 0
-        global_step = 0
-        first_epoch = 0
-        resume_from_checkpoint_path = None
-    else:
-        if resume_from_checkpoint != "latest":
-            path = os.path.basename(resume_from_checkpoint)
-        else:
-            # Get the most recent checkpoint
-            dirs = os.listdir(output_dir)
-            dirs = [d for d in dirs if d.startswith("checkpoint")]
-            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
-            path = dirs[-1] if len(dirs) > 0 else None
-
-        if path is None:
-            logger.info(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
-            resume_from_checkpoint = None
-            initial_global_step = 0
-            global_step = 0
-            first_epoch = 0
-            resume_from_checkpoint_path = None
-        else:
-            logger.info(f"Resuming from checkpoint {path}")
-            resume_from_checkpoint_path = os.path.join(output_dir, path)
-            global_step = int(path.split("-")[1])
-
-            initial_global_step = global_step
-            first_epoch = global_step // num_update_steps_per_epoch
-
-    return resume_from_checkpoint_path, initial_global_step, global_step, first_epoch
-
-
-def get_intermediate_ckpt_path(checkpointing_limit: int, step: int, output_dir: str) -> str:
-    # before saving state, check if this save would set us over the `checkpointing_limit`
-    if checkpointing_limit is not None:
-        checkpoints = find_files(output_dir, prefix="checkpoint")
-
-        # before we save the new checkpoint, we need to have at_most `checkpoints_total_limit - 1` checkpoints
-        if len(checkpoints) >= checkpointing_limit:
-            num_to_remove = len(checkpoints) - checkpointing_limit + 1
-            checkpoints_to_remove = [os.path.join(output_dir, x) for x in checkpoints[0:num_to_remove]]
-            delete_files(checkpoints_to_remove)
-
-    logger.info(f"Checkpointing at step {step}")
-    save_path = os.path.join(output_dir, f"checkpoint-{step}")
-    logger.info(f"Saving state to {save_path}")
-    return save_path
diff --git a/finetrainers/utils/data_utils.py b/finetrainers/utils/data.py
similarity index 76%
rename from finetrainers/utils/data_utils.py
rename to finetrainers/utils/data.py
index 284dd1a9c9c1d91633dd0ae7e032a0cb507e8b2c..ae3fcc35262f7ec85d015c468983b033d61a154c 100644
--- a/finetrainers/utils/data_utils.py
+++ b/finetrainers/utils/data.py
@@ -1,6 +1,7 @@
 from pathlib import Path
-from typing import Union
+from typing import Any, Union
 
+import torch
 from accelerate.logging import get_logger
 
 from ..constants import PRECOMPUTED_CONDITIONS_DIR_NAME, PRECOMPUTED_LATENTS_DIR_NAME
@@ -33,3 +34,18 @@ def should_perform_precomputation(precomputation_dir: Union[str, Path]) -> bool:
             return False
     logger.info("Precomputed data not found. Running precomputation.")
     return True
+
+
+def determine_batch_size(x: Any) -> int:
+    if isinstance(x, list):
+        return len(x)
+    if isinstance(x, torch.Tensor):
+        return x.size(0)
+    if isinstance(x, dict):
+        for key in x:
+            try:
+                return determine_batch_size(x[key])
+            except ValueError:
+                pass
+        return 1
+    raise ValueError("Could not determine batch size from input.")
diff --git a/finetrainers/utils/diffusion_utils.py b/finetrainers/utils/diffusion.py
similarity index 100%
rename from finetrainers/utils/diffusion_utils.py
rename to finetrainers/utils/diffusion.py
diff --git a/finetrainers/utils/file_utils.py b/finetrainers/utils/file.py
similarity index 86%
rename from finetrainers/utils/file_utils.py
rename to finetrainers/utils/file.py
index eb731771ba9bc7e07a273ca8947949dfd572b465..ba01213e758685aaa339f3ecad12c312a540dd9e 100644
--- a/finetrainers/utils/file_utils.py
+++ b/finetrainers/utils/file.py
@@ -1,12 +1,12 @@
-import logging
 import os
 import shutil
 from pathlib import Path
 from typing import List, Union
 
+from ..logging import get_logger
 
-logger = logging.getLogger("finetrainers")
-logger.setLevel(os.environ.get("FINETRAINERS_LOG_LEVEL", "INFO"))
+
+logger = get_logger()
 
 
 def find_files(dir: Union[str, Path], prefix: str = "checkpoint") -> List[str]:
@@ -24,7 +24,7 @@ def delete_files(dirs: Union[str, List[str], Path, List[Path]]) -> None:
     if not isinstance(dirs, list):
         dirs = [dirs]
     dirs = [Path(d) if isinstance(d, str) else d for d in dirs]
-    logger.info(f"Deleting files: {dirs}")
+    logger.debug(f"Deleting files: {dirs}")
     for dir in dirs:
         if not dir.exists():
             continue
diff --git a/finetrainers/utils/hub_utils.py b/finetrainers/utils/hub.py
similarity index 82%
rename from finetrainers/utils/hub_utils.py
rename to finetrainers/utils/hub.py
index ef865407ef9a1d41300c2e544d622a62b498989b..ea1a16eb42cbb1f2848376440817a3e1680ce61c 100644
--- a/finetrainers/utils/hub_utils.py
+++ b/finetrainers/utils/hub.py
@@ -28,20 +28,17 @@ def save_model_card(
                 }
             )
 
-    training_type = "Full" if args.training_type == "full-finetune" else "LoRA"
     model_description = f"""
-# {training_type} Finetune
+# LoRA Finetune
 
 <Gallery />
 
 ## Model description
 
-This is a {training_type.lower()} finetune of model: `{args.pretrained_model_name_or_path}`.
+This is a lora finetune of model: `{args.pretrained_model_name_or_path}`.
 
 The model was trained using [`finetrainers`](https://github.com/a-r-r-o-w/finetrainers).
 
-`id_token` used: {args.id_token} (if it's not `None`, it should be used in the prompts.)
-
 ## Download model
 
 [Download LoRA]({repo_id}/tree/main) in the Files & Versions tab.
@@ -56,7 +53,7 @@ TODO
 
 For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
 """
-    if wandb.run and wandb.run.url:
+    if wandb.run.url:
         model_description += f"""
 Find out the wandb run URL and training configurations [here]({wandb.run.url}).
 """
@@ -72,13 +69,9 @@ Find out the wandb run URL and training configurations [here]({wandb.run.url}).
         "text-to-video",
         "diffusers-training",
         "diffusers",
-        "finetrainers",
+        "lora",
         "template:sd-lora",
     ]
-    if training_type == "Full":
-        tags.append("full-finetune")
-    else:
-        tags.append("lora")
 
     model_card = populate_model_card(model_card, tags=tags)
     model_card.save(os.path.join(args.output_dir, "README.md"))
diff --git a/finetrainers/utils/import_utils.py b/finetrainers/utils/import_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..56c19db6e4f032296bca103e38155da1ded1c074
--- /dev/null
+++ b/finetrainers/utils/import_utils.py
@@ -0,0 +1,20 @@
+import importlib
+
+import importlib_metadata
+
+from ..logging import get_logger
+
+
+logger = get_logger()
+
+
+_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
+try:
+    _bitsandbytes_version = importlib_metadata.version("bitsandbytes")
+    logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
+except importlib_metadata.PackageNotFoundError:
+    _bitsandbytes_available = False
+
+
+def is_bitsandbytes_available():
+    return _bitsandbytes_available
diff --git a/finetrainers/utils/memory_utils.py b/finetrainers/utils/memory.py
similarity index 100%
rename from finetrainers/utils/memory_utils.py
rename to finetrainers/utils/memory.py
diff --git a/finetrainers/utils/model.py b/finetrainers/utils/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4427f97d25ed44b2d9832cf456b082f65d66c2a8
--- /dev/null
+++ b/finetrainers/utils/model.py
@@ -0,0 +1,32 @@
+import importlib
+import json
+import os
+from typing import Optional
+
+from huggingface_hub import hf_hub_download
+
+
+def resolve_component_cls(
+    pretrained_model_name_or_path: str,
+    component_name: str,
+    filename: str = "model_index.json",
+    revision: Optional[str] = None,
+    cache_dir: Optional[str] = None,
+):
+    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+    if os.path.exists(str(pretrained_model_name_or_path)) and os.path.isdir(pretrained_model_name_or_path):
+        index_path = os.path.join(pretrained_model_name_or_path, filename)
+    else:
+        index_path = hf_hub_download(
+            repo_id=pretrained_model_name_or_path, filename=filename, revision=revision, cache_dir=cache_dir
+        )
+
+    with open(index_path, "r") as f:
+        model_index_dict = json.load(f)
+
+    if component_name not in model_index_dict:
+        raise ValueError(f"No {component_name} found in the model index dict.")
+
+    cls_config = model_index_dict[component_name]
+    library = importlib.import_module(cls_config[0])
+    return getattr(library, cls_config[1])
diff --git a/finetrainers/utils/model_utils.py b/finetrainers/utils/model_utils.py
deleted file mode 100644
index 1451ebff3d7c29e60d856cd41f56b358b044ffc9..0000000000000000000000000000000000000000
--- a/finetrainers/utils/model_utils.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import importlib
-import json
-import os
-
-from huggingface_hub import hf_hub_download
-
-
-def resolve_vae_cls_from_ckpt_path(ckpt_path, **kwargs):
-    ckpt_path = str(ckpt_path)
-    if os.path.exists(str(ckpt_path)) and os.path.isdir(ckpt_path):
-        index_path = os.path.join(ckpt_path, "model_index.json")
-    else:
-        revision = kwargs.get("revision", None)
-        cache_dir = kwargs.get("cache_dir", None)
-        index_path = hf_hub_download(
-            repo_id=ckpt_path, filename="model_index.json", revision=revision, cache_dir=cache_dir
-        )
-
-    with open(index_path, "r") as f:
-        model_index_dict = json.load(f)
-    assert "vae" in model_index_dict, "No VAE found in the modelx index dict."
-
-    vae_cls_config = model_index_dict["vae"]
-    library = importlib.import_module(vae_cls_config[0])
-    return getattr(library, vae_cls_config[1])
diff --git a/finetrainers/utils/optimizer_utils.py b/finetrainers/utils/optimizer_utils.py
deleted file mode 100644
index 84c215b51e8e458ff3702fd65be223bce7a7aeb9..0000000000000000000000000000000000000000
--- a/finetrainers/utils/optimizer_utils.py
+++ /dev/null
@@ -1,178 +0,0 @@
-import inspect
-
-import torch
-from accelerate.logging import get_logger
-
-
-logger = get_logger("finetrainers")
-
-
-def get_optimizer(
-    params_to_optimize,
-    optimizer_name: str = "adam",
-    learning_rate: float = 1e-3,
-    beta1: float = 0.9,
-    beta2: float = 0.95,
-    beta3: float = 0.98,
-    epsilon: float = 1e-8,
-    weight_decay: float = 1e-4,
-    prodigy_decouple: bool = False,
-    prodigy_use_bias_correction: bool = False,
-    prodigy_safeguard_warmup: bool = False,
-    use_8bit: bool = False,
-    use_4bit: bool = False,
-    use_torchao: bool = False,
-    use_deepspeed: bool = False,
-    use_cpu_offload_optimizer: bool = False,
-    offload_gradients: bool = False,
-) -> torch.optim.Optimizer:
-    optimizer_name = optimizer_name.lower()
-
-    # Use DeepSpeed optimzer
-    if use_deepspeed:
-        from accelerate.utils import DummyOptim
-
-        return DummyOptim(
-            params_to_optimize,
-            lr=learning_rate,
-            betas=(beta1, beta2),
-            eps=epsilon,
-            weight_decay=weight_decay,
-        )
-
-    # TODO: consider moving the validation logic to `args.py` when we have torchao.
-    if use_8bit and use_4bit:
-        raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
-
-    if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
-        try:
-            import torchao  # noqa
-
-        except ImportError:
-            raise ImportError(
-                "To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
-            )
-
-    if not use_torchao and use_4bit:
-        raise ValueError("4-bit Optimizers are only supported with torchao.")
-
-    # Optimizer creation
-    supported_optimizers = ["adam", "adamw", "prodigy", "came"]
-    if optimizer_name not in supported_optimizers:
-        logger.warning(
-            f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
-        )
-        optimizer_name = "adamw"
-
-    if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
-        raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
-
-    if use_8bit:
-        try:
-            import bitsandbytes as bnb
-        except ImportError:
-            raise ImportError(
-                "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
-            )
-
-    if optimizer_name == "adamw":
-        if use_torchao:
-            from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
-
-            optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
-        else:
-            optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
-
-        init_kwargs = {
-            "betas": (beta1, beta2),
-            "eps": epsilon,
-            "weight_decay": weight_decay,
-        }
-
-    elif optimizer_name == "adam":
-        if use_torchao:
-            from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
-
-            optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
-        else:
-            optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
-
-        init_kwargs = {
-            "betas": (beta1, beta2),
-            "eps": epsilon,
-            "weight_decay": weight_decay,
-        }
-
-    elif optimizer_name == "prodigy":
-        try:
-            import prodigyopt
-        except ImportError:
-            raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
-
-        optimizer_class = prodigyopt.Prodigy
-
-        if learning_rate <= 0.1:
-            logger.warning(
-                "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
-            )
-
-        init_kwargs = {
-            "lr": learning_rate,
-            "betas": (beta1, beta2),
-            "beta3": beta3,
-            "eps": epsilon,
-            "weight_decay": weight_decay,
-            "decouple": prodigy_decouple,
-            "use_bias_correction": prodigy_use_bias_correction,
-            "safeguard_warmup": prodigy_safeguard_warmup,
-        }
-
-    elif optimizer_name == "came":
-        try:
-            import came_pytorch
-        except ImportError:
-            raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
-
-        optimizer_class = came_pytorch.CAME
-
-        init_kwargs = {
-            "lr": learning_rate,
-            "eps": (1e-30, 1e-16),
-            "betas": (beta1, beta2, beta3),
-            "weight_decay": weight_decay,
-        }
-
-    if use_cpu_offload_optimizer:
-        from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
-
-        if "fused" in inspect.signature(optimizer_class.__init__).parameters:
-            init_kwargs.update({"fused": True})
-
-        optimizer = CPUOffloadOptimizer(
-            params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
-        )
-    else:
-        optimizer = optimizer_class(params_to_optimize, **init_kwargs)
-
-    return optimizer
-
-
-def gradient_norm(parameters):
-    norm = 0
-    for param in parameters:
-        if param.grad is None:
-            continue
-        local_norm = param.grad.detach().data.norm(2)
-        norm += local_norm.item() ** 2
-    norm = norm**0.5
-    return norm
-
-
-def max_gradient(parameters):
-    max_grad_value = float("-inf")
-    for param in parameters:
-        if param.grad is None:
-            continue
-        local_max_grad = param.grad.detach().data.abs().max()
-        max_grad_value = max(max_grad_value, local_max_grad.item())
-    return max_grad_value
diff --git a/finetrainers/utils/state_checkpoint.py b/finetrainers/utils/state_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab0e0b9b5f6214ba56d0c308714e29b9e11f4d8a
--- /dev/null
+++ b/finetrainers/utils/state_checkpoint.py
@@ -0,0 +1,203 @@
+import functools
+import pathlib
+import shutil
+import time
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
+
+import torch
+import torch.distributed.checkpoint
+from torch.distributed.checkpoint.state_dict import (
+    StateDictOptions,
+    get_model_state_dict,
+    set_model_state_dict,
+)
+from torch.distributed.checkpoint.stateful import Stateful
+
+from ..logging import get_logger
+
+
+if TYPE_CHECKING:
+    from .. import optimizer
+
+
+logger = get_logger()
+
+
+class ModelWrapper(Stateful):
+    def __init__(self, model: Union[torch.nn.Module, List[torch.nn.Module]]) -> None:
+        self.model = [model] if isinstance(model, torch.nn.Module) else model
+
+    def state_dict(self) -> Dict[str, Any]:
+        return {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
+
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+        func = functools.partial(
+            set_model_state_dict,
+            model_state_dict=state_dict,
+            options=StateDictOptions(strict=False),
+        )
+        list(map(func, self.model))
+
+
+class PTDCheckpointManager:
+    def __init__(
+        self,
+        dataloader: torch.utils.data.DataLoader,
+        model_parts: List[torch.nn.Module],
+        optimizers: "optimizer.OptimizerWrapper",
+        schedulers: "optimizer.SchedulerWrapper",
+        states: Dict[str, Any],
+        checkpointing_steps: int,
+        checkpointing_limit: int,
+        output_dir: str,
+        enable: bool = True,
+        _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
+        _prefix: str = "finetrainers_step",
+    ) -> None:
+        self.states = states
+        self.states.update(
+            {
+                "model": ModelWrapper(model_parts),
+                "optimizer": optimizers,
+                "dataloader": dataloader,
+            }
+        )
+        self.states.update(schedulers.get_lr_scheduler_state())
+
+        self.checkpointing_steps = checkpointing_steps
+        self.checkpointing_limit = checkpointing_limit
+        self.output_dir = pathlib.Path(output_dir)
+        self.enable = enable
+        self._callback_fn = _callback_fn
+        self._prefix = _prefix
+
+        logger.info(f"Checkpointing enabled. Checkpoints will be stored in '{self.output_dir}'")
+
+    def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _is_main_process: bool) -> str:
+        if not self._should_checkpoint(step, force):
+            return None
+
+        checkpoint_dir = self._get_checkpoint_dir(step)
+        begin_time = time.monotonic()
+        torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix())
+        end_time = time.monotonic()
+        logger.info(
+            f"Saved checkpoint in {end_time - begin_time:.2f} seconds at step {step}. Directory: {checkpoint_dir}"
+        )
+        self._purge_stale_checkpoints()
+
+        state_dicts = [
+            gather_state_dict_on_cpu_rank0(model, _device, is_main_process=_is_main_process)
+            for model in self.states["model"].model
+        ]
+        if self._callback_fn is not None:
+            list(map(self._callback_fn, state_dicts))
+
+        return checkpoint_dir.as_posix()
+
+    def load(self, step: int = -1) -> bool:
+        if not self.enable:
+            return False
+        if not self.output_dir.exists():
+            return False
+        if step != -1 and not self._get_checkpoint_dir(step).exists():
+            return False
+
+        if step == -1:
+            latest_checkpoint_dir = self._find_latest_checkpoint_dir()
+            if latest_checkpoint_dir is None:
+                return False
+            step = int(latest_checkpoint_dir.name.split("_")[-1])
+
+        checkpoint_dir = self._get_checkpoint_dir(step)
+        logger.info(f"Loading checkpoint from '{checkpoint_dir}' at step {step}")
+
+        # For step 0, optimizers/schedulers are not available as they are created during training after first step
+        states = {"model": self.states["model"]} if step == 0 else self.states
+
+        # See bug: https://github.com/pytorch/pytorch/pull/138575
+        original_stateful_states = {k: v for k, v in states.items() if isinstance(v, Stateful)}
+        begin_time = time.monotonic()
+        torch.distributed.checkpoint.load(states, checkpoint_id=checkpoint_dir.as_posix())
+        end_time = time.monotonic()
+        logger.info(f"Loaded checkpoint in {end_time - begin_time:.2f} seconds.")
+
+        # bugfix from above: restore the original stateful objects, whose states were already updated in-place by dcp.load()
+        states.update(original_stateful_states)
+
+        return True
+
+    def _should_checkpoint(self, step: int, force: bool) -> bool:
+        if not self.enable:
+            return False
+
+        if not force:
+            if step % self.checkpointing_steps != 0:
+                return False
+
+        return True
+
+    def _get_checkpoint_dir(self, step: int) -> pathlib.Path:
+        return self.output_dir / f"{self._prefix}_{step}"
+
+    def _find_latest_checkpoint_dir(self) -> Union[pathlib.Path, None]:
+        checkpoints = sorted(self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]))
+        return checkpoints[-1] if len(checkpoints) > 0 else None
+
+    def _purge_stale_checkpoints(self) -> None:
+        if self.checkpointing_limit is None or self.checkpointing_limit <= 0:
+            return
+        checkpoints = sorted(
+            self.output_dir.glob(f"{self._prefix}_*"), key=lambda x: int(x.name.split("_")[-1]), reverse=True
+        )
+        for checkpoint in checkpoints[self.checkpointing_limit :]:
+            logger.info(f"Deleting stale checkpoint: {checkpoint}")
+            shutil.rmtree(checkpoint, ignore_errors=True)
+
+
+def gather_state_dict_on_cpu_rank0(
+    model, device: Optional[torch.device] = None, *, is_main_process: bool
+) -> Dict[str, Any]:
+    cpu_state_dict = {}
+    sharded_sd = model.state_dict()
+    for param_name, param in sharded_sd.items():
+        if param.is_cpu:
+            # Move back to device if offloaded to CPU
+            param = param.to(device)
+        if hasattr(param, "_local_tensor"):
+            # Gather DTensor
+            param = param.full_tensor()
+        if is_main_process:
+            cpu_state_dict[param_name] = param.cpu()
+        torch.distributed.barrier()
+    return cpu_state_dict
+
+
+# # Copied from pytorch (torch/distributed/checkpoint/format_utils.py) to support callbacks to modify state_dict
+# def dcp_to_torch_save(
+#     dcp_checkpoint_dir: Union[str, os.PathLike],
+#     torch_save_path: Union[str, os.PathLike],
+#     callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
+# ):
+#     """
+#     Given a directory containing a DCP checkpoint, this function will convert it into a
+#     Torch save file.
+
+#     Args:
+#         dcp_checkpoint_dir: Directory containing the DCP checkpoint.
+#         torch_save_path: Filename to store the converted Torch save file.
+#         callback_fn: Optional callback function that takes the state_dict as input and returns a modified state_dict.
+
+#     .. warning::
+#         To avoid OOM, it's recommended to only run this function on a single rank.
+#     """
+#     state_dict = {}
+#     _load_state_dict(
+#         state_dict,
+#         storage_reader=FileSystemReader(dcp_checkpoint_dir),
+#         planner=_EmptyStateDictLoadPlanner(),
+#         no_dist=True,
+#     )
+#     if callback_fn is not None:
+#         state_dict = callback_fn(state_dict)
+#     torch.save(state_dict, torch_save_path)
diff --git a/finetrainers/utils/torch.py b/finetrainers/utils/torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..db434d3e361ea03e6c06811c98f08594c4ab773d
--- /dev/null
+++ b/finetrainers/utils/torch.py
@@ -0,0 +1,338 @@
+import math
+import os
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.backends
+import torch.distributed as dist
+import torch.distributed.tensor
+from accelerate import Accelerator
+from diffusers.utils.torch_utils import is_compiled_module
+
+from ..logging import get_logger
+
+
+logger = get_logger()
+
+_STRING_TO_DTYPE = {
+    "fp32": torch.float32,
+    "fp16": torch.float16,
+    "bf16": torch.bfloat16,
+}
+
+_DTYPE_TO_STRING = {v: k for k, v in _STRING_TO_DTYPE.items()}
+
+_HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = False
+
+
+def align_device_and_dtype(
+    x: Union[torch.Tensor, Dict[str, torch.Tensor]],
+    device: Optional[torch.device] = None,
+    dtype: Optional[torch.dtype] = None,
+):
+    if isinstance(x, torch.Tensor):
+        if device is not None:
+            x = x.to(device)
+        if dtype is not None:
+            x = x.to(dtype)
+    elif isinstance(x, dict):
+        if device is not None:
+            x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
+        if dtype is not None:
+            x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
+    return x
+
+
+def _clip_grad_norm_while_handling_failing_dtensor_cases(
+    parameters: Union[torch.Tensor, List[torch.Tensor]],
+    max_norm: float,
+    norm_type: float = 2.0,
+    error_if_nonfinite: bool = False,
+    foreach: Optional[bool] = None,
+    pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
+) -> Optional[torch.Tensor]:
+    global _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES
+
+    if not _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES:
+        try:
+            return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach, pp_mesh)
+        except NotImplementedError as e:
+            if "DTensor does not support cross-mesh operation" in str(e):
+                # https://github.com/pytorch/pytorch/issues/134212
+                logger.warning(
+                    "DTensor does not support cross-mesh operation. If you haven't fully tensor-parallelized your "
+                    "model, while combining other parallelisms such as FSDP, it could be the reason for this error. "
+                    "Gradient clipping will be skipped and gradient norm will not be logged."
+                )
+        except Exception as e:
+            logger.warning(
+                f"An error occurred while clipping gradients: {e}. Gradient clipping will be skipped and gradient "
+                f"norm will not be logged."
+            )
+            _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = True
+    return None
+
+
+# Copied from https://github.com/pytorch/torchtitan/blob/4a169701555ab9bd6ca3769f9650ae3386b84c6e/torchtitan/utils.py#L362
+@torch.no_grad()
+def clip_grad_norm_(
+    parameters: Union[torch.Tensor, List[torch.Tensor]],
+    max_norm: float,
+    norm_type: float = 2.0,
+    error_if_nonfinite: bool = False,
+    foreach: Optional[bool] = None,
+    pp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
+) -> torch.Tensor:
+    r"""
+    Clip the gradient norm of parameters.
+
+    Gradient norm clipping requires computing the gradient norm over the entire model.
+    `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions.
+    We need to manually reduce the gradient norm across PP stages.
+    See https://github.com/pytorch/torchtitan/issues/596 for details.
+
+    Args:
+        parameters (`torch.Tensor` or `List[torch.Tensor]`):
+            Tensors that will have gradients normalized.
+        max_norm (`float`):
+            Maximum norm of the gradients after clipping.
+        norm_type (`float`, defaults to `2.0`):
+            Type of p-norm to use. Can be `inf` for infinity norm.
+        error_if_nonfinite (`bool`, defaults to `False`):
+            If `True`, an error is thrown if the total norm of the gradients from `parameters` is `nan`, `inf`, or `-inf`.
+        foreach (`bool`, defaults to `None`):
+            Use the faster foreach-based implementation. If `None`, use the foreach implementation for CUDA and CPU native tensors
+            and silently fall back to the slow implementation for other device types.
+        pp_mesh (`torch.distributed.device_mesh.DeviceMesh`, defaults to `None`):
+            Pipeline parallel device mesh. If not `None`, will reduce gradient norm across PP stages.
+
+    Returns:
+        `torch.Tensor`:
+            Total norm of the gradients
+    """
+    grads = [p.grad for p in parameters if p.grad is not None]
+
+    # TODO(aryan): Wait for next Pytorch release to use `torch.nn.utils.get_total_norm`
+    # total_norm = torch.nn.utils.get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
+    total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
+
+    # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
+    # We can simply reduce the DTensor to get the total norm in this tensor's process group
+    # and then convert it to a local tensor.
+    # It has two purposes:
+    #   1. to make sure the total norm is computed correctly when PP is used (see below)
+    #   2. to return a reduced total_norm tensor whose .item() would return the correct value
+    if isinstance(total_norm, torch.distributed.tensor.DTensor):
+        # Will reach here if any non-PP parallelism is used.
+        # If only using PP, total_norm will be a local tensor.
+        total_norm = total_norm.full_tensor()
+
+    if pp_mesh is not None:
+        if math.isinf(norm_type):
+            dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
+        else:
+            total_norm **= norm_type
+            dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group())
+            total_norm **= 1.0 / norm_type
+
+    _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
+    return total_norm
+
+
+def enable_determinism(
+    seed: int,
+    world_mesh: Optional[torch.distributed.DeviceMesh] = None,
+    deterministic: bool = False,
+) -> None:
+    r"""
+    For all ranks within the same DTensor SPMD group, the same seed will be set.
+    For PP groups, different seeds will be set.
+    """
+    if deterministic:
+        logger.info("Deterministic algorithms are enabled (expect performance degradation).")
+        torch.use_deterministic_algorithms(True)
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+        # https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
+        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+
+    if not world_mesh:
+        if seed is not None:
+            torch.manual_seed(seed)
+            os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
+            logger.debug(f"Single-process job using seed: {seed}")
+        return
+
+    # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh,
+    # and choose a unique seed for each rank on the PP mesh.
+    if torch.distributed.distributed_c10d.get_world_size() > 1 and "pp" in world_mesh.mesh_dim_names:
+        pp_mesh = world_mesh["pp"]
+        seed += pp_mesh.get_local_rank()
+        seed %= 2**64
+
+        info = {
+            "pp_rank": pp_mesh.get_local_rank(),
+            "global_rank": torch.distributed.distributed_c10d.get_rank(),
+            "seed": seed,
+        }
+        logger.debug(f"Enabling determinism: {info}")
+        spmd_mesh_dims = list(filter(lambda name: name != "pp", world_mesh.mesh_dim_names))
+        spmd_mesh = world_mesh[spmd_mesh_dims] if len(spmd_mesh_dims) else None
+    else:
+        spmd_mesh = world_mesh
+        info = {"global_rank": torch.distributed.distributed_c10d.get_rank(), "seed": seed}
+        logger.debug(f"Enabling determinism: {info}")
+
+    # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency
+    torch.manual_seed(seed)
+    # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1]
+    os.environ["PYTHONHASHSEED"] = str(seed % 2**32)
+
+    # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh.
+    # IF PP is also used, this seed is unique per PP rank.
+    if spmd_mesh and spmd_mesh.get_coordinate() is not None:
+        torch.distributed.tensor._random.manual_seed(seed, spmd_mesh)
+
+
+def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor:
+    assert len(tensor.shape) <= ndim
+    return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape)))
+
+
+def get_device_info():
+    from torch._utils import _get_available_device_type, _get_device_module
+
+    device_type = _get_available_device_type()
+    if device_type is None:
+        device_type = "cuda"
+    device_module = _get_device_module(device_type)
+    return device_type, device_module
+
+
+def get_dtype_from_string(dtype: str):
+    return _STRING_TO_DTYPE[dtype]
+
+
+def get_string_from_dtype(dtype: torch.dtype):
+    return _DTYPE_TO_STRING[dtype]
+
+
+def set_requires_grad(models: Union[torch.nn.Module, List[torch.nn.Module]], value: bool) -> None:
+    if isinstance(models, torch.nn.Module):
+        models = [models]
+    for model in models:
+        if model is not None:
+            model.requires_grad_(value)
+
+
+def synchronize_device() -> None:
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+    elif torch.backends.mps.is_available():
+        torch.mps.synchronize()
+
+
+def unwrap_model(accelerator: Accelerator, model):
+    model = accelerator.unwrap_model(model)
+    model = model._orig_mod if is_compiled_module(model) else model
+    return model
+
+
+# TODO(aryan): remove everything below this after next torch release
+def _get_total_norm(
+    tensors: Union[torch.Tensor, List[torch.Tensor]],
+    norm_type: float = 2.0,
+    error_if_nonfinite: bool = False,
+    foreach: Optional[bool] = None,
+) -> torch.Tensor:
+    if isinstance(tensors, torch.Tensor):
+        tensors = [tensors]
+    else:
+        tensors = list(tensors)
+    norm_type = float(norm_type)
+    if len(tensors) == 0:
+        return torch.tensor(0.0)
+    first_device = tensors[0].device
+    grouped_tensors: dict[
+        tuple[torch.device, torch.dtype], tuple[list[list[torch.Tensor]], list[int]]
+    ] = _group_tensors_by_device_and_dtype(
+        [tensors]  # type: ignore[list-item]
+    )  # type: ignore[assignment]
+
+    norms: List[torch.Tensor] = []
+    for (device, _), ([device_tensors], _) in grouped_tensors.items():
+        if (foreach is None and _has_foreach_support(device_tensors, device)) or (
+            foreach and _device_has_foreach_support(device)
+        ):
+            norms.extend(torch._foreach_norm(device_tensors, norm_type))
+        elif foreach:
+            raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors")
+        else:
+            norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_tensors])
+
+    total_norm = torch.linalg.vector_norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
+
+    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
+        raise RuntimeError(
+            f"The total norm of order {norm_type} for gradients from "
+            "`parameters` is non-finite, so it cannot be clipped. To disable "
+            "this error and scale the gradients by the non-finite norm anyway, "
+            "set `error_if_nonfinite=False`"
+        )
+    return total_norm
+
+
+@torch.no_grad()
+def _clip_grads_with_norm_(
+    parameters: Union[torch.Tensor, List[torch.Tensor]],
+    max_norm: float,
+    total_norm: torch.Tensor,
+    foreach: Optional[bool] = None,
+) -> None:
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    grads = [p.grad for p in parameters if p.grad is not None]
+    max_norm = float(max_norm)
+    if len(grads) == 0:
+        return
+    grouped_grads: dict[
+        Tuple[torch.device, torch.dtype], Tuple[List[List[torch.Tensor]], List[int]]
+    ] = _group_tensors_by_device_and_dtype([grads])  # type: ignore[assignment]
+
+    clip_coef = max_norm / (total_norm + 1e-6)
+    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
+    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
+    # when the gradients do not reside in CPU memory.
+    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
+    for (device, _), ([device_grads], _) in grouped_grads.items():
+        if (foreach is None and _has_foreach_support(device_grads, device)) or (
+            foreach and _device_has_foreach_support(device)
+        ):
+            torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
+        elif foreach:
+            raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors")
+        else:
+            clip_coef_clamped_device = clip_coef_clamped.to(device)
+            for g in device_grads:
+                g.mul_(clip_coef_clamped_device)
+
+
+def _get_foreach_kernels_supported_devices() -> list[str]:
+    r"""Return the device type list that supports foreach kernels."""
+    return ["cuda", "xpu", torch._C._get_privateuse1_backend_name()]
+
+
+@torch.no_grad()
+def _group_tensors_by_device_and_dtype(
+    tensorlistlist: List[List[Optional[torch.Tensor]]],
+    with_indices: bool = False,
+) -> dict[tuple[torch.device, torch.dtype], tuple[List[List[Optional[torch.Tensor]]], List[int]]]:
+    return torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
+
+
+def _device_has_foreach_support(device: torch.device) -> bool:
+    return device.type in (_get_foreach_kernels_supported_devices() + ["cpu"]) and not torch.jit.is_scripting()
+
+
+def _has_foreach_support(tensors: List[torch.Tensor], device: torch.device) -> bool:
+    return _device_has_foreach_support(device) and all(t is None or type(t) in [torch.Tensor] for t in tensors)
diff --git a/finetrainers/utils/torch_utils.py b/finetrainers/utils/torch_utils.py
deleted file mode 100644
index 1c6ef5df9ea5a832e671cbaaff008cd6f48078b0..0000000000000000000000000000000000000000
--- a/finetrainers/utils/torch_utils.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from typing import Dict, Optional, Union
-
-import torch
-from accelerate import Accelerator
-from diffusers.utils.torch_utils import is_compiled_module
-
-
-def unwrap_model(accelerator: Accelerator, model):
-    model = accelerator.unwrap_model(model)
-    model = model._orig_mod if is_compiled_module(model) else model
-    return model
-
-
-def align_device_and_dtype(
-    x: Union[torch.Tensor, Dict[str, torch.Tensor]],
-    device: Optional[torch.device] = None,
-    dtype: Optional[torch.dtype] = None,
-):
-    if isinstance(x, torch.Tensor):
-        if device is not None:
-            x = x.to(device)
-        if dtype is not None:
-            x = x.to(dtype)
-    elif isinstance(x, dict):
-        if device is not None:
-            x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
-        if dtype is not None:
-            x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
-    return x
-
-
-def expand_tensor_dims(tensor, ndim):
-    while len(tensor.shape) < ndim:
-        tensor = tensor.unsqueeze(-1)
-    return tensor
diff --git a/train.py b/train.py
index 088c0617ae0c37fbc7ef5751fddc2d484f36d1a4..3260752aa563851527c867518c49fa7b1698d3e6 100644
--- a/train.py
+++ b/train.py
@@ -1,12 +1,12 @@
-import logging
+import sys
 import traceback
 
-from finetrainers import Trainer, parse_arguments
-from finetrainers.constants import FINETRAINERS_LOG_LEVEL
+from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger
+from finetrainers.config import _get_model_specifiction_cls
+from finetrainers.trainer.sft_trainer.config import SFTFullRankConfig, SFTLowRankConfig
 
 
-logger = logging.getLogger("finetrainers")
-logger.setLevel(FINETRAINERS_LOG_LEVEL)
+logger = get_logger()
 
 
 def main():
@@ -22,18 +22,52 @@ def main():
         )
 
     try:
-        args = parse_arguments()
-        trainer = Trainer(args)
-
-        trainer.prepare_dataset()
-        trainer.prepare_models()
-        trainer.prepare_precomputations()
-        trainer.prepare_trainable_parameters()
-        trainer.prepare_optimizer()
-        trainer.prepare_for_training()
-        trainer.prepare_trackers()
-        trainer.train()
-        # trainer.evaluate()
+        args = BaseArgs()
+
+        argv = [y.strip() for x in sys.argv for y in x.split()]
+        training_type_index = argv.index("--training_type")
+        if training_type_index == -1:
+            raise ValueError("Training type not provided in command line arguments.")
+
+        training_type = argv[training_type_index + 1]
+        training_cls = None
+        if training_type == TrainingType.LORA:
+            training_cls = SFTLowRankConfig
+        elif training_type == TrainingType.FULL_FINETUNE:
+            training_cls = SFTFullRankConfig
+        else:
+            raise ValueError(f"Training type {training_type} not supported.")
+
+        training_config = training_cls()
+        args.extend_args(training_config.add_args, training_config.map_args, training_config.validate_args)
+        args = args.parse_args()
+
+        model_specification_cls = _get_model_specifiction_cls(args.model_name, args.training_type)
+        model_specification = model_specification_cls(
+            pretrained_model_name_or_path=args.pretrained_model_name_or_path,
+            tokenizer_id=args.tokenizer_id,
+            tokenizer_2_id=args.tokenizer_2_id,
+            tokenizer_3_id=args.tokenizer_3_id,
+            text_encoder_id=args.text_encoder_id,
+            text_encoder_2_id=args.text_encoder_2_id,
+            text_encoder_3_id=args.text_encoder_3_id,
+            transformer_id=args.transformer_id,
+            vae_id=args.vae_id,
+            text_encoder_dtype=args.text_encoder_dtype,
+            text_encoder_2_dtype=args.text_encoder_2_dtype,
+            text_encoder_3_dtype=args.text_encoder_3_dtype,
+            transformer_dtype=args.transformer_dtype,
+            vae_dtype=args.vae_dtype,
+            revision=args.revision,
+            cache_dir=args.cache_dir,
+        )
+
+        if args.training_type in [TrainingType.LORA, TrainingType.FULL_FINETUNE]:
+            trainer = SFTTrainer(args, model_specification)
+        else:
+            raise ValueError(f"Training type {args.training_type} not supported.")
+
+        trainer.run()
 
     except KeyboardInterrupt:
         logger.info("Received keyboard interrupt. Exiting...")