Spaces:
Sleeping
Sleeping
File size: 8,064 Bytes
73ed896 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import logging
import math
import random
from datetime import timedelta
from pathlib import Path
import hydra
import numpy as np
import torch
import torch.distributed as distributed
from hydra import compose
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, open_dict
from torch.distributed.elastic.multiprocessing.errors import record
from mmaudio.data.data_setup import setup_training_datasets, setup_val_datasets
from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K
from mmaudio.runner import Runner
from mmaudio.sample import sample
from mmaudio.utils.dist_utils import info_if_rank_zero, local_rank, world_size
from mmaudio.utils.logger import TensorboardLogger
from mmaudio.utils.synthesize_ema import synthesize_ema
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
log = logging.getLogger()
def distributed_setup():
distributed.init_process_group(backend="nccl", timeout=timedelta(hours=2))
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
return local_rank, world_size
@record
@hydra.main(version_base='1.3.2', config_path='config', config_name='train_config.yaml')
def train(cfg: DictConfig):
# initial setup
torch.cuda.set_device(local_rank)
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
distributed_setup()
num_gpus = world_size
run_dir = HydraConfig.get().run.dir
# compose early such that it does not rely on future hard disk reading
eval_cfg = compose('eval_config', overrides=[f'exp_id={cfg.exp_id}'])
# patch data dim
if cfg.model.endswith('16k'):
seq_cfg = CONFIG_16K
elif cfg.model.endswith('44k'):
seq_cfg = CONFIG_44K
else:
raise ValueError(f'Unknown model: {cfg.model}')
with open_dict(cfg):
cfg.data_dim.latent_seq_len = seq_cfg.latent_seq_len
cfg.data_dim.clip_seq_len = seq_cfg.clip_seq_len
cfg.data_dim.sync_seq_len = seq_cfg.sync_seq_len
# wrap python logger with a tensorboard logger
log = TensorboardLogger(cfg.exp_id,
run_dir,
logging.getLogger(),
is_rank0=(local_rank == 0),
enable_email=cfg.enable_email and not cfg.debug)
info_if_rank_zero(log, f'All configuration: {cfg}')
info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}')
# number of dataloader workers
info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}')
# Set seeds to ensure the same initialization
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
random.seed(cfg.seed)
# setting up configurations
info_if_rank_zero(log, f'Training configuration: {cfg}')
cfg.batch_size //= num_gpus
info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}')
# determine time to change max skip
total_iterations = cfg['num_iterations']
# setup datasets
dataset, sampler, loader = setup_training_datasets(cfg)
info_if_rank_zero(log, f'Number of training samples: {len(dataset)}')
info_if_rank_zero(log, f'Number of training batches: {len(loader)}')
val_dataset, val_loader, eval_loader = setup_val_datasets(cfg)
info_if_rank_zero(log, f'Number of val samples: {len(val_dataset)}')
val_cfg = cfg.data.ExtractedVGG_val
# compute and set mean and std
latent_mean, latent_std = dataset.compute_latent_stats()
# construct the trainer
trainer = Runner(cfg,
log=log,
run_path=run_dir,
for_training=True,
latent_mean=latent_mean,
latent_std=latent_std).enter_train()
eval_rng_clone = trainer.rng.graphsafe_get_state()
# load previous checkpoint if needed
if cfg['checkpoint'] is not None:
curr_iter = trainer.load_checkpoint(cfg['checkpoint'])
cfg['checkpoint'] = None
info_if_rank_zero(log, 'Model checkpoint loaded!')
else:
# if run_dir exists, load the latest checkpoint
checkpoint = trainer.get_latest_checkpoint_path()
if checkpoint is not None:
curr_iter = trainer.load_checkpoint(checkpoint)
info_if_rank_zero(log, 'Latest checkpoint loaded!')
else:
# load previous network weights if needed
curr_iter = 0
if cfg['weights'] is not None:
info_if_rank_zero(log, 'Loading weights from the disk')
trainer.load_weights(cfg['weights'])
cfg['weights'] = None
# determine max epoch
total_epoch = math.ceil(total_iterations / len(loader))
current_epoch = curr_iter // len(loader)
info_if_rank_zero(log, f'We will approximately use {total_epoch} epochs.')
# training loop
try:
# Need this to select random bases in different workers
np.random.seed(np.random.randint(2**30 - 1) + local_rank * 1000)
while curr_iter < total_iterations:
# Crucial for randomness!
sampler.set_epoch(current_epoch)
current_epoch += 1
log.debug(f'Current epoch: {current_epoch}')
trainer.enter_train()
trainer.log.data_timer.start()
for data in loader:
trainer.train_pass(data, curr_iter)
if (curr_iter + 1) % cfg.val_interval == 0:
# swap into a eval rng state, i.e., use the same seed for every validation pass
train_rng_snapshot = trainer.rng.graphsafe_get_state()
trainer.rng.graphsafe_set_state(eval_rng_clone)
info_if_rank_zero(log, f'Iteration {curr_iter}: validating')
for data in val_loader:
trainer.validation_pass(data, curr_iter)
distributed.barrier()
trainer.val_integrator.finalize('val', curr_iter, ignore_timer=True)
trainer.rng.graphsafe_set_state(train_rng_snapshot)
if (curr_iter + 1) % cfg.eval_interval == 0:
save_eval = (curr_iter + 1) % cfg.save_eval_interval == 0
train_rng_snapshot = trainer.rng.graphsafe_get_state()
trainer.rng.graphsafe_set_state(eval_rng_clone)
info_if_rank_zero(log, f'Iteration {curr_iter}: validating')
for data in eval_loader:
audio_path = trainer.inference_pass(data,
curr_iter,
val_cfg,
save_eval=save_eval)
distributed.barrier()
trainer.rng.graphsafe_set_state(train_rng_snapshot)
trainer.eval(audio_path, curr_iter, val_cfg)
curr_iter += 1
if curr_iter >= total_iterations:
break
except Exception as e:
log.error(f'Error occurred at iteration {curr_iter}!')
log.critical(e.message if hasattr(e, 'message') else str(e))
raise
finally:
if not cfg.debug:
trainer.save_checkpoint(curr_iter)
trainer.save_weights(curr_iter)
# Inference pass
del trainer
torch.cuda.empty_cache()
# Synthesize EMA
if local_rank == 0:
log.info(f'Synthesizing EMA with sigma={cfg.ema.default_output_sigma}')
ema_sigma = cfg.ema.default_output_sigma
state_dict = synthesize_ema(cfg, ema_sigma, step=None)
save_dir = Path(run_dir) / f'{cfg.exp_id}_ema_final.pth'
torch.save(state_dict, save_dir)
log.info(f'Synthesized EMA saved to {save_dir}!')
distributed.barrier()
log.info(f'Evaluation: {eval_cfg}')
sample(eval_cfg)
# clean-up
log.complete()
distributed.barrier()
distributed.destroy_process_group()
if __name__ == '__main__':
train()
|