Spaces:
Sleeping
Sleeping
import logging | |
import math | |
import random | |
from datetime import timedelta | |
from pathlib import Path | |
import hydra | |
import numpy as np | |
import torch | |
import torch.distributed as distributed | |
from hydra import compose | |
from hydra.core.hydra_config import HydraConfig | |
from omegaconf import DictConfig, open_dict | |
from torch.distributed.elastic.multiprocessing.errors import record | |
from mmaudio.data.data_setup import setup_training_datasets, setup_val_datasets | |
from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K | |
from mmaudio.runner import Runner | |
from mmaudio.sample import sample | |
from mmaudio.utils.dist_utils import info_if_rank_zero, local_rank, world_size | |
from mmaudio.utils.logger import TensorboardLogger | |
from mmaudio.utils.synthesize_ema import synthesize_ema | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
log = logging.getLogger() | |
def distributed_setup(): | |
distributed.init_process_group(backend="nccl", timeout=timedelta(hours=2)) | |
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') | |
return local_rank, world_size | |
def train(cfg: DictConfig): | |
# initial setup | |
torch.cuda.set_device(local_rank) | |
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark | |
distributed_setup() | |
num_gpus = world_size | |
run_dir = HydraConfig.get().run.dir | |
# compose early such that it does not rely on future hard disk reading | |
eval_cfg = compose('eval_config', overrides=[f'exp_id={cfg.exp_id}']) | |
# patch data dim | |
if cfg.model.endswith('16k'): | |
seq_cfg = CONFIG_16K | |
elif cfg.model.endswith('44k'): | |
seq_cfg = CONFIG_44K | |
else: | |
raise ValueError(f'Unknown model: {cfg.model}') | |
with open_dict(cfg): | |
cfg.data_dim.latent_seq_len = seq_cfg.latent_seq_len | |
cfg.data_dim.clip_seq_len = seq_cfg.clip_seq_len | |
cfg.data_dim.sync_seq_len = seq_cfg.sync_seq_len | |
# wrap python logger with a tensorboard logger | |
log = TensorboardLogger(cfg.exp_id, | |
run_dir, | |
logging.getLogger(), | |
is_rank0=(local_rank == 0), | |
enable_email=cfg.enable_email and not cfg.debug) | |
info_if_rank_zero(log, f'All configuration: {cfg}') | |
info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}') | |
# number of dataloader workers | |
info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}') | |
# Set seeds to ensure the same initialization | |
torch.manual_seed(cfg.seed) | |
np.random.seed(cfg.seed) | |
random.seed(cfg.seed) | |
# setting up configurations | |
info_if_rank_zero(log, f'Training configuration: {cfg}') | |
cfg.batch_size //= num_gpus | |
info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}') | |
# determine time to change max skip | |
total_iterations = cfg['num_iterations'] | |
# setup datasets | |
dataset, sampler, loader = setup_training_datasets(cfg) | |
info_if_rank_zero(log, f'Number of training samples: {len(dataset)}') | |
info_if_rank_zero(log, f'Number of training batches: {len(loader)}') | |
val_dataset, val_loader, eval_loader = setup_val_datasets(cfg) | |
info_if_rank_zero(log, f'Number of val samples: {len(val_dataset)}') | |
val_cfg = cfg.data.ExtractedVGG_val | |
# compute and set mean and std | |
latent_mean, latent_std = dataset.compute_latent_stats() | |
# construct the trainer | |
trainer = Runner(cfg, | |
log=log, | |
run_path=run_dir, | |
for_training=True, | |
latent_mean=latent_mean, | |
latent_std=latent_std).enter_train() | |
eval_rng_clone = trainer.rng.graphsafe_get_state() | |
# load previous checkpoint if needed | |
if cfg['checkpoint'] is not None: | |
curr_iter = trainer.load_checkpoint(cfg['checkpoint']) | |
cfg['checkpoint'] = None | |
info_if_rank_zero(log, 'Model checkpoint loaded!') | |
else: | |
# if run_dir exists, load the latest checkpoint | |
checkpoint = trainer.get_latest_checkpoint_path() | |
if checkpoint is not None: | |
curr_iter = trainer.load_checkpoint(checkpoint) | |
info_if_rank_zero(log, 'Latest checkpoint loaded!') | |
else: | |
# load previous network weights if needed | |
curr_iter = 0 | |
if cfg['weights'] is not None: | |
info_if_rank_zero(log, 'Loading weights from the disk') | |
trainer.load_weights(cfg['weights']) | |
cfg['weights'] = None | |
# determine max epoch | |
total_epoch = math.ceil(total_iterations / len(loader)) | |
current_epoch = curr_iter // len(loader) | |
info_if_rank_zero(log, f'We will approximately use {total_epoch} epochs.') | |
# training loop | |
try: | |
# Need this to select random bases in different workers | |
np.random.seed(np.random.randint(2**30 - 1) + local_rank * 1000) | |
while curr_iter < total_iterations: | |
# Crucial for randomness! | |
sampler.set_epoch(current_epoch) | |
current_epoch += 1 | |
log.debug(f'Current epoch: {current_epoch}') | |
trainer.enter_train() | |
trainer.log.data_timer.start() | |
for data in loader: | |
trainer.train_pass(data, curr_iter) | |
if (curr_iter + 1) % cfg.val_interval == 0: | |
# swap into a eval rng state, i.e., use the same seed for every validation pass | |
train_rng_snapshot = trainer.rng.graphsafe_get_state() | |
trainer.rng.graphsafe_set_state(eval_rng_clone) | |
info_if_rank_zero(log, f'Iteration {curr_iter}: validating') | |
for data in val_loader: | |
trainer.validation_pass(data, curr_iter) | |
distributed.barrier() | |
trainer.val_integrator.finalize('val', curr_iter, ignore_timer=True) | |
trainer.rng.graphsafe_set_state(train_rng_snapshot) | |
if (curr_iter + 1) % cfg.eval_interval == 0: | |
save_eval = (curr_iter + 1) % cfg.save_eval_interval == 0 | |
train_rng_snapshot = trainer.rng.graphsafe_get_state() | |
trainer.rng.graphsafe_set_state(eval_rng_clone) | |
info_if_rank_zero(log, f'Iteration {curr_iter}: validating') | |
for data in eval_loader: | |
audio_path = trainer.inference_pass(data, | |
curr_iter, | |
val_cfg, | |
save_eval=save_eval) | |
distributed.barrier() | |
trainer.rng.graphsafe_set_state(train_rng_snapshot) | |
trainer.eval(audio_path, curr_iter, val_cfg) | |
curr_iter += 1 | |
if curr_iter >= total_iterations: | |
break | |
except Exception as e: | |
log.error(f'Error occurred at iteration {curr_iter}!') | |
log.critical(e.message if hasattr(e, 'message') else str(e)) | |
raise | |
finally: | |
if not cfg.debug: | |
trainer.save_checkpoint(curr_iter) | |
trainer.save_weights(curr_iter) | |
# Inference pass | |
del trainer | |
torch.cuda.empty_cache() | |
# Synthesize EMA | |
if local_rank == 0: | |
log.info(f'Synthesizing EMA with sigma={cfg.ema.default_output_sigma}') | |
ema_sigma = cfg.ema.default_output_sigma | |
state_dict = synthesize_ema(cfg, ema_sigma, step=None) | |
save_dir = Path(run_dir) / f'{cfg.exp_id}_ema_final.pth' | |
torch.save(state_dict, save_dir) | |
log.info(f'Synthesized EMA saved to {save_dir}!') | |
distributed.barrier() | |
log.info(f'Evaluation: {eval_cfg}') | |
sample(eval_cfg) | |
# clean-up | |
log.complete() | |
distributed.barrier() | |
distributed.destroy_process_group() | |
if __name__ == '__main__': | |
train() | |