Spaces:
Running
Running
# Calculate SI-SDR, Multi-resolution spectrogram mse score of the pre-inferenced sources | |
import os | |
import argparse | |
import csv | |
import json | |
import glob | |
import tqdm | |
import numpy as np | |
import librosa | |
import pyloudnorm as pyln | |
from asteroid.metrics import get_metrics | |
from utils import str2bool | |
def multi_resolution_spectrogram_mse( | |
gt, est, n_fft=[2048, 1024, 512], n_hop=[512, 256, 128] | |
): | |
assert gt.shape == est.shape | |
assert len(n_fft) == len(n_hop) | |
score = 0.0 | |
for i in range(len(n_fft)): | |
gt_spec = librosa.magphase( | |
librosa.stft(gt, n_fft=n_fft[i], hop_length=n_hop[i]) | |
)[0] | |
est_spec = librosa.magphase( | |
librosa.stft(est, n_fft=n_fft[i], hop_length=n_hop[i]) | |
)[0] | |
score = score + np.mean((gt_spec - est_spec) ** 2) | |
return score | |
parser = argparse.ArgumentParser(description="model test.py") | |
parser.add_argument( | |
"--target", | |
type=str, | |
default="all", | |
help="target source. all, vocals, drums, bass, other, 0.5_mixed", | |
) | |
parser.add_argument( | |
"--root", type=str, default="/path/to/musdb18hq_loudnorm" | |
) | |
parser.add_argument("--exp_name", type=str, default="convtasnet_6_s") | |
parser.add_argument( | |
"--output_directory", | |
type=str, | |
default="/path/to/results", | |
) | |
parser.add_argument("--loudnorm_lufs", type=float, default=-14.0) | |
parser.add_argument( | |
"--calc_mse", | |
type=str2bool, | |
default=True, | |
help="calculate multi-resolution spectrogram mse", | |
) | |
parser.add_argument( | |
"--calc_results", | |
type=str2bool, | |
default=True, | |
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)", | |
) | |
args, _ = parser.parse_known_args() | |
args.sample_rate = 44100 | |
meter = pyln.Meter(args.sample_rate) | |
if args.calc_results: | |
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}" | |
else: | |
args.test_output_dir = f"{args.output_directory}/{args.exp_name}" | |
if args.target == "all" or args.target == "0.5_mixed": | |
test_tracks = glob.glob(f"{args.root}/*/mixture.wav") | |
else: | |
test_tracks = glob.glob(f"{args.root}/*/{args.target}.wav") | |
i = 0 | |
dict_song_score = {} | |
list_si_sdr = [] | |
list_multi_mse = [] | |
for track in tqdm.tqdm(test_tracks): | |
if args.target == "all": # for standard de-limiter estimation | |
audio_name = os.path.basename(os.path.dirname(track)) | |
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0] | |
est_delimiter = librosa.load( | |
f"{args.test_output_dir}/{audio_name}/all.wav", | |
sr=args.sample_rate, | |
mono=False, | |
)[0] | |
else: # for source-separated de-limiter estimation | |
audio_name = os.path.basename(os.path.dirname(track)) | |
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0] | |
est_delimiter = librosa.load( | |
f"{args.test_output_dir}/{audio_name}/{args.target}.wav", | |
sr=args.sample_rate, | |
mono=False, | |
)[0] | |
metrics_dict = get_metrics( | |
gt_source + est_delimiter, | |
gt_source, | |
est_delimiter, | |
sample_rate=args.sample_rate, | |
metrics_list=["si_sdr"], | |
) | |
if args.calc_mse: | |
multi_resolution_spectrogram_mse_score = multi_resolution_spectrogram_mse( | |
gt_source, est_delimiter | |
) | |
else: | |
multi_resolution_spectrogram_mse_score = None | |
dict_song_score[audio_name] = { | |
"si_sdr": metrics_dict["si_sdr"], | |
"multi_mse": multi_resolution_spectrogram_mse_score, | |
} | |
list_si_sdr.append(metrics_dict["si_sdr"]) | |
list_multi_mse.append(multi_resolution_spectrogram_mse_score) | |
i += 1 | |
print(f"{args.exp_name} on {args.target}") | |
print(f"SI-SDR score: {sum(list_si_sdr) / len(list_si_sdr)}") | |
if args.calc_mse: | |
print(f"multi-mse score: {sum(list_multi_mse) / len(list_multi_mse)}") | |
if args.target != "all": | |
# save dict_song_score to json file | |
with open(f"{args.test_output_dir}/score_{args.target}.json", "w") as f: | |
json.dump(dict_song_score, f, indent=4) | |
else: | |
# save dict_song_score to json file | |
with open(f"{args.test_output_dir}/score.json", "w") as f: | |
json.dump(dict_song_score, f, indent=4) | |