|
|
|
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' |
|
__version__ = '1.0.3' |
|
|
|
|
|
|
|
|
|
import argparse |
|
import soundfile as sf |
|
import numpy as np |
|
import time |
|
import glob |
|
from tqdm.auto import tqdm |
|
import os |
|
import torch |
|
import wandb |
|
import auraloss |
|
import torch.nn as nn |
|
from torch.optim import Adam, AdamW, SGD, RAdam, RMSprop |
|
from torch.utils.data import DataLoader |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
import torch.nn.functional as F |
|
from accelerate import Accelerator |
|
|
|
from dataset import MSSDataset |
|
from utils import get_model_from_config, demix, sdr, prefer_target_instrument |
|
from train import masked_loss, manual_seed, load_not_compatible_weights |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def valid(model, valid_loader, args, config, device, verbose=False): |
|
instruments = prefer_target_instrument(config) |
|
|
|
all_sdr = dict() |
|
for instr in instruments: |
|
all_sdr[instr] = [] |
|
|
|
all_mixtures_path = valid_loader |
|
if verbose: |
|
all_mixtures_path = tqdm(valid_loader) |
|
|
|
pbar_dict = {} |
|
for path_list in all_mixtures_path: |
|
path = path_list[0] |
|
mix, sr = sf.read(path) |
|
folder = os.path.dirname(path) |
|
res = demix(config, model, mix.T, device, model_type=args.model_type) |
|
for instr in instruments: |
|
if instr != 'other' or config.training.other_fix is False: |
|
track, sr1 = sf.read(folder + '/{}.wav'.format(instr)) |
|
else: |
|
|
|
track, sr1 = sf.read(folder + '/{}.wav'.format('vocals')) |
|
track = mix - track |
|
|
|
references = np.expand_dims(track, axis=0) |
|
estimates = np.expand_dims(res[instr].T, axis=0) |
|
sdr_val = sdr(references, estimates)[0] |
|
single_val = torch.from_numpy(np.array([sdr_val])).to(device) |
|
all_sdr[instr].append(single_val) |
|
pbar_dict['sdr_{}'.format(instr)] = sdr_val |
|
if verbose: |
|
all_mixtures_path.set_postfix(pbar_dict) |
|
|
|
return all_sdr |
|
|
|
|
|
class MSSValidationDataset(torch.utils.data.Dataset): |
|
def __init__(self, args): |
|
all_mixtures_path = [] |
|
for valid_path in args.valid_path: |
|
part = sorted(glob.glob(valid_path + '/*/mixture.wav')) |
|
if len(part) == 0: |
|
print('No validation data found in: {}'.format(valid_path)) |
|
all_mixtures_path += part |
|
|
|
self.list_of_files = all_mixtures_path |
|
|
|
def __len__(self): |
|
return len(self.list_of_files) |
|
|
|
def __getitem__(self, index): |
|
return self.list_of_files[index] |
|
|
|
|
|
def train_model(args): |
|
accelerator = Accelerator() |
|
device = accelerator.device |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_type", type=str, default='mdx23c', help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit") |
|
parser.add_argument("--config_path", type=str, help="path to config file") |
|
parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to start training") |
|
parser.add_argument("--results_path", type=str, help="path to folder where results will be stored (weights, metadata)") |
|
parser.add_argument("--data_path", nargs="+", type=str, help="Dataset data paths. You can provide several folders.") |
|
parser.add_argument("--dataset_type", type=int, default=1, help="Dataset type. Must be one of: 1, 2, 3 or 4. Details here: https://github.com/ZFTurbo/Music-Source-Separation-Training/blob/main/docs/dataset_types.md") |
|
parser.add_argument("--valid_path", nargs="+", type=str, help="validation data paths. You can provide several folders.") |
|
parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers") |
|
parser.add_argument("--pin_memory", type=bool, default=False, help="dataloader pin_memory") |
|
parser.add_argument("--seed", type=int, default=0, help="random seed") |
|
parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help='list of gpu ids') |
|
parser.add_argument("--use_multistft_loss", action='store_true', help="Use MultiSTFT Loss (from auraloss package)") |
|
parser.add_argument("--use_mse_loss", action='store_true', help="Use default MSE loss") |
|
parser.add_argument("--use_l1_loss", action='store_true', help="Use L1 loss") |
|
parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key') |
|
parser.add_argument("--pre_valid", action='store_true', help='Run validation before training') |
|
if args is None: |
|
args = parser.parse_args() |
|
else: |
|
args = parser.parse_args(args) |
|
|
|
manual_seed(args.seed + int(time.time())) |
|
|
|
torch.backends.cudnn.deterministic = False |
|
torch.multiprocessing.set_start_method('spawn') |
|
|
|
model, config = get_model_from_config(args.model_type, args.config_path) |
|
accelerator.print("Instruments: {}".format(config.training.instruments)) |
|
|
|
os.makedirs(args.results_path, exist_ok=True) |
|
|
|
device_ids = args.device_ids |
|
batch_size = config.training.batch_size |
|
|
|
|
|
if accelerator.is_main_process and args.wandb_key is not None and args.wandb_key.strip() != '': |
|
wandb.login(key = args.wandb_key) |
|
wandb.init(project = 'msst-accelerate', config = { 'config': config, 'args': args, 'device_ids': device_ids, 'batch_size': batch_size }) |
|
else: |
|
wandb.init(mode = 'disabled') |
|
|
|
|
|
config.training.num_steps *= accelerator.num_processes |
|
|
|
trainset = MSSDataset( |
|
config, |
|
args.data_path, |
|
batch_size=batch_size, |
|
metadata_path=os.path.join(args.results_path, 'metadata_{}.pkl'.format(args.dataset_type)), |
|
dataset_type=args.dataset_type, |
|
verbose=accelerator.is_main_process, |
|
) |
|
|
|
train_loader = DataLoader( |
|
trainset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=args.num_workers, |
|
pin_memory=args.pin_memory |
|
) |
|
|
|
validset = MSSValidationDataset(args) |
|
valid_dataset_length = len(validset) |
|
|
|
valid_loader = DataLoader( |
|
validset, |
|
batch_size=1, |
|
shuffle=False, |
|
) |
|
|
|
valid_loader = accelerator.prepare(valid_loader) |
|
|
|
if args.start_check_point != '': |
|
accelerator.print('Start from checkpoint: {}'.format(args.start_check_point)) |
|
if 1: |
|
load_not_compatible_weights(model, args.start_check_point, verbose=False) |
|
else: |
|
model.load_state_dict( |
|
torch.load(args.start_check_point) |
|
) |
|
|
|
optim_params = dict() |
|
if 'optimizer' in config: |
|
optim_params = dict(config['optimizer']) |
|
accelerator.print('Optimizer params from config:\n{}'.format(optim_params)) |
|
|
|
if config.training.optimizer == 'adam': |
|
optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif config.training.optimizer == 'adamw': |
|
optimizer = AdamW(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif config.training.optimizer == 'radam': |
|
optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif config.training.optimizer == 'rmsprop': |
|
optimizer = RMSprop(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif config.training.optimizer == 'prodigy': |
|
from prodigyopt import Prodigy |
|
|
|
|
|
optimizer = Prodigy(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif config.training.optimizer == 'adamw8bit': |
|
import bitsandbytes as bnb |
|
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif config.training.optimizer == 'sgd': |
|
accelerator.print('Use SGD optimizer') |
|
optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params) |
|
else: |
|
accelerator.print('Unknown optimizer: {}'.format(config.training.optimizer)) |
|
exit() |
|
|
|
if accelerator.is_main_process: |
|
print('Processes GPU: {}'.format(accelerator.num_processes)) |
|
print("Patience: {} Reduce factor: {} Batch size: {} Optimizer: {}".format( |
|
config.training.patience, |
|
config.training.reduce_factor, |
|
batch_size, |
|
config.training.optimizer, |
|
)) |
|
|
|
scheduler = ReduceLROnPlateau( |
|
optimizer, |
|
'max', |
|
|
|
patience=config.training.patience, |
|
factor=config.training.reduce_factor |
|
) |
|
|
|
if args.use_multistft_loss: |
|
try: |
|
loss_options = dict(config.loss_multistft) |
|
except: |
|
loss_options = dict() |
|
accelerator.print('Loss options: {}'.format(loss_options)) |
|
loss_multistft = auraloss.freq.MultiResolutionSTFTLoss( |
|
**loss_options |
|
) |
|
|
|
model, optimizer, train_loader, scheduler = accelerator.prepare(model, optimizer, train_loader, scheduler) |
|
|
|
if args.pre_valid: |
|
sdr_list = valid(model, valid_loader, args, config, device, verbose=accelerator.is_main_process) |
|
sdr_list = accelerator.gather(sdr_list) |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
|
|
sdr_avg = 0.0 |
|
instruments = prefer_target_instrument(config) |
|
|
|
for instr in instruments: |
|
|
|
sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy() |
|
sdr_val = sdr_data.mean() |
|
accelerator.print("Valid length: {}".format(valid_dataset_length)) |
|
accelerator.print("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data))) |
|
sdr_val = sdr_data[:valid_dataset_length].mean() |
|
accelerator.print("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data))) |
|
sdr_avg += sdr_val |
|
sdr_avg /= len(instruments) |
|
if len(instruments) > 1: |
|
accelerator.print('SDR Avg: {:.4f}'.format(sdr_avg)) |
|
sdr_list = None |
|
|
|
accelerator.print('Train for: {}'.format(config.training.num_epochs)) |
|
best_sdr = -100 |
|
for epoch in range(config.training.num_epochs): |
|
model.train().to(device) |
|
accelerator.print('Train epoch: {} Learning rate: {}'.format(epoch, optimizer.param_groups[0]['lr'])) |
|
loss_val = 0. |
|
total = 0 |
|
|
|
pbar = tqdm(train_loader, disable=not accelerator.is_main_process) |
|
for i, (batch, mixes) in enumerate(pbar): |
|
y = batch |
|
x = mixes |
|
|
|
if args.model_type in ['mel_band_roformer', 'bs_roformer']: |
|
|
|
loss = model(x, y) |
|
else: |
|
y_ = model(x) |
|
if args.use_multistft_loss: |
|
y1_ = torch.reshape(y_, (y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3])) |
|
y1 = torch.reshape(y, (y.shape[0], y.shape[1] * y.shape[2], y.shape[3])) |
|
loss = loss_multistft(y1_, y1) |
|
|
|
if args.use_mse_loss: |
|
loss += 1000 * nn.MSELoss()(y1_, y1) |
|
if args.use_l1_loss: |
|
loss += 1000 * F.l1_loss(y1_, y1) |
|
elif args.use_mse_loss: |
|
loss = nn.MSELoss()(y_, y) |
|
elif args.use_l1_loss: |
|
loss = F.l1_loss(y_, y) |
|
else: |
|
loss = masked_loss( |
|
y_, |
|
y, |
|
q=config.training.q, |
|
coarse=config.training.coarse_loss_clip |
|
) |
|
|
|
accelerator.backward(loss) |
|
if config.training.grad_clip: |
|
accelerator.clip_grad_norm_(model.parameters(), config.training.grad_clip) |
|
|
|
optimizer.step() |
|
optimizer.zero_grad() |
|
li = loss.item() |
|
loss_val += li |
|
total += 1 |
|
if accelerator.is_main_process: |
|
wandb.log({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1), 'total': total, 'loss_val': loss_val, 'i': i }) |
|
pbar.set_postfix({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1)}) |
|
|
|
if accelerator.is_main_process: |
|
print('Training loss: {:.6f}'.format(loss_val / total)) |
|
wandb.log({'train_loss': loss_val / total, 'epoch': epoch}) |
|
|
|
|
|
store_path = args.results_path + '/last_{}.ckpt'.format(args.model_type) |
|
accelerator.wait_for_everyone() |
|
if accelerator.is_main_process: |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
accelerator.save(unwrapped_model.state_dict(), store_path) |
|
|
|
sdr_list = valid(model, valid_loader, args, config, device, verbose=accelerator.is_main_process) |
|
sdr_list = accelerator.gather(sdr_list) |
|
accelerator.wait_for_everyone() |
|
|
|
sdr_avg = 0.0 |
|
instruments = prefer_target_instrument(config) |
|
|
|
for instr in instruments: |
|
if accelerator.is_main_process and 0: |
|
print(sdr_list[instr]) |
|
sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy() |
|
|
|
sdr_val = sdr_data[:valid_dataset_length].mean() |
|
if accelerator.is_main_process: |
|
print("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data))) |
|
wandb.log({ f'{instr}_sdr': sdr_val }) |
|
sdr_avg += sdr_val |
|
sdr_avg /= len(instruments) |
|
if len(instruments) > 1: |
|
if accelerator.is_main_process: |
|
print('SDR Avg: {:.4f}'.format(sdr_avg)) |
|
wandb.log({'sdr_avg': sdr_avg, 'best_sdr': best_sdr}) |
|
|
|
if accelerator.is_main_process: |
|
if sdr_avg > best_sdr: |
|
store_path = args.results_path + '/model_{}_ep_{}_sdr_{:.4f}.ckpt'.format(args.model_type, epoch, sdr_avg) |
|
print('Store weights: {}'.format(store_path)) |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
accelerator.save(unwrapped_model.state_dict(), store_path) |
|
best_sdr = sdr_avg |
|
|
|
scheduler.step(sdr_avg) |
|
|
|
sdr_list = None |
|
accelerator.wait_for_everyone() |
|
|
|
|
|
if __name__ == "__main__": |
|
train_model(None) |
|
|