De-limiter / eval_delimit /score_calc_delimit.py
jeonchangbin49's picture
first commit
a00b67a
# Calculate SI-SDR, Multi-resolution spectrogram mse score of the pre-inferenced sources
import os
import argparse
import csv
import json
import glob
import tqdm
import numpy as np
import librosa
import pyloudnorm as pyln
from asteroid.metrics import get_metrics
from utils import str2bool
def multi_resolution_spectrogram_mse(
gt, est, n_fft=[2048, 1024, 512], n_hop=[512, 256, 128]
):
assert gt.shape == est.shape
assert len(n_fft) == len(n_hop)
score = 0.0
for i in range(len(n_fft)):
gt_spec = librosa.magphase(
librosa.stft(gt, n_fft=n_fft[i], hop_length=n_hop[i])
)[0]
est_spec = librosa.magphase(
librosa.stft(est, n_fft=n_fft[i], hop_length=n_hop[i])
)[0]
score = score + np.mean((gt_spec - est_spec) ** 2)
return score
parser = argparse.ArgumentParser(description="model test.py")
parser.add_argument(
"--target",
type=str,
default="all",
help="target source. all, vocals, drums, bass, other, 0.5_mixed",
)
parser.add_argument(
"--root", type=str, default="/path/to/musdb18hq_loudnorm"
)
parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
parser.add_argument(
"--output_directory",
type=str,
default="/path/to/results",
)
parser.add_argument("--loudnorm_lufs", type=float, default=-14.0)
parser.add_argument(
"--calc_mse",
type=str2bool,
default=True,
help="calculate multi-resolution spectrogram mse",
)
parser.add_argument(
"--calc_results",
type=str2bool,
default=True,
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
)
args, _ = parser.parse_known_args()
args.sample_rate = 44100
meter = pyln.Meter(args.sample_rate)
if args.calc_results:
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
else:
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
if args.target == "all" or args.target == "0.5_mixed":
test_tracks = glob.glob(f"{args.root}/*/mixture.wav")
else:
test_tracks = glob.glob(f"{args.root}/*/{args.target}.wav")
i = 0
dict_song_score = {}
list_si_sdr = []
list_multi_mse = []
for track in tqdm.tqdm(test_tracks):
if args.target == "all": # for standard de-limiter estimation
audio_name = os.path.basename(os.path.dirname(track))
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
est_delimiter = librosa.load(
f"{args.test_output_dir}/{audio_name}/all.wav",
sr=args.sample_rate,
mono=False,
)[0]
else: # for source-separated de-limiter estimation
audio_name = os.path.basename(os.path.dirname(track))
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
est_delimiter = librosa.load(
f"{args.test_output_dir}/{audio_name}/{args.target}.wav",
sr=args.sample_rate,
mono=False,
)[0]
metrics_dict = get_metrics(
gt_source + est_delimiter,
gt_source,
est_delimiter,
sample_rate=args.sample_rate,
metrics_list=["si_sdr"],
)
if args.calc_mse:
multi_resolution_spectrogram_mse_score = multi_resolution_spectrogram_mse(
gt_source, est_delimiter
)
else:
multi_resolution_spectrogram_mse_score = None
dict_song_score[audio_name] = {
"si_sdr": metrics_dict["si_sdr"],
"multi_mse": multi_resolution_spectrogram_mse_score,
}
list_si_sdr.append(metrics_dict["si_sdr"])
list_multi_mse.append(multi_resolution_spectrogram_mse_score)
i += 1
print(f"{args.exp_name} on {args.target}")
print(f"SI-SDR score: {sum(list_si_sdr) / len(list_si_sdr)}")
if args.calc_mse:
print(f"multi-mse score: {sum(list_multi_mse) / len(list_multi_mse)}")
if args.target != "all":
# save dict_song_score to json file
with open(f"{args.test_output_dir}/score_{args.target}.json", "w") as f:
json.dump(dict_song_score, f, indent=4)
else:
# save dict_song_score to json file
with open(f"{args.test_output_dir}/score.json", "w") as f:
json.dump(dict_song_score, f, indent=4)