File size: 2,478 Bytes
a00b67a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import argparse
import csv
import json
import glob

import tqdm
import numpy as np
import librosa
import musdb
import pyloudnorm as pyln

from utils import str2bool, db2linear

parser = argparse.ArgumentParser(description="model test.py")

parser.add_argument(
    "--target",
    type=str,
    default="all",
    help="target source. all, vocals, bass, drums, other.",
)
parser.add_argument(
    "--root",
    type=str,
    default="/path/to/musdb18hq_loudnorm",
)
parser.add_argument(
    "--output_directory",
    type=str,
    default="/path/to/results",
)
parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
parser.add_argument(
    "--calc_results",
    type=str2bool,
    default=True,
    help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
)

args, _ = parser.parse_known_args()

args.sample_rate = 44100
meter = pyln.Meter(args.sample_rate)

if args.calc_results:
    args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
else:
    args.test_output_dir = f"{args.output_directory}/{args.exp_name}"


est_track_list = glob.glob(f"{args.test_output_dir}/*/{args.target}.wav")
f = open(
    f"{args.test_output_dir}/score_feature_{args.target}.json",
    encoding="UTF-8",
)
dict_song_score_est = json.loads(f.read())

if args.target == "all":
    ref_track_list = glob.glob(f"{args.root}/*/mixture.wav")
    f = open(f"{args.root}/score_feature.json", encoding="UTF-8")
    dict_song_score_ref = json.loads(f.read())
else:
    ref_track_list = glob.glob(f"{args.root}/*/{args.target}.wav")
    f = open(f"{args.root}/score_feature_{args.target}.json", encoding="UTF-8")
    dict_song_score_ref = json.loads(f.read())

i = 0

dict_song_score = {}
list_diff_dynamic_complexity = []

for track in tqdm.tqdm(ref_track_list):
    audio_name = os.path.basename(os.path.dirname(track))
    ref_dyn_complexity = dict_song_score_ref[audio_name]["dynamic_complexity_score"]
    est_dyn_complexity = dict_song_score_est[audio_name]["dynamic_complexity_score"]

    list_diff_dynamic_complexity.append(est_dyn_complexity - ref_dyn_complexity)

    i += 1

print(
    f"Dynamic complexity difference {args.exp_name} vs {os.path.basename(args.root)} on {args.target}"
)
print("mean: ", np.mean(list_diff_dynamic_complexity))
print("median: ", np.median(list_diff_dynamic_complexity))
print("std: ", np.std(list_diff_dynamic_complexity))