Spaces:
Running
Running
import time | |
import json | |
import torch | |
import torch.nn as nn | |
import wandb | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
import torch.distributed as dist | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |
from asteroid.losses import ( | |
pairwise_neg_sisdr, | |
PairwiseNegSDR, | |
) | |
from einops import rearrange, reduce | |
from ema_pytorch import EMA | |
from models import load_model_with_args | |
import utils | |
from dataloader import ( | |
MusdbTrainDataset, | |
MusdbValidDataset, | |
DelimitTrainDataset, | |
DelimitValidDataset, | |
OzoneTrainDataset, | |
OzoneValidDataset, | |
aug_from_str, | |
SingleTrackSet, | |
) | |
class Solver(object): | |
def __init__(self): | |
pass | |
def set_gpu(self, args): | |
if args.wandb_params.use_wandb and args.gpu == 0: | |
if args.wandb_params.sweep: | |
wandb.init( | |
entity=args.wandb_params.entity, | |
project=args.wandb_params.project, | |
config=args, | |
resume=True | |
if args.dir_params.resume != None and args.gpu == 0 | |
else False, | |
) | |
else: | |
wandb.init( | |
entity=args.wandb_params.entity, | |
project=args.wandb_params.project, | |
name=f"{args.dir_params.exp_name}", | |
config=args, | |
resume="must" | |
if args.dir_params.resume != None | |
and not args.dir_params.continual_train | |
else False, | |
id=args.wandb_params.rerun_id | |
if args.wandb_params.rerun_id | |
else None, | |
settings=wandb.Settings(start_method="fork"), | |
) | |
###################### Define Models ###################### | |
self.model = load_model_with_args(args) | |
trainable_params = [] | |
trainable_params = trainable_params + list(self.model.parameters()) | |
if args.hyperparams.optimizer == "sgd": | |
print("Use SGD optimizer.") | |
self.optimizer = torch.optim.SGD( | |
params=trainable_params, | |
lr=args.hyperparams.lr, | |
momentum=0.9, | |
weight_decay=args.hyperparams.weight_decay, | |
) | |
elif args.hyperparams.optimizer == "adamw": | |
print("Use AdamW optimizer.") | |
self.optimizer = torch.optim.AdamW( | |
params=trainable_params, | |
lr=args.hyperparams.lr, | |
betas=(0.9, 0.999), | |
amsgrad=False, | |
weight_decay=args.hyperparams.weight_decay, | |
) | |
elif args.hyperparams.optimizer == "radam": | |
print("Use RAdam optimizer.") | |
self.optimizer = torch.optim.RAdam( | |
params=trainable_params, | |
lr=args.hyperparams.lr, | |
betas=(0.9, 0.999), | |
eps=1e-08, | |
weight_decay=args.hyperparams.weight_decay, | |
) | |
elif args.hyperparams.optimizer == "adam": | |
print("Use Adam optimizer.") | |
self.optimizer = torch.optim.Adam( | |
params=trainable_params, | |
lr=args.hyperparams.lr, | |
betas=(0.9, 0.999), | |
weight_decay=args.hyperparams.weight_decay, | |
) | |
else: | |
print("no optimizer loaded") | |
raise NotImplementedError | |
if args.hyperparams.lr_scheduler == "step_lr": | |
if args.model_loss_params.architecture == "umx": | |
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
self.optimizer, | |
mode="min", | |
factor=args.hyperparams.lr_decay_gamma, | |
patience=args.hyperparams.lr_decay_patience, | |
cooldown=10, | |
verbose=True, | |
) | |
else: | |
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
self.optimizer, | |
mode="min", | |
factor=args.hyperparams.lr_decay_gamma, | |
patience=args.hyperparams.lr_decay_patience, | |
cooldown=0, | |
min_lr=5e-5, | |
verbose=True, | |
) | |
elif args.hyperparams.lr_scheduler == "cos_warmup": | |
self.scheduler = utils.CosineAnnealingWarmUpRestarts( | |
self.optimizer, | |
T_0=40, | |
T_mult=1, | |
eta_max=args.hyperparams.lr, | |
T_up=10, | |
gamma=0.5, | |
) | |
torch.cuda.set_device(args.gpu) | |
self.model = self.model.to(f"cuda:{args.gpu}") | |
############################################################ | |
# Define Losses | |
self.criterion = {} | |
self.criterion["l1"] = nn.L1Loss().to(args.gpu) | |
self.criterion["mse"] = nn.MSELoss().to(args.gpu) | |
self.criterion["si_sdr"] = pairwise_neg_sisdr.to(args.gpu) | |
self.criterion["snr"] = PairwiseNegSDR("snr").to(args.gpu) | |
self.criterion["bcewithlogits"] = nn.BCEWithLogitsLoss().to(args.gpu) | |
self.criterion["bce"] = nn.BCELoss().to(args.gpu) | |
self.criterion["kl"] = nn.KLDivLoss(log_target=True).to(args.gpu) | |
print("Loss functions we use in this training:") | |
print(args.model_loss_params.train_loss_func) | |
# Early stopping utils | |
self.es = utils.EarlyStopping(patience=args.hyperparams.patience) | |
self.stop = False | |
if args.wandb_params.use_wandb and args.gpu == 0: | |
wandb.watch(self.model, log="all") | |
self.start_epoch = 1 | |
self.train_losses = [] | |
self.valid_losses = [] | |
self.train_times = [] | |
self.best_epoch = 0 | |
if args.dir_params.resume and not args.hyperparams.ema: | |
self.resume(args) | |
# Distribute models to machine | |
self.model = DDP( | |
self.model, | |
device_ids=[args.gpu], | |
output_device=args.gpu, | |
find_unused_parameters=True, | |
) | |
if args.hyperparams.ema: | |
self.model_ema = EMA( | |
self.model, | |
beta=0.999, | |
update_after_step=100, | |
update_every=10, | |
) | |
if args.resume and args.hyperparams.ema: | |
self.resume(args) | |
###################### Define data pipeline ###################### | |
args.hyperparams.batch_size = int( | |
args.hyperparams.batch_size / args.ngpus_per_node | |
) | |
self.mp_context = torch.multiprocessing.get_context("fork") | |
if args.task_params.dataset == "musdb": | |
self.train_dataset = MusdbTrainDataset( | |
target=args.task_params.target, | |
root=args.dir_params.root, | |
seq_duration=args.data_params.seq_dur, | |
samples_per_track=args.data_params.samples_per_track, | |
source_augmentations=aug_from_str( | |
["gain", "channelswap"], | |
), | |
sample_rate=args.data_params.sample_rate, | |
seed=args.sys_params.seed, | |
limitaug_method=args.data_params.limitaug_method, | |
limitaug_mode=args.data_params.limitaug_mode, | |
limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs, | |
limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std, | |
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs, | |
custom_limiter_attack_range=args.data_params.custom_limiter_attack_range, | |
custom_limiter_release_range=args.data_params.custom_limiter_release_range, | |
) | |
self.valid_dataset = MusdbValidDataset( | |
target=args.task_params.target, root=args.dir_params.root | |
) | |
elif args.task_params.dataset == "delimit": | |
if args.data_params.limitaug_method == "ozone": | |
self.train_dataset = OzoneTrainDataset( | |
target=args.task_params.target, | |
root=args.dir_params.root, | |
ozone_root=args.dir_params.ozone_root, | |
use_fixed=args.data_params.use_fixed, | |
seq_duration=args.data_params.seq_dur, | |
samples_per_track=args.data_params.samples_per_track, | |
source_augmentations=aug_from_str( | |
["gain", "channelswap"], | |
), | |
sample_rate=args.data_params.sample_rate, | |
seed=args.sys_params.seed, | |
limitaug_method=args.data_params.limitaug_method, | |
limitaug_mode=args.data_params.limitaug_mode, | |
limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs, | |
limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std, | |
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs, | |
target_limitaug_mode=args.data_params.target_limitaug_mode, | |
target_limitaug_custom_target_lufs=args.data_params.target_limitaug_custom_target_lufs, | |
target_limitaug_custom_target_lufs_std=args.data_params.target_limitaug_custom_target_lufs_std, | |
custom_limiter_attack_range=args.data_params.custom_limiter_attack_range, | |
custom_limiter_release_range=args.data_params.custom_limiter_release_range, | |
) | |
self.valid_dataset = OzoneValidDataset( | |
target=args.task_params.target, | |
root=args.dir_params.root, | |
ozone_root=args.dir_params.ozone_root, | |
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs, | |
) | |
else: | |
self.train_dataset = DelimitTrainDataset( | |
target=args.task_params.target, | |
root=args.dir_params.root, | |
seq_duration=args.data_params.seq_dur, | |
samples_per_track=args.data_params.samples_per_track, | |
source_augmentations=aug_from_str( | |
["gain", "channelswap"], | |
), | |
sample_rate=args.data_params.sample_rate, | |
seed=args.sys_params.seed, | |
limitaug_method=args.data_params.limitaug_method, | |
limitaug_mode=args.data_params.limitaug_mode, | |
limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs, | |
limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std, | |
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs, | |
target_limitaug_mode=args.data_params.target_limitaug_mode, | |
target_limitaug_custom_target_lufs=args.data_params.target_limitaug_custom_target_lufs, | |
target_limitaug_custom_target_lufs_std=args.data_params.target_limitaug_custom_target_lufs_std, | |
custom_limiter_attack_range=args.data_params.custom_limiter_attack_range, | |
custom_limiter_release_range=args.data_params.custom_limiter_release_range, | |
) | |
self.valid_dataset = DelimitValidDataset( | |
target=args.task_params.target, | |
root=args.dir_params.root, | |
delimit_valid_root=args.dir_params.delimit_valid_root, | |
valid_target_lufs=args.data_params.valid_target_lufs, | |
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs, | |
delimit_valid_L_root=args.dir_params.delimit_valid_L_root, | |
) | |
self.train_sampler = DistributedSampler( | |
self.train_dataset, shuffle=True, rank=args.gpu | |
) | |
self.train_loader = torch.utils.data.DataLoader( | |
self.train_dataset, | |
batch_size=args.hyperparams.batch_size, | |
shuffle=False, | |
num_workers=args.sys_params.nb_workers, | |
multiprocessing_context=self.mp_context, | |
pin_memory=True, | |
sampler=self.train_sampler, | |
drop_last=False, | |
) | |
self.valid_sampler = DistributedSampler( | |
self.valid_dataset, shuffle=False, rank=args.gpu | |
) | |
self.valid_loader = torch.utils.data.DataLoader( | |
self.valid_dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=args.sys_params.nb_workers, | |
multiprocessing_context=self.mp_context, | |
pin_memory=False, | |
sampler=self.valid_sampler, | |
drop_last=False, | |
) | |
def train(self, args, epoch): | |
self.end = time.time() | |
self.model.train() | |
# get current learning rate | |
for param_group in self.optimizer.param_groups: | |
current_lr = param_group["lr"] | |
if ( | |
args.sys_params.rank % args.ngpus_per_node == 0 | |
): # when the last rank process is finished | |
print(f"Epoch {epoch}, Learning rate: {current_lr}") | |
losses = utils.AverageMeter() | |
loss_logger = {} | |
loss_logger["train/train loss"] = 0 | |
# with torch.autograd.detect_anomaly(): # use this if you want to detect anomaly behavior while training. | |
for i, values in enumerate(self.train_loader): | |
mixture, clean, *train_vars = values | |
mixture = mixture.cuda(args.gpu, non_blocking=True) | |
clean = clean.cuda(args.gpu, non_blocking=True) | |
target = clean # target_shape = [batch_size, n_srcs, nb_channels (if stereo: 2), wave_length] | |
loss_input = {} | |
estimates, *estimates_vars = self.model(mixture) | |
# estimates = self.model(mixture) | |
# loss = [] | |
dict_loss = {} | |
if args.task_params.dataset == "delimit": | |
estimates = estimates_vars[0] | |
for train_loss_idx, single_train_loss_func in enumerate( | |
args.model_loss_params.train_loss_func | |
): | |
if self.model.module.use_encoder_to_target: | |
target_spec = self.model.module.encoder( | |
rearrange(target, "b s c t -> (b s) c t") | |
) | |
target_spec = rearrange( | |
target_spec, | |
"(b s) c f t -> b s c f t", | |
s=args.task_params.bleeding_nsrcs, | |
) | |
loss_else = self.criterion[single_train_loss_func]( | |
estimates, | |
target_spec | |
if self.model.module.use_encoder_to_target | |
else target, | |
) | |
dict_loss[single_train_loss_func] = ( | |
loss_else.mean() | |
* args.model_loss_params.train_loss_scales[train_loss_idx] | |
) | |
loss = sum([value for key, value in dict_loss.items()]) | |
############################################################ | |
#################### 5. Back propagation #################### | |
loss.backward() | |
if args.hyperparams.gradient_clip: | |
nn.utils.clip_grad_norm_( | |
self.model.parameters(), max_norm=args.hyperparams.gradient_clip | |
) | |
losses.update(loss.item(), clean.size(0)) | |
loss_logger["train/train loss"] = losses.avg | |
for key, value in dict_loss.items(): | |
loss_logger[f"train/{key}"] = value.item() | |
self.optimizer.step() | |
self.model.zero_grad( | |
set_to_none=True | |
) # set_to_none=True is for memory saving | |
if args.hyperparams.ema: | |
self.model_ema.update() | |
############################################################ | |
# ###################### 6. Plot ###################### | |
if i % 30 == 0: | |
# loss print for multiple loss function | |
multiple_score = torch.Tensor( | |
[value for key, value in loss_logger.items()] | |
).to(args.gpu) | |
gathered_score_list = [ | |
torch.ones_like(multiple_score) | |
for _ in range(dist.get_world_size()) | |
] | |
dist.all_gather(gathered_score_list, multiple_score) | |
gathered_score = torch.mean( | |
torch.stack(gathered_score_list, dim=0), dim=0 | |
) | |
if args.gpu == 0: | |
print(f"Epoch {epoch}, step {i} / {len(self.train_loader)}") | |
temp_loss_logger = {} | |
for index, (key, value) in enumerate(loss_logger.items()): | |
temp_key = key.replace("train/", "iter-wise/") | |
temp_loss_logger[temp_key] = round( | |
gathered_score[index].item(), 6 | |
) | |
print(f"{key} : {round(gathered_score[index].item(), 6)}") | |
single_score = torch.Tensor([losses.avg]).to(args.gpu) | |
gathered_score_list = [ | |
torch.ones_like(single_score) for _ in range(dist.get_world_size()) | |
] | |
dist.all_gather(gathered_score_list, single_score) | |
gathered_score = torch.mean(torch.cat(gathered_score_list)).item() | |
if args.gpu == 0: | |
self.train_losses.append(gathered_score) | |
if args.wandb_params.use_wandb: | |
loss_logger["train/train loss"] = single_score | |
loss_logger["train/epoch"] = epoch | |
wandb.log(loss_logger) | |
############################################################ | |
def multi_validate(self, args, epoch): | |
if args.gpu == 0: | |
print(f"Epoch {epoch} Validation session!") | |
losses = utils.AverageMeter() | |
loss_logger = {} | |
self.model.eval() | |
with torch.no_grad(): | |
for i, values in enumerate(self.valid_loader, start=1): | |
mixture, clean, song_name, *valid_vars = values | |
mixture = mixture.cuda(args.gpu, non_blocking=True) | |
clean = clean.cuda(args.gpu, non_blocking=True) | |
target = clean | |
dict_loss = {} | |
if not args.data_params.singleset_num_frames: | |
if args.hyperparams.ema: | |
estimates, *estimates_vars = self.model_ema(mixture) | |
else: | |
estimates, *estimates_vars = self.model(mixture) | |
if args.task_params.dataset == "delimit": | |
estimates = estimates_vars[0] | |
estimates = estimates[..., : clean.size(-1)] | |
else: # use SingleTrackSet | |
db = SingleTrackSet( | |
mixture[0], | |
hop_length=args.data_params.nhop, | |
num_frame=args.data_params.singleset_num_frames, | |
target_name=args.task_params.target, | |
) | |
separated = [] | |
for item in db: | |
if args.hyperparams.ema: | |
estimates, *estimates_vars = self.model_ema( | |
item.unsqueeze(0).to(args.gpu) | |
) | |
else: | |
estimates, *estimates_vars = self.model( | |
item.unsqueeze(0).to(args.gpu) | |
) | |
if args.task_params.dataset == "delimit": | |
estimates = estimates_vars[0] | |
separated.append( | |
estimates_vars[0][ | |
..., db.trim_length : -db.trim_length | |
].clone() | |
) | |
estimates = torch.cat(separated, dim=-1) | |
estimates = estimates[..., : target.shape[-1]] | |
for valid_loss_idx, single_valid_loss_func in enumerate( | |
args.model_loss_params.valid_loss_func | |
): | |
loss_else = self.criterion[single_valid_loss_func]( | |
estimates, | |
target, | |
) | |
dict_loss[single_valid_loss_func] = ( | |
loss_else.mean() | |
* args.model_loss_params.valid_loss_scales[valid_loss_idx] | |
) | |
loss = sum([value for key, value in dict_loss.items()]) | |
losses.update(loss.item(), clean.size(0)) | |
list_sum_count = torch.Tensor([losses.sum, losses.count]).to(args.gpu) | |
list_gathered_sum_count = [ | |
torch.ones_like(list_sum_count) for _ in range(dist.get_world_size()) | |
] | |
dist.all_gather(list_gathered_sum_count, list_sum_count) | |
gathered_score = reduce( | |
torch.stack(list_gathered_sum_count), "s c -> c", "sum" | |
) # s: sum of losses.sum, c: sum of losses.count | |
gathered_score = (gathered_score[0] / gathered_score[1]).item() | |
loss_logger["valid/valid loss"] = gathered_score | |
for key, value in dict_loss.items(): | |
loss_logger[f"valid/{key}"] = value.item() | |
if args.hyperparams.lr_scheduler == "step_lr": | |
self.scheduler.step(gathered_score) | |
elif args.hyperparams.lr_scheduler == "cos_warmup": | |
self.scheduler.step(epoch) | |
else: | |
self.scheduler.step(gathered_score) | |
if args.wandb_params.use_wandb and args.gpu == 0: | |
loss_logger["valid/epoch"] = epoch | |
wandb.log(loss_logger) | |
if args.gpu == 0: | |
self.valid_losses.append(gathered_score) | |
self.stop = self.es.step(gathered_score) | |
print(f"Epoch {epoch}, validation loss : {round(gathered_score, 6)}") | |
plt.plot(self.train_losses, label="train loss") | |
plt.plot(self.valid_losses, label="valid loss") | |
plt.legend(loc="upper right") | |
plt.savefig(f"{args.output}/loss_graph_{args.task_params.target}.png") | |
plt.close() | |
save_states = { | |
"epoch": epoch, | |
"state_dict": self.model.module.state_dict() | |
if not args.hyperparams.ema | |
else self.model_ema.state_dict(), | |
"best_loss": self.es.best, | |
"optimizer": self.optimizer.state_dict(), | |
"scheduler": self.scheduler.state_dict(), | |
} | |
utils.save_checkpoint( | |
save_states, | |
state_dict_only=gathered_score == self.es.best, | |
path=args.output, | |
target=args.task_params.target, | |
) | |
self.train_times.append(time.time() - self.end) | |
if gathered_score == self.es.best: | |
self.best_epoch = epoch | |
# save params | |
params = { | |
"epochs_trained": epoch, | |
"args": args.toDict(), | |
"best_loss": self.es.best, | |
"best_epoch": self.best_epoch, | |
"train_loss_history": self.train_losses, | |
"valid_loss_history": self.valid_losses, | |
"train_time_history": self.train_times, | |
"num_bad_epochs": self.es.num_bad_epochs, | |
} | |
with open( | |
f"{args.output}/{args.task_params.target}.json", "w" | |
) as outfile: | |
outfile.write(json.dumps(params, indent=4, sort_keys=True)) | |
self.train_times.append(time.time() - self.end) | |
print( | |
f"Epoch {epoch} train completed. Took {round(self.train_times[-1], 3)} seconds" | |
) | |
def resume(self, args): | |
print(f"Resume checkpoint from: {args.dir_params.resume}:") | |
loc = f"cuda:{args.gpu}" | |
checkpoint_path = f"{args.dir_params.resume}/{args.task_params.target}" | |
with open(f"{checkpoint_path}.json", "r") as stream: | |
results = json.load(stream) | |
checkpoint = torch.load(f"{checkpoint_path}.chkpnt", map_location=loc) | |
if args.hyperparams.ema: | |
self.model_ema.load_state_dict(checkpoint["state_dict"]) | |
else: | |
self.model.load_state_dict(checkpoint["state_dict"]) | |
self.optimizer.load_state_dict(checkpoint["optimizer"]) | |
if ( | |
args.dir_params.continual_train | |
): # we want to use a pre-trained model but not want to use lr_scheduler history. | |
for param_group in self.optimizer.param_groups: | |
param_group["lr"] = args.hyperparams.lr | |
else: | |
self.scheduler.load_state_dict(checkpoint["scheduler"]) | |
self.es.best = results["best_loss"] | |
self.es.num_bad_epochs = results["num_bad_epochs"] | |
self.start_epoch = results["epochs_trained"] | |
self.train_losses = results["train_loss_history"] | |
self.valid_losses = results["valid_loss_history"] | |
self.train_times = results["train_time_history"] | |
self.best_epoch = results["best_epoch"] | |
if args.sys_params.rank % args.ngpus_per_node == 0: | |
print( | |
f"=> loaded checkpoint {checkpoint_path} (epoch {results['epochs_trained']})" | |
) | |
def cal_loss(self, args, loss_input): | |
loss_dict = {} | |
for key, value in loss_input.items(): | |
loss_dict[key] = self.criterion[key](*value) | |
return loss_dict | |
def cal_multiple_losses(self, args, dict_loss_name_input): | |
loss_dict = {} | |
for loss_name, loss_input in dict_loss_name_input.items(): | |
loss_dict[loss_name] = self.cal_loss(args, loss_input) | |
return loss_dict | |