|
|
|
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' |
|
__version__ = '1.0.4' |
|
|
|
import random |
|
import argparse |
|
from tqdm.auto import tqdm |
|
import os |
|
import torch |
|
import wandb |
|
import numpy as np |
|
import auraloss |
|
import torch.nn as nn |
|
from torch.optim import Adam, AdamW, SGD, RAdam, RMSprop |
|
from torch.utils.data import DataLoader |
|
from torch.cuda.amp.grad_scaler import GradScaler |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from ml_collections import ConfigDict |
|
import torch.nn.functional as F |
|
from typing import List, Tuple, Dict, Union, Callable, Any |
|
|
|
from dataset import MSSDataset |
|
from utils import get_model_from_config |
|
from valid import valid_multi_gpu, valid |
|
|
|
from utils import bind_lora_to_model, load_start_checkpoint |
|
import loralib as lora |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace: |
|
""" |
|
Parse command-line arguments for configuring the model, dataset, and training parameters. |
|
|
|
Args: |
|
dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv. |
|
|
|
Returns: |
|
Namespace object containing parsed arguments and their values. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_type", type=str, default='mdx23c', |
|
help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit") |
|
parser.add_argument("--config_path", type=str, help="path to config file") |
|
parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to start training") |
|
parser.add_argument("--results_path", type=str, |
|
help="path to folder where results will be stored (weights, metadata)") |
|
parser.add_argument("--data_path", nargs="+", type=str, help="Dataset data paths. You can provide several folders.") |
|
parser.add_argument("--dataset_type", type=int, default=1, |
|
help="Dataset type. Must be one of: 1, 2, 3 or 4. Details here: https://github.com/ZFTurbo/Music-Source-Separation-Training/blob/main/docs/dataset_types.md") |
|
parser.add_argument("--valid_path", nargs="+", type=str, |
|
help="validation data paths. You can provide several folders.") |
|
parser.add_argument("--num_workers", type=int, default=0, help="dataloader num_workers") |
|
parser.add_argument("--pin_memory", action='store_true', help="dataloader pin_memory") |
|
parser.add_argument("--seed", type=int, default=0, help="random seed") |
|
parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help='list of gpu ids') |
|
parser.add_argument("--loss", type=str, nargs='+', choices=['masked_loss', 'mse_loss', 'l1_loss', 'multistft_loss'], |
|
default=['masked_loss'], help="List of loss functions to use") |
|
parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key') |
|
parser.add_argument("--pre_valid", action='store_true', help='Run validation before training') |
|
parser.add_argument("--metrics", nargs='+', type=str, default=["sdr"], |
|
choices=['sdr', 'l1_freq', 'si_sdr', 'log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless', |
|
'fullness'], help='List of metrics to use.') |
|
parser.add_argument("--metric_for_scheduler", default="sdr", |
|
choices=['sdr', 'l1_freq', 'si_sdr', 'log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless', |
|
'fullness'], help='Metric which will be used for scheduler.') |
|
parser.add_argument("--train_lora", action='store_true', help="Train with LoRA") |
|
parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights") |
|
|
|
if dict_args is not None: |
|
args = parser.parse_args([]) |
|
args_dict = vars(args) |
|
args_dict.update(dict_args) |
|
args = argparse.Namespace(**args_dict) |
|
else: |
|
args = parser.parse_args() |
|
|
|
if args.metric_for_scheduler not in args.metrics: |
|
args.metrics += [args.metric_for_scheduler] |
|
|
|
return args |
|
|
|
|
|
def manual_seed(seed: int) -> None: |
|
""" |
|
Set the random seed for reproducibility across Python, NumPy, and PyTorch. |
|
|
|
Args: |
|
seed: The seed value to set. |
|
""" |
|
|
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
|
|
|
def initialize_environment(seed: int, results_path: str) -> None: |
|
""" |
|
Initialize the environment by setting the random seed, configuring PyTorch settings, |
|
and creating the results directory. |
|
|
|
Args: |
|
seed: The seed value for reproducibility. |
|
results_path: Path to the directory where results will be stored. |
|
""" |
|
|
|
manual_seed(seed) |
|
torch.backends.cudnn.deterministic = False |
|
try: |
|
torch.multiprocessing.set_start_method('spawn') |
|
except Exception as e: |
|
pass |
|
os.makedirs(results_path, exist_ok=True) |
|
|
|
def wandb_init(args: argparse.Namespace, config: Dict, device_ids: List[int], batch_size: int) -> None: |
|
""" |
|
Initialize the Weights & Biases (wandb) logging system. |
|
|
|
Args: |
|
args: Parsed command-line arguments containing the wandb key. |
|
config: Configuration dictionary for the experiment. |
|
device_ids: List of GPU device IDs used for training. |
|
batch_size: Batch size for training. |
|
""" |
|
|
|
if args.wandb_key is None or args.wandb_key.strip() == '': |
|
wandb.init(mode='disabled') |
|
else: |
|
wandb.login(key=args.wandb_key) |
|
wandb.init(project='msst', config={'config': config, 'args': args, 'device_ids': device_ids, 'batch_size': batch_size }) |
|
|
|
|
|
def prepare_data(config: Dict, args: argparse.Namespace, batch_size: int) -> DataLoader: |
|
""" |
|
Prepare the training dataset and data loader. |
|
|
|
Args: |
|
config: Configuration dictionary for the dataset. |
|
args: Parsed command-line arguments containing dataset paths and settings. |
|
batch_size: Batch size for training. |
|
|
|
Returns: |
|
DataLoader object for the training dataset. |
|
""" |
|
|
|
trainset = MSSDataset( |
|
config, |
|
args.data_path, |
|
batch_size=batch_size, |
|
metadata_path=os.path.join(args.results_path, f'metadata_{args.dataset_type}.pkl'), |
|
dataset_type=args.dataset_type, |
|
) |
|
|
|
train_loader = DataLoader( |
|
trainset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=args.num_workers, |
|
pin_memory=args.pin_memory |
|
) |
|
return train_loader |
|
|
|
|
|
def initialize_model_and_device(model: torch.nn.Module, device_ids: List[int]) -> Tuple[Union[torch.device, str], torch.nn.Module]: |
|
""" |
|
Initialize the model and assign it to the appropriate device (GPU or CPU). |
|
|
|
Args: |
|
model: The PyTorch model to be initialized. |
|
device_ids: List of GPU device IDs to use for parallel processing. |
|
|
|
Returns: |
|
A tuple containing the device and the model moved to that device. |
|
""" |
|
|
|
if torch.cuda.is_available(): |
|
if len(device_ids) <= 1: |
|
device = torch.device(f'cuda:{device_ids[0]}') |
|
model = model.to(device) |
|
else: |
|
device = torch.device(f'cuda:{device_ids[0]}') |
|
model = nn.DataParallel(model, device_ids=device_ids).to(device) |
|
else: |
|
device = 'cpu' |
|
model = model.to(device) |
|
print("CUDA is not available. Running on CPU.") |
|
|
|
return device, model |
|
|
|
|
|
def get_optimizer(config: ConfigDict, model: torch.nn.Module) -> torch.optim.Optimizer: |
|
""" |
|
Initializes an optimizer based on the configuration. |
|
|
|
Args: |
|
config: Configuration object containing training parameters. |
|
model: PyTorch model whose parameters will be optimized. |
|
|
|
Returns: |
|
A PyTorch optimizer object configured based on the specified settings. |
|
""" |
|
|
|
optim_params = dict() |
|
if 'optimizer' in config: |
|
optim_params = dict(config['optimizer']) |
|
print(f'Optimizer params from config:\n{optim_params}') |
|
|
|
name_optimizer = getattr(config.training, 'optimizer', |
|
'No optimizer in config') |
|
|
|
if name_optimizer == 'adam': |
|
optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif name_optimizer == 'adamw': |
|
optimizer = AdamW(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif name_optimizer == 'radam': |
|
optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif name_optimizer == 'rmsprop': |
|
optimizer = RMSprop(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif name_optimizer == 'prodigy': |
|
from prodigyopt import Prodigy |
|
|
|
|
|
optimizer = Prodigy(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif name_optimizer == 'adamw8bit': |
|
import bitsandbytes as bnb |
|
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.training.lr, **optim_params) |
|
elif name_optimizer == 'sgd': |
|
print('Use SGD optimizer') |
|
optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params) |
|
else: |
|
print(f'Unknown optimizer: {name_optimizer}') |
|
exit() |
|
return optimizer |
|
|
|
|
|
def multistft_loss(y: torch.Tensor, y_: torch.Tensor, |
|
loss_multistft: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> torch.Tensor: |
|
if len(y_.shape) == 4: |
|
y1_ = y_.reshape(y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3]) |
|
y1 = y.reshape(y.shape[0], y.shape[1] * y.shape[2], y.shape[3]) |
|
elif len(y_.shape) == 3: |
|
y1_, y1 = y_, y |
|
else: |
|
raise ValueError(f"Invalid shape for predicted array: {y_.shape}. Expected 3 or 4 dimensions.") |
|
return loss_multistft(y1_, y1) |
|
|
|
|
|
def masked_loss(y_: torch.Tensor, y: torch.Tensor, q: float, coarse: bool = True) -> torch.Tensor: |
|
loss = torch.nn.MSELoss(reduction='none')(y_, y).transpose(0, 1) |
|
if coarse: |
|
loss = loss.mean(dim=(-1, -2)) |
|
loss = loss.reshape(loss.shape[0], -1) |
|
quantile = torch.quantile(loss.detach(), q, interpolation='linear', dim=1, keepdim=True) |
|
mask = loss < quantile |
|
return (loss * mask).mean() |
|
|
|
|
|
def choice_loss(args: argparse.Namespace, config: ConfigDict) -> Callable[[Any, Any], int]: |
|
""" |
|
Select and return the appropriate loss function based on the configuration and arguments. |
|
|
|
Args: |
|
args: Parsed command-line arguments containing flags for different loss functions. |
|
config: Configuration object containing loss settings and parameters. |
|
|
|
Returns: |
|
A loss function that can be applied to the predicted and ground truth tensors. |
|
""" |
|
|
|
print(f'Losses for training: {args.loss}') |
|
loss_fns = [] |
|
if 'masked_loss' in args.loss: |
|
loss_fns.append( |
|
lambda y_, y: masked_loss(y_, y, q=config['training']['q'], coarse=config['training']['coarse_loss_clip'])) |
|
if 'mse_loss' in args.loss: |
|
loss_fns.append(nn.MSELoss()) |
|
if 'l1_loss' in args.loss: |
|
loss_fns.append(F.l1_loss) |
|
if 'multistft_loss' in args.loss: |
|
loss_options = dict(config.get('loss_multistft', {})) |
|
loss_multistft = auraloss.freq.MultiResolutionSTFTLoss(**loss_options) |
|
loss_fns.append(lambda y_, y: multistft_loss(y_, y, loss_multistft) / 1000) |
|
|
|
def multi_loss(y_, y): |
|
return sum(loss_fn(y_, y) for loss_fn in loss_fns) |
|
|
|
return multi_loss |
|
|
|
|
|
def normalize_batch(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Normalize a batch of tensors (x and y) by subtracting the mean and dividing by the standard deviation. |
|
|
|
Args: |
|
x: Tensor to normalize. |
|
y: Tensor to normalize (same as x, typically). |
|
|
|
Returns: |
|
A tuple of normalized tensors (x, y). |
|
""" |
|
|
|
mean = x.mean() |
|
std = x.std() |
|
if std != 0: |
|
x = (x - mean) / std |
|
y = (y - mean) / std |
|
return x, y |
|
|
|
|
|
def train_one_epoch(model: torch.nn.Module, config: ConfigDict, args: argparse.Namespace, optimizer: torch.optim.Optimizer, |
|
device: torch.device, device_ids: List[int], epoch: int, use_amp: bool, scaler: torch.cuda.amp.GradScaler, |
|
gradient_accumulation_steps: int, train_loader: torch.utils.data.DataLoader, |
|
multi_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> None: |
|
""" |
|
Train the model for one epoch. |
|
|
|
Args: |
|
model: The model to train. |
|
config: Configuration object containing training parameters. |
|
args: Command-line arguments with specific settings (e.g., model type). |
|
optimizer: Optimizer used for training. |
|
device: Device to run the model on (CPU or GPU). |
|
device_ids: List of GPU device IDs if using multiple GPUs. |
|
epoch: The current epoch number. |
|
use_amp: Whether to use automatic mixed precision (AMP) for training. |
|
scaler: Scaler for AMP to manage gradient scaling. |
|
gradient_accumulation_steps: Number of gradient accumulation steps before updating the optimizer. |
|
train_loader: DataLoader for the training dataset. |
|
multi_loss: The loss function to use during training. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
model.train().to(device) |
|
print(f'Train epoch: {epoch} Learning rate: {optimizer.param_groups[0]["lr"]}') |
|
loss_val = 0. |
|
total = 0 |
|
|
|
normalize = getattr(config.training, 'normalize', False) |
|
|
|
pbar = tqdm(train_loader) |
|
for i, (batch, mixes) in enumerate(pbar): |
|
x = mixes.to(device) |
|
y = batch.to(device) |
|
|
|
if normalize: |
|
x, y = normalize_batch(x, y) |
|
|
|
with torch.cuda.amp.autocast(enabled=use_amp): |
|
if args.model_type in ['mel_band_roformer', 'bs_roformer']: |
|
|
|
loss = model(x, y) |
|
if isinstance(device_ids, (list, tuple)): |
|
|
|
loss = loss.mean() |
|
else: |
|
y_ = model(x) |
|
loss = multi_loss(y_, y) |
|
|
|
loss /= gradient_accumulation_steps |
|
scaler.scale(loss).backward() |
|
if config.training.grad_clip: |
|
nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip) |
|
|
|
if ((i + 1) % gradient_accumulation_steps == 0) or (i == len(train_loader) - 1): |
|
scaler.step(optimizer) |
|
scaler.update() |
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
li = loss.item() * gradient_accumulation_steps |
|
loss_val += li |
|
total += 1 |
|
pbar.set_postfix({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1)}) |
|
wandb.log({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1), 'i': i}) |
|
loss.detach() |
|
|
|
print(f'Training loss: {loss_val / total}') |
|
wandb.log({'train_loss': loss_val / total, 'epoch': epoch, 'learning_rate': optimizer.param_groups[0]['lr']}) |
|
|
|
|
|
def save_weights(store_path, model, device_ids, train_lora): |
|
|
|
if train_lora: |
|
torch.save(lora.lora_state_dict(model), store_path) |
|
else: |
|
state_dict = model.state_dict() if len(device_ids) <= 1 else model.module.state_dict() |
|
torch.save( |
|
state_dict, |
|
store_path |
|
) |
|
|
|
|
|
def save_last_weights(args: argparse.Namespace, model: torch.nn.Module, device_ids: List[int]) -> None: |
|
""" |
|
Save the model's state_dict to a file for later use. |
|
|
|
Args: |
|
args: Command-line arguments containing the results path and model type. |
|
model: The model whose weights will be saved. |
|
device_ids: List of GPU device IDs if using multiple GPUs. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
store_path = f'{args.results_path}/last_{args.model_type}.ckpt' |
|
train_lora = args.train_lora |
|
save_weights(store_path, model, device_ids, train_lora) |
|
|
|
|
|
def compute_epoch_metrics(model: torch.nn.Module, args: argparse.Namespace, config: ConfigDict, |
|
device: torch.device, device_ids: List[int], best_metric: float, |
|
epoch: int, scheduler: torch.optim.lr_scheduler._LRScheduler) -> float: |
|
""" |
|
Compute and log the metrics for the current epoch, and save model weights if the metric improves. |
|
|
|
Args: |
|
model: The model to evaluate. |
|
args: Command-line arguments containing configuration paths and other settings. |
|
config: Configuration dictionary containing training settings. |
|
device: The device (CPU or GPU) used for evaluation. |
|
device_ids: List of GPU device IDs when using multiple GPUs. |
|
best_metric: The best metric value seen so far. |
|
epoch: The current epoch number. |
|
scheduler: The learning rate scheduler to adjust the learning rate. |
|
|
|
Returns: |
|
The updated best_metric. |
|
""" |
|
|
|
if torch.cuda.is_available() and len(device_ids) > 1: |
|
metrics_avg, all_metrics = valid_multi_gpu(model, args, config, args.device_ids, verbose=False) |
|
else: |
|
metrics_avg, all_metrics = valid(model, args, config, device, verbose=False) |
|
metric_avg = metrics_avg[args.metric_for_scheduler] |
|
if metric_avg > best_metric: |
|
store_path = f'{args.results_path}/model_{args.model_type}_ep_{epoch}_{args.metric_for_scheduler}_{metric_avg:.4f}.ckpt' |
|
print(f'Store weights: {store_path}') |
|
train_lora = args.train_lora |
|
save_weights(store_path, model, device_ids, train_lora) |
|
best_metric = metric_avg |
|
scheduler.step(metric_avg) |
|
wandb.log({'metric_main': metric_avg, 'best_metric': best_metric}) |
|
for metric_name in metrics_avg: |
|
wandb.log({f'metric_{metric_name}': metrics_avg[metric_name]}) |
|
|
|
return best_metric |
|
|
|
|
|
def train_model(args: argparse.Namespace) -> None: |
|
""" |
|
Trains the model based on the provided arguments, including data preparation, optimizer setup, |
|
and loss calculation. The model is trained for multiple epochs with logging via wandb. |
|
|
|
Args: |
|
args: Command-line arguments containing configuration paths, hyperparameters, and other settings. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
args = parse_args(args) |
|
|
|
initialize_environment(args.seed, args.results_path) |
|
model, config = get_model_from_config(args.model_type, args.config_path) |
|
use_amp = getattr(config.training, 'use_amp', True) |
|
device_ids = args.device_ids |
|
batch_size = config.training.batch_size * len(device_ids) |
|
|
|
wandb_init(args, config, device_ids, batch_size) |
|
|
|
train_loader = prepare_data(config, args, batch_size) |
|
|
|
if args.start_check_point: |
|
load_start_checkpoint(args, model, type_='train') |
|
|
|
if args.train_lora: |
|
model = bind_lora_to_model(config, model) |
|
lora.mark_only_lora_as_trainable(model) |
|
|
|
device, model = initialize_model_and_device(model, args.device_ids) |
|
|
|
if args.pre_valid: |
|
if torch.cuda.is_available() and len(device_ids) > 1: |
|
valid_multi_gpu(model, args, config, args.device_ids, verbose=True) |
|
else: |
|
valid(model, args, config, device, verbose=True) |
|
|
|
optimizer = get_optimizer(config, model) |
|
gradient_accumulation_steps = int(getattr(config.training, 'gradient_accumulation_steps', 1)) |
|
|
|
|
|
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=config.training.patience, |
|
factor=config.training.reduce_factor) |
|
|
|
multi_loss = choice_loss(args, config) |
|
scaler = GradScaler() |
|
best_metric = float('-inf') |
|
|
|
print( |
|
f"Instruments: {config.training.instruments}\n" |
|
f"Metrics for training: {args.metrics}. Metric for scheduler: {args.metric_for_scheduler}\n" |
|
f"Patience: {config.training.patience} " |
|
f"Reduce factor: {config.training.reduce_factor}\n" |
|
f"Batch size: {batch_size} " |
|
f"Grad accum steps: {gradient_accumulation_steps} " |
|
f"Effective batch size: {batch_size * gradient_accumulation_steps}\n" |
|
f"Dataset type: {args.dataset_type}\n" |
|
f"Optimizer: {config.training.optimizer}" |
|
) |
|
|
|
print(f'Train for: {config.training.num_epochs} epochs') |
|
|
|
for epoch in range(config.training.num_epochs): |
|
|
|
train_one_epoch(model, config, args, optimizer, device, device_ids, epoch, |
|
use_amp, scaler, gradient_accumulation_steps, train_loader, multi_loss) |
|
save_last_weights(args, model, device_ids) |
|
best_metric = compute_epoch_metrics(model, args, config, device, device_ids, best_metric, epoch, scheduler) |
|
|
|
|
|
if __name__ == "__main__": |
|
train_model(None) |