aman81 commited on
Commit
24cdb2a
·
verified ·
1 Parent(s): 20608e3

Delete train_stage_2.py

Browse files
Files changed (1) hide show
  1. train_stage_2.py +0 -773
train_stage_2.py DELETED
@@ -1,773 +0,0 @@
1
- import argparse
2
- import copy
3
- import logging
4
- import math
5
- import os
6
- import os.path as osp
7
- import random
8
- import time
9
- import warnings
10
- from collections import OrderedDict
11
- from datetime import datetime
12
- from pathlib import Path
13
- from tempfile import TemporaryDirectory
14
-
15
- import diffusers
16
- import mlflow
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- import torch.utils.checkpoint
21
- import transformers
22
- from accelerate import Accelerator
23
- from accelerate.logging import get_logger
24
- from accelerate.utils import DistributedDataParallelKwargs
25
- from diffusers import AutoencoderKL, DDIMScheduler
26
- from diffusers.optimization import get_scheduler
27
- from diffusers.utils import check_min_version
28
- from diffusers.utils.import_utils import is_xformers_available
29
- from einops import rearrange
30
- from omegaconf import OmegaConf
31
- from PIL import Image
32
- from torchvision import transforms
33
- from tqdm.auto import tqdm
34
- from transformers import CLIPVisionModelWithProjection
35
-
36
- from src.dataset.dance_video import HumanDanceVideoDataset
37
- from src.models.mutual_self_attention import ReferenceAttentionControl
38
- from src.models.pose_guider import PoseGuider
39
- from src.models.unet_2d_condition import UNet2DConditionModel
40
- from src.models.unet_3d import UNet3DConditionModel
41
- from src.pipelines.pipeline_pose2vid import Pose2VideoPipeline
42
- from src.utils.util import (
43
- delete_additional_ckpt,
44
- import_filename,
45
- read_frames,
46
- save_videos_grid,
47
- seed_everything,
48
- )
49
-
50
- warnings.filterwarnings("ignore")
51
-
52
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
53
- check_min_version("0.10.0.dev0")
54
-
55
- logger = get_logger(__name__, log_level="INFO")
56
-
57
-
58
- class Net(nn.Module):
59
- def __init__(
60
- self,
61
- reference_unet: UNet2DConditionModel,
62
- denoising_unet: UNet3DConditionModel,
63
- pose_guider: PoseGuider,
64
- reference_control_writer,
65
- reference_control_reader,
66
- ):
67
- super().__init__()
68
- self.reference_unet = reference_unet
69
- self.denoising_unet = denoising_unet
70
- self.pose_guider = pose_guider
71
- self.reference_control_writer = reference_control_writer
72
- self.reference_control_reader = reference_control_reader
73
-
74
- def forward(
75
- self,
76
- noisy_latents,
77
- timesteps,
78
- ref_image_latents,
79
- clip_image_embeds,
80
- pose_img,
81
- uncond_fwd: bool = False,
82
- ):
83
- pose_cond_tensor = pose_img.to(device="cuda")
84
- pose_fea = self.pose_guider(pose_cond_tensor)
85
-
86
- if not uncond_fwd:
87
- ref_timesteps = torch.zeros_like(timesteps)
88
- self.reference_unet(
89
- ref_image_latents,
90
- ref_timesteps,
91
- encoder_hidden_states=clip_image_embeds,
92
- return_dict=False,
93
- )
94
- self.reference_control_reader.update(self.reference_control_writer)
95
-
96
- model_pred = self.denoising_unet(
97
- noisy_latents,
98
- timesteps,
99
- pose_cond_fea=pose_fea,
100
- encoder_hidden_states=clip_image_embeds,
101
- ).sample
102
-
103
- return model_pred
104
-
105
-
106
- def compute_snr(noise_scheduler, timesteps):
107
- """
108
- Computes SNR as per
109
- https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
110
- """
111
- alphas_cumprod = noise_scheduler.alphas_cumprod
112
- sqrt_alphas_cumprod = alphas_cumprod**0.5
113
- sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
114
-
115
- # Expand the tensors.
116
- # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
117
- sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
118
- timesteps
119
- ].float()
120
- while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
121
- sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
122
- alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
123
-
124
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
125
- device=timesteps.device
126
- )[timesteps].float()
127
- while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
128
- sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
129
- sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
130
-
131
- # Compute SNR.
132
- snr = (alpha / sigma) ** 2
133
- return snr
134
-
135
-
136
- def log_validation(
137
- vae,
138
- image_enc,
139
- net,
140
- scheduler,
141
- accelerator,
142
- width,
143
- height,
144
- clip_length=24,
145
- generator=None,
146
- ):
147
- logger.info("Running validation... ")
148
-
149
- ori_net = accelerator.unwrap_model(net)
150
- reference_unet = ori_net.reference_unet
151
- denoising_unet = ori_net.denoising_unet
152
- pose_guider = ori_net.pose_guider
153
-
154
- if generator is None:
155
- generator = torch.manual_seed(42)
156
- tmp_denoising_unet = copy.deepcopy(denoising_unet)
157
- tmp_denoising_unet = tmp_denoising_unet.to(dtype=torch.float16)
158
-
159
- pipe = Pose2VideoPipeline(
160
- vae=vae,
161
- image_encoder=image_enc,
162
- reference_unet=reference_unet,
163
- denoising_unet=tmp_denoising_unet,
164
- pose_guider=pose_guider,
165
- scheduler=scheduler,
166
- )
167
- pipe = pipe.to(accelerator.device)
168
-
169
- test_cases = [
170
- (
171
- "./configs/inference/ref_images/anyone-3.png",
172
- "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
173
- ),
174
- (
175
- "./configs/inference/ref_images/anyone-2.png",
176
- "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
177
- ),
178
- ]
179
-
180
- results = []
181
- for test_case in test_cases:
182
- ref_image_path, pose_video_path = test_case
183
- ref_name = Path(ref_image_path).stem
184
- pose_name = Path(pose_video_path).stem
185
- ref_image_pil = Image.open(ref_image_path).convert("RGB")
186
-
187
- pose_list = []
188
- pose_tensor_list = []
189
- pose_images = read_frames(pose_video_path)
190
- pose_transform = transforms.Compose(
191
- [transforms.Resize((height, width)), transforms.ToTensor()]
192
- )
193
- for pose_image_pil in pose_images[:clip_length]:
194
- pose_tensor_list.append(pose_transform(pose_image_pil))
195
- pose_list.append(pose_image_pil)
196
-
197
- pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
198
- pose_tensor = pose_tensor.transpose(0, 1)
199
-
200
- pipeline_output = pipe(
201
- ref_image_pil,
202
- pose_list,
203
- width,
204
- height,
205
- clip_length,
206
- 20,
207
- 3.5,
208
- generator=generator,
209
- )
210
- video = pipeline_output.videos
211
-
212
- # Concat it with pose tensor
213
- pose_tensor = pose_tensor.unsqueeze(0)
214
- video = torch.cat([video, pose_tensor], dim=0)
215
-
216
- results.append({"name": f"{ref_name}_{pose_name}", "vid": video})
217
-
218
- del tmp_denoising_unet
219
- del pipe
220
- torch.cuda.empty_cache()
221
-
222
- return results
223
-
224
-
225
- def main(cfg):
226
- kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
227
- accelerator = Accelerator(
228
- gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
229
- mixed_precision=cfg.solver.mixed_precision,
230
- log_with="mlflow",
231
- project_dir="./mlruns",
232
- kwargs_handlers=[kwargs],
233
- )
234
-
235
- # Make one log on every process with the configuration for debugging.
236
- logging.basicConfig(
237
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
238
- datefmt="%m/%d/%Y %H:%M:%S",
239
- level=logging.INFO,
240
- )
241
- logger.info(accelerator.state, main_process_only=False)
242
- if accelerator.is_local_main_process:
243
- transformers.utils.logging.set_verbosity_warning()
244
- diffusers.utils.logging.set_verbosity_info()
245
- else:
246
- transformers.utils.logging.set_verbosity_error()
247
- diffusers.utils.logging.set_verbosity_error()
248
-
249
- # If passed along, set the training seed now.
250
- if cfg.seed is not None:
251
- seed_everything(cfg.seed)
252
-
253
- exp_name = cfg.exp_name
254
- save_dir = f"{cfg.output_dir}/{exp_name}"
255
- if accelerator.is_main_process:
256
- if not os.path.exists(save_dir):
257
- os.makedirs(save_dir)
258
-
259
- inference_config_path = "./configs/inference/inference_v2.yaml"
260
- infer_config = OmegaConf.load(inference_config_path)
261
-
262
- if cfg.weight_dtype == "fp16":
263
- weight_dtype = torch.float16
264
- elif cfg.weight_dtype == "fp32":
265
- weight_dtype = torch.float32
266
- else:
267
- raise ValueError(
268
- f"Do not support weight dtype: {cfg.weight_dtype} during training"
269
- )
270
-
271
- sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
272
- if cfg.enable_zero_snr:
273
- sched_kwargs.update(
274
- rescale_betas_zero_snr=True,
275
- timestep_spacing="trailing",
276
- prediction_type="v_prediction",
277
- )
278
- val_noise_scheduler = DDIMScheduler(**sched_kwargs)
279
- sched_kwargs.update({"beta_schedule": "scaled_linear"})
280
- train_noise_scheduler = DDIMScheduler(**sched_kwargs)
281
-
282
- image_enc = CLIPVisionModelWithProjection.from_pretrained(
283
- cfg.image_encoder_path,
284
- ).to(dtype=weight_dtype, device="cuda")
285
- vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
286
- "cuda", dtype=weight_dtype
287
- )
288
- reference_unet = UNet2DConditionModel.from_pretrained(
289
- cfg.base_model_path,
290
- subfolder="unet",
291
- ).to(device="cuda", dtype=weight_dtype)
292
-
293
- denoising_unet = UNet3DConditionModel.from_pretrained_2d(
294
- cfg.base_model_path,
295
- cfg.mm_path,
296
- subfolder="unet",
297
- unet_additional_kwargs=OmegaConf.to_container(
298
- infer_config.unet_additional_kwargs
299
- ),
300
- ).to(device="cuda")
301
-
302
- pose_guider = PoseGuider(
303
- conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
304
- ).to(device="cuda", dtype=weight_dtype)
305
-
306
- stage1_ckpt_dir = cfg.stage1_ckpt_dir
307
- stage1_ckpt_step = cfg.stage1_ckpt_step
308
- denoising_unet.load_state_dict(
309
- torch.load(
310
- os.path.join(stage1_ckpt_dir, f"denoising_unet-{stage1_ckpt_step}.pth"),
311
- map_location="cpu",
312
- ),
313
- strict=False,
314
- )
315
- reference_unet.load_state_dict(
316
- torch.load(
317
- os.path.join(stage1_ckpt_dir, f"reference_unet-{stage1_ckpt_step}.pth"),
318
- map_location="cpu",
319
- ),
320
- strict=False,
321
- )
322
- pose_guider.load_state_dict(
323
- torch.load(
324
- os.path.join(stage1_ckpt_dir, f"pose_guider-{stage1_ckpt_step}.pth"),
325
- map_location="cpu",
326
- ),
327
- strict=False,
328
- )
329
-
330
- # Freeze
331
- vae.requires_grad_(False)
332
- image_enc.requires_grad_(False)
333
- reference_unet.requires_grad_(False)
334
- denoising_unet.requires_grad_(False)
335
- pose_guider.requires_grad_(False)
336
-
337
- # Set motion module learnable
338
- for name, module in denoising_unet.named_modules():
339
- if "motion_modules" in name:
340
- for params in module.parameters():
341
- params.requires_grad = True
342
-
343
- reference_control_writer = ReferenceAttentionControl(
344
- reference_unet,
345
- do_classifier_free_guidance=False,
346
- mode="write",
347
- fusion_blocks="full",
348
- )
349
- reference_control_reader = ReferenceAttentionControl(
350
- denoising_unet,
351
- do_classifier_free_guidance=False,
352
- mode="read",
353
- fusion_blocks="full",
354
- )
355
-
356
- net = Net(
357
- reference_unet,
358
- denoising_unet,
359
- pose_guider,
360
- reference_control_writer,
361
- reference_control_reader,
362
- )
363
-
364
- if cfg.solver.enable_xformers_memory_efficient_attention:
365
- if is_xformers_available():
366
- reference_unet.enable_xformers_memory_efficient_attention()
367
- denoising_unet.enable_xformers_memory_efficient_attention()
368
- else:
369
- raise ValueError(
370
- "xformers is not available. Make sure it is installed correctly"
371
- )
372
-
373
- if cfg.solver.gradient_checkpointing:
374
- reference_unet.enable_gradient_checkpointing()
375
- denoising_unet.enable_gradient_checkpointing()
376
-
377
- if cfg.solver.scale_lr:
378
- learning_rate = (
379
- cfg.solver.learning_rate
380
- * cfg.solver.gradient_accumulation_steps
381
- * cfg.data.train_bs
382
- * accelerator.num_processes
383
- )
384
- else:
385
- learning_rate = cfg.solver.learning_rate
386
-
387
- # Initialize the optimizer
388
- if cfg.solver.use_8bit_adam:
389
- try:
390
- import bitsandbytes as bnb
391
- except ImportError:
392
- raise ImportError(
393
- "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
394
- )
395
-
396
- optimizer_cls = bnb.optim.AdamW8bit
397
- else:
398
- optimizer_cls = torch.optim.AdamW
399
-
400
- trainable_params = list(filter(lambda p: p.requires_grad, net.parameters()))
401
- logger.info(f"Total trainable params {len(trainable_params)}")
402
- optimizer = optimizer_cls(
403
- trainable_params,
404
- lr=learning_rate,
405
- betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
406
- weight_decay=cfg.solver.adam_weight_decay,
407
- eps=cfg.solver.adam_epsilon,
408
- )
409
-
410
- # Scheduler
411
- lr_scheduler = get_scheduler(
412
- cfg.solver.lr_scheduler,
413
- optimizer=optimizer,
414
- num_warmup_steps=cfg.solver.lr_warmup_steps
415
- * cfg.solver.gradient_accumulation_steps,
416
- num_training_steps=cfg.solver.max_train_steps
417
- * cfg.solver.gradient_accumulation_steps,
418
- )
419
-
420
- train_dataset = HumanDanceVideoDataset(
421
- width=cfg.data.train_width,
422
- height=cfg.data.train_height,
423
- n_sample_frames=cfg.data.n_sample_frames,
424
- sample_rate=cfg.data.sample_rate,
425
- img_scale=(1.0, 1.0),
426
- data_meta_paths=cfg.data.meta_paths,
427
- )
428
- train_dataloader = torch.utils.data.DataLoader(
429
- train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
430
- )
431
-
432
- # Prepare everything with our `accelerator`.
433
- (
434
- net,
435
- optimizer,
436
- train_dataloader,
437
- lr_scheduler,
438
- ) = accelerator.prepare(
439
- net,
440
- optimizer,
441
- train_dataloader,
442
- lr_scheduler,
443
- )
444
-
445
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
446
- num_update_steps_per_epoch = math.ceil(
447
- len(train_dataloader) / cfg.solver.gradient_accumulation_steps
448
- )
449
- # Afterwards we recalculate our number of training epochs
450
- num_train_epochs = math.ceil(
451
- cfg.solver.max_train_steps / num_update_steps_per_epoch
452
- )
453
-
454
- # We need to initialize the trackers we use, and also store our configuration.
455
- # The trackers initializes automatically on the main process.
456
- if accelerator.is_main_process:
457
- run_time = datetime.now().strftime("%Y%m%d-%H%M")
458
- accelerator.init_trackers(
459
- exp_name,
460
- init_kwargs={"mlflow": {"run_name": run_time}},
461
- )
462
- # dump config file
463
- mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
464
-
465
- # Train!
466
- total_batch_size = (
467
- cfg.data.train_bs
468
- * accelerator.num_processes
469
- * cfg.solver.gradient_accumulation_steps
470
- )
471
-
472
- logger.info("***** Running training *****")
473
- logger.info(f" Num examples = {len(train_dataset)}")
474
- logger.info(f" Num Epochs = {num_train_epochs}")
475
- logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
476
- logger.info(
477
- f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
478
- )
479
- logger.info(
480
- f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
481
- )
482
- logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
483
- global_step = 0
484
- first_epoch = 0
485
-
486
- # Potentially load in the weights and states from a previous save
487
- if cfg.resume_from_checkpoint:
488
- if cfg.resume_from_checkpoint != "latest":
489
- resume_dir = cfg.resume_from_checkpoint
490
- else:
491
- resume_dir = save_dir
492
- # Get the most recent checkpoint
493
- dirs = os.listdir(resume_dir)
494
- dirs = [d for d in dirs if d.startswith("checkpoint")]
495
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
496
- path = dirs[-1]
497
- accelerator.load_state(os.path.join(resume_dir, path))
498
- accelerator.print(f"Resuming from checkpoint {path}")
499
- global_step = int(path.split("-")[1])
500
-
501
- first_epoch = global_step // num_update_steps_per_epoch
502
- resume_step = global_step % num_update_steps_per_epoch
503
-
504
- # Only show the progress bar once on each machine.
505
- progress_bar = tqdm(
506
- range(global_step, cfg.solver.max_train_steps),
507
- disable=not accelerator.is_local_main_process,
508
- )
509
- progress_bar.set_description("Steps")
510
-
511
- for epoch in range(first_epoch, num_train_epochs):
512
- train_loss = 0.0
513
- t_data_start = time.time()
514
- for step, batch in enumerate(train_dataloader):
515
- t_data = time.time() - t_data_start
516
- with accelerator.accumulate(net):
517
- # Convert videos to latent space
518
- pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype)
519
- with torch.no_grad():
520
- video_length = pixel_values_vid.shape[1]
521
- pixel_values_vid = rearrange(
522
- pixel_values_vid, "b f c h w -> (b f) c h w"
523
- )
524
- latents = vae.encode(pixel_values_vid).latent_dist.sample()
525
- latents = rearrange(
526
- latents, "(b f) c h w -> b c f h w", f=video_length
527
- )
528
- latents = latents * 0.18215
529
-
530
- noise = torch.randn_like(latents)
531
- if cfg.noise_offset > 0:
532
- noise += cfg.noise_offset * torch.randn(
533
- (latents.shape[0], latents.shape[1], 1, 1, 1),
534
- device=latents.device,
535
- )
536
- bsz = latents.shape[0]
537
- # Sample a random timestep for each video
538
- timesteps = torch.randint(
539
- 0,
540
- train_noise_scheduler.num_train_timesteps,
541
- (bsz,),
542
- device=latents.device,
543
- )
544
- timesteps = timesteps.long()
545
-
546
- pixel_values_pose = batch["pixel_values_pose"] # (bs, f, c, H, W)
547
- pixel_values_pose = pixel_values_pose.transpose(
548
- 1, 2
549
- ) # (bs, c, f, H, W)
550
-
551
- uncond_fwd = random.random() < cfg.uncond_ratio
552
- clip_image_list = []
553
- ref_image_list = []
554
- for batch_idx, (ref_img, clip_img) in enumerate(
555
- zip(
556
- batch["pixel_values_ref_img"],
557
- batch["clip_ref_img"],
558
- )
559
- ):
560
- if uncond_fwd:
561
- clip_image_list.append(torch.zeros_like(clip_img))
562
- else:
563
- clip_image_list.append(clip_img)
564
- ref_image_list.append(ref_img)
565
-
566
- with torch.no_grad():
567
- ref_img = torch.stack(ref_image_list, dim=0).to(
568
- dtype=vae.dtype, device=vae.device
569
- )
570
- ref_image_latents = vae.encode(
571
- ref_img
572
- ).latent_dist.sample() # (bs, d, 64, 64)
573
- ref_image_latents = ref_image_latents * 0.18215
574
-
575
- clip_img = torch.stack(clip_image_list, dim=0).to(
576
- dtype=image_enc.dtype, device=image_enc.device
577
- )
578
- clip_img = clip_img.to(device="cuda", dtype=weight_dtype)
579
- clip_image_embeds = image_enc(
580
- clip_img.to("cuda", dtype=weight_dtype)
581
- ).image_embeds
582
- clip_image_embeds = clip_image_embeds.unsqueeze(1) # (bs, 1, d)
583
-
584
- # add noise
585
- noisy_latents = train_noise_scheduler.add_noise(
586
- latents, noise, timesteps
587
- )
588
-
589
- # Get the target for loss depending on the prediction type
590
- if train_noise_scheduler.prediction_type == "epsilon":
591
- target = noise
592
- elif train_noise_scheduler.prediction_type == "v_prediction":
593
- target = train_noise_scheduler.get_velocity(
594
- latents, noise, timesteps
595
- )
596
- else:
597
- raise ValueError(
598
- f"Unknown prediction type {train_noise_scheduler.prediction_type}"
599
- )
600
-
601
- # ---- Forward!!! -----
602
- model_pred = net(
603
- noisy_latents,
604
- timesteps,
605
- ref_image_latents,
606
- clip_image_embeds,
607
- pixel_values_pose,
608
- uncond_fwd=uncond_fwd,
609
- )
610
-
611
- if cfg.snr_gamma == 0:
612
- loss = F.mse_loss(
613
- model_pred.float(), target.float(), reduction="mean"
614
- )
615
- else:
616
- snr = compute_snr(train_noise_scheduler, timesteps)
617
- if train_noise_scheduler.config.prediction_type == "v_prediction":
618
- # Velocity objective requires that we add one to SNR values before we divide by them.
619
- snr = snr + 1
620
- mse_loss_weights = (
621
- torch.stack(
622
- [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
623
- ).min(dim=1)[0]
624
- / snr
625
- )
626
- loss = F.mse_loss(
627
- model_pred.float(), target.float(), reduction="none"
628
- )
629
- loss = (
630
- loss.mean(dim=list(range(1, len(loss.shape))))
631
- * mse_loss_weights
632
- )
633
- loss = loss.mean()
634
-
635
- # Gather the losses across all processes for logging (if we use distributed training).
636
- avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
637
- train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
638
-
639
- # Backpropagate
640
- accelerator.backward(loss)
641
- if accelerator.sync_gradients:
642
- accelerator.clip_grad_norm_(
643
- trainable_params,
644
- cfg.solver.max_grad_norm,
645
- )
646
- optimizer.step()
647
- lr_scheduler.step()
648
- optimizer.zero_grad()
649
-
650
- if accelerator.sync_gradients:
651
- reference_control_reader.clear()
652
- reference_control_writer.clear()
653
- progress_bar.update(1)
654
- global_step += 1
655
- accelerator.log({"train_loss": train_loss}, step=global_step)
656
- train_loss = 0.0
657
-
658
- if global_step % cfg.val.validation_steps == 0:
659
- if accelerator.is_main_process:
660
- generator = torch.Generator(device=accelerator.device)
661
- generator.manual_seed(cfg.seed)
662
-
663
- sample_dicts = log_validation(
664
- vae=vae,
665
- image_enc=image_enc,
666
- net=net,
667
- scheduler=val_noise_scheduler,
668
- accelerator=accelerator,
669
- width=cfg.data.train_width,
670
- height=cfg.data.train_height,
671
- clip_length=cfg.data.n_sample_frames,
672
- generator=generator,
673
- )
674
-
675
- for sample_id, sample_dict in enumerate(sample_dicts):
676
- sample_name = sample_dict["name"]
677
- vid = sample_dict["vid"]
678
- with TemporaryDirectory() as temp_dir:
679
- out_file = Path(
680
- f"{temp_dir}/{global_step:06d}-{sample_name}.gif"
681
- )
682
- save_videos_grid(vid, out_file, n_rows=2)
683
- mlflow.log_artifact(out_file)
684
-
685
- logs = {
686
- "step_loss": loss.detach().item(),
687
- "lr": lr_scheduler.get_last_lr()[0],
688
- "td": f"{t_data:.2f}s",
689
- }
690
- t_data_start = time.time()
691
- progress_bar.set_postfix(**logs)
692
-
693
- if global_step >= cfg.solver.max_train_steps:
694
- break
695
- # save model after each epoch
696
- if accelerator.is_main_process:
697
- save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
698
- delete_additional_ckpt(save_dir, 1)
699
- accelerator.save_state(save_path)
700
- # save motion module only
701
- unwrap_net = accelerator.unwrap_model(net)
702
- save_checkpoint(
703
- unwrap_net.denoising_unet,
704
- save_dir,
705
- "motion_module",
706
- global_step,
707
- total_limit=3,
708
- )
709
-
710
- # Create the pipeline using the trained modules and save it.
711
- accelerator.wait_for_everyone()
712
- accelerator.end_training()
713
-
714
-
715
- def save_checkpoint(model, save_dir, prefix, ckpt_num, total_limit=None):
716
- save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
717
-
718
- if total_limit is not None:
719
- checkpoints = os.listdir(save_dir)
720
- checkpoints = [d for d in checkpoints if d.startswith(prefix)]
721
- checkpoints = sorted(
722
- checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
723
- )
724
-
725
- if len(checkpoints) >= total_limit:
726
- num_to_remove = len(checkpoints) - total_limit + 1
727
- removing_checkpoints = checkpoints[0:num_to_remove]
728
- logger.info(
729
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
730
- )
731
- logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
732
-
733
- for removing_checkpoint in removing_checkpoints:
734
- removing_checkpoint = os.path.join(save_dir, removing_checkpoint)
735
- os.remove(removing_checkpoint)
736
-
737
- mm_state_dict = OrderedDict()
738
- state_dict = model.state_dict()
739
- for key in state_dict:
740
- if "motion_module" in key:
741
- mm_state_dict[key] = state_dict[key]
742
-
743
- torch.save(mm_state_dict, save_path)
744
-
745
-
746
- def decode_latents(vae, latents):
747
- video_length = latents.shape[2]
748
- latents = 1 / 0.18215 * latents
749
- latents = rearrange(latents, "b c f h w -> (b f) c h w")
750
- # video = self.vae.decode(latents).sample
751
- video = []
752
- for frame_idx in tqdm(range(latents.shape[0])):
753
- video.append(vae.decode(latents[frame_idx : frame_idx + 1]).sample)
754
- video = torch.cat(video)
755
- video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
756
- video = (video / 2 + 0.5).clamp(0, 1)
757
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
758
- video = video.cpu().float().numpy()
759
- return video
760
-
761
-
762
- if __name__ == "__main__":
763
- parser = argparse.ArgumentParser()
764
- parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
765
- args = parser.parse_args()
766
-
767
- if args.config[-5:] == ".yaml":
768
- config = OmegaConf.load(args.config)
769
- elif args.config[-3:] == ".py":
770
- config = import_filename(args.config).cfg
771
- else:
772
- raise ValueError("Do not support this format config file")
773
- main(config)