Spaces:
Running
Running
File size: 4,250 Bytes
a00b67a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# Calculate SI-SDR, Multi-resolution spectrogram mse score of the pre-inferenced sources
import os
import argparse
import csv
import json
import glob
import tqdm
import numpy as np
import librosa
import pyloudnorm as pyln
from asteroid.metrics import get_metrics
from utils import str2bool
def multi_resolution_spectrogram_mse(
gt, est, n_fft=[2048, 1024, 512], n_hop=[512, 256, 128]
):
assert gt.shape == est.shape
assert len(n_fft) == len(n_hop)
score = 0.0
for i in range(len(n_fft)):
gt_spec = librosa.magphase(
librosa.stft(gt, n_fft=n_fft[i], hop_length=n_hop[i])
)[0]
est_spec = librosa.magphase(
librosa.stft(est, n_fft=n_fft[i], hop_length=n_hop[i])
)[0]
score = score + np.mean((gt_spec - est_spec) ** 2)
return score
parser = argparse.ArgumentParser(description="model test.py")
parser.add_argument(
"--target",
type=str,
default="all",
help="target source. all, vocals, drums, bass, other, 0.5_mixed",
)
parser.add_argument(
"--root", type=str, default="/path/to/musdb18hq_loudnorm"
)
parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
parser.add_argument(
"--output_directory",
type=str,
default="/path/to/results",
)
parser.add_argument("--loudnorm_lufs", type=float, default=-14.0)
parser.add_argument(
"--calc_mse",
type=str2bool,
default=True,
help="calculate multi-resolution spectrogram mse",
)
parser.add_argument(
"--calc_results",
type=str2bool,
default=True,
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
)
args, _ = parser.parse_known_args()
args.sample_rate = 44100
meter = pyln.Meter(args.sample_rate)
if args.calc_results:
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
else:
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
if args.target == "all" or args.target == "0.5_mixed":
test_tracks = glob.glob(f"{args.root}/*/mixture.wav")
else:
test_tracks = glob.glob(f"{args.root}/*/{args.target}.wav")
i = 0
dict_song_score = {}
list_si_sdr = []
list_multi_mse = []
for track in tqdm.tqdm(test_tracks):
if args.target == "all": # for standard de-limiter estimation
audio_name = os.path.basename(os.path.dirname(track))
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
est_delimiter = librosa.load(
f"{args.test_output_dir}/{audio_name}/all.wav",
sr=args.sample_rate,
mono=False,
)[0]
else: # for source-separated de-limiter estimation
audio_name = os.path.basename(os.path.dirname(track))
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
est_delimiter = librosa.load(
f"{args.test_output_dir}/{audio_name}/{args.target}.wav",
sr=args.sample_rate,
mono=False,
)[0]
metrics_dict = get_metrics(
gt_source + est_delimiter,
gt_source,
est_delimiter,
sample_rate=args.sample_rate,
metrics_list=["si_sdr"],
)
if args.calc_mse:
multi_resolution_spectrogram_mse_score = multi_resolution_spectrogram_mse(
gt_source, est_delimiter
)
else:
multi_resolution_spectrogram_mse_score = None
dict_song_score[audio_name] = {
"si_sdr": metrics_dict["si_sdr"],
"multi_mse": multi_resolution_spectrogram_mse_score,
}
list_si_sdr.append(metrics_dict["si_sdr"])
list_multi_mse.append(multi_resolution_spectrogram_mse_score)
i += 1
print(f"{args.exp_name} on {args.target}")
print(f"SI-SDR score: {sum(list_si_sdr) / len(list_si_sdr)}")
if args.calc_mse:
print(f"multi-mse score: {sum(list_multi_mse) / len(list_multi_mse)}")
if args.target != "all":
# save dict_song_score to json file
with open(f"{args.test_output_dir}/score_{args.target}.json", "w") as f:
json.dump(dict_song_score, f, indent=4)
else:
# save dict_song_score to json file
with open(f"{args.test_output_dir}/score.json", "w") as f:
json.dump(dict_song_score, f, indent=4)
|