Spaces:
No application file
No application file
Delete train_stage_2.py
Browse files- train_stage_2.py +0 -773
train_stage_2.py
DELETED
@@ -1,773 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import copy
|
3 |
-
import logging
|
4 |
-
import math
|
5 |
-
import os
|
6 |
-
import os.path as osp
|
7 |
-
import random
|
8 |
-
import time
|
9 |
-
import warnings
|
10 |
-
from collections import OrderedDict
|
11 |
-
from datetime import datetime
|
12 |
-
from pathlib import Path
|
13 |
-
from tempfile import TemporaryDirectory
|
14 |
-
|
15 |
-
import diffusers
|
16 |
-
import mlflow
|
17 |
-
import torch
|
18 |
-
import torch.nn as nn
|
19 |
-
import torch.nn.functional as F
|
20 |
-
import torch.utils.checkpoint
|
21 |
-
import transformers
|
22 |
-
from accelerate import Accelerator
|
23 |
-
from accelerate.logging import get_logger
|
24 |
-
from accelerate.utils import DistributedDataParallelKwargs
|
25 |
-
from diffusers import AutoencoderKL, DDIMScheduler
|
26 |
-
from diffusers.optimization import get_scheduler
|
27 |
-
from diffusers.utils import check_min_version
|
28 |
-
from diffusers.utils.import_utils import is_xformers_available
|
29 |
-
from einops import rearrange
|
30 |
-
from omegaconf import OmegaConf
|
31 |
-
from PIL import Image
|
32 |
-
from torchvision import transforms
|
33 |
-
from tqdm.auto import tqdm
|
34 |
-
from transformers import CLIPVisionModelWithProjection
|
35 |
-
|
36 |
-
from src.dataset.dance_video import HumanDanceVideoDataset
|
37 |
-
from src.models.mutual_self_attention import ReferenceAttentionControl
|
38 |
-
from src.models.pose_guider import PoseGuider
|
39 |
-
from src.models.unet_2d_condition import UNet2DConditionModel
|
40 |
-
from src.models.unet_3d import UNet3DConditionModel
|
41 |
-
from src.pipelines.pipeline_pose2vid import Pose2VideoPipeline
|
42 |
-
from src.utils.util import (
|
43 |
-
delete_additional_ckpt,
|
44 |
-
import_filename,
|
45 |
-
read_frames,
|
46 |
-
save_videos_grid,
|
47 |
-
seed_everything,
|
48 |
-
)
|
49 |
-
|
50 |
-
warnings.filterwarnings("ignore")
|
51 |
-
|
52 |
-
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
53 |
-
check_min_version("0.10.0.dev0")
|
54 |
-
|
55 |
-
logger = get_logger(__name__, log_level="INFO")
|
56 |
-
|
57 |
-
|
58 |
-
class Net(nn.Module):
|
59 |
-
def __init__(
|
60 |
-
self,
|
61 |
-
reference_unet: UNet2DConditionModel,
|
62 |
-
denoising_unet: UNet3DConditionModel,
|
63 |
-
pose_guider: PoseGuider,
|
64 |
-
reference_control_writer,
|
65 |
-
reference_control_reader,
|
66 |
-
):
|
67 |
-
super().__init__()
|
68 |
-
self.reference_unet = reference_unet
|
69 |
-
self.denoising_unet = denoising_unet
|
70 |
-
self.pose_guider = pose_guider
|
71 |
-
self.reference_control_writer = reference_control_writer
|
72 |
-
self.reference_control_reader = reference_control_reader
|
73 |
-
|
74 |
-
def forward(
|
75 |
-
self,
|
76 |
-
noisy_latents,
|
77 |
-
timesteps,
|
78 |
-
ref_image_latents,
|
79 |
-
clip_image_embeds,
|
80 |
-
pose_img,
|
81 |
-
uncond_fwd: bool = False,
|
82 |
-
):
|
83 |
-
pose_cond_tensor = pose_img.to(device="cuda")
|
84 |
-
pose_fea = self.pose_guider(pose_cond_tensor)
|
85 |
-
|
86 |
-
if not uncond_fwd:
|
87 |
-
ref_timesteps = torch.zeros_like(timesteps)
|
88 |
-
self.reference_unet(
|
89 |
-
ref_image_latents,
|
90 |
-
ref_timesteps,
|
91 |
-
encoder_hidden_states=clip_image_embeds,
|
92 |
-
return_dict=False,
|
93 |
-
)
|
94 |
-
self.reference_control_reader.update(self.reference_control_writer)
|
95 |
-
|
96 |
-
model_pred = self.denoising_unet(
|
97 |
-
noisy_latents,
|
98 |
-
timesteps,
|
99 |
-
pose_cond_fea=pose_fea,
|
100 |
-
encoder_hidden_states=clip_image_embeds,
|
101 |
-
).sample
|
102 |
-
|
103 |
-
return model_pred
|
104 |
-
|
105 |
-
|
106 |
-
def compute_snr(noise_scheduler, timesteps):
|
107 |
-
"""
|
108 |
-
Computes SNR as per
|
109 |
-
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
110 |
-
"""
|
111 |
-
alphas_cumprod = noise_scheduler.alphas_cumprod
|
112 |
-
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
113 |
-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
114 |
-
|
115 |
-
# Expand the tensors.
|
116 |
-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
117 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
|
118 |
-
timesteps
|
119 |
-
].float()
|
120 |
-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
121 |
-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
122 |
-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
123 |
-
|
124 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
|
125 |
-
device=timesteps.device
|
126 |
-
)[timesteps].float()
|
127 |
-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
128 |
-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
129 |
-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
130 |
-
|
131 |
-
# Compute SNR.
|
132 |
-
snr = (alpha / sigma) ** 2
|
133 |
-
return snr
|
134 |
-
|
135 |
-
|
136 |
-
def log_validation(
|
137 |
-
vae,
|
138 |
-
image_enc,
|
139 |
-
net,
|
140 |
-
scheduler,
|
141 |
-
accelerator,
|
142 |
-
width,
|
143 |
-
height,
|
144 |
-
clip_length=24,
|
145 |
-
generator=None,
|
146 |
-
):
|
147 |
-
logger.info("Running validation... ")
|
148 |
-
|
149 |
-
ori_net = accelerator.unwrap_model(net)
|
150 |
-
reference_unet = ori_net.reference_unet
|
151 |
-
denoising_unet = ori_net.denoising_unet
|
152 |
-
pose_guider = ori_net.pose_guider
|
153 |
-
|
154 |
-
if generator is None:
|
155 |
-
generator = torch.manual_seed(42)
|
156 |
-
tmp_denoising_unet = copy.deepcopy(denoising_unet)
|
157 |
-
tmp_denoising_unet = tmp_denoising_unet.to(dtype=torch.float16)
|
158 |
-
|
159 |
-
pipe = Pose2VideoPipeline(
|
160 |
-
vae=vae,
|
161 |
-
image_encoder=image_enc,
|
162 |
-
reference_unet=reference_unet,
|
163 |
-
denoising_unet=tmp_denoising_unet,
|
164 |
-
pose_guider=pose_guider,
|
165 |
-
scheduler=scheduler,
|
166 |
-
)
|
167 |
-
pipe = pipe.to(accelerator.device)
|
168 |
-
|
169 |
-
test_cases = [
|
170 |
-
(
|
171 |
-
"./configs/inference/ref_images/anyone-3.png",
|
172 |
-
"./configs/inference/pose_videos/anyone-video-1_kps.mp4",
|
173 |
-
),
|
174 |
-
(
|
175 |
-
"./configs/inference/ref_images/anyone-2.png",
|
176 |
-
"./configs/inference/pose_videos/anyone-video-2_kps.mp4",
|
177 |
-
),
|
178 |
-
]
|
179 |
-
|
180 |
-
results = []
|
181 |
-
for test_case in test_cases:
|
182 |
-
ref_image_path, pose_video_path = test_case
|
183 |
-
ref_name = Path(ref_image_path).stem
|
184 |
-
pose_name = Path(pose_video_path).stem
|
185 |
-
ref_image_pil = Image.open(ref_image_path).convert("RGB")
|
186 |
-
|
187 |
-
pose_list = []
|
188 |
-
pose_tensor_list = []
|
189 |
-
pose_images = read_frames(pose_video_path)
|
190 |
-
pose_transform = transforms.Compose(
|
191 |
-
[transforms.Resize((height, width)), transforms.ToTensor()]
|
192 |
-
)
|
193 |
-
for pose_image_pil in pose_images[:clip_length]:
|
194 |
-
pose_tensor_list.append(pose_transform(pose_image_pil))
|
195 |
-
pose_list.append(pose_image_pil)
|
196 |
-
|
197 |
-
pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
|
198 |
-
pose_tensor = pose_tensor.transpose(0, 1)
|
199 |
-
|
200 |
-
pipeline_output = pipe(
|
201 |
-
ref_image_pil,
|
202 |
-
pose_list,
|
203 |
-
width,
|
204 |
-
height,
|
205 |
-
clip_length,
|
206 |
-
20,
|
207 |
-
3.5,
|
208 |
-
generator=generator,
|
209 |
-
)
|
210 |
-
video = pipeline_output.videos
|
211 |
-
|
212 |
-
# Concat it with pose tensor
|
213 |
-
pose_tensor = pose_tensor.unsqueeze(0)
|
214 |
-
video = torch.cat([video, pose_tensor], dim=0)
|
215 |
-
|
216 |
-
results.append({"name": f"{ref_name}_{pose_name}", "vid": video})
|
217 |
-
|
218 |
-
del tmp_denoising_unet
|
219 |
-
del pipe
|
220 |
-
torch.cuda.empty_cache()
|
221 |
-
|
222 |
-
return results
|
223 |
-
|
224 |
-
|
225 |
-
def main(cfg):
|
226 |
-
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
|
227 |
-
accelerator = Accelerator(
|
228 |
-
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
|
229 |
-
mixed_precision=cfg.solver.mixed_precision,
|
230 |
-
log_with="mlflow",
|
231 |
-
project_dir="./mlruns",
|
232 |
-
kwargs_handlers=[kwargs],
|
233 |
-
)
|
234 |
-
|
235 |
-
# Make one log on every process with the configuration for debugging.
|
236 |
-
logging.basicConfig(
|
237 |
-
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
238 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
239 |
-
level=logging.INFO,
|
240 |
-
)
|
241 |
-
logger.info(accelerator.state, main_process_only=False)
|
242 |
-
if accelerator.is_local_main_process:
|
243 |
-
transformers.utils.logging.set_verbosity_warning()
|
244 |
-
diffusers.utils.logging.set_verbosity_info()
|
245 |
-
else:
|
246 |
-
transformers.utils.logging.set_verbosity_error()
|
247 |
-
diffusers.utils.logging.set_verbosity_error()
|
248 |
-
|
249 |
-
# If passed along, set the training seed now.
|
250 |
-
if cfg.seed is not None:
|
251 |
-
seed_everything(cfg.seed)
|
252 |
-
|
253 |
-
exp_name = cfg.exp_name
|
254 |
-
save_dir = f"{cfg.output_dir}/{exp_name}"
|
255 |
-
if accelerator.is_main_process:
|
256 |
-
if not os.path.exists(save_dir):
|
257 |
-
os.makedirs(save_dir)
|
258 |
-
|
259 |
-
inference_config_path = "./configs/inference/inference_v2.yaml"
|
260 |
-
infer_config = OmegaConf.load(inference_config_path)
|
261 |
-
|
262 |
-
if cfg.weight_dtype == "fp16":
|
263 |
-
weight_dtype = torch.float16
|
264 |
-
elif cfg.weight_dtype == "fp32":
|
265 |
-
weight_dtype = torch.float32
|
266 |
-
else:
|
267 |
-
raise ValueError(
|
268 |
-
f"Do not support weight dtype: {cfg.weight_dtype} during training"
|
269 |
-
)
|
270 |
-
|
271 |
-
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
|
272 |
-
if cfg.enable_zero_snr:
|
273 |
-
sched_kwargs.update(
|
274 |
-
rescale_betas_zero_snr=True,
|
275 |
-
timestep_spacing="trailing",
|
276 |
-
prediction_type="v_prediction",
|
277 |
-
)
|
278 |
-
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
279 |
-
sched_kwargs.update({"beta_schedule": "scaled_linear"})
|
280 |
-
train_noise_scheduler = DDIMScheduler(**sched_kwargs)
|
281 |
-
|
282 |
-
image_enc = CLIPVisionModelWithProjection.from_pretrained(
|
283 |
-
cfg.image_encoder_path,
|
284 |
-
).to(dtype=weight_dtype, device="cuda")
|
285 |
-
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
|
286 |
-
"cuda", dtype=weight_dtype
|
287 |
-
)
|
288 |
-
reference_unet = UNet2DConditionModel.from_pretrained(
|
289 |
-
cfg.base_model_path,
|
290 |
-
subfolder="unet",
|
291 |
-
).to(device="cuda", dtype=weight_dtype)
|
292 |
-
|
293 |
-
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
|
294 |
-
cfg.base_model_path,
|
295 |
-
cfg.mm_path,
|
296 |
-
subfolder="unet",
|
297 |
-
unet_additional_kwargs=OmegaConf.to_container(
|
298 |
-
infer_config.unet_additional_kwargs
|
299 |
-
),
|
300 |
-
).to(device="cuda")
|
301 |
-
|
302 |
-
pose_guider = PoseGuider(
|
303 |
-
conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
|
304 |
-
).to(device="cuda", dtype=weight_dtype)
|
305 |
-
|
306 |
-
stage1_ckpt_dir = cfg.stage1_ckpt_dir
|
307 |
-
stage1_ckpt_step = cfg.stage1_ckpt_step
|
308 |
-
denoising_unet.load_state_dict(
|
309 |
-
torch.load(
|
310 |
-
os.path.join(stage1_ckpt_dir, f"denoising_unet-{stage1_ckpt_step}.pth"),
|
311 |
-
map_location="cpu",
|
312 |
-
),
|
313 |
-
strict=False,
|
314 |
-
)
|
315 |
-
reference_unet.load_state_dict(
|
316 |
-
torch.load(
|
317 |
-
os.path.join(stage1_ckpt_dir, f"reference_unet-{stage1_ckpt_step}.pth"),
|
318 |
-
map_location="cpu",
|
319 |
-
),
|
320 |
-
strict=False,
|
321 |
-
)
|
322 |
-
pose_guider.load_state_dict(
|
323 |
-
torch.load(
|
324 |
-
os.path.join(stage1_ckpt_dir, f"pose_guider-{stage1_ckpt_step}.pth"),
|
325 |
-
map_location="cpu",
|
326 |
-
),
|
327 |
-
strict=False,
|
328 |
-
)
|
329 |
-
|
330 |
-
# Freeze
|
331 |
-
vae.requires_grad_(False)
|
332 |
-
image_enc.requires_grad_(False)
|
333 |
-
reference_unet.requires_grad_(False)
|
334 |
-
denoising_unet.requires_grad_(False)
|
335 |
-
pose_guider.requires_grad_(False)
|
336 |
-
|
337 |
-
# Set motion module learnable
|
338 |
-
for name, module in denoising_unet.named_modules():
|
339 |
-
if "motion_modules" in name:
|
340 |
-
for params in module.parameters():
|
341 |
-
params.requires_grad = True
|
342 |
-
|
343 |
-
reference_control_writer = ReferenceAttentionControl(
|
344 |
-
reference_unet,
|
345 |
-
do_classifier_free_guidance=False,
|
346 |
-
mode="write",
|
347 |
-
fusion_blocks="full",
|
348 |
-
)
|
349 |
-
reference_control_reader = ReferenceAttentionControl(
|
350 |
-
denoising_unet,
|
351 |
-
do_classifier_free_guidance=False,
|
352 |
-
mode="read",
|
353 |
-
fusion_blocks="full",
|
354 |
-
)
|
355 |
-
|
356 |
-
net = Net(
|
357 |
-
reference_unet,
|
358 |
-
denoising_unet,
|
359 |
-
pose_guider,
|
360 |
-
reference_control_writer,
|
361 |
-
reference_control_reader,
|
362 |
-
)
|
363 |
-
|
364 |
-
if cfg.solver.enable_xformers_memory_efficient_attention:
|
365 |
-
if is_xformers_available():
|
366 |
-
reference_unet.enable_xformers_memory_efficient_attention()
|
367 |
-
denoising_unet.enable_xformers_memory_efficient_attention()
|
368 |
-
else:
|
369 |
-
raise ValueError(
|
370 |
-
"xformers is not available. Make sure it is installed correctly"
|
371 |
-
)
|
372 |
-
|
373 |
-
if cfg.solver.gradient_checkpointing:
|
374 |
-
reference_unet.enable_gradient_checkpointing()
|
375 |
-
denoising_unet.enable_gradient_checkpointing()
|
376 |
-
|
377 |
-
if cfg.solver.scale_lr:
|
378 |
-
learning_rate = (
|
379 |
-
cfg.solver.learning_rate
|
380 |
-
* cfg.solver.gradient_accumulation_steps
|
381 |
-
* cfg.data.train_bs
|
382 |
-
* accelerator.num_processes
|
383 |
-
)
|
384 |
-
else:
|
385 |
-
learning_rate = cfg.solver.learning_rate
|
386 |
-
|
387 |
-
# Initialize the optimizer
|
388 |
-
if cfg.solver.use_8bit_adam:
|
389 |
-
try:
|
390 |
-
import bitsandbytes as bnb
|
391 |
-
except ImportError:
|
392 |
-
raise ImportError(
|
393 |
-
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
394 |
-
)
|
395 |
-
|
396 |
-
optimizer_cls = bnb.optim.AdamW8bit
|
397 |
-
else:
|
398 |
-
optimizer_cls = torch.optim.AdamW
|
399 |
-
|
400 |
-
trainable_params = list(filter(lambda p: p.requires_grad, net.parameters()))
|
401 |
-
logger.info(f"Total trainable params {len(trainable_params)}")
|
402 |
-
optimizer = optimizer_cls(
|
403 |
-
trainable_params,
|
404 |
-
lr=learning_rate,
|
405 |
-
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
406 |
-
weight_decay=cfg.solver.adam_weight_decay,
|
407 |
-
eps=cfg.solver.adam_epsilon,
|
408 |
-
)
|
409 |
-
|
410 |
-
# Scheduler
|
411 |
-
lr_scheduler = get_scheduler(
|
412 |
-
cfg.solver.lr_scheduler,
|
413 |
-
optimizer=optimizer,
|
414 |
-
num_warmup_steps=cfg.solver.lr_warmup_steps
|
415 |
-
* cfg.solver.gradient_accumulation_steps,
|
416 |
-
num_training_steps=cfg.solver.max_train_steps
|
417 |
-
* cfg.solver.gradient_accumulation_steps,
|
418 |
-
)
|
419 |
-
|
420 |
-
train_dataset = HumanDanceVideoDataset(
|
421 |
-
width=cfg.data.train_width,
|
422 |
-
height=cfg.data.train_height,
|
423 |
-
n_sample_frames=cfg.data.n_sample_frames,
|
424 |
-
sample_rate=cfg.data.sample_rate,
|
425 |
-
img_scale=(1.0, 1.0),
|
426 |
-
data_meta_paths=cfg.data.meta_paths,
|
427 |
-
)
|
428 |
-
train_dataloader = torch.utils.data.DataLoader(
|
429 |
-
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
|
430 |
-
)
|
431 |
-
|
432 |
-
# Prepare everything with our `accelerator`.
|
433 |
-
(
|
434 |
-
net,
|
435 |
-
optimizer,
|
436 |
-
train_dataloader,
|
437 |
-
lr_scheduler,
|
438 |
-
) = accelerator.prepare(
|
439 |
-
net,
|
440 |
-
optimizer,
|
441 |
-
train_dataloader,
|
442 |
-
lr_scheduler,
|
443 |
-
)
|
444 |
-
|
445 |
-
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
446 |
-
num_update_steps_per_epoch = math.ceil(
|
447 |
-
len(train_dataloader) / cfg.solver.gradient_accumulation_steps
|
448 |
-
)
|
449 |
-
# Afterwards we recalculate our number of training epochs
|
450 |
-
num_train_epochs = math.ceil(
|
451 |
-
cfg.solver.max_train_steps / num_update_steps_per_epoch
|
452 |
-
)
|
453 |
-
|
454 |
-
# We need to initialize the trackers we use, and also store our configuration.
|
455 |
-
# The trackers initializes automatically on the main process.
|
456 |
-
if accelerator.is_main_process:
|
457 |
-
run_time = datetime.now().strftime("%Y%m%d-%H%M")
|
458 |
-
accelerator.init_trackers(
|
459 |
-
exp_name,
|
460 |
-
init_kwargs={"mlflow": {"run_name": run_time}},
|
461 |
-
)
|
462 |
-
# dump config file
|
463 |
-
mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
|
464 |
-
|
465 |
-
# Train!
|
466 |
-
total_batch_size = (
|
467 |
-
cfg.data.train_bs
|
468 |
-
* accelerator.num_processes
|
469 |
-
* cfg.solver.gradient_accumulation_steps
|
470 |
-
)
|
471 |
-
|
472 |
-
logger.info("***** Running training *****")
|
473 |
-
logger.info(f" Num examples = {len(train_dataset)}")
|
474 |
-
logger.info(f" Num Epochs = {num_train_epochs}")
|
475 |
-
logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
|
476 |
-
logger.info(
|
477 |
-
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
478 |
-
)
|
479 |
-
logger.info(
|
480 |
-
f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
|
481 |
-
)
|
482 |
-
logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
|
483 |
-
global_step = 0
|
484 |
-
first_epoch = 0
|
485 |
-
|
486 |
-
# Potentially load in the weights and states from a previous save
|
487 |
-
if cfg.resume_from_checkpoint:
|
488 |
-
if cfg.resume_from_checkpoint != "latest":
|
489 |
-
resume_dir = cfg.resume_from_checkpoint
|
490 |
-
else:
|
491 |
-
resume_dir = save_dir
|
492 |
-
# Get the most recent checkpoint
|
493 |
-
dirs = os.listdir(resume_dir)
|
494 |
-
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
495 |
-
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
496 |
-
path = dirs[-1]
|
497 |
-
accelerator.load_state(os.path.join(resume_dir, path))
|
498 |
-
accelerator.print(f"Resuming from checkpoint {path}")
|
499 |
-
global_step = int(path.split("-")[1])
|
500 |
-
|
501 |
-
first_epoch = global_step // num_update_steps_per_epoch
|
502 |
-
resume_step = global_step % num_update_steps_per_epoch
|
503 |
-
|
504 |
-
# Only show the progress bar once on each machine.
|
505 |
-
progress_bar = tqdm(
|
506 |
-
range(global_step, cfg.solver.max_train_steps),
|
507 |
-
disable=not accelerator.is_local_main_process,
|
508 |
-
)
|
509 |
-
progress_bar.set_description("Steps")
|
510 |
-
|
511 |
-
for epoch in range(first_epoch, num_train_epochs):
|
512 |
-
train_loss = 0.0
|
513 |
-
t_data_start = time.time()
|
514 |
-
for step, batch in enumerate(train_dataloader):
|
515 |
-
t_data = time.time() - t_data_start
|
516 |
-
with accelerator.accumulate(net):
|
517 |
-
# Convert videos to latent space
|
518 |
-
pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype)
|
519 |
-
with torch.no_grad():
|
520 |
-
video_length = pixel_values_vid.shape[1]
|
521 |
-
pixel_values_vid = rearrange(
|
522 |
-
pixel_values_vid, "b f c h w -> (b f) c h w"
|
523 |
-
)
|
524 |
-
latents = vae.encode(pixel_values_vid).latent_dist.sample()
|
525 |
-
latents = rearrange(
|
526 |
-
latents, "(b f) c h w -> b c f h w", f=video_length
|
527 |
-
)
|
528 |
-
latents = latents * 0.18215
|
529 |
-
|
530 |
-
noise = torch.randn_like(latents)
|
531 |
-
if cfg.noise_offset > 0:
|
532 |
-
noise += cfg.noise_offset * torch.randn(
|
533 |
-
(latents.shape[0], latents.shape[1], 1, 1, 1),
|
534 |
-
device=latents.device,
|
535 |
-
)
|
536 |
-
bsz = latents.shape[0]
|
537 |
-
# Sample a random timestep for each video
|
538 |
-
timesteps = torch.randint(
|
539 |
-
0,
|
540 |
-
train_noise_scheduler.num_train_timesteps,
|
541 |
-
(bsz,),
|
542 |
-
device=latents.device,
|
543 |
-
)
|
544 |
-
timesteps = timesteps.long()
|
545 |
-
|
546 |
-
pixel_values_pose = batch["pixel_values_pose"] # (bs, f, c, H, W)
|
547 |
-
pixel_values_pose = pixel_values_pose.transpose(
|
548 |
-
1, 2
|
549 |
-
) # (bs, c, f, H, W)
|
550 |
-
|
551 |
-
uncond_fwd = random.random() < cfg.uncond_ratio
|
552 |
-
clip_image_list = []
|
553 |
-
ref_image_list = []
|
554 |
-
for batch_idx, (ref_img, clip_img) in enumerate(
|
555 |
-
zip(
|
556 |
-
batch["pixel_values_ref_img"],
|
557 |
-
batch["clip_ref_img"],
|
558 |
-
)
|
559 |
-
):
|
560 |
-
if uncond_fwd:
|
561 |
-
clip_image_list.append(torch.zeros_like(clip_img))
|
562 |
-
else:
|
563 |
-
clip_image_list.append(clip_img)
|
564 |
-
ref_image_list.append(ref_img)
|
565 |
-
|
566 |
-
with torch.no_grad():
|
567 |
-
ref_img = torch.stack(ref_image_list, dim=0).to(
|
568 |
-
dtype=vae.dtype, device=vae.device
|
569 |
-
)
|
570 |
-
ref_image_latents = vae.encode(
|
571 |
-
ref_img
|
572 |
-
).latent_dist.sample() # (bs, d, 64, 64)
|
573 |
-
ref_image_latents = ref_image_latents * 0.18215
|
574 |
-
|
575 |
-
clip_img = torch.stack(clip_image_list, dim=0).to(
|
576 |
-
dtype=image_enc.dtype, device=image_enc.device
|
577 |
-
)
|
578 |
-
clip_img = clip_img.to(device="cuda", dtype=weight_dtype)
|
579 |
-
clip_image_embeds = image_enc(
|
580 |
-
clip_img.to("cuda", dtype=weight_dtype)
|
581 |
-
).image_embeds
|
582 |
-
clip_image_embeds = clip_image_embeds.unsqueeze(1) # (bs, 1, d)
|
583 |
-
|
584 |
-
# add noise
|
585 |
-
noisy_latents = train_noise_scheduler.add_noise(
|
586 |
-
latents, noise, timesteps
|
587 |
-
)
|
588 |
-
|
589 |
-
# Get the target for loss depending on the prediction type
|
590 |
-
if train_noise_scheduler.prediction_type == "epsilon":
|
591 |
-
target = noise
|
592 |
-
elif train_noise_scheduler.prediction_type == "v_prediction":
|
593 |
-
target = train_noise_scheduler.get_velocity(
|
594 |
-
latents, noise, timesteps
|
595 |
-
)
|
596 |
-
else:
|
597 |
-
raise ValueError(
|
598 |
-
f"Unknown prediction type {train_noise_scheduler.prediction_type}"
|
599 |
-
)
|
600 |
-
|
601 |
-
# ---- Forward!!! -----
|
602 |
-
model_pred = net(
|
603 |
-
noisy_latents,
|
604 |
-
timesteps,
|
605 |
-
ref_image_latents,
|
606 |
-
clip_image_embeds,
|
607 |
-
pixel_values_pose,
|
608 |
-
uncond_fwd=uncond_fwd,
|
609 |
-
)
|
610 |
-
|
611 |
-
if cfg.snr_gamma == 0:
|
612 |
-
loss = F.mse_loss(
|
613 |
-
model_pred.float(), target.float(), reduction="mean"
|
614 |
-
)
|
615 |
-
else:
|
616 |
-
snr = compute_snr(train_noise_scheduler, timesteps)
|
617 |
-
if train_noise_scheduler.config.prediction_type == "v_prediction":
|
618 |
-
# Velocity objective requires that we add one to SNR values before we divide by them.
|
619 |
-
snr = snr + 1
|
620 |
-
mse_loss_weights = (
|
621 |
-
torch.stack(
|
622 |
-
[snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
|
623 |
-
).min(dim=1)[0]
|
624 |
-
/ snr
|
625 |
-
)
|
626 |
-
loss = F.mse_loss(
|
627 |
-
model_pred.float(), target.float(), reduction="none"
|
628 |
-
)
|
629 |
-
loss = (
|
630 |
-
loss.mean(dim=list(range(1, len(loss.shape))))
|
631 |
-
* mse_loss_weights
|
632 |
-
)
|
633 |
-
loss = loss.mean()
|
634 |
-
|
635 |
-
# Gather the losses across all processes for logging (if we use distributed training).
|
636 |
-
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
|
637 |
-
train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
|
638 |
-
|
639 |
-
# Backpropagate
|
640 |
-
accelerator.backward(loss)
|
641 |
-
if accelerator.sync_gradients:
|
642 |
-
accelerator.clip_grad_norm_(
|
643 |
-
trainable_params,
|
644 |
-
cfg.solver.max_grad_norm,
|
645 |
-
)
|
646 |
-
optimizer.step()
|
647 |
-
lr_scheduler.step()
|
648 |
-
optimizer.zero_grad()
|
649 |
-
|
650 |
-
if accelerator.sync_gradients:
|
651 |
-
reference_control_reader.clear()
|
652 |
-
reference_control_writer.clear()
|
653 |
-
progress_bar.update(1)
|
654 |
-
global_step += 1
|
655 |
-
accelerator.log({"train_loss": train_loss}, step=global_step)
|
656 |
-
train_loss = 0.0
|
657 |
-
|
658 |
-
if global_step % cfg.val.validation_steps == 0:
|
659 |
-
if accelerator.is_main_process:
|
660 |
-
generator = torch.Generator(device=accelerator.device)
|
661 |
-
generator.manual_seed(cfg.seed)
|
662 |
-
|
663 |
-
sample_dicts = log_validation(
|
664 |
-
vae=vae,
|
665 |
-
image_enc=image_enc,
|
666 |
-
net=net,
|
667 |
-
scheduler=val_noise_scheduler,
|
668 |
-
accelerator=accelerator,
|
669 |
-
width=cfg.data.train_width,
|
670 |
-
height=cfg.data.train_height,
|
671 |
-
clip_length=cfg.data.n_sample_frames,
|
672 |
-
generator=generator,
|
673 |
-
)
|
674 |
-
|
675 |
-
for sample_id, sample_dict in enumerate(sample_dicts):
|
676 |
-
sample_name = sample_dict["name"]
|
677 |
-
vid = sample_dict["vid"]
|
678 |
-
with TemporaryDirectory() as temp_dir:
|
679 |
-
out_file = Path(
|
680 |
-
f"{temp_dir}/{global_step:06d}-{sample_name}.gif"
|
681 |
-
)
|
682 |
-
save_videos_grid(vid, out_file, n_rows=2)
|
683 |
-
mlflow.log_artifact(out_file)
|
684 |
-
|
685 |
-
logs = {
|
686 |
-
"step_loss": loss.detach().item(),
|
687 |
-
"lr": lr_scheduler.get_last_lr()[0],
|
688 |
-
"td": f"{t_data:.2f}s",
|
689 |
-
}
|
690 |
-
t_data_start = time.time()
|
691 |
-
progress_bar.set_postfix(**logs)
|
692 |
-
|
693 |
-
if global_step >= cfg.solver.max_train_steps:
|
694 |
-
break
|
695 |
-
# save model after each epoch
|
696 |
-
if accelerator.is_main_process:
|
697 |
-
save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
|
698 |
-
delete_additional_ckpt(save_dir, 1)
|
699 |
-
accelerator.save_state(save_path)
|
700 |
-
# save motion module only
|
701 |
-
unwrap_net = accelerator.unwrap_model(net)
|
702 |
-
save_checkpoint(
|
703 |
-
unwrap_net.denoising_unet,
|
704 |
-
save_dir,
|
705 |
-
"motion_module",
|
706 |
-
global_step,
|
707 |
-
total_limit=3,
|
708 |
-
)
|
709 |
-
|
710 |
-
# Create the pipeline using the trained modules and save it.
|
711 |
-
accelerator.wait_for_everyone()
|
712 |
-
accelerator.end_training()
|
713 |
-
|
714 |
-
|
715 |
-
def save_checkpoint(model, save_dir, prefix, ckpt_num, total_limit=None):
|
716 |
-
save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
|
717 |
-
|
718 |
-
if total_limit is not None:
|
719 |
-
checkpoints = os.listdir(save_dir)
|
720 |
-
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
|
721 |
-
checkpoints = sorted(
|
722 |
-
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
723 |
-
)
|
724 |
-
|
725 |
-
if len(checkpoints) >= total_limit:
|
726 |
-
num_to_remove = len(checkpoints) - total_limit + 1
|
727 |
-
removing_checkpoints = checkpoints[0:num_to_remove]
|
728 |
-
logger.info(
|
729 |
-
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
730 |
-
)
|
731 |
-
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
732 |
-
|
733 |
-
for removing_checkpoint in removing_checkpoints:
|
734 |
-
removing_checkpoint = os.path.join(save_dir, removing_checkpoint)
|
735 |
-
os.remove(removing_checkpoint)
|
736 |
-
|
737 |
-
mm_state_dict = OrderedDict()
|
738 |
-
state_dict = model.state_dict()
|
739 |
-
for key in state_dict:
|
740 |
-
if "motion_module" in key:
|
741 |
-
mm_state_dict[key] = state_dict[key]
|
742 |
-
|
743 |
-
torch.save(mm_state_dict, save_path)
|
744 |
-
|
745 |
-
|
746 |
-
def decode_latents(vae, latents):
|
747 |
-
video_length = latents.shape[2]
|
748 |
-
latents = 1 / 0.18215 * latents
|
749 |
-
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
750 |
-
# video = self.vae.decode(latents).sample
|
751 |
-
video = []
|
752 |
-
for frame_idx in tqdm(range(latents.shape[0])):
|
753 |
-
video.append(vae.decode(latents[frame_idx : frame_idx + 1]).sample)
|
754 |
-
video = torch.cat(video)
|
755 |
-
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
756 |
-
video = (video / 2 + 0.5).clamp(0, 1)
|
757 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
758 |
-
video = video.cpu().float().numpy()
|
759 |
-
return video
|
760 |
-
|
761 |
-
|
762 |
-
if __name__ == "__main__":
|
763 |
-
parser = argparse.ArgumentParser()
|
764 |
-
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
|
765 |
-
args = parser.parse_args()
|
766 |
-
|
767 |
-
if args.config[-5:] == ".yaml":
|
768 |
-
config = OmegaConf.load(args.config)
|
769 |
-
elif args.config[-3:] == ".py":
|
770 |
-
config = import_filename(args.config).cfg
|
771 |
-
else:
|
772 |
-
raise ValueError("Do not support this format config file")
|
773 |
-
main(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|