File size: 4,250 Bytes
a00b67a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Calculate SI-SDR, Multi-resolution spectrogram mse score of the pre-inferenced sources
import os
import argparse
import csv
import json
import glob

import tqdm
import numpy as np
import librosa
import pyloudnorm as pyln
from asteroid.metrics import get_metrics

from utils import str2bool


def multi_resolution_spectrogram_mse(
    gt, est, n_fft=[2048, 1024, 512], n_hop=[512, 256, 128]
):
    assert gt.shape == est.shape
    assert len(n_fft) == len(n_hop)

    score = 0.0
    for i in range(len(n_fft)):
        gt_spec = librosa.magphase(
            librosa.stft(gt, n_fft=n_fft[i], hop_length=n_hop[i])
        )[0]
        est_spec = librosa.magphase(
            librosa.stft(est, n_fft=n_fft[i], hop_length=n_hop[i])
        )[0]
        score = score + np.mean((gt_spec - est_spec) ** 2)

    return score


parser = argparse.ArgumentParser(description="model test.py")

parser.add_argument(
    "--target",
    type=str,
    default="all",
    help="target source. all, vocals, drums, bass, other, 0.5_mixed",
)
parser.add_argument(
    "--root", type=str, default="/path/to/musdb18hq_loudnorm"
)
parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
parser.add_argument(
    "--output_directory",
    type=str,
    default="/path/to/results",
)
parser.add_argument("--loudnorm_lufs", type=float, default=-14.0)
parser.add_argument(
    "--calc_mse",
    type=str2bool,
    default=True,
    help="calculate multi-resolution spectrogram mse",
)

parser.add_argument(
    "--calc_results",
    type=str2bool,
    default=True,
    help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
)

args, _ = parser.parse_known_args()

args.sample_rate = 44100

meter = pyln.Meter(args.sample_rate)

if args.calc_results:
    args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
else:
    args.test_output_dir = f"{args.output_directory}/{args.exp_name}"

if args.target == "all" or args.target == "0.5_mixed":
    test_tracks = glob.glob(f"{args.root}/*/mixture.wav")
else:
    test_tracks = glob.glob(f"{args.root}/*/{args.target}.wav")
i = 0

dict_song_score = {}
list_si_sdr = []
list_multi_mse = []
for track in tqdm.tqdm(test_tracks):
    if args.target == "all":  # for standard de-limiter estimation
        audio_name = os.path.basename(os.path.dirname(track))
        gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]

        est_delimiter = librosa.load(
            f"{args.test_output_dir}/{audio_name}/all.wav",
            sr=args.sample_rate,
            mono=False,
        )[0]

    else:  # for source-separated de-limiter estimation
        audio_name = os.path.basename(os.path.dirname(track))
        gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
        est_delimiter = librosa.load(
            f"{args.test_output_dir}/{audio_name}/{args.target}.wav",
            sr=args.sample_rate,
            mono=False,
        )[0]


    metrics_dict = get_metrics(
        gt_source + est_delimiter,
        gt_source,
        est_delimiter,
        sample_rate=args.sample_rate,
        metrics_list=["si_sdr"],
    )

    if args.calc_mse:
        multi_resolution_spectrogram_mse_score = multi_resolution_spectrogram_mse(
            gt_source, est_delimiter
        )
    else:
        multi_resolution_spectrogram_mse_score = None

    dict_song_score[audio_name] = {
        "si_sdr": metrics_dict["si_sdr"],
        "multi_mse": multi_resolution_spectrogram_mse_score,
    }
    list_si_sdr.append(metrics_dict["si_sdr"])
    list_multi_mse.append(multi_resolution_spectrogram_mse_score)

    i += 1

print(f"{args.exp_name} on {args.target}")
print(f"SI-SDR score: {sum(list_si_sdr) / len(list_si_sdr)}")
if args.calc_mse:
    print(f"multi-mse score: {sum(list_multi_mse) / len(list_multi_mse)}")

if args.target != "all":
    # save dict_song_score to json file
    with open(f"{args.test_output_dir}/score_{args.target}.json", "w") as f:
        json.dump(dict_song_score, f, indent=4)
else:
    # save dict_song_score to json file
    with open(f"{args.test_output_dir}/score.json", "w") as f:
        json.dump(dict_song_score, f, indent=4)