ASesYusuf1's picture
Upload folder using huggingface_hub
3978e51
# coding: utf-8
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
import argparse
import time
import os
import glob
import torch
import librosa
import numpy as np
import soundfile as sf
from tqdm.auto import tqdm
from ml_collections import ConfigDict
from typing import Tuple, Dict, List, Union
from utils import demix, get_model_from_config, prefer_target_instrument, draw_spectrogram
from utils import normalize_audio, denormalize_audio, apply_tta, read_audio_transposed, load_start_checkpoint
from metrics import get_metrics
import warnings
warnings.filterwarnings("ignore")
def logging(logs: List[str], text: str, verbose_logging: bool = False) -> None:
"""
Log validation information by printing the text and appending it to a log list.
Parameters:
----------
store_dir : str
Directory to store the logs. If empty, logs are not stored.
logs : List[str]
List where the logs will be appended if the store_dir is specified.
text : str
The text to be logged, printed, and optionally added to the logs list.
Returns:
-------
None
This function modifies the logs list in place and prints the text.
"""
print(text)
if verbose_logging:
logs.append(text)
def write_results_in_file(store_dir: str, logs: List[str]) -> None:
"""
Write the list of results into a file in the specified directory.
Parameters:
----------
store_dir : str
The directory where the results file will be saved.
results : List[str]
A list of result strings to be written to the file.
Returns:
-------
None
"""
with open(f'{store_dir}/results.txt', 'w') as out:
for item in logs:
out.write(item + "\n")
def get_mixture_paths(
args,
verbose: bool,
config: ConfigDict,
extension: str
) -> List[str]:
"""
Retrieve paths to mixture files in the specified validation directories.
Parameters:
----------
valid_path : List[str]
A list of directories to search for validation mixtures.
verbose : bool
If True, prints detailed information about the search process.
config : ConfigDict
Configuration object containing parameters like `inference.num_overlap` and `inference.batch_size`.
extension : str
File extension of the mixture files (e.g., 'wav').
Returns:
-------
List[str]
A list of file paths to the mixture files.
"""
try:
valid_path = args.valid_path
except Exception as e:
print('No valid path in args')
raise e
all_mixtures_path = []
for path in valid_path:
part = sorted(glob.glob(f"{path}/*/mixture.{extension}"))
if len(part) == 0:
if verbose:
print(f'No validation data found in: {path}')
all_mixtures_path += part
if verbose:
print(f'Total mixtures: {len(all_mixtures_path)}')
print(f'Overlap: {config.inference.num_overlap} Batch size: {config.inference.batch_size}')
return all_mixtures_path
def update_metrics_and_pbar(
track_metrics: Dict,
all_metrics: Dict,
instr: str,
pbar_dict: Dict,
mixture_paths: Union[List[str], tqdm],
verbose: bool = False
) -> None:
"""
Update metrics dictionary and progress bar with new metric values.
Parameters:
----------
track_metrics : Dict
Dictionary with metric names as keys and their computed values as values.
all_metrics : Dict
Dictionary to store all metrics, organized by metric name and instrument.
instr : str
Name of the instrument for which the metrics are being computed.
pbar_dict : Dict
Dictionary for progress bar updates.
mixture_paths : tqdm, optional
Progress bar object, if available. Default is None.
verbose : bool, optional
If True, prints metric values to the console. Default is False.
"""
for metric_name, metric_value in track_metrics.items():
if verbose:
print(f"Metric {metric_name:11s} value: {metric_value:.4f}")
all_metrics[metric_name][instr].append(metric_value)
pbar_dict[f'{metric_name}_{instr}'] = metric_value
if mixture_paths is not None:
try:
mixture_paths.set_postfix(pbar_dict)
except Exception:
pass
def process_audio_files(
mixture_paths: List[str],
model: torch.nn.Module,
args,
config,
device: torch.device,
verbose: bool = False,
is_tqdm: bool = True
) -> Dict[str, Dict[str, List[float]]]:
"""
Process a list of audio files, perform source separation, and evaluate metrics.
Parameters:
----------
mixture_paths : List[str]
List of file paths to the audio mixtures.
model : torch.nn.Module
The trained model used for source separation.
args : Any
Argument object containing user-specified options like metrics, model type, etc.
config : Any
Configuration object containing model and processing parameters.
device : torch.device
Device (CPU or CUDA) on which the model will be executed.
verbose : bool, optional
If True, prints detailed logs for each processed file. Default is False.
is_tqdm : bool, optional
If True, displays a progress bar for file processing. Default is True.
Returns:
-------
Dict[str, Dict[str, List[float]]]
A nested dictionary where the outer keys are metric names,
the inner keys are instrument names, and the values are lists of metric scores.
"""
instruments = prefer_target_instrument(config)
use_tta = getattr(args, 'use_tta', False)
# dir to save files, if empty no saving
store_dir = getattr(args, 'store_dir', '')
# codec to save files
if 'extension' in config['inference']:
extension = config['inference']['extension']
else:
extension = getattr(args, 'extension', 'wav')
# Initialize metrics dictionary
all_metrics = {
metric: {instr: [] for instr in config.training.instruments}
for metric in args.metrics
}
if is_tqdm:
mixture_paths = tqdm(mixture_paths)
for path in mixture_paths:
start_time = time.time()
mix, sr = read_audio_transposed(path)
mix_orig = mix.copy()
folder = os.path.dirname(path)
if 'sample_rate' in config.audio:
if sr != config.audio['sample_rate']:
orig_length = mix.shape[-1]
if verbose:
print(f'Warning: sample rate is different. In config: {config.audio["sample_rate"]} in file {path}: {sr}')
mix = librosa.resample(mix, orig_sr=sr, target_sr=config.audio['sample_rate'], res_type='kaiser_best')
if verbose:
folder_name = os.path.abspath(folder)
print(f'Song: {folder_name} Shape: {mix.shape}')
if 'normalize' in config.inference:
if config.inference['normalize'] is True:
mix, norm_params = normalize_audio(mix)
waveforms_orig = demix(config, model, mix.copy(), device, model_type=args.model_type)
if use_tta:
waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type)
pbar_dict = {}
for instr in instruments:
if verbose:
print(f"Instr: {instr}")
if instr != 'other' or config.training.other_fix is False:
track, sr1 = read_audio_transposed(f"{folder}/{instr}.{extension}", instr, skip_err=True)
if track is None:
continue
else:
# if track=vocal+other
track, sr1 = read_audio_transposed(f"{folder}/vocals.{extension}")
track = mix_orig - track
estimates = waveforms_orig[instr]
if 'sample_rate' in config.audio:
if sr != config.audio['sample_rate']:
estimates = librosa.resample(estimates, orig_sr=config.audio['sample_rate'], target_sr=sr,
res_type='kaiser_best')
estimates = librosa.util.fix_length(estimates, size=orig_length)
if 'normalize' in config.inference:
if config.inference['normalize'] is True:
estimates = denormalize_audio(estimates, norm_params)
if store_dir:
os.makedirs(store_dir, exist_ok=True)
out_wav_name = f"{store_dir}/{os.path.basename(folder)}_{instr}.wav"
sf.write(out_wav_name, estimates.T, sr, subtype='FLOAT')
if args.draw_spectro > 0:
out_img_name = f"{store_dir}/{os.path.basename(folder)}_{instr}.jpg"
draw_spectrogram(estimates.T, sr, args.draw_spectro, out_img_name)
out_img_name_orig = f"{store_dir}/{os.path.basename(folder)}_{instr}_orig.jpg"
draw_spectrogram(track.T, sr, args.draw_spectro, out_img_name_orig)
track_metrics = get_metrics(
args.metrics,
track,
estimates,
mix_orig,
device=device,
)
update_metrics_and_pbar(
track_metrics,
all_metrics,
instr, pbar_dict,
mixture_paths=mixture_paths,
verbose=verbose
)
if verbose:
print(f"Time for song: {time.time() - start_time:.2f} sec")
return all_metrics
def compute_metric_avg(
store_dir: str,
args,
instruments: List[str],
config: ConfigDict,
all_metrics: Dict[str, Dict[str, List[float]]],
start_time: float
) -> Dict[str, float]:
"""
Calculate and log the average metrics for each instrument, including per-instrument metrics and overall averages.
Parameters:
----------
store_dir : str
Directory to store the logs. If empty, logs are not stored.
args : dict
Dictionary containing the arguments, used for logging.
instruments : List[str]
List of instruments to process.
config : ConfigDict
Configuration dictionary containing the inference settings.
all_metrics : Dict[str, Dict[str, List[float]]]
A dictionary containing metric values for each instrument.
The structure is {metric_name: {instrument_name: [metric_values]}}.
start_time : float
The starting time for calculating elapsed time.
Returns:
-------
Dict[str, float]
A dictionary with the average value for each metric across all instruments.
"""
logs = []
if store_dir:
logs.append(str(args))
verbose_logging = True
else:
verbose_logging = False
logging(logs, text=f"Num overlap: {config.inference.num_overlap}", verbose_logging=verbose_logging)
metric_avg = {}
for instr in instruments:
for metric_name in all_metrics:
metric_values = np.array(all_metrics[metric_name][instr])
mean_val = metric_values.mean()
std_val = metric_values.std()
logging(logs, text=f"Instr {instr} {metric_name}: {mean_val:.4f} (Std: {std_val:.4f})", verbose_logging=verbose_logging)
if metric_name not in metric_avg:
metric_avg[metric_name] = 0.0
metric_avg[metric_name] += mean_val
for metric_name in all_metrics:
metric_avg[metric_name] /= len(instruments)
if len(instruments) > 1:
for metric_name in metric_avg:
logging(logs, text=f'Metric avg {metric_name:11s}: {metric_avg[metric_name]:.4f}', verbose_logging=verbose_logging)
logging(logs, text=f"Elapsed time: {time.time() - start_time:.2f} sec", verbose_logging=verbose_logging)
if store_dir:
write_results_in_file(store_dir, logs)
return metric_avg
def valid(
model: torch.nn.Module,
args,
config: ConfigDict,
device: torch.device,
verbose: bool = False
) -> Tuple[dict, dict]:
"""
Validate a trained model on a set of audio mixtures and compute metrics.
This function performs validation by separating audio sources from mixtures,
computing evaluation metrics, and optionally saving results to a file.
Parameters:
----------
model : torch.nn.Module
The trained model for source separation.
args : Namespace
Command-line arguments or equivalent object containing configurations.
config : dict
Configuration dictionary with model and processing parameters.
device : torch.device
The device (CPU or CUDA) to run the model on.
verbose : bool, optional
If True, enables verbose output during processing. Default is False.
Returns:
-------
dict
A dictionary of average metrics across all instruments.
"""
start_time = time.time()
model.eval().to(device)
# dir to save files, if empty no saving
store_dir = getattr(args, 'store_dir', '')
# codec to save files
if 'extension' in config['inference']:
extension = config['inference']['extension']
else:
extension = getattr(args, 'extension', 'wav')
all_mixtures_path = get_mixture_paths(args, verbose, config, extension)
all_metrics = process_audio_files(all_mixtures_path, model, args, config, device, verbose, not verbose)
instruments = prefer_target_instrument(config)
return compute_metric_avg(store_dir, args, instruments, config, all_metrics, start_time), all_metrics
def validate_in_subprocess(
proc_id: int,
queue: torch.multiprocessing.Queue,
all_mixtures_path: List[str],
model: torch.nn.Module,
args,
config: ConfigDict,
device: str,
return_dict
) -> None:
"""
Perform validation on a subprocess with multi-processing support. Each process handles inference on a subset of the mixture files
and updates the shared metrics dictionary.
Parameters:
----------
proc_id : int
The process ID (used to assign metrics to the correct key in `return_dict`).
queue : torch.multiprocessing.Queue
Queue to receive paths to the mixture files for processing.
all_mixtures_path : List[str]
List of paths to the mixture files to be processed.
model : torch.nn.Module
The model to be used for inference.
args : dict
Dictionary containing various argument configurations (e.g., metrics to calculate).
config : ConfigDict
Configuration object containing model settings and training parameters.
device : str
The device to use for inference (e.g., 'cpu', 'cuda:0').
return_dict : torch.multiprocessing.Manager().dict
Shared dictionary to store the results from each process.
Returns:
-------
None
The function modifies the `return_dict` in place, but does not return any value.
"""
m1 = model.eval().to(device)
if proc_id == 0:
progress_bar = tqdm(total=len(all_mixtures_path))
# Initialize metrics dictionary
all_metrics = {
metric: {instr: [] for instr in config.training.instruments}
for metric in args.metrics
}
while True:
current_step, path = queue.get()
if path is None: # check for sentinel value
break
single_metrics = process_audio_files([path], m1, args, config, device, False, False)
pbar_dict = {}
for instr in config.training.instruments:
for metric_name in all_metrics:
all_metrics[metric_name][instr] += single_metrics[metric_name][instr]
if len(single_metrics[metric_name][instr]) > 0:
pbar_dict[f"{metric_name}_{instr}"] = f"{single_metrics[metric_name][instr][0]:.4f}"
if proc_id == 0:
progress_bar.update(current_step - progress_bar.n)
progress_bar.set_postfix(pbar_dict)
# print(f"Inference on process {proc_id}", all_sdr)
return_dict[proc_id] = all_metrics
return
def run_parallel_validation(
verbose: bool,
all_mixtures_path: List[str],
config: ConfigDict,
model: torch.nn.Module,
device_ids: List[int],
args,
return_dict
) -> None:
"""
Run parallel validation using multiple processes. Each process handles a subset of the mixture files and computes the metrics.
The results are stored in a shared dictionary.
Parameters:
----------
verbose : bool
Flag to print detailed information about the validation process.
all_mixtures_path : List[str]
List of paths to the mixture files to be processed.
config : ConfigDict
Configuration object containing model settings and validation parameters.
model : torch.nn.Module
The model to be used for inference.
device_ids : List[int]
List of device IDs (for multi-GPU setups) to use for validation.
args : dict
Dictionary containing various argument configurations (e.g., metrics to calculate).
Returns:
-------
A shared dictionary containing the validation metrics from all processes.
"""
model = model.to('cpu')
try:
# For multiGPU training extract single model
model = model.module
except:
pass
queue = torch.multiprocessing.Queue()
processes = []
for i, device in enumerate(device_ids):
if torch.cuda.is_available():
device = f'cuda:{device}'
else:
device = 'cpu'
p = torch.multiprocessing.Process(
target=validate_in_subprocess,
args=(i, queue, all_mixtures_path, model, args, config, device, return_dict)
)
p.start()
processes.append(p)
for i, path in enumerate(all_mixtures_path):
queue.put((i, path))
for _ in range(len(device_ids)):
queue.put((None, None)) # sentinel value to signal subprocesses to exit
for p in processes:
p.join() # wait for all subprocesses to finish
return
def valid_multi_gpu(
model: torch.nn.Module,
args,
config: ConfigDict,
device_ids: List[int],
verbose: bool = False
) -> Tuple[Dict[str, float], dict]:
"""
Perform validation across multiple GPUs, processing mixtures and computing metrics using parallel processes.
The results from each GPU are aggregated and the average metrics are computed.
Parameters:
----------
model : torch.nn.Module
The model to be used for inference.
args : dict
Dictionary containing various argument configurations, such as file saving directory and codec settings.
config : ConfigDict
Configuration object containing model settings and validation parameters.
device_ids : List[int]
List of device IDs (for multi-GPU setups) to use for validation.
verbose : bool, optional
Flag to print detailed information about the validation process. Default is False.
Returns:
-------
Dict[str, float]
A dictionary containing the average metrics for each metric name.
"""
start_time = time.time()
# dir to save files, if empty no saving
store_dir = getattr(args, 'store_dir', '')
# codec to save files
if 'extension' in config['inference']:
extension = config['inference']['extension']
else:
extension = getattr(args, 'extension', 'wav')
all_mixtures_path = get_mixture_paths(args, verbose, config, extension)
return_dict = torch.multiprocessing.Manager().dict()
run_parallel_validation(verbose, all_mixtures_path, config, model, device_ids, args, return_dict)
all_metrics = dict()
for metric in args.metrics:
all_metrics[metric] = dict()
for instr in config.training.instruments:
all_metrics[metric][instr] = []
for i in range(len(device_ids)):
all_metrics[metric][instr] += return_dict[i][metric][instr]
instruments = prefer_target_instrument(config)
return compute_metric_avg(store_dir, args, instruments, config, all_metrics, start_time), all_metrics
def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace:
"""
Parse command-line arguments for configuring the model, dataset, and training parameters.
Args:
dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv.
Returns:
Namespace object containing parsed arguments and their values.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default='mdx23c',
help="One of mdx23c, htdemucs, segm_models, mel_band_roformer,"
" bs_roformer, swin_upernet, bandit")
parser.add_argument("--config_path", type=str, help="Path to config file")
parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint"
" to valid weights")
parser.add_argument("--valid_path", nargs="+", type=str, help="Validate path")
parser.add_argument("--store_dir", type=str, default="", help="Path to store results as wav file")
parser.add_argument("--draw_spectro", type=float, default=0,
help="If --store_dir is set then code will generate spectrograms for resulted stems as well."
" Value defines for how many seconds os track spectrogram will be generated.")
parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='List of gpu ids')
parser.add_argument("--num_workers", type=int, default=0, help="Dataloader num_workers")
parser.add_argument("--pin_memory", action='store_true', help="Dataloader pin_memory")
parser.add_argument("--extension", type=str, default='wav', help="Choose extension for validation")
parser.add_argument("--use_tta", action='store_true',
help="Flag adds test time augmentation during inference (polarity and channel inverse)."
"While this triples the runtime, it reduces noise and slightly improves prediction quality.")
parser.add_argument("--metrics", nargs='+', type=str, default=["sdr"],
choices=['sdr', 'l1_freq', 'si_sdr', 'neg_log_wmse', 'aura_stft', 'aura_mrstft', 'bleedless',
'fullness'], help='List of metrics to use.')
parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights")
if dict_args is not None:
args = parser.parse_args([])
args_dict = vars(args)
args_dict.update(dict_args)
args = argparse.Namespace(**args_dict)
else:
args = parser.parse_args()
return args
def check_validation(dict_args):
args = parse_args(dict_args)
torch.backends.cudnn.benchmark = True
try:
torch.multiprocessing.set_start_method('spawn')
except Exception as e:
pass
model, config = get_model_from_config(args.model_type, args.config_path)
if args.start_check_point:
load_start_checkpoint(args, model, type_='valid')
print(f"Instruments: {config.training.instruments}")
device_ids = args.device_ids
if torch.cuda.is_available():
device = torch.device(f'cuda:{device_ids[0]}')
else:
device = 'cpu'
print('CUDA is not available. Run validation on CPU. It will be very slow...')
if torch.cuda.is_available() and len(device_ids) > 1:
valid_multi_gpu(model, args, config, device_ids, verbose=False)
else:
valid(model, args, config, device, verbose=True)
if __name__ == "__main__":
check_validation(None)