jbilcke-hf HF staff commited on
Commit
80ebcb3
·
1 Parent(s): 76eb17f

upgrading finetrainers (and losing my extra code + improvements)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. accelerate_configs/uncompiled_4.yaml +17 -0
  2. finetrainers/__init__.py +5 -2
  3. finetrainers/args.py +447 -778
  4. finetrainers/config.py +52 -0
  5. finetrainers/constants.py +3 -0
  6. finetrainers/data/__init__.py +19 -0
  7. finetrainers/data/_artifact.py +29 -0
  8. finetrainers/data/dataloader.py +40 -0
  9. finetrainers/data/dataset.py +844 -0
  10. finetrainers/data/precomputation.py +163 -0
  11. finetrainers/data/sampler.py +58 -0
  12. finetrainers/data/utils.py +20 -0
  13. finetrainers/dataset.py +0 -564
  14. finetrainers/functional/__init__.py +16 -0
  15. finetrainers/functional/diffusion.py +11 -0
  16. finetrainers/functional/image.py +54 -0
  17. finetrainers/functional/text.py +26 -0
  18. finetrainers/functional/video.py +94 -0
  19. finetrainers/hooks/__init__.py +0 -1
  20. finetrainers/hooks/hooks.py +0 -176
  21. finetrainers/hooks/layerwise_upcasting.py +0 -140
  22. finetrainers/logging.py +111 -0
  23. finetrainers/models/__init__.py +1 -33
  24. finetrainers/models/cogvideox/__init__.py +1 -2
  25. finetrainers/models/cogvideox/base_specification.py +424 -0
  26. finetrainers/models/cogvideox/full_finetune.py +0 -32
  27. finetrainers/models/cogvideox/lora.py +0 -334
  28. finetrainers/models/hunyuan_video/__init__.py +1 -2
  29. finetrainers/models/hunyuan_video/base_specification.py +413 -0
  30. finetrainers/models/hunyuan_video/full_finetune.py +0 -30
  31. finetrainers/models/hunyuan_video/lora.py +0 -368
  32. finetrainers/models/ltx_video/__init__.py +1 -2
  33. finetrainers/models/ltx_video/base_specification.py +522 -0
  34. finetrainers/models/ltx_video/full_finetune.py +0 -30
  35. finetrainers/models/ltx_video/lora.py +0 -331
  36. finetrainers/models/modeling_utils.py +292 -0
  37. finetrainers/models/utils.py +62 -0
  38. finetrainers/models/wan/__init__.py +1 -0
  39. finetrainers/models/wan/base_specification.py +378 -0
  40. finetrainers/optimizer.py +449 -0
  41. finetrainers/parallel/__init__.py +22 -0
  42. finetrainers/parallel/accelerate.py +218 -0
  43. finetrainers/parallel/base.py +96 -0
  44. finetrainers/parallel/deepspeed.py +7 -0
  45. finetrainers/parallel/ptd.py +228 -0
  46. finetrainers/parallel/utils.py +99 -0
  47. finetrainers/patches/__init__.py +23 -0
  48. finetrainers/{patches.py → patches/dependencies/peft/patch.py} +3 -28
  49. finetrainers/patches/models/ltx_video/patch.py +127 -0
  50. finetrainers/patches/utils.py +18 -0
accelerate_configs/uncompiled_4.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ downcast_bf16: 'no'
5
+ enable_cpu_affinity: false
6
+ gpu_ids: 0,1,2,3
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 4
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
finetrainers/__init__.py CHANGED
@@ -1,2 +1,5 @@
1
- from .args import Args, parse_arguments
2
- from .trainer import Trainer
 
 
 
 
1
+ from .args import BaseArgs
2
+ from .config import ModelType, TrainingType
3
+ from .logging import get_logger
4
+ from .models import ModelSpecification
5
+ from .trainer import SFTTrainer
finetrainers/args.py CHANGED
@@ -1,14 +1,21 @@
1
  import argparse
 
 
2
  import sys
3
- from typing import Any, Dict, List, Optional, Tuple
4
 
5
  import torch
6
 
7
- from .constants import DEFAULT_IMAGE_RESOLUTION_BUCKETS, DEFAULT_VIDEO_RESOLUTION_BUCKETS
8
- from .models import SUPPORTED_MODEL_CONFIGS
 
 
9
 
10
 
11
- class Args:
 
 
 
12
  r"""
13
  The arguments for the finetrainers training script.
14
 
@@ -19,6 +26,19 @@ class Args:
19
  TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show
20
  memory requirements per model, per training type with sensible training settings.
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  MODEL ARGUMENTS
23
  ---------------
24
  model_name (`str`):
@@ -33,6 +53,22 @@ class Args:
33
  storage requirements.
34
  cache_dir (`str`, defaults to `None`):
35
  The directory where the downloaded models and datasets will be stored, or loaded from.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
37
  Data type for the text encoder when generating text embeddings.
38
  text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
@@ -54,41 +90,47 @@ class Args:
54
 
55
  DATASET ARGUMENTS
56
  -----------------
57
- data_root (`str`):
58
- A folder containing the training data.
59
- dataset_file (`str`, defaults to `None`):
60
- Path to a CSV/JSON/JSONL file containing metadata for training. This should be provided if you're not using
61
- a directory dataset format containing a simple `prompts.txt` and `videos.txt`/`images.txt` for example.
62
- video_column (`str`):
63
- The column of the dataset containing videos. Or, the name of the file in `data_root` folder containing the
64
- line-separated path to video data.
65
- caption_column (`str`):
66
- The column of the dataset containing the instance prompt for each video. Or, the name of the file in
67
- `data_root` folder containing the line-separated instance prompts.
68
- id_token (`str`, defaults to `None`):
69
- Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training.
70
- image_resolution_buckets (`List[Tuple[int, int]]`, defaults to `None`):
71
- Resolution buckets for images. This should be a list of integer tuples, where each tuple represents the
72
- resolution (height, width) of the image. All images will be resized to the nearest bucket resolution.
73
- video_resolution_buckets (`List[Tuple[int, int, int]]`, defaults to `None`):
74
- Resolution buckets for videos. This should be a list of integer tuples, where each tuple represents the
75
- resolution (num_frames, height, width) of the video. All videos will be resized to the nearest bucket
76
- resolution.
77
- video_reshape_mode (`str`, defaults to `None`):
78
- All input videos are reshaped to this mode. Choose between ['center', 'random', 'none'].
79
- TODO(aryan): We don't support this.
80
- caption_dropout_p (`float`, defaults to `0.00`):
81
- Probability of dropout for the caption tokens. This is useful to improve the unconditional generation
82
- quality of the model.
83
- caption_dropout_technique (`str`, defaults to `empty`):
84
- Technique to use for caption dropout. Choose between ['empty', 'zero']. Some models apply caption dropout
85
- by setting the prompt condition to an empty string, while others zero-out the text embedding tensors.
86
- precompute_conditions (`bool`, defaults to `False`):
87
- Whether or not to precompute the conditionings for the model. This is useful for faster training, and
88
- reduces the memory requirements.
89
- remove_common_llm_caption_prefixes (`bool`, defaults to `False`):
90
- Whether or not to remove common LLM caption prefixes. This is useful for improving the quality of the
91
- generated text.
 
 
 
 
 
 
92
 
93
  DATALOADER_ARGUMENTS
94
  --------------------
@@ -136,16 +178,11 @@ class Args:
136
  A seed for reproducible training.
137
  batch_size (`int`, defaults to `1`):
138
  Per-device batch size.
139
- train_epochs (`int`, defaults to `1`):
140
- Number of training epochs.
141
- train_steps (`int`, defaults to `None`):
142
- Total number of training steps to perform. If provided, overrides `train_epochs`.
143
- rank (`int`, defaults to `128`):
144
- The rank for LoRA matrices.
145
- lora_alpha (`float`, defaults to `64`):
146
- The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.
147
- target_modules (`List[str]`, defaults to `["to_k", "to_q", "to_v", "to_out.0"]`):
148
- The target modules for LoRA. Make sure to modify this based on the model.
149
  gradient_accumulation_steps (`int`, defaults to `1`):
150
  Number of gradients steps to accumulate before performing an optimizer step.
151
  gradient_checkpointing (`bool`, defaults to `False`):
@@ -164,13 +201,11 @@ class Args:
164
  OPTIMIZER ARGUMENTS
165
  -------------------
166
  optimizer (`str`, defaults to `adamw`):
167
- The optimizer type to use. Choose between ['adam', 'adamw'].
168
- use_8bit_bnb (`bool`, defaults to `False`):
169
- Whether to use 8bit variant of the `optimizer` using `bitsandbytes`.
170
  lr (`float`, defaults to `1e-4`):
171
  Initial learning rate (after the potential warmup period) to use.
172
- scale_lr (`bool`, defaults to `False`):
173
- Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.
174
  lr_scheduler (`str`, defaults to `cosine_with_restarts`):
175
  The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',
176
  'constant', 'constant_with_warmup'].
@@ -192,29 +227,21 @@ class Args:
192
 
193
  VALIDATION ARGUMENTS
194
  --------------------
195
- validation_prompts (`List[str]`, defaults to `None`):
196
- List of prompts to use for validation. If not provided, a random prompt will be selected from the training
197
- dataset.
198
- validation_images (`List[str]`, defaults to `None`):
199
- List of image paths to use for validation.
200
- validation_videos (`List[str]`, defaults to `None`):
201
- List of video paths to use for validation.
202
- validation_heights (`List[int]`, defaults to `None`):
203
- List of heights for the validation videos.
204
- validation_widths (`List[int]`, defaults to `None`):
205
- List of widths for the validation videos.
206
- validation_num_frames (`List[int]`, defaults to `None`):
207
- List of number of frames for the validation videos.
208
- num_validation_videos_per_prompt (`int`, defaults to `1`):
209
- Number of videos to use for validation per prompt.
210
- validation_every_n_epochs (`int`, defaults to `None`):
211
- Perform validation every `n` training epochs.
212
- validation_every_n_steps (`int`, defaults to `None`):
213
- Perform validation every `n` training steps.
214
  enable_model_cpu_offload (`bool`, defaults to `False`):
215
  Whether or not to offload different modeling components to CPU during validation.
216
- validation_frame_rate (`int`, defaults to `25`):
217
- Frame rate to use for the validation videos. This value is defaulted to 25, as used in LTX Video pipeline.
218
 
219
  MISCELLANEOUS ARGUMENTS
220
  -----------------------
@@ -230,20 +257,44 @@ class Args:
230
  The directory where the model checkpoints and logs will be stored.
231
  logging_dir (`str`, defaults to `logs`):
232
  The directory where the logs will be stored.
 
 
233
  allow_tf32 (`bool`, defaults to `False`):
234
  Whether or not to allow the use of TF32 matmul on compatible hardware.
235
  nccl_timeout (`int`, defaults to `1800`):
236
  Timeout for the NCCL communication.
237
  report_to (`str`, defaults to `wandb`):
238
  The name of the logger to use for logging training metrics. Choose between ['wandb'].
 
 
 
 
 
 
239
  """
240
 
 
 
 
 
 
 
 
 
241
  # Model arguments
242
  model_name: str = None
243
  pretrained_model_name_or_path: str = None
244
  revision: Optional[str] = None
245
  variant: Optional[str] = None
246
  cache_dir: Optional[str] = None
 
 
 
 
 
 
 
 
247
  text_encoder_dtype: torch.dtype = torch.bfloat16
248
  text_encoder_2_dtype: torch.dtype = torch.bfloat16
249
  text_encoder_3_dtype: torch.dtype = torch.bfloat16
@@ -263,18 +314,11 @@ class Args:
263
  ]
264
 
265
  # Dataset arguments
266
- data_root: str = None
267
- dataset_file: Optional[str] = None
268
- video_column: str = None
269
- caption_column: str = None
270
- id_token: Optional[str] = None
271
- image_resolution_buckets: List[Tuple[int, int]] = None
272
- video_resolution_buckets: List[Tuple[int, int, int]] = None
273
- video_reshape_mode: Optional[str] = None
274
- caption_dropout_p: float = 0.00
275
- caption_dropout_technique: str = "empty"
276
- precompute_conditions: bool = False
277
- remove_common_llm_caption_prefixes: bool = False
278
 
279
  # Dataloader arguments
280
  dataloader_num_workers: int = 0
@@ -296,11 +340,8 @@ class Args:
296
  training_type: str = None
297
  seed: int = 42
298
  batch_size: int = 1
299
- train_epochs: int = 1
300
- train_steps: int = None
301
- rank: int = 128
302
- lora_alpha: float = 64
303
- target_modules: List[str] = ["to_k", "to_q", "to_v", "to_out.0"]
304
  gradient_accumulation_steps: int = 1
305
  gradient_checkpointing: bool = False
306
  checkpointing_steps: int = 500
@@ -311,9 +352,7 @@ class Args:
311
 
312
  # Optimizer arguments
313
  optimizer: str = "adamw"
314
- use_8bit_bnb: bool = False
315
  lr: float = 1e-4
316
- scale_lr: bool = False
317
  lr_scheduler: str = "cosine_with_restarts"
318
  lr_warmup_steps: int = 0
319
  lr_num_cycles: int = 1
@@ -326,17 +365,9 @@ class Args:
326
  max_grad_norm: float = 1.0
327
 
328
  # Validation arguments
329
- validation_prompts: List[str] = None
330
- validation_images: List[str] = None
331
- validation_videos: List[str] = None
332
- validation_heights: List[int] = None
333
- validation_widths: List[int] = None
334
- validation_num_frames: List[int] = None
335
- num_validation_videos_per_prompt: int = 1
336
- validation_every_n_epochs: Optional[int] = None
337
- validation_every_n_steps: Optional[int] = None
338
  enable_model_cpu_offload: bool = False
339
- validation_frame_rate: int = 25
340
 
341
  # Miscellaneous arguments
342
  tracker_name: str = "finetrainers"
@@ -345,664 +376,343 @@ class Args:
345
  hub_model_id: Optional[str] = None
346
  output_dir: str = None
347
  logging_dir: Optional[str] = "logs"
 
348
  allow_tf32: bool = False
349
- nccl_timeout: int = 1800 # 30 minutes
 
350
  report_to: str = "wandb"
 
351
 
352
  def to_dict(self) -> Dict[str, Any]:
353
- return {
354
- "model_arguments": {
355
- "model_name": self.model_name,
356
- "pretrained_model_name_or_path": self.pretrained_model_name_or_path,
357
- "revision": self.revision,
358
- "variant": self.variant,
359
- "cache_dir": self.cache_dir,
360
- "text_encoder_dtype": self.text_encoder_dtype,
361
- "text_encoder_2_dtype": self.text_encoder_2_dtype,
362
- "text_encoder_3_dtype": self.text_encoder_3_dtype,
363
- "transformer_dtype": self.transformer_dtype,
364
- "vae_dtype": self.vae_dtype,
365
- "layerwise_upcasting_modules": self.layerwise_upcasting_modules,
366
- "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
367
- "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
368
- },
369
- "dataset_arguments": {
370
- "data_root": self.data_root,
371
- "dataset_file": self.dataset_file,
372
- "video_column": self.video_column,
373
- "caption_column": self.caption_column,
374
- "id_token": self.id_token,
375
- "image_resolution_buckets": self.image_resolution_buckets,
376
- "video_resolution_buckets": self.video_resolution_buckets,
377
- "video_reshape_mode": self.video_reshape_mode,
378
- "caption_dropout_p": self.caption_dropout_p,
379
- "caption_dropout_technique": self.caption_dropout_technique,
380
- "precompute_conditions": self.precompute_conditions,
381
- "remove_common_llm_caption_prefixes": self.remove_common_llm_caption_prefixes,
382
- },
383
- "dataloader_arguments": {
384
- "dataloader_num_workers": self.dataloader_num_workers,
385
- "pin_memory": self.pin_memory,
386
- },
387
- "diffusion_arguments": {
388
- "flow_resolution_shifting": self.flow_resolution_shifting,
389
- "flow_base_seq_len": self.flow_base_seq_len,
390
- "flow_max_seq_len": self.flow_max_seq_len,
391
- "flow_base_shift": self.flow_base_shift,
392
- "flow_max_shift": self.flow_max_shift,
393
- "flow_shift": self.flow_shift,
394
- "flow_weighting_scheme": self.flow_weighting_scheme,
395
- "flow_logit_mean": self.flow_logit_mean,
396
- "flow_logit_std": self.flow_logit_std,
397
- "flow_mode_scale": self.flow_mode_scale,
398
- },
399
- "training_arguments": {
400
- "training_type": self.training_type,
401
- "seed": self.seed,
402
- "batch_size": self.batch_size,
403
- "train_epochs": self.train_epochs,
404
- "train_steps": self.train_steps,
405
- "rank": self.rank,
406
- "lora_alpha": self.lora_alpha,
407
- "target_modules": self.target_modules,
408
- "gradient_accumulation_steps": self.gradient_accumulation_steps,
409
- "gradient_checkpointing": self.gradient_checkpointing,
410
- "checkpointing_steps": self.checkpointing_steps,
411
- "checkpointing_limit": self.checkpointing_limit,
412
- "resume_from_checkpoint": self.resume_from_checkpoint,
413
- "enable_slicing": self.enable_slicing,
414
- "enable_tiling": self.enable_tiling,
415
- },
416
- "optimizer_arguments": {
417
- "optimizer": self.optimizer,
418
- "use_8bit_bnb": self.use_8bit_bnb,
419
- "lr": self.lr,
420
- "scale_lr": self.scale_lr,
421
- "lr_scheduler": self.lr_scheduler,
422
- "lr_warmup_steps": self.lr_warmup_steps,
423
- "lr_num_cycles": self.lr_num_cycles,
424
- "lr_power": self.lr_power,
425
- "beta1": self.beta1,
426
- "beta2": self.beta2,
427
- "beta3": self.beta3,
428
- "weight_decay": self.weight_decay,
429
- "epsilon": self.epsilon,
430
- "max_grad_norm": self.max_grad_norm,
431
- },
432
- "validation_arguments": {
433
- "validation_prompts": self.validation_prompts,
434
- "validation_images": self.validation_images,
435
- "validation_videos": self.validation_videos,
436
- "num_validation_videos_per_prompt": self.num_validation_videos_per_prompt,
437
- "validation_every_n_epochs": self.validation_every_n_epochs,
438
- "validation_every_n_steps": self.validation_every_n_steps,
439
- "enable_model_cpu_offload": self.enable_model_cpu_offload,
440
- "validation_frame_rate": self.validation_frame_rate,
441
- },
442
- "miscellaneous_arguments": {
443
- "tracker_name": self.tracker_name,
444
- "push_to_hub": self.push_to_hub,
445
- "hub_token": self.hub_token,
446
- "hub_model_id": self.hub_model_id,
447
- "output_dir": self.output_dir,
448
- "logging_dir": self.logging_dir,
449
- "allow_tf32": self.allow_tf32,
450
- "nccl_timeout": self.nccl_timeout,
451
- "report_to": self.report_to,
452
- },
453
  }
454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
- # TODO(aryan): handle more informative messages
457
- _IS_ARGUMENTS_REQUIRED = "--list_models" not in sys.argv
458
-
459
-
460
- def parse_arguments() -> Args:
461
- parser = argparse.ArgumentParser()
462
 
463
- if _IS_ARGUMENTS_REQUIRED:
464
- _add_model_arguments(parser)
465
- _add_dataset_arguments(parser)
466
- _add_dataloader_arguments(parser)
467
- _add_diffusion_arguments(parser)
468
- _add_training_arguments(parser)
469
- _add_optimizer_arguments(parser)
470
- _add_validation_arguments(parser)
471
- _add_miscellaneous_arguments(parser)
 
 
 
472
 
473
- args = parser.parse_args()
474
- return _map_to_args_type(args)
475
- else:
476
- _add_helper_arguments(parser)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
- args = parser.parse_args()
479
- _display_helper_messages(args)
480
- sys.exit(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
- def validate_args(args: Args):
484
- _validated_model_args(args)
485
- _validate_training_args(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  _validate_validation_args(args)
487
 
488
 
489
- def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
490
- parser.add_argument(
491
- "--model_name",
492
- type=str,
493
- required=True,
494
- choices=list(SUPPORTED_MODEL_CONFIGS.keys()),
495
- help="Name of model to train.",
496
- )
497
- parser.add_argument(
498
- "--pretrained_model_name_or_path",
499
- type=str,
500
- required=True,
501
- help="Path to pretrained model or model identifier from huggingface.co/models.",
502
- )
503
  parser.add_argument(
504
- "--revision",
505
  type=str,
506
- default=None,
507
- required=False,
508
- help="Revision of pretrained model identifier from huggingface.co/models.",
509
  )
 
 
 
 
 
 
 
 
510
  parser.add_argument(
511
- "--variant",
512
- type=str,
513
- default=None,
514
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
515
- )
516
- parser.add_argument(
517
- "--cache_dir",
518
- type=str,
519
- default=None,
520
- help="The directory where the downloaded models and datasets will be stored.",
521
- )
522
- parser.add_argument("--text_encoder_dtype", type=str, default="bf16", help="Data type for the text encoder.")
523
- parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16", help="Data type for the text encoder 2.")
524
- parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16", help="Data type for the text encoder 3.")
525
- parser.add_argument("--transformer_dtype", type=str, default="bf16", help="Data type for the transformer model.")
526
- parser.add_argument("--vae_dtype", type=str, default="bf16", help="Data type for the VAE model.")
527
- parser.add_argument(
528
- "--layerwise_upcasting_modules",
529
- type=str,
530
- default=[],
531
- nargs="+",
532
- choices=["transformer"],
533
- help="Modules that should have fp8 storage weights but higher precision computation.",
534
- )
535
  parser.add_argument(
536
  "--layerwise_upcasting_storage_dtype",
537
  type=str,
538
  default="float8_e4m3fn",
539
  choices=["float8_e4m3fn", "float8_e5m2"],
540
- help="Data type for the layerwise upcasting storage.",
541
  )
542
  parser.add_argument(
543
  "--layerwise_upcasting_skip_modules_pattern",
544
  type=str,
545
  default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
546
  nargs="+",
547
- help="Modules to skip for layerwise upcasting.",
548
  )
549
 
550
 
551
  def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
552
- def parse_resolution_bucket(resolution_bucket: str) -> Tuple[int, ...]:
553
- return tuple(map(int, resolution_bucket.split("x")))
554
-
555
- def parse_image_resolution_bucket(resolution_bucket: str) -> Tuple[int, int]:
556
- resolution_bucket = parse_resolution_bucket(resolution_bucket)
557
- assert (
558
- len(resolution_bucket) == 2
559
- ), f"Expected 2D resolution bucket, got {len(resolution_bucket)}D resolution bucket"
560
- return resolution_bucket
561
-
562
- def parse_video_resolution_bucket(resolution_bucket: str) -> Tuple[int, int, int]:
563
- resolution_bucket = parse_resolution_bucket(resolution_bucket)
564
- assert (
565
- len(resolution_bucket) == 3
566
- ), f"Expected 3D resolution bucket, got {len(resolution_bucket)}D resolution bucket"
567
- return resolution_bucket
568
-
569
- parser.add_argument(
570
- "--data_root",
571
- type=str,
572
- required=True,
573
- help=("A folder containing the training data."),
574
- )
575
- parser.add_argument(
576
- "--dataset_file",
577
- type=str,
578
- default=None,
579
- help=("Path to a CSV file if loading prompts/video paths using this format."),
580
- )
581
- parser.add_argument(
582
- "--video_column",
583
- type=str,
584
- default="video",
585
- help="The column of the dataset containing videos. Or, the name of the file in `--data_root` folder containing the line-separated path to video data.",
586
- )
587
- parser.add_argument(
588
- "--caption_column",
589
- type=str,
590
- default="text",
591
- help="The column of the dataset containing the instance prompt for each video. Or, the name of the file in `--data_root` folder containing the line-separated instance prompts.",
592
- )
593
- parser.add_argument(
594
- "--id_token",
595
- type=str,
596
- default=None,
597
- help="Identifier token appended to the start of each prompt if provided.",
598
- )
599
- parser.add_argument(
600
- "--image_resolution_buckets",
601
- type=parse_image_resolution_bucket,
602
- default=None,
603
- nargs="+",
604
- help="Resolution buckets for images.",
605
- )
606
- parser.add_argument(
607
- "--video_resolution_buckets",
608
- type=parse_video_resolution_bucket,
609
- default=None,
610
- nargs="+",
611
- help="Resolution buckets for videos.",
612
- )
613
- parser.add_argument(
614
- "--video_reshape_mode",
615
- type=str,
616
- default=None,
617
- help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
618
- )
619
- parser.add_argument(
620
- "--caption_dropout_p",
621
- type=float,
622
- default=0.00,
623
- help="Probability of dropout for the caption tokens.",
624
- )
625
- parser.add_argument(
626
- "--caption_dropout_technique",
627
- type=str,
628
- default="empty",
629
- choices=["empty", "zero"],
630
- help="Technique to use for caption dropout.",
631
- )
632
- parser.add_argument(
633
- "--precompute_conditions",
634
- action="store_true",
635
- help="Whether or not to precompute the conditionings for the model.",
636
- )
637
- parser.add_argument(
638
- "--remove_common_llm_caption_prefixes",
639
- action="store_true",
640
- help="Whether or not to remove common LLM caption prefixes.",
641
- )
642
 
643
 
644
  def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
645
- parser.add_argument(
646
- "--dataloader_num_workers",
647
- type=int,
648
- default=0,
649
- help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
650
- )
651
- parser.add_argument(
652
- "--pin_memory",
653
- action="store_true",
654
- help="Whether or not to use the pinned memory setting in pytorch dataloader.",
655
- )
656
 
657
 
658
  def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
659
- parser.add_argument(
660
- "--flow_resolution_shifting",
661
- action="store_true",
662
- help="Resolution-dependent shifting of timestep schedules.",
663
- )
664
- parser.add_argument(
665
- "--flow_base_seq_len",
666
- type=int,
667
- default=256,
668
- help="Base image/video sequence length for the diffusion model.",
669
- )
670
- parser.add_argument(
671
- "--flow_max_seq_len",
672
- type=int,
673
- default=4096,
674
- help="Maximum image/video sequence length for the diffusion model.",
675
- )
676
- parser.add_argument(
677
- "--flow_base_shift",
678
- type=float,
679
- default=0.5,
680
- help="Base shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)",
681
- )
682
- parser.add_argument(
683
- "--flow_max_shift",
684
- type=float,
685
- default=1.15,
686
- help="Maximum shift as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2403.03206)",
687
- )
688
- parser.add_argument(
689
- "--flow_shift",
690
- type=float,
691
- default=1.0,
692
- help="Shift value to use for the flow matching timestep schedule.",
693
- )
694
  parser.add_argument(
695
  "--flow_weighting_scheme",
696
  type=str,
697
  default="none",
698
  choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
699
- help='We default to the "none" weighting scheme for uniform sampling and uniform loss',
700
- )
701
- parser.add_argument(
702
- "--flow_logit_mean",
703
- type=float,
704
- default=0.0,
705
- help="Mean to use when using the `'logit_normal'` weighting scheme.",
706
- )
707
- parser.add_argument(
708
- "--flow_logit_std",
709
- type=float,
710
- default=1.0,
711
- help="Standard deviation to use when using the `'logit_normal'` weighting scheme.",
712
- )
713
- parser.add_argument(
714
- "--flow_mode_scale",
715
- type=float,
716
- default=1.29,
717
- help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
718
  )
 
 
 
719
 
720
 
721
  def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
722
- # TODO: support full finetuning and other kinds
723
- parser.add_argument(
724
- "--training_type",
725
- type=str,
726
- choices=["lora", "full-finetune"],
727
- required=True,
728
- help="Type of training to perform. Choose between ['lora', 'full-finetune']",
729
- )
730
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
731
- parser.add_argument(
732
- "--batch_size",
733
- type=int,
734
- default=1,
735
- help="Batch size (per device) for the training dataloader.",
736
- )
737
- parser.add_argument("--train_epochs", type=int, default=1, help="Number of training epochs.")
738
- parser.add_argument(
739
- "--train_steps",
740
- type=int,
741
- default=None,
742
- help="Total number of training steps to perform. If provided, overrides `--num_train_epochs`.",
743
- )
744
- parser.add_argument("--rank", type=int, default=64, help="The rank for LoRA matrices.")
745
- parser.add_argument(
746
- "--lora_alpha",
747
- type=int,
748
- default=64,
749
- help="The lora_alpha to compute scaling factor (lora_alpha / rank) for LoRA matrices.",
750
- )
751
- parser.add_argument(
752
- "--target_modules",
753
- type=str,
754
- default=["to_k", "to_q", "to_v", "to_out.0"],
755
- nargs="+",
756
- help="The target modules for LoRA.",
757
- )
758
- parser.add_argument(
759
- "--gradient_accumulation_steps",
760
- type=int,
761
- default=1,
762
- help="Number of updates steps to accumulate before performing a backward/update pass.",
763
- )
764
  parser.add_argument(
765
- "--gradient_checkpointing",
766
- action="store_true",
767
- help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
768
- )
769
- parser.add_argument(
770
- "--checkpointing_steps",
771
- type=int,
772
- default=500,
773
- help=(
774
- "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
775
- " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
776
- " training using `--resume_from_checkpoint`."
777
- ),
778
- )
779
- parser.add_argument(
780
- "--checkpointing_limit",
781
- type=int,
782
- default=None,
783
- help=("Max number of checkpoints to store."),
784
- )
785
- parser.add_argument(
786
- "--resume_from_checkpoint",
787
- type=str,
788
- default=None,
789
- help=(
790
- "Whether training should be resumed from a previous checkpoint. Use a path saved by"
791
- ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
792
- ),
793
- )
794
- parser.add_argument(
795
- "--enable_slicing",
796
- action="store_true",
797
- help="Whether or not to use VAE slicing for saving memory.",
798
- )
799
- parser.add_argument(
800
- "--enable_tiling",
801
- action="store_true",
802
- help="Whether or not to use VAE tiling for saving memory.",
803
  )
 
 
 
 
 
 
 
 
 
 
 
804
 
805
 
806
  def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None:
807
- parser.add_argument(
808
- "--lr",
809
- type=float,
810
- default=1e-4,
811
- help="Initial learning rate (after the potential warmup period) to use.",
812
- )
813
- parser.add_argument(
814
- "--scale_lr",
815
- action="store_true",
816
- help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
817
- )
818
- parser.add_argument(
819
- "--lr_scheduler",
820
- type=str,
821
- default="constant",
822
- help=(
823
- 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
824
- ' "constant", "constant_with_warmup"]'
825
- ),
826
- )
827
- parser.add_argument(
828
- "--lr_warmup_steps",
829
- type=int,
830
- default=500,
831
- help="Number of steps for the warmup in the lr scheduler.",
832
- )
833
- parser.add_argument(
834
- "--lr_num_cycles",
835
- type=int,
836
- default=1,
837
- help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
838
- )
839
- parser.add_argument(
840
- "--lr_power",
841
- type=float,
842
- default=1.0,
843
- help="Power factor of the polynomial scheduler.",
844
- )
845
  parser.add_argument(
846
  "--optimizer",
847
  type=lambda s: s.lower(),
848
  default="adam",
849
- choices=["adam", "adamw"],
850
- help=("The optimizer type to use."),
851
  )
852
- parser.add_argument(
853
- "--use_8bit_bnb",
854
- action="store_true",
855
- help=("Whether to use 8bit variant of the `--optimizer` using `bitsandbytes`."),
856
- )
857
- parser.add_argument(
858
- "--beta1",
859
- type=float,
860
- default=0.9,
861
- help="The beta1 parameter for the Adam and Prodigy optimizers.",
862
- )
863
- parser.add_argument(
864
- "--beta2",
865
- type=float,
866
- default=0.95,
867
- help="The beta2 parameter for the Adam and Prodigy optimizers.",
868
- )
869
- parser.add_argument(
870
- "--beta3",
871
- type=float,
872
- default=None,
873
- help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
874
- )
875
- parser.add_argument(
876
- "--weight_decay",
877
- type=float,
878
- default=1e-04,
879
- help="Weight decay to use for optimizer.",
880
- )
881
- parser.add_argument(
882
- "--epsilon",
883
- type=float,
884
- default=1e-8,
885
- help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
886
- )
887
- parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
888
 
889
 
890
  def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
891
- parser.add_argument(
892
- "--validation_prompts",
893
- type=str,
894
- default=None,
895
- help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
896
- )
897
- parser.add_argument(
898
- "--validation_images",
899
- type=str,
900
- default=None,
901
- help="One or more image path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
902
- )
903
- parser.add_argument(
904
- "--validation_videos",
905
- type=str,
906
- default=None,
907
- help="One or more video path(s)/URLs that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
908
- )
909
- parser.add_argument(
910
- "--validation_separator",
911
- type=str,
912
- default=":::",
913
- help="String that separates multiple validation prompts",
914
- )
915
- parser.add_argument(
916
- "--num_validation_videos",
917
- type=int,
918
- default=1,
919
- help="Number of videos that should be generated during validation per `validation_prompt`.",
920
- )
921
- parser.add_argument(
922
- "--validation_epochs",
923
- type=int,
924
- default=None,
925
- help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
926
- )
927
- parser.add_argument(
928
- "--validation_steps",
929
- type=int,
930
- default=None,
931
- help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
932
- )
933
- parser.add_argument(
934
- "--validation_frame_rate",
935
- type=int,
936
- default=25,
937
- help="Frame rate to use for the validation videos.",
938
- )
939
- parser.add_argument(
940
- "--enable_model_cpu_offload",
941
- action="store_true",
942
- help="Whether or not to enable model-wise CPU offloading when performing validation/testing to save memory.",
943
- )
944
 
945
 
946
  def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
947
- parser.add_argument("--tracker_name", type=str, default="finetrainers", help="Project tracker name")
948
- parser.add_argument(
949
- "--push_to_hub",
950
- action="store_true",
951
- help="Whether or not to push the model to the Hub.",
952
- )
953
- parser.add_argument(
954
- "--hub_token",
955
- type=str,
956
- default=None,
957
- help="The token to use to push to the Model Hub.",
958
- )
959
- parser.add_argument(
960
- "--hub_model_id",
961
- type=str,
962
- default=None,
963
- help="The name of the repository to keep in sync with the local `output_dir`.",
964
- )
965
- parser.add_argument(
966
- "--output_dir",
967
- type=str,
968
- default="finetrainers-training",
969
- help="The output directory where the model predictions and checkpoints will be written.",
970
- )
971
- parser.add_argument(
972
- "--logging_dir",
973
- type=str,
974
- default="logs",
975
- help="Directory where logs are stored.",
976
- )
977
- parser.add_argument(
978
- "--allow_tf32",
979
- action="store_true",
980
- help=(
981
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
982
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
983
- ),
984
- )
985
- parser.add_argument(
986
- "--nccl_timeout",
987
- type=int,
988
- default=600,
989
- help="Maximum timeout duration before which allgather, or related, operations fail in multi-GPU/multi-node training settings.",
990
- )
991
- parser.add_argument(
992
- "--report_to",
993
- type=str,
994
- default="none",
995
- choices=["none", "wandb"],
996
- help="The integration to report the results and logs to.",
997
- )
998
 
999
 
1000
  def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
1001
- parser.add_argument(
1002
- "--list_models",
1003
- action="store_true",
1004
- help="List all the supported models.",
1005
- )
1006
 
1007
 
1008
  _DTYPE_MAP = {
@@ -1014,8 +724,16 @@ _DTYPE_MAP = {
1014
  }
1015
 
1016
 
1017
- def _map_to_args_type(args: Dict[str, Any]) -> Args:
1018
- result_args = Args()
 
 
 
 
 
 
 
 
1019
 
1020
  # Model arguments
1021
  result_args.model_name = args.model_name
@@ -1023,6 +741,14 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
1023
  result_args.revision = args.revision
1024
  result_args.variant = args.variant
1025
  result_args.cache_dir = args.cache_dir
 
 
 
 
 
 
 
 
1026
  result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
1027
  result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
1028
  result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
@@ -1033,21 +759,11 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
1033
  result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern
1034
 
1035
  # Dataset arguments
1036
- if args.data_root is None and args.dataset_file is None:
1037
- raise ValueError("At least one of `data_root` or `dataset_file` should be provided.")
1038
-
1039
- result_args.data_root = args.data_root
1040
- result_args.dataset_file = args.dataset_file
1041
- result_args.video_column = args.video_column
1042
- result_args.caption_column = args.caption_column
1043
- result_args.id_token = args.id_token
1044
- result_args.image_resolution_buckets = args.image_resolution_buckets or DEFAULT_IMAGE_RESOLUTION_BUCKETS
1045
- result_args.video_resolution_buckets = args.video_resolution_buckets or DEFAULT_VIDEO_RESOLUTION_BUCKETS
1046
- result_args.video_reshape_mode = args.video_reshape_mode
1047
- result_args.caption_dropout_p = args.caption_dropout_p
1048
- result_args.caption_dropout_technique = args.caption_dropout_technique
1049
- result_args.precompute_conditions = args.precompute_conditions
1050
- result_args.remove_common_llm_caption_prefixes = args.remove_common_llm_caption_prefixes
1051
 
1052
  # Dataloader arguments
1053
  result_args.dataloader_num_workers = args.dataloader_num_workers
@@ -1069,11 +785,8 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
1069
  result_args.training_type = args.training_type
1070
  result_args.seed = args.seed
1071
  result_args.batch_size = args.batch_size
1072
- result_args.train_epochs = args.train_epochs
1073
  result_args.train_steps = args.train_steps
1074
- result_args.rank = args.rank
1075
- result_args.lora_alpha = args.lora_alpha
1076
- result_args.target_modules = args.target_modules
1077
  result_args.gradient_accumulation_steps = args.gradient_accumulation_steps
1078
  result_args.gradient_checkpointing = args.gradient_checkpointing
1079
  result_args.checkpointing_steps = args.checkpointing_steps
@@ -1084,9 +797,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
1084
 
1085
  # Optimizer arguments
1086
  result_args.optimizer = args.optimizer or "adamw"
1087
- result_args.use_8bit_bnb = args.use_8bit_bnb
1088
  result_args.lr = args.lr or 1e-4
1089
- result_args.scale_lr = args.scale_lr
1090
  result_args.lr_scheduler = args.lr_scheduler
1091
  result_args.lr_warmup_steps = args.lr_warmup_steps
1092
  result_args.lr_num_cycles = args.lr_num_cycles
@@ -1099,42 +810,9 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
1099
  result_args.max_grad_norm = args.max_grad_norm
1100
 
1101
  # Validation arguments
1102
- validation_prompts = args.validation_prompts.split(args.validation_separator) if args.validation_prompts else []
1103
- validation_images = args.validation_images.split(args.validation_separator) if args.validation_images else None
1104
- validation_videos = args.validation_videos.split(args.validation_separator) if args.validation_videos else None
1105
- stripped_validation_prompts = []
1106
- validation_heights = []
1107
- validation_widths = []
1108
- validation_num_frames = []
1109
- for prompt in validation_prompts:
1110
- prompt: str
1111
- prompt = prompt.strip()
1112
- actual_prompt, separator, resolution = prompt.rpartition("@@@")
1113
- stripped_validation_prompts.append(actual_prompt)
1114
- num_frames, height, width = None, None, None
1115
- if len(resolution) > 0:
1116
- num_frames, height, width = map(int, resolution.split("x"))
1117
- validation_num_frames.append(num_frames)
1118
- validation_heights.append(height)
1119
- validation_widths.append(width)
1120
-
1121
- if validation_images is None:
1122
- validation_images = [None] * len(validation_prompts)
1123
- if validation_videos is None:
1124
- validation_videos = [None] * len(validation_prompts)
1125
-
1126
- result_args.validation_prompts = stripped_validation_prompts
1127
- result_args.validation_heights = validation_heights
1128
- result_args.validation_widths = validation_widths
1129
- result_args.validation_num_frames = validation_num_frames
1130
- result_args.validation_images = validation_images
1131
- result_args.validation_videos = validation_videos
1132
-
1133
- result_args.num_validation_videos_per_prompt = args.num_validation_videos
1134
- result_args.validation_every_n_epochs = args.validation_epochs
1135
- result_args.validation_every_n_steps = args.validation_steps
1136
  result_args.enable_model_cpu_offload = args.enable_model_cpu_offload
1137
- result_args.validation_frame_rate = args.validation_frame_rate
1138
 
1139
  # Miscellaneous arguments
1140
  result_args.tracker_name = args.tracker_name
@@ -1143,45 +821,36 @@ def _map_to_args_type(args: Dict[str, Any]) -> Args:
1143
  result_args.hub_model_id = args.hub_model_id
1144
  result_args.output_dir = args.output_dir
1145
  result_args.logging_dir = args.logging_dir
 
1146
  result_args.allow_tf32 = args.allow_tf32
 
1147
  result_args.nccl_timeout = args.nccl_timeout
1148
  result_args.report_to = args.report_to
 
1149
 
1150
  return result_args
1151
 
1152
 
1153
- def _validated_model_args(args: Args):
1154
  if args.training_type == "full-finetune":
1155
  assert (
1156
  "transformer" not in args.layerwise_upcasting_modules
1157
  ), "Layerwise upcasting is not supported for full-finetune training"
1158
 
1159
 
1160
- def _validate_training_args(args: Args):
1161
- if args.training_type == "lora":
1162
- assert args.rank is not None, "Rank is required for LoRA training"
1163
- assert args.lora_alpha is not None, "LoRA alpha is required for LoRA training"
1164
- assert (
1165
- args.target_modules is not None and len(args.target_modules) > 0
1166
- ), "Target modules are required for LoRA training"
1167
-
1168
-
1169
- def _validate_validation_args(args: Args):
1170
- assert args.validation_prompts is not None, "Validation prompts are required for validation"
1171
- if args.validation_images is not None:
1172
- assert len(args.validation_images) == len(
1173
- args.validation_prompts
1174
- ), "Validation images and prompts should be of same length"
1175
- if args.validation_videos is not None:
1176
- assert len(args.validation_videos) == len(
1177
- args.validation_prompts
1178
- ), "Validation videos and prompts should be of same length"
1179
- assert len(args.validation_prompts) == len(
1180
- args.validation_heights
1181
- ), "Validation prompts and heights should be of same length"
1182
- assert len(args.validation_prompts) == len(
1183
- args.validation_widths
1184
- ), "Validation prompts and widths should be of same length"
1185
 
1186
 
1187
  def _display_helper_messages(args: argparse.Namespace):
 
1
  import argparse
2
+ import os
3
+ import pathlib
4
  import sys
5
+ from typing import Any, Callable, Dict, List, Optional
6
 
7
  import torch
8
 
9
+ from .config import SUPPORTED_MODEL_CONFIGS, ModelType, TrainingType
10
+ from .logging import get_logger
11
+ from .parallel import ParallelBackendEnum
12
+ from .utils import get_non_null_items
13
 
14
 
15
+ logger = get_logger()
16
+
17
+
18
+ class BaseArgs:
19
  r"""
20
  The arguments for the finetrainers training script.
21
 
 
26
  TODO(aryan): add `python train.py --memory_requirements --model_name <model_name>` to show
27
  memory requirements per model, per training type with sensible training settings.
28
 
29
+ PARALLEL ARGUMENTS
30
+ ------------------
31
+ parallel_backend (`str`, defaults to `accelerate`):
32
+ The parallel backend to use for training. Choose between ['accelerate', 'ptd'].
33
+ pp_degree (`int`, defaults to `1`):
34
+ The degree of pipeline parallelism.
35
+ dp_degree (`int`, defaults to `1`):
36
+ The degree of data parallelism (number of model replicas).
37
+ dp_shards (`int`, defaults to `-1`):
38
+ The number of data parallel shards (number of model partitions).
39
+ cp_degree (`int`, defaults to `1`):
40
+ The degree of context parallelism.
41
+
42
  MODEL ARGUMENTS
43
  ---------------
44
  model_name (`str`):
 
53
  storage requirements.
54
  cache_dir (`str`, defaults to `None`):
55
  The directory where the downloaded models and datasets will be stored, or loaded from.
56
+ tokenizer_id (`str`, defaults to `None`):
57
+ Identifier for the tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
58
+ tokenizer_2_id (`str`, defaults to `None`):
59
+ Identifier for the second tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
60
+ tokenizer_3_id (`str`, defaults to `None`):
61
+ Identifier for the third tokenizer model. This is useful when using a different tokenizer than the default from `pretrained_model_name_or_path`.
62
+ text_encoder_id (`str`, defaults to `None`):
63
+ Identifier for the text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
64
+ text_encoder_2_id (`str`, defaults to `None`):
65
+ Identifier for the second text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
66
+ text_encoder_3_id (`str`, defaults to `None`):
67
+ Identifier for the third text encoder model. This is useful when using a different text encoder than the default from `pretrained_model_name_or_path`.
68
+ transformer_id (`str`, defaults to `None`):
69
+ Identifier for the transformer model. This is useful when using a different transformer model than the default from `pretrained_model_name_or_path`.
70
+ vae_id (`str`, defaults to `None`):
71
+ Identifier for the VAE model. This is useful when using a different VAE model than the default from `pretrained_model_name_or_path`.
72
  text_encoder_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
73
  Data type for the text encoder when generating text embeddings.
74
  text_encoder_2_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
 
90
 
91
  DATASET ARGUMENTS
92
  -----------------
93
+ dataset_config (`str`):
94
+ File to a dataset file containing information about training data. This file can contain information about one or
95
+ more datasets in JSON format. The file must have a key called "datasets", which is a list of dictionaries. Each
96
+ dictionary must contain the following keys:
97
+ - "data_root": (`str`)
98
+ The root directory containing the dataset. This parameter must be provided if `dataset_file` is not provided.
99
+ - "dataset_file": (`str`)
100
+ Path to a CSV/JSON/JSONL/PARQUET/ARROW/HF_HUB_DATASET file containing metadata for training. This parameter
101
+ must be provided if `data_root` is not provided.
102
+ - "dataset_type": (`str`)
103
+ Type of dataset. Choose between ['image', 'video'].
104
+ - "id_token": (`str`)
105
+ Identifier token appended to the start of each prompt if provided. This is useful for LoRA-type training
106
+ for single subject/concept/style training, but is not necessary.
107
+ - "image_resolution_buckets": (`List[Tuple[int, int]]`)
108
+ Resolution buckets for image. This should be a list of tuples containing 2 values, where each tuple
109
+ represents the resolution (height, width). All images will be resized to the nearest bucket resolution.
110
+ This parameter must be provided if `dataset_type` is 'image'.
111
+ - "video_resolution_buckets": (`List[Tuple[int, int, int]]`)
112
+ Resolution buckets for video. This should be a list of tuples containing 3 values, where each tuple
113
+ represents the resolution (num_frames, height, width). All videos will be resized to the nearest bucket
114
+ resolution. This parameter must be provided if `dataset_type` is 'video'.
115
+ - "reshape_mode": (`str`)
116
+ All input images/videos are reshaped using this mode. Choose between the following:
117
+ ["center_crop", "random_crop", "bicubic"].
118
+ - "remove_common_llm_caption_prefixes": (`boolean`)
119
+ Whether or not to remove common LLM caption prefixes. See `~constants.py` for the list of common prefixes.
120
+ dataset_shuffle_buffer_size (`int`, defaults to `1`):
121
+ The buffer size for shuffling the dataset. This is useful for shuffling the dataset before training. The default
122
+ value of `1` means that the dataset will not be shuffled.
123
+ precomputation_items (`int`, defaults to `512`):
124
+ Number of data samples to precompute at once for memory-efficient training. The higher this value,
125
+ the more disk memory will be used to save the precomputed samples (conditions and latents).
126
+ precomputation_dir (`str`, defaults to `None`):
127
+ The directory where the precomputed samples will be stored. If not provided, the precomputed samples
128
+ will be stored in a temporary directory of the output directory.
129
+ precomputation_once (`bool`, defaults to `False`):
130
+ Precompute embeddings from all datasets at once before training. This is useful to save time during training
131
+ with smaller datasets. If set to `False`, will save disk space by precomputing embeddings on-the-fly during
132
+ training when required. Make sure to set `precomputation_items` to a reasonable value in line with the size
133
+ of your dataset(s).
134
 
135
  DATALOADER_ARGUMENTS
136
  --------------------
 
178
  A seed for reproducible training.
179
  batch_size (`int`, defaults to `1`):
180
  Per-device batch size.
181
+ train_steps (`int`, defaults to `1000`):
182
+ Total number of training steps to perform.
183
+ max_data_samples (`int`, defaults to `2**64`):
184
+ Maximum number of data samples observed during training training. If lesser than that required by `train_steps`,
185
+ the training will stop early.
 
 
 
 
 
186
  gradient_accumulation_steps (`int`, defaults to `1`):
187
  Number of gradients steps to accumulate before performing an optimizer step.
188
  gradient_checkpointing (`bool`, defaults to `False`):
 
201
  OPTIMIZER ARGUMENTS
202
  -------------------
203
  optimizer (`str`, defaults to `adamw`):
204
+ The optimizer type to use. Choose between the following:
205
+ - Torch optimizers: ["adam", "adamw"]
206
+ - Bitsandbytes optimizers: ["adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"]
207
  lr (`float`, defaults to `1e-4`):
208
  Initial learning rate (after the potential warmup period) to use.
 
 
209
  lr_scheduler (`str`, defaults to `cosine_with_restarts`):
210
  The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',
211
  'constant', 'constant_with_warmup'].
 
227
 
228
  VALIDATION ARGUMENTS
229
  --------------------
230
+ validation_dataset_file (`str`, defaults to `None`):
231
+ Path to a CSV/JSON/PARQUET/ARROW file containing information for validation. The file must contain atleast the
232
+ "caption" column. Other columns such as "image_path" and "video_path" can be provided too. If provided, "image_path"
233
+ will be used to load a PIL.Image.Image and set the "image" key in the sample dictionary. Similarly, "video_path"
234
+ will be used to load a List[PIL.Image.Image] and set the "video" key in the sample dictionary.
235
+ The validation dataset file may contain other attributes specific to inference/validation such as:
236
+ - "height" and "width" and "num_frames": Resolution
237
+ - "num_inference_steps": Number of inference steps
238
+ - "guidance_scale": Classifier-free Guidance Scale
239
+ - ... (any number of additional attributes can be provided. The ModelSpecification::validate method will be
240
+ invoked with the sample dictionary to validate the sample.)
241
+ validation_steps (`int`, defaults to `500`):
242
+ Number of training steps after which a validation step is performed.
 
 
 
 
 
 
243
  enable_model_cpu_offload (`bool`, defaults to `False`):
244
  Whether or not to offload different modeling components to CPU during validation.
 
 
245
 
246
  MISCELLANEOUS ARGUMENTS
247
  -----------------------
 
257
  The directory where the model checkpoints and logs will be stored.
258
  logging_dir (`str`, defaults to `logs`):
259
  The directory where the logs will be stored.
260
+ logging_steps (`int`, defaults to `1`):
261
+ Training logs will be tracked every `logging_steps` steps.
262
  allow_tf32 (`bool`, defaults to `False`):
263
  Whether or not to allow the use of TF32 matmul on compatible hardware.
264
  nccl_timeout (`int`, defaults to `1800`):
265
  Timeout for the NCCL communication.
266
  report_to (`str`, defaults to `wandb`):
267
  The name of the logger to use for logging training metrics. Choose between ['wandb'].
268
+ verbose (`int`, defaults to `1`):
269
+ Whether or not to print verbose logs.
270
+ - 0: Diffusers/Transformers warning logging on local main process only
271
+ - 1: Diffusers/Transformers info logging on local main process only
272
+ - 2: Diffusers/Transformers debug logging on local main process only
273
+ - 3: Diffusers/Transformers debug logging on all processes
274
  """
275
 
276
+ # Parallel arguments
277
+ parallel_backend = ParallelBackendEnum.ACCELERATE
278
+ pp_degree: int = 1
279
+ dp_degree: int = 1
280
+ dp_shards: int = 1
281
+ cp_degree: int = 1
282
+ tp_degree: int = 1
283
+
284
  # Model arguments
285
  model_name: str = None
286
  pretrained_model_name_or_path: str = None
287
  revision: Optional[str] = None
288
  variant: Optional[str] = None
289
  cache_dir: Optional[str] = None
290
+ tokenizer_id: Optional[str] = None
291
+ tokenizer_2_id: Optional[str] = None
292
+ tokenizer_3_id: Optional[str] = None
293
+ text_encoder_id: Optional[str] = None
294
+ text_encoder_2_id: Optional[str] = None
295
+ text_encoder_3_id: Optional[str] = None
296
+ transformer_id: Optional[str] = None
297
+ vae_id: Optional[str] = None
298
  text_encoder_dtype: torch.dtype = torch.bfloat16
299
  text_encoder_2_dtype: torch.dtype = torch.bfloat16
300
  text_encoder_3_dtype: torch.dtype = torch.bfloat16
 
314
  ]
315
 
316
  # Dataset arguments
317
+ dataset_config: str = None
318
+ dataset_shuffle_buffer_size: int = 1
319
+ precomputation_items: int = 512
320
+ precomputation_dir: Optional[str] = None
321
+ precomputation_once: bool = False
 
 
 
 
 
 
 
322
 
323
  # Dataloader arguments
324
  dataloader_num_workers: int = 0
 
340
  training_type: str = None
341
  seed: int = 42
342
  batch_size: int = 1
343
+ train_steps: int = 1000
344
+ max_data_samples: int = 2**64
 
 
 
345
  gradient_accumulation_steps: int = 1
346
  gradient_checkpointing: bool = False
347
  checkpointing_steps: int = 500
 
352
 
353
  # Optimizer arguments
354
  optimizer: str = "adamw"
 
355
  lr: float = 1e-4
 
356
  lr_scheduler: str = "cosine_with_restarts"
357
  lr_warmup_steps: int = 0
358
  lr_num_cycles: int = 1
 
365
  max_grad_norm: float = 1.0
366
 
367
  # Validation arguments
368
+ validation_dataset_file: Optional[str] = None
369
+ validation_steps: int = 500
 
 
 
 
 
 
 
370
  enable_model_cpu_offload: bool = False
 
371
 
372
  # Miscellaneous arguments
373
  tracker_name: str = "finetrainers"
 
376
  hub_model_id: Optional[str] = None
377
  output_dir: str = None
378
  logging_dir: Optional[str] = "logs"
379
+ logging_steps: int = 1
380
  allow_tf32: bool = False
381
+ init_timeout: int = 300 # 5 minutes
382
+ nccl_timeout: int = 600 # 10 minutes, considering that validation may be performed
383
  report_to: str = "wandb"
384
+ verbose: int = 1
385
 
386
  def to_dict(self) -> Dict[str, Any]:
387
+ parallel_arguments = {
388
+ "pp_degree": self.pp_degree,
389
+ "dp_degree": self.dp_degree,
390
+ "dp_shards": self.dp_shards,
391
+ "cp_degree": self.cp_degree,
392
+ "tp_degree": self.tp_degree,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  }
394
 
395
+ model_arguments = {
396
+ "model_name": self.model_name,
397
+ "pretrained_model_name_or_path": self.pretrained_model_name_or_path,
398
+ "revision": self.revision,
399
+ "variant": self.variant,
400
+ "cache_dir": self.cache_dir,
401
+ "tokenizer_id": self.tokenizer_id,
402
+ "tokenizer_2_id": self.tokenizer_2_id,
403
+ "tokenizer_3_id": self.tokenizer_3_id,
404
+ "text_encoder_id": self.text_encoder_id,
405
+ "text_encoder_2_id": self.text_encoder_2_id,
406
+ "text_encoder_3_id": self.text_encoder_3_id,
407
+ "transformer_id": self.transformer_id,
408
+ "vae_id": self.vae_id,
409
+ "text_encoder_dtype": self.text_encoder_dtype,
410
+ "text_encoder_2_dtype": self.text_encoder_2_dtype,
411
+ "text_encoder_3_dtype": self.text_encoder_3_dtype,
412
+ "transformer_dtype": self.transformer_dtype,
413
+ "vae_dtype": self.vae_dtype,
414
+ "layerwise_upcasting_modules": self.layerwise_upcasting_modules,
415
+ "layerwise_upcasting_storage_dtype": self.layerwise_upcasting_storage_dtype,
416
+ "layerwise_upcasting_skip_modules_pattern": self.layerwise_upcasting_skip_modules_pattern,
417
+ }
418
+ model_arguments = get_non_null_items(model_arguments)
419
+
420
+ dataset_arguments = {
421
+ "dataset_config": self.dataset_config,
422
+ "dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
423
+ "precomputation_items": self.precomputation_items,
424
+ "precomputation_dir": self.precomputation_dir,
425
+ "precomputation_once": self.precomputation_once,
426
+ }
427
+ dataset_arguments = get_non_null_items(dataset_arguments)
428
 
429
+ dataloader_arguments = {
430
+ "dataloader_num_workers": self.dataloader_num_workers,
431
+ "pin_memory": self.pin_memory,
432
+ }
 
 
433
 
434
+ diffusion_arguments = {
435
+ "flow_resolution_shifting": self.flow_resolution_shifting,
436
+ "flow_base_seq_len": self.flow_base_seq_len,
437
+ "flow_max_seq_len": self.flow_max_seq_len,
438
+ "flow_base_shift": self.flow_base_shift,
439
+ "flow_max_shift": self.flow_max_shift,
440
+ "flow_shift": self.flow_shift,
441
+ "flow_weighting_scheme": self.flow_weighting_scheme,
442
+ "flow_logit_mean": self.flow_logit_mean,
443
+ "flow_logit_std": self.flow_logit_std,
444
+ "flow_mode_scale": self.flow_mode_scale,
445
+ }
446
 
447
+ training_arguments = {
448
+ "training_type": self.training_type,
449
+ "seed": self.seed,
450
+ "batch_size": self.batch_size,
451
+ "train_steps": self.train_steps,
452
+ "max_data_samples": self.max_data_samples,
453
+ "gradient_accumulation_steps": self.gradient_accumulation_steps,
454
+ "gradient_checkpointing": self.gradient_checkpointing,
455
+ "checkpointing_steps": self.checkpointing_steps,
456
+ "checkpointing_limit": self.checkpointing_limit,
457
+ "resume_from_checkpoint": self.resume_from_checkpoint,
458
+ "enable_slicing": self.enable_slicing,
459
+ "enable_tiling": self.enable_tiling,
460
+ }
461
+ training_arguments = get_non_null_items(training_arguments)
462
+
463
+ optimizer_arguments = {
464
+ "optimizer": self.optimizer,
465
+ "lr": self.lr,
466
+ "lr_scheduler": self.lr_scheduler,
467
+ "lr_warmup_steps": self.lr_warmup_steps,
468
+ "lr_num_cycles": self.lr_num_cycles,
469
+ "lr_power": self.lr_power,
470
+ "beta1": self.beta1,
471
+ "beta2": self.beta2,
472
+ "beta3": self.beta3,
473
+ "weight_decay": self.weight_decay,
474
+ "epsilon": self.epsilon,
475
+ "max_grad_norm": self.max_grad_norm,
476
+ }
477
+ optimizer_arguments = get_non_null_items(optimizer_arguments)
478
 
479
+ validation_arguments = {
480
+ "validation_dataset_file": self.validation_dataset_file,
481
+ "validation_steps": self.validation_steps,
482
+ "enable_model_cpu_offload": self.enable_model_cpu_offload,
483
+ }
484
+ validation_arguments = get_non_null_items(validation_arguments)
485
+
486
+ miscellaneous_arguments = {
487
+ "tracker_name": self.tracker_name,
488
+ "push_to_hub": self.push_to_hub,
489
+ "hub_token": self.hub_token,
490
+ "hub_model_id": self.hub_model_id,
491
+ "output_dir": self.output_dir,
492
+ "logging_dir": self.logging_dir,
493
+ "logging_steps": self.logging_steps,
494
+ "allow_tf32": self.allow_tf32,
495
+ "init_timeout": self.init_timeout,
496
+ "nccl_timeout": self.nccl_timeout,
497
+ "report_to": self.report_to,
498
+ "verbose": self.verbose,
499
+ }
500
+ miscellaneous_arguments = get_non_null_items(miscellaneous_arguments)
501
 
502
+ return {
503
+ "parallel_arguments": parallel_arguments,
504
+ "model_arguments": model_arguments,
505
+ "dataset_arguments": dataset_arguments,
506
+ "dataloader_arguments": dataloader_arguments,
507
+ "diffusion_arguments": diffusion_arguments,
508
+ "training_arguments": training_arguments,
509
+ "optimizer_arguments": optimizer_arguments,
510
+ "validation_arguments": validation_arguments,
511
+ "miscellaneous_arguments": miscellaneous_arguments,
512
+ }
513
 
514
+ def extend_args(
515
+ self,
516
+ add_fn: Callable[[argparse.ArgumentParser], None],
517
+ map_fn: Callable[["BaseArgs"], None],
518
+ validate_fn: Callable[["BaseArgs"], None],
519
+ ) -> None:
520
+ if not hasattr(self, "_extended_add_arguments"):
521
+ self._extended_add_arguments = []
522
+ self._extended_add_arguments.append((add_fn, validate_fn, map_fn))
523
+
524
+ def parse_args(self):
525
+ _LIST_MODELS = "--list_models"
526
+
527
+ parser = argparse.ArgumentParser()
528
+
529
+ special_args = [_LIST_MODELS]
530
+ if any(arg in sys.argv for arg in special_args):
531
+ _add_helper_arguments(parser)
532
+ args = parser.parse_args()
533
+ _display_helper_messages(args)
534
+ sys.exit(0)
535
+ else:
536
+ _add_args(parser)
537
+ for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
538
+ add_fn, _, _ = extended_add_arg_fns
539
+ add_fn(parser)
540
+
541
+ args, remaining_args = parser.parse_known_args()
542
+ logger.debug(f"Remaining unparsed arguments: {remaining_args}")
543
+
544
+ mapped_args = _map_to_args_type(args)
545
+ for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
546
+ _, _, map_fn = extended_add_arg_fns
547
+ map_fn(args, mapped_args)
548
+
549
+ _validate_args(mapped_args)
550
+ for extended_add_arg_fns in getattr(self, "_extended_add_arguments", []):
551
+ _, validate_fn, _ = extended_add_arg_fns
552
+ validate_fn(mapped_args)
553
+
554
+ return mapped_args
555
+
556
+
557
+ def _add_args(parser: argparse.ArgumentParser) -> None:
558
+ _add_parallel_arguments(parser)
559
+ _add_model_arguments(parser)
560
+ _add_dataset_arguments(parser)
561
+ _add_dataloader_arguments(parser)
562
+ _add_diffusion_arguments(parser)
563
+ _add_training_arguments(parser)
564
+ _add_optimizer_arguments(parser)
565
+ _add_validation_arguments(parser)
566
+ _add_miscellaneous_arguments(parser)
567
+
568
+
569
+ def _validate_args(args: BaseArgs):
570
+ _validate_model_args(args)
571
+ _validate_dataset_args(args)
572
  _validate_validation_args(args)
573
 
574
 
575
+ def _add_parallel_arguments(parser: argparse.ArgumentParser) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  parser.add_argument(
577
+ "--parallel_backend",
578
  type=str,
579
+ default=ParallelBackendEnum.ACCELERATE,
580
+ choices=[ParallelBackendEnum.ACCELERATE, ParallelBackendEnum.PTD],
 
581
  )
582
+ parser.add_argument("--pp_degree", type=int, default=1)
583
+ parser.add_argument("--dp_degree", type=int, default=1)
584
+ parser.add_argument("--dp_shards", type=int, default=1)
585
+ parser.add_argument("--cp_degree", type=int, default=1)
586
+ parser.add_argument("--tp_degree", type=int, default=1)
587
+
588
+
589
+ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
590
  parser.add_argument(
591
+ "--model_name", type=str, required=True, choices=[x.value for x in ModelType.__members__.values()]
592
+ )
593
+ parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
594
+ parser.add_argument("--revision", type=str, default=None, required=False)
595
+ parser.add_argument("--variant", type=str, default=None)
596
+ parser.add_argument("--cache_dir", type=str, default=None)
597
+ parser.add_argument("--tokenizer_id", type=str, default=None)
598
+ parser.add_argument("--tokenizer_2_id", type=str, default=None)
599
+ parser.add_argument("--tokenizer_3_id", type=str, default=None)
600
+ parser.add_argument("--text_encoder_id", type=str, default=None)
601
+ parser.add_argument("--text_encoder_2_id", type=str, default=None)
602
+ parser.add_argument("--text_encoder_3_id", type=str, default=None)
603
+ parser.add_argument("--transformer_id", type=str, default=None)
604
+ parser.add_argument("--vae_id", type=str, default=None)
605
+ parser.add_argument("--text_encoder_dtype", type=str, default="bf16")
606
+ parser.add_argument("--text_encoder_2_dtype", type=str, default="bf16")
607
+ parser.add_argument("--text_encoder_3_dtype", type=str, default="bf16")
608
+ parser.add_argument("--transformer_dtype", type=str, default="bf16")
609
+ parser.add_argument("--vae_dtype", type=str, default="bf16")
610
+ parser.add_argument("--layerwise_upcasting_modules", type=str, default=[], nargs="+", choices=["transformer"])
 
 
 
 
611
  parser.add_argument(
612
  "--layerwise_upcasting_storage_dtype",
613
  type=str,
614
  default="float8_e4m3fn",
615
  choices=["float8_e4m3fn", "float8_e5m2"],
 
616
  )
617
  parser.add_argument(
618
  "--layerwise_upcasting_skip_modules_pattern",
619
  type=str,
620
  default=["patch_embed", "pos_embed", "x_embedder", "context_embedder", "^proj_in$", "^proj_out$", "norm"],
621
  nargs="+",
 
622
  )
623
 
624
 
625
  def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
626
+ parser.add_argument("--dataset_config", type=str, required=True)
627
+ parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
628
+ parser.add_argument("--precomputation_items", type=int, default=512)
629
+ parser.add_argument("--precomputation_dir", type=str, default=None)
630
+ parser.add_argument("--precomputation_once", action="store_true")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
 
632
 
633
  def _add_dataloader_arguments(parser: argparse.ArgumentParser) -> None:
634
+ parser.add_argument("--dataloader_num_workers", type=int, default=0)
635
+ parser.add_argument("--pin_memory", action="store_true")
 
 
 
 
 
 
 
 
 
636
 
637
 
638
  def _add_diffusion_arguments(parser: argparse.ArgumentParser) -> None:
639
+ parser.add_argument("--flow_resolution_shifting", action="store_true")
640
+ parser.add_argument("--flow_base_seq_len", type=int, default=256)
641
+ parser.add_argument("--flow_max_seq_len", type=int, default=4096)
642
+ parser.add_argument("--flow_base_shift", type=float, default=0.5)
643
+ parser.add_argument("--flow_max_shift", type=float, default=1.15)
644
+ parser.add_argument("--flow_shift", type=float, default=1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  parser.add_argument(
646
  "--flow_weighting_scheme",
647
  type=str,
648
  default="none",
649
  choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  )
651
+ parser.add_argument("--flow_logit_mean", type=float, default=0.0)
652
+ parser.add_argument("--flow_logit_std", type=float, default=1.0)
653
+ parser.add_argument("--flow_mode_scale", type=float, default=1.29)
654
 
655
 
656
  def _add_training_arguments(parser: argparse.ArgumentParser) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  parser.add_argument(
658
+ "--training_type", type=str, choices=[x.value for x in TrainingType.__members__.values()], required=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  )
660
+ parser.add_argument("--seed", type=int, default=None)
661
+ parser.add_argument("--batch_size", type=int, default=1)
662
+ parser.add_argument("--train_steps", type=int, default=1000)
663
+ parser.add_argument("--max_data_samples", type=int, default=2**64)
664
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
665
+ parser.add_argument("--gradient_checkpointing", action="store_true")
666
+ parser.add_argument("--checkpointing_steps", type=int, default=500)
667
+ parser.add_argument("--checkpointing_limit", type=int, default=None)
668
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
669
+ parser.add_argument("--enable_slicing", action="store_true")
670
+ parser.add_argument("--enable_tiling", action="store_true")
671
 
672
 
673
  def _add_optimizer_arguments(parser: argparse.ArgumentParser) -> None:
674
+ parser.add_argument("--lr", type=float, default=1e-4)
675
+ parser.add_argument("--lr_scheduler", type=str, default="constant")
676
+ parser.add_argument("--lr_warmup_steps", type=int, default=500)
677
+ parser.add_argument("--lr_num_cycles", type=int, default=1)
678
+ parser.add_argument("--lr_power", type=float, default=1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
  parser.add_argument(
680
  "--optimizer",
681
  type=lambda s: s.lower(),
682
  default="adam",
683
+ choices=["adam", "adamw", "adam-bnb", "adamw-bnb", "adam-bnb-8bit", "adamw-bnb-8bit"],
 
684
  )
685
+ parser.add_argument("--beta1", type=float, default=0.9)
686
+ parser.add_argument("--beta2", type=float, default=0.95)
687
+ parser.add_argument("--beta3", type=float, default=None)
688
+ parser.add_argument("--weight_decay", type=float, default=1e-04)
689
+ parser.add_argument("--epsilon", type=float, default=1e-8)
690
+ parser.add_argument("--max_grad_norm", default=1.0, type=float)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
 
692
 
693
  def _add_validation_arguments(parser: argparse.ArgumentParser) -> None:
694
+ parser.add_argument("--validation_dataset_file", type=str, default=None)
695
+ parser.add_argument("--validation_steps", type=int, default=500)
696
+ parser.add_argument("--enable_model_cpu_offload", action="store_true")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
 
699
  def _add_miscellaneous_arguments(parser: argparse.ArgumentParser) -> None:
700
+ parser.add_argument("--tracker_name", type=str, default="finetrainers")
701
+ parser.add_argument("--push_to_hub", action="store_true")
702
+ parser.add_argument("--hub_token", type=str, default=None)
703
+ parser.add_argument("--hub_model_id", type=str, default=None)
704
+ parser.add_argument("--output_dir", type=str, default="finetrainers-training")
705
+ parser.add_argument("--logging_dir", type=str, default="logs")
706
+ parser.add_argument("--logging_steps", type=int, default=1)
707
+ parser.add_argument("--allow_tf32", action="store_true")
708
+ parser.add_argument("--init_timeout", type=int, default=300)
709
+ parser.add_argument("--nccl_timeout", type=int, default=600)
710
+ parser.add_argument("--report_to", type=str, default="none", choices=["none", "wandb"])
711
+ parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
 
713
 
714
  def _add_helper_arguments(parser: argparse.ArgumentParser) -> None:
715
+ parser.add_argument("--list_models", action="store_true")
 
 
 
 
716
 
717
 
718
  _DTYPE_MAP = {
 
724
  }
725
 
726
 
727
+ def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs:
728
+ result_args = BaseArgs()
729
+
730
+ # Parallel arguments
731
+ result_args.parallel_backend = args.parallel_backend
732
+ result_args.pp_degree = args.pp_degree
733
+ result_args.dp_degree = args.dp_degree
734
+ result_args.dp_shards = args.dp_shards
735
+ result_args.cp_degree = args.cp_degree
736
+ result_args.tp_degree = args.tp_degree
737
 
738
  # Model arguments
739
  result_args.model_name = args.model_name
 
741
  result_args.revision = args.revision
742
  result_args.variant = args.variant
743
  result_args.cache_dir = args.cache_dir
744
+ result_args.tokenizer_id = args.tokenizer_id
745
+ result_args.tokenizer_2_id = args.tokenizer_2_id
746
+ result_args.tokenizer_3_id = args.tokenizer_3_id
747
+ result_args.text_encoder_id = args.text_encoder_id
748
+ result_args.text_encoder_2_id = args.text_encoder_2_id
749
+ result_args.text_encoder_3_id = args.text_encoder_3_id
750
+ result_args.transformer_id = args.transformer_id
751
+ result_args.vae_id = args.vae_id
752
  result_args.text_encoder_dtype = _DTYPE_MAP[args.text_encoder_dtype]
753
  result_args.text_encoder_2_dtype = _DTYPE_MAP[args.text_encoder_2_dtype]
754
  result_args.text_encoder_3_dtype = _DTYPE_MAP[args.text_encoder_3_dtype]
 
759
  result_args.layerwise_upcasting_skip_modules_pattern = args.layerwise_upcasting_skip_modules_pattern
760
 
761
  # Dataset arguments
762
+ result_args.dataset_config = args.dataset_config
763
+ result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
764
+ result_args.precomputation_items = args.precomputation_items
765
+ result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
766
+ result_args.precomputation_once = args.precomputation_once
 
 
 
 
 
 
 
 
 
 
767
 
768
  # Dataloader arguments
769
  result_args.dataloader_num_workers = args.dataloader_num_workers
 
785
  result_args.training_type = args.training_type
786
  result_args.seed = args.seed
787
  result_args.batch_size = args.batch_size
 
788
  result_args.train_steps = args.train_steps
789
+ result_args.max_data_samples = args.max_data_samples
 
 
790
  result_args.gradient_accumulation_steps = args.gradient_accumulation_steps
791
  result_args.gradient_checkpointing = args.gradient_checkpointing
792
  result_args.checkpointing_steps = args.checkpointing_steps
 
797
 
798
  # Optimizer arguments
799
  result_args.optimizer = args.optimizer or "adamw"
 
800
  result_args.lr = args.lr or 1e-4
 
801
  result_args.lr_scheduler = args.lr_scheduler
802
  result_args.lr_warmup_steps = args.lr_warmup_steps
803
  result_args.lr_num_cycles = args.lr_num_cycles
 
810
  result_args.max_grad_norm = args.max_grad_norm
811
 
812
  # Validation arguments
813
+ result_args.validation_dataset_file = args.validation_dataset_file
814
+ result_args.validation_steps = args.validation_steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
815
  result_args.enable_model_cpu_offload = args.enable_model_cpu_offload
 
816
 
817
  # Miscellaneous arguments
818
  result_args.tracker_name = args.tracker_name
 
821
  result_args.hub_model_id = args.hub_model_id
822
  result_args.output_dir = args.output_dir
823
  result_args.logging_dir = args.logging_dir
824
+ result_args.logging_steps = args.logging_steps
825
  result_args.allow_tf32 = args.allow_tf32
826
+ result_args.init_timeout = args.init_timeout
827
  result_args.nccl_timeout = args.nccl_timeout
828
  result_args.report_to = args.report_to
829
+ result_args.verbose = args.verbose
830
 
831
  return result_args
832
 
833
 
834
+ def _validate_model_args(args: BaseArgs):
835
  if args.training_type == "full-finetune":
836
  assert (
837
  "transformer" not in args.layerwise_upcasting_modules
838
  ), "Layerwise upcasting is not supported for full-finetune training"
839
 
840
 
841
+ def _validate_dataset_args(args: BaseArgs):
842
+ dataset_config = pathlib.Path(args.dataset_config)
843
+ if not dataset_config.exists():
844
+ raise ValueError(f"Dataset config file {args.dataset_config} does not exist.")
845
+ if args.dataset_shuffle_buffer_size < 1:
846
+ raise ValueError("Dataset shuffle buffer size must be greater than 0.")
847
+ if args.precomputation_items < 1:
848
+ raise ValueError("Precomputation items must be greater than 0.")
849
+
850
+
851
+ def _validate_validation_args(args: BaseArgs):
852
+ if args.dp_shards > 1 and args.enable_model_cpu_offload:
853
+ raise ValueError("Model CPU offload is not supported with FSDP at the moment.")
 
 
 
 
 
 
 
 
 
 
 
 
854
 
855
 
856
  def _display_helper_messages(args: argparse.Namespace):
finetrainers/config.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Type
3
+
4
+ from .models import ModelSpecification
5
+ from .models.cogvideox import CogVideoXModelSpecification
6
+ from .models.hunyuan_video import HunyuanVideoModelSpecification
7
+ from .models.ltx_video import LTXVideoModelSpecification
8
+ from .models.wan import WanModelSpecification
9
+
10
+
11
+ class ModelType(str, Enum):
12
+ COGVIDEOX = "cogvideox"
13
+ HUNYUAN_VIDEO = "hunyuan_video"
14
+ LTX_VIDEO = "ltx_video"
15
+ WAN = "wan"
16
+
17
+
18
+ class TrainingType(str, Enum):
19
+ LORA = "lora"
20
+ FULL_FINETUNE = "full-finetune"
21
+
22
+
23
+ SUPPORTED_MODEL_CONFIGS = {
24
+ ModelType.HUNYUAN_VIDEO: {
25
+ TrainingType.LORA: HunyuanVideoModelSpecification,
26
+ TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
27
+ },
28
+ ModelType.LTX_VIDEO: {
29
+ TrainingType.LORA: LTXVideoModelSpecification,
30
+ TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
31
+ },
32
+ ModelType.COGVIDEOX: {
33
+ TrainingType.LORA: CogVideoXModelSpecification,
34
+ TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
35
+ },
36
+ ModelType.WAN: {
37
+ TrainingType.LORA: WanModelSpecification,
38
+ TrainingType.FULL_FINETUNE: WanModelSpecification,
39
+ },
40
+ }
41
+
42
+
43
+ def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]:
44
+ if model_name not in SUPPORTED_MODEL_CONFIGS:
45
+ raise ValueError(
46
+ f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
47
+ )
48
+ if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
49
+ raise ValueError(
50
+ f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
51
+ )
52
+ return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
finetrainers/constants.py CHANGED
@@ -78,3 +78,6 @@ COMMON_LLM_START_PHRASES = (
78
  for continuation in _COMMON_CONTINUATION_WORDS
79
  ),
80
  )
 
 
 
 
78
  for continuation in _COMMON_CONTINUATION_WORDS
79
  ),
80
  )
81
+
82
+ SUPPORTED_IMAGE_FILE_EXTENSIONS = ("jpg", "jpeg", "png")
83
+ SUPPORTED_VIDEO_FILE_EXTENSIONS = ("mp4", "mov")
finetrainers/data/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._artifact import ImageArtifact, VideoArtifact
2
+ from .dataloader import DPDataLoader
3
+ from .dataset import (
4
+ ImageCaptionFilePairDataset,
5
+ ImageFileCaptionFileListDataset,
6
+ ImageFolderDataset,
7
+ ImageWebDataset,
8
+ ValidationDataset,
9
+ VideoCaptionFilePairDataset,
10
+ VideoFileCaptionFileListDataset,
11
+ VideoFolderDataset,
12
+ VideoWebDataset,
13
+ combine_datasets,
14
+ initialize_dataset,
15
+ wrap_iterable_dataset_for_preprocessing,
16
+ )
17
+ from .precomputation import DistributedDataPreprocessor, PreprocessedDataIterable
18
+ from .sampler import ResolutionSampler
19
+ from .utils import find_files
finetrainers/data/_artifact.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===== THIS FILE ONLY EXISTS FOR THE TIME BEING SINCE I DID NOT KNOW WHERE TO PUT IT =====
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, List
5
+
6
+ from PIL.Image import Image
7
+
8
+
9
+ @dataclass
10
+ class Artifact:
11
+ type: str
12
+ value: Any
13
+ file_extension: str
14
+
15
+
16
+ @dataclass
17
+ class ImageArtifact(Artifact):
18
+ value: Image
19
+
20
+ def __init__(self, value: Image):
21
+ super().__init__(type="image", value=value, file_extension="png")
22
+
23
+
24
+ @dataclass
25
+ class VideoArtifact(Artifact):
26
+ value: List[Image]
27
+
28
+ def __init__(self, value: List[Image]):
29
+ super().__init__(type="video", value=value, file_extension="mp4")
finetrainers/data/dataloader.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from typing import Any, Dict
3
+
4
+ import torch.distributed.checkpoint.stateful
5
+ import torchdata.stateful_dataloader
6
+
7
+ from ..logging import get_logger
8
+
9
+
10
+ logger = get_logger()
11
+
12
+
13
+ class DPDataLoader(torchdata.stateful_dataloader.StatefulDataLoader, torch.distributed.checkpoint.stateful.Stateful):
14
+ def __init__(
15
+ self,
16
+ rank: int,
17
+ dataset: torch.utils.data.IterableDataset,
18
+ batch_size: int = 1,
19
+ num_workers: int = 0,
20
+ collate_fn=None,
21
+ ) -> None:
22
+ super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn)
23
+
24
+ self._dp_rank = rank
25
+ self._rank_id = f"dp_rank_{rank}"
26
+
27
+ def state_dict(self) -> Dict[str, Any]:
28
+ # Store state only for dp rank to avoid replicating the same state across other dimensions
29
+ return {self._rank_id: pickle.dumps(super().state_dict())}
30
+
31
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
32
+ # State being empty is valid
33
+ if not state_dict:
34
+ return
35
+
36
+ if self._rank_id not in state_dict:
37
+ logger.warning(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
38
+ return
39
+
40
+ super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
finetrainers/data/dataset.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import random
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import datasets
6
+ import datasets.data_files
7
+ import datasets.distributed
8
+ import datasets.exceptions
9
+ import huggingface_hub
10
+ import huggingface_hub.errors
11
+ import numpy as np
12
+ import PIL.Image
13
+ import torch
14
+ import torch.distributed.checkpoint.stateful
15
+ from diffusers.utils import load_image, load_video
16
+ from huggingface_hub import list_repo_files, repo_exists, snapshot_download
17
+ from tqdm.auto import tqdm
18
+
19
+ from .. import constants
20
+ from .. import functional as FF
21
+ from ..logging import get_logger
22
+ from . import utils
23
+
24
+
25
+ import decord # isort:skip
26
+
27
+ decord.bridge.set_bridge("torch")
28
+
29
+ logger = get_logger()
30
+
31
+
32
+ MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
33
+ COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
34
+ COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
35
+ COMMON_IMAGE_FILES = ["image.txt", "images.txt"]
36
+
37
+
38
+ class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
39
+ def __init__(self, root: str, infinite: bool = False) -> None:
40
+ super().__init__()
41
+
42
+ self.root = pathlib.Path(root)
43
+ self.infinite = infinite
44
+
45
+ data = []
46
+ caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0))
47
+ for caption_file in caption_files:
48
+ data_file = self._find_data_file(caption_file)
49
+ if data_file:
50
+ data.append(
51
+ {
52
+ "caption": (self.root / caption_file).as_posix(),
53
+ "image": (self.root / data_file).as_posix(),
54
+ }
55
+ )
56
+
57
+ data = datasets.Dataset.from_list(data)
58
+ data = data.cast_column("image", datasets.Image(mode="RGB"))
59
+
60
+ self._data = data.to_iterable_dataset()
61
+ self._sample_index = 0
62
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
63
+
64
+ def _get_data_iter(self):
65
+ if self._sample_index == 0:
66
+ return iter(self._data)
67
+ return iter(self._data.skip(self._sample_index))
68
+
69
+ def __iter__(self):
70
+ while True:
71
+ for sample in self._get_data_iter():
72
+ self._sample_index += 1
73
+ sample["caption"] = _read_caption_from_file(sample["caption"])
74
+ sample["image"] = _preprocess_image(sample["image"])
75
+ yield sample
76
+
77
+ if not self.infinite:
78
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
79
+ break
80
+ else:
81
+ self._sample_index = 0
82
+
83
+ def load_state_dict(self, state_dict):
84
+ self._sample_index = state_dict["sample_index"]
85
+
86
+ def state_dict(self):
87
+ return {"sample_index": self._sample_index}
88
+
89
+ def _find_data_file(self, caption_file: str) -> str:
90
+ caption_file = pathlib.Path(caption_file)
91
+ data_file = None
92
+ found_data = 0
93
+
94
+ for extension in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
95
+ image_filename = caption_file.with_suffix(f".{extension}")
96
+ if image_filename.exists():
97
+ found_data += 1
98
+ data_file = image_filename
99
+
100
+ if found_data == 0:
101
+ return False
102
+ elif found_data > 1:
103
+ raise ValueError(
104
+ f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
105
+ f"file per caption file. The following extensions are supported:\n"
106
+ f" - Images: {constants.SUPPORTED_IMAGE_FILE_EXTENSIONS}\n"
107
+ )
108
+
109
+ return data_file.as_posix()
110
+
111
+
112
+ class VideoCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
113
+ def __init__(self, root: str, infinite: bool = False) -> None:
114
+ super().__init__()
115
+
116
+ self.root = pathlib.Path(root)
117
+ self.infinite = infinite
118
+
119
+ data = []
120
+ caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0))
121
+ for caption_file in caption_files:
122
+ data_file = self._find_data_file(caption_file)
123
+ if data_file:
124
+ data.append(
125
+ {
126
+ "caption": (self.root / caption_file).as_posix(),
127
+ "video": (self.root / data_file).as_posix(),
128
+ }
129
+ )
130
+
131
+ data = datasets.Dataset.from_list(data)
132
+ data = data.cast_column("video", datasets.Video())
133
+
134
+ self._data = data.to_iterable_dataset()
135
+ self._sample_index = 0
136
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
137
+
138
+ def _get_data_iter(self):
139
+ if self._sample_index == 0:
140
+ return iter(self._data)
141
+ return iter(self._data.skip(self._sample_index))
142
+
143
+ def __iter__(self):
144
+ while True:
145
+ for sample in self._get_data_iter():
146
+ self._sample_index += 1
147
+ sample["caption"] = _read_caption_from_file(sample["caption"])
148
+ sample["video"] = _preprocess_video(sample["video"])
149
+ yield sample
150
+
151
+ if not self.infinite:
152
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
153
+ break
154
+ else:
155
+ self._sample_index = 0
156
+
157
+ def load_state_dict(self, state_dict):
158
+ self._sample_index = state_dict["sample_index"]
159
+
160
+ def state_dict(self):
161
+ return {"sample_index": self._sample_index}
162
+
163
+ def _find_data_file(self, caption_file: str) -> str:
164
+ caption_file = pathlib.Path(caption_file)
165
+ data_file = None
166
+ found_data = 0
167
+
168
+ for extension in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
169
+ video_filename = caption_file.with_suffix(f".{extension}")
170
+ if video_filename.exists():
171
+ found_data += 1
172
+ data_file = video_filename
173
+
174
+ if found_data == 0:
175
+ return False
176
+ elif found_data > 1:
177
+ raise ValueError(
178
+ f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
179
+ f"file per caption file. The following extensions are supported:\n"
180
+ f" - Videos: {constants.SUPPORTED_VIDEO_FILE_EXTENSIONS}\n"
181
+ )
182
+
183
+ return data_file.as_posix()
184
+
185
+
186
+ class ImageFileCaptionFileListDataset(
187
+ torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
188
+ ):
189
+ def __init__(self, root: str, infinite: bool = False) -> None:
190
+ super().__init__()
191
+
192
+ VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
193
+ VALID_IMAGE_FILES = ["image.txt", "images.txt"]
194
+
195
+ self.root = pathlib.Path(root)
196
+ self.infinite = infinite
197
+
198
+ data = []
199
+ existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
200
+ existing_image_files = [file for file in VALID_IMAGE_FILES if (self.root / file).exists()]
201
+
202
+ if len(existing_caption_files) == 0:
203
+ raise FileNotFoundError(
204
+ f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
205
+ )
206
+ if len(existing_image_files) == 0:
207
+ raise FileNotFoundError(
208
+ f"No image file found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
209
+ )
210
+ if len(existing_caption_files) > 1:
211
+ raise ValueError(
212
+ f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
213
+ )
214
+ if len(existing_image_files) > 1:
215
+ raise ValueError(
216
+ f"Multiple image files found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
217
+ )
218
+
219
+ caption_file = existing_caption_files[0]
220
+ image_file = existing_image_files[0]
221
+
222
+ with open((self.root / caption_file).as_posix(), "r") as f:
223
+ captions = f.read().splitlines()
224
+ with open((self.root / image_file).as_posix(), "r") as f:
225
+ images = f.read().splitlines()
226
+ images = [(self.root / image).as_posix() for image in images]
227
+
228
+ if len(captions) != len(images):
229
+ raise ValueError(f"Number of captions ({len(captions)}) must match number of images ({len(images)})")
230
+
231
+ for caption, image in zip(captions, images):
232
+ data.append({"caption": caption, "image": image})
233
+
234
+ data = datasets.Dataset.from_list(data)
235
+ data = data.cast_column("image", datasets.Image(mode="RGB"))
236
+
237
+ self._data = data.to_iterable_dataset()
238
+ self._sample_index = 0
239
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
240
+
241
+ def _get_data_iter(self):
242
+ if self._sample_index == 0:
243
+ return iter(self._data)
244
+ return iter(self._data.skip(self._sample_index))
245
+
246
+ def __iter__(self):
247
+ while True:
248
+ for sample in self._get_data_iter():
249
+ self._sample_index += 1
250
+ sample["image"] = _preprocess_image(sample["image"])
251
+ yield sample
252
+
253
+ if not self.infinite:
254
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
255
+ break
256
+ else:
257
+ self._sample_index = 0
258
+
259
+ def load_state_dict(self, state_dict):
260
+ self._sample_index = state_dict["sample_index"]
261
+
262
+ def state_dict(self):
263
+ return {"sample_index": self._sample_index}
264
+
265
+
266
+ class VideoFileCaptionFileListDataset(
267
+ torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
268
+ ):
269
+ def __init__(self, root: str, infinite: bool = False) -> None:
270
+ super().__init__()
271
+
272
+ VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
273
+ VALID_VIDEO_FILES = ["video.txt", "videos.txt"]
274
+
275
+ self.root = pathlib.Path(root)
276
+ self.infinite = infinite
277
+
278
+ data = []
279
+ existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
280
+ existing_video_files = [file for file in VALID_VIDEO_FILES if (self.root / file).exists()]
281
+
282
+ if len(existing_caption_files) == 0:
283
+ raise FileNotFoundError(
284
+ f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
285
+ )
286
+ if len(existing_video_files) == 0:
287
+ raise FileNotFoundError(
288
+ f"No video file found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
289
+ )
290
+ if len(existing_caption_files) > 1:
291
+ raise ValueError(
292
+ f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
293
+ )
294
+ if len(existing_video_files) > 1:
295
+ raise ValueError(
296
+ f"Multiple video files found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
297
+ )
298
+
299
+ caption_file = existing_caption_files[0]
300
+ video_file = existing_video_files[0]
301
+
302
+ with open((self.root / caption_file).as_posix(), "r") as f:
303
+ captions = f.read().splitlines()
304
+ with open((self.root / video_file).as_posix(), "r") as f:
305
+ videos = f.read().splitlines()
306
+ videos = [(self.root / video).as_posix() for video in videos]
307
+
308
+ if len(captions) != len(videos):
309
+ raise ValueError(f"Number of captions ({len(captions)}) must match number of videos ({len(videos)})")
310
+
311
+ for caption, video in zip(captions, videos):
312
+ data.append({"caption": caption, "video": video})
313
+
314
+ data = datasets.Dataset.from_list(data)
315
+ data = data.cast_column("video", datasets.Video())
316
+
317
+ self._data = data.to_iterable_dataset()
318
+ self._sample_index = 0
319
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
320
+
321
+ def _get_data_iter(self):
322
+ if self._sample_index == 0:
323
+ return iter(self._data)
324
+ return iter(self._data.skip(self._sample_index))
325
+
326
+ def __iter__(self):
327
+ while True:
328
+ for sample in self._get_data_iter():
329
+ self._sample_index += 1
330
+ sample["video"] = _preprocess_video(sample["video"])
331
+ yield sample
332
+
333
+ if not self.infinite:
334
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
335
+ break
336
+ else:
337
+ self._sample_index = 0
338
+
339
+ def load_state_dict(self, state_dict):
340
+ self._sample_index = state_dict["sample_index"]
341
+
342
+ def state_dict(self):
343
+ return {"sample_index": self._sample_index}
344
+
345
+
346
+ class ImageFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
347
+ def __init__(self, root: str, infinite: bool = False) -> None:
348
+ super().__init__()
349
+
350
+ self.root = pathlib.Path(root)
351
+ self.infinite = infinite
352
+
353
+ data = datasets.load_dataset("imagefolder", data_dir=self.root.as_posix(), split="train")
354
+
355
+ self._data = data.to_iterable_dataset()
356
+ self._sample_index = 0
357
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
358
+
359
+ def _get_data_iter(self):
360
+ if self._sample_index == 0:
361
+ return iter(self._data)
362
+ return iter(self._data.skip(self._sample_index))
363
+
364
+ def __iter__(self):
365
+ while True:
366
+ for sample in self._get_data_iter():
367
+ self._sample_index += 1
368
+ sample["image"] = _preprocess_image(sample["image"])
369
+ yield sample
370
+
371
+ if not self.infinite:
372
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
373
+ break
374
+ else:
375
+ self._sample_index = 0
376
+
377
+ def load_state_dict(self, state_dict):
378
+ self._sample_index = state_dict["sample_index"]
379
+
380
+ def state_dict(self):
381
+ return {"sample_index": self._sample_index}
382
+
383
+
384
+ class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
385
+ def __init__(self, root: str, infinite: bool = False) -> None:
386
+ super().__init__()
387
+
388
+ self.root = pathlib.Path(root)
389
+ self.infinite = infinite
390
+
391
+ data = datasets.load_dataset("videofolder", data_dir=self.root.as_posix(), split="train")
392
+
393
+ self._data = data.to_iterable_dataset()
394
+ self._sample_index = 0
395
+ self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
396
+
397
+ def _get_data_iter(self):
398
+ if self._sample_index == 0:
399
+ return iter(self._data)
400
+ return iter(self._data.skip(self._sample_index))
401
+
402
+ def __iter__(self):
403
+ while True:
404
+ for sample in self._get_data_iter():
405
+ self._sample_index += 1
406
+ sample["video"] = _preprocess_video(sample["video"])
407
+ yield sample
408
+
409
+ if not self.infinite:
410
+ logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
411
+ break
412
+ else:
413
+ self._sample_index = 0
414
+
415
+ def load_state_dict(self, state_dict):
416
+ self._sample_index = state_dict["sample_index"]
417
+
418
+ def state_dict(self):
419
+ return {"sample_index": self._sample_index}
420
+
421
+
422
+ class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
423
+ def __init__(self, dataset_name: str, infinite: bool = False) -> None:
424
+ super().__init__()
425
+
426
+ self.dataset_name = dataset_name
427
+ self.infinite = infinite
428
+
429
+ data = datasets.load_dataset(dataset_name, split="train", streaming=True)
430
+ data = data.rename_column("txt", "caption")
431
+ for column_name in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
432
+ if column_name in data.column_names:
433
+ data = data.cast_column(column_name, datasets.Image(mode="RGB"))
434
+ data = data.rename_column(column_name, "image")
435
+
436
+ self._data = data
437
+ self._sample_index = 0
438
+ self._precomputable_once = False
439
+
440
+ def _get_data_iter(self):
441
+ if self._sample_index == 0:
442
+ return iter(self._data)
443
+ return iter(self._data.skip(self._sample_index))
444
+
445
+ def __iter__(self):
446
+ while True:
447
+ for sample in self._get_data_iter():
448
+ self._sample_index += 1
449
+ yield sample
450
+
451
+ if not self.infinite:
452
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
453
+ break
454
+ else:
455
+ # Reset offset for the next iteration
456
+ self._sample_index = 0
457
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
458
+
459
+ def load_state_dict(self, state_dict):
460
+ self._sample_index = state_dict["sample_index"]
461
+
462
+ def state_dict(self):
463
+ return {"sample_index": self._sample_index}
464
+
465
+
466
+ class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
467
+ def __init__(self, dataset_name: str, infinite: bool = False) -> None:
468
+ super().__init__()
469
+
470
+ self.dataset_name = dataset_name
471
+ self.infinite = infinite
472
+
473
+ data = datasets.load_dataset(dataset_name, split="train", streaming=True)
474
+ data = data.rename_column("txt", "caption")
475
+ for column_name in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
476
+ if column_name in data.column_names:
477
+ data = data.cast_column(column_name, datasets.Video())
478
+ data = data.rename_column(column_name, "video")
479
+
480
+ self._data = data
481
+ self._sample_index = 0
482
+ self._precomputable_once = False
483
+
484
+ def _get_data_iter(self):
485
+ if self._sample_index == 0:
486
+ return iter(self._data)
487
+ return iter(self._data.skip(self._sample_index))
488
+
489
+ def __iter__(self):
490
+ while True:
491
+ for sample in self._get_data_iter():
492
+ self._sample_index += 1
493
+ yield sample
494
+
495
+ if not self.infinite:
496
+ logger.warning(f"Dataset {self.dataset_name} has run out of data")
497
+ break
498
+ else:
499
+ # Reset offset for the next iteration
500
+ self._sample_index = 0
501
+ logger.warning(f"Dataset {self.dataset_name} is being re-looped")
502
+
503
+ def load_state_dict(self, state_dict):
504
+ self._sample_index = state_dict["sample_index"]
505
+
506
+ def state_dict(self):
507
+ return {"sample_index": self._sample_index}
508
+
509
+
510
+ class ValidationDataset(torch.utils.data.IterableDataset):
511
+ def __init__(self, filename: str):
512
+ super().__init__()
513
+
514
+ self.filename = pathlib.Path(filename)
515
+
516
+ if not self.filename.exists():
517
+ raise FileNotFoundError(f"File {self.filename.as_posix()} does not exist")
518
+
519
+ if self.filename.suffix == ".csv":
520
+ data = datasets.load_dataset("csv", data_files=self.filename.as_posix(), split="train")
521
+ elif self.filename.suffix == ".json":
522
+ data = datasets.load_dataset("json", data_files=self.filename.as_posix(), split="train", field="data")
523
+ elif self.filename.suffix == ".parquet":
524
+ data = datasets.load_dataset("parquet", data_files=self.filename.as_posix(), split="train")
525
+ elif self.filename.suffix == ".arrow":
526
+ data = datasets.load_dataset("arrow", data_files=self.filename.as_posix(), split="train")
527
+ else:
528
+ _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"]
529
+ raise ValueError(
530
+ f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}"
531
+ )
532
+
533
+ self._data = data.to_iterable_dataset()
534
+
535
+ def __iter__(self):
536
+ for sample in self._data:
537
+ # For consistency reasons, we mandate that "caption" is always present in the validation dataset.
538
+ # However, since the model specifications use "prompt", we create an alias here.
539
+ sample["prompt"] = sample["caption"]
540
+
541
+ # Load image or video if the path is provided
542
+ # TODO(aryan): need to handle custom columns here for control conditions
543
+ sample["image"] = None
544
+ sample["video"] = None
545
+
546
+ if sample.get("image_path", None) is not None:
547
+ image_path = pathlib.Path(sample["image_path"])
548
+ if not image_path.is_file():
549
+ logger.warning(f"Image file {image_path.as_posix()} does not exist.")
550
+ else:
551
+ sample["image"] = load_image(sample["image_path"])
552
+
553
+ if sample.get("video_path", None) is not None:
554
+ video_path = pathlib.Path(sample["video_path"])
555
+ if not video_path.is_file():
556
+ logger.warning(f"Video file {video_path.as_posix()} does not exist.")
557
+ else:
558
+ sample["video"] = load_video(sample["video_path"])
559
+
560
+ sample = {k: v for k, v in sample.items() if v is not None}
561
+ yield sample
562
+
563
+
564
+ class IterableDatasetPreprocessingWrapper(
565
+ torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
566
+ ):
567
+ def __init__(
568
+ self,
569
+ dataset: torch.utils.data.IterableDataset,
570
+ dataset_type: str,
571
+ id_token: Optional[str] = None,
572
+ image_resolution_buckets: List[Tuple[int, int]] = None,
573
+ video_resolution_buckets: List[Tuple[int, int, int]] = None,
574
+ reshape_mode: str = "bicubic",
575
+ remove_common_llm_caption_prefixes: bool = False,
576
+ **kwargs,
577
+ ):
578
+ super().__init__()
579
+
580
+ self.dataset = dataset
581
+ self.dataset_type = dataset_type
582
+ self.id_token = id_token
583
+ self.image_resolution_buckets = image_resolution_buckets
584
+ self.video_resolution_buckets = video_resolution_buckets
585
+ self.reshape_mode = reshape_mode
586
+ self.remove_common_llm_caption_prefixes = remove_common_llm_caption_prefixes
587
+
588
+ logger.info(
589
+ f"Initializing IterableDatasetPreprocessingWrapper for the dataset with the following configuration:\n"
590
+ f" - Dataset Type: {dataset_type}\n"
591
+ f" - ID Token: {id_token}\n"
592
+ f" - Image Resolution Buckets: {image_resolution_buckets}\n"
593
+ f" - Video Resolution Buckets: {video_resolution_buckets}\n"
594
+ f" - Reshape Mode: {reshape_mode}\n"
595
+ f" - Remove Common LLM Caption Prefixes: {remove_common_llm_caption_prefixes}\n"
596
+ )
597
+
598
+ def __iter__(self):
599
+ logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset")
600
+ for sample in iter(self.dataset):
601
+ if self.dataset_type == "image":
602
+ if self.image_resolution_buckets:
603
+ sample["image"] = FF.resize_to_nearest_bucket_image(
604
+ sample["image"], self.image_resolution_buckets, self.reshape_mode
605
+ )
606
+ elif self.dataset_type == "video":
607
+ if self.video_resolution_buckets:
608
+ sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
609
+ sample["video"], self.video_resolution_buckets, self.reshape_mode
610
+ )
611
+ if _first_frame_only:
612
+ msg = (
613
+ "The number of frames in the video is less than the minimum bucket size "
614
+ "specified. The first frame is being used as a single frame video. This "
615
+ "message is logged at the first occurence and for every 128th occurence "
616
+ "after that."
617
+ )
618
+ logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE", msg, frequency=128)
619
+ sample["video"] = sample["video"][0]
620
+
621
+ if self.remove_common_llm_caption_prefixes:
622
+ sample["caption"] = FF.remove_prefix(sample["caption"], constants.COMMON_LLM_START_PHRASES)
623
+
624
+ if self.id_token is not None:
625
+ sample["caption"] = f"{self.id_token} {sample['caption']}"
626
+
627
+ yield sample
628
+
629
+ def load_state_dict(self, state_dict):
630
+ self.dataset.load_state_dict(state_dict["dataset"])
631
+
632
+ def state_dict(self):
633
+ return {"dataset": self.dataset.state_dict()}
634
+
635
+
636
+ class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
637
+ def __init__(self, datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False):
638
+ super().__init__()
639
+
640
+ self.datasets = datasets
641
+ self.buffer_size = buffer_size
642
+ self.shuffle = shuffle
643
+
644
+ logger.info(
645
+ f"Initializing IterableCombinedDataset with the following configuration:\n"
646
+ f" - Number of Datasets: {len(datasets)}\n"
647
+ f" - Buffer Size: {buffer_size}\n"
648
+ f" - Shuffle: {shuffle}\n"
649
+ )
650
+
651
+ def __iter__(self):
652
+ logger.info(f"Starting IterableCombinedDataset with {len(self.datasets)} datasets")
653
+ iterators = [iter(dataset) for dataset in self.datasets]
654
+ buffer = []
655
+ per_iter = max(1, self.buffer_size // len(iterators))
656
+
657
+ for index, it in enumerate(iterators):
658
+ for _ in tqdm(range(per_iter), desc=f"Filling buffer from data iterator {index}"):
659
+ try:
660
+ buffer.append((it, next(it)))
661
+ except StopIteration:
662
+ continue
663
+
664
+ while len(buffer) > 0:
665
+ idx = 0
666
+ if self.shuffle:
667
+ idx = random.randint(0, len(buffer) - 1)
668
+ current_it, sample = buffer.pop(idx)
669
+ yield sample
670
+ try:
671
+ buffer.append((current_it, next(current_it)))
672
+ except StopIteration:
673
+ pass
674
+
675
+ def load_state_dict(self, state_dict):
676
+ for dataset, dataset_state_dict in zip(self.datasets, state_dict["datasets"]):
677
+ dataset.load_state_dict(dataset_state_dict)
678
+
679
+ def state_dict(self):
680
+ return {"datasets": [dataset.state_dict() for dataset in self.datasets]}
681
+
682
+
683
+ # TODO(aryan): maybe write a test for this
684
+ def initialize_dataset(
685
+ dataset_name_or_root: str, dataset_type: str = "video", streaming: bool = True, infinite: bool = False
686
+ ) -> torch.utils.data.IterableDataset:
687
+ assert dataset_type in ["image", "video"]
688
+
689
+ try:
690
+ does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset")
691
+ except huggingface_hub.errors.HFValidationError:
692
+ does_repo_exist_on_hub = False
693
+
694
+ if does_repo_exist_on_hub:
695
+ return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite)
696
+ else:
697
+ return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite)
698
+
699
+
700
+ def combine_datasets(
701
+ datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False
702
+ ) -> torch.utils.data.IterableDataset:
703
+ return IterableCombinedDataset(datasets=datasets, buffer_size=buffer_size, shuffle=shuffle)
704
+
705
+
706
+ def wrap_iterable_dataset_for_preprocessing(
707
+ dataset: torch.utils.data.IterableDataset, dataset_type: str, config: Dict[str, Any]
708
+ ) -> torch.utils.data.IterableDataset:
709
+ return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config)
710
+
711
+
712
+ def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False):
713
+ root = pathlib.Path(dataset_name_or_root)
714
+ supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"]
715
+ metadata_files = [root / metadata_file for metadata_file in supported_metadata_files]
716
+ metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()]
717
+
718
+ if len(metadata_files) > 1:
719
+ raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.")
720
+
721
+ if len(metadata_files) == 1:
722
+ if dataset_type == "image":
723
+ dataset = ImageFolderDataset(root.as_posix(), infinite=infinite)
724
+ else:
725
+ dataset = VideoFolderDataset(root.as_posix(), infinite=infinite)
726
+ return dataset
727
+
728
+ if _has_data_caption_file_pairs(root, remote=False):
729
+ if dataset_type == "image":
730
+ dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite)
731
+ else:
732
+ dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite)
733
+ elif _has_data_file_caption_file_lists(root, remote=False):
734
+ if dataset_type == "image":
735
+ dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
736
+ else:
737
+ dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
738
+ else:
739
+ raise ValueError(
740
+ f"Could not find any supported dataset structure in the directory {root}. Please open an issue at "
741
+ f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will "
742
+ f"help you set it up."
743
+ )
744
+
745
+ return dataset
746
+
747
+
748
+ def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool = False):
749
+ repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
750
+ if _has_data_caption_file_pairs(repo_file_list, remote=True):
751
+ return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
752
+ elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
753
+ return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
754
+ else:
755
+ return _initialize_webdataset(dataset_name, dataset_type, infinite)
756
+
757
+
758
+ def _initialize_data_caption_file_dataset_from_hub(
759
+ dataset_name: str, dataset_type: str, infinite: bool = False
760
+ ) -> torch.utils.data.IterableDataset:
761
+ logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
762
+ dataset_root = snapshot_download(dataset_name, repo_type="dataset")
763
+ if dataset_type == "image":
764
+ return ImageCaptionFilePairDataset(dataset_root, infinite=infinite)
765
+ else:
766
+ return VideoCaptionFilePairDataset(dataset_root, infinite=infinite)
767
+
768
+
769
+ def _initialize_data_file_caption_file_dataset_from_hub(
770
+ dataset_name: str, dataset_type: str, infinite: bool = False
771
+ ) -> torch.utils.data.IterableDataset:
772
+ logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
773
+ dataset_root = snapshot_download(dataset_name, repo_type="dataset")
774
+ if dataset_type == "image":
775
+ return ImageFileCaptionFileListDataset(dataset_root, infinite=infinite)
776
+ else:
777
+ return VideoFileCaptionFileListDataset(dataset_root, infinite=infinite)
778
+
779
+
780
+ def _initialize_webdataset(
781
+ dataset_name: str, dataset_type: str, infinite: bool = False
782
+ ) -> torch.utils.data.IterableDataset:
783
+ logger.info(f"Streaming webdataset {dataset_name} from the HF Hub")
784
+ if dataset_type == "image":
785
+ return ImageWebDataset(dataset_name, infinite=infinite)
786
+ else:
787
+ return VideoWebDataset(dataset_name, infinite=infinite)
788
+
789
+
790
+ def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
791
+ # TODO(aryan): this logic can be improved
792
+ if not remote:
793
+ caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0)
794
+ for caption_file in caption_files:
795
+ caption_file = pathlib.Path(caption_file)
796
+ for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
797
+ data_filename = caption_file.with_suffix(f".{extension}")
798
+ if data_filename.exists():
799
+ return True
800
+ return False
801
+ else:
802
+ caption_files = [file for file in root if file.endswith(".txt")]
803
+ for caption_file in caption_files:
804
+ caption_file = pathlib.Path(caption_file)
805
+ for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
806
+ data_filename = caption_file.with_suffix(f".{extension}").name
807
+ if data_filename in root:
808
+ return True
809
+ return False
810
+
811
+
812
+ def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
813
+ # TODO(aryan): this logic can be improved
814
+ if not remote:
815
+ file_list = {x.name for x in root.iterdir()}
816
+ has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES)
817
+ has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES)
818
+ has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES)
819
+ return has_caption_files and (has_video_files or has_image_files)
820
+ else:
821
+ has_caption_files = any(file in root for file in COMMON_CAPTION_FILES)
822
+ has_video_files = any(file in root for file in COMMON_VIDEO_FILES)
823
+ has_image_files = any(file in root for file in COMMON_IMAGE_FILES)
824
+ return has_caption_files and (has_video_files or has_image_files)
825
+
826
+
827
+ def _read_caption_from_file(filename: str) -> str:
828
+ with open(filename, "r") as f:
829
+ return f.read().strip()
830
+
831
+
832
+ def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
833
+ image = image.convert("RGB")
834
+ image = np.array(image).astype(np.float32)
835
+ image = torch.from_numpy(image)
836
+ image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
837
+ return image
838
+
839
+
840
+ def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
841
+ video = video.get_batch(list(range(len(video))))
842
+ video = video.permute(0, 3, 1, 2).contiguous()
843
+ video = video.float() / 127.5 - 1.0
844
+ return video
finetrainers/data/precomputation.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import Any, Callable, Dict, Iterable, Optional
3
+
4
+ import torch
5
+ from tqdm.auto import tqdm
6
+
7
+ from .. import utils
8
+
9
+
10
+ class DistributedDataPreprocessor:
11
+ def __init__(
12
+ self,
13
+ rank: int,
14
+ num_items: int,
15
+ processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
16
+ save_dir: str,
17
+ ) -> None:
18
+ self._rank = rank
19
+ self._num_items = num_items
20
+ self._processor_fn = processor_fn
21
+ self._save_dir = pathlib.Path(save_dir)
22
+
23
+ self._cached_samples = []
24
+ self._preprocessed_iterator: "PreprocessedDataIterable" = None
25
+
26
+ self._save_dir.mkdir(parents=True, exist_ok=True)
27
+
28
+ subdirectories = [f for f in self._save_dir.iterdir() if f.is_dir()]
29
+ utils.delete_files(subdirectories)
30
+
31
+ def consume(
32
+ self,
33
+ data_type: str,
34
+ components: Dict[str, Any],
35
+ data_iterator,
36
+ generator: Optional[torch.Generator] = None,
37
+ cache_samples: bool = False,
38
+ use_cached_samples: bool = False,
39
+ drop_samples: bool = False,
40
+ ) -> Iterable[Dict[str, Any]]:
41
+ if data_type not in self._processor_fn.keys():
42
+ raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
43
+ if cache_samples:
44
+ if use_cached_samples:
45
+ raise ValueError("Cannot cache and use cached samples at the same time.")
46
+ if drop_samples:
47
+ raise ValueError("Cannot cache and drop samples at the same time.")
48
+
49
+ for i in tqdm(range(self._num_items), desc=f"Rank {self._rank}", total=self._num_items):
50
+ if use_cached_samples:
51
+ item = self._cached_samples[i]
52
+ else:
53
+ item = next(data_iterator)
54
+ if cache_samples:
55
+ self._cached_samples.append(item)
56
+ item = self._processor_fn[data_type](**item, **components, generator=generator)
57
+ _save_item(self._rank, i, item, self._save_dir, data_type)
58
+
59
+ if drop_samples:
60
+ del self._cached_samples
61
+ self._cached_samples = []
62
+ utils.free_memory()
63
+
64
+ self._preprocessed_iterator = PreprocessedDataIterable(self._rank, self._save_dir, data_type)
65
+ return iter(self._preprocessed_iterator)
66
+
67
+ def consume_once(
68
+ self,
69
+ data_type: str,
70
+ components: Dict[str, Any],
71
+ data_iterator,
72
+ generator: Optional[torch.Generator] = None,
73
+ cache_samples: bool = False,
74
+ use_cached_samples: bool = False,
75
+ drop_samples: bool = False,
76
+ ) -> Iterable[Dict[str, Any]]:
77
+ if data_type not in self._processor_fn.keys():
78
+ raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
79
+ if cache_samples:
80
+ if use_cached_samples:
81
+ raise ValueError("Cannot cache and use cached samples at the same time.")
82
+ if drop_samples:
83
+ raise ValueError("Cannot cache and drop samples at the same time.")
84
+
85
+ for i in tqdm(range(self._num_items), desc=f"Processing data on rank {self._rank}", total=self._num_items):
86
+ if use_cached_samples:
87
+ item = self._cached_samples[i]
88
+ else:
89
+ item = next(data_iterator)
90
+ if cache_samples:
91
+ self._cached_samples.append(item)
92
+ item = self._processor_fn[data_type](**item, **components, generator=generator)
93
+ _save_item(self._rank, i, item, self._save_dir, data_type)
94
+
95
+ if drop_samples:
96
+ del self._cached_samples
97
+ self._cached_samples = []
98
+ utils.free_memory()
99
+
100
+ self._preprocessed_iterator = PreprocessedOnceDataIterable(self._rank, self._save_dir, data_type)
101
+ return iter(self._preprocessed_iterator)
102
+
103
+ @property
104
+ def requires_data(self):
105
+ if self._preprocessed_iterator is None:
106
+ return True
107
+ return self._preprocessed_iterator.requires_data
108
+
109
+
110
+ class PreprocessedDataIterable:
111
+ def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
112
+ self._rank = rank
113
+ self._save_dir = pathlib.Path(save_dir)
114
+ self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt")))
115
+ self._data_type = data_type
116
+
117
+ self._requires_data = False
118
+
119
+ def __iter__(self) -> Iterable[Dict[str, Any]]:
120
+ for i in range(self._num_items):
121
+ if i == self._num_items - 1:
122
+ self._requires_data = True
123
+ yield _load_item(self._rank, i, self._save_dir, self._data_type)
124
+
125
+ def __len__(self) -> int:
126
+ return self._num_items
127
+
128
+ @property
129
+ def requires_data(self):
130
+ return self._requires_data
131
+
132
+
133
+ class PreprocessedOnceDataIterable:
134
+ def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
135
+ self._rank = rank
136
+ self._save_dir = pathlib.Path(save_dir)
137
+ self._num_items = len(list(self._save_dir.glob(f"{data_type}-{rank}-*.pt")))
138
+ self._data_type = data_type
139
+
140
+ self._requires_data = False
141
+
142
+ def __iter__(self) -> Iterable[Dict[str, Any]]:
143
+ index = 0
144
+ while True:
145
+ yield _load_item(self._rank, index, self._save_dir, self._data_type)
146
+ index = (index + 1) % self._num_items
147
+
148
+ def __len__(self) -> int:
149
+ return self._num_items
150
+
151
+ @property
152
+ def requires_data(self):
153
+ return self._requires_data
154
+
155
+
156
+ def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None:
157
+ filename = directory / f"{data_type}-{rank}-{index}.pt"
158
+ torch.save(item, filename.as_posix())
159
+
160
+
161
+ def _load_item(rank: int, index: int, directory: pathlib.Path, data_type: str) -> Dict[str, Any]:
162
+ filename = directory / f"{data_type}-{rank}-{index}.pt"
163
+ return torch.load(filename.as_posix(), weights_only=True)
finetrainers/data/sampler.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ class ResolutionSampler:
7
+ def __init__(self, batch_size: int = 1, dim_keys: Dict[str, Tuple[int, ...]] = None) -> None:
8
+ self.batch_size = batch_size
9
+ self.dim_keys = dim_keys
10
+ assert dim_keys is not None, "dim_keys must be provided"
11
+
12
+ self._chosen_leader_key = None
13
+ self._unsatisfied_buckets: Dict[Tuple[int, ...], List[Dict[Any, Any]]] = {}
14
+ self._satisfied_buckets: List[Dict[Any, Any]] = []
15
+
16
+ def consume(self, *dict_items: Dict[Any, Any]) -> None:
17
+ if self._chosen_leader_key is None:
18
+ self._determine_leader_item(*dict_items)
19
+ self._update_buckets(*dict_items)
20
+
21
+ def get_batch(self) -> List[Dict[str, Any]]:
22
+ return list(zip(*self._satisfied_buckets.pop(-1)))
23
+
24
+ @property
25
+ def is_ready(self) -> bool:
26
+ return len(self._satisfied_buckets) > 0
27
+
28
+ def _determine_leader_item(self, *dict_items: Dict[Any, Any]) -> None:
29
+ num_observed = 0
30
+ for dict_item in dict_items:
31
+ for key in self.dim_keys.keys():
32
+ if key in dict_item.keys():
33
+ self._chosen_leader_key = key
34
+ if not torch.is_tensor(dict_item[key]):
35
+ raise ValueError(f"Leader key {key} must be a tensor")
36
+ num_observed += 1
37
+ if num_observed > 1:
38
+ raise ValueError(
39
+ f"Only one leader key is allowed in provided list of data dictionaries. Found {num_observed} leader keys"
40
+ )
41
+ if self._chosen_leader_key is None:
42
+ raise ValueError("No leader key found in provided list of data dictionaries")
43
+
44
+ def _update_buckets(self, *dict_items: Dict[Any, Any]) -> None:
45
+ chosen_value = [
46
+ dict_item[self._chosen_leader_key]
47
+ for dict_item in dict_items
48
+ if self._chosen_leader_key in dict_item.keys()
49
+ ]
50
+ if len(chosen_value) == 0:
51
+ raise ValueError(f"Leader key {self._chosen_leader_key} not found in provided list of data dictionaries")
52
+ chosen_value = chosen_value[0]
53
+ dims = tuple(chosen_value.size(x) for x in self.dim_keys[self._chosen_leader_key])
54
+ if dims not in self._unsatisfied_buckets:
55
+ self._unsatisfied_buckets[dims] = []
56
+ self._unsatisfied_buckets[dims].append(dict_items)
57
+ if len(self._unsatisfied_buckets[dims]) == self.batch_size:
58
+ self._satisfied_buckets.append(self._unsatisfied_buckets.pop(dims))
finetrainers/data/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import List
3
+
4
+
5
+ def find_files(root: str, pattern: str, depth: int = 0) -> List[str]:
6
+ root_path = pathlib.Path(root)
7
+ result_files = []
8
+
9
+ def within_depth(path: pathlib.Path) -> bool:
10
+ return len(path.relative_to(root_path).parts) <= depth
11
+
12
+ if depth == 0:
13
+ result_files.extend([str(file) for file in root_path.glob(pattern)])
14
+ else:
15
+ # rglob matches all levels, but we filter by depth
16
+ for file in root_path.rglob(pattern):
17
+ if file.is_file() and within_depth(file.parent):
18
+ result_files.append(str(file))
19
+
20
+ return result_files
finetrainers/dataset.py DELETED
@@ -1,564 +0,0 @@
1
- import json
2
- import os
3
- import random
4
- from pathlib import Path
5
- from typing import Any, Dict, List, Optional, Tuple
6
-
7
- import numpy as np
8
- import pandas as pd
9
- import torch
10
- import torchvision.transforms as TT
11
- import torchvision.transforms.functional as TTF
12
- from accelerate.logging import get_logger
13
- from torch.utils.data import Dataset, Sampler
14
- from torchvision import transforms
15
- from torchvision.transforms import InterpolationMode
16
- from torchvision.transforms.functional import resize
17
-
18
- import gc
19
- import time
20
- import resource
21
-
22
- # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
23
- # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
24
- import decord # isort:skip
25
-
26
- decord.bridge.set_bridge("torch")
27
-
28
- from .constants import ( # noqa
29
- COMMON_LLM_START_PHRASES,
30
- PRECOMPUTED_CONDITIONS_DIR_NAME,
31
- PRECOMPUTED_DIR_NAME,
32
- PRECOMPUTED_LATENTS_DIR_NAME,
33
- )
34
-
35
- # Decord is causing us some issues!
36
- # Let's try to increase file descriptor limits to avoid this error:
37
- #
38
- # decord._ffi.base.DECORDError: Resource temporarily unavailable
39
- try:
40
- soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
41
- print(f"Current file descriptor limits: soft={soft}, hard={hard}")
42
-
43
- # Try to increase to hard limit if possible
44
- if soft < hard:
45
- resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard))
46
- new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
47
- print(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}")
48
- except Exception as e:
49
- print(f"Could not check or update file descriptor limits: {e}")
50
-
51
- logger = get_logger(__name__)
52
-
53
- # TODO(aryan): This needs a refactor with separation of concerns.
54
- # Images should be handled separately. Videos should be handled separately.
55
- # Loading should be handled separately.
56
- # Preprocessing (aspect ratio, resizing) should be handled separately.
57
- # URL loading should be handled.
58
- # Parquet format should be handled.
59
- # Loading from ZIP should be handled.
60
- class ImageOrVideoDataset(Dataset):
61
- def __init__(
62
- self,
63
- data_root: str,
64
- caption_column: str,
65
- video_column: str,
66
- resolution_buckets: List[Tuple[int, int, int]],
67
- dataset_file: Optional[str] = None,
68
- id_token: Optional[str] = None,
69
- remove_llm_prefixes: bool = False,
70
- ) -> None:
71
- super().__init__()
72
-
73
- self.data_root = Path(data_root)
74
- self.dataset_file = dataset_file
75
- self.caption_column = caption_column
76
- self.video_column = video_column
77
- self.id_token = f"{id_token.strip()} " if id_token else ""
78
- self.resolution_buckets = resolution_buckets
79
-
80
- # Four methods of loading data are supported.
81
- # - Using a CSV: caption_column and video_column must be some column in the CSV. One could
82
- # make use of other columns too, such as a motion score or aesthetic score, by modifying the
83
- # logic in CSV processing.
84
- # - Using two files containing line-separate captions and relative paths to videos.
85
- # - Using a JSON file containing a list of dictionaries, where each dictionary has a `caption_column` and `video_column` key.
86
- # - Using a JSONL file containing a list of line-separated dictionaries, where each dictionary has a `caption_column` and `video_column` key.
87
- # For a more detailed explanation about preparing dataset format, checkout the README.
88
- if dataset_file is None:
89
- (
90
- self.prompts,
91
- self.video_paths,
92
- ) = self._load_dataset_from_local_path()
93
- elif dataset_file.endswith(".csv"):
94
- (
95
- self.prompts,
96
- self.video_paths,
97
- ) = self._load_dataset_from_csv()
98
- elif dataset_file.endswith(".json"):
99
- (
100
- self.prompts,
101
- self.video_paths,
102
- ) = self._load_dataset_from_json()
103
- elif dataset_file.endswith(".jsonl"):
104
- (
105
- self.prompts,
106
- self.video_paths,
107
- ) = self._load_dataset_from_jsonl()
108
- else:
109
- raise ValueError(
110
- "Expected `--dataset_file` to be a path to a CSV file or a directory containing line-separated text prompts and video paths."
111
- )
112
-
113
- if len(self.video_paths) != len(self.prompts):
114
- raise ValueError(
115
- f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
116
- )
117
-
118
- # Clean LLM start phrases
119
- if remove_llm_prefixes:
120
- for i in range(len(self.prompts)):
121
- self.prompts[i] = self.prompts[i].strip()
122
- for phrase in COMMON_LLM_START_PHRASES:
123
- if self.prompts[i].startswith(phrase):
124
- self.prompts[i] = self.prompts[i].removeprefix(phrase).strip()
125
-
126
- self.video_transforms = transforms.Compose(
127
- [
128
- transforms.Lambda(self.scale_transform),
129
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
130
- ]
131
- )
132
-
133
- @staticmethod
134
- def scale_transform(x):
135
- return x / 255.0
136
-
137
- def __len__(self) -> int:
138
- return len(self.video_paths)
139
-
140
- def __getitem__(self, index: int) -> Dict[str, Any]:
141
- if isinstance(index, list):
142
- # Here, index is actually a list of data objects that we need to return.
143
- # The BucketSampler should ideally return indices. But, in the sampler, we'd like
144
- # to have information about num_frames, height and width. Since this is not stored
145
- # as metadata, we need to read the video to get this information. You could read this
146
- # information without loading the full video in memory, but we do it anyway. In order
147
- # to not load the video twice (once to get the metadata, and once to return the loaded video
148
- # based on sampled indices), we cache it in the BucketSampler. When the sampler is
149
- # to yield, we yield the cache data instead of indices. So, this special check ensures
150
- # that data is not loaded a second time. PRs are welcome for improvements.
151
- return index
152
-
153
- prompt = self.id_token + self.prompts[index]
154
-
155
- video_path: Path = self.video_paths[index]
156
- if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
157
- video = self._preprocess_image(video_path)
158
- else:
159
- video = self._preprocess_video(video_path)
160
-
161
- return {
162
- "prompt": prompt,
163
- "video": video,
164
- "video_metadata": {
165
- "num_frames": video.shape[0],
166
- "height": video.shape[2],
167
- "width": video.shape[3],
168
- },
169
- }
170
-
171
- def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]:
172
- if not self.data_root.exists():
173
- raise ValueError("Root folder for videos does not exist")
174
-
175
- prompt_path = self.data_root.joinpath(self.caption_column)
176
- video_path = self.data_root.joinpath(self.video_column)
177
-
178
- if not prompt_path.exists() or not prompt_path.is_file():
179
- raise ValueError(
180
- "Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts."
181
- )
182
- if not video_path.exists() or not video_path.is_file():
183
- raise ValueError(
184
- "Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory."
185
- )
186
-
187
- with open(prompt_path, "r", encoding="utf-8") as file:
188
- prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0]
189
- with open(video_path, "r", encoding="utf-8") as file:
190
- video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0]
191
-
192
- if any(not path.is_file() for path in video_paths):
193
- raise ValueError(
194
- f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
195
- )
196
-
197
- return prompts, video_paths
198
-
199
- def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]:
200
- df = pd.read_csv(self.dataset_file)
201
- prompts = df[self.caption_column].tolist()
202
- video_paths = df[self.video_column].tolist()
203
- video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths]
204
-
205
- if any(not path.is_file() for path in video_paths):
206
- raise ValueError(
207
- f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
208
- )
209
-
210
- return prompts, video_paths
211
-
212
- def _load_dataset_from_json(self) -> Tuple[List[str], List[str]]:
213
- with open(self.dataset_file, "r", encoding="utf-8") as file:
214
- data = json.load(file)
215
-
216
- prompts = [entry[self.caption_column] for entry in data]
217
- video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data]
218
-
219
- if any(not path.is_file() for path in video_paths):
220
- raise ValueError(
221
- f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
222
- )
223
-
224
- return prompts, video_paths
225
-
226
- def _load_dataset_from_jsonl(self) -> Tuple[List[str], List[str]]:
227
- with open(self.dataset_file, "r", encoding="utf-8") as file:
228
- data = [json.loads(line) for line in file]
229
-
230
- prompts = [entry[self.caption_column] for entry in data]
231
- video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data]
232
-
233
- if any(not path.is_file() for path in video_paths):
234
- raise ValueError(
235
- f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file."
236
- )
237
-
238
- return prompts, video_paths
239
-
240
- def _preprocess_image(self, path: Path) -> torch.Tensor:
241
- # TODO(aryan): Support alpha channel in future by whitening background
242
- image = TTF.Image.open(path.as_posix()).convert("RGB")
243
- image = TTF.to_tensor(image)
244
- image = image * 2.0 - 1.0
245
- image = image.unsqueeze(0).contiguous() # [C, H, W] -> [1, C, H, W] (1-frame video)
246
- return image
247
-
248
- def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
249
- """
250
- Loads a single video, or latent and prompt embedding, based on initialization parameters.
251
- Returns a [F, C, H, W] video tensor.
252
- """
253
- max_retries = 3
254
- retry_delay = 1.0 # seconds
255
-
256
- for attempt in range(max_retries):
257
- try:
258
- # Create video reader
259
- video_reader = decord.VideoReader(uri=path.as_posix())
260
- video_num_frames = len(video_reader)
261
-
262
- # Process frames
263
- indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames))
264
- frames = video_reader.get_batch(indices)
265
- frames = frames[: self.max_num_frames].float()
266
- frames = frames.permute(0, 3, 1, 2).contiguous()
267
- frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0)
268
-
269
- # Explicitly clean up resources
270
- del video_reader
271
-
272
- # Force garbage collection occasionally
273
- if random.random() < 0.05: # 5% chance
274
- gc.collect()
275
-
276
- return frames
277
-
278
- except decord._ffi.base.DECORDError as e:
279
- # Log the error
280
- error_msg = str(e)
281
- if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1:
282
- logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}")
283
-
284
- # Clean up and wait before retrying
285
- gc.collect()
286
- time.sleep(retry_delay * (attempt + 1)) # Increasing backoff
287
- else:
288
- # Either not a resource error or we've run out of retries
289
- logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}")
290
- raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}")
291
-
292
-
293
- class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset):
294
- def __init__(self, *args, **kwargs) -> None:
295
- super().__init__(*args, **kwargs)
296
-
297
- self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
298
-
299
- def _preprocess_image(self, path: Path) -> torch.Tensor:
300
- # TODO(aryan): Support alpha channel in future by whitening background
301
- image = TTF.Image.open(path.as_posix()).convert("RGB")
302
- image = TTF.to_tensor(image)
303
-
304
- nearest_res = self._find_nearest_resolution(image.shape[1], image.shape[2])
305
- image = resize(image, nearest_res)
306
-
307
- image = image * 2.0 - 1.0
308
- image = image.unsqueeze(0).contiguous()
309
- return image
310
-
311
- def _preprocess_video(self, path: Path) -> torch.Tensor:
312
- max_retries = 3
313
- retry_delay = 1.0 # seconds
314
-
315
- for attempt in range(max_retries):
316
- try:
317
- # Create video reader
318
- video_reader = decord.VideoReader(uri=path.as_posix())
319
- video_num_frames = len(video_reader)
320
-
321
- # Find appropriate bucket for the video
322
- video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
323
-
324
- if not video_buckets:
325
- _, h, w = self.resolution_buckets[0]
326
- video_buckets = [(1, h, w)]
327
-
328
- nearest_frame_bucket = min(
329
- video_buckets,
330
- key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
331
- default=video_buckets[0],
332
- )[0]
333
-
334
- # Extract and process frames
335
- frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
336
- frames = video_reader.get_batch(frame_indices)
337
- frames = frames[:nearest_frame_bucket].float()
338
- frames = frames.permute(0, 3, 1, 2).contiguous()
339
-
340
- nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
341
- frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0)
342
- frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
343
-
344
- # Explicitly clean up resources
345
- del video_reader
346
-
347
- # Force garbage collection occasionally
348
- if random.random() < 0.05: # 5% chance
349
- gc.collect()
350
-
351
- return frames
352
-
353
- except decord._ffi.base.DECORDError as e:
354
- # Log the error
355
- error_msg = str(e)
356
- if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1:
357
- logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}")
358
-
359
- # Clean up and wait before retrying
360
- gc.collect()
361
- time.sleep(retry_delay * (attempt + 1)) # Increasing backoff
362
- else:
363
- # Either not a resource error or we've run out of retries
364
- logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}")
365
- raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}")
366
-
367
- def _find_nearest_resolution(self, height, width):
368
- nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
369
- return nearest_res[1], nearest_res[2]
370
-
371
-
372
- class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset):
373
- def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
374
- super().__init__(*args, **kwargs)
375
-
376
- self.video_reshape_mode = video_reshape_mode
377
- self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0]
378
-
379
- def _resize_for_rectangle_crop(self, arr, image_size):
380
- reshape_mode = self.video_reshape_mode
381
- if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
382
- arr = resize(
383
- arr,
384
- size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
385
- interpolation=InterpolationMode.BICUBIC,
386
- )
387
- else:
388
- arr = resize(
389
- arr,
390
- size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
391
- interpolation=InterpolationMode.BICUBIC,
392
- )
393
-
394
- h, w = arr.shape[2], arr.shape[3]
395
- arr = arr.squeeze(0)
396
-
397
- delta_h = h - image_size[0]
398
- delta_w = w - image_size[1]
399
-
400
- if reshape_mode == "random" or reshape_mode == "none":
401
- top = np.random.randint(0, delta_h + 1)
402
- left = np.random.randint(0, delta_w + 1)
403
- elif reshape_mode == "center":
404
- top, left = delta_h // 2, delta_w // 2
405
- else:
406
- raise NotImplementedError
407
- arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
408
- return arr
409
-
410
- def _preprocess_video(self, path: Path) -> torch.Tensor:
411
- max_retries = 3
412
- retry_delay = 1.0 # seconds
413
-
414
- for attempt in range(max_retries):
415
- try:
416
- # Create video reader
417
- video_reader = decord.VideoReader(uri=path.as_posix())
418
- video_num_frames = len(video_reader)
419
-
420
- # Find appropriate bucket for the video
421
- video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames]
422
-
423
- if not video_buckets:
424
- _, h, w = self.resolution_buckets[0]
425
- video_buckets = [(1, h, w)]
426
-
427
- nearest_frame_bucket = min(
428
- video_buckets,
429
- key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)),
430
- default=video_buckets[0],
431
- )[0]
432
-
433
- # Extract and process frames
434
- frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
435
- frames = video_reader.get_batch(frame_indices)
436
- frames = frames[:nearest_frame_bucket].float()
437
- frames = frames.permute(0, 3, 1, 2).contiguous()
438
-
439
- # Fix: Change self.resolutions to self.resolution_buckets to match the class attribute
440
- nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
441
- frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
442
- frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
443
-
444
- # Explicitly clean up resources
445
- del video_reader
446
-
447
- # Force garbage collection occasionally
448
- if random.random() < 0.05: # 5% chance
449
- gc.collect()
450
-
451
- return frames
452
-
453
- except decord._ffi.base.DECORDError as e:
454
- # Log the error
455
- error_msg = str(e)
456
- if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1:
457
- logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}")
458
-
459
- # Clean up and wait before retrying
460
- gc.collect()
461
- time.sleep(retry_delay * (attempt + 1)) # Increasing backoff
462
- else:
463
- # Either not a resource error or we've run out of retries
464
- logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}")
465
- raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}")
466
-
467
- def _find_nearest_resolution(self, height, width):
468
- nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
469
- return nearest_res[1], nearest_res[2]
470
-
471
-
472
- class PrecomputedDataset(Dataset):
473
- def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None:
474
- super().__init__()
475
-
476
- self.data_root = Path(data_root)
477
-
478
- if model_name and cleaned_model_id:
479
- precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}"
480
- self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME
481
- self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME
482
- else:
483
- self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME
484
- self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME
485
-
486
- self.latent_conditions = sorted(os.listdir(self.latents_path))
487
- self.text_conditions = sorted(os.listdir(self.conditions_path))
488
-
489
- assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match"
490
-
491
- def __len__(self) -> int:
492
- return len(self.latent_conditions)
493
-
494
- def __getitem__(self, index: int) -> Dict[str, Any]:
495
- conditions = {}
496
- latent_path = self.latents_path / self.latent_conditions[index]
497
- condition_path = self.conditions_path / self.text_conditions[index]
498
- conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True)
499
- conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True)
500
- return conditions
501
-
502
-
503
- class BucketSampler(Sampler):
504
- r"""
505
- PyTorch Sampler that groups 3D data by height, width and frames.
506
-
507
- Args:
508
- data_source (`ImageOrVideoDataset`):
509
- A PyTorch dataset object that is an instance of `ImageOrVideoDataset`.
510
- batch_size (`int`, defaults to `8`):
511
- The batch size to use for training.
512
- shuffle (`bool`, defaults to `True`):
513
- Whether or not to shuffle the data in each batch before dispatching to dataloader.
514
- drop_last (`bool`, defaults to `False`):
515
- Whether or not to drop incomplete buckets of data after completely iterating over all data
516
- in the dataset. If set to True, only batches that have `batch_size` number of entries will
517
- be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
518
- and batches that do not have `batch_size` number of entries will also be yielded.
519
- """
520
-
521
- def __init__(
522
- self, data_source: ImageOrVideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
523
- ) -> None:
524
- self.data_source = data_source
525
- self.batch_size = batch_size
526
- self.shuffle = shuffle
527
- self.drop_last = drop_last
528
-
529
- self.buckets = {resolution: [] for resolution in data_source.resolution_buckets}
530
-
531
- self._raised_warning_for_drop_last = False
532
-
533
- def __len__(self):
534
- if self.drop_last and not self._raised_warning_for_drop_last:
535
- self._raised_warning_for_drop_last = True
536
- logger.warning(
537
- "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
538
- )
539
- return (len(self.data_source) + self.batch_size - 1) // self.batch_size
540
-
541
- def __iter__(self):
542
- for index, data in enumerate(self.data_source):
543
- video_metadata = data["video_metadata"]
544
- f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
545
-
546
- self.buckets[(f, h, w)].append(data)
547
- if len(self.buckets[(f, h, w)]) == self.batch_size:
548
- if self.shuffle:
549
- random.shuffle(self.buckets[(f, h, w)])
550
- yield self.buckets[(f, h, w)]
551
- del self.buckets[(f, h, w)]
552
- self.buckets[(f, h, w)] = []
553
-
554
- if self.drop_last:
555
- return
556
-
557
- for fhw, bucket in list(self.buckets.items()):
558
- if len(bucket) == 0:
559
- continue
560
- if self.shuffle:
561
- random.shuffle(bucket)
562
- yield bucket
563
- del self.buckets[fhw]
564
- self.buckets[fhw] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/functional/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .diffusion import flow_match_target, flow_match_xt
2
+ from .image import (
3
+ bicubic_resize_image,
4
+ center_crop_image,
5
+ find_nearest_resolution_image,
6
+ resize_crop_image,
7
+ resize_to_nearest_bucket_image,
8
+ )
9
+ from .text import dropout_caption, dropout_embeddings_to_zero, remove_prefix
10
+ from .video import (
11
+ bicubic_resize_video,
12
+ center_crop_video,
13
+ find_nearest_video_resolution,
14
+ resize_crop_video,
15
+ resize_to_nearest_bucket_video,
16
+ )
finetrainers/functional/diffusion.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
5
+ r"""Forward process of flow matching."""
6
+ return (1.0 - t) * x0 + t * n
7
+
8
+
9
+ def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
10
+ r"""Loss target for flow matching."""
11
+ return n - x0
finetrainers/functional/image.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def center_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
8
+ num_channels, height, width = image.shape
9
+ crop_h, crop_w = size
10
+ top = (height - crop_h) // 2
11
+ left = (width - crop_w) // 2
12
+ return image[:, top : top + crop_h, left : left + crop_w]
13
+
14
+
15
+ def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
16
+ num_channels, height, width = image.shape
17
+ target_h, target_w = size
18
+ scale = max(target_h / height, target_w / width)
19
+ new_h, new_w = int(height * scale), int(width * scale)
20
+ image = F.interpolate(image, size=(new_h, new_w), mode="bilinear", align_corners=False)
21
+ return center_crop_image(image, size)
22
+
23
+
24
+ def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
25
+ return F.interpolate(image, size=size, mode="bicubic", align_corners=False)
26
+
27
+
28
+ def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
29
+ num_channels, height, width = image.shape
30
+ aspect_ratio = width / height
31
+
32
+ def aspect_ratio_diff(bucket):
33
+ return abs((bucket[1] / bucket[0]) - aspect_ratio)
34
+
35
+ return min(resolution_buckets, key=aspect_ratio_diff)
36
+
37
+
38
+ def resize_to_nearest_bucket_image(
39
+ image: torch.Tensor,
40
+ resolution_buckets: List[Tuple[int, int]],
41
+ resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic",
42
+ ) -> torch.Tensor:
43
+ target_size = find_nearest_resolution_image(image, resolution_buckets)
44
+
45
+ if resize_mode == "center_crop":
46
+ return center_crop_image(image, target_size)
47
+ elif resize_mode == "resize_crop":
48
+ return resize_crop_image(image, target_size)
49
+ elif resize_mode == "bicubic":
50
+ return bicubic_resize_image(image, target_size)
51
+ else:
52
+ raise ValueError(
53
+ f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'."
54
+ )
finetrainers/functional/text.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import List, Union
3
+
4
+ import torch
5
+
6
+
7
+ def dropout_caption(caption: Union[str, List[str]], dropout_p: float = 0) -> Union[str, List[str]]:
8
+ if random.random() >= dropout_p:
9
+ return caption
10
+ if isinstance(caption, str):
11
+ return ""
12
+ return [""] * len(caption)
13
+
14
+
15
+ def dropout_embeddings_to_zero(embed: torch.Tensor, dropout_p: float = 0) -> torch.Tensor:
16
+ if random.random() >= dropout_p:
17
+ return embed
18
+ embed = torch.zeros_like(embed)
19
+ return embed
20
+
21
+
22
+ def remove_prefix(text: str, prefixes: List[str]) -> str:
23
+ for prefix in prefixes:
24
+ if text.startswith(prefix):
25
+ return text.removeprefix(prefix).strip()
26
+ return text
finetrainers/functional/video.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def center_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
8
+ num_frames, num_channels, height, width = video.shape
9
+ crop_h, crop_w = size
10
+ top = (height - crop_h) // 2
11
+ left = (width - crop_w) // 2
12
+ return video[:, :, top : top + crop_h, left : left + crop_w]
13
+
14
+
15
+ def resize_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
16
+ num_frames, num_channels, height, width = video.shape
17
+ target_h, target_w = size
18
+ scale = max(target_h / height, target_w / width)
19
+ new_h, new_w = int(height * scale), int(width * scale)
20
+ video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False)
21
+ return center_crop_video(video, size)
22
+
23
+
24
+ def bicubic_resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
25
+ num_frames, num_channels, height, width = video.shape
26
+ video = F.interpolate(video, size=size, mode="bicubic", align_corners=False)
27
+ return video
28
+
29
+
30
+ def find_nearest_video_resolution(
31
+ video: torch.Tensor, resolution_buckets: List[Tuple[int, int, int]]
32
+ ) -> Tuple[int, int, int]:
33
+ num_frames, num_channels, height, width = video.shape
34
+ aspect_ratio = width / height
35
+ possible_buckets = [b for b in resolution_buckets if b[0] <= num_frames]
36
+
37
+ if not possible_buckets:
38
+ best_frame_match = min(resolution_buckets, key=lambda b: abs(b[0] - num_frames))
39
+ else:
40
+ best_frame_match = max(possible_buckets, key=lambda b: b[0])
41
+
42
+ frame_filtered_buckets = [b for b in resolution_buckets if b[0] == best_frame_match[0]]
43
+
44
+ def aspect_ratio_diff(bucket):
45
+ return abs((bucket[2] / bucket[1]) - aspect_ratio)
46
+
47
+ return min(frame_filtered_buckets, key=aspect_ratio_diff)
48
+
49
+
50
+ def resize_to_nearest_bucket_video(
51
+ video: torch.Tensor,
52
+ resolution_buckets: List[Tuple[int, int, int]],
53
+ resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic",
54
+ ) -> torch.Tensor:
55
+ """
56
+ Resizes a video tensor to the nearest resolution bucket using the specified mode.
57
+ - It first finds a frame match with <= T frames.
58
+ - Then, it selects the closest height/width bucket.
59
+
60
+ Args:
61
+ video (`torch.Tensor`):
62
+ Input video tensor of shape `(B, T, C, H, W)`.
63
+ resolution_buckets (`List[Tuple[int, int, int]]`):
64
+ Available (num_frames, height, width) resolution buckets.
65
+ resize_mode (`str`):
66
+ One of ["center_crop", "resize_crop", "bicubic"].
67
+
68
+ Returns:
69
+ `torch.Tensor`:
70
+ Resized video tensor of the nearest bucket resolution.
71
+ """
72
+ target_frames, target_h, target_w = find_nearest_video_resolution(video, resolution_buckets)
73
+
74
+ # Adjust frame count: only interpolate frames if no lesser/equal frame count exists
75
+ num_frames, num_channels, height, width = video.shape
76
+ _first_frame_only = False
77
+ if num_frames > target_frames:
78
+ # Downsample: Select frames evenly
79
+ indices = torch.linspace(0, num_frames - 1, target_frames).long()
80
+ video = video[indices, :, :, :]
81
+ elif num_frames < target_frames:
82
+ _first_frame_only = False
83
+
84
+ # Resize spatial resolution
85
+ if resize_mode == "center_crop":
86
+ return center_crop_video(video, (target_h, target_w)), _first_frame_only
87
+ elif resize_mode == "resize_crop":
88
+ return resize_crop_video(video, (target_h, target_w)), _first_frame_only
89
+ elif resize_mode == "bicubic":
90
+ return bicubic_resize_video(video, (target_h, target_w)), _first_frame_only
91
+ else:
92
+ raise ValueError(
93
+ f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'."
94
+ )
finetrainers/hooks/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .layerwise_upcasting import apply_layerwise_upcasting
 
 
finetrainers/hooks/hooks.py DELETED
@@ -1,176 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import functools
16
- from typing import Any, Dict, Optional, Tuple
17
-
18
- import torch
19
- from accelerate.logging import get_logger
20
-
21
- from ..constants import FINETRAINERS_LOG_LEVEL
22
-
23
-
24
- logger = get_logger("finetrainers") # pylint: disable=invalid-name
25
- logger.setLevel(FINETRAINERS_LOG_LEVEL)
26
-
27
-
28
- class ModelHook:
29
- r"""
30
- A hook that contains callbacks to be executed just before and after the forward method of a model.
31
- """
32
-
33
- _is_stateful = False
34
-
35
- def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
36
- r"""
37
- Hook that is executed when a model is initialized.
38
- Args:
39
- module (`torch.nn.Module`):
40
- The module attached to this hook.
41
- """
42
- return module
43
-
44
- def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
45
- r"""
46
- Hook that is executed when a model is deinitalized.
47
- Args:
48
- module (`torch.nn.Module`):
49
- The module attached to this hook.
50
- """
51
- module.forward = module._old_forward
52
- del module._old_forward
53
- return module
54
-
55
- def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
56
- r"""
57
- Hook that is executed just before the forward method of the model.
58
- Args:
59
- module (`torch.nn.Module`):
60
- The module whose forward pass will be executed just after this event.
61
- args (`Tuple[Any]`):
62
- The positional arguments passed to the module.
63
- kwargs (`Dict[Str, Any]`):
64
- The keyword arguments passed to the module.
65
- Returns:
66
- `Tuple[Tuple[Any], Dict[Str, Any]]`:
67
- A tuple with the treated `args` and `kwargs`.
68
- """
69
- return args, kwargs
70
-
71
- def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
72
- r"""
73
- Hook that is executed just after the forward method of the model.
74
- Args:
75
- module (`torch.nn.Module`):
76
- The module whose forward pass been executed just before this event.
77
- output (`Any`):
78
- The output of the module.
79
- Returns:
80
- `Any`: The processed `output`.
81
- """
82
- return output
83
-
84
- def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
85
- r"""
86
- Hook that is executed when the hook is detached from a module.
87
- Args:
88
- module (`torch.nn.Module`):
89
- The module detached from this hook.
90
- """
91
- return module
92
-
93
- def reset_state(self, module: torch.nn.Module):
94
- if self._is_stateful:
95
- raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
96
- return module
97
-
98
-
99
- class HookRegistry:
100
- def __init__(self, module_ref: torch.nn.Module) -> None:
101
- super().__init__()
102
-
103
- self.hooks: Dict[str, ModelHook] = {}
104
-
105
- self._module_ref = module_ref
106
- self._hook_order = []
107
-
108
- def register_hook(self, hook: ModelHook, name: str) -> None:
109
- if name in self.hooks.keys():
110
- logger.warning(f"Hook with name {name} already exists, replacing it.")
111
-
112
- if hasattr(self._module_ref, "_old_forward"):
113
- old_forward = self._module_ref._old_forward
114
- else:
115
- old_forward = self._module_ref.forward
116
- self._module_ref._old_forward = self._module_ref.forward
117
-
118
- self._module_ref = hook.initialize_hook(self._module_ref)
119
-
120
- if hasattr(hook, "new_forward"):
121
- rewritten_forward = hook.new_forward
122
-
123
- def new_forward(module, *args, **kwargs):
124
- args, kwargs = hook.pre_forward(module, *args, **kwargs)
125
- output = rewritten_forward(module, *args, **kwargs)
126
- return hook.post_forward(module, output)
127
- else:
128
-
129
- def new_forward(module, *args, **kwargs):
130
- args, kwargs = hook.pre_forward(module, *args, **kwargs)
131
- output = old_forward(*args, **kwargs)
132
- return hook.post_forward(module, output)
133
-
134
- self._module_ref.forward = functools.update_wrapper(
135
- functools.partial(new_forward, self._module_ref), old_forward
136
- )
137
-
138
- self.hooks[name] = hook
139
- self._hook_order.append(name)
140
-
141
- def get_hook(self, name: str) -> Optional[ModelHook]:
142
- if name not in self.hooks.keys():
143
- return None
144
- return self.hooks[name]
145
-
146
- def remove_hook(self, name: str) -> None:
147
- if name not in self.hooks.keys():
148
- raise ValueError(f"Hook with name {name} not found.")
149
- self.hooks[name].deinitalize_hook(self._module_ref)
150
- del self.hooks[name]
151
- self._hook_order.remove(name)
152
-
153
- def reset_stateful_hooks(self, recurse: bool = True) -> None:
154
- for hook_name in self._hook_order:
155
- hook = self.hooks[hook_name]
156
- if hook._is_stateful:
157
- hook.reset_state(self._module_ref)
158
-
159
- if recurse:
160
- for module in self._module_ref.modules():
161
- if hasattr(module, "_diffusers_hook"):
162
- module._diffusers_hook.reset_stateful_hooks(recurse=False)
163
-
164
- @classmethod
165
- def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
166
- if not hasattr(module, "_diffusers_hook"):
167
- module._diffusers_hook = cls(module)
168
- return module._diffusers_hook
169
-
170
- def __repr__(self) -> str:
171
- hook_repr = ""
172
- for i, hook_name in enumerate(self._hook_order):
173
- hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
174
- if i < len(self._hook_order) - 1:
175
- hook_repr += "\n"
176
- return f"HookRegistry(\n{hook_repr}\n)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/hooks/layerwise_upcasting.py DELETED
@@ -1,140 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import re
16
- from typing import Optional, Tuple, Type
17
-
18
- import torch
19
- from accelerate.logging import get_logger
20
-
21
- from ..constants import FINETRAINERS_LOG_LEVEL
22
- from .hooks import HookRegistry, ModelHook
23
-
24
-
25
- logger = get_logger("finetrainers") # pylint: disable=invalid-name
26
- logger.setLevel(FINETRAINERS_LOG_LEVEL)
27
-
28
-
29
- # fmt: off
30
- _SUPPORTED_PYTORCH_LAYERS = (
31
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
32
- torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
33
- torch.nn.Linear,
34
- )
35
-
36
- _DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm")
37
- # fmt: on
38
-
39
-
40
- class LayerwiseUpcastingHook(ModelHook):
41
- r"""
42
- A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
43
- for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
44
- footprint.
45
- """
46
-
47
- _is_stateful = False
48
-
49
- def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
50
- self.storage_dtype = storage_dtype
51
- self.compute_dtype = compute_dtype
52
- self.non_blocking = non_blocking
53
-
54
- def initialize_hook(self, module: torch.nn.Module):
55
- module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
56
- return module
57
-
58
- def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
59
- module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
60
- return args, kwargs
61
-
62
- def post_forward(self, module: torch.nn.Module, output):
63
- module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
64
- return output
65
-
66
-
67
- def apply_layerwise_upcasting(
68
- module: torch.nn.Module,
69
- storage_dtype: torch.dtype,
70
- compute_dtype: torch.dtype,
71
- skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN,
72
- skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None,
73
- non_blocking: bool = False,
74
- _prefix: str = "",
75
- ) -> None:
76
- r"""
77
- Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
78
- nn.Module using diffusers layers or pytorch primitives.
79
- Args:
80
- module (`torch.nn.Module`):
81
- The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
82
- precision dtype for storage.
83
- storage_dtype (`torch.dtype`):
84
- The dtype to cast the module to before/after the forward pass for storage.
85
- compute_dtype (`torch.dtype`):
86
- The dtype to cast the module to during the forward pass for computation.
87
- skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
88
- A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
89
- skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`):
90
- A list of module classes to skip during the layerwise upcasting process.
91
- non_blocking (`bool`, defaults to `False`):
92
- If `True`, the weight casting operations are non-blocking.
93
- """
94
- if skip_modules_classes is None and skip_modules_pattern is None:
95
- apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
96
- return
97
-
98
- should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
99
- skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
100
- )
101
- if should_skip:
102
- logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"')
103
- return
104
-
105
- if isinstance(module, _SUPPORTED_PYTORCH_LAYERS):
106
- logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"')
107
- apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
108
- return
109
-
110
- for name, submodule in module.named_children():
111
- layer_name = f"{_prefix}.{name}" if _prefix else name
112
- apply_layerwise_upcasting(
113
- submodule,
114
- storage_dtype,
115
- compute_dtype,
116
- skip_modules_pattern,
117
- skip_modules_classes,
118
- non_blocking,
119
- _prefix=layer_name,
120
- )
121
-
122
-
123
- def apply_layerwise_upcasting_hook(
124
- module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
125
- ) -> None:
126
- r"""
127
- Applies a `LayerwiseUpcastingHook` to a given module.
128
- Args:
129
- module (`torch.nn.Module`):
130
- The module to attach the hook to.
131
- storage_dtype (`torch.dtype`):
132
- The dtype to cast the module to before the forward pass.
133
- compute_dtype (`torch.dtype`):
134
- The dtype to cast the module to during the forward pass.
135
- non_blocking (`bool`):
136
- If `True`, the weight casting operations are non-blocking.
137
- """
138
- registry = HookRegistry.check_if_exists_or_initialize(module)
139
- hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking)
140
- registry.register_hook(hook, "layerwise_upcasting")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/logging.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import TYPE_CHECKING, Union
4
+
5
+ from .constants import FINETRAINERS_LOG_LEVEL
6
+
7
+
8
+ if TYPE_CHECKING:
9
+ from .parallel import ParallelBackendType
10
+
11
+
12
+ class FinetrainersLoggerAdapter(logging.LoggerAdapter):
13
+ def __init__(self, logger: logging.Logger, parallel_backend: "ParallelBackendType" = None) -> None:
14
+ super().__init__(logger, {})
15
+ self.parallel_backend = parallel_backend
16
+ self._log_freq = {}
17
+ self._log_freq_counter = {}
18
+
19
+ def log(
20
+ self,
21
+ level,
22
+ msg,
23
+ *args,
24
+ main_process_only: bool = False,
25
+ local_main_process_only: bool = True,
26
+ in_order: bool = False,
27
+ **kwargs,
28
+ ):
29
+ # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
30
+ kwargs.setdefault("stacklevel", 2)
31
+
32
+ if not self.isEnabledFor(level):
33
+ return
34
+
35
+ if self.parallel_backend is None:
36
+ if int(os.environ.get("RANK", 0)) == 0:
37
+ msg, kwargs = self.process(msg, kwargs)
38
+ self.logger.log(level, msg, *args, **kwargs)
39
+ return
40
+
41
+ if (main_process_only or local_main_process_only) and in_order:
42
+ raise ValueError(
43
+ "Cannot set `main_process_only` or `local_main_process_only` to True while `in_order` is True."
44
+ )
45
+
46
+ if (main_process_only and self.parallel_backend.is_main_process) or (
47
+ local_main_process_only and self.parallel_backend.is_local_main_process
48
+ ):
49
+ msg, kwargs = self.process(msg, kwargs)
50
+ self.logger.log(level, msg, *args, **kwargs)
51
+ return
52
+
53
+ if in_order:
54
+ for i in range(self.parallel_backend.world_size):
55
+ if self.rank == i:
56
+ msg, kwargs = self.process(msg, kwargs)
57
+ self.logger.log(level, msg, *args, **kwargs)
58
+ self.parallel_backend.wait_for_everyone()
59
+ return
60
+
61
+ if not main_process_only and not local_main_process_only:
62
+ msg, kwargs = self.process(msg, kwargs)
63
+ self.logger.log(level, msg, *args, **kwargs)
64
+ return
65
+
66
+ def log_freq(
67
+ self,
68
+ level: str,
69
+ name: str,
70
+ msg: str,
71
+ frequency: int,
72
+ *,
73
+ main_process_only: bool = False,
74
+ local_main_process_only: bool = True,
75
+ in_order: bool = False,
76
+ **kwargs,
77
+ ) -> None:
78
+ if frequency <= 0:
79
+ return
80
+ if name not in self._log_freq_counter:
81
+ self._log_freq[name] = frequency
82
+ self._log_freq_counter[name] = 0
83
+ if self._log_freq_counter[name] % self._log_freq[name] == 0:
84
+ self.log(
85
+ level,
86
+ msg,
87
+ main_process_only=main_process_only,
88
+ local_main_process_only=local_main_process_only,
89
+ in_order=in_order,
90
+ **kwargs,
91
+ )
92
+ self._log_freq_counter[name] += 1
93
+
94
+
95
+ def get_logger() -> Union[logging.Logger, FinetrainersLoggerAdapter]:
96
+ global _logger
97
+ return _logger
98
+
99
+
100
+ def _set_parallel_backend(parallel_backend: "ParallelBackendType") -> FinetrainersLoggerAdapter:
101
+ _logger.parallel_backend = parallel_backend
102
+
103
+
104
+ _logger = logging.getLogger("finetrainers")
105
+ _logger.setLevel(FINETRAINERS_LOG_LEVEL)
106
+ _console_handler = logging.StreamHandler()
107
+ _console_handler.setLevel(FINETRAINERS_LOG_LEVEL)
108
+ _formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
109
+ _console_handler.setFormatter(_formatter)
110
+ _logger.addHandler(_console_handler)
111
+ _logger = FinetrainersLoggerAdapter(_logger)
finetrainers/models/__init__.py CHANGED
@@ -1,33 +1 @@
1
- from typing import Any, Dict
2
-
3
- from .cogvideox import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG, COGVIDEOX_T2V_LORA_CONFIG
4
- from .hunyuan_video import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG, HUNYUAN_VIDEO_T2V_LORA_CONFIG
5
- from .ltx_video import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG, LTX_VIDEO_T2V_LORA_CONFIG
6
-
7
-
8
- SUPPORTED_MODEL_CONFIGS = {
9
- "hunyuan_video": {
10
- "lora": HUNYUAN_VIDEO_T2V_LORA_CONFIG,
11
- "full-finetune": HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG,
12
- },
13
- "ltx_video": {
14
- "lora": LTX_VIDEO_T2V_LORA_CONFIG,
15
- "full-finetune": LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG,
16
- },
17
- "cogvideox": {
18
- "lora": COGVIDEOX_T2V_LORA_CONFIG,
19
- "full-finetune": COGVIDEOX_T2V_FULL_FINETUNE_CONFIG,
20
- },
21
- }
22
-
23
-
24
- def get_config_from_model_name(model_name: str, training_type: str) -> Dict[str, Any]:
25
- if model_name not in SUPPORTED_MODEL_CONFIGS:
26
- raise ValueError(
27
- f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
28
- )
29
- if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
30
- raise ValueError(
31
- f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
32
- )
33
- return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
 
1
+ from .modeling_utils import ModelSpecification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/cogvideox/__init__.py CHANGED
@@ -1,2 +1 @@
1
- from .full_finetune import COGVIDEOX_T2V_FULL_FINETUNE_CONFIG
2
- from .lora import COGVIDEOX_T2V_LORA_CONFIG
 
1
+ from .base_specification import CogVideoXModelSpecification
 
finetrainers/models/cogvideox/base_specification.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ from accelerate import init_empty_weights
6
+ from diffusers import (
7
+ AutoencoderKLCogVideoX,
8
+ CogVideoXDDIMScheduler,
9
+ CogVideoXImageToVideoPipeline,
10
+ CogVideoXPipeline,
11
+ CogVideoXTransformer3DModel,
12
+ )
13
+ from PIL.Image import Image
14
+ from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
15
+
16
+ from ... import data
17
+ from ...logging import get_logger
18
+ from ...processors import ProcessorMixin, T5Processor
19
+ from ...typing import ArtifactType, SchedulerType
20
+ from ...utils import get_non_null_items
21
+ from ..modeling_utils import ModelSpecification
22
+ from ..utils import DiagonalGaussianDistribution
23
+ from .utils import prepare_rotary_positional_embeddings
24
+
25
+
26
+ logger = get_logger()
27
+
28
+
29
+ class CogVideoXLatentEncodeProcessor(ProcessorMixin):
30
+ r"""
31
+ Processor to encode image/video into latents using the CogVideoX VAE.
32
+
33
+ Args:
34
+ output_names (`List[str]`):
35
+ The names of the outputs that the processor returns. The outputs are in the following order:
36
+ - latents: The latents of the input image/video.
37
+ """
38
+
39
+ def __init__(self, output_names: List[str]):
40
+ super().__init__()
41
+ self.output_names = output_names
42
+ assert len(self.output_names) == 1
43
+
44
+ def forward(
45
+ self,
46
+ vae: AutoencoderKLCogVideoX,
47
+ image: Optional[torch.Tensor] = None,
48
+ video: Optional[torch.Tensor] = None,
49
+ generator: Optional[torch.Generator] = None,
50
+ compute_posterior: bool = True,
51
+ ) -> Dict[str, torch.Tensor]:
52
+ device = vae.device
53
+ dtype = vae.dtype
54
+
55
+ if image is not None:
56
+ video = image.unsqueeze(1)
57
+
58
+ assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
59
+ video = video.to(device=device, dtype=vae.dtype)
60
+ video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
61
+
62
+ if compute_posterior:
63
+ latents = vae.encode(video).latent_dist.sample(generator=generator)
64
+ latents = latents.to(dtype=dtype)
65
+ else:
66
+ if vae.use_slicing and video.shape[0] > 1:
67
+ encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
68
+ moments = torch.cat(encoded_slices)
69
+ else:
70
+ moments = vae._encode(video)
71
+ latents = moments.to(dtype=dtype)
72
+
73
+ latents = latents.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] -> [B, F, C, H, W]
74
+ return {self.output_names[0]: latents}
75
+
76
+
77
+ class CogVideoXModelSpecification(ModelSpecification):
78
+ def __init__(
79
+ self,
80
+ pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b",
81
+ tokenizer_id: Optional[str] = None,
82
+ text_encoder_id: Optional[str] = None,
83
+ transformer_id: Optional[str] = None,
84
+ vae_id: Optional[str] = None,
85
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
86
+ transformer_dtype: torch.dtype = torch.bfloat16,
87
+ vae_dtype: torch.dtype = torch.bfloat16,
88
+ revision: Optional[str] = None,
89
+ cache_dir: Optional[str] = None,
90
+ condition_model_processors: List[ProcessorMixin] = None,
91
+ latent_model_processors: List[ProcessorMixin] = None,
92
+ **kwargs,
93
+ ) -> None:
94
+ super().__init__(
95
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
96
+ tokenizer_id=tokenizer_id,
97
+ text_encoder_id=text_encoder_id,
98
+ transformer_id=transformer_id,
99
+ vae_id=vae_id,
100
+ text_encoder_dtype=text_encoder_dtype,
101
+ transformer_dtype=transformer_dtype,
102
+ vae_dtype=vae_dtype,
103
+ revision=revision,
104
+ cache_dir=cache_dir,
105
+ )
106
+
107
+ if condition_model_processors is None:
108
+ condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
109
+ if latent_model_processors is None:
110
+ latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])]
111
+
112
+ self.condition_model_processors = condition_model_processors
113
+ self.latent_model_processors = latent_model_processors
114
+
115
+ @property
116
+ def _resolution_dim_keys(self):
117
+ return {"latents": (1, 3, 4)}
118
+
119
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
120
+ if self.tokenizer_id is not None:
121
+ tokenizer = AutoTokenizer.from_pretrained(
122
+ self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
123
+ )
124
+ else:
125
+ tokenizer = T5Tokenizer.from_pretrained(
126
+ self.pretrained_model_name_or_path,
127
+ subfolder="tokenizer",
128
+ revision=self.revision,
129
+ cache_dir=self.cache_dir,
130
+ )
131
+
132
+ if self.text_encoder_id is not None:
133
+ text_encoder = AutoModel.from_pretrained(
134
+ self.text_encoder_id,
135
+ torch_dtype=self.text_encoder_dtype,
136
+ revision=self.revision,
137
+ cache_dir=self.cache_dir,
138
+ )
139
+ else:
140
+ text_encoder = T5EncoderModel.from_pretrained(
141
+ self.pretrained_model_name_or_path,
142
+ subfolder="text_encoder",
143
+ torch_dtype=self.text_encoder_dtype,
144
+ revision=self.revision,
145
+ cache_dir=self.cache_dir,
146
+ )
147
+
148
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
149
+
150
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
151
+ if self.vae_id is not None:
152
+ vae = AutoencoderKLCogVideoX.from_pretrained(
153
+ self.vae_id,
154
+ torch_dtype=self.vae_dtype,
155
+ revision=self.revision,
156
+ cache_dir=self.cache_dir,
157
+ )
158
+ else:
159
+ vae = AutoencoderKLCogVideoX.from_pretrained(
160
+ self.pretrained_model_name_or_path,
161
+ subfolder="vae",
162
+ torch_dtype=self.vae_dtype,
163
+ revision=self.revision,
164
+ cache_dir=self.cache_dir,
165
+ )
166
+
167
+ return {"vae": vae}
168
+
169
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
170
+ if self.transformer_id is not None:
171
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
172
+ self.transformer_id,
173
+ torch_dtype=self.transformer_dtype,
174
+ revision=self.revision,
175
+ cache_dir=self.cache_dir,
176
+ )
177
+ else:
178
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
179
+ self.pretrained_model_name_or_path,
180
+ subfolder="transformer",
181
+ torch_dtype=self.transformer_dtype,
182
+ revision=self.revision,
183
+ cache_dir=self.cache_dir,
184
+ )
185
+
186
+ scheduler = CogVideoXDDIMScheduler.from_pretrained(
187
+ self.pretrained_model_name_or_path, subfolder="scheduler", revision=self.revision, cache_dir=self.cache_dir
188
+ )
189
+
190
+ return {"transformer": transformer, "scheduler": scheduler}
191
+
192
+ def load_pipeline(
193
+ self,
194
+ tokenizer: Optional[T5Tokenizer] = None,
195
+ text_encoder: Optional[T5EncoderModel] = None,
196
+ transformer: Optional[CogVideoXTransformer3DModel] = None,
197
+ vae: Optional[AutoencoderKLCogVideoX] = None,
198
+ scheduler: Optional[CogVideoXDDIMScheduler] = None,
199
+ enable_slicing: bool = False,
200
+ enable_tiling: bool = False,
201
+ enable_model_cpu_offload: bool = False,
202
+ training: bool = False,
203
+ **kwargs,
204
+ ) -> CogVideoXPipeline:
205
+ components = {
206
+ "tokenizer": tokenizer,
207
+ "text_encoder": text_encoder,
208
+ "transformer": transformer,
209
+ "vae": vae,
210
+ "scheduler": scheduler,
211
+ }
212
+ components = get_non_null_items(components)
213
+
214
+ pipe = CogVideoXPipeline.from_pretrained(
215
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
216
+ )
217
+ pipe.text_encoder.to(self.text_encoder_dtype)
218
+ pipe.vae.to(self.vae_dtype)
219
+
220
+ if not training:
221
+ pipe.transformer.to(self.transformer_dtype)
222
+
223
+ if enable_slicing:
224
+ pipe.vae.enable_slicing()
225
+ if enable_tiling:
226
+ pipe.vae.enable_tiling()
227
+ if enable_model_cpu_offload:
228
+ pipe.enable_model_cpu_offload()
229
+
230
+ return pipe
231
+
232
+ @torch.no_grad()
233
+ def prepare_conditions(
234
+ self,
235
+ tokenizer: T5Tokenizer,
236
+ text_encoder: T5EncoderModel,
237
+ caption: str,
238
+ max_sequence_length: int = 226,
239
+ **kwargs,
240
+ ) -> Dict[str, Any]:
241
+ conditions = {
242
+ "tokenizer": tokenizer,
243
+ "text_encoder": text_encoder,
244
+ "caption": caption,
245
+ "max_sequence_length": max_sequence_length,
246
+ **kwargs,
247
+ }
248
+ input_keys = set(conditions.keys())
249
+ conditions = super().prepare_conditions(**conditions)
250
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
251
+ conditions.pop("prompt_attention_mask", None)
252
+ return conditions
253
+
254
+ @torch.no_grad()
255
+ def prepare_latents(
256
+ self,
257
+ vae: AutoencoderKLCogVideoX,
258
+ image: Optional[torch.Tensor] = None,
259
+ video: Optional[torch.Tensor] = None,
260
+ generator: Optional[torch.Generator] = None,
261
+ compute_posterior: bool = True,
262
+ **kwargs,
263
+ ) -> Dict[str, torch.Tensor]:
264
+ conditions = {
265
+ "vae": vae,
266
+ "image": image,
267
+ "video": video,
268
+ "generator": generator,
269
+ "compute_posterior": compute_posterior,
270
+ **kwargs,
271
+ }
272
+ input_keys = set(conditions.keys())
273
+ conditions = super().prepare_latents(**conditions)
274
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
275
+ return conditions
276
+
277
+ def forward(
278
+ self,
279
+ transformer: CogVideoXTransformer3DModel,
280
+ scheduler: CogVideoXDDIMScheduler,
281
+ condition_model_conditions: Dict[str, torch.Tensor],
282
+ latent_model_conditions: Dict[str, torch.Tensor],
283
+ sigmas: torch.Tensor,
284
+ generator: Optional[torch.Generator] = None,
285
+ compute_posterior: bool = True,
286
+ **kwargs,
287
+ ) -> Tuple[torch.Tensor, ...]:
288
+ # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
289
+ VAE_SPATIAL_SCALE_FACTOR = 8
290
+ rope_base_height = self.transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
291
+ rope_base_width = self.transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR
292
+ patch_size = self.transformer_config.patch_size
293
+ patch_size_t = getattr(self.transformer_config, "patch_size_t", None)
294
+
295
+ if compute_posterior:
296
+ latents = latent_model_conditions.pop("latents")
297
+ else:
298
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"), _dim=2)
299
+ latents = posterior.sample(generator=generator)
300
+ del posterior
301
+
302
+ if not self.vae_config.invert_scale_latents:
303
+ latents = latents * self.vae_config.scaling_factor
304
+
305
+ if patch_size_t is not None:
306
+ latents = self._pad_frames(latents, patch_size_t)
307
+
308
+ timesteps = (sigmas.flatten() * 1000.0).long()
309
+
310
+ noise = torch.zeros_like(latents).normal_(generator=generator)
311
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
312
+
313
+ batch_size, num_frames, num_channels, height, width = latents.shape
314
+ ofs_emb = (
315
+ None
316
+ if getattr(self.transformer_config, "ofs_embed_dim", None) is None
317
+ else latents.new_full((batch_size,), fill_value=2.0)
318
+ )
319
+
320
+ image_rotary_emb = (
321
+ prepare_rotary_positional_embeddings(
322
+ height=height * VAE_SPATIAL_SCALE_FACTOR,
323
+ width=width * VAE_SPATIAL_SCALE_FACTOR,
324
+ num_frames=num_frames,
325
+ vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
326
+ patch_size=patch_size,
327
+ patch_size_t=patch_size_t,
328
+ attention_head_dim=self.transformer_config.attention_head_dim,
329
+ device=transformer.device,
330
+ base_height=rope_base_height,
331
+ base_width=rope_base_width,
332
+ )
333
+ if self.transformer_config.use_rotary_positional_embeddings
334
+ else None
335
+ )
336
+
337
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
338
+ latent_model_conditions["image_rotary_emb"] = image_rotary_emb
339
+ latent_model_conditions["ofs"] = ofs_emb
340
+ condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
341
+
342
+ velocity = transformer(
343
+ **latent_model_conditions,
344
+ **condition_model_conditions,
345
+ timestep=timesteps,
346
+ return_dict=False,
347
+ )[0]
348
+ # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
349
+ # code paths as scheduler.get_velocity(), which can be confusing to understand.
350
+ pred = scheduler.get_velocity(velocity, noisy_latents, timesteps)
351
+ target = latents
352
+
353
+ return pred, target, sigmas
354
+
355
+ def validation(
356
+ self,
357
+ pipeline: CogVideoXPipeline,
358
+ prompt: str,
359
+ image: Optional[Image] = None,
360
+ height: Optional[int] = None,
361
+ width: Optional[int] = None,
362
+ num_frames: Optional[int] = None,
363
+ num_inference_steps: int = 50,
364
+ generator: Optional[torch.Generator] = None,
365
+ **kwargs,
366
+ ) -> List[ArtifactType]:
367
+ # TODO(aryan): add support for more parameters
368
+ if image is not None:
369
+ pipeline = CogVideoXImageToVideoPipeline.from_pipe(pipeline)
370
+
371
+ generation_kwargs = {
372
+ "prompt": prompt,
373
+ "image": image,
374
+ "height": height,
375
+ "width": width,
376
+ "num_frames": num_frames,
377
+ "num_inference_steps": num_inference_steps,
378
+ "generator": generator,
379
+ "return_dict": True,
380
+ "output_type": "pil",
381
+ }
382
+ generation_kwargs = get_non_null_items(generation_kwargs)
383
+ video = pipeline(**generation_kwargs).frames[0]
384
+ return [data.VideoArtifact(value=video)]
385
+
386
+ def _save_lora_weights(
387
+ self,
388
+ directory: str,
389
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
390
+ scheduler: Optional[SchedulerType] = None,
391
+ *args,
392
+ **kwargs,
393
+ ) -> None:
394
+ # TODO(aryan): this needs refactoring
395
+ if transformer_state_dict is not None:
396
+ CogVideoXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
397
+ if scheduler is not None:
398
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
399
+
400
+ def _save_model(
401
+ self,
402
+ directory: str,
403
+ transformer: CogVideoXTransformer3DModel,
404
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
405
+ scheduler: Optional[SchedulerType] = None,
406
+ ) -> None:
407
+ # TODO(aryan): this needs refactoring
408
+ if transformer_state_dict is not None:
409
+ with init_empty_weights():
410
+ transformer_copy = CogVideoXTransformer3DModel.from_config(transformer.config)
411
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
412
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
413
+ if scheduler is not None:
414
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
415
+
416
+ @staticmethod
417
+ def _pad_frames(latents: torch.Tensor, patch_size_t: int) -> torch.Tensor:
418
+ num_frames = latents.size(1)
419
+ additional_frames = patch_size_t - (num_frames % patch_size_t)
420
+ if additional_frames > 0:
421
+ last_frame = latents[:, -1:]
422
+ padding_frames = last_frame.expand(-1, additional_frames, -1, -1, -1)
423
+ latents = torch.cat([latents, padding_frames], dim=1)
424
+ return latents
finetrainers/models/cogvideox/full_finetune.py DELETED
@@ -1,32 +0,0 @@
1
- from diffusers import CogVideoXPipeline
2
-
3
- from .lora import (
4
- calculate_noisy_latents,
5
- collate_fn_t2v,
6
- forward_pass,
7
- initialize_pipeline,
8
- load_condition_models,
9
- load_diffusion_models,
10
- load_latent_models,
11
- post_latent_preparation,
12
- prepare_conditions,
13
- prepare_latents,
14
- validation,
15
- )
16
-
17
-
18
- # TODO(aryan): refactor into model specs for better re-use
19
- COGVIDEOX_T2V_FULL_FINETUNE_CONFIG = {
20
- "pipeline_cls": CogVideoXPipeline,
21
- "load_condition_models": load_condition_models,
22
- "load_latent_models": load_latent_models,
23
- "load_diffusion_models": load_diffusion_models,
24
- "initialize_pipeline": initialize_pipeline,
25
- "prepare_conditions": prepare_conditions,
26
- "prepare_latents": prepare_latents,
27
- "post_latent_preparation": post_latent_preparation,
28
- "collate_fn": collate_fn_t2v,
29
- "calculate_noisy_latents": calculate_noisy_latents,
30
- "forward_pass": forward_pass,
31
- "validation": validation,
32
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/cogvideox/lora.py DELETED
@@ -1,334 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Union
2
-
3
- import torch
4
- from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
5
- from PIL import Image
6
- from transformers import T5EncoderModel, T5Tokenizer
7
-
8
- from .utils import prepare_rotary_positional_embeddings
9
-
10
-
11
- def load_condition_models(
12
- model_id: str = "THUDM/CogVideoX-5b",
13
- text_encoder_dtype: torch.dtype = torch.bfloat16,
14
- revision: Optional[str] = None,
15
- cache_dir: Optional[str] = None,
16
- **kwargs,
17
- ):
18
- tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
19
- text_encoder = T5EncoderModel.from_pretrained(
20
- model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
21
- )
22
- return {"tokenizer": tokenizer, "text_encoder": text_encoder}
23
-
24
-
25
- def load_latent_models(
26
- model_id: str = "THUDM/CogVideoX-5b",
27
- vae_dtype: torch.dtype = torch.bfloat16,
28
- revision: Optional[str] = None,
29
- cache_dir: Optional[str] = None,
30
- **kwargs,
31
- ):
32
- vae = AutoencoderKLCogVideoX.from_pretrained(
33
- model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
34
- )
35
- return {"vae": vae}
36
-
37
-
38
- def load_diffusion_models(
39
- model_id: str = "THUDM/CogVideoX-5b",
40
- transformer_dtype: torch.dtype = torch.bfloat16,
41
- revision: Optional[str] = None,
42
- cache_dir: Optional[str] = None,
43
- **kwargs,
44
- ):
45
- transformer = CogVideoXTransformer3DModel.from_pretrained(
46
- model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
47
- )
48
- scheduler = CogVideoXDDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
49
- return {"transformer": transformer, "scheduler": scheduler}
50
-
51
-
52
- def initialize_pipeline(
53
- model_id: str = "THUDM/CogVideoX-5b",
54
- text_encoder_dtype: torch.dtype = torch.bfloat16,
55
- transformer_dtype: torch.dtype = torch.bfloat16,
56
- vae_dtype: torch.dtype = torch.bfloat16,
57
- tokenizer: Optional[T5Tokenizer] = None,
58
- text_encoder: Optional[T5EncoderModel] = None,
59
- transformer: Optional[CogVideoXTransformer3DModel] = None,
60
- vae: Optional[AutoencoderKLCogVideoX] = None,
61
- scheduler: Optional[CogVideoXDDIMScheduler] = None,
62
- device: Optional[torch.device] = None,
63
- revision: Optional[str] = None,
64
- cache_dir: Optional[str] = None,
65
- enable_slicing: bool = False,
66
- enable_tiling: bool = False,
67
- enable_model_cpu_offload: bool = False,
68
- is_training: bool = False,
69
- **kwargs,
70
- ) -> CogVideoXPipeline:
71
- component_name_pairs = [
72
- ("tokenizer", tokenizer),
73
- ("text_encoder", text_encoder),
74
- ("transformer", transformer),
75
- ("vae", vae),
76
- ("scheduler", scheduler),
77
- ]
78
- components = {}
79
- for name, component in component_name_pairs:
80
- if component is not None:
81
- components[name] = component
82
-
83
- pipe = CogVideoXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
84
- pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
85
- pipe.vae = pipe.vae.to(dtype=vae_dtype)
86
-
87
- # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
88
- # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
89
- # DDP optimizer step.
90
- if not is_training:
91
- pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
92
-
93
- if enable_slicing:
94
- pipe.vae.enable_slicing()
95
- if enable_tiling:
96
- pipe.vae.enable_tiling()
97
-
98
- if enable_model_cpu_offload:
99
- pipe.enable_model_cpu_offload(device=device)
100
- else:
101
- pipe.to(device=device)
102
-
103
- return pipe
104
-
105
-
106
- def prepare_conditions(
107
- tokenizer,
108
- text_encoder,
109
- prompt: Union[str, List[str]],
110
- device: Optional[torch.device] = None,
111
- dtype: Optional[torch.dtype] = None,
112
- max_sequence_length: int = 226, # TODO: this should be configurable
113
- **kwargs,
114
- ):
115
- device = device or text_encoder.device
116
- dtype = dtype or text_encoder.dtype
117
- return _get_t5_prompt_embeds(
118
- tokenizer=tokenizer,
119
- text_encoder=text_encoder,
120
- prompt=prompt,
121
- max_sequence_length=max_sequence_length,
122
- device=device,
123
- dtype=dtype,
124
- )
125
-
126
-
127
- def prepare_latents(
128
- vae: AutoencoderKLCogVideoX,
129
- image_or_video: torch.Tensor,
130
- device: Optional[torch.device] = None,
131
- dtype: Optional[torch.dtype] = None,
132
- generator: Optional[torch.Generator] = None,
133
- precompute: bool = False,
134
- **kwargs,
135
- ) -> torch.Tensor:
136
- device = device or vae.device
137
- dtype = dtype or vae.dtype
138
-
139
- if image_or_video.ndim == 4:
140
- image_or_video = image_or_video.unsqueeze(2)
141
- assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
142
-
143
- image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
144
- image_or_video = image_or_video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
145
- if not precompute:
146
- latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
147
- if not vae.config.invert_scale_latents:
148
- latents = latents * vae.config.scaling_factor
149
- # For training Cog 1.5, we don't need to handle the scaling factor here.
150
- # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
151
- # is probably only needed for image-to-video training.
152
- # TODO(aryan): investigate this
153
- # else:
154
- # latents = 1 / vae.config.scaling_factor * latents
155
- latents = latents.to(dtype=dtype)
156
- return {"latents": latents}
157
- else:
158
- # handle vae scaling in the `train()` method directly.
159
- if vae.use_slicing and image_or_video.shape[0] > 1:
160
- encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
161
- h = torch.cat(encoded_slices)
162
- else:
163
- h = vae._encode(image_or_video)
164
- return {"latents": h}
165
-
166
-
167
- def post_latent_preparation(
168
- vae_config: Dict[str, Any], latents: torch.Tensor, patch_size_t: Optional[int] = None, **kwargs
169
- ) -> torch.Tensor:
170
- if not vae_config.invert_scale_latents:
171
- latents = latents * vae_config.scaling_factor
172
- # For training Cog 1.5, we don't need to handle the scaling factor here.
173
- # The CogVideoX team forgot to multiply here, so we should not do it too. Invert scale latents
174
- # is probably only needed for image-to-video training.
175
- # TODO(aryan): investigate this
176
- # else:
177
- # latents = 1 / vae_config.scaling_factor * latents
178
- latents = _pad_frames(latents, patch_size_t)
179
- latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
180
- return {"latents": latents}
181
-
182
-
183
- def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
184
- return {
185
- "prompts": [x["prompt"] for x in batch[0]],
186
- "videos": torch.stack([x["video"] for x in batch[0]]),
187
- }
188
-
189
-
190
- def calculate_noisy_latents(
191
- scheduler: CogVideoXDDIMScheduler,
192
- noise: torch.Tensor,
193
- latents: torch.Tensor,
194
- timesteps: torch.LongTensor,
195
- ) -> torch.Tensor:
196
- noisy_latents = scheduler.add_noise(latents, noise, timesteps)
197
- return noisy_latents
198
-
199
-
200
- def forward_pass(
201
- transformer: CogVideoXTransformer3DModel,
202
- scheduler: CogVideoXDDIMScheduler,
203
- prompt_embeds: torch.Tensor,
204
- latents: torch.Tensor,
205
- noisy_latents: torch.Tensor,
206
- timesteps: torch.LongTensor,
207
- ofs_emb: Optional[torch.Tensor] = None,
208
- **kwargs,
209
- ) -> torch.Tensor:
210
- # Just hardcode for now. In Diffusers, we will refactor such that RoPE would be handled within the model itself.
211
- VAE_SPATIAL_SCALE_FACTOR = 8
212
- transformer_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
213
- batch_size, num_frames, num_channels, height, width = noisy_latents.shape
214
- rope_base_height = transformer_config.sample_height * VAE_SPATIAL_SCALE_FACTOR
215
- rope_base_width = transformer_config.sample_width * VAE_SPATIAL_SCALE_FACTOR
216
-
217
- image_rotary_emb = (
218
- prepare_rotary_positional_embeddings(
219
- height=height * VAE_SPATIAL_SCALE_FACTOR,
220
- width=width * VAE_SPATIAL_SCALE_FACTOR,
221
- num_frames=num_frames,
222
- vae_scale_factor_spatial=VAE_SPATIAL_SCALE_FACTOR,
223
- patch_size=transformer_config.patch_size,
224
- patch_size_t=transformer_config.patch_size_t if hasattr(transformer_config, "patch_size_t") else None,
225
- attention_head_dim=transformer_config.attention_head_dim,
226
- device=transformer.device,
227
- base_height=rope_base_height,
228
- base_width=rope_base_width,
229
- )
230
- if transformer_config.use_rotary_positional_embeddings
231
- else None
232
- )
233
- ofs_emb = None if transformer_config.ofs_embed_dim is None else latents.new_full((batch_size,), fill_value=2.0)
234
-
235
- velocity = transformer(
236
- hidden_states=noisy_latents,
237
- timestep=timesteps,
238
- encoder_hidden_states=prompt_embeds,
239
- ofs=ofs_emb,
240
- image_rotary_emb=image_rotary_emb,
241
- return_dict=False,
242
- )[0]
243
- # For CogVideoX, the transformer predicts the velocity. The denoised output is calculated by applying the same
244
- # code paths as scheduler.get_velocity(), which can be confusing to understand.
245
- denoised_latents = scheduler.get_velocity(velocity, noisy_latents, timesteps)
246
-
247
- return {"latents": denoised_latents}
248
-
249
-
250
- def validation(
251
- pipeline: CogVideoXPipeline,
252
- prompt: str,
253
- image: Optional[Image.Image] = None,
254
- video: Optional[List[Image.Image]] = None,
255
- height: Optional[int] = None,
256
- width: Optional[int] = None,
257
- num_frames: Optional[int] = None,
258
- num_videos_per_prompt: int = 1,
259
- generator: Optional[torch.Generator] = None,
260
- **kwargs,
261
- ):
262
- generation_kwargs = {
263
- "prompt": prompt,
264
- "height": height,
265
- "width": width,
266
- "num_frames": num_frames,
267
- "num_videos_per_prompt": num_videos_per_prompt,
268
- "generator": generator,
269
- "return_dict": True,
270
- "output_type": "pil",
271
- }
272
- generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
273
- output = pipeline(**generation_kwargs).frames[0]
274
- return [("video", output)]
275
-
276
-
277
- def _get_t5_prompt_embeds(
278
- tokenizer: T5Tokenizer,
279
- text_encoder: T5EncoderModel,
280
- prompt: Union[str, List[str]] = None,
281
- max_sequence_length: int = 226,
282
- device: Optional[torch.device] = None,
283
- dtype: Optional[torch.dtype] = None,
284
- ):
285
- prompt = [prompt] if isinstance(prompt, str) else prompt
286
-
287
- text_inputs = tokenizer(
288
- prompt,
289
- padding="max_length",
290
- max_length=max_sequence_length,
291
- truncation=True,
292
- add_special_tokens=True,
293
- return_tensors="pt",
294
- )
295
- text_input_ids = text_inputs.input_ids
296
-
297
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
298
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
299
-
300
- return {"prompt_embeds": prompt_embeds}
301
-
302
-
303
- def _pad_frames(latents: torch.Tensor, patch_size_t: int):
304
- if patch_size_t is None or patch_size_t == 1:
305
- return latents
306
-
307
- # `latents` should be of the following format: [B, C, F, H, W].
308
- # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
309
- latent_num_frames = latents.shape[2]
310
- additional_frames = patch_size_t - latent_num_frames % patch_size_t
311
-
312
- if additional_frames > 0:
313
- last_frame = latents[:, :, -1:, :, :]
314
- padding_frames = last_frame.repeat(1, 1, additional_frames, 1, 1)
315
- latents = torch.cat([latents, padding_frames], dim=2)
316
-
317
- return latents
318
-
319
-
320
- # TODO(aryan): refactor into model specs for better re-use
321
- COGVIDEOX_T2V_LORA_CONFIG = {
322
- "pipeline_cls": CogVideoXPipeline,
323
- "load_condition_models": load_condition_models,
324
- "load_latent_models": load_latent_models,
325
- "load_diffusion_models": load_diffusion_models,
326
- "initialize_pipeline": initialize_pipeline,
327
- "prepare_conditions": prepare_conditions,
328
- "prepare_latents": prepare_latents,
329
- "post_latent_preparation": post_latent_preparation,
330
- "collate_fn": collate_fn_t2v,
331
- "calculate_noisy_latents": calculate_noisy_latents,
332
- "forward_pass": forward_pass,
333
- "validation": validation,
334
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/hunyuan_video/__init__.py CHANGED
@@ -1,2 +1 @@
1
- from .full_finetune import HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG
2
- from .lora import HUNYUAN_VIDEO_T2V_LORA_CONFIG
 
1
+ from .base_specification import HunyuanVideoModelSpecification
 
finetrainers/models/hunyuan_video/base_specification.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ from accelerate import init_empty_weights
6
+ from diffusers import (
7
+ AutoencoderKLHunyuanVideo,
8
+ FlowMatchEulerDiscreteScheduler,
9
+ HunyuanVideoPipeline,
10
+ HunyuanVideoTransformer3DModel,
11
+ )
12
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
13
+ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel
14
+
15
+ from ... import data
16
+ from ... import functional as FF
17
+ from ...logging import get_logger
18
+ from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin
19
+ from ...typing import ArtifactType, SchedulerType
20
+ from ...utils import get_non_null_items
21
+ from ..modeling_utils import ModelSpecification
22
+
23
+
24
+ logger = get_logger()
25
+
26
+
27
+ class HunyuanLatentEncodeProcessor(ProcessorMixin):
28
+ r"""
29
+ Processor to encode image/video into latents using the HunyuanVideo VAE.
30
+
31
+ Args:
32
+ output_names (`List[str]`):
33
+ The names of the outputs that the processor returns. The outputs are in the following order:
34
+ - latents: The latents of the input image/video.
35
+ """
36
+
37
+ def __init__(self, output_names: List[str]):
38
+ super().__init__()
39
+ self.output_names = output_names
40
+ assert len(self.output_names) == 1
41
+
42
+ def forward(
43
+ self,
44
+ vae: AutoencoderKLHunyuanVideo,
45
+ image: Optional[torch.Tensor] = None,
46
+ video: Optional[torch.Tensor] = None,
47
+ generator: Optional[torch.Generator] = None,
48
+ compute_posterior: bool = True,
49
+ ) -> Dict[str, torch.Tensor]:
50
+ device = vae.device
51
+ dtype = vae.dtype
52
+
53
+ if image is not None:
54
+ video = image.unsqueeze(1)
55
+
56
+ assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
57
+ video = video.to(device=device, dtype=vae.dtype)
58
+ video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
59
+
60
+ if compute_posterior:
61
+ latents = vae.encode(video).latent_dist.sample(generator=generator)
62
+ latents = latents.to(dtype=dtype)
63
+ else:
64
+ if vae.use_slicing and video.shape[0] > 1:
65
+ encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
66
+ moments = torch.cat(encoded_slices)
67
+ else:
68
+ moments = vae._encode(video)
69
+ latents = moments.to(dtype=dtype)
70
+
71
+ return {self.output_names[0]: latents}
72
+
73
+
74
+ class HunyuanVideoModelSpecification(ModelSpecification):
75
+ def __init__(
76
+ self,
77
+ pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo",
78
+ tokenizer_id: Optional[str] = None,
79
+ text_encoder_id: Optional[str] = None,
80
+ transformer_id: Optional[str] = None,
81
+ vae_id: Optional[str] = None,
82
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
83
+ transformer_dtype: torch.dtype = torch.bfloat16,
84
+ vae_dtype: torch.dtype = torch.bfloat16,
85
+ revision: Optional[str] = None,
86
+ cache_dir: Optional[str] = None,
87
+ condition_model_processors: List[ProcessorMixin] = None,
88
+ latent_model_processors: List[ProcessorMixin] = None,
89
+ **kwargs,
90
+ ) -> None:
91
+ super().__init__(
92
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
93
+ tokenizer_id=tokenizer_id,
94
+ text_encoder_id=text_encoder_id,
95
+ transformer_id=transformer_id,
96
+ vae_id=vae_id,
97
+ text_encoder_dtype=text_encoder_dtype,
98
+ transformer_dtype=transformer_dtype,
99
+ vae_dtype=vae_dtype,
100
+ revision=revision,
101
+ cache_dir=cache_dir,
102
+ )
103
+
104
+ if condition_model_processors is None:
105
+ condition_model_processors = [
106
+ LlamaProcessor(["encoder_hidden_states", "encoder_attention_mask"]),
107
+ CLIPPooledProcessor(
108
+ ["pooled_projections"],
109
+ input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"},
110
+ ),
111
+ ]
112
+ if latent_model_processors is None:
113
+ latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])]
114
+
115
+ self.condition_model_processors = condition_model_processors
116
+ self.latent_model_processors = latent_model_processors
117
+
118
+ @property
119
+ def _resolution_dim_keys(self):
120
+ # TODO
121
+ return {
122
+ "latents": (2, 3, 4),
123
+ }
124
+
125
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
126
+ if self.tokenizer_id is not None:
127
+ tokenizer = AutoTokenizer.from_pretrained(
128
+ self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
129
+ )
130
+ else:
131
+ tokenizer = AutoTokenizer.from_pretrained(
132
+ self.pretrained_model_name_or_path,
133
+ subfolder="tokenizer",
134
+ revision=self.revision,
135
+ cache_dir=self.cache_dir,
136
+ )
137
+
138
+ if self.tokenizer_2_id is not None:
139
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
140
+ self.tokenizer_2_id, revision=self.revision, cache_dir=self.cache_dir
141
+ )
142
+ else:
143
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
144
+ self.pretrained_model_name_or_path,
145
+ subfolder="tokenizer_2",
146
+ revision=self.revision,
147
+ cache_dir=self.cache_dir,
148
+ )
149
+
150
+ if self.text_encoder_id is not None:
151
+ text_encoder = LlamaModel.from_pretrained(
152
+ self.text_encoder_id,
153
+ torch_dtype=self.text_encoder_dtype,
154
+ revision=self.revision,
155
+ cache_dir=self.cache_dir,
156
+ )
157
+ else:
158
+ text_encoder = LlamaModel.from_pretrained(
159
+ self.pretrained_model_name_or_path,
160
+ subfolder="text_encoder",
161
+ torch_dtype=self.text_encoder_dtype,
162
+ revision=self.revision,
163
+ cache_dir=self.cache_dir,
164
+ )
165
+
166
+ if self.text_encoder_2_id is not None:
167
+ text_encoder_2 = CLIPTextModel.from_pretrained(
168
+ self.text_encoder_2_id,
169
+ torch_dtype=self.text_encoder_2_dtype,
170
+ revision=self.revision,
171
+ cache_dir=self.cache_dir,
172
+ )
173
+ else:
174
+ text_encoder_2 = CLIPTextModel.from_pretrained(
175
+ self.pretrained_model_name_or_path,
176
+ subfolder="text_encoder_2",
177
+ torch_dtype=self.text_encoder_2_dtype,
178
+ revision=self.revision,
179
+ cache_dir=self.cache_dir,
180
+ )
181
+
182
+ return {
183
+ "tokenizer": tokenizer,
184
+ "tokenizer_2": tokenizer_2,
185
+ "text_encoder": text_encoder,
186
+ "text_encoder_2": text_encoder_2,
187
+ }
188
+
189
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
190
+ if self.vae_id is not None:
191
+ vae = AutoencoderKLHunyuanVideo.from_pretrained(
192
+ self.vae_id,
193
+ torch_dtype=self.vae_dtype,
194
+ revision=self.revision,
195
+ cache_dir=self.cache_dir,
196
+ )
197
+ else:
198
+ vae = AutoencoderKLHunyuanVideo.from_pretrained(
199
+ self.pretrained_model_name_or_path,
200
+ subfolder="vae",
201
+ torch_dtype=self.vae_dtype,
202
+ revision=self.revision,
203
+ cache_dir=self.cache_dir,
204
+ )
205
+
206
+ return {"vae": vae}
207
+
208
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
209
+ if self.transformer_id is not None:
210
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
211
+ self.transformer_id,
212
+ torch_dtype=self.transformer_dtype,
213
+ revision=self.revision,
214
+ cache_dir=self.cache_dir,
215
+ )
216
+ else:
217
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
218
+ self.pretrained_model_name_or_path,
219
+ subfolder="transformer",
220
+ torch_dtype=self.transformer_dtype,
221
+ revision=self.revision,
222
+ cache_dir=self.cache_dir,
223
+ )
224
+
225
+ scheduler = FlowMatchEulerDiscreteScheduler()
226
+
227
+ return {"transformer": transformer, "scheduler": scheduler}
228
+
229
+ def load_pipeline(
230
+ self,
231
+ tokenizer: Optional[AutoTokenizer] = None,
232
+ tokenizer_2: Optional[CLIPTokenizer] = None,
233
+ text_encoder: Optional[LlamaModel] = None,
234
+ text_encoder_2: Optional[CLIPTextModel] = None,
235
+ transformer: Optional[HunyuanVideoTransformer3DModel] = None,
236
+ vae: Optional[AutoencoderKLHunyuanVideo] = None,
237
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
238
+ enable_slicing: bool = False,
239
+ enable_tiling: bool = False,
240
+ enable_model_cpu_offload: bool = False,
241
+ training: bool = False,
242
+ **kwargs,
243
+ ) -> HunyuanVideoPipeline:
244
+ components = {
245
+ "tokenizer": tokenizer,
246
+ "tokenizer_2": tokenizer_2,
247
+ "text_encoder": text_encoder,
248
+ "text_encoder_2": text_encoder_2,
249
+ "transformer": transformer,
250
+ "vae": vae,
251
+ "scheduler": scheduler,
252
+ }
253
+ components = get_non_null_items(components)
254
+
255
+ pipe = HunyuanVideoPipeline.from_pretrained(
256
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
257
+ )
258
+ pipe.text_encoder.to(self.text_encoder_dtype)
259
+ pipe.text_encoder_2.to(self.text_encoder_2_dtype)
260
+ pipe.vae.to(self.vae_dtype)
261
+
262
+ if not training:
263
+ pipe.transformer.to(self.transformer_dtype)
264
+
265
+ if enable_slicing:
266
+ pipe.vae.enable_slicing()
267
+ if enable_tiling:
268
+ pipe.vae.enable_tiling()
269
+ if enable_model_cpu_offload:
270
+ pipe.enable_model_cpu_offload()
271
+
272
+ return pipe
273
+
274
+ @torch.no_grad()
275
+ def prepare_conditions(
276
+ self,
277
+ tokenizer: AutoTokenizer,
278
+ tokenizer_2: CLIPTokenizer,
279
+ text_encoder: LlamaModel,
280
+ text_encoder_2: CLIPTextModel,
281
+ caption: str,
282
+ max_sequence_length: int = 256,
283
+ **kwargs,
284
+ ) -> Dict[str, Any]:
285
+ conditions = {
286
+ "tokenizer": tokenizer,
287
+ "tokenizer_2": tokenizer_2,
288
+ "text_encoder": text_encoder,
289
+ "text_encoder_2": text_encoder_2,
290
+ "caption": caption,
291
+ "max_sequence_length": max_sequence_length,
292
+ **kwargs,
293
+ }
294
+ input_keys = set(conditions.keys())
295
+ conditions = super().prepare_conditions(**conditions)
296
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
297
+ return conditions
298
+
299
+ @torch.no_grad()
300
+ def prepare_latents(
301
+ self,
302
+ vae: AutoencoderKLHunyuanVideo,
303
+ image: Optional[torch.Tensor] = None,
304
+ video: Optional[torch.Tensor] = None,
305
+ generator: Optional[torch.Generator] = None,
306
+ compute_posterior: bool = True,
307
+ **kwargs,
308
+ ) -> Dict[str, torch.Tensor]:
309
+ conditions = {
310
+ "vae": vae,
311
+ "image": image,
312
+ "video": video,
313
+ "generator": generator,
314
+ "compute_posterior": compute_posterior,
315
+ **kwargs,
316
+ }
317
+ input_keys = set(conditions.keys())
318
+ conditions = super().prepare_latents(**conditions)
319
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
320
+ return conditions
321
+
322
+ def forward(
323
+ self,
324
+ transformer: HunyuanVideoTransformer3DModel,
325
+ condition_model_conditions: Dict[str, torch.Tensor],
326
+ latent_model_conditions: Dict[str, torch.Tensor],
327
+ sigmas: torch.Tensor,
328
+ guidance: float = 1.0,
329
+ generator: Optional[torch.Generator] = None,
330
+ compute_posterior: bool = True,
331
+ **kwargs,
332
+ ) -> Tuple[torch.Tensor, ...]:
333
+ if compute_posterior:
334
+ latents = latent_model_conditions.pop("latents")
335
+ else:
336
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
337
+ latents = posterior.sample(generator=generator)
338
+ del posterior
339
+
340
+ latents = latents * self.vae_config.scaling_factor
341
+ noise = torch.zeros_like(latents).normal_(generator=generator)
342
+ noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
343
+
344
+ timesteps = (sigmas.flatten() * 1000.0).long()
345
+ guidance = latents.new_full((latents.size(0),), fill_value=guidance) * 1000.0
346
+
347
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
348
+ latent_model_conditions["guidance"] = guidance
349
+
350
+ pred = transformer(
351
+ **latent_model_conditions,
352
+ **condition_model_conditions,
353
+ timestep=timesteps,
354
+ return_dict=False,
355
+ )[0]
356
+ target = FF.flow_match_target(noise, latents)
357
+
358
+ return pred, target, sigmas
359
+
360
+ def validation(
361
+ self,
362
+ pipeline: HunyuanVideoPipeline,
363
+ prompt: str,
364
+ height: Optional[int] = None,
365
+ width: Optional[int] = None,
366
+ num_frames: Optional[int] = None,
367
+ num_inference_steps: int = 50,
368
+ generator: Optional[torch.Generator] = None,
369
+ **kwargs,
370
+ ) -> List[ArtifactType]:
371
+ generation_kwargs = {
372
+ "prompt": prompt,
373
+ "height": height,
374
+ "width": width,
375
+ "num_frames": num_frames,
376
+ "num_inference_steps": num_inference_steps,
377
+ "generator": generator,
378
+ "return_dict": True,
379
+ "output_type": "pil",
380
+ }
381
+ generation_kwargs = get_non_null_items(generation_kwargs)
382
+ video = pipeline(**generation_kwargs).frames[0]
383
+ return [data.VideoArtifact(value=video)]
384
+
385
+ def _save_lora_weights(
386
+ self,
387
+ directory: str,
388
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
389
+ scheduler: Optional[SchedulerType] = None,
390
+ *args,
391
+ **kwargs,
392
+ ) -> None:
393
+ # TODO(aryan): this needs refactoring
394
+ if transformer_state_dict is not None:
395
+ HunyuanVideoPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
396
+ if scheduler is not None:
397
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
398
+
399
+ def _save_model(
400
+ self,
401
+ directory: str,
402
+ transformer: HunyuanVideoTransformer3DModel,
403
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
404
+ scheduler: Optional[SchedulerType] = None,
405
+ ) -> None:
406
+ # TODO(aryan): this needs refactoring
407
+ if transformer_state_dict is not None:
408
+ with init_empty_weights():
409
+ transformer_copy = HunyuanVideoTransformer3DModel.from_config(transformer.config)
410
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
411
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
412
+ if scheduler is not None:
413
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
finetrainers/models/hunyuan_video/full_finetune.py DELETED
@@ -1,30 +0,0 @@
1
- from diffusers import HunyuanVideoPipeline
2
-
3
- from .lora import (
4
- collate_fn_t2v,
5
- forward_pass,
6
- initialize_pipeline,
7
- load_condition_models,
8
- load_diffusion_models,
9
- load_latent_models,
10
- post_latent_preparation,
11
- prepare_conditions,
12
- prepare_latents,
13
- validation,
14
- )
15
-
16
-
17
- # TODO(aryan): refactor into model specs for better re-use
18
- HUNYUAN_VIDEO_T2V_FULL_FINETUNE_CONFIG = {
19
- "pipeline_cls": HunyuanVideoPipeline,
20
- "load_condition_models": load_condition_models,
21
- "load_latent_models": load_latent_models,
22
- "load_diffusion_models": load_diffusion_models,
23
- "initialize_pipeline": initialize_pipeline,
24
- "prepare_conditions": prepare_conditions,
25
- "prepare_latents": prepare_latents,
26
- "post_latent_preparation": post_latent_preparation,
27
- "collate_fn": collate_fn_t2v,
28
- "forward_pass": forward_pass,
29
- "validation": validation,
30
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/hunyuan_video/lora.py DELETED
@@ -1,368 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- from accelerate.logging import get_logger
6
- from diffusers import (
7
- AutoencoderKLHunyuanVideo,
8
- FlowMatchEulerDiscreteScheduler,
9
- HunyuanVideoPipeline,
10
- HunyuanVideoTransformer3DModel,
11
- )
12
- from PIL import Image
13
- from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizer
14
-
15
-
16
- logger = get_logger("finetrainers") # pylint: disable=invalid-name
17
-
18
-
19
- def load_condition_models(
20
- model_id: str = "hunyuanvideo-community/HunyuanVideo",
21
- text_encoder_dtype: torch.dtype = torch.float16,
22
- text_encoder_2_dtype: torch.dtype = torch.float16,
23
- revision: Optional[str] = None,
24
- cache_dir: Optional[str] = None,
25
- **kwargs,
26
- ) -> Dict[str, nn.Module]:
27
- tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
28
- text_encoder = LlamaModel.from_pretrained(
29
- model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
30
- )
31
- tokenizer_2 = CLIPTokenizer.from_pretrained(
32
- model_id, subfolder="tokenizer_2", revision=revision, cache_dir=cache_dir
33
- )
34
- text_encoder_2 = CLIPTextModel.from_pretrained(
35
- model_id, subfolder="text_encoder_2", torch_dtype=text_encoder_2_dtype, revision=revision, cache_dir=cache_dir
36
- )
37
- return {
38
- "tokenizer": tokenizer,
39
- "text_encoder": text_encoder,
40
- "tokenizer_2": tokenizer_2,
41
- "text_encoder_2": text_encoder_2,
42
- }
43
-
44
-
45
- def load_latent_models(
46
- model_id: str = "hunyuanvideo-community/HunyuanVideo",
47
- vae_dtype: torch.dtype = torch.float16,
48
- revision: Optional[str] = None,
49
- cache_dir: Optional[str] = None,
50
- **kwargs,
51
- ) -> Dict[str, nn.Module]:
52
- vae = AutoencoderKLHunyuanVideo.from_pretrained(
53
- model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
54
- )
55
- return {"vae": vae}
56
-
57
-
58
- def load_diffusion_models(
59
- model_id: str = "hunyuanvideo-community/HunyuanVideo",
60
- transformer_dtype: torch.dtype = torch.bfloat16,
61
- shift: float = 1.0,
62
- revision: Optional[str] = None,
63
- cache_dir: Optional[str] = None,
64
- **kwargs,
65
- ) -> Dict[str, Union[nn.Module, FlowMatchEulerDiscreteScheduler]]:
66
- transformer = HunyuanVideoTransformer3DModel.from_pretrained(
67
- model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
68
- )
69
- scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
70
- return {"transformer": transformer, "scheduler": scheduler}
71
-
72
-
73
- def initialize_pipeline(
74
- model_id: str = "hunyuanvideo-community/HunyuanVideo",
75
- text_encoder_dtype: torch.dtype = torch.float16,
76
- text_encoder_2_dtype: torch.dtype = torch.float16,
77
- transformer_dtype: torch.dtype = torch.bfloat16,
78
- vae_dtype: torch.dtype = torch.float16,
79
- tokenizer: Optional[LlamaTokenizer] = None,
80
- text_encoder: Optional[LlamaModel] = None,
81
- tokenizer_2: Optional[CLIPTokenizer] = None,
82
- text_encoder_2: Optional[CLIPTextModel] = None,
83
- transformer: Optional[HunyuanVideoTransformer3DModel] = None,
84
- vae: Optional[AutoencoderKLHunyuanVideo] = None,
85
- scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
86
- device: Optional[torch.device] = None,
87
- revision: Optional[str] = None,
88
- cache_dir: Optional[str] = None,
89
- enable_slicing: bool = False,
90
- enable_tiling: bool = False,
91
- enable_model_cpu_offload: bool = False,
92
- is_training: bool = False,
93
- **kwargs,
94
- ) -> HunyuanVideoPipeline:
95
- component_name_pairs = [
96
- ("tokenizer", tokenizer),
97
- ("text_encoder", text_encoder),
98
- ("tokenizer_2", tokenizer_2),
99
- ("text_encoder_2", text_encoder_2),
100
- ("transformer", transformer),
101
- ("vae", vae),
102
- ("scheduler", scheduler),
103
- ]
104
- components = {}
105
- for name, component in component_name_pairs:
106
- if component is not None:
107
- components[name] = component
108
-
109
- pipe = HunyuanVideoPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
110
- pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
111
- pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=text_encoder_2_dtype)
112
- pipe.vae = pipe.vae.to(dtype=vae_dtype)
113
-
114
- # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
115
- # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
116
- # DDP optimizer step.
117
- if not is_training:
118
- pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
119
-
120
- if enable_slicing:
121
- pipe.vae.enable_slicing()
122
- if enable_tiling:
123
- pipe.vae.enable_tiling()
124
-
125
- if enable_model_cpu_offload:
126
- pipe.enable_model_cpu_offload(device=device)
127
- else:
128
- pipe.to(device=device)
129
-
130
- return pipe
131
-
132
-
133
- def prepare_conditions(
134
- tokenizer: LlamaTokenizer,
135
- text_encoder: LlamaModel,
136
- tokenizer_2: CLIPTokenizer,
137
- text_encoder_2: CLIPTextModel,
138
- prompt: Union[str, List[str]],
139
- guidance: float = 1.0,
140
- device: Optional[torch.device] = None,
141
- dtype: Optional[torch.dtype] = None,
142
- max_sequence_length: int = 256,
143
- # TODO(aryan): make configurable
144
- prompt_template: Dict[str, Any] = {
145
- "template": (
146
- "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
147
- "1. The main content and theme of the video."
148
- "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
149
- "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
150
- "4. background environment, light, style and atmosphere."
151
- "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
152
- "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
153
- ),
154
- "crop_start": 95,
155
- },
156
- **kwargs,
157
- ) -> torch.Tensor:
158
- device = device or text_encoder.device
159
- dtype = dtype or text_encoder.dtype
160
-
161
- if isinstance(prompt, str):
162
- prompt = [prompt]
163
-
164
- conditions = {}
165
- conditions.update(
166
- _get_llama_prompt_embeds(tokenizer, text_encoder, prompt, prompt_template, device, dtype, max_sequence_length)
167
- )
168
- conditions.update(_get_clip_prompt_embeds(tokenizer_2, text_encoder_2, prompt, device, dtype))
169
-
170
- guidance = torch.tensor([guidance], device=device, dtype=dtype) * 1000.0
171
- conditions["guidance"] = guidance
172
-
173
- return conditions
174
-
175
-
176
- def prepare_latents(
177
- vae: AutoencoderKLHunyuanVideo,
178
- image_or_video: torch.Tensor,
179
- device: Optional[torch.device] = None,
180
- dtype: Optional[torch.dtype] = None,
181
- generator: Optional[torch.Generator] = None,
182
- precompute: bool = False,
183
- **kwargs,
184
- ) -> torch.Tensor:
185
- device = device or vae.device
186
- dtype = dtype or vae.dtype
187
-
188
- if image_or_video.ndim == 4:
189
- image_or_video = image_or_video.unsqueeze(2)
190
- assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
191
-
192
- image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
193
- image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]
194
- if not precompute:
195
- latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
196
- latents = latents * vae.config.scaling_factor
197
- latents = latents.to(dtype=dtype)
198
- return {"latents": latents}
199
- else:
200
- if vae.use_slicing and image_or_video.shape[0] > 1:
201
- encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
202
- h = torch.cat(encoded_slices)
203
- else:
204
- h = vae._encode(image_or_video)
205
- return {"latents": h}
206
-
207
-
208
- def post_latent_preparation(
209
- vae_config: Dict[str, Any],
210
- latents: torch.Tensor,
211
- **kwargs,
212
- ) -> torch.Tensor:
213
- latents = latents * vae_config.scaling_factor
214
- return {"latents": latents}
215
-
216
-
217
- def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
218
- return {
219
- "prompts": [x["prompt"] for x in batch[0]],
220
- "videos": torch.stack([x["video"] for x in batch[0]]),
221
- }
222
-
223
-
224
- def forward_pass(
225
- transformer: HunyuanVideoTransformer3DModel,
226
- prompt_embeds: torch.Tensor,
227
- pooled_prompt_embeds: torch.Tensor,
228
- prompt_attention_mask: torch.Tensor,
229
- guidance: torch.Tensor,
230
- latents: torch.Tensor,
231
- noisy_latents: torch.Tensor,
232
- timesteps: torch.LongTensor,
233
- **kwargs,
234
- ) -> torch.Tensor:
235
- denoised_latents = transformer(
236
- hidden_states=noisy_latents,
237
- timestep=timesteps,
238
- encoder_hidden_states=prompt_embeds,
239
- pooled_projections=pooled_prompt_embeds,
240
- encoder_attention_mask=prompt_attention_mask,
241
- guidance=guidance,
242
- return_dict=False,
243
- )[0]
244
-
245
- return {"latents": denoised_latents}
246
-
247
-
248
- def validation(
249
- pipeline: HunyuanVideoPipeline,
250
- prompt: str,
251
- image: Optional[Image.Image] = None,
252
- video: Optional[List[Image.Image]] = None,
253
- height: Optional[int] = None,
254
- width: Optional[int] = None,
255
- num_frames: Optional[int] = None,
256
- num_videos_per_prompt: int = 1,
257
- generator: Optional[torch.Generator] = None,
258
- **kwargs,
259
- ):
260
- generation_kwargs = {
261
- "prompt": prompt,
262
- "height": height,
263
- "width": width,
264
- "num_frames": num_frames,
265
- "num_inference_steps": 30,
266
- "num_videos_per_prompt": num_videos_per_prompt,
267
- "generator": generator,
268
- "return_dict": True,
269
- "output_type": "pil",
270
- }
271
- generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
272
- output = pipeline(**generation_kwargs).frames[0]
273
- return [("video", output)]
274
-
275
-
276
- def _get_llama_prompt_embeds(
277
- tokenizer: LlamaTokenizer,
278
- text_encoder: LlamaModel,
279
- prompt: List[str],
280
- prompt_template: Dict[str, Any],
281
- device: torch.device,
282
- dtype: torch.dtype,
283
- max_sequence_length: int = 256,
284
- num_hidden_layers_to_skip: int = 2,
285
- ) -> Tuple[torch.Tensor, torch.Tensor]:
286
- batch_size = len(prompt)
287
- prompt = [prompt_template["template"].format(p) for p in prompt]
288
-
289
- crop_start = prompt_template.get("crop_start", None)
290
- if crop_start is None:
291
- prompt_template_input = tokenizer(
292
- prompt_template["template"],
293
- padding="max_length",
294
- return_tensors="pt",
295
- return_length=False,
296
- return_overflowing_tokens=False,
297
- return_attention_mask=False,
298
- )
299
- crop_start = prompt_template_input["input_ids"].shape[-1]
300
- # Remove <|eot_id|> token and placeholder {}
301
- crop_start -= 2
302
-
303
- max_sequence_length += crop_start
304
- text_inputs = tokenizer(
305
- prompt,
306
- max_length=max_sequence_length,
307
- padding="max_length",
308
- truncation=True,
309
- return_tensors="pt",
310
- return_length=False,
311
- return_overflowing_tokens=False,
312
- return_attention_mask=True,
313
- )
314
- text_input_ids = text_inputs.input_ids.to(device=device)
315
- prompt_attention_mask = text_inputs.attention_mask.to(device=device)
316
-
317
- prompt_embeds = text_encoder(
318
- input_ids=text_input_ids,
319
- attention_mask=prompt_attention_mask,
320
- output_hidden_states=True,
321
- ).hidden_states[-(num_hidden_layers_to_skip + 1)]
322
- prompt_embeds = prompt_embeds.to(dtype=dtype)
323
-
324
- if crop_start is not None and crop_start > 0:
325
- prompt_embeds = prompt_embeds[:, crop_start:]
326
- prompt_attention_mask = prompt_attention_mask[:, crop_start:]
327
-
328
- prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
329
-
330
- return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
331
-
332
-
333
- def _get_clip_prompt_embeds(
334
- tokenizer_2: CLIPTokenizer,
335
- text_encoder_2: CLIPTextModel,
336
- prompt: Union[str, List[str]],
337
- device: torch.device,
338
- dtype: torch.dtype,
339
- max_sequence_length: int = 77,
340
- ) -> torch.Tensor:
341
- text_inputs = tokenizer_2(
342
- prompt,
343
- padding="max_length",
344
- max_length=max_sequence_length,
345
- truncation=True,
346
- return_tensors="pt",
347
- )
348
-
349
- prompt_embeds = text_encoder_2(text_inputs.input_ids.to(device), output_hidden_states=False).pooler_output
350
- prompt_embeds = prompt_embeds.to(dtype=dtype)
351
-
352
- return {"pooled_prompt_embeds": prompt_embeds}
353
-
354
-
355
- # TODO(aryan): refactor into model specs for better re-use
356
- HUNYUAN_VIDEO_T2V_LORA_CONFIG = {
357
- "pipeline_cls": HunyuanVideoPipeline,
358
- "load_condition_models": load_condition_models,
359
- "load_latent_models": load_latent_models,
360
- "load_diffusion_models": load_diffusion_models,
361
- "initialize_pipeline": initialize_pipeline,
362
- "prepare_conditions": prepare_conditions,
363
- "prepare_latents": prepare_latents,
364
- "post_latent_preparation": post_latent_preparation,
365
- "collate_fn": collate_fn_t2v,
366
- "forward_pass": forward_pass,
367
- "validation": validation,
368
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/ltx_video/__init__.py CHANGED
@@ -1,2 +1 @@
1
- from .full_finetune import LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG
2
- from .lora import LTX_VIDEO_T2V_LORA_CONFIG
 
1
+ from .base_specification import LTXVideoModelSpecification
 
finetrainers/models/ltx_video/base_specification.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ from accelerate import init_empty_weights
7
+ from diffusers import (
8
+ AutoencoderKLLTXVideo,
9
+ FlowMatchEulerDiscreteScheduler,
10
+ LTXImageToVideoPipeline,
11
+ LTXPipeline,
12
+ LTXVideoTransformer3DModel,
13
+ )
14
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
15
+ from PIL.Image import Image
16
+ from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
17
+
18
+ from ... import data
19
+ from ... import functional as FF
20
+ from ...logging import get_logger
21
+ from ...parallel import ParallelBackendEnum
22
+ from ...processors import ProcessorMixin, T5Processor
23
+ from ...typing import ArtifactType, SchedulerType
24
+ from ...utils import get_non_null_items
25
+ from ..modeling_utils import ModelSpecification
26
+
27
+
28
+ logger = get_logger()
29
+
30
+
31
+ class LTXLatentEncodeProcessor(ProcessorMixin):
32
+ r"""
33
+ Processor to encode image/video into latents using the LTX VAE.
34
+
35
+ Args:
36
+ output_names (`List[str]`):
37
+ The names of the outputs that the processor returns. The outputs are in the following order:
38
+ - latents: The latents of the input image/video.
39
+ - num_frames: The number of frames in the input video.
40
+ - height: The height of the input image/video.
41
+ - width: The width of the input image/video.
42
+ - latents_mean: The latent channel means from the VAE state dict.
43
+ - latents_std: The latent channel standard deviations from the VAE state dict.
44
+ """
45
+
46
+ def __init__(self, output_names: List[str]):
47
+ super().__init__()
48
+ self.output_names = output_names
49
+ assert len(self.output_names) == 6
50
+
51
+ def forward(
52
+ self,
53
+ vae: AutoencoderKLLTXVideo,
54
+ image: Optional[torch.Tensor] = None,
55
+ video: Optional[torch.Tensor] = None,
56
+ generator: Optional[torch.Generator] = None,
57
+ compute_posterior: bool = True,
58
+ ) -> Dict[str, torch.Tensor]:
59
+ device = vae.device
60
+ dtype = vae.dtype
61
+
62
+ if image is not None:
63
+ video = image.unsqueeze(1)
64
+
65
+ assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
66
+ video = video.to(device=device, dtype=vae.dtype)
67
+ video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
68
+
69
+ if compute_posterior:
70
+ latents = vae.encode(video).latent_dist.sample(generator=generator)
71
+ latents = latents.to(dtype=dtype)
72
+ else:
73
+ if vae.use_slicing and video.shape[0] > 1:
74
+ encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
75
+ moments = torch.cat(encoded_slices)
76
+ else:
77
+ moments = vae._encode(video)
78
+ latents = moments.to(dtype=dtype)
79
+
80
+ _, _, num_frames, height, width = latents.shape
81
+
82
+ return {
83
+ self.output_names[0]: latents,
84
+ self.output_names[1]: num_frames,
85
+ self.output_names[2]: height,
86
+ self.output_names[3]: width,
87
+ self.output_names[4]: vae.latents_mean,
88
+ self.output_names[5]: vae.latents_std,
89
+ }
90
+
91
+
92
+ class LTXVideoModelSpecification(ModelSpecification):
93
+ def __init__(
94
+ self,
95
+ pretrained_model_name_or_path: str = "Lightricks/LTX-Video",
96
+ tokenizer_id: Optional[str] = None,
97
+ text_encoder_id: Optional[str] = None,
98
+ transformer_id: Optional[str] = None,
99
+ vae_id: Optional[str] = None,
100
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
101
+ transformer_dtype: torch.dtype = torch.bfloat16,
102
+ vae_dtype: torch.dtype = torch.bfloat16,
103
+ revision: Optional[str] = None,
104
+ cache_dir: Optional[str] = None,
105
+ condition_model_processors: List[ProcessorMixin] = None,
106
+ latent_model_processors: List[ProcessorMixin] = None,
107
+ **kwargs,
108
+ ) -> None:
109
+ super().__init__(
110
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
111
+ tokenizer_id=tokenizer_id,
112
+ text_encoder_id=text_encoder_id,
113
+ transformer_id=transformer_id,
114
+ vae_id=vae_id,
115
+ text_encoder_dtype=text_encoder_dtype,
116
+ transformer_dtype=transformer_dtype,
117
+ vae_dtype=vae_dtype,
118
+ revision=revision,
119
+ cache_dir=cache_dir,
120
+ )
121
+
122
+ if condition_model_processors is None:
123
+ condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
124
+ if latent_model_processors is None:
125
+ latent_model_processors = [
126
+ LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"])
127
+ ]
128
+
129
+ self.condition_model_processors = condition_model_processors
130
+ self.latent_model_processors = latent_model_processors
131
+
132
+ @property
133
+ def _resolution_dim_keys(self):
134
+ return {
135
+ "latents": (2, 3, 4),
136
+ }
137
+
138
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
139
+ if self.tokenizer_id is not None:
140
+ tokenizer = AutoTokenizer.from_pretrained(
141
+ self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
142
+ )
143
+ else:
144
+ tokenizer = T5Tokenizer.from_pretrained(
145
+ self.pretrained_model_name_or_path,
146
+ subfolder="tokenizer",
147
+ revision=self.revision,
148
+ cache_dir=self.cache_dir,
149
+ )
150
+
151
+ if self.text_encoder_id is not None:
152
+ text_encoder = AutoModel.from_pretrained(
153
+ self.text_encoder_id,
154
+ torch_dtype=self.text_encoder_dtype,
155
+ revision=self.revision,
156
+ cache_dir=self.cache_dir,
157
+ )
158
+ else:
159
+ text_encoder = T5EncoderModel.from_pretrained(
160
+ self.pretrained_model_name_or_path,
161
+ subfolder="text_encoder",
162
+ torch_dtype=self.text_encoder_dtype,
163
+ revision=self.revision,
164
+ cache_dir=self.cache_dir,
165
+ )
166
+
167
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
168
+
169
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
170
+ if self.vae_id is not None:
171
+ vae = AutoencoderKLLTXVideo.from_pretrained(
172
+ self.vae_id,
173
+ torch_dtype=self.vae_dtype,
174
+ revision=self.revision,
175
+ cache_dir=self.cache_dir,
176
+ )
177
+ else:
178
+ vae = AutoencoderKLLTXVideo.from_pretrained(
179
+ self.pretrained_model_name_or_path,
180
+ subfolder="vae",
181
+ torch_dtype=self.vae_dtype,
182
+ revision=self.revision,
183
+ cache_dir=self.cache_dir,
184
+ )
185
+
186
+ return {"vae": vae}
187
+
188
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
189
+ if self.transformer_id is not None:
190
+ transformer = LTXVideoTransformer3DModel.from_pretrained(
191
+ self.transformer_id,
192
+ torch_dtype=self.transformer_dtype,
193
+ revision=self.revision,
194
+ cache_dir=self.cache_dir,
195
+ )
196
+ else:
197
+ transformer = LTXVideoTransformer3DModel.from_pretrained(
198
+ self.pretrained_model_name_or_path,
199
+ subfolder="transformer",
200
+ torch_dtype=self.transformer_dtype,
201
+ revision=self.revision,
202
+ cache_dir=self.cache_dir,
203
+ )
204
+
205
+ scheduler = FlowMatchEulerDiscreteScheduler()
206
+
207
+ return {"transformer": transformer, "scheduler": scheduler}
208
+
209
+ def load_pipeline(
210
+ self,
211
+ tokenizer: Optional[T5Tokenizer] = None,
212
+ text_encoder: Optional[T5EncoderModel] = None,
213
+ transformer: Optional[LTXVideoTransformer3DModel] = None,
214
+ vae: Optional[AutoencoderKLLTXVideo] = None,
215
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
216
+ enable_slicing: bool = False,
217
+ enable_tiling: bool = False,
218
+ enable_model_cpu_offload: bool = False,
219
+ training: bool = False,
220
+ **kwargs,
221
+ ) -> LTXPipeline:
222
+ components = {
223
+ "tokenizer": tokenizer,
224
+ "text_encoder": text_encoder,
225
+ "transformer": transformer,
226
+ "vae": vae,
227
+ "scheduler": scheduler,
228
+ }
229
+ components = get_non_null_items(components)
230
+
231
+ pipe = LTXPipeline.from_pretrained(
232
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
233
+ )
234
+ pipe.text_encoder.to(self.text_encoder_dtype)
235
+ pipe.vae.to(self.vae_dtype)
236
+
237
+ if not training:
238
+ pipe.transformer.to(self.transformer_dtype)
239
+
240
+ if enable_slicing:
241
+ pipe.vae.enable_slicing()
242
+ if enable_tiling:
243
+ pipe.vae.enable_tiling()
244
+ if enable_model_cpu_offload:
245
+ pipe.enable_model_cpu_offload()
246
+
247
+ return pipe
248
+
249
+ @torch.no_grad()
250
+ def prepare_conditions(
251
+ self,
252
+ tokenizer: T5Tokenizer,
253
+ text_encoder: T5EncoderModel,
254
+ caption: str,
255
+ max_sequence_length: int = 128,
256
+ **kwargs,
257
+ ) -> Dict[str, Any]:
258
+ conditions = {
259
+ "tokenizer": tokenizer,
260
+ "text_encoder": text_encoder,
261
+ "caption": caption,
262
+ "max_sequence_length": max_sequence_length,
263
+ **kwargs,
264
+ }
265
+ input_keys = set(conditions.keys())
266
+ conditions = super().prepare_conditions(**conditions)
267
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
268
+ return conditions
269
+
270
+ @torch.no_grad()
271
+ def prepare_latents(
272
+ self,
273
+ vae: AutoencoderKLLTXVideo,
274
+ image: Optional[torch.Tensor] = None,
275
+ video: Optional[torch.Tensor] = None,
276
+ generator: Optional[torch.Generator] = None,
277
+ compute_posterior: bool = True,
278
+ **kwargs,
279
+ ) -> Dict[str, torch.Tensor]:
280
+ conditions = {
281
+ "vae": vae,
282
+ "image": image,
283
+ "video": video,
284
+ "generator": generator,
285
+ "compute_posterior": compute_posterior,
286
+ **kwargs,
287
+ }
288
+ input_keys = set(conditions.keys())
289
+ conditions = super().prepare_latents(**conditions)
290
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
291
+ return conditions
292
+
293
+ def forward(
294
+ self,
295
+ transformer: LTXVideoTransformer3DModel,
296
+ condition_model_conditions: Dict[str, torch.Tensor],
297
+ latent_model_conditions: Dict[str, torch.Tensor],
298
+ sigmas: torch.Tensor,
299
+ generator: Optional[torch.Generator] = None,
300
+ compute_posterior: bool = True,
301
+ **kwargs,
302
+ ) -> Tuple[torch.Tensor, ...]:
303
+ # TODO(aryan): make this configurable? Should it be?
304
+ first_frame_conditioning_p = 0.1
305
+ min_first_frame_sigma = 0.25
306
+
307
+ if compute_posterior:
308
+ latents = latent_model_conditions.pop("latents")
309
+ else:
310
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
311
+ latents = posterior.sample(generator=generator)
312
+ del posterior
313
+
314
+ latents_mean = latent_model_conditions.pop("latents_mean")
315
+ latents_std = latent_model_conditions.pop("latents_std")
316
+
317
+ latents = self._normalize_latents(latents, latents_mean, latents_std)
318
+ noise = torch.zeros_like(latents).normal_(generator=generator)
319
+
320
+ if random.random() < first_frame_conditioning_p:
321
+ # Based on Section 2.4 of the paper, it mentions that the first frame timesteps should be a small random value.
322
+ # Making as estimated guess, we limit the sigmas to be at least 0.2.
323
+ # torch.rand_like returns values in [0, 1). We want to make sure that the first frame sigma is <= actual sigmas
324
+ # for image conditioning. In order to do this, we rescale by multiplying with sigmas so the range is [0, sigmas).
325
+ first_frame_sigma = torch.rand_like(sigmas) * sigmas
326
+ first_frame_sigma = torch.min(first_frame_sigma, sigmas.new_full(sigmas.shape, min_first_frame_sigma))
327
+
328
+ latents_first_frame, latents_rest = latents[:, :, :1], latents[:, :, 1:]
329
+ noisy_latents_first_frame = FF.flow_match_xt(latents_first_frame, noise[:, :, :1], first_frame_sigma)
330
+ noisy_latents_remaining = FF.flow_match_xt(latents_rest, noise[:, :, 1:], sigmas)
331
+ noisy_latents = torch.cat([noisy_latents_first_frame, noisy_latents_remaining], dim=2)
332
+ else:
333
+ noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
334
+
335
+ patch_size = self.transformer_config.patch_size
336
+ patch_size_t = self.transformer_config.patch_size_t
337
+
338
+ latents = self._pack_latents(latents, patch_size, patch_size_t)
339
+ noise = self._pack_latents(noise, patch_size, patch_size_t)
340
+ noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
341
+
342
+ sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
343
+
344
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
345
+ condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
346
+ condition_model_conditions["encoder_attention_mask"] = condition_model_conditions.pop("prompt_attention_mask")
347
+
348
+ # TODO(aryan): make this configurable
349
+ frame_rate = 25
350
+ temporal_compression_ratio = 8
351
+ vae_spatial_compression_ratio = 32
352
+ latent_frame_rate = frame_rate / temporal_compression_ratio
353
+
354
+ rope_interpolation_scale = [
355
+ 1 / latent_frame_rate,
356
+ vae_spatial_compression_ratio,
357
+ vae_spatial_compression_ratio,
358
+ ]
359
+ timesteps = (sigmas * 1000.0).long()
360
+
361
+ pred = transformer(
362
+ **latent_model_conditions,
363
+ **condition_model_conditions,
364
+ timestep=timesteps,
365
+ rope_interpolation_scale=rope_interpolation_scale,
366
+ return_dict=False,
367
+ )[0]
368
+ target = FF.flow_match_target(noise, latents)
369
+
370
+ return pred, target, sigmas
371
+
372
+ def validation(
373
+ self,
374
+ pipeline: LTXPipeline,
375
+ prompt: str,
376
+ image: Optional[Image] = None,
377
+ height: Optional[int] = None,
378
+ width: Optional[int] = None,
379
+ num_frames: Optional[int] = None,
380
+ frame_rate: int = 25,
381
+ num_inference_steps: int = 50,
382
+ generator: Optional[torch.Generator] = None,
383
+ **kwargs,
384
+ ) -> List[ArtifactType]:
385
+ if image is not None:
386
+ pipeline = LTXImageToVideoPipeline.from_pipe(pipeline)
387
+
388
+ generation_kwargs = {
389
+ "prompt": prompt,
390
+ "image": image,
391
+ "height": height,
392
+ "width": width,
393
+ "num_frames": num_frames,
394
+ "frame_rate": frame_rate,
395
+ "num_inference_steps": num_inference_steps,
396
+ "generator": generator,
397
+ "return_dict": True,
398
+ "output_type": "pil",
399
+ }
400
+ generation_kwargs = get_non_null_items(generation_kwargs)
401
+ video = pipeline(**generation_kwargs).frames[0]
402
+ return [data.VideoArtifact(value=video)]
403
+
404
+ def _save_lora_weights(
405
+ self,
406
+ directory: str,
407
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
408
+ scheduler: Optional[SchedulerType] = None,
409
+ *args,
410
+ **kwargs,
411
+ ) -> None:
412
+ # TODO(aryan): this needs refactoring
413
+ if transformer_state_dict is not None:
414
+ LTXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
415
+ if scheduler is not None:
416
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
417
+
418
+ def _save_model(
419
+ self,
420
+ directory: str,
421
+ transformer: LTXVideoTransformer3DModel,
422
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
423
+ scheduler: Optional[SchedulerType] = None,
424
+ ) -> None:
425
+ # TODO(aryan): this needs refactoring
426
+ if transformer_state_dict is not None:
427
+ with init_empty_weights():
428
+ transformer_copy = LTXVideoTransformer3DModel.from_config(transformer.config)
429
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
430
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
431
+ if scheduler is not None:
432
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
433
+
434
+ def apply_tensor_parallel(
435
+ self,
436
+ backend: ParallelBackendEnum,
437
+ device_mesh: torch.distributed.DeviceMesh,
438
+ transformer: LTXVideoTransformer3DModel,
439
+ **kwargs,
440
+ ) -> None:
441
+ if backend == ParallelBackendEnum.PTD:
442
+ _apply_tensor_parallel_ptd(device_mesh, transformer)
443
+ else:
444
+ raise NotImplementedError(f"Parallel backend {backend} is not supported for LTXVideoModelSpecification")
445
+
446
+ @staticmethod
447
+ def _normalize_latents(
448
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
449
+ ) -> torch.Tensor:
450
+ # Normalize latents across the channel dimension [B, C, F, H, W]
451
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
452
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
453
+ latents = (latents - latents_mean) * scaling_factor / latents_std
454
+ return latents
455
+
456
+ @staticmethod
457
+ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
458
+ # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
459
+ # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
460
+ # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
461
+ # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
462
+ batch_size, num_channels, num_frames, height, width = latents.shape
463
+ post_patch_num_frames = num_frames // patch_size_t
464
+ post_patch_height = height // patch_size
465
+ post_patch_width = width // patch_size
466
+ latents = latents.reshape(
467
+ batch_size,
468
+ -1,
469
+ post_patch_num_frames,
470
+ patch_size_t,
471
+ post_patch_height,
472
+ patch_size,
473
+ post_patch_width,
474
+ patch_size,
475
+ )
476
+ latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
477
+ return latents
478
+
479
+
480
+ def _apply_tensor_parallel_ptd(
481
+ device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel
482
+ ) -> None:
483
+ from torch.distributed.tensor.parallel import parallelize_module
484
+ from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel
485
+
486
+ transformer_plan = {
487
+ # ===== Condition embeddings =====
488
+ # "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(),
489
+ # "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)),
490
+ # "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()),
491
+ # "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())),
492
+ # "caption_projection.linear_1": ColwiseParallel(),
493
+ # "caption_projection.linear_2": RowwiseParallel(),
494
+ # "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False),
495
+ # ===== =====
496
+ }
497
+
498
+ for block in transformer.transformer_blocks:
499
+ block_plan = {}
500
+
501
+ # ===== Attention =====
502
+ # 8 all-to-all, 3 all-reduce
503
+ # block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False)
504
+ # block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False)
505
+ # block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False)
506
+ # block_plan["attn1.norm_q"] = SequenceParallel()
507
+ # block_plan["attn1.norm_k"] = SequenceParallel()
508
+ # block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
509
+ # block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False)
510
+ # block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False)
511
+ # block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False)
512
+ # block_plan["attn2.norm_q"] = SequenceParallel()
513
+ # block_plan["attn2.norm_k"] = SequenceParallel()
514
+ # block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
515
+ # ===== =====
516
+
517
+ block_plan["ff.net.0.proj"] = ColwiseParallel()
518
+ block_plan["ff.net.2"] = RowwiseParallel()
519
+
520
+ parallelize_module(block, device_mesh, block_plan)
521
+
522
+ parallelize_module(transformer, device_mesh, transformer_plan)
finetrainers/models/ltx_video/full_finetune.py DELETED
@@ -1,30 +0,0 @@
1
- from diffusers import LTXPipeline
2
-
3
- from .lora import (
4
- collate_fn_t2v,
5
- forward_pass,
6
- initialize_pipeline,
7
- load_condition_models,
8
- load_diffusion_models,
9
- load_latent_models,
10
- post_latent_preparation,
11
- prepare_conditions,
12
- prepare_latents,
13
- validation,
14
- )
15
-
16
-
17
- # TODO(aryan): refactor into model specs for better re-use
18
- LTX_VIDEO_T2V_FULL_FINETUNE_CONFIG = {
19
- "pipeline_cls": LTXPipeline,
20
- "load_condition_models": load_condition_models,
21
- "load_latent_models": load_latent_models,
22
- "load_diffusion_models": load_diffusion_models,
23
- "initialize_pipeline": initialize_pipeline,
24
- "prepare_conditions": prepare_conditions,
25
- "prepare_latents": prepare_latents,
26
- "post_latent_preparation": post_latent_preparation,
27
- "collate_fn": collate_fn_t2v,
28
- "forward_pass": forward_pass,
29
- "validation": validation,
30
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/ltx_video/lora.py DELETED
@@ -1,331 +0,0 @@
1
- from typing import Dict, List, Optional, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
- from accelerate.logging import get_logger
6
- from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
7
- from PIL import Image
8
- from transformers import T5EncoderModel, T5Tokenizer
9
-
10
-
11
- logger = get_logger("finetrainers") # pylint: disable=invalid-name
12
-
13
-
14
- def load_condition_models(
15
- model_id: str = "Lightricks/LTX-Video",
16
- text_encoder_dtype: torch.dtype = torch.bfloat16,
17
- revision: Optional[str] = None,
18
- cache_dir: Optional[str] = None,
19
- **kwargs,
20
- ) -> Dict[str, nn.Module]:
21
- tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer", revision=revision, cache_dir=cache_dir)
22
- text_encoder = T5EncoderModel.from_pretrained(
23
- model_id, subfolder="text_encoder", torch_dtype=text_encoder_dtype, revision=revision, cache_dir=cache_dir
24
- )
25
- return {"tokenizer": tokenizer, "text_encoder": text_encoder}
26
-
27
-
28
- def load_latent_models(
29
- model_id: str = "Lightricks/LTX-Video",
30
- vae_dtype: torch.dtype = torch.bfloat16,
31
- revision: Optional[str] = None,
32
- cache_dir: Optional[str] = None,
33
- **kwargs,
34
- ) -> Dict[str, nn.Module]:
35
- vae = AutoencoderKLLTXVideo.from_pretrained(
36
- model_id, subfolder="vae", torch_dtype=vae_dtype, revision=revision, cache_dir=cache_dir
37
- )
38
- return {"vae": vae}
39
-
40
-
41
- def load_diffusion_models(
42
- model_id: str = "Lightricks/LTX-Video",
43
- transformer_dtype: torch.dtype = torch.bfloat16,
44
- revision: Optional[str] = None,
45
- cache_dir: Optional[str] = None,
46
- **kwargs,
47
- ) -> Dict[str, nn.Module]:
48
- transformer = LTXVideoTransformer3DModel.from_pretrained(
49
- model_id, subfolder="transformer", torch_dtype=transformer_dtype, revision=revision, cache_dir=cache_dir
50
- )
51
- scheduler = FlowMatchEulerDiscreteScheduler()
52
- return {"transformer": transformer, "scheduler": scheduler}
53
-
54
-
55
- def initialize_pipeline(
56
- model_id: str = "Lightricks/LTX-Video",
57
- text_encoder_dtype: torch.dtype = torch.bfloat16,
58
- transformer_dtype: torch.dtype = torch.bfloat16,
59
- vae_dtype: torch.dtype = torch.bfloat16,
60
- tokenizer: Optional[T5Tokenizer] = None,
61
- text_encoder: Optional[T5EncoderModel] = None,
62
- transformer: Optional[LTXVideoTransformer3DModel] = None,
63
- vae: Optional[AutoencoderKLLTXVideo] = None,
64
- scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
65
- device: Optional[torch.device] = None,
66
- revision: Optional[str] = None,
67
- cache_dir: Optional[str] = None,
68
- enable_slicing: bool = False,
69
- enable_tiling: bool = False,
70
- enable_model_cpu_offload: bool = False,
71
- is_training: bool = False,
72
- **kwargs,
73
- ) -> LTXPipeline:
74
- component_name_pairs = [
75
- ("tokenizer", tokenizer),
76
- ("text_encoder", text_encoder),
77
- ("transformer", transformer),
78
- ("vae", vae),
79
- ("scheduler", scheduler),
80
- ]
81
- components = {}
82
- for name, component in component_name_pairs:
83
- if component is not None:
84
- components[name] = component
85
-
86
- pipe = LTXPipeline.from_pretrained(model_id, **components, revision=revision, cache_dir=cache_dir)
87
- pipe.text_encoder = pipe.text_encoder.to(dtype=text_encoder_dtype)
88
- pipe.vae = pipe.vae.to(dtype=vae_dtype)
89
- # The transformer should already be in the correct dtype when training, so we don't need to cast it here.
90
- # If we cast, whilst using fp8 layerwise upcasting hooks, it will lead to an error in the training during
91
- # DDP optimizer step.
92
- if not is_training:
93
- pipe.transformer = pipe.transformer.to(dtype=transformer_dtype)
94
-
95
- if enable_slicing:
96
- pipe.vae.enable_slicing()
97
- if enable_tiling:
98
- pipe.vae.enable_tiling()
99
-
100
- if enable_model_cpu_offload:
101
- pipe.enable_model_cpu_offload(device=device)
102
- else:
103
- pipe.to(device=device)
104
-
105
- return pipe
106
-
107
-
108
- def prepare_conditions(
109
- tokenizer: T5Tokenizer,
110
- text_encoder: T5EncoderModel,
111
- prompt: Union[str, List[str]],
112
- device: Optional[torch.device] = None,
113
- dtype: Optional[torch.dtype] = None,
114
- max_sequence_length: int = 128,
115
- **kwargs,
116
- ) -> torch.Tensor:
117
- device = device or text_encoder.device
118
- dtype = dtype or text_encoder.dtype
119
-
120
- if isinstance(prompt, str):
121
- prompt = [prompt]
122
-
123
- return _encode_prompt_t5(tokenizer, text_encoder, prompt, device, dtype, max_sequence_length)
124
-
125
-
126
- def prepare_latents(
127
- vae: AutoencoderKLLTXVideo,
128
- image_or_video: torch.Tensor,
129
- patch_size: int = 1,
130
- patch_size_t: int = 1,
131
- device: Optional[torch.device] = None,
132
- dtype: Optional[torch.dtype] = None,
133
- generator: Optional[torch.Generator] = None,
134
- precompute: bool = False,
135
- ) -> torch.Tensor:
136
- device = device or vae.device
137
-
138
- if image_or_video.ndim == 4:
139
- image_or_video = image_or_video.unsqueeze(2)
140
- assert image_or_video.ndim == 5, f"Expected 5D tensor, got {image_or_video.ndim}D tensor"
141
-
142
- image_or_video = image_or_video.to(device=device, dtype=vae.dtype)
143
- image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, C, F, H, W] -> [B, F, C, H, W]
144
- if not precompute:
145
- latents = vae.encode(image_or_video).latent_dist.sample(generator=generator)
146
- latents = latents.to(dtype=dtype)
147
- _, _, num_frames, height, width = latents.shape
148
- latents = _normalize_latents(latents, vae.latents_mean, vae.latents_std)
149
- latents = _pack_latents(latents, patch_size, patch_size_t)
150
- return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
151
- else:
152
- if vae.use_slicing and image_or_video.shape[0] > 1:
153
- encoded_slices = [vae._encode(x_slice) for x_slice in image_or_video.split(1)]
154
- h = torch.cat(encoded_slices)
155
- else:
156
- h = vae._encode(image_or_video)
157
- _, _, num_frames, height, width = h.shape
158
-
159
- # TODO(aryan): This is very stupid that we might possibly be storing the latents_mean and latents_std in every file
160
- # if precomputation is enabled. We should probably have a single file where re-usable properties like this are stored
161
- # so as to reduce the disk memory requirements of the precomputed files.
162
- return {
163
- "latents": h,
164
- "num_frames": num_frames,
165
- "height": height,
166
- "width": width,
167
- "latents_mean": vae.latents_mean,
168
- "latents_std": vae.latents_std,
169
- }
170
-
171
-
172
- def post_latent_preparation(
173
- latents: torch.Tensor,
174
- latents_mean: torch.Tensor,
175
- latents_std: torch.Tensor,
176
- num_frames: int,
177
- height: int,
178
- width: int,
179
- patch_size: int = 1,
180
- patch_size_t: int = 1,
181
- **kwargs,
182
- ) -> torch.Tensor:
183
- latents = _normalize_latents(latents, latents_mean, latents_std)
184
- latents = _pack_latents(latents, patch_size, patch_size_t)
185
- return {"latents": latents, "num_frames": num_frames, "height": height, "width": width}
186
-
187
-
188
- def collate_fn_t2v(batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
189
- return {
190
- "prompts": [x["prompt"] for x in batch[0]],
191
- "videos": torch.stack([x["video"] for x in batch[0]]),
192
- }
193
-
194
-
195
- def forward_pass(
196
- transformer: LTXVideoTransformer3DModel,
197
- prompt_embeds: torch.Tensor,
198
- prompt_attention_mask: torch.Tensor,
199
- latents: torch.Tensor,
200
- noisy_latents: torch.Tensor,
201
- timesteps: torch.LongTensor,
202
- num_frames: int,
203
- height: int,
204
- width: int,
205
- **kwargs,
206
- ) -> torch.Tensor:
207
- # TODO(aryan): make configurable
208
- frame_rate = 25
209
- latent_frame_rate = frame_rate / 8
210
- spatial_compression_ratio = 32
211
- rope_interpolation_scale = [1 / latent_frame_rate, spatial_compression_ratio, spatial_compression_ratio]
212
-
213
- denoised_latents = transformer(
214
- hidden_states=noisy_latents,
215
- encoder_hidden_states=prompt_embeds,
216
- timestep=timesteps,
217
- encoder_attention_mask=prompt_attention_mask,
218
- num_frames=num_frames,
219
- height=height,
220
- width=width,
221
- rope_interpolation_scale=rope_interpolation_scale,
222
- return_dict=False,
223
- )[0]
224
-
225
- return {"latents": denoised_latents}
226
-
227
-
228
- def validation(
229
- pipeline: LTXPipeline,
230
- prompt: str,
231
- image: Optional[Image.Image] = None,
232
- video: Optional[List[Image.Image]] = None,
233
- height: Optional[int] = None,
234
- width: Optional[int] = None,
235
- num_frames: Optional[int] = None,
236
- frame_rate: int = 24,
237
- num_videos_per_prompt: int = 1,
238
- generator: Optional[torch.Generator] = None,
239
- **kwargs,
240
- ):
241
- generation_kwargs = {
242
- "prompt": prompt,
243
- "height": height,
244
- "width": width,
245
- "num_frames": num_frames,
246
- "frame_rate": frame_rate,
247
- "num_videos_per_prompt": num_videos_per_prompt,
248
- "generator": generator,
249
- "return_dict": True,
250
- "output_type": "pil",
251
- }
252
- generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
253
- video = pipeline(**generation_kwargs).frames[0]
254
- return [("video", video)]
255
-
256
-
257
- def _encode_prompt_t5(
258
- tokenizer: T5Tokenizer,
259
- text_encoder: T5EncoderModel,
260
- prompt: List[str],
261
- device: torch.device,
262
- dtype: torch.dtype,
263
- max_sequence_length,
264
- ) -> torch.Tensor:
265
- batch_size = len(prompt)
266
-
267
- text_inputs = tokenizer(
268
- prompt,
269
- padding="max_length",
270
- max_length=max_sequence_length,
271
- truncation=True,
272
- add_special_tokens=True,
273
- return_tensors="pt",
274
- )
275
- text_input_ids = text_inputs.input_ids
276
- prompt_attention_mask = text_inputs.attention_mask
277
- prompt_attention_mask = prompt_attention_mask.bool().to(device)
278
-
279
- prompt_embeds = text_encoder(text_input_ids.to(device))[0]
280
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
281
- prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
282
-
283
- return {"prompt_embeds": prompt_embeds, "prompt_attention_mask": prompt_attention_mask}
284
-
285
-
286
- def _normalize_latents(
287
- latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
288
- ) -> torch.Tensor:
289
- # Normalize latents across the channel dimension [B, C, F, H, W]
290
- latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
291
- latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
292
- latents = (latents - latents_mean) * scaling_factor / latents_std
293
- return latents
294
-
295
-
296
- def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
297
- # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
298
- # The patch dimensions are then permuted and collapsed into the channel dimension of shape:
299
- # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
300
- # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
301
- batch_size, num_channels, num_frames, height, width = latents.shape
302
- post_patch_num_frames = num_frames // patch_size_t
303
- post_patch_height = height // patch_size
304
- post_patch_width = width // patch_size
305
- latents = latents.reshape(
306
- batch_size,
307
- -1,
308
- post_patch_num_frames,
309
- patch_size_t,
310
- post_patch_height,
311
- patch_size,
312
- post_patch_width,
313
- patch_size,
314
- )
315
- latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
316
- return latents
317
-
318
-
319
- LTX_VIDEO_T2V_LORA_CONFIG = {
320
- "pipeline_cls": LTXPipeline,
321
- "load_condition_models": load_condition_models,
322
- "load_latent_models": load_latent_models,
323
- "load_diffusion_models": load_diffusion_models,
324
- "initialize_pipeline": initialize_pipeline,
325
- "prepare_conditions": prepare_conditions,
326
- "prepare_latents": prepare_latents,
327
- "post_latent_preparation": post_latent_preparation,
328
- "collate_fn": collate_fn_t2v,
329
- "forward_pass": forward_pass,
330
- "validation": validation,
331
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/models/modeling_utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from diffusers import DiffusionPipeline
5
+ from diffusers.configuration_utils import FrozenDict
6
+ from PIL.Image import Image
7
+
8
+ from ..logging import get_logger
9
+ from ..parallel import ParallelBackendEnum
10
+ from ..processors import ProcessorMixin
11
+ from ..typing import ArtifactType, SchedulerType, TokenizerType
12
+ from ..utils import resolve_component_cls
13
+
14
+
15
+ logger = get_logger()
16
+
17
+ # TODO(aryan): we most likely don't need this. take a look after refactoring more
18
+ # fmt: off
19
+ IGNORE_KEYS_FOR_COLLATION = {"height", "width", "num_frames", "frame_rate", "rope_interpolation_scale", "return_dict", "attention_kwargs", "cross_attention_kwargs", "joint_attention_kwargs", "latents_mean", "latents_std"}
20
+ # fmt: on
21
+
22
+
23
+ class ModelSpecification:
24
+ r"""
25
+ The ModelSpecification class is an interface to be used for Diffusion training recipes. It provides
26
+ loose structure about how to organize the code for training. The trainer implementations will
27
+ make use of this interface to load models, prepare conditions, prepare latents, forward pass, etc.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ pretrained_model_name_or_path: Optional[str] = None,
33
+ tokenizer_id: Optional[str] = None,
34
+ tokenizer_2_id: Optional[str] = None,
35
+ tokenizer_3_id: Optional[str] = None,
36
+ text_encoder_id: Optional[str] = None,
37
+ text_encoder_2_id: Optional[str] = None,
38
+ text_encoder_3_id: Optional[str] = None,
39
+ transformer_id: Optional[str] = None,
40
+ vae_id: Optional[str] = None,
41
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
42
+ text_encoder_2_dtype: torch.dtype = torch.bfloat16,
43
+ text_encoder_3_dtype: torch.dtype = torch.bfloat16,
44
+ transformer_dtype: torch.dtype = torch.bfloat16,
45
+ vae_dtype: str = torch.bfloat16,
46
+ revision: Optional[str] = None,
47
+ cache_dir: Optional[str] = None,
48
+ condition_model_processors: List[ProcessorMixin] = None,
49
+ latent_model_processors: List[ProcessorMixin] = None,
50
+ ) -> None:
51
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
52
+ self.tokenizer_id = tokenizer_id
53
+ self.tokenizer_2_id = tokenizer_2_id
54
+ self.tokenizer_3_id = tokenizer_3_id
55
+ self.text_encoder_id = text_encoder_id
56
+ self.text_encoder_2_id = text_encoder_2_id
57
+ self.text_encoder_3_id = text_encoder_3_id
58
+ self.transformer_id = transformer_id
59
+ self.vae_id = vae_id
60
+ self.text_encoder_dtype = text_encoder_dtype
61
+ self.text_encoder_2_dtype = text_encoder_2_dtype
62
+ self.text_encoder_3_dtype = text_encoder_3_dtype
63
+ self.transformer_dtype = transformer_dtype
64
+ self.vae_dtype = vae_dtype
65
+ self.revision = revision
66
+ self.cache_dir = cache_dir
67
+ self.condition_model_processors = condition_model_processors or []
68
+ self.latent_model_processors = latent_model_processors or []
69
+
70
+ self.transformer_config: Dict[str, Any] = None
71
+ self.vae_config: Dict[str, Any] = None
72
+
73
+ self._load_configs()
74
+
75
+ # TODO(aryan): revisit how to do this better without user having to worry about it
76
+ @property
77
+ def _resolution_dim_keys(self) -> Dict[str, Tuple[int, ...]]:
78
+ raise NotImplementedError(
79
+ f"ModelSpecification::_resolution_dim_keys is not implemented for {self.__class__.__name__}"
80
+ )
81
+
82
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
83
+ raise NotImplementedError(
84
+ f"ModelSpecification::load_condition_models is not implemented for {self.__class__.__name__}"
85
+ )
86
+
87
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
88
+ raise NotImplementedError(
89
+ f"ModelSpecification::load_latent_models is not implemented for {self.__class__.__name__}"
90
+ )
91
+
92
+ def load_diffusion_models(self) -> Dict[str, Union[torch.nn.Module]]:
93
+ raise NotImplementedError(
94
+ f"ModelSpecification::load_diffusion_models is not implemented for {self.__class__.__name__}"
95
+ )
96
+
97
+ def load_pipeline(
98
+ self,
99
+ tokenizer: Optional[TokenizerType] = None,
100
+ tokenizer_2: Optional[TokenizerType] = None,
101
+ tokenizer_3: Optional[TokenizerType] = None,
102
+ text_encoder: Optional[torch.nn.Module] = None,
103
+ text_encoder_2: Optional[torch.nn.Module] = None,
104
+ text_encoder_3: Optional[torch.nn.Module] = None,
105
+ transformer: Optional[torch.nn.Module] = None,
106
+ vae: Optional[torch.nn.Module] = None,
107
+ scheduler: Optional[SchedulerType] = None,
108
+ enable_slicing: bool = False,
109
+ enable_tiling: bool = False,
110
+ enable_model_cpu_offload: bool = False,
111
+ training: bool = False,
112
+ **kwargs,
113
+ ) -> DiffusionPipeline:
114
+ raise NotImplementedError(
115
+ f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}"
116
+ )
117
+
118
+ def collate_fn(self, batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
119
+ raise NotImplementedError(f"ModelSpecification::collate_fn is not implemented for {self.__class__.__name__}")
120
+
121
+ def prepare_conditions(self, **kwargs) -> Dict[str, Any]:
122
+ for processor in self.condition_model_processors:
123
+ result = processor(**kwargs)
124
+ result_keys = set(result.keys())
125
+ repeat_keys = result_keys.intersection(kwargs.keys())
126
+ if repeat_keys:
127
+ logger.warning(
128
+ f"Processor {processor.__class__.__name__} returned keys that already exist in "
129
+ f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
130
+ f"be intended. Please rename the keys in the processor to avoid conflicts."
131
+ )
132
+ kwargs.update(result)
133
+ return kwargs
134
+
135
+ def prepare_latents(self, **kwargs) -> Dict[str, Any]:
136
+ for processor in self.latent_model_processors:
137
+ result = processor(**kwargs)
138
+ result_keys = set(result.keys())
139
+ repeat_keys = result_keys.intersection(kwargs.keys())
140
+ if repeat_keys:
141
+ logger.warning(
142
+ f"Processor {processor.__class__.__name__} returned keys that already exist in "
143
+ f"conditions: {repeat_keys}. Overwriting the existing values, but this may not "
144
+ f"be intended. Please rename the keys in the processor to avoid conflicts."
145
+ )
146
+ kwargs.update(result)
147
+ return kwargs
148
+
149
+ def collate_conditions(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
150
+ keys = list(data[0].keys())
151
+ collated_data = {}
152
+ for key in keys:
153
+ if key in IGNORE_KEYS_FOR_COLLATION:
154
+ collated_data[key] = data[0][key]
155
+ continue
156
+ collated_d = [d[key] for d in data]
157
+ if isinstance(collated_d[0], torch.Tensor):
158
+ collated_d = torch.cat(collated_d)
159
+ collated_data[key] = collated_d
160
+ return collated_data
161
+
162
+ def collate_latents(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
163
+ keys = list(data[0].keys())
164
+ collated_data = {}
165
+ for key in keys:
166
+ if key in IGNORE_KEYS_FOR_COLLATION:
167
+ collated_data[key] = data[0][key]
168
+ continue
169
+ collated_d = [d[key] for d in data]
170
+ # TODO(aryan): Support multi-resolution collation
171
+ if isinstance(collated_d[0], torch.Tensor):
172
+ collated_d = torch.cat(collated_d)
173
+ collated_data[key] = collated_d
174
+ return collated_data
175
+
176
+ def forward(
177
+ self, transformer: torch.nn.Module, generator: Optional[torch.Generator] = None, **kwargs
178
+ ) -> Dict[str, torch.Tensor]:
179
+ raise NotImplementedError(f"ModelSpecification::forward is not implemented for {self.__class__.__name__}")
180
+
181
+ def validation(
182
+ self,
183
+ pipeline: DiffusionPipeline,
184
+ prompt: Optional[str] = None,
185
+ image: Optional[Image] = None,
186
+ video: Optional[List[Image]] = None,
187
+ height: Optional[int] = None,
188
+ width: Optional[int] = None,
189
+ num_frames: Optional[int] = None,
190
+ frame_rate: Optional[int] = None,
191
+ generator: Optional[torch.Generator] = None,
192
+ ) -> List[ArtifactType]:
193
+ raise NotImplementedError(f"ModelSpecification::validation is not implemented for {self.__class__.__name__}")
194
+
195
+ def _save_lora_weights(
196
+ self,
197
+ directory: str,
198
+ transformer: torch.nn.Module,
199
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
200
+ scheduler: Optional[SchedulerType] = None,
201
+ ) -> None:
202
+ r"""
203
+ Save the lora state dicts of the model to the given directory.
204
+
205
+ This API is not backwards compatible and will be changed in near future.
206
+ """
207
+ raise NotImplementedError(
208
+ f"ModelSpecification::save_lora_weights is not implemented for {self.__class__.__name__}"
209
+ )
210
+
211
+ def _save_model(
212
+ self,
213
+ directory: str,
214
+ transformer: torch.nn.Module,
215
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
216
+ scheduler: Optional[SchedulerType] = None,
217
+ ) -> None:
218
+ r"""
219
+ Save the state dicts to the given directory.
220
+
221
+ This API is not backwards compatible and will be changed in near future.
222
+ """
223
+ raise NotImplementedError(f"ModelSpecification::save_model is not implemented for {self.__class__.__name__}")
224
+
225
+ def apply_tensor_parallel(
226
+ self,
227
+ backend: ParallelBackendEnum,
228
+ device_mesh: torch.distributed.DeviceMesh,
229
+ text_encoder: torch.nn.Module,
230
+ text_encoder_2: torch.nn.Module,
231
+ text_encoder_3: torch.nn.Module,
232
+ transformer: torch.nn.Module,
233
+ vae: torch.nn.Module,
234
+ ) -> None:
235
+ raise NotImplementedError(
236
+ f"ModelSpecification::apply_tensor_parallel is not implemented for {self.__class__.__name__}"
237
+ )
238
+
239
+ def _load_configs(self) -> None:
240
+ self._load_transformer_config()
241
+ self._load_vae_config()
242
+
243
+ def _load_transformer_config(self) -> None:
244
+ if self.transformer_id is not None:
245
+ transformer_cls = resolve_component_cls(
246
+ self.transformer_id,
247
+ component_name="_class_name",
248
+ filename="config.json",
249
+ revision=self.revision,
250
+ cache_dir=self.cache_dir,
251
+ )
252
+ self.transformer_config = transformer_cls.load_config(
253
+ self.transformer_id, revision=self.revision, cache_dir=self.cache_dir
254
+ )
255
+ else:
256
+ transformer_cls = resolve_component_cls(
257
+ self.pretrained_model_name_or_path,
258
+ component_name="transformer",
259
+ filename="model_index.json",
260
+ revision=self.revision,
261
+ cache_dir=self.cache_dir,
262
+ )
263
+ self.transformer_config = transformer_cls.load_config(
264
+ self.pretrained_model_name_or_path,
265
+ subfolder="transformer",
266
+ revision=self.revision,
267
+ cache_dir=self.cache_dir,
268
+ )
269
+ self.transformer_config = FrozenDict(**self.transformer_config)
270
+
271
+ def _load_vae_config(self) -> None:
272
+ if self.vae_id is not None:
273
+ vae_cls = resolve_component_cls(
274
+ self.vae_id,
275
+ component_name="_class_name",
276
+ filename="config.json",
277
+ revision=self.revision,
278
+ cache_dir=self.cache_dir,
279
+ )
280
+ self.vae_config = vae_cls.load_config(self.vae_id, revision=self.revision, cache_dir=self.cache_dir)
281
+ else:
282
+ vae_cls = resolve_component_cls(
283
+ self.pretrained_model_name_or_path,
284
+ component_name="vae",
285
+ filename="model_index.json",
286
+ revision=self.revision,
287
+ cache_dir=self.cache_dir,
288
+ )
289
+ self.vae_config = vae_cls.load_config(
290
+ self.pretrained_model_name_or_path, subfolder="vae", revision=self.revision, cache_dir=self.cache_dir
291
+ )
292
+ self.vae_config = FrozenDict(**self.vae_config)
finetrainers/models/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+
7
+
8
+ class DiagonalGaussianDistribution(object):
9
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False, _dim: int = 1):
10
+ # Note: _dim is the new argument added here after copying from diffusers
11
+ self.parameters = parameters
12
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=_dim)
13
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
14
+ self.deterministic = deterministic
15
+ self.std = torch.exp(0.5 * self.logvar)
16
+ self.var = torch.exp(self.logvar)
17
+ if self.deterministic:
18
+ self.var = self.std = torch.zeros_like(
19
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
20
+ )
21
+
22
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
23
+ # make sure sample is on the same device as the parameters and has same dtype
24
+ sample = randn_tensor(
25
+ self.mean.shape,
26
+ generator=generator,
27
+ device=self.parameters.device,
28
+ dtype=self.parameters.dtype,
29
+ )
30
+ x = self.mean + self.std * sample
31
+ return x
32
+
33
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
34
+ if self.deterministic:
35
+ return torch.Tensor([0.0])
36
+ else:
37
+ if other is None:
38
+ return 0.5 * torch.sum(
39
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
40
+ dim=[1, 2, 3],
41
+ )
42
+ else:
43
+ return 0.5 * torch.sum(
44
+ torch.pow(self.mean - other.mean, 2) / other.var
45
+ + self.var / other.var
46
+ - 1.0
47
+ - self.logvar
48
+ + other.logvar,
49
+ dim=[1, 2, 3],
50
+ )
51
+
52
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
53
+ if self.deterministic:
54
+ return torch.Tensor([0.0])
55
+ logtwopi = np.log(2.0 * np.pi)
56
+ return 0.5 * torch.sum(
57
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
58
+ dim=dims,
59
+ )
60
+
61
+ def mode(self) -> torch.Tensor:
62
+ return self.mean
finetrainers/models/wan/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .base_specification import WanModelSpecification
finetrainers/models/wan/base_specification.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ from accelerate import init_empty_weights
6
+ from diffusers import (
7
+ AutoencoderKLWan,
8
+ FlowMatchEulerDiscreteScheduler,
9
+ WanImageToVideoPipeline,
10
+ WanPipeline,
11
+ WanTransformer3DModel,
12
+ )
13
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
14
+ from PIL.Image import Image
15
+ from transformers import AutoModel, AutoTokenizer, UMT5EncoderModel
16
+
17
+ from ... import data
18
+ from ... import functional as FF
19
+ from ...logging import get_logger
20
+ from ...processors import ProcessorMixin, T5Processor
21
+ from ...typing import ArtifactType, SchedulerType
22
+ from ...utils import get_non_null_items
23
+ from ..modeling_utils import ModelSpecification
24
+
25
+
26
+ logger = get_logger()
27
+
28
+
29
+ class WanLatentEncodeProcessor(ProcessorMixin):
30
+ r"""
31
+ Processor to encode image/video into latents using the Wan VAE.
32
+
33
+ Args:
34
+ output_names (`List[str]`):
35
+ The names of the outputs that the processor returns. The outputs are in the following order:
36
+ - latents: The latents of the input image/video.
37
+ - num_frames: The number of frames in the input video.
38
+ - height: The height of the input image/video.
39
+ - width: The width of the input image/video.
40
+ - latents_mean: The latent channel means from the VAE state dict.
41
+ - latents_std: The latent channel standard deviations from the VAE state dict.
42
+ """
43
+
44
+ def __init__(self, output_names: List[str]):
45
+ super().__init__()
46
+ self.output_names = output_names
47
+ assert len(self.output_names) == 1
48
+
49
+ def forward(
50
+ self,
51
+ vae: AutoencoderKLWan,
52
+ image: Optional[torch.Tensor] = None,
53
+ video: Optional[torch.Tensor] = None,
54
+ generator: Optional[torch.Generator] = None,
55
+ compute_posterior: bool = True,
56
+ ) -> Dict[str, torch.Tensor]:
57
+ device = vae.device
58
+ dtype = vae.dtype
59
+
60
+ if image is not None:
61
+ video = image.unsqueeze(1)
62
+
63
+ assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
64
+ video = video.to(device=device, dtype=vae.dtype)
65
+ video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
66
+
67
+ if compute_posterior:
68
+ latents = vae.encode(video).latent_dist.sample(generator=generator)
69
+ latents = latents.to(dtype=dtype)
70
+ else:
71
+ # TODO(aryan): refactor in diffusers to have use_slicing attribute
72
+ # if vae.use_slicing and video.shape[0] > 1:
73
+ # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
74
+ # moments = torch.cat(encoded_slices)
75
+ # else:
76
+ # moments = vae._encode(video)
77
+ moments = vae._encode(video)
78
+ latents = moments.to(dtype=dtype)
79
+
80
+ return {self.output_names[0]: latents}
81
+
82
+
83
+ class WanModelSpecification(ModelSpecification):
84
+ def __init__(
85
+ self,
86
+ pretrained_model_name_or_path: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
87
+ tokenizer_id: Optional[str] = None,
88
+ text_encoder_id: Optional[str] = None,
89
+ transformer_id: Optional[str] = None,
90
+ vae_id: Optional[str] = None,
91
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
92
+ transformer_dtype: torch.dtype = torch.bfloat16,
93
+ vae_dtype: torch.dtype = torch.bfloat16,
94
+ revision: Optional[str] = None,
95
+ cache_dir: Optional[str] = None,
96
+ condition_model_processors: List[ProcessorMixin] = None,
97
+ latent_model_processors: List[ProcessorMixin] = None,
98
+ **kwargs,
99
+ ) -> None:
100
+ super().__init__(
101
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
102
+ tokenizer_id=tokenizer_id,
103
+ text_encoder_id=text_encoder_id,
104
+ transformer_id=transformer_id,
105
+ vae_id=vae_id,
106
+ text_encoder_dtype=text_encoder_dtype,
107
+ transformer_dtype=transformer_dtype,
108
+ vae_dtype=vae_dtype,
109
+ revision=revision,
110
+ cache_dir=cache_dir,
111
+ )
112
+
113
+ if condition_model_processors is None:
114
+ condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
115
+ if latent_model_processors is None:
116
+ latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
117
+
118
+ self.condition_model_processors = condition_model_processors
119
+ self.latent_model_processors = latent_model_processors
120
+
121
+ @property
122
+ def _resolution_dim_keys(self):
123
+ # TODO
124
+ return {
125
+ "latents": (2, 3, 4),
126
+ }
127
+
128
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
129
+ if self.tokenizer_id is not None:
130
+ tokenizer = AutoTokenizer.from_pretrained(
131
+ self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
132
+ )
133
+ else:
134
+ tokenizer = AutoTokenizer.from_pretrained(
135
+ self.pretrained_model_name_or_path,
136
+ subfolder="tokenizer",
137
+ revision=self.revision,
138
+ cache_dir=self.cache_dir,
139
+ )
140
+
141
+ if self.text_encoder_id is not None:
142
+ text_encoder = AutoModel.from_pretrained(
143
+ self.text_encoder_id,
144
+ torch_dtype=self.text_encoder_dtype,
145
+ revision=self.revision,
146
+ cache_dir=self.cache_dir,
147
+ )
148
+ else:
149
+ text_encoder = UMT5EncoderModel.from_pretrained(
150
+ self.pretrained_model_name_or_path,
151
+ subfolder="text_encoder",
152
+ torch_dtype=self.text_encoder_dtype,
153
+ revision=self.revision,
154
+ cache_dir=self.cache_dir,
155
+ )
156
+
157
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
158
+
159
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
160
+ if self.vae_id is not None:
161
+ vae = AutoencoderKLWan.from_pretrained(
162
+ self.vae_id,
163
+ torch_dtype=self.vae_dtype,
164
+ revision=self.revision,
165
+ cache_dir=self.cache_dir,
166
+ )
167
+ else:
168
+ vae = AutoencoderKLWan.from_pretrained(
169
+ self.pretrained_model_name_or_path,
170
+ subfolder="vae",
171
+ torch_dtype=self.vae_dtype,
172
+ revision=self.revision,
173
+ cache_dir=self.cache_dir,
174
+ )
175
+
176
+ return {"vae": vae}
177
+
178
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
179
+ if self.transformer_id is not None:
180
+ transformer = WanTransformer3DModel.from_pretrained(
181
+ self.transformer_id,
182
+ torch_dtype=self.transformer_dtype,
183
+ revision=self.revision,
184
+ cache_dir=self.cache_dir,
185
+ )
186
+ else:
187
+ transformer = WanTransformer3DModel.from_pretrained(
188
+ self.pretrained_model_name_or_path,
189
+ subfolder="transformer",
190
+ torch_dtype=self.transformer_dtype,
191
+ revision=self.revision,
192
+ cache_dir=self.cache_dir,
193
+ )
194
+
195
+ scheduler = FlowMatchEulerDiscreteScheduler()
196
+
197
+ return {"transformer": transformer, "scheduler": scheduler}
198
+
199
+ def load_pipeline(
200
+ self,
201
+ tokenizer: Optional[AutoTokenizer] = None,
202
+ text_encoder: Optional[UMT5EncoderModel] = None,
203
+ transformer: Optional[WanTransformer3DModel] = None,
204
+ vae: Optional[AutoencoderKLWan] = None,
205
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
206
+ enable_slicing: bool = False,
207
+ enable_tiling: bool = False,
208
+ enable_model_cpu_offload: bool = False,
209
+ training: bool = False,
210
+ **kwargs,
211
+ ) -> WanPipeline:
212
+ components = {
213
+ "tokenizer": tokenizer,
214
+ "text_encoder": text_encoder,
215
+ "transformer": transformer,
216
+ "vae": vae,
217
+ "scheduler": scheduler,
218
+ }
219
+ components = get_non_null_items(components)
220
+
221
+ pipe = WanPipeline.from_pretrained(
222
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
223
+ )
224
+ pipe.text_encoder.to(self.text_encoder_dtype)
225
+ pipe.vae.to(self.vae_dtype)
226
+
227
+ if not training:
228
+ pipe.transformer.to(self.transformer_dtype)
229
+
230
+ # TODO(aryan): add support in diffusers
231
+ # if enable_slicing:
232
+ # pipe.vae.enable_slicing()
233
+ # if enable_tiling:
234
+ # pipe.vae.enable_tiling()
235
+ if enable_model_cpu_offload:
236
+ pipe.enable_model_cpu_offload()
237
+
238
+ return pipe
239
+
240
+ @torch.no_grad()
241
+ def prepare_conditions(
242
+ self,
243
+ tokenizer: AutoTokenizer,
244
+ text_encoder: UMT5EncoderModel,
245
+ caption: str,
246
+ max_sequence_length: int = 512,
247
+ **kwargs,
248
+ ) -> Dict[str, Any]:
249
+ conditions = {
250
+ "tokenizer": tokenizer,
251
+ "text_encoder": text_encoder,
252
+ "caption": caption,
253
+ "max_sequence_length": max_sequence_length,
254
+ **kwargs,
255
+ }
256
+ input_keys = set(conditions.keys())
257
+ conditions = super().prepare_conditions(**conditions)
258
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
259
+ conditions.pop("prompt_attention_mask", None)
260
+ return conditions
261
+
262
+ @torch.no_grad()
263
+ def prepare_latents(
264
+ self,
265
+ vae: AutoencoderKLWan,
266
+ image: Optional[torch.Tensor] = None,
267
+ video: Optional[torch.Tensor] = None,
268
+ generator: Optional[torch.Generator] = None,
269
+ compute_posterior: bool = True,
270
+ **kwargs,
271
+ ) -> Dict[str, torch.Tensor]:
272
+ conditions = {
273
+ "vae": vae,
274
+ "image": image,
275
+ "video": video,
276
+ "generator": generator,
277
+ "compute_posterior": compute_posterior,
278
+ **kwargs,
279
+ }
280
+ input_keys = set(conditions.keys())
281
+ conditions = super().prepare_latents(**conditions)
282
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
283
+ return conditions
284
+
285
+ def forward(
286
+ self,
287
+ transformer: WanTransformer3DModel,
288
+ condition_model_conditions: Dict[str, torch.Tensor],
289
+ latent_model_conditions: Dict[str, torch.Tensor],
290
+ sigmas: torch.Tensor,
291
+ generator: Optional[torch.Generator] = None,
292
+ compute_posterior: bool = True,
293
+ **kwargs,
294
+ ) -> Tuple[torch.Tensor, ...]:
295
+ if compute_posterior:
296
+ latents = latent_model_conditions.pop("latents")
297
+ else:
298
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
299
+ latents = posterior.sample(generator=generator)
300
+ del posterior
301
+
302
+ noise = torch.zeros_like(latents).normal_(generator=generator)
303
+ noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
304
+
305
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
306
+ condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
307
+
308
+ timesteps = (sigmas.flatten() * 1000.0).long()
309
+
310
+ pred = transformer(
311
+ **latent_model_conditions,
312
+ **condition_model_conditions,
313
+ timestep=timesteps,
314
+ return_dict=False,
315
+ )[0]
316
+ target = FF.flow_match_target(noise, latents)
317
+
318
+ return pred, target, sigmas
319
+
320
+ def validation(
321
+ self,
322
+ pipeline: WanPipeline,
323
+ prompt: str,
324
+ image: Optional[Image] = None,
325
+ height: Optional[int] = None,
326
+ width: Optional[int] = None,
327
+ num_frames: Optional[int] = None,
328
+ num_inference_steps: int = 50,
329
+ generator: Optional[torch.Generator] = None,
330
+ **kwargs,
331
+ ) -> List[ArtifactType]:
332
+ if image is not None:
333
+ pipeline = WanImageToVideoPipeline.from_pipe(pipeline)
334
+
335
+ generation_kwargs = {
336
+ "prompt": prompt,
337
+ "image": image,
338
+ "height": height,
339
+ "width": width,
340
+ "num_frames": num_frames,
341
+ "num_inference_steps": num_inference_steps,
342
+ "generator": generator,
343
+ "return_dict": True,
344
+ "output_type": "pil",
345
+ }
346
+ generation_kwargs = get_non_null_items(generation_kwargs)
347
+ video = pipeline(**generation_kwargs).frames[0]
348
+ return [data.VideoArtifact(value=video)]
349
+
350
+ def _save_lora_weights(
351
+ self,
352
+ directory: str,
353
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
354
+ scheduler: Optional[SchedulerType] = None,
355
+ *args,
356
+ **kwargs,
357
+ ) -> None:
358
+ # TODO(aryan): this needs refactoring
359
+ if transformer_state_dict is not None:
360
+ WanPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
361
+ if scheduler is not None:
362
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
363
+
364
+ def _save_model(
365
+ self,
366
+ directory: str,
367
+ transformer: WanTransformer3DModel,
368
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
369
+ scheduler: Optional[SchedulerType] = None,
370
+ ) -> None:
371
+ # TODO(aryan): this needs refactoring
372
+ if transformer_state_dict is not None:
373
+ with init_empty_weights():
374
+ transformer_copy = WanTransformer3DModel.from_config(transformer.config)
375
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
376
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
377
+ if scheduler is not None:
378
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
finetrainers/optimizer.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+ from typing import Any, Callable, Dict, List, Optional, Type, Union
4
+
5
+ import torch
6
+ from torch.distributed.checkpoint.state_dict import (
7
+ StateDictOptions,
8
+ get_optimizer_state_dict,
9
+ set_optimizer_state_dict,
10
+ )
11
+ from torch.distributed.checkpoint.stateful import Stateful
12
+
13
+ from .parallel import ParallelBackendEnum
14
+ from .utils.import_utils import is_bitsandbytes_available
15
+
16
+
17
+ class OptimizerWrapper(Stateful):
18
+ r"""
19
+ Optimizer wrapper that:
20
+ - allows step/zero_grad on multiple optimizers needed for virtual pipeline stages
21
+ - saves/loading optimizer state_dict at checkpoint
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model_parts: List[torch.nn.Module],
27
+ optimizer_cls: Type[torch.optim.Optimizer],
28
+ optimizer_kwargs: Dict[str, Any],
29
+ ) -> None:
30
+ self.optimizer_cls = optimizer_cls
31
+ self.optimizer_kwargs = optimizer_kwargs
32
+
33
+ self.optimizers = []
34
+ self.model_parts = model_parts
35
+
36
+ for model in self.model_parts:
37
+ optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)
38
+ self.optimizers.append(optimizer)
39
+
40
+ def step(self) -> None:
41
+ for optimizer in self.optimizers:
42
+ optimizer.step()
43
+
44
+ def zero_grad(self) -> None:
45
+ for optimizer in self.optimizers:
46
+ optimizer.zero_grad()
47
+
48
+ def state_dict(self) -> Dict[str, Any]:
49
+ func = functools.partial(
50
+ get_optimizer_state_dict,
51
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
52
+ )
53
+ return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()}
54
+
55
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
56
+ func = functools.partial(
57
+ set_optimizer_state_dict,
58
+ optim_state_dict=state_dict,
59
+ options=StateDictOptions(flatten_optimizer_state_dict=True),
60
+ )
61
+ list(map(func, self.model_parts, self.optimizers))
62
+
63
+
64
+ class SchedulerWrapper:
65
+ def __init__(
66
+ self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int
67
+ ) -> None:
68
+ self.schedulers = []
69
+ for optimizer in optimizers:
70
+ self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch))
71
+
72
+ def step(self) -> None:
73
+ for scheduler in self.schedulers:
74
+ scheduler.step()
75
+
76
+ def get_last_lr(self) -> List[float]:
77
+ # TODO(aryan): look into this later. Currently calling it leads to NCCL hang?????
78
+ return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)}
79
+
80
+ def get_lr_scheduler_state(self) -> Dict[str, Any]:
81
+ state_dict = {}
82
+ if len(self.schedulers) == 1:
83
+ state_dict["lr_scheduler"] = self.schedulers[0]
84
+ else:
85
+ # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
86
+ # It should only support saving and loading a distributed checkpoint with the same number of pp ranks
87
+ for idx, lr_scheduler in enumerate(self.schedulers):
88
+ state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
89
+ return state_dict
90
+
91
+
92
+ def get_optimizer(
93
+ parallel_backend: ParallelBackendEnum,
94
+ name: str,
95
+ model_parts: List[torch.nn.Module],
96
+ learning_rate: float = 1e-3,
97
+ beta1: float = 0.9,
98
+ beta2: float = 0.95,
99
+ beta3: float = 0.999,
100
+ epsilon: float = 1e-8,
101
+ weight_decay: float = 1e-4,
102
+ fused: bool = False,
103
+ ) -> Union[torch.optim.Optimizer, OptimizerWrapper]:
104
+ name = name.lower()
105
+
106
+ _raise_errors_if_packages_not_available(name)
107
+
108
+ if name == "adam":
109
+ optimizer_cls = torch.optim.Adam
110
+ optimizer_kwargs = {
111
+ "lr": learning_rate,
112
+ "betas": (beta1, beta2),
113
+ "eps": epsilon,
114
+ "weight_decay": weight_decay,
115
+ "fused": fused,
116
+ }
117
+ elif name == "adamw":
118
+ optimizer_cls = torch.optim.AdamW
119
+ optimizer_kwargs = {
120
+ "lr": learning_rate,
121
+ "betas": (beta1, beta2),
122
+ "eps": epsilon,
123
+ "weight_decay": weight_decay,
124
+ "fused": fused,
125
+ }
126
+ elif name == "adam-bnb":
127
+ from bitsandbytes.optim import Adam
128
+
129
+ optimizer_cls = Adam
130
+ optimizer_kwargs = {
131
+ "lr": learning_rate,
132
+ "betas": (beta1, beta2),
133
+ "eps": epsilon,
134
+ "weight_decay": weight_decay,
135
+ }
136
+ elif name == "adamw-bnb":
137
+ from bitsandbytes.optim import AdamW
138
+
139
+ optimizer_cls = AdamW
140
+ optimizer_kwargs = {
141
+ "lr": learning_rate,
142
+ "betas": (beta1, beta2),
143
+ "eps": epsilon,
144
+ "weight_decay": weight_decay,
145
+ }
146
+ elif name == "adam-bnb-8bit":
147
+ from bitsandbytes.optim import Adam8bit
148
+
149
+ optimizer_cls = Adam8bit
150
+ optimizer_kwargs = {
151
+ "lr": learning_rate,
152
+ "betas": (beta1, beta2),
153
+ "eps": epsilon,
154
+ "weight_decay": weight_decay,
155
+ }
156
+ elif name == "adamw-bnb-8bit":
157
+ from bitsandbytes.optim import AdamW8bit
158
+
159
+ optimizer_cls = AdamW8bit
160
+ optimizer_kwargs = {
161
+ "lr": learning_rate,
162
+ "betas": (beta1, beta2),
163
+ "eps": epsilon,
164
+ "weight_decay": weight_decay,
165
+ }
166
+
167
+ # TODO(aryan): handle bitsandbytes and torchao
168
+ else:
169
+ raise ValueError(f"Unsupported optimizer: {name}")
170
+
171
+ if parallel_backend == ParallelBackendEnum.ACCELERATE:
172
+ return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs)
173
+ elif parallel_backend == ParallelBackendEnum.PTD:
174
+ return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs)
175
+
176
+
177
+ def get_optimizer_accelerate(
178
+ model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
179
+ ) -> torch.optim.Optimizer:
180
+ params = [param for model in model_parts for param in model.parameters() if param.requires_grad]
181
+ optimizer = optimizer_cls(params, **optimizer_kwargs)
182
+ return optimizer
183
+
184
+
185
+ def get_optimizer_ptd(
186
+ model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any]
187
+ ) -> OptimizerWrapper:
188
+ return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs)
189
+
190
+
191
+ def get_lr_scheduler(
192
+ parallel_backend: ParallelBackendEnum,
193
+ name: str,
194
+ optimizer: Union[torch.optim.Optimizer, OptimizerWrapper],
195
+ step_rules: Optional[str] = None,
196
+ num_warmup_steps: Optional[int] = None,
197
+ num_training_steps: Optional[int] = None,
198
+ num_cycles: int = 1,
199
+ power: float = 1.0,
200
+ lr_init: float = 1e-3,
201
+ lr_end: float = 1e-7,
202
+ last_epoch: int = -1,
203
+ ) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]:
204
+ name = name.lower()
205
+ if name == "constant":
206
+ scheduler_lambda_fn = get_constant_schedule()
207
+ elif name == "constant_with_warmup":
208
+ scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps)
209
+ elif name == "piecewise_constant":
210
+ scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules)
211
+ elif name == "linear":
212
+ scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps)
213
+ elif name == "cosine":
214
+ scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles)
215
+ elif name == "cosine_with_restarts":
216
+ scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup(
217
+ num_warmup_steps, num_training_steps, num_cycles
218
+ )
219
+ elif name == "polynomial":
220
+ scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup(
221
+ num_warmup_steps, num_training_steps, lr_init, lr_end, power
222
+ )
223
+ else:
224
+ raise ValueError(f"Unsupported scheduler: {name}")
225
+
226
+ if parallel_backend == ParallelBackendEnum.ACCELERATE:
227
+ return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch)
228
+ elif parallel_backend == ParallelBackendEnum.PTD:
229
+ return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch)
230
+
231
+
232
+ def get_lr_scheduler_accelerate(
233
+ optimizer: torch.optim.Optimizer,
234
+ scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler],
235
+ last_epoch: int = -1,
236
+ ) -> torch.optim.lr_scheduler.LambdaLR:
237
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)
238
+ return scheduler
239
+
240
+
241
+ def get_lr_scheduler_ptd(
242
+ optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1
243
+ ) -> SchedulerWrapper:
244
+ return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch)
245
+
246
+
247
+ # ==============================
248
+ # Adapted from https://github.com/huggingface/diffusers/blob/196aef5a6f76e1ad6ba889184860c3633d166910/src/diffusers/optimization.py
249
+ # ==============================
250
+
251
+
252
+ def get_constant_schedule() -> Callable[[int], float]:
253
+ r"""
254
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
255
+ """
256
+
257
+ def lr_lambda(current_step: int):
258
+ return 1.0
259
+
260
+ return lr_lambda
261
+
262
+
263
+ def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]:
264
+ r"""
265
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
266
+ increases linearly between 0 and the initial lr set in the optimizer.
267
+
268
+ Args:
269
+ num_warmup_steps (`int`):
270
+ The number of steps for the warmup phase.
271
+ """
272
+
273
+ def lr_lambda(current_step: int):
274
+ if current_step < num_warmup_steps:
275
+ return float(current_step) / float(max(1.0, num_warmup_steps))
276
+ return 1.0
277
+
278
+ return lr_lambda
279
+
280
+
281
+ def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]:
282
+ r"""
283
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
284
+
285
+ Args:
286
+ step_rules (`string`):
287
+ The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
288
+ if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
289
+ steps and multiple 0.005 for the other steps.
290
+ """
291
+
292
+ rules_dict = {}
293
+ rule_list = step_rules.split(",")
294
+ for rule_str in rule_list[:-1]:
295
+ value_str, steps_str = rule_str.split(":")
296
+ steps = int(steps_str)
297
+ value = float(value_str)
298
+ rules_dict[steps] = value
299
+ last_lr_multiple = float(rule_list[-1])
300
+
301
+ def create_rules_function(rules_dict, last_lr_multiple):
302
+ def rule_func(steps: int) -> float:
303
+ sorted_steps = sorted(rules_dict.keys())
304
+ for i, sorted_step in enumerate(sorted_steps):
305
+ if steps < sorted_step:
306
+ return rules_dict[sorted_steps[i]]
307
+ return last_lr_multiple
308
+
309
+ return rule_func
310
+
311
+ rules_func = create_rules_function(rules_dict, last_lr_multiple)
312
+ return rules_func
313
+
314
+
315
+ def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]:
316
+ r"""
317
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
318
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
319
+
320
+ Args:
321
+ num_warmup_steps (`int`):
322
+ The number of steps for the warmup phase.
323
+ num_training_steps (`int`):
324
+ The total number of training steps.
325
+ """
326
+
327
+ def lr_lambda(current_step: int):
328
+ if current_step < num_warmup_steps:
329
+ return float(current_step) / float(max(1, num_warmup_steps))
330
+ return max(
331
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
332
+ )
333
+
334
+ return lr_lambda
335
+
336
+
337
+ def get_cosine_schedule_with_warmup(
338
+ num_warmup_steps: int,
339
+ num_training_steps: int,
340
+ num_cycles: float = 0.5,
341
+ ) -> Callable[[int], float]:
342
+ r"""
343
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
344
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
345
+ initial lr set in the optimizer.
346
+
347
+ Args:
348
+ num_warmup_steps (`int`):
349
+ The number of steps for the warmup phase.
350
+ num_training_steps (`int`):
351
+ The total number of training steps.
352
+ num_periods (`float`, *optional*, defaults to 0.5):
353
+ The number of periods of the cosine function in a schedule (the default is to just decrease from the max
354
+ value to 0 following a half-cosine).
355
+ """
356
+
357
+ def lr_lambda(current_step):
358
+ if current_step < num_warmup_steps:
359
+ return float(current_step) / float(max(1, num_warmup_steps))
360
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
361
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
362
+
363
+ return lr_lambda
364
+
365
+
366
+ def get_cosine_with_hard_restarts_schedule_with_warmup(
367
+ num_warmup_steps: int,
368
+ num_training_steps: int,
369
+ num_cycles: int = 1,
370
+ ) -> Callable[[int], float]:
371
+ r"""
372
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
373
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
374
+ linearly between 0 and the initial lr set in the optimizer.
375
+
376
+ Args:
377
+ num_warmup_steps (`int`):
378
+ The number of steps for the warmup phase.
379
+ num_training_steps (`int`):
380
+ The total number of training steps.
381
+ num_cycles (`int`, *optional*, defaults to 1):
382
+ The number of hard restarts to use.
383
+ """
384
+
385
+ def lr_lambda(current_step):
386
+ if current_step < num_warmup_steps:
387
+ return float(current_step) / float(max(1, num_warmup_steps))
388
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
389
+ if progress >= 1.0:
390
+ return 0.0
391
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
392
+
393
+ return lr_lambda
394
+
395
+
396
+ def get_polynomial_decay_schedule_with_warmup(
397
+ num_warmup_steps: int,
398
+ num_training_steps: int,
399
+ lr_init: float,
400
+ lr_end: float = 1e-7,
401
+ power: float = 1.0,
402
+ ) -> Callable[[int], float]:
403
+ r"""
404
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
405
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
406
+ initial lr set in the optimizer.
407
+
408
+ Args:
409
+ num_warmup_steps (`int`):
410
+ The number of steps for the warmup phase.
411
+ num_training_steps (`int`):
412
+ The total number of training steps.
413
+ lr_end (`float`, *optional*, defaults to 1e-7):
414
+ The end LR.
415
+ power (`float`, *optional*, defaults to 1.0):
416
+ Power factor.
417
+
418
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at
419
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
420
+ """
421
+
422
+ if not (lr_init > lr_end):
423
+ raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
424
+
425
+ def lr_lambda(current_step: int):
426
+ if current_step < num_warmup_steps:
427
+ return float(current_step) / float(max(1, num_warmup_steps))
428
+ elif current_step > num_training_steps:
429
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
430
+ else:
431
+ lr_range = lr_init - lr_end
432
+ decay_steps = num_training_steps - num_warmup_steps
433
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
434
+ decay = lr_range * pct_remaining**power + lr_end
435
+ return decay / lr_init # as LambdaLR multiplies by lr_init
436
+
437
+ return lr_lambda
438
+
439
+
440
+ def _raise_errors_if_packages_not_available(name: str) -> None:
441
+ name_split = name.split("-")
442
+ if len(name_split) < 2:
443
+ return
444
+ package_name = name_split[1]
445
+ if package_name == "bnb":
446
+ if not is_bitsandbytes_available():
447
+ raise ImportError(
448
+ f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer."
449
+ )
finetrainers/parallel/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Union
3
+
4
+ from .accelerate import AccelerateParallelBackend
5
+ from .ptd import PytorchDTensorParallelBackend
6
+ from .utils import apply_ddp_ptd, apply_fsdp2_ptd, dist_max, dist_mean
7
+
8
+
9
+ ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend]
10
+
11
+
12
+ class ParallelBackendEnum(str, Enum):
13
+ ACCELERATE = "accelerate"
14
+ PTD = "ptd"
15
+
16
+
17
+ def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType:
18
+ if backend == ParallelBackendEnum.ACCELERATE:
19
+ return AccelerateParallelBackend
20
+ if backend == ParallelBackendEnum.PTD:
21
+ return PytorchDTensorParallelBackend
22
+ raise ValueError(f"Unknown parallel backend: {backend}")
finetrainers/parallel/accelerate.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import pathlib
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from diffusers.utils import is_accelerate_available
7
+
8
+ from ..logging import get_logger
9
+ from ..utils import get_device_info
10
+ from .base import BaseParallelBackend
11
+ from .utils import apply_ddp_accelerate
12
+
13
+
14
+ if not is_accelerate_available():
15
+ raise ImportError(
16
+ "Please install the accelerate package using `pip install accelerate` to use the AccelerateParallelBackend."
17
+ )
18
+
19
+ from accelerate import Accelerator
20
+ from accelerate.data_loader import DataLoader
21
+ from accelerate.utils import (
22
+ DataLoaderConfiguration,
23
+ DistributedDataParallelKwargs,
24
+ InitProcessGroupKwargs,
25
+ ProjectConfiguration,
26
+ )
27
+
28
+
29
+ logger = get_logger()
30
+ _device_type, _device_module = get_device_info()
31
+
32
+
33
+ class AccelerateParallelBackend(BaseParallelBackend):
34
+ def __init__(
35
+ self,
36
+ world_size: int,
37
+ pp_degree: int = 1,
38
+ dp_degree: int = 1,
39
+ dp_shards: int = -1,
40
+ cp_degree: int = 1,
41
+ tp_degree: int = 1,
42
+ backend: str = "nccl",
43
+ timeout: int = 180,
44
+ logging_dir: Optional[str] = None,
45
+ output_dir: Optional[str] = None,
46
+ gradient_accumulation_steps: Optional[int] = None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self._world_size = world_size
51
+ self._pp_degree = pp_degree
52
+ self._dp_degree = dp_degree
53
+ self._dp_shards = dp_shards
54
+ self._cp_degree = cp_degree
55
+ self._tp_degree = tp_degree
56
+ self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
57
+ self._logging_dir = (
58
+ self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
59
+ )
60
+ self._backend = backend
61
+ self._timeout = timeout
62
+ self._gradient_accumulation_steps = gradient_accumulation_steps
63
+
64
+ if pp_degree > 1 or dp_shards > 1 or cp_degree > 1 or tp_degree > 1:
65
+ raise ValueError(
66
+ "AccelerateParallelBackend does not support anything but Distributed Data Parallelism at the moment."
67
+ )
68
+ if dp_degree != world_size:
69
+ raise ValueError("Data parallel degree must be equal to world size.")
70
+
71
+ self._accelerator: Accelerator = None
72
+ self._mesh: torch.distributed.DeviceMesh = None
73
+
74
+ def apply_ddp(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module:
75
+ project_config = None
76
+ ddp_kwargs = None
77
+ init_process_group_kwargs = None
78
+ if self._accelerator is None:
79
+ project_config = ProjectConfiguration(project_dir=self._output_dir, logging_dir=self._logging_dir)
80
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
81
+ dataloader_config = DataLoaderConfiguration(
82
+ split_batches=False, dispatch_batches=False, use_stateful_dataloader=True
83
+ )
84
+ init_process_group_kwargs = InitProcessGroupKwargs(
85
+ backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout)
86
+ )
87
+ self._accelerator, model = apply_ddp_accelerate(
88
+ model,
89
+ project_config,
90
+ ddp_kwargs,
91
+ init_process_group_kwargs,
92
+ dataloader_config,
93
+ self._gradient_accumulation_steps,
94
+ accelerator=self._accelerator,
95
+ )
96
+ logger.debug("Applied AccelerateParallel::apply_ddp to model.")
97
+ return model
98
+
99
+ def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
100
+ logger.debug("AccelerateParallelBackend::prepare_dataset completed!")
101
+ return dataset
102
+
103
+ def prepare_dataloader(
104
+ self,
105
+ dataset: torch.utils.data.IterableDataset,
106
+ batch_size: int = 1,
107
+ num_workers: int = 0,
108
+ pin_memory: bool = False,
109
+ ) -> DataLoader:
110
+ dataloader = torch.utils.data.DataLoader(
111
+ dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory
112
+ )
113
+ dataloader = self._accelerator.prepare_data_loader(dataloader)
114
+ logger.debug("AccelerateParallelBackend::prepare_dataloader completed!")
115
+ return dataloader
116
+
117
+ def prepare_optimizer(self, optimizer, lr_scheduler):
118
+ optimizer = self._accelerator.prepare_optimizer(optimizer)
119
+ lr_scheduler = self._accelerator.prepare_scheduler(lr_scheduler)
120
+ return optimizer, lr_scheduler
121
+
122
+ def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
123
+ def _get_mesh():
124
+ if name is None:
125
+ return self._mesh
126
+ try:
127
+ return self._mesh[name]
128
+ except (KeyError, RuntimeError):
129
+ return self._mesh
130
+
131
+ if self._mesh is not None:
132
+ return _get_mesh()
133
+
134
+ mesh_list = [("dp_replicate", self._dp_degree), ("dp_shard", self._dp_shards)]
135
+ mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
136
+ names = [x[0] for x in mesh_list]
137
+ degrees = [x[1] for x in mesh_list]
138
+ mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)
139
+
140
+ dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []
141
+
142
+ if self.data_replication_enabled:
143
+ dp_mesh_names.append("dp_replicate")
144
+ dp_cp_mesh_names.append("dp_replicate")
145
+ if self.data_sharding_enabled:
146
+ dp_mesh_names.append("dp_shard")
147
+ dp_cp_mesh_names.append("dp_shard")
148
+ dp_shard_cp_mesh_names.append("dp_shard")
149
+ if self.context_parallel_enabled:
150
+ dp_cp_mesh_names.append("cp")
151
+ dp_shard_cp_mesh_names.append("cp")
152
+
153
+ if len(dp_mesh_names) > 0:
154
+ mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
155
+ if len(dp_cp_mesh_names) > 0:
156
+ mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
157
+ if len(dp_shard_cp_mesh_names) > 0:
158
+ mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")
159
+
160
+ logger.debug(f"Device mesh: {mesh}")
161
+ self._mesh = mesh
162
+ return _get_mesh()
163
+
164
+ @property
165
+ def world_size(self):
166
+ return self._accelerator.num_processes
167
+
168
+ @property
169
+ def rank(self):
170
+ return self._accelerator.process_index
171
+
172
+ @property
173
+ def local_rank(self):
174
+ return self._accelerator.local_process_index
175
+
176
+ @property
177
+ def is_main_process(self):
178
+ r"""Returns `True` if the current process is the main process on the master node."""
179
+ return self._accelerator.is_main_process
180
+
181
+ @property
182
+ def is_local_main_process(self):
183
+ r"""Returns `True` if the current process is the main process on local node."""
184
+ return self._accelerator.is_local_main_process
185
+
186
+ @property
187
+ def device(self):
188
+ return self._accelerator.device
189
+
190
+ def wait_for_everyone(self):
191
+ self._accelerator.wait_for_everyone()
192
+
193
+ def destroy(self):
194
+ self._accelerator.end_training()
195
+
196
+ @property
197
+ def pipeline_parallel_enabled(self):
198
+ return self._pp_degree > 1
199
+
200
+ @property
201
+ def data_parallel_enabled(self):
202
+ return self._dp_degree > 1 or self._dp_shards > 1
203
+
204
+ @property
205
+ def data_replication_enabled(self):
206
+ return self._dp_degree > 1
207
+
208
+ @property
209
+ def data_sharding_enabled(self):
210
+ return self._dp_shards > 1
211
+
212
+ @property
213
+ def context_parallel_enabled(self):
214
+ return self._cp_degree > 1
215
+
216
+ @property
217
+ def tensor_parallel_enabled(self):
218
+ return self._tp_degree > 1
finetrainers/parallel/base.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+
6
+ from ..trackers import TrackerType, initialize_trackers
7
+
8
+
9
+ class BaseParallelBackend:
10
+ r"""
11
+ Base class that contains properties and methods that should be implemented by different parallel backends.
12
+ """
13
+
14
+ def apply_ddp(self, *args, **kwargs) -> torch.nn.Module:
15
+ raise NotImplementedError("Method `apply_ddp` must be implemented by subclass.")
16
+
17
+ def prepare_dataset(self, *args, **kwargs) -> Any:
18
+ raise NotImplementedError("Method `prepare_dataset` must be implemented by subclass.")
19
+
20
+ def prepare_dataloader(self, *args, **kwargs) -> Any:
21
+ raise NotImplementedError("Method `prepare_dataloader` must be implemented by subclass.")
22
+
23
+ def prepare_optimizer(self, *args, **kwargs) -> Any:
24
+ raise NotImplementedError("Method `prepare_optimizer` must be implemented by subclass.")
25
+
26
+ def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
27
+ raise NotImplementedError("Method `get_mesh` must be implemented by subclass.")
28
+
29
+ def initialize_trackers(
30
+ self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
31
+ ) -> TrackerType:
32
+ self.tracker = None
33
+ if self.is_main_process:
34
+ self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir)
35
+
36
+ def log(self, metrics: Dict[str, Any], step: int) -> None:
37
+ if self.is_main_process:
38
+ self.tracker.log(metrics, step)
39
+
40
+ def wait_for_everyone(self):
41
+ raise NotImplementedError("Method `wait_for_everyone` must be implemented by subclass.")
42
+
43
+ @contextmanager
44
+ def main_process_first(self):
45
+ raise NotImplementedError("Method `main_process_first` must be implemented by subclass.")
46
+
47
+ def destroy(self):
48
+ raise NotImplementedError("Method `destroy` must be implemented by subclass.")
49
+
50
+ @property
51
+ def world_size(self):
52
+ raise NotImplementedError("Method `world_size` must be implemented by subclass.")
53
+
54
+ @property
55
+ def rank(self):
56
+ raise NotImplementedError("Method `rank` must be implemented by subclass.")
57
+
58
+ @property
59
+ def local_rank(self):
60
+ raise NotImplementedError("Method `local_rank` must be implemented by subclass.")
61
+
62
+ @property
63
+ def is_main_process(self):
64
+ raise NotImplementedError("Method `is_main_process` must be implemented by subclass.")
65
+
66
+ @property
67
+ def is_local_main_process(self):
68
+ raise NotImplementedError("Method `is_local_main_process` must be implemented by subclass.")
69
+
70
+ @property
71
+ def device(self):
72
+ raise NotImplementedError("Method `device` must be implemented by subclass.")
73
+
74
+ @property
75
+ def pipeline_parallel_enabled(self):
76
+ raise NotImplementedError("Property `pipeline_parallel_enabled` must be implemented by subclass.")
77
+
78
+ @property
79
+ def data_parallel_enabled(self):
80
+ raise NotImplementedError("Property `data_parallel_enabled` must be implemented by subclass.")
81
+
82
+ @property
83
+ def data_replication_enabled(self):
84
+ raise NotImplementedError("Property `data_replication_enabled` must be implemented by subclass.")
85
+
86
+ @property
87
+ def data_sharding_enabled(self):
88
+ raise NotImplementedError("Property `data_sharding_enabled` must be implemented by subclass.")
89
+
90
+ @property
91
+ def context_parallel_enabled(self):
92
+ raise NotImplementedError("Property `context_parallel_enabled` must be implemented by subclass.")
93
+
94
+ @property
95
+ def tensor_parallel_enabled(self):
96
+ raise NotImplementedError("Property `tensor_parallel_enabled` must be implemented by subclass.")
finetrainers/parallel/deepspeed.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .base import BaseParallelBackend
2
+
3
+
4
+ class DeepspeedParallelBackend(BaseParallelBackend):
5
+ def __init__(self):
6
+ # TODO(aryan)
7
+ raise NotImplementedError("DeepspeedParallelBackend is not implemented yet.")
finetrainers/parallel/ptd.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import pathlib
4
+ from typing import Optional
5
+
6
+ import datasets.distributed
7
+ import torch
8
+
9
+ from ..data import DPDataLoader
10
+ from ..logging import get_logger
11
+ from ..utils import get_device_info
12
+ from .base import BaseParallelBackend
13
+ from .utils import apply_ddp_ptd
14
+
15
+
16
+ _device_type, _device_module = get_device_info()
17
+ logger = get_logger()
18
+
19
+
20
+ class PytorchDTensorParallelBackend(BaseParallelBackend):
21
+ def __init__(
22
+ self,
23
+ world_size: int,
24
+ pp_degree: int = 1,
25
+ dp_degree: int = 1,
26
+ dp_shards: int = -1,
27
+ cp_degree: int = 1,
28
+ tp_degree: int = 1,
29
+ backend: str = "nccl",
30
+ timeout: int = 180,
31
+ logging_dir: Optional[str] = None,
32
+ output_dir: Optional[str] = None,
33
+ gradient_accumulation_steps: Optional[int] = None,
34
+ ) -> None:
35
+ super().__init__()
36
+
37
+ self._world_size = world_size
38
+ self._pp_degree = pp_degree
39
+ self._dp_degree = dp_degree
40
+ self._dp_shards = dp_shards
41
+ self._cp_degree = cp_degree
42
+ self._tp_degree = tp_degree
43
+ self._output_dir = pathlib.Path(output_dir) if output_dir is not None else None
44
+ self._logging_dir = (
45
+ self._output_dir / logging_dir if output_dir is not None and logging_dir is not None else None
46
+ )
47
+ self._backend = backend
48
+ self._timeout = timeout
49
+
50
+ for degree in [pp_degree, dp_degree, dp_shards, cp_degree, tp_degree]:
51
+ if degree < 1:
52
+ raise ValueError(f"Parallel degree must be at least 1, got {degree}.")
53
+
54
+ if dp_shards * pp_degree * dp_degree * cp_degree * tp_degree != world_size:
55
+ raise ValueError(
56
+ f"World size {world_size} must be divisible by the product of all parallel degrees and data parallel shards."
57
+ )
58
+
59
+ torch.distributed.init_process_group(backend=self._backend, timeout=datetime.timedelta(seconds=self._timeout))
60
+ _device_module.set_device(self.local_rank)
61
+
62
+ logger.info(
63
+ f"Initialized parallel state with:\n"
64
+ f" - World size: {world_size}\n"
65
+ f" - Pipeline parallel degree: {pp_degree}\n"
66
+ f" - Data parallel degree: {dp_degree}\n"
67
+ f" - Context parallel degree: {cp_degree}\n"
68
+ f" - Tensor parallel degree: {tp_degree}\n"
69
+ f" - Data parallel shards: {dp_shards}\n"
70
+ )
71
+
72
+ self._mesh: torch.distributed.DeviceMesh = None
73
+
74
+ def apply_ddp(
75
+ self, model: torch.nn.Module, device_mesh: Optional[torch.distributed.DeviceMesh] = None
76
+ ) -> torch.nn.Module:
77
+ if device_mesh is None:
78
+ device_mesh = self.get_mesh()
79
+ apply_ddp_ptd(model, device_mesh)
80
+ logger.debug("Applied PytorchDTensorParallel::apply_ddp to model.")
81
+ return model
82
+
83
+ def prepare_dataset(self, dataset: torch.utils.data.IterableDataset) -> torch.utils.data.IterableDataset:
84
+ dp_mesh = self.get_mesh("dp_replicate")
85
+ if dp_mesh is None:
86
+ dp_mesh = self.get_mesh()
87
+ if self.world_size > 1:
88
+ dp_local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size()
89
+ else:
90
+ dp_local_rank, dp_world_size = 0, 1
91
+ dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, dp_local_rank, dp_world_size)
92
+ logger.debug("PytorchDTensorParallelBackend::prepare_dataset completed!")
93
+ return dataset
94
+
95
+ def prepare_dataloader(
96
+ self, dataset: torch.utils.data.IterableDataset, batch_size: int, num_workers: int, pin_memory: bool
97
+ ) -> DPDataLoader:
98
+ dp_mesh = self.get_mesh("dp_replicate")
99
+ if dp_mesh is None:
100
+ dp_mesh = self.get_mesh()
101
+ if self.world_size > 1:
102
+ dp_local_rank = dp_mesh.get_local_rank()
103
+ else:
104
+ dp_local_rank = 0
105
+ dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers)
106
+ logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!")
107
+ return dataloader
108
+
109
+ def prepare_optimizer(self, optimizer, lr_scheduler):
110
+ logger.debug("PytorchDTensorParallelBackend::prepare_optimizer completed!")
111
+ return optimizer, lr_scheduler
112
+
113
+ def get_mesh(self, name: Optional[str] = None) -> torch.distributed.DeviceMesh:
114
+ def _get_mesh():
115
+ if name is None:
116
+ return self._mesh
117
+ try:
118
+ return self._mesh[name]
119
+ except (KeyError, RuntimeError):
120
+ if self._mesh.ndim == 0:
121
+ return None
122
+ return self._mesh
123
+
124
+ if self._mesh is not None:
125
+ return _get_mesh()
126
+
127
+ mesh_list = [
128
+ ("pp", self._pp_degree),
129
+ ("dp_replicate", self._dp_degree),
130
+ ("dp_shard", self._dp_shards),
131
+ ("cp", self._cp_degree),
132
+ ("tp", self._tp_degree),
133
+ ]
134
+ mesh_list = [(name, degree) for name, degree in mesh_list if degree > 1]
135
+ names = [x[0] for x in mesh_list]
136
+ degrees = [x[1] for x in mesh_list]
137
+ mesh = torch.distributed.device_mesh.init_device_mesh(_device_type, mesh_shape=degrees, mesh_dim_names=names)
138
+
139
+ dp_mesh_names, dp_cp_mesh_names, dp_shard_cp_mesh_names = [], [], []
140
+
141
+ if self.data_replication_enabled:
142
+ dp_mesh_names.append("dp_replicate")
143
+ dp_cp_mesh_names.append("dp_replicate")
144
+ if self.data_sharding_enabled:
145
+ dp_mesh_names.append("dp_shard")
146
+ dp_cp_mesh_names.append("dp_shard")
147
+ dp_shard_cp_mesh_names.append("dp_shard")
148
+ if self.context_parallel_enabled:
149
+ dp_cp_mesh_names.append("cp")
150
+ dp_shard_cp_mesh_names.append("cp")
151
+
152
+ if len(dp_mesh_names) > 0:
153
+ mesh[tuple(dp_mesh_names)]._flatten(mesh_dim_name="dp")
154
+ if len(dp_cp_mesh_names) > 0:
155
+ mesh[tuple(dp_cp_mesh_names)]._flatten(mesh_dim_name="dp_cp")
156
+ if len(dp_shard_cp_mesh_names) > 0:
157
+ mesh[tuple(dp_shard_cp_mesh_names)]._flatten(mesh_dim_name="dp_shard_cp")
158
+
159
+ logger.debug(f"Device mesh: {mesh}")
160
+ self._mesh = mesh
161
+ return _get_mesh()
162
+
163
+ @property
164
+ def world_size(self):
165
+ return torch.distributed.get_world_size()
166
+
167
+ @property
168
+ def rank(self):
169
+ return torch.distributed.get_rank()
170
+
171
+ @property
172
+ def local_rank(self):
173
+ return int(os.environ.get("LOCAL_RANK", 0))
174
+
175
+ @property
176
+ def is_main_process(self):
177
+ r"""Returns `True` if the current process is the main process on the master node."""
178
+ return self.rank == 0
179
+
180
+ @property
181
+ def is_local_main_process(self):
182
+ r"""Returns `True` if the current process is the main process on local node."""
183
+ return self.local_rank == 0
184
+
185
+ @property
186
+ def device(self):
187
+ return torch.device(_device_type, self.local_rank)
188
+
189
+ def wait_for_everyone(self):
190
+ return torch.distributed.barrier()
191
+
192
+ # @contextmanager
193
+ # def main_process_first(self):
194
+ # if self.is_main_process:
195
+ # yield
196
+ # self.wait_for_everyone()
197
+ # else:
198
+ # self.wait_for_everyone()
199
+ # yield
200
+
201
+ def destroy(self):
202
+ if self.is_main_process:
203
+ self.tracker.finish()
204
+ return torch.distributed.destroy_process_group()
205
+
206
+ @property
207
+ def pipeline_parallel_enabled(self):
208
+ return self._pp_degree > 1
209
+
210
+ @property
211
+ def data_parallel_enabled(self):
212
+ return self._dp_degree > 1 or self._dp_shards > 1
213
+
214
+ @property
215
+ def data_replication_enabled(self):
216
+ return self._dp_degree > 1
217
+
218
+ @property
219
+ def data_sharding_enabled(self):
220
+ return self._dp_shards > 1
221
+
222
+ @property
223
+ def context_parallel_enabled(self):
224
+ return self._cp_degree > 1
225
+
226
+ @property
227
+ def tensor_parallel_enabled(self):
228
+ return self._tp_degree > 1
finetrainers/parallel/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.distributed._functional_collectives as funcol
5
+ import torch.distributed.tensor
6
+ from diffusers.utils import is_accelerate_available
7
+ from torch.distributed._composable.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard
8
+ from torch.distributed._composable.replicate import replicate
9
+
10
+ from ..utils._common import DIFFUSERS_TRANSFORMER_BLOCK_NAMES
11
+
12
+
13
+ if is_accelerate_available():
14
+ from accelerate import Accelerator
15
+ from accelerate.utils import (
16
+ DataLoaderConfiguration,
17
+ DistributedDataParallelKwargs,
18
+ InitProcessGroupKwargs,
19
+ ProjectConfiguration,
20
+ )
21
+
22
+
23
+ def apply_fsdp2_ptd(
24
+ model: torch.nn.Module,
25
+ dp_mesh: torch.distributed.device_mesh.DeviceMesh,
26
+ param_dtype: torch.dtype,
27
+ reduce_dtype: torch.dtype,
28
+ output_dtype: torch.dtype,
29
+ pp_enabled: bool = False,
30
+ cpu_offload: bool = False,
31
+ ) -> None:
32
+ r"""Apply FSDP2 on a model."""
33
+ mp_policy = MixedPrecisionPolicy(param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=True)
34
+ fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
35
+
36
+ if cpu_offload:
37
+ fsdp_config["offload_policy"] = CPUOffloadPolicy(pin_memory=True)
38
+
39
+ def apply_fully_shard(blocks):
40
+ for layer_index, block in enumerate(blocks):
41
+ if pp_enabled:
42
+ # For PP, do not reshard after forward to avoid per-microbatch
43
+ # all-gathers, which can be expensive and non-overlapped
44
+ reshard_after_forward = False
45
+ else:
46
+ # As an optimization, do not reshard after forward for the last
47
+ # transformer block since FSDP would prefetch it immediately
48
+ reshard_after_forward = layer_index < len(blocks) - 1
49
+ fully_shard(block, **fsdp_config, reshard_after_forward=reshard_after_forward)
50
+
51
+ for transformer_block_name in DIFFUSERS_TRANSFORMER_BLOCK_NAMES:
52
+ blocks = getattr(model, transformer_block_name, None)
53
+ if blocks is not None:
54
+ apply_fully_shard(blocks)
55
+
56
+ fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
57
+
58
+
59
+ def apply_ddp_accelerate(
60
+ model: torch.nn.Module,
61
+ project_config: Optional[ProjectConfiguration] = None,
62
+ ddp_kwargs: Optional[DistributedDataParallelKwargs] = None,
63
+ init_process_group_kwargs: Optional[InitProcessGroupKwargs] = None,
64
+ dataloader_config: Optional[DataLoaderConfiguration] = None,
65
+ gradient_accumulation_steps: Optional[int] = None,
66
+ accelerator: Optional[Accelerator] = None,
67
+ ) -> torch.nn.Module:
68
+ if accelerator is None:
69
+ accelerator = Accelerator(
70
+ project_config=project_config,
71
+ dataloader_config=dataloader_config,
72
+ gradient_accumulation_steps=gradient_accumulation_steps,
73
+ log_with=None,
74
+ kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
75
+ )
76
+ if torch.backends.mps.is_available():
77
+ accelerator.native_amp = False
78
+ accelerator.prepare_model(model)
79
+ return accelerator, model
80
+
81
+
82
+ def apply_ddp_ptd(model: torch.nn.Module, dp_mesh: torch.distributed.device_mesh.DeviceMesh) -> None:
83
+ replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
84
+
85
+
86
+ def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
87
+ if isinstance(x, torch.distributed.tensor.DTensor):
88
+ # functional collectives do not support DTensor inputs
89
+ x = x.full_tensor()
90
+ assert x.numel() == 1 # required by `.item()`
91
+ return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
92
+
93
+
94
+ def dist_max(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
95
+ return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.MAX.name, mesh=mesh)
96
+
97
+
98
+ def dist_mean(x: torch.Tensor, mesh: torch.distributed.device_mesh.DeviceMesh) -> float:
99
+ return dist_reduce(x, reduceOp=torch.distributed.distributed_c10d.ReduceOp.AVG.name, mesh=mesh)
finetrainers/patches/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+
3
+
4
+ if TYPE_CHECKING:
5
+ from ..args import BaseArgs
6
+ from ..parallel import ParallelBackendType
7
+
8
+
9
+ def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBackendType") -> None:
10
+ # To avoid circular imports
11
+ from ..config import ModelType, TrainingType
12
+
13
+ if args.model_name == ModelType.LTX_VIDEO:
14
+ from .models.ltx_video import patch
15
+
16
+ patch.patch_transformer_forward()
17
+ if parallel_backend.tensor_parallel_enabled:
18
+ patch.patch_apply_rotary_emb_for_tp_compatibility()
19
+
20
+ if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
21
+ from dependencies.peft import patch
22
+
23
+ patch.patch_peft_move_adapter_to_device_of_base_layer()
finetrainers/{patches.py → patches/dependencies/peft/patch.py} RENAMED
@@ -1,50 +1,25 @@
1
  import functools
2
 
3
- import torch
4
- from accelerate.logging import get_logger
5
  from peft.tuners.tuners_utils import BaseTunerLayer
6
 
7
- from .constants import FINETRAINERS_LOG_LEVEL
8
 
9
 
10
- logger = get_logger("finetrainers") # pylint: disable=invalid-name
11
- logger.setLevel(FINETRAINERS_LOG_LEVEL)
12
-
13
-
14
- def perform_peft_patches() -> None:
15
  _perform_patch_move_adapter_to_device_of_base_layer()
16
 
17
 
18
  def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
19
- # We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights
20
- # are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of
21
- # LoRA weights from higher precision dtype.
22
  BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
23
  BaseTunerLayer._move_adapter_to_device_of_base_layer
24
  )
25
 
26
 
27
  def _patched_move_adapter_to_device_of_base_layer(func) -> None:
 
28
  @functools.wraps(func)
29
  def wrapper(self, *args, **kwargs):
30
  with DisableTensorToDtype():
31
  return func(self, *args, **kwargs)
32
 
33
  return wrapper
34
-
35
-
36
- class DisableTensorToDtype:
37
- def __enter__(self):
38
- self.original_to = torch.Tensor.to
39
-
40
- def modified_to(tensor, *args, **kwargs):
41
- # remove dtype from args if present
42
- args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
43
- if "dtype" in kwargs:
44
- kwargs.pop("dtype")
45
- return self.original_to(tensor, *args, **kwargs)
46
-
47
- torch.Tensor.to = modified_to
48
-
49
- def __exit__(self, exc_type, exc_val, exc_tb):
50
- torch.Tensor.to = self.original_to
 
1
  import functools
2
 
 
 
3
  from peft.tuners.tuners_utils import BaseTunerLayer
4
 
5
+ from ...utils import DisableTensorToDtype
6
 
7
 
8
+ def patch_peft_move_adapter_to_device_of_base_layer() -> None:
 
 
 
 
9
  _perform_patch_move_adapter_to_device_of_base_layer()
10
 
11
 
12
  def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
 
 
 
13
  BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
14
  BaseTunerLayer._move_adapter_to_device_of_base_layer
15
  )
16
 
17
 
18
  def _patched_move_adapter_to_device_of_base_layer(func) -> None:
19
+ # TODO(aryan): This is really unsafe probably and may break things. It works for now, but revisit and refactor.
20
  @functools.wraps(func)
21
  def wrapper(self, *args, **kwargs):
22
  with DisableTensorToDtype():
23
  return func(self, *args, **kwargs)
24
 
25
  return wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
finetrainers/patches/models/ltx_video/patch.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple
2
+
3
+ import diffusers
4
+ import torch
5
+ from diffusers import LTXVideoTransformer3DModel
6
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
7
+ from diffusers.utils.import_utils import is_torch_version
8
+
9
+
10
+ def patch_transformer_forward() -> None:
11
+ _perform_ltx_transformer_forward_patch()
12
+
13
+
14
+ def patch_apply_rotary_emb_for_tp_compatibility() -> None:
15
+ _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch()
16
+
17
+
18
+ def _perform_ltx_transformer_forward_patch() -> None:
19
+ LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3Dforward
20
+
21
+
22
+ def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
23
+ def apply_rotary_emb(x, freqs):
24
+ cos, sin = freqs
25
+ # ======== THIS IS CHANGED FROM THE ORIGINAL IMPLEMENTATION ========
26
+ # The change is made due to unsupported DTensor operation aten.ops.unbind
27
+ # FIXME: Once aten.ops.unbind support lands, this will no longer be required
28
+ # x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
29
+ x_real, x_imag = x.unflatten(2, (-1, 2)).chunk(2, dim=-1) # [B, S, H, D // 2]
30
+ # ==================================================================
31
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
32
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
33
+ return out
34
+
35
+ diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
36
+
37
+
38
+ def _patched_LTXVideoTransformer3Dforward(
39
+ self,
40
+ hidden_states: torch.Tensor,
41
+ encoder_hidden_states: torch.Tensor,
42
+ timestep: torch.LongTensor,
43
+ encoder_attention_mask: torch.Tensor,
44
+ num_frames: int,
45
+ height: int,
46
+ width: int,
47
+ rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
48
+ return_dict: bool = True,
49
+ *args,
50
+ **kwargs,
51
+ ) -> torch.Tensor:
52
+ image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
53
+
54
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
55
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
56
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
57
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
58
+
59
+ batch_size = hidden_states.size(0)
60
+
61
+ # ===== This is modified compared to Diffusers =====
62
+ # This is done because the Diffusers pipeline will pass in a 1D tensor for timestep
63
+ if timestep.ndim == 1:
64
+ timestep = timestep.view(-1, 1, 1).expand(-1, *hidden_states.shape[1:-1], -1)
65
+ # ==================================================
66
+
67
+ temb, embedded_timestep = self.time_embed(
68
+ timestep.flatten(),
69
+ batch_size=batch_size,
70
+ hidden_dtype=hidden_states.dtype,
71
+ )
72
+
73
+ # ===== This is modified compared to Diffusers =====
74
+ # temb = temb.view(batch_size, -1, temb.size(-1))
75
+ # embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
76
+ # ==================================================
77
+ # This is done to make it possible to use per-token timestep embedding
78
+ temb = temb.view(batch_size, *hidden_states.shape[1:-1], temb.size(-1))
79
+ embedded_timestep = embedded_timestep.view(batch_size, *hidden_states.shape[1:-1], embedded_timestep.size(-1))
80
+ # ==================================================
81
+
82
+ hidden_states = self.proj_in(hidden_states)
83
+
84
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
85
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
86
+
87
+ for block in self.transformer_blocks:
88
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
89
+
90
+ def create_custom_forward(module, return_dict=None):
91
+ def custom_forward(*inputs):
92
+ if return_dict is not None:
93
+ return module(*inputs, return_dict=return_dict)
94
+ else:
95
+ return module(*inputs)
96
+
97
+ return custom_forward
98
+
99
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
100
+ hidden_states = torch.utils.checkpoint.checkpoint(
101
+ create_custom_forward(block),
102
+ hidden_states,
103
+ encoder_hidden_states,
104
+ temb,
105
+ image_rotary_emb,
106
+ encoder_attention_mask,
107
+ **ckpt_kwargs,
108
+ )
109
+ else:
110
+ hidden_states = block(
111
+ hidden_states=hidden_states,
112
+ encoder_hidden_states=encoder_hidden_states,
113
+ temb=temb,
114
+ image_rotary_emb=image_rotary_emb,
115
+ encoder_attention_mask=encoder_attention_mask,
116
+ )
117
+
118
+ scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
119
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
120
+
121
+ hidden_states = self.norm_out(hidden_states)
122
+ hidden_states = hidden_states * (1 + scale) + shift
123
+ output = self.proj_out(hidden_states)
124
+
125
+ if not return_dict:
126
+ return (output,)
127
+ return Transformer2DModelOutput(sample=output)
finetrainers/patches/utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class DisableTensorToDtype:
5
+ def __enter__(self):
6
+ self.original_to = torch.Tensor.to
7
+
8
+ def modified_to(tensor, *args, **kwargs):
9
+ # remove dtype from args if present
10
+ args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
11
+ if "dtype" in kwargs:
12
+ kwargs.pop("dtype")
13
+ return self.original_to(tensor, *args, **kwargs)
14
+
15
+ torch.Tensor.to = modified_to
16
+
17
+ def __exit__(self, *args, **kwargs):
18
+ torch.Tensor.to = self.original_to