|
import copy |
|
import os |
|
|
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch |
|
from pytorch_lightning import loggers as pl_loggers |
|
from pytorch_lightning.callbacks import * |
|
from torch.cuda import amp |
|
from torch.optim.optimizer import Optimizer |
|
from torch.utils.data.dataset import TensorDataset |
|
from model.seq2seq import DiffusionPredictor |
|
|
|
from config import * |
|
from dist_utils import * |
|
from renderer import * |
|
|
|
|
|
class LitModel(pl.LightningModule): |
|
def __init__(self, conf: TrainConfig): |
|
super().__init__() |
|
assert conf.train_mode != TrainMode.manipulate |
|
if conf.seed is not None: |
|
pl.seed_everything(conf.seed) |
|
|
|
self.save_hyperparameters(conf.as_dict_jsonable()) |
|
|
|
self.conf = conf |
|
|
|
self.model = DiffusionPredictor(conf) |
|
|
|
self.ema_model = copy.deepcopy(self.model) |
|
self.ema_model.requires_grad_(False) |
|
self.ema_model.eval() |
|
|
|
self.sampler = conf.make_diffusion_conf().make_sampler() |
|
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() |
|
|
|
|
|
self.T_sampler = conf.make_T_sampler() |
|
|
|
if conf.train_mode.use_latent_net(): |
|
self.latent_sampler = conf.make_latent_diffusion_conf( |
|
).make_sampler() |
|
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( |
|
).make_sampler() |
|
else: |
|
self.latent_sampler = None |
|
self.eval_latent_sampler = None |
|
|
|
|
|
self.register_buffer( |
|
'x_T', |
|
torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size)) |
|
|
|
|
|
def render(self, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, step_T, control_flag): |
|
if step_T is None: |
|
sampler = self.eval_sampler |
|
else: |
|
sampler = self.conf._make_diffusion_conf(step_T).make_sampler() |
|
|
|
pred_img = render_condition(self.conf, |
|
self.ema_model, |
|
sampler, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, control_flag) |
|
return pred_img |
|
|
|
def forward(self, noise=None, x_start=None, ema_model: bool = False): |
|
with amp.autocast(False): |
|
if not self.disable_ema: |
|
model = self.ema_model |
|
else: |
|
model = self.model |
|
gen = self.eval_sampler.sample(model=model, |
|
noise=noise, |
|
x_start=x_start) |
|
return gen |
|
|
|
def setup(self, stage=None) -> None: |
|
""" |
|
make datasets & seeding each worker separately |
|
""" |
|
|
|
|
|
if self.conf.seed is not None: |
|
seed = self.conf.seed * get_world_size() + self.global_rank |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
print('local seed:', seed) |
|
|
|
|
|
self.train_data = self.conf.make_dataset() |
|
print('train data:', len(self.train_data)) |
|
self.val_data = self.train_data |
|
print('val data:', len(self.val_data)) |
|
|
|
def _train_dataloader(self, drop_last=True): |
|
""" |
|
really make the dataloader |
|
""" |
|
|
|
|
|
conf = self.conf.clone() |
|
conf.batch_size = self.batch_size |
|
|
|
dataloader = conf.make_loader(self.train_data, |
|
shuffle=True, |
|
drop_last=drop_last) |
|
return dataloader |
|
|
|
def train_dataloader(self): |
|
""" |
|
return the dataloader, if diffusion mode => return image dataset |
|
if latent mode => return the inferred latent dataset |
|
""" |
|
print('on train dataloader start ...') |
|
if self.conf.train_mode.require_dataset_infer(): |
|
if self.conds is None: |
|
|
|
|
|
self.conds = self.infer_whole_dataset() |
|
|
|
|
|
self.conds_mean.data = self.conds.float().mean(dim=0, |
|
keepdim=True) |
|
self.conds_std.data = self.conds.float().std(dim=0, |
|
keepdim=True) |
|
print('mean:', self.conds_mean.mean(), 'std:', |
|
self.conds_std.mean()) |
|
|
|
|
|
conf = self.conf.clone() |
|
conf.batch_size = self.batch_size |
|
data = TensorDataset(self.conds) |
|
return conf.make_loader(data, shuffle=True) |
|
else: |
|
return self._train_dataloader() |
|
|
|
@property |
|
def batch_size(self): |
|
""" |
|
local batch size for each worker |
|
""" |
|
ws = get_world_size() |
|
assert self.conf.batch_size % ws == 0 |
|
return self.conf.batch_size // ws |
|
|
|
@property |
|
def num_samples(self): |
|
""" |
|
(global) batch size * iterations |
|
""" |
|
|
|
|
|
return self.global_step * self.conf.batch_size_effective |
|
|
|
def is_last_accum(self, batch_idx): |
|
""" |
|
is it the last gradient accumulation loop? |
|
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not |
|
""" |
|
return (batch_idx + 1) % self.conf.accum_batches == 0 |
|
|
|
def training_step(self, batch, batch_idx): |
|
""" |
|
given an input, calculate the loss function |
|
no optimization at this stage. |
|
""" |
|
with amp.autocast(False): |
|
motion_start = batch['motion_start'] |
|
motion_direction = batch['motion_direction'] |
|
audio_feats = batch['audio_feats'].float() |
|
face_location = batch['face_location'].float() |
|
face_scale = batch['face_scale'].float() |
|
yaw_pitch_roll = batch['yaw_pitch_roll'].float() |
|
motion_direction_start = batch['motion_direction_start'].float() |
|
|
|
|
|
if self.conf.train_mode == TrainMode.diffusion: |
|
""" |
|
main training mode!!! |
|
""" |
|
|
|
t, weight = self.T_sampler.sample(len(motion_start), motion_start.device) |
|
losses = self.sampler.training_losses(model=self.model, |
|
motion_direction_start=motion_direction_start, |
|
motion_target=motion_direction, |
|
motion_start=motion_start, |
|
audio_feats=audio_feats, |
|
face_location=face_location, |
|
face_scale=face_scale, |
|
yaw_pitch_roll=yaw_pitch_roll, |
|
t=t) |
|
else: |
|
raise NotImplementedError() |
|
|
|
loss = losses['loss'].mean() |
|
|
|
for key in losses.keys(): |
|
losses[key] = self.all_gather(losses[key]).mean() |
|
|
|
if self.global_rank == 0: |
|
self.logger.experiment.add_scalar('loss', losses['loss'], |
|
self.num_samples) |
|
for key in losses: |
|
self.logger.experiment.add_scalar( |
|
f'loss/{key}', losses[key], self.num_samples) |
|
|
|
return {'loss': loss} |
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx: int, |
|
dataloader_idx: int) -> None: |
|
""" |
|
after each training step ... |
|
""" |
|
if self.is_last_accum(batch_idx): |
|
|
|
if self.conf.train_mode == TrainMode.latent_diffusion: |
|
|
|
ema(self.model.latent_net, self.ema_model.latent_net, |
|
self.conf.ema_decay) |
|
else: |
|
ema(self.model, self.ema_model, self.conf.ema_decay) |
|
|
|
def on_before_optimizer_step(self, optimizer: Optimizer, |
|
optimizer_idx: int) -> None: |
|
|
|
|
|
if self.conf.grad_clip > 0: |
|
|
|
params = [ |
|
p for group in optimizer.param_groups for p in group['params'] |
|
] |
|
torch.nn.utils.clip_grad_norm_(params, |
|
max_norm=self.conf.grad_clip) |
|
def configure_optimizers(self): |
|
out = {} |
|
if self.conf.optimizer == OptimizerType.adam: |
|
optim = torch.optim.Adam(self.model.parameters(), |
|
lr=self.conf.lr, |
|
weight_decay=self.conf.weight_decay) |
|
elif self.conf.optimizer == OptimizerType.adamw: |
|
optim = torch.optim.AdamW(self.model.parameters(), |
|
lr=self.conf.lr, |
|
weight_decay=self.conf.weight_decay) |
|
else: |
|
raise NotImplementedError() |
|
out['optimizer'] = optim |
|
if self.conf.warmup > 0: |
|
sched = torch.optim.lr_scheduler.LambdaLR(optim, |
|
lr_lambda=WarmupLR( |
|
self.conf.warmup)) |
|
out['lr_scheduler'] = { |
|
'scheduler': sched, |
|
'interval': 'step', |
|
} |
|
return out |
|
|
|
def split_tensor(self, x): |
|
""" |
|
extract the tensor for a corresponding "worker" in the batch dimension |
|
|
|
Args: |
|
x: (n, c) |
|
|
|
Returns: x: (n_local, c) |
|
""" |
|
n = len(x) |
|
rank = self.global_rank |
|
world_size = get_world_size() |
|
|
|
per_rank = n // world_size |
|
return x[rank * per_rank:(rank + 1) * per_rank] |
|
|
|
def ema(source, target, decay): |
|
source_dict = source.state_dict() |
|
target_dict = target.state_dict() |
|
for key in source_dict.keys(): |
|
target_dict[key].data.copy_(target_dict[key].data * decay + |
|
source_dict[key].data * (1 - decay)) |
|
|
|
|
|
class WarmupLR: |
|
def __init__(self, warmup) -> None: |
|
self.warmup = warmup |
|
|
|
def __call__(self, step): |
|
return min(step, self.warmup) / self.warmup |
|
|
|
|
|
def is_time(num_samples, every, step_size): |
|
closest = (num_samples // every) * every |
|
return num_samples - closest < step_size |
|
|
|
|
|
def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'): |
|
print('conf:', conf.name) |
|
|
|
|
|
model = LitModel(conf) |
|
|
|
if not os.path.exists(conf.logdir): |
|
os.makedirs(conf.logdir) |
|
checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}', |
|
save_last=True, |
|
save_top_k=-1, |
|
every_n_epochs=10) |
|
checkpoint_path = f'{conf.logdir}/last.ckpt' |
|
print('ckpt path:', checkpoint_path) |
|
if os.path.exists(checkpoint_path): |
|
resume = checkpoint_path |
|
print('resume!') |
|
else: |
|
if conf.continue_from is not None: |
|
|
|
resume = conf.continue_from.pathcd |
|
else: |
|
resume = None |
|
|
|
tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, |
|
name=None, |
|
version='') |
|
|
|
|
|
|
|
plugins = [] |
|
if len(gpus) == 1 and nodes == 1: |
|
accelerator = None |
|
else: |
|
accelerator = 'ddp' |
|
from pytorch_lightning.plugins import DDPPlugin |
|
|
|
|
|
plugins.append(DDPPlugin(find_unused_parameters=True)) |
|
|
|
trainer = pl.Trainer( |
|
max_steps=conf.total_samples // conf.batch_size_effective, |
|
resume_from_checkpoint=resume, |
|
gpus=gpus, |
|
num_nodes=nodes, |
|
accelerator=accelerator, |
|
precision=16 if conf.fp16 else 32, |
|
callbacks=[ |
|
checkpoint, |
|
LearningRateMonitor(), |
|
], |
|
|
|
|
|
replace_sampler_ddp=True, |
|
logger=tb_logger, |
|
accumulate_grad_batches=conf.accum_batches, |
|
plugins=plugins, |
|
) |
|
|
|
trainer.fit(model) |
|
|