Spaces:
Running
Running
Commit
·
a00b67a
1
Parent(s):
da27cbe
first commit
Browse files- .gitattributes +0 -34
- LICENSE +21 -0
- README.md +2 -13
- add.py +293 -0
- configs/delimit_6_s.yaml +92 -0
- dataloader/__init__.py +8 -0
- dataloader/dataset.py +579 -0
- dataloader/delimit_dataset.py +573 -0
- dataloader/singleset.py +95 -0
- eval_delimit/calc_flops.py +44 -0
- eval_delimit/score_calc_delimit.py +145 -0
- eval_delimit/score_diff_dyn_complexity.py +87 -0
- eval_delimit/score_fad.py +75 -0
- eval_delimit/score_features.py +233 -0
- eval_delimit/score_peaq.py +77 -0
- eval_delimit/score_peaq_aggregate.py +88 -0
- inference.py +165 -0
- main_ddp.py +49 -0
- models/__init__.py +1 -0
- models/base_models.py +239 -0
- models/load_models.py +87 -0
- prepro/delimit_save_delimiter_stems.py +93 -0
- prepro/delimit_save_musdb_loudnorm.py +118 -0
- prepro/delimit_train_ozone_prepro.py +293 -0
- prepro/delimit_valid_L_prepro.py +41 -0
- prepro/delimit_valid_custom_limiter_prepro.py +59 -0
- prepro/delimit_valid_prepro.py +41 -0
- requirements.txt +13 -0
- separate_func/__init__.py +1 -0
- separate_func/conv_tasnet_separate.py +89 -0
- solver_ddp.py +643 -0
- test_ddp.py +245 -0
- train_ddp.py +56 -0
- utils/__init__.py +19 -0
- utils/logging.py +79 -0
- utils/loudness_utils.py +71 -0
- utils/lr_scheduler.py +80 -0
- utils/read_wave_utils.py +109 -0
- utils/train_utils.py +27 -0
- weight/all.json +957 -0
- weight/all.pth +3 -0
.gitattributes
CHANGED
@@ -1,35 +1 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 jeonchangbin49
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 🏃
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: indigo
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.39.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# De-limiter
|
2 |
+
An official demo of "Music De-limiter Networks via Sample-wise Gain Inversion", which will be presented in WASPAA 2023.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
add.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torch
|
8 |
+
import tqdm
|
9 |
+
import librosa
|
10 |
+
import librosa.display
|
11 |
+
import soundfile as sf
|
12 |
+
import pyloudnorm as pyln
|
13 |
+
from dotmap import DotMap
|
14 |
+
import gradio as gr
|
15 |
+
|
16 |
+
from models import load_model_with_args
|
17 |
+
from separate_func import (
|
18 |
+
conv_tasnet_separate,
|
19 |
+
)
|
20 |
+
from utils import db2linear
|
21 |
+
|
22 |
+
|
23 |
+
tqdm.monitor_interval = 0
|
24 |
+
|
25 |
+
|
26 |
+
def separate_track_with_model(
|
27 |
+
args, model, device, track_audio, track_name, meter, augmented_gain
|
28 |
+
):
|
29 |
+
with torch.no_grad():
|
30 |
+
if (
|
31 |
+
args.model_loss_params.architecture == "conv_tasnet_mask_on_output"
|
32 |
+
or args.model_loss_params.architecture == "conv_tasnet"
|
33 |
+
):
|
34 |
+
estimates = conv_tasnet_separate(
|
35 |
+
args,
|
36 |
+
model,
|
37 |
+
device,
|
38 |
+
track_audio,
|
39 |
+
track_name,
|
40 |
+
meter=meter,
|
41 |
+
augmented_gain=augmented_gain,
|
42 |
+
)
|
43 |
+
|
44 |
+
return estimates
|
45 |
+
|
46 |
+
|
47 |
+
def main(input, mix_coefficient):
|
48 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
49 |
+
parser.add_argument("--target", type=str, default="all")
|
50 |
+
parser.add_argument("--weight_directory", type=str, default="weight")
|
51 |
+
parser.add_argument("--output_directory", type=str, default="output")
|
52 |
+
parser.add_argument("--use_gpu", type=bool, default=True)
|
53 |
+
parser.add_argument("--save_name_as_target", type=bool, default=False)
|
54 |
+
parser.add_argument(
|
55 |
+
"--loudnorm_input_lufs",
|
56 |
+
type=float,
|
57 |
+
default=None,
|
58 |
+
help="If you want to use loudnorm for input",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--save_output_loudnorm",
|
62 |
+
type=float,
|
63 |
+
default=-14.0,
|
64 |
+
help="Save loudness normalized outputs or not. If you want to save, input target loudness",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--save_mixed_output",
|
68 |
+
type=float,
|
69 |
+
default=None,
|
70 |
+
help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--save_16k_mono",
|
74 |
+
type=bool,
|
75 |
+
default=False,
|
76 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
77 |
+
)
|
78 |
+
parser.add_argument(
|
79 |
+
"--save_histogram",
|
80 |
+
type=bool,
|
81 |
+
default=False,
|
82 |
+
help="Save histogram of the output. Only valid when the task is 'delimit'",
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--use_singletrackset",
|
86 |
+
type=bool,
|
87 |
+
default=False,
|
88 |
+
help="Use SingleTrackSet if input data is too long.",
|
89 |
+
)
|
90 |
+
|
91 |
+
args, _ = parser.parse_known_args()
|
92 |
+
|
93 |
+
with open(f"{args.weight_directory}/{args.target}.json", "r") as f:
|
94 |
+
args_dict = json.load(f)
|
95 |
+
args_dict = DotMap(args_dict)
|
96 |
+
|
97 |
+
for key, value in args_dict["args"].items():
|
98 |
+
if key in list(vars(args).keys()):
|
99 |
+
pass
|
100 |
+
else:
|
101 |
+
setattr(args, key, value)
|
102 |
+
|
103 |
+
args.test_output_dir = f"{args.output_directory}"
|
104 |
+
os.makedirs(args.test_output_dir, exist_ok=True)
|
105 |
+
|
106 |
+
device = torch.device(
|
107 |
+
"cuda" if torch.cuda.is_available() and args.use_gpu else "cpu"
|
108 |
+
)
|
109 |
+
|
110 |
+
###################### Define Models ######################
|
111 |
+
our_model = load_model_with_args(args)
|
112 |
+
our_model = our_model.to(device)
|
113 |
+
|
114 |
+
target_model_path = f"{args.weight_directory}/{args.target}.pth"
|
115 |
+
checkpoint = torch.load(target_model_path, map_location=device)
|
116 |
+
our_model.load_state_dict(checkpoint)
|
117 |
+
|
118 |
+
our_model.eval()
|
119 |
+
|
120 |
+
meter = pyln.Meter(44100)
|
121 |
+
|
122 |
+
sr, track_audio = input
|
123 |
+
track_audio = track_audio.T
|
124 |
+
track_name = "gradio_demo"
|
125 |
+
|
126 |
+
orig_audio = track_audio.copy()
|
127 |
+
|
128 |
+
if sr != 44100:
|
129 |
+
raise ValueError("Sample rate should be 44100")
|
130 |
+
augmented_gain = None
|
131 |
+
|
132 |
+
if args.loudnorm_input_lufs: # If you want to use loud-normalized input
|
133 |
+
track_lufs = meter.integrated_loudness(track_audio.T)
|
134 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
135 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
136 |
+
|
137 |
+
track_audio = (
|
138 |
+
torch.as_tensor(track_audio, dtype=torch.float32).unsqueeze(0).to(device)
|
139 |
+
)
|
140 |
+
|
141 |
+
estimates = separate_track_with_model(
|
142 |
+
args, our_model, device, track_audio, track_name, meter, augmented_gain
|
143 |
+
)
|
144 |
+
|
145 |
+
if args.save_mixed_output:
|
146 |
+
track_lufs = meter.integrated_loudness(orig_audio.T)
|
147 |
+
augmented_gain = args.save_output_loudnorm - track_lufs
|
148 |
+
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
|
149 |
+
|
150 |
+
mixed_output = orig_audio * args.save_mixed_output + estimates * (
|
151 |
+
1 - args.save_mixed_output
|
152 |
+
)
|
153 |
+
|
154 |
+
sf.write(
|
155 |
+
f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav",
|
156 |
+
mixed_output.T,
|
157 |
+
args.data_params.sample_rate,
|
158 |
+
)
|
159 |
+
|
160 |
+
return (
|
161 |
+
(sr, estimates.T),
|
162 |
+
(sr, orig_audio.T),
|
163 |
+
(sr, orig_audio.T * mix_coefficient + estimates.T * (1 - mix_coefficient)),
|
164 |
+
)
|
165 |
+
|
166 |
+
|
167 |
+
def parallel_mix(input, output, mix_coefficient):
|
168 |
+
sr = 44100
|
169 |
+
return sr, input[1] * mix_coefficient + output[1] * (1 - mix_coefficient)
|
170 |
+
|
171 |
+
|
172 |
+
def int16_to_float32(wav):
|
173 |
+
wav = np.frombuffer(wav, dtype=np.int16)
|
174 |
+
X = wav / 32768
|
175 |
+
return X
|
176 |
+
|
177 |
+
|
178 |
+
def waveform_plot(input, output, prl_mix_ouptut, figsize_x=20, figsize_y=9):
|
179 |
+
sr = 44100
|
180 |
+
fig, ax = plt.subplots(
|
181 |
+
nrows=3, sharex=True, sharey=True, figsize=(figsize_x, figsize_y)
|
182 |
+
)
|
183 |
+
librosa.display.waveshow(int16_to_float32(input[1]).T, sr=sr, ax=ax[0])
|
184 |
+
ax[0].set(title="Loudness Normalized Input")
|
185 |
+
ax[0].label_outer()
|
186 |
+
librosa.display.waveshow(int16_to_float32(output[1]).T, sr=sr, ax=ax[1])
|
187 |
+
ax[1].set(title="De-limiter Output")
|
188 |
+
ax[1].label_outer()
|
189 |
+
librosa.display.waveshow(int16_to_float32(prl_mix_ouptut[1]).T, sr=sr, ax=ax[2])
|
190 |
+
ax[2].set(title="Parallel Mix of the Input and its De-limiter Output")
|
191 |
+
ax[2].label_outer()
|
192 |
+
return fig
|
193 |
+
|
194 |
+
|
195 |
+
with gr.Blocks() as demo:
|
196 |
+
gr.HTML(
|
197 |
+
"""
|
198 |
+
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
199 |
+
<div
|
200 |
+
style="
|
201 |
+
display: inline-flex;
|
202 |
+
align-items: center;
|
203 |
+
gap: 0.8rem;
|
204 |
+
font-size: 1.75rem;
|
205 |
+
"
|
206 |
+
>
|
207 |
+
<h1 style="font-weight: 900; margin-bottom: 7px;">
|
208 |
+
Music De-limiter
|
209 |
+
</h1>
|
210 |
+
</div>
|
211 |
+
<p style="margin-bottom: 10px; font-size: 94%">
|
212 |
+
A demo for "Music De-limiter via Sample-wise Gain Inversion" to appear in WASPAA 2023.
|
213 |
+
You can first upload a music (.wav or .mp3) file and then press "De-limit" button to apply the De-limiter. Since we use a CPU instead of a GPU, it may require a few minute.
|
214 |
+
Then, you can apply a Parallel Mix technique, which is a simple linear mixing technique of "loudness normalized input" and the "de-limiter output".
|
215 |
+
You can modify the mixing coefficient by yourself.
|
216 |
+
If the coefficient is 0.3 then the output will be the "loudness_normalized_input * 0.3 + de-limiter_output * 0.7"
|
217 |
+
</div>
|
218 |
+
"""
|
219 |
+
)
|
220 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
221 |
+
with gr.Column():
|
222 |
+
with gr.Box():
|
223 |
+
input_audio = gr.Audio(source="upload", label="De-limiter Input")
|
224 |
+
btn = gr.Button("De-limit")
|
225 |
+
with gr.Column():
|
226 |
+
with gr.Box():
|
227 |
+
loud_norm_input = gr.Audio(label="Loudness Normalized Input (-14LUFS)")
|
228 |
+
with gr.Box():
|
229 |
+
output_audio = gr.Audio(label="De-limiter Output")
|
230 |
+
with gr.Box():
|
231 |
+
output_audio_parallel = gr.Audio(
|
232 |
+
label="Parallel Mix of the Input and its De-limiter Output"
|
233 |
+
)
|
234 |
+
slider = gr.Slider(
|
235 |
+
minimum=0,
|
236 |
+
maximum=1,
|
237 |
+
step=0.1,
|
238 |
+
value=0.5,
|
239 |
+
label="Parallel Mix Coefficient",
|
240 |
+
)
|
241 |
+
btn.click(
|
242 |
+
main,
|
243 |
+
inputs=[input_audio, slider],
|
244 |
+
outputs=[output_audio, loud_norm_input, output_audio_parallel],
|
245 |
+
)
|
246 |
+
slider.release(
|
247 |
+
parallel_mix,
|
248 |
+
inputs=[input_audio, output_audio, slider],
|
249 |
+
outputs=output_audio_parallel,
|
250 |
+
)
|
251 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
252 |
+
with gr.Column():
|
253 |
+
with gr.Box():
|
254 |
+
plot = gr.Plot(label="Plots")
|
255 |
+
btn2 = gr.Button("Show Plots")
|
256 |
+
slider_plot_x = gr.Slider(
|
257 |
+
minimum=1,
|
258 |
+
maximum=100,
|
259 |
+
step=1,
|
260 |
+
value=20,
|
261 |
+
label="Plot X-axis size",
|
262 |
+
)
|
263 |
+
slider_plot_y = gr.Slider(
|
264 |
+
minimum=1,
|
265 |
+
maximum=30,
|
266 |
+
step=1,
|
267 |
+
value=9,
|
268 |
+
label="Plot Y-axis size",
|
269 |
+
)
|
270 |
+
btn2.click(
|
271 |
+
waveform_plot,
|
272 |
+
inputs=[
|
273 |
+
loud_norm_input,
|
274 |
+
output_audio,
|
275 |
+
output_audio_parallel,
|
276 |
+
slider_plot_x,
|
277 |
+
slider_plot_y,
|
278 |
+
],
|
279 |
+
outputs=plot,
|
280 |
+
)
|
281 |
+
slider.release(
|
282 |
+
waveform_plot,
|
283 |
+
inputs=[
|
284 |
+
loud_norm_input,
|
285 |
+
output_audio,
|
286 |
+
output_audio_parallel,
|
287 |
+
slider_plot_x,
|
288 |
+
slider_plot_y,
|
289 |
+
],
|
290 |
+
outputs=plot,
|
291 |
+
)
|
292 |
+
if __name__ == "__main__":
|
293 |
+
demo.launch(debug=True)
|
configs/delimit_6_s.yaml
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# For De-limit task, Conv-TasNet.
|
2 |
+
# si_sdr loss
|
3 |
+
#
|
4 |
+
# ozone_train_fixed is about 6.36 hours
|
5 |
+
# 300,000 segments is about 333.33 hours
|
6 |
+
# ratio should be about 0.019
|
7 |
+
|
8 |
+
wandb_params:
|
9 |
+
use_wandb: true
|
10 |
+
entity: null # your wandb id
|
11 |
+
project: delimit # your wandb project
|
12 |
+
rerun_id: null # use when you rerun wandb.
|
13 |
+
sweep: false
|
14 |
+
|
15 |
+
sys_params:
|
16 |
+
nb_workers: 4
|
17 |
+
seed: 777
|
18 |
+
n_nodes: 1
|
19 |
+
port: null
|
20 |
+
rank: 0
|
21 |
+
|
22 |
+
task_params:
|
23 |
+
target: all # choices=["all"]
|
24 |
+
train: true
|
25 |
+
dataset: delimit # choices=["musdb", "delimit"]
|
26 |
+
|
27 |
+
dir_params:
|
28 |
+
root: /path/to/musdb18hq
|
29 |
+
output_directory: /path/to/results
|
30 |
+
exp_name: convtasnet_6_s # you MUST specify this
|
31 |
+
resume: null # "path of checkpoint folder"
|
32 |
+
continual_train: false # when we want to use a pre-trained model but not want to use lr_scheduler history.
|
33 |
+
delimit_valid_root: null
|
34 |
+
delimit_valid_L_root: null
|
35 |
+
ozone_root: /path/to/musdb-XL-train # you have to specify data_params.use_fixed
|
36 |
+
|
37 |
+
hyperparams:
|
38 |
+
batch_size: 8 # with 1 gpus (we used 2080ti 11GB)
|
39 |
+
epochs: 200
|
40 |
+
optimizer: adamw
|
41 |
+
weight_decay: 0.01
|
42 |
+
lr: 0.00003
|
43 |
+
lr_decay_gamma: 0.5
|
44 |
+
lr_decay_patience: 15
|
45 |
+
patience: 50
|
46 |
+
lr_scheduler: step_lr
|
47 |
+
gradient_clip: 5.0
|
48 |
+
ema: false
|
49 |
+
|
50 |
+
data_params:
|
51 |
+
nfft: 4096
|
52 |
+
nhop: 1024
|
53 |
+
nb_channels: 2
|
54 |
+
sample_rate: 44100
|
55 |
+
seq_dur: 4.0
|
56 |
+
singleset_num_frames: null
|
57 |
+
samples_per_track: 128 # "Number of samples per track to use for training."
|
58 |
+
limitaug_method: ozone
|
59 |
+
limitaug_mode: null
|
60 |
+
limitaug_custom_target_lufs: null
|
61 |
+
limitaug_custom_target_lufs_std: null
|
62 |
+
target_loudnorm_lufs: -14.0
|
63 |
+
random_mix: true
|
64 |
+
target_limitaug_mode: null
|
65 |
+
target_limitaug_custom_target_lufs: null
|
66 |
+
target_limitaug_custom_target_lufs_std: null
|
67 |
+
custom_limiter_attack_range: null
|
68 |
+
custom_limiter_release_range: null
|
69 |
+
use_fixed: 0.019 # range 0.0 ~ 1.0 => 1.0 will use fixed Ozoned_mixture training examples only.
|
70 |
+
|
71 |
+
model_loss_params:
|
72 |
+
architecture: conv_tasnet_mask_on_output # Sample-wise Gain Inversion (SGI)
|
73 |
+
train_loss_func: [si_sdr]
|
74 |
+
train_loss_scales: [1.]
|
75 |
+
valid_loss_func: [si_sdr]
|
76 |
+
valid_loss_scales: [1.]
|
77 |
+
|
78 |
+
conv_tasnet_params:
|
79 |
+
encoder_activation: relu
|
80 |
+
n_filters: 512
|
81 |
+
kernel_size: 128 # about 3ms in 44100Hz
|
82 |
+
stride: 64
|
83 |
+
n_blocks: 5
|
84 |
+
n_repeats: 2
|
85 |
+
bn_chan: 128
|
86 |
+
hid_chan: 512
|
87 |
+
skip_chan: 128
|
88 |
+
# conv_kernel_size:
|
89 |
+
# norm_type:
|
90 |
+
mask_act: relu
|
91 |
+
# causal:
|
92 |
+
decoder_activation: sigmoid
|
dataloader/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset import aug_from_str, MusdbTrainDataset, MusdbValidDataset
|
2 |
+
from .singleset import SingleTrackSet
|
3 |
+
from .delimit_dataset import (
|
4 |
+
DelimitTrainDataset,
|
5 |
+
DelimitValidDataset,
|
6 |
+
OzoneTrainDataset,
|
7 |
+
OzoneValidDataset,
|
8 |
+
)
|
dataloader/dataset.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataloader based on https://github.com/jeonchangbin49/LimitAug
|
2 |
+
import os
|
3 |
+
from glob import glob
|
4 |
+
import random
|
5 |
+
from typing import Optional, Callable
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import librosa
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
import pyloudnorm as pyln
|
12 |
+
from pedalboard import Pedalboard, Limiter, Gain, Compressor, Clipping
|
13 |
+
|
14 |
+
from utils import load_wav_arbitrary_position_stereo, db2linear
|
15 |
+
|
16 |
+
|
17 |
+
# based on https://github.com/sigsep/open-unmix-pytorch
|
18 |
+
def aug_from_str(list_of_function_names: list):
|
19 |
+
if list_of_function_names:
|
20 |
+
return Compose([globals()["_augment_" + aug] for aug in list_of_function_names])
|
21 |
+
else:
|
22 |
+
return lambda audio: audio
|
23 |
+
|
24 |
+
|
25 |
+
class Compose(object):
|
26 |
+
"""Composes several augmentation transforms.
|
27 |
+
Args:
|
28 |
+
augmentations: list of augmentations to compose.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, transforms):
|
32 |
+
self.transforms = transforms
|
33 |
+
|
34 |
+
def __call__(self, audio: torch.Tensor) -> torch.Tensor:
|
35 |
+
for t in self.transforms:
|
36 |
+
audio = t(audio)
|
37 |
+
return audio
|
38 |
+
|
39 |
+
|
40 |
+
# numpy based augmentation
|
41 |
+
# based on https://github.com/sigsep/open-unmix-pytorch
|
42 |
+
def _augment_gain(audio, low=0.25, high=1.25):
|
43 |
+
"""Applies a random gain between `low` and `high`"""
|
44 |
+
g = low + random.random() * (high - low)
|
45 |
+
return audio * g
|
46 |
+
|
47 |
+
|
48 |
+
def _augment_channelswap(audio):
|
49 |
+
"""Swap channels of stereo signals with a probability of p=0.5"""
|
50 |
+
if audio.shape[0] == 2 and random.random() < 0.5:
|
51 |
+
return np.flip(audio, axis=0) # axis=0 must be given
|
52 |
+
else:
|
53 |
+
return audio
|
54 |
+
|
55 |
+
|
56 |
+
# Linear gain increasing implementation for Method (1)
|
57 |
+
def apply_linear_gain_increase(mixture, target, board, meter, samplerate, target_lufs):
|
58 |
+
mixture, target = mixture.T, target.T
|
59 |
+
loudness = meter.integrated_loudness(mixture)
|
60 |
+
|
61 |
+
if np.isinf(loudness):
|
62 |
+
augmented_gain = 0.0
|
63 |
+
board[0].gain_db = augmented_gain
|
64 |
+
else:
|
65 |
+
augmented_gain = target_lufs - loudness
|
66 |
+
board[0].gain_db = augmented_gain
|
67 |
+
mixture = board(mixture.T, samplerate)
|
68 |
+
target = board(target.T, samplerate)
|
69 |
+
return mixture, target
|
70 |
+
|
71 |
+
|
72 |
+
# LimitAug implementation for Method (2) and
|
73 |
+
# implementation of LimitAug then Loudness normalization for Method (4)
|
74 |
+
def apply_limitaug(
|
75 |
+
audio,
|
76 |
+
board,
|
77 |
+
meter,
|
78 |
+
samplerate,
|
79 |
+
target_lufs,
|
80 |
+
target_loudnorm_lufs=None,
|
81 |
+
loudness=None,
|
82 |
+
):
|
83 |
+
audio = audio.T
|
84 |
+
if loudness is None:
|
85 |
+
loudness = meter.integrated_loudness(audio)
|
86 |
+
|
87 |
+
if np.isinf(loudness):
|
88 |
+
augmented_gain = 0.0
|
89 |
+
board[0].gain_db = augmented_gain
|
90 |
+
else:
|
91 |
+
augmented_gain = target_lufs - loudness
|
92 |
+
board[0].gain_db = augmented_gain
|
93 |
+
audio = board(audio.T, samplerate)
|
94 |
+
|
95 |
+
if target_loudnorm_lufs:
|
96 |
+
after_loudness = meter.integrated_loudness(audio.T)
|
97 |
+
|
98 |
+
if np.isinf(after_loudness):
|
99 |
+
pass
|
100 |
+
else:
|
101 |
+
target_gain = target_loudnorm_lufs - after_loudness
|
102 |
+
audio = audio * db2linear(target_gain)
|
103 |
+
return audio, loudness
|
104 |
+
|
105 |
+
|
106 |
+
"""
|
107 |
+
This dataloader implementation is based on https://github.com/sigsep/open-unmix-pytorch
|
108 |
+
"""
|
109 |
+
|
110 |
+
|
111 |
+
class MusdbTrainDataset(Dataset):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
target: str = "vocals",
|
115 |
+
root: str = None,
|
116 |
+
seq_duration: Optional[float] = 6.0,
|
117 |
+
samples_per_track: int = 64,
|
118 |
+
source_augmentations: Optional[Callable] = lambda audio: audio,
|
119 |
+
sample_rate: int = 44100,
|
120 |
+
seed: int = 42,
|
121 |
+
limitaug_method: str = "limitaug_then_loudnorm",
|
122 |
+
limitaug_mode: str = "normal_L",
|
123 |
+
limitaug_custom_target_lufs: float = None,
|
124 |
+
limitaug_custom_target_lufs_std: float = None,
|
125 |
+
target_loudnorm_lufs: float = -14.0,
|
126 |
+
custom_limiter_attack_range: list = [2.0, 2.0],
|
127 |
+
custom_limiter_release_range: list = [200.0, 200.0],
|
128 |
+
*args,
|
129 |
+
**kwargs,
|
130 |
+
) -> None:
|
131 |
+
"""
|
132 |
+
Parameters
|
133 |
+
----------
|
134 |
+
limitaug_method : str
|
135 |
+
choose from ["linear_gain_increase", "limitaug", "limitaug_then_loudnorm", "only_loudnorm"]
|
136 |
+
limitaug_mode : str
|
137 |
+
choose from ["uniform", "normal", "normal_L", "normal_XL", "normal_short_term", "normal_L_short_term", "normal_XL_short_term", "custom"]
|
138 |
+
limitaug_custom_target_lufs : float
|
139 |
+
valid only when
|
140 |
+
limitaug_mode == "custom"
|
141 |
+
limitaug_custom_target_lufs_std : float
|
142 |
+
also valid only when
|
143 |
+
limitaug_mode == "custom
|
144 |
+
target_loudnorm_lufs : float
|
145 |
+
valid only when
|
146 |
+
limitaug_method == 'limitaug_then_loudnorm' or 'only_loudnorm'
|
147 |
+
default is -14.
|
148 |
+
To the best of my knowledge, Spotify and Youtube music is using -14 as a reference loudness normalization level.
|
149 |
+
No special reason for the choice of -14 as target_loudnorm_lufs.
|
150 |
+
target : str
|
151 |
+
target name of the source to be separated, defaults to ``vocals``.
|
152 |
+
root : str
|
153 |
+
root path of MUSDB
|
154 |
+
seq_duration : float
|
155 |
+
training is performed in chunks of ``seq_duration`` (in seconds,
|
156 |
+
defaults to ``None`` which loads the full audio track
|
157 |
+
samples_per_track : int
|
158 |
+
sets the number of samples, yielded from each track per epoch.
|
159 |
+
Defaults to 64
|
160 |
+
source_augmentations : list[callables]
|
161 |
+
provide list of augmentation function that take a multi-channel
|
162 |
+
audio file of shape (src, samples) as input and output. Defaults to
|
163 |
+
no-augmentations (input = output)
|
164 |
+
seed : int
|
165 |
+
control randomness of dataset iterations
|
166 |
+
args, kwargs : additional keyword arguments
|
167 |
+
used to add further control for the musdb dataset
|
168 |
+
initialization function.
|
169 |
+
"""
|
170 |
+
|
171 |
+
self.seed = seed
|
172 |
+
random.seed(seed)
|
173 |
+
self.seq_duration = seq_duration
|
174 |
+
self.target = target
|
175 |
+
self.samples_per_track = samples_per_track
|
176 |
+
self.source_augmentations = source_augmentations
|
177 |
+
self.sample_rate = sample_rate
|
178 |
+
|
179 |
+
self.root = root
|
180 |
+
self.sources = ["vocals", "bass", "drums", "other"]
|
181 |
+
self.train_list = glob(f"{self.root}/train/*")
|
182 |
+
self.valid_list = [
|
183 |
+
"ANiMAL - Rockshow",
|
184 |
+
"Actions - One Minute Smile",
|
185 |
+
"Alexander Ross - Goodbye Bolero",
|
186 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
187 |
+
"Fergessen - Nos Palpitants",
|
188 |
+
"James May - On The Line",
|
189 |
+
"Johnny Lokke - Promises & Lies",
|
190 |
+
"Leaf - Summerghost",
|
191 |
+
"Meaxic - Take A Step",
|
192 |
+
"Patrick Talbot - A Reason To Leave",
|
193 |
+
"Skelpolu - Human Mistakes",
|
194 |
+
"Traffic Experiment - Sirens",
|
195 |
+
"Triviul - Angelsaint",
|
196 |
+
"Young Griffo - Pennies",
|
197 |
+
]
|
198 |
+
|
199 |
+
self.train_list = [
|
200 |
+
x for x in self.train_list if os.path.basename(x) not in self.valid_list
|
201 |
+
]
|
202 |
+
|
203 |
+
# limitaug related
|
204 |
+
self.limitaug_method = limitaug_method
|
205 |
+
self.limitaug_mode = limitaug_mode
|
206 |
+
self.limitaug_custom_target_lufs = limitaug_custom_target_lufs
|
207 |
+
self.limitaug_custom_target_lufs_std = limitaug_custom_target_lufs_std
|
208 |
+
self.target_loudnorm_lufs = target_loudnorm_lufs
|
209 |
+
self.meter = pyln.Meter(self.sample_rate)
|
210 |
+
|
211 |
+
# Method (1) in our paper's Results section and Table 5
|
212 |
+
if self.limitaug_method == "linear_gain_increase":
|
213 |
+
print("using linear gain increasing!")
|
214 |
+
self.board = Pedalboard([Gain(gain_db=0.0)])
|
215 |
+
|
216 |
+
# Method (2) in our paper's Results section and Table 5
|
217 |
+
elif self.limitaug_method == "limitaug":
|
218 |
+
print("using limitaug!")
|
219 |
+
self.board = Pedalboard(
|
220 |
+
[Gain(gain_db=0.0), Limiter(threshold_db=0.0, release_ms=100.0)]
|
221 |
+
)
|
222 |
+
|
223 |
+
# Method (3) in our paper's Results section and Table 5
|
224 |
+
elif self.limitaug_method == "only_loudnorm":
|
225 |
+
print("using only loudness normalized inputs")
|
226 |
+
|
227 |
+
# Method (4) in our paper's Results section and Table 5
|
228 |
+
elif self.limitaug_method == "limitaug_then_loudnorm":
|
229 |
+
print("using limitaug then loudness normalize!")
|
230 |
+
self.board = Pedalboard(
|
231 |
+
[Gain(gain_db=0.0), Limiter(threshold_db=0.0, release_ms=100.0)]
|
232 |
+
)
|
233 |
+
|
234 |
+
elif self.limitaug_method == "custom_limiter_limitaug":
|
235 |
+
print("using Custom limiter limitaug!")
|
236 |
+
self.custom_limiter_attack_range = custom_limiter_attack_range
|
237 |
+
self.custom_limiter_release_range = custom_limiter_release_range
|
238 |
+
self.board = Pedalboard(
|
239 |
+
[
|
240 |
+
Gain(gain_db=0.0),
|
241 |
+
Compressor(
|
242 |
+
threshold_db=-10.0, ratio=4.0, attack_ms=2.0, release_ms=200.0
|
243 |
+
), # attack_ms and release_ms will be changed later.
|
244 |
+
Compressor(
|
245 |
+
threshold_db=0.0,
|
246 |
+
ratio=1000.0,
|
247 |
+
attack_ms=0.001,
|
248 |
+
release_ms=100.0,
|
249 |
+
),
|
250 |
+
Gain(gain_db=3.75),
|
251 |
+
Clipping(threshold_db=0.0),
|
252 |
+
]
|
253 |
+
) # This implementation is the same as JUCE Limiter.
|
254 |
+
# However, we want the first compressor to have a variable attack and release time.
|
255 |
+
# Therefore, we use the Custom Limiter instead of the JUCE Limiter.
|
256 |
+
|
257 |
+
self.limitaug_mode_statistics = {
|
258 |
+
"normal": [
|
259 |
+
-15.954,
|
260 |
+
1.264,
|
261 |
+
], # -15.954 is mean LUFS of musdb-hq and 1.264 is standard deviation
|
262 |
+
"normal_L": [
|
263 |
+
-10.887,
|
264 |
+
1.191,
|
265 |
+
], # -10.887 is mean LUFS of musdb-L and 1.191 is standard deviation
|
266 |
+
"normal_XL": [
|
267 |
+
-8.608,
|
268 |
+
1.165,
|
269 |
+
], # -8.608 is mean LUFS of musdb-L and 1.165 is standard deviation
|
270 |
+
"normal_short_term": [
|
271 |
+
-17.317,
|
272 |
+
5.036,
|
273 |
+
], # In our experiments, short-term statistics were not helpful.
|
274 |
+
"normal_L_short_term": [-12.303, 5.233],
|
275 |
+
"normal_XL_short_term": [-9.988, 5.518],
|
276 |
+
"custom": [limitaug_custom_target_lufs, limitaug_custom_target_lufs_std],
|
277 |
+
}
|
278 |
+
|
279 |
+
def sample_target_lufs(self):
|
280 |
+
if (
|
281 |
+
self.limitaug_mode == "uniform"
|
282 |
+
): # if limitaug_mode is uniform, then choose target_lufs from uniform distribution
|
283 |
+
target_lufs = random.uniform(-20, -5)
|
284 |
+
else: # else, choose target_lufs from gaussian distribution
|
285 |
+
target_lufs = random.gauss(
|
286 |
+
self.limitaug_mode_statistics[self.limitaug_mode][0],
|
287 |
+
self.limitaug_mode_statistics[self.limitaug_mode][1],
|
288 |
+
)
|
289 |
+
|
290 |
+
return target_lufs
|
291 |
+
|
292 |
+
def get_limitaug_results(self, mixture, target):
|
293 |
+
# Apply linear gain increasing (Method (1))
|
294 |
+
if self.limitaug_method == "linear_gain_increase":
|
295 |
+
target_lufs = self.sample_target_lufs()
|
296 |
+
mixture, target = apply_linear_gain_increase(
|
297 |
+
mixture,
|
298 |
+
target,
|
299 |
+
self.board,
|
300 |
+
self.meter,
|
301 |
+
self.sample_rate,
|
302 |
+
target_lufs=target_lufs,
|
303 |
+
)
|
304 |
+
|
305 |
+
# Apply LimitAug (Method (2))
|
306 |
+
elif self.limitaug_method == "limitaug":
|
307 |
+
self.board[1].release_ms = random.uniform(30.0, 200.0)
|
308 |
+
mixture_orig = mixture.copy()
|
309 |
+
target_lufs = self.sample_target_lufs()
|
310 |
+
mixture, _ = apply_limitaug(
|
311 |
+
mixture,
|
312 |
+
self.board,
|
313 |
+
self.meter,
|
314 |
+
self.sample_rate,
|
315 |
+
target_lufs=target_lufs,
|
316 |
+
)
|
317 |
+
print("mixture shape:", mixture.shape)
|
318 |
+
print("target shape:", target.shape)
|
319 |
+
target *= mixture / (mixture_orig + 1e-8)
|
320 |
+
|
321 |
+
# Apply only loudness normalization (Method(3))
|
322 |
+
elif self.limitaug_method == "only_loudnorm":
|
323 |
+
mixture_loudness = self.meter.integrated_loudness(mixture.T)
|
324 |
+
if np.isinf(
|
325 |
+
mixture_loudness
|
326 |
+
): # if the source is silence, then mixture_loudness is -inf.
|
327 |
+
pass
|
328 |
+
else:
|
329 |
+
augmented_gain = (
|
330 |
+
self.target_loudnorm_lufs - mixture_loudness
|
331 |
+
) # default target_loudnorm_lufs is -14.
|
332 |
+
mixture = mixture * db2linear(augmented_gain)
|
333 |
+
target = target * db2linear(augmented_gain)
|
334 |
+
|
335 |
+
# Apply LimitAug then loudness normalization (Method (4))
|
336 |
+
elif self.limitaug_method == "limitaug_then_loudnorm":
|
337 |
+
self.board[1].release_ms = random.uniform(30.0, 200.0)
|
338 |
+
mixture_orig = mixture.copy()
|
339 |
+
target_lufs = self.sample_target_lufs()
|
340 |
+
mixture, _ = apply_limitaug(
|
341 |
+
mixture,
|
342 |
+
self.board,
|
343 |
+
self.meter,
|
344 |
+
self.sample_rate,
|
345 |
+
target_lufs=target_lufs,
|
346 |
+
target_loudnorm_lufs=self.target_loudnorm_lufs,
|
347 |
+
)
|
348 |
+
target *= mixture / (mixture_orig + 1e-8)
|
349 |
+
|
350 |
+
# Apply LimitAug using Custom Limiter
|
351 |
+
elif self.limitaug_method == "custom_limiter_limitaug":
|
352 |
+
# Change attack time of First compressor of the Limiter
|
353 |
+
self.board[1].attack_ms = random.uniform(
|
354 |
+
self.custom_limiter_attack_range[0], self.custom_limiter_attack_range[1]
|
355 |
+
)
|
356 |
+
# Change release time of First compressor of the Limiter
|
357 |
+
self.board[1].release_ms = random.uniform(
|
358 |
+
self.custom_limiter_release_range[0],
|
359 |
+
self.custom_limiter_release_range[1],
|
360 |
+
)
|
361 |
+
# Change release time of Second compressor of the Limiter
|
362 |
+
self.board[2].release_ms = random.uniform(30.0, 200.0)
|
363 |
+
mixture_orig = mixture.copy()
|
364 |
+
target_lufs = self.sample_target_lufs()
|
365 |
+
mixture, _ = apply_limitaug(
|
366 |
+
mixture,
|
367 |
+
self.board,
|
368 |
+
self.meter,
|
369 |
+
self.sample_rate,
|
370 |
+
target_lufs=target_lufs,
|
371 |
+
target_loudnorm_lufs=self.target_loudnorm_lufs,
|
372 |
+
)
|
373 |
+
target *= mixture / (mixture_orig + 1e-8)
|
374 |
+
|
375 |
+
return mixture, target
|
376 |
+
|
377 |
+
def __getitem__(self, index):
|
378 |
+
audio_sources = []
|
379 |
+
target_ind = None
|
380 |
+
|
381 |
+
for k, source in enumerate(self.sources):
|
382 |
+
# memorize index of target source
|
383 |
+
if source == self.target: # if source is 'vocals'
|
384 |
+
target_ind = k
|
385 |
+
track_path = self.train_list[
|
386 |
+
index // self.samples_per_track
|
387 |
+
] # we want to use # training samples per each track.
|
388 |
+
audio_path = f"{track_path}/{source}.wav"
|
389 |
+
audio = load_wav_arbitrary_position_stereo(
|
390 |
+
audio_path, self.sample_rate, self.seq_duration
|
391 |
+
)
|
392 |
+
else:
|
393 |
+
track_path = random.choice(self.train_list)
|
394 |
+
audio_path = f"{track_path}/{source}.wav"
|
395 |
+
audio = load_wav_arbitrary_position_stereo(
|
396 |
+
audio_path, self.sample_rate, self.seq_duration
|
397 |
+
)
|
398 |
+
audio = self.source_augmentations(audio)
|
399 |
+
audio_sources.append(audio)
|
400 |
+
|
401 |
+
stems = np.stack(audio_sources, axis=0)
|
402 |
+
|
403 |
+
# # apply linear mix over source index=0
|
404 |
+
x = stems.sum(0)
|
405 |
+
# get the target stem
|
406 |
+
y = stems[target_ind]
|
407 |
+
|
408 |
+
# Apply the limitaug,
|
409 |
+
x, y = self.get_limitaug_results(x, y)
|
410 |
+
|
411 |
+
x = torch.as_tensor(x, dtype=torch.float32)
|
412 |
+
y = torch.as_tensor(y, dtype=torch.float32)
|
413 |
+
|
414 |
+
return x, y
|
415 |
+
|
416 |
+
def __len__(self):
|
417 |
+
return len(self.train_list) * self.samples_per_track
|
418 |
+
|
419 |
+
|
420 |
+
class MusdbValidDataset(Dataset):
|
421 |
+
def __init__(
|
422 |
+
self,
|
423 |
+
target: str = "vocals",
|
424 |
+
root: str = None,
|
425 |
+
*args,
|
426 |
+
**kwargs,
|
427 |
+
) -> None:
|
428 |
+
"""MUSDB18 torch.data.Dataset that samples from the MUSDB tracks
|
429 |
+
using track and excerpts with replacement.
|
430 |
+
Parameters
|
431 |
+
----------
|
432 |
+
target : str
|
433 |
+
target name of the source to be separated, defaults to ``vocals``.
|
434 |
+
root : str
|
435 |
+
root path of MUSDB18HQ dataset, defaults to ``None``.
|
436 |
+
args, kwargs : additional keyword arguments
|
437 |
+
used to add further control for the musdb dataset
|
438 |
+
initialization function.
|
439 |
+
"""
|
440 |
+
self.target = target
|
441 |
+
self.sample_rate = 44100.0 # musdb is fixed sample rate
|
442 |
+
|
443 |
+
self.root = root
|
444 |
+
self.sources = ["vocals", "bass", "drums", "other"]
|
445 |
+
self.train_list = glob(f"{self.root}/train/*")
|
446 |
+
|
447 |
+
self.valid_list = [
|
448 |
+
"ANiMAL - Rockshow",
|
449 |
+
"Actions - One Minute Smile",
|
450 |
+
"Alexander Ross - Goodbye Bolero",
|
451 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
452 |
+
"Fergessen - Nos Palpitants",
|
453 |
+
"James May - On The Line",
|
454 |
+
"Johnny Lokke - Promises & Lies",
|
455 |
+
"Leaf - Summerghost",
|
456 |
+
"Meaxic - Take A Step",
|
457 |
+
"Patrick Talbot - A Reason To Leave",
|
458 |
+
"Skelpolu - Human Mistakes",
|
459 |
+
"Traffic Experiment - Sirens",
|
460 |
+
"Triviul - Angelsaint",
|
461 |
+
"Young Griffo - Pennies",
|
462 |
+
]
|
463 |
+
self.valid_list = [
|
464 |
+
x for x in self.train_list if os.path.basename(x) in self.valid_list
|
465 |
+
]
|
466 |
+
|
467 |
+
def __getitem__(self, index):
|
468 |
+
audio_sources = []
|
469 |
+
target_ind = None
|
470 |
+
|
471 |
+
for k, source in enumerate(self.sources):
|
472 |
+
# memorize index of target source
|
473 |
+
if source == self.target: # if source is 'vocals'
|
474 |
+
target_ind = k
|
475 |
+
track_path = self.valid_list[index]
|
476 |
+
song_name = os.path.basename(track_path)
|
477 |
+
audio_path = f"{track_path}/{source}.wav"
|
478 |
+
# audio = utils.load_wav_stereo(audio_path, self.sample_rate)
|
479 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
480 |
+
else:
|
481 |
+
track_path = self.valid_list[index]
|
482 |
+
song_name = os.path.basename(track_path)
|
483 |
+
audio_path = f"{track_path}/{source}.wav"
|
484 |
+
# audio = utils.load_wav_stereo(audio_path, self.sample_rate)
|
485 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
486 |
+
|
487 |
+
audio = torch.as_tensor(audio, dtype=torch.float32)
|
488 |
+
audio_sources.append(audio)
|
489 |
+
|
490 |
+
stems = torch.stack(audio_sources, dim=0)
|
491 |
+
# # apply linear mix over source index=0
|
492 |
+
x = stems.sum(0)
|
493 |
+
# get the target stem
|
494 |
+
y = stems[target_ind]
|
495 |
+
|
496 |
+
return x, y, song_name
|
497 |
+
|
498 |
+
def __len__(self):
|
499 |
+
return len(self.valid_list)
|
500 |
+
|
501 |
+
|
502 |
+
# If you want to check the LUFS values of training examples, run this.
|
503 |
+
if __name__ == "__main__":
|
504 |
+
import argparse
|
505 |
+
|
506 |
+
parser = argparse.ArgumentParser(
|
507 |
+
description="Make musdb-L and musdb-XL dataset from its ratio data"
|
508 |
+
)
|
509 |
+
|
510 |
+
parser.add_argument(
|
511 |
+
"--musdb_root",
|
512 |
+
type=str,
|
513 |
+
default="/path/to/musdb",
|
514 |
+
help="root path of musdb-hq dataset",
|
515 |
+
)
|
516 |
+
parser.add_argument(
|
517 |
+
"--limitaug_method",
|
518 |
+
type=str,
|
519 |
+
default="limitaug",
|
520 |
+
choices=[
|
521 |
+
"linear_gain_increase",
|
522 |
+
"limitaug",
|
523 |
+
"limitaug_then_loudnorm",
|
524 |
+
"only_loudnorm",
|
525 |
+
None,
|
526 |
+
],
|
527 |
+
help="choose limitaug method",
|
528 |
+
)
|
529 |
+
parser.add_argument(
|
530 |
+
"--limitaug_mode",
|
531 |
+
type=str,
|
532 |
+
default="normal_L",
|
533 |
+
choices=[
|
534 |
+
"uniform",
|
535 |
+
"normal",
|
536 |
+
"normal_L",
|
537 |
+
"normal_XL",
|
538 |
+
"normal_short_term",
|
539 |
+
"normal_L_short_term",
|
540 |
+
"normal_XL_short_term",
|
541 |
+
"custom",
|
542 |
+
],
|
543 |
+
help="if you use LimitAug, what lufs distribution to target",
|
544 |
+
)
|
545 |
+
parser.add_argument(
|
546 |
+
"--limitaug_custom_target_lufs",
|
547 |
+
type=float,
|
548 |
+
default=None,
|
549 |
+
help="if limitaug_mode is custom, set custom target lufs for LimitAug",
|
550 |
+
)
|
551 |
+
|
552 |
+
args, _ = parser.parse_known_args()
|
553 |
+
|
554 |
+
source_augmentations_ = aug_from_str(["gain", "channelswap"])
|
555 |
+
|
556 |
+
train_dataset = MusdbTrainDataset(
|
557 |
+
target="vocals",
|
558 |
+
root=args.musdb_root,
|
559 |
+
seq_duration=6.0,
|
560 |
+
source_augmentations=source_augmentations_,
|
561 |
+
limitaug_method=args.limitaug_method,
|
562 |
+
limitaug_mode=args.limitaug_mode,
|
563 |
+
limitaug_custom_target_lufs=args.limitaug_custom_target_lufs,
|
564 |
+
)
|
565 |
+
|
566 |
+
dataloader = torch.utils.data.DataLoader(
|
567 |
+
train_dataset,
|
568 |
+
batch_size=1,
|
569 |
+
shuffle=True,
|
570 |
+
num_workers=4,
|
571 |
+
pin_memory=True,
|
572 |
+
drop_last=False,
|
573 |
+
)
|
574 |
+
|
575 |
+
meter = pyln.Meter(44100)
|
576 |
+
for i in range(5):
|
577 |
+
for x, y in dataloader:
|
578 |
+
loudness = meter.integrated_loudness(x[0].numpy().T)
|
579 |
+
print(f"mixture loudness : {loudness} LUFS")
|
dataloader/delimit_dataset.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from typing import Optional, Callable
|
4 |
+
import json
|
5 |
+
import glob
|
6 |
+
import csv
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import librosa
|
11 |
+
import pyloudnorm as pyln
|
12 |
+
from pedalboard import Pedalboard, Limiter, Gain, Compressor, Clipping
|
13 |
+
|
14 |
+
from .dataset import (
|
15 |
+
MusdbTrainDataset,
|
16 |
+
MusdbValidDataset,
|
17 |
+
apply_limitaug,
|
18 |
+
)
|
19 |
+
from utils import (
|
20 |
+
load_wav_arbitrary_position_stereo,
|
21 |
+
load_wav_specific_position_stereo,
|
22 |
+
db2linear,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DelimitTrainDataset(MusdbTrainDataset):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
target: str = "all",
|
30 |
+
root: str = None,
|
31 |
+
seq_duration: Optional[float] = 6.0,
|
32 |
+
samples_per_track: int = 64,
|
33 |
+
source_augmentations: Optional[Callable] = lambda audio: audio,
|
34 |
+
sample_rate: int = 44100,
|
35 |
+
seed: int = 42,
|
36 |
+
limitaug_method: str = "limitaug",
|
37 |
+
limitaug_mode: str = "normal_L",
|
38 |
+
limitaug_custom_target_lufs: float = None,
|
39 |
+
limitaug_custom_target_lufs_std: float = None,
|
40 |
+
target_loudnorm_lufs: float = -14.0,
|
41 |
+
target_limitaug_mode: str = None,
|
42 |
+
target_limitaug_custom_target_lufs: float = None,
|
43 |
+
target_limitaug_custom_target_lufs_std: float = None,
|
44 |
+
custom_limiter_attack_range: list = [2.0, 2.0],
|
45 |
+
custom_limiter_release_range: list = [200.0, 200.0],
|
46 |
+
*args,
|
47 |
+
**kwargs,
|
48 |
+
) -> None:
|
49 |
+
super().__init__(
|
50 |
+
target=target,
|
51 |
+
root=root,
|
52 |
+
seq_duration=seq_duration,
|
53 |
+
samples_per_track=samples_per_track,
|
54 |
+
source_augmentations=source_augmentations,
|
55 |
+
sample_rate=sample_rate,
|
56 |
+
seed=seed,
|
57 |
+
limitaug_method=limitaug_method,
|
58 |
+
limitaug_mode=limitaug_mode,
|
59 |
+
limitaug_custom_target_lufs=limitaug_custom_target_lufs,
|
60 |
+
limitaug_custom_target_lufs_std=limitaug_custom_target_lufs_std,
|
61 |
+
target_loudnorm_lufs=target_loudnorm_lufs,
|
62 |
+
custom_limiter_attack_range=custom_limiter_attack_range,
|
63 |
+
custom_limiter_release_range=custom_limiter_release_range,
|
64 |
+
*args,
|
65 |
+
**kwargs,
|
66 |
+
)
|
67 |
+
|
68 |
+
self.target_limitaug_mode = target_limitaug_mode
|
69 |
+
|
70 |
+
self.target_limitaug_custom_target_lufs = (target_limitaug_custom_target_lufs,)
|
71 |
+
self.target_limitaug_custom_target_lufs_std = (
|
72 |
+
target_limitaug_custom_target_lufs_std,
|
73 |
+
)
|
74 |
+
self.limitaug_mode_statistics["target_custom"] = [
|
75 |
+
target_limitaug_custom_target_lufs,
|
76 |
+
target_limitaug_custom_target_lufs_std,
|
77 |
+
]
|
78 |
+
|
79 |
+
"""
|
80 |
+
Parameters
|
81 |
+
----------
|
82 |
+
limitaug_method : str
|
83 |
+
choose from ["linear_gain_increase", "limitaug", "limitaug_then_loudnorm", "only_loudnorm"]
|
84 |
+
limitaug_mode : str
|
85 |
+
choose from ["uniform", "normal", "normal_L", "normal_XL", "normal_short_term", "normal_L_short_term", "normal_XL_short_term", "custom"]
|
86 |
+
limitaug_custom_target_lufs : float
|
87 |
+
valid only when
|
88 |
+
limitaug_mode == "custom"
|
89 |
+
target_loudnorm_lufs : float
|
90 |
+
valid only when
|
91 |
+
limitaug_method == 'limitaug_then_loudnorm' or 'only_loudnorm'
|
92 |
+
default is -14.
|
93 |
+
To the best of my knowledge, Spotify and Youtube music is using -14 as a reference loudness normalization level.
|
94 |
+
No special reason for the choice of -14 as target_loudnorm_lufs.
|
95 |
+
target : str
|
96 |
+
target name of the source to be separated, defaults to ``vocals``.
|
97 |
+
root : str
|
98 |
+
root path of MUSDB
|
99 |
+
seq_duration : float
|
100 |
+
training is performed in chunks of ``seq_duration`` (in seconds,
|
101 |
+
defaults to ``None`` which loads the full audio track
|
102 |
+
samples_per_track : int
|
103 |
+
sets the number of samples, yielded from each track per epoch.
|
104 |
+
Defaults to 64
|
105 |
+
source_augmentations : list[callables]
|
106 |
+
provide list of augmentation function that take a multi-channel
|
107 |
+
audio file of shape (src, samples) as input and output. Defaults to
|
108 |
+
no-augmentations (input = output)
|
109 |
+
seed : int
|
110 |
+
control randomness of dataset iterations
|
111 |
+
args, kwargs : additional keyword arguments
|
112 |
+
used to add further control for the musdb dataset
|
113 |
+
initialization function.
|
114 |
+
"""
|
115 |
+
|
116 |
+
# Get a limitaug result without target (individual stem source)
|
117 |
+
def get_limitaug_mixture(self, mixture):
|
118 |
+
if self.limitaug_method == "limitaug":
|
119 |
+
self.board[1].release_ms = random.uniform(30.0, 200.0)
|
120 |
+
target_lufs = self.sample_target_lufs()
|
121 |
+
mixture_limited, mixture_lufs = apply_limitaug(
|
122 |
+
mixture,
|
123 |
+
self.board,
|
124 |
+
self.meter,
|
125 |
+
self.sample_rate,
|
126 |
+
target_lufs=target_lufs,
|
127 |
+
)
|
128 |
+
|
129 |
+
elif self.limitaug_method == "limitaug_then_loudnorm":
|
130 |
+
self.board[1].release_ms = random.uniform(30.0, 200.0)
|
131 |
+
target_lufs = self.sample_target_lufs()
|
132 |
+
mixture_limited, mixture_lufs = (
|
133 |
+
apply_limitaug(
|
134 |
+
mixture,
|
135 |
+
self.board,
|
136 |
+
self.meter,
|
137 |
+
self.sample_rate,
|
138 |
+
target_lufs=target_lufs,
|
139 |
+
target_loudnorm_lufs=self.target_loudnorm_lufs,
|
140 |
+
),
|
141 |
+
)
|
142 |
+
|
143 |
+
# Apply LimitAug using Custom Limiter
|
144 |
+
elif self.limitaug_method == "custom_limiter_limitaug":
|
145 |
+
# Change attack time of First compressor of the Limiter
|
146 |
+
self.board[1].attack_ms = random.uniform(
|
147 |
+
self.custom_limiter_attack_range[0], self.custom_limiter_attack_range[1]
|
148 |
+
)
|
149 |
+
# Change release time of First compressor of the Limiter
|
150 |
+
self.board[1].release_ms = random.uniform(
|
151 |
+
self.custom_limiter_release_range[0],
|
152 |
+
self.custom_limiter_release_range[1],
|
153 |
+
)
|
154 |
+
# Change release time of Second compressor of the Limiter
|
155 |
+
self.board[2].release_ms = random.uniform(30.0, 200.0)
|
156 |
+
target_lufs = self.sample_target_lufs()
|
157 |
+
mixture_limited, mixture_lufs = apply_limitaug(
|
158 |
+
mixture,
|
159 |
+
self.board,
|
160 |
+
self.meter,
|
161 |
+
self.sample_rate,
|
162 |
+
target_lufs=target_lufs,
|
163 |
+
target_loudnorm_lufs=self.target_loudnorm_lufs,
|
164 |
+
)
|
165 |
+
|
166 |
+
# When we want to force NN to output an appropriately compressed target output
|
167 |
+
if self.target_limitaug_mode:
|
168 |
+
mixture_target_lufs = random.gauss(
|
169 |
+
self.limitaug_mode_statistics[self.target_limitaug_mode][0],
|
170 |
+
self.limitaug_mode_statistics[self.target_limitaug_mode][1],
|
171 |
+
)
|
172 |
+
mixture, target_lufs = apply_limitaug(
|
173 |
+
mixture,
|
174 |
+
self.board,
|
175 |
+
self.meter,
|
176 |
+
self.sample_rate,
|
177 |
+
target_lufs=mixture_target_lufs,
|
178 |
+
loudness=mixture_lufs,
|
179 |
+
)
|
180 |
+
|
181 |
+
if np.isinf(mixture_lufs):
|
182 |
+
mixture_loudnorm = mixture
|
183 |
+
else:
|
184 |
+
augmented_gain = self.target_loudnorm_lufs - mixture_lufs
|
185 |
+
mixture_loudnorm = mixture * db2linear(augmented_gain, eps=0.0)
|
186 |
+
|
187 |
+
return mixture_limited, mixture_loudnorm
|
188 |
+
|
189 |
+
def __getitem__(self, index):
|
190 |
+
audio_sources = []
|
191 |
+
|
192 |
+
for k, source in enumerate(self.sources):
|
193 |
+
# memorize index of target source
|
194 |
+
if source == self.target: # if source is 'vocals'
|
195 |
+
track_path = self.train_list[
|
196 |
+
index // self.samples_per_track
|
197 |
+
] # we want to use # training samples per each track.
|
198 |
+
audio_path = f"{track_path}/{source}.wav"
|
199 |
+
audio = load_wav_arbitrary_position_stereo(
|
200 |
+
audio_path, self.sample_rate, self.seq_duration
|
201 |
+
)
|
202 |
+
else:
|
203 |
+
track_path = random.choice(self.train_list)
|
204 |
+
audio_path = f"{track_path}/{source}.wav"
|
205 |
+
audio = load_wav_arbitrary_position_stereo(
|
206 |
+
audio_path, self.sample_rate, self.seq_duration
|
207 |
+
)
|
208 |
+
audio = self.source_augmentations(audio)
|
209 |
+
audio_sources.append(audio)
|
210 |
+
|
211 |
+
stems = np.stack(audio_sources, axis=0)
|
212 |
+
|
213 |
+
# apply linear mix over source index=0
|
214 |
+
# and here, linear mixture is a target unlike in MusdbTrainDataset
|
215 |
+
mixture = stems.sum(0)
|
216 |
+
mixture_limited, mixture_loudnorm = self.get_limitaug_mixture(mixture)
|
217 |
+
# We will give mixture_limited as an input and mixture_loudnorm as a target to the model.
|
218 |
+
|
219 |
+
mixture_limited = np.clip(mixture_limited, -1.0, 1.0)
|
220 |
+
mixture_limited = torch.as_tensor(mixture_limited, dtype=torch.float32)
|
221 |
+
mixture_loudnorm = torch.as_tensor(mixture_loudnorm, dtype=torch.float32)
|
222 |
+
|
223 |
+
return mixture_limited, mixture_loudnorm
|
224 |
+
|
225 |
+
|
226 |
+
class OzoneTrainDataset(DelimitTrainDataset):
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
target: str = "all",
|
230 |
+
root: str = None,
|
231 |
+
ozone_root: str = None,
|
232 |
+
use_fixed: float = 0.1, # ratio of fixed samples
|
233 |
+
seq_duration: Optional[float] = 6.0,
|
234 |
+
samples_per_track: int = 64,
|
235 |
+
source_augmentations: Optional[Callable] = lambda audio: audio,
|
236 |
+
sample_rate: int = 44100,
|
237 |
+
seed: int = 42,
|
238 |
+
limitaug_method: str = "limitaug",
|
239 |
+
limitaug_mode: str = "normal_L",
|
240 |
+
limitaug_custom_target_lufs: float = None,
|
241 |
+
limitaug_custom_target_lufs_std: float = None,
|
242 |
+
target_loudnorm_lufs: float = -14.0,
|
243 |
+
target_limitaug_mode: str = None,
|
244 |
+
target_limitaug_custom_target_lufs: float = None,
|
245 |
+
target_limitaug_custom_target_lufs_std: float = None,
|
246 |
+
custom_limiter_attack_range: list = [2.0, 2.0],
|
247 |
+
custom_limiter_release_range: list = [200.0, 200.0],
|
248 |
+
*args,
|
249 |
+
**kwargs,
|
250 |
+
) -> None:
|
251 |
+
super().__init__(
|
252 |
+
target,
|
253 |
+
root,
|
254 |
+
seq_duration,
|
255 |
+
samples_per_track,
|
256 |
+
source_augmentations,
|
257 |
+
sample_rate,
|
258 |
+
seed,
|
259 |
+
limitaug_method,
|
260 |
+
limitaug_mode,
|
261 |
+
limitaug_custom_target_lufs,
|
262 |
+
limitaug_custom_target_lufs_std,
|
263 |
+
target_loudnorm_lufs,
|
264 |
+
target_limitaug_mode,
|
265 |
+
target_limitaug_custom_target_lufs,
|
266 |
+
target_limitaug_custom_target_lufs_std,
|
267 |
+
custom_limiter_attack_range,
|
268 |
+
custom_limiter_release_range,
|
269 |
+
*args,
|
270 |
+
**kwargs,
|
271 |
+
)
|
272 |
+
|
273 |
+
self.ozone_root = ozone_root
|
274 |
+
self.use_fixed = use_fixed
|
275 |
+
self.list_train_fixed = glob.glob(f"{self.ozone_root}/ozone_train_fixed/*.wav")
|
276 |
+
self.list_train_random = glob.glob(
|
277 |
+
f"{self.ozone_root}/ozone_train_random/*.wav"
|
278 |
+
)
|
279 |
+
self.dict_train_random = {}
|
280 |
+
|
281 |
+
# Load information of pre-generated random training examples
|
282 |
+
list_csv_files = glob.glob(f"{self.ozone_root}/ozone_train_random_*.csv")
|
283 |
+
list_csv_files.sort()
|
284 |
+
for csv_file in list_csv_files:
|
285 |
+
with open(csv_file, "r") as f:
|
286 |
+
reader = csv.reader(f)
|
287 |
+
next(reader)
|
288 |
+
for row in reader:
|
289 |
+
self.dict_train_random[row[0]] = {
|
290 |
+
"max_threshold": float(row[1]),
|
291 |
+
"max_character": float(row[2]),
|
292 |
+
"vocals": {
|
293 |
+
"name": row[3],
|
294 |
+
"start_sec": float(row[4]),
|
295 |
+
"gain": float(row[5]),
|
296 |
+
"channelswap": bool(row[6]),
|
297 |
+
},
|
298 |
+
"bass": {
|
299 |
+
"name": row[7],
|
300 |
+
"start_sec": float(row[8]),
|
301 |
+
"gain": float(row[9]),
|
302 |
+
"channelswap": bool(row[10]),
|
303 |
+
},
|
304 |
+
"drums": {
|
305 |
+
"name": row[11],
|
306 |
+
"start_sec": float(row[12]),
|
307 |
+
"gain": float(row[13]),
|
308 |
+
"channelswap": bool(row[14]),
|
309 |
+
},
|
310 |
+
"other": {
|
311 |
+
"name": row[15],
|
312 |
+
"start_sec": float(row[16]),
|
313 |
+
"gain": float(row[17]),
|
314 |
+
"channelswap": bool(row[18]),
|
315 |
+
},
|
316 |
+
}
|
317 |
+
|
318 |
+
def __getitem__(self, idx):
|
319 |
+
use_fixed_prob = random.random()
|
320 |
+
|
321 |
+
if use_fixed_prob <= self.use_fixed:
|
322 |
+
# Fixed examples
|
323 |
+
audio_path = random.choice(self.list_train_fixed)
|
324 |
+
song_name = os.path.basename(audio_path).replace(".wav", "")
|
325 |
+
mixture_limited, start_pos_sec = load_wav_arbitrary_position_stereo(
|
326 |
+
audio_path, self.sample_rate, self.seq_duration, return_pos=True
|
327 |
+
)
|
328 |
+
|
329 |
+
audio_sources = []
|
330 |
+
track_path = f"{self.root}/train/{song_name}"
|
331 |
+
for source in self.sources:
|
332 |
+
audio_path = f"{track_path}/{source}.wav"
|
333 |
+
audio = load_wav_specific_position_stereo(
|
334 |
+
audio_path,
|
335 |
+
self.sample_rate,
|
336 |
+
self.seq_duration,
|
337 |
+
start_position=start_pos_sec,
|
338 |
+
)
|
339 |
+
audio_sources.append(audio)
|
340 |
+
|
341 |
+
else:
|
342 |
+
# Random examples
|
343 |
+
# Load mixture_limited (pre-generated)
|
344 |
+
audio_path = random.choice(self.list_train_random)
|
345 |
+
seg_name = os.path.basename(audio_path).replace(".wav", "")
|
346 |
+
mixture_limited, sr = librosa.load(
|
347 |
+
audio_path, sr=self.sample_rate, mono=False
|
348 |
+
)
|
349 |
+
|
350 |
+
# Load mixture_unlimited (from the original musdb18, using metadata)
|
351 |
+
audio_sources = []
|
352 |
+
for source in self.sources:
|
353 |
+
dict_seg_info = self.dict_train_random[seg_name]
|
354 |
+
dict_seg_source_info = dict_seg_info[source]
|
355 |
+
audio_path = (
|
356 |
+
f"{self.root}/train/{dict_seg_source_info['name']}/{source}.wav"
|
357 |
+
)
|
358 |
+
audio = load_wav_specific_position_stereo(
|
359 |
+
audio_path,
|
360 |
+
self.sample_rate,
|
361 |
+
self.seq_duration,
|
362 |
+
start_position=dict_seg_source_info["start_sec"],
|
363 |
+
)
|
364 |
+
|
365 |
+
# apply augmentations
|
366 |
+
audio = audio * dict_seg_source_info["gain"]
|
367 |
+
if dict_seg_source_info["channelswap"]:
|
368 |
+
audio = np.flip(audio, axis=0)
|
369 |
+
|
370 |
+
audio_sources.append(audio)
|
371 |
+
|
372 |
+
stems = np.stack(audio_sources, axis=0)
|
373 |
+
mixture = stems.sum(axis=0)
|
374 |
+
mixture_lufs = self.meter.integrated_loudness(mixture.T)
|
375 |
+
if np.isinf(mixture_lufs):
|
376 |
+
mixture_loudnorm = mixture
|
377 |
+
else:
|
378 |
+
augmented_gain = self.target_loudnorm_lufs - mixture_lufs
|
379 |
+
mixture_loudnorm = mixture * db2linear(augmented_gain, eps=0.0)
|
380 |
+
|
381 |
+
return mixture_limited, mixture_loudnorm
|
382 |
+
|
383 |
+
|
384 |
+
class DelimitValidDataset(MusdbValidDataset):
|
385 |
+
def __init__(
|
386 |
+
self,
|
387 |
+
target: str = "vocals",
|
388 |
+
root: str = None,
|
389 |
+
delimit_valid_root: str = None,
|
390 |
+
valid_target_lufs: float = -8.05, # From the Table 1 of the "Towards robust music source separation on loud commercial music" paper, the average loudness of commerical music.
|
391 |
+
target_loudnorm_lufs: float = -14.0,
|
392 |
+
delimit_valid_L_root: str = None, # This will be used when using the target as compressed (normal_L) mixture.
|
393 |
+
use_custom_limiter: bool = False,
|
394 |
+
custom_limiter_attack_range: list = [0.1, 10.0],
|
395 |
+
custom_limiter_release_range: list = [30.0, 200.0],
|
396 |
+
*args,
|
397 |
+
**kwargs,
|
398 |
+
) -> None:
|
399 |
+
super().__init__(target=target, root=root, *args, **kwargs)
|
400 |
+
self.delimit_valid_root = delimit_valid_root
|
401 |
+
if self.delimit_valid_root:
|
402 |
+
with open(f"{self.delimit_valid_root}/valid_loudness.json", "r") as f:
|
403 |
+
self.dict_valid_loudness = json.load(f)
|
404 |
+
self.delimit_valid_L_root = delimit_valid_L_root
|
405 |
+
if self.delimit_valid_L_root:
|
406 |
+
with open(f"{self.delimit_valid_L_root}/valid_loudness.json", "r") as f:
|
407 |
+
self.dict_valid_L_loudness = json.load(f)
|
408 |
+
|
409 |
+
self.valid_target_lufs = valid_target_lufs
|
410 |
+
self.target_loudnorm_lufs = target_loudnorm_lufs
|
411 |
+
self.meter = pyln.Meter(self.sample_rate)
|
412 |
+
self.use_custom_limiter = use_custom_limiter
|
413 |
+
|
414 |
+
if self.use_custom_limiter:
|
415 |
+
print("using Custom limiter limitaug for validation!!")
|
416 |
+
self.custom_limiter_attack_range = custom_limiter_attack_range
|
417 |
+
self.custom_limiter_release_range = custom_limiter_release_range
|
418 |
+
self.board = Pedalboard(
|
419 |
+
[
|
420 |
+
Gain(gain_db=0.0),
|
421 |
+
Compressor(
|
422 |
+
threshold_db=-10.0, ratio=4.0, attack_ms=2.0, release_ms=200.0
|
423 |
+
), # attack_ms and release_ms will be changed later.
|
424 |
+
Compressor(
|
425 |
+
threshold_db=0.0,
|
426 |
+
ratio=1000.0,
|
427 |
+
attack_ms=0.001,
|
428 |
+
release_ms=100.0,
|
429 |
+
),
|
430 |
+
Gain(gain_db=3.75),
|
431 |
+
Clipping(threshold_db=0.0),
|
432 |
+
]
|
433 |
+
) # This implementation is the same as JUCE Limiter.
|
434 |
+
# However, we want the first compressor to have a variable attack and release time.
|
435 |
+
# Therefore, we use the Custom Limiter instead of the JUCE Limiter.
|
436 |
+
else:
|
437 |
+
self.board = Pedalboard(
|
438 |
+
[Gain(gain_db=0.0), Limiter(threshold_db=0.0, release_ms=100.0)]
|
439 |
+
) # Currently, we are using a limiter with a release time of 100ms.
|
440 |
+
|
441 |
+
def __getitem__(self, index):
|
442 |
+
audio_sources = []
|
443 |
+
target_ind = None
|
444 |
+
|
445 |
+
for k, source in enumerate(self.sources):
|
446 |
+
# memorize index of target source
|
447 |
+
if source == self.target: # if source is 'vocals'
|
448 |
+
target_ind = k
|
449 |
+
track_path = self.valid_list[index]
|
450 |
+
song_name = os.path.basename(track_path)
|
451 |
+
audio_path = f"{track_path}/{source}.wav"
|
452 |
+
# audio = utils.load_wav_stereo(audio_path, self.sample_rate)
|
453 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
454 |
+
else:
|
455 |
+
track_path = self.valid_list[index]
|
456 |
+
song_name = os.path.basename(track_path)
|
457 |
+
audio_path = f"{track_path}/{source}.wav"
|
458 |
+
# audio = utils.load_wav_stereo(audio_path, self.sample_rate)
|
459 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
460 |
+
|
461 |
+
audio = torch.as_tensor(audio, dtype=torch.float32)
|
462 |
+
audio_sources.append(audio)
|
463 |
+
|
464 |
+
stems = np.stack(audio_sources, axis=0)
|
465 |
+
|
466 |
+
# apply linear mix over source index=0
|
467 |
+
# and here, linear mixture is a target unlike in MusdbTrainDataset
|
468 |
+
mixture = stems.sum(0)
|
469 |
+
if (
|
470 |
+
self.delimit_valid_root
|
471 |
+
): # If there exists a pre-processed delimit valid dataset
|
472 |
+
audio_path = f"{self.delimit_valid_root}/valid/{song_name}.wav"
|
473 |
+
mixture_limited = librosa.load(audio_path, mono=False, sr=self.sample_rate)[
|
474 |
+
0
|
475 |
+
]
|
476 |
+
mixture_lufs = self.dict_valid_loudness[song_name]
|
477 |
+
|
478 |
+
else:
|
479 |
+
if self.use_custom_limiter:
|
480 |
+
custom_limiter_attack = random.uniform(
|
481 |
+
self.custom_limiter_attack_range[0],
|
482 |
+
self.custom_limiter_attack_range[1],
|
483 |
+
)
|
484 |
+
self.board[1].attack_ms = custom_limiter_attack
|
485 |
+
|
486 |
+
custom_limiter_release = random.uniform(
|
487 |
+
self.custom_limiter_release_range[0],
|
488 |
+
self.custom_limiter_release_range[1],
|
489 |
+
)
|
490 |
+
self.board[1].release_ms = custom_limiter_release
|
491 |
+
|
492 |
+
mixture_limited, mixture_lufs = apply_limitaug(
|
493 |
+
mixture,
|
494 |
+
self.board,
|
495 |
+
self.meter,
|
496 |
+
self.sample_rate,
|
497 |
+
target_lufs=self.valid_target_lufs,
|
498 |
+
)
|
499 |
+
else:
|
500 |
+
mixture_limited, mixture_lufs = apply_limitaug(
|
501 |
+
mixture,
|
502 |
+
self.board,
|
503 |
+
self.meter,
|
504 |
+
self.sample_rate,
|
505 |
+
target_lufs=self.valid_target_lufs,
|
506 |
+
# target_loudnorm_lufs=self.target_loudnorm_lufs,
|
507 |
+
) # mixture_limited is a limiter applied mixture
|
508 |
+
# We will give mixture_limited as an input and mixture_loudnorm as a target to the model.
|
509 |
+
|
510 |
+
if self.delimit_valid_L_root:
|
511 |
+
audio_L_path = f"{self.delimit_valid_L_root}/valid/{song_name}.wav"
|
512 |
+
mixture_loudnorm = librosa.load(
|
513 |
+
audio_L_path, mono=False, sr=self.sample_rate
|
514 |
+
)[0]
|
515 |
+
mixture_lufs = self.dict_valid_L_loudness[song_name]
|
516 |
+
mixture = mixture_loudnorm
|
517 |
+
|
518 |
+
augmented_gain = self.target_loudnorm_lufs - mixture_lufs
|
519 |
+
mixture_loudnorm = mixture * db2linear(augmented_gain)
|
520 |
+
|
521 |
+
if self.use_custom_limiter:
|
522 |
+
return (
|
523 |
+
mixture_limited,
|
524 |
+
mixture_loudnorm,
|
525 |
+
song_name,
|
526 |
+
mixture_lufs,
|
527 |
+
custom_limiter_attack,
|
528 |
+
custom_limiter_release,
|
529 |
+
)
|
530 |
+
else:
|
531 |
+
return mixture_limited, mixture_loudnorm, song_name, mixture_lufs
|
532 |
+
|
533 |
+
|
534 |
+
class OzoneValidDataset(MusdbValidDataset):
|
535 |
+
def __init__(
|
536 |
+
self,
|
537 |
+
target: str = "all",
|
538 |
+
root: str = None,
|
539 |
+
ozone_root: str = None,
|
540 |
+
target_loudnorm_lufs: float = -14.0,
|
541 |
+
*args,
|
542 |
+
**kwargs,
|
543 |
+
) -> None:
|
544 |
+
super().__init__(target=target, root=root, *args, **kwargs)
|
545 |
+
|
546 |
+
self.ozone_root = ozone_root
|
547 |
+
self.target_loudnorm_lufs = target_loudnorm_lufs
|
548 |
+
|
549 |
+
with open(f"{self.ozone_root}/valid_loudness.json", "r") as f:
|
550 |
+
self.dict_valid_loudness = json.load(f)
|
551 |
+
|
552 |
+
def __getitem__(self, index):
|
553 |
+
audio_sources = []
|
554 |
+
|
555 |
+
track_path = self.valid_list[index]
|
556 |
+
song_name = os.path.basename(track_path)
|
557 |
+
for k, source in enumerate(self.sources):
|
558 |
+
audio_path = f"{track_path}/{source}.wav"
|
559 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
560 |
+
audio_sources.append(audio)
|
561 |
+
|
562 |
+
stems = np.stack(audio_sources, axis=0)
|
563 |
+
|
564 |
+
mixture = stems.sum(0)
|
565 |
+
|
566 |
+
audio_path = f"{self.ozone_root}/ozone_train_fixed/{song_name}.wav"
|
567 |
+
mixture_limited = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
568 |
+
|
569 |
+
mixture_lufs = self.dict_valid_loudness[song_name]
|
570 |
+
augmented_gain = self.target_loudnorm_lufs - mixture_lufs
|
571 |
+
mixture_loudnorm = mixture * db2linear(augmented_gain)
|
572 |
+
|
573 |
+
return mixture_limited, mixture_loudnorm, song_name, mixture_lufs
|
dataloader/singleset.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
# Modified version from woosungchoi's original implementation
|
8 |
+
class SingleTrackSet(Dataset):
|
9 |
+
def __init__(self, track, hop_length, num_frame=128, target_name="vocals"):
|
10 |
+
|
11 |
+
assert len(track.shape) == 2
|
12 |
+
assert track.shape[0] == 2 # check stereo audio
|
13 |
+
|
14 |
+
self.hop_length = hop_length
|
15 |
+
self.window_length = hop_length * (num_frame - 1) # 130048
|
16 |
+
self.trim_length = self.get_trim_length(self.hop_length) # 5120
|
17 |
+
|
18 |
+
self.true_samples = self.window_length - 2 * self.trim_length # 119808
|
19 |
+
|
20 |
+
self.lengths = [track.shape[1]] # track lengths (in sample level)
|
21 |
+
self.source_names = [
|
22 |
+
"vocals",
|
23 |
+
"drums",
|
24 |
+
"bass",
|
25 |
+
"other",
|
26 |
+
] # == self.musdb_train.targets_names[:-2]
|
27 |
+
|
28 |
+
self.target_names = [target_name]
|
29 |
+
|
30 |
+
self.num_tracks = 1
|
31 |
+
|
32 |
+
import math
|
33 |
+
|
34 |
+
num_chunks = [
|
35 |
+
math.ceil(length / self.true_samples) for length in self.lengths
|
36 |
+
] # example : 44.1khz 180sec audio, => [67]
|
37 |
+
self.acc_chunk_final_ids = [
|
38 |
+
sum(num_chunks[: i + 1]) for i in range(self.num_tracks)
|
39 |
+
] # [67]
|
40 |
+
|
41 |
+
self.cache_mode = True
|
42 |
+
self.cache = {}
|
43 |
+
self.cache[0] = {}
|
44 |
+
self.cache[0]["linear_mixture"] = track
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return self.acc_chunk_final_ids[-1] * len(self.target_names) # 67
|
48 |
+
|
49 |
+
def __getitem__(self, idx):
|
50 |
+
|
51 |
+
target_offset = idx % len(self.target_names) # 0
|
52 |
+
idx = idx // len(self.target_names) # idx
|
53 |
+
|
54 |
+
target_name = self.target_names[target_offset] # 'vocals'
|
55 |
+
mixture_idx, start_pos = self.idx_to_track_offset(
|
56 |
+
idx
|
57 |
+
) # idx * self.true_samples
|
58 |
+
|
59 |
+
length = self.true_samples
|
60 |
+
left_padding_num = right_padding_num = self.trim_length # 5120
|
61 |
+
if mixture_idx is None:
|
62 |
+
raise StopIteration
|
63 |
+
mixture_length = self.lengths[mixture_idx]
|
64 |
+
if start_pos + length > mixture_length: # last
|
65 |
+
right_padding_num += self.true_samples - (mixture_length - start_pos)
|
66 |
+
length = None
|
67 |
+
|
68 |
+
mixture = self.get_audio(mixture_idx, "linear_mixture", start_pos, length)
|
69 |
+
mixture = F.pad(mixture, (left_padding_num, right_padding_num), "constant", 0)
|
70 |
+
|
71 |
+
return mixture
|
72 |
+
|
73 |
+
def idx_to_track_offset(self, idx):
|
74 |
+
|
75 |
+
for i, last_chunk in enumerate(self.acc_chunk_final_ids):
|
76 |
+
if idx < last_chunk:
|
77 |
+
if i != 0:
|
78 |
+
offset = (idx - self.acc_chunk_final_ids[i - 1]) * self.true_samples
|
79 |
+
else:
|
80 |
+
offset = idx * self.true_samples
|
81 |
+
return i, offset
|
82 |
+
|
83 |
+
return None, None
|
84 |
+
|
85 |
+
def get_audio(self, idx, target_name, pos=0, length=None):
|
86 |
+
track = self.cache[idx][target_name]
|
87 |
+
return track[:, pos : pos + length] if length is not None else track[:, pos:]
|
88 |
+
|
89 |
+
def get_trim_length(self, hop_length, min_trim=5000):
|
90 |
+
trim_per_hop = math.ceil(min_trim / hop_length)
|
91 |
+
|
92 |
+
trim_length = trim_per_hop * hop_length
|
93 |
+
assert trim_per_hop > 1
|
94 |
+
return trim_length
|
95 |
+
|
eval_delimit/calc_flops.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from deepspeed.profiling.flops_profiler import get_model_profile
|
7 |
+
|
8 |
+
from utils import get_config
|
9 |
+
from models import load_model_with_args
|
10 |
+
|
11 |
+
|
12 |
+
# def main():
|
13 |
+
parser = argparse.ArgumentParser(description="FLOPs calculation")
|
14 |
+
|
15 |
+
parser.add_argument(
|
16 |
+
"-c", "--config", default="delimit_6_s", type=str, help="Name of the setting file."
|
17 |
+
)
|
18 |
+
|
19 |
+
config_args = parser.parse_args()
|
20 |
+
|
21 |
+
args = get_config(config_args.config)
|
22 |
+
print(args)
|
23 |
+
|
24 |
+
with torch.cuda.device(0):
|
25 |
+
model = load_model_with_args(args)
|
26 |
+
batch_size = 1
|
27 |
+
flops, macs, params = get_model_profile(
|
28 |
+
model=model, # model
|
29 |
+
input_shape=(batch_size, 2, 44100 * 60), # input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
|
30 |
+
args=[], # list of positional arguments to the model.
|
31 |
+
kwargs={}, # dictionary of keyword arguments to the model.
|
32 |
+
print_profile=True, # prints the model graph with the measured profile attached to each module
|
33 |
+
detailed=True, # print the detailed profile
|
34 |
+
module_depth=-1, # depth into the nested modules, with -1 being the inner most modules
|
35 |
+
top_modules=1, # the number of top modules to print aggregated profile
|
36 |
+
warm_up=1, # the number of warm-ups before measuring the time of each module
|
37 |
+
as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
|
38 |
+
output_file=None, # path to the output file. If None, the profiler prints to stdout.
|
39 |
+
ignore_modules=None,
|
40 |
+
) # the list of modules to ignore in the profiling
|
41 |
+
print(args.dir_params.exp_name)
|
42 |
+
print('flops: ', flops)
|
43 |
+
print('macs: ', macs)
|
44 |
+
print('params: ', params)
|
eval_delimit/score_calc_delimit.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Calculate SI-SDR, Multi-resolution spectrogram mse score of the pre-inferenced sources
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import csv
|
5 |
+
import json
|
6 |
+
import glob
|
7 |
+
|
8 |
+
import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import librosa
|
11 |
+
import pyloudnorm as pyln
|
12 |
+
from asteroid.metrics import get_metrics
|
13 |
+
|
14 |
+
from utils import str2bool
|
15 |
+
|
16 |
+
|
17 |
+
def multi_resolution_spectrogram_mse(
|
18 |
+
gt, est, n_fft=[2048, 1024, 512], n_hop=[512, 256, 128]
|
19 |
+
):
|
20 |
+
assert gt.shape == est.shape
|
21 |
+
assert len(n_fft) == len(n_hop)
|
22 |
+
|
23 |
+
score = 0.0
|
24 |
+
for i in range(len(n_fft)):
|
25 |
+
gt_spec = librosa.magphase(
|
26 |
+
librosa.stft(gt, n_fft=n_fft[i], hop_length=n_hop[i])
|
27 |
+
)[0]
|
28 |
+
est_spec = librosa.magphase(
|
29 |
+
librosa.stft(est, n_fft=n_fft[i], hop_length=n_hop[i])
|
30 |
+
)[0]
|
31 |
+
score = score + np.mean((gt_spec - est_spec) ** 2)
|
32 |
+
|
33 |
+
return score
|
34 |
+
|
35 |
+
|
36 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
37 |
+
|
38 |
+
parser.add_argument(
|
39 |
+
"--target",
|
40 |
+
type=str,
|
41 |
+
default="all",
|
42 |
+
help="target source. all, vocals, drums, bass, other, 0.5_mixed",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--root", type=str, default="/path/to/musdb18hq_loudnorm"
|
46 |
+
)
|
47 |
+
parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
|
48 |
+
parser.add_argument(
|
49 |
+
"--output_directory",
|
50 |
+
type=str,
|
51 |
+
default="/path/to/results",
|
52 |
+
)
|
53 |
+
parser.add_argument("--loudnorm_lufs", type=float, default=-14.0)
|
54 |
+
parser.add_argument(
|
55 |
+
"--calc_mse",
|
56 |
+
type=str2bool,
|
57 |
+
default=True,
|
58 |
+
help="calculate multi-resolution spectrogram mse",
|
59 |
+
)
|
60 |
+
|
61 |
+
parser.add_argument(
|
62 |
+
"--calc_results",
|
63 |
+
type=str2bool,
|
64 |
+
default=True,
|
65 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
66 |
+
)
|
67 |
+
|
68 |
+
args, _ = parser.parse_known_args()
|
69 |
+
|
70 |
+
args.sample_rate = 44100
|
71 |
+
|
72 |
+
meter = pyln.Meter(args.sample_rate)
|
73 |
+
|
74 |
+
if args.calc_results:
|
75 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
76 |
+
else:
|
77 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
78 |
+
|
79 |
+
if args.target == "all" or args.target == "0.5_mixed":
|
80 |
+
test_tracks = glob.glob(f"{args.root}/*/mixture.wav")
|
81 |
+
else:
|
82 |
+
test_tracks = glob.glob(f"{args.root}/*/{args.target}.wav")
|
83 |
+
i = 0
|
84 |
+
|
85 |
+
dict_song_score = {}
|
86 |
+
list_si_sdr = []
|
87 |
+
list_multi_mse = []
|
88 |
+
for track in tqdm.tqdm(test_tracks):
|
89 |
+
if args.target == "all": # for standard de-limiter estimation
|
90 |
+
audio_name = os.path.basename(os.path.dirname(track))
|
91 |
+
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
|
92 |
+
|
93 |
+
est_delimiter = librosa.load(
|
94 |
+
f"{args.test_output_dir}/{audio_name}/all.wav",
|
95 |
+
sr=args.sample_rate,
|
96 |
+
mono=False,
|
97 |
+
)[0]
|
98 |
+
|
99 |
+
else: # for source-separated de-limiter estimation
|
100 |
+
audio_name = os.path.basename(os.path.dirname(track))
|
101 |
+
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
|
102 |
+
est_delimiter = librosa.load(
|
103 |
+
f"{args.test_output_dir}/{audio_name}/{args.target}.wav",
|
104 |
+
sr=args.sample_rate,
|
105 |
+
mono=False,
|
106 |
+
)[0]
|
107 |
+
|
108 |
+
|
109 |
+
metrics_dict = get_metrics(
|
110 |
+
gt_source + est_delimiter,
|
111 |
+
gt_source,
|
112 |
+
est_delimiter,
|
113 |
+
sample_rate=args.sample_rate,
|
114 |
+
metrics_list=["si_sdr"],
|
115 |
+
)
|
116 |
+
|
117 |
+
if args.calc_mse:
|
118 |
+
multi_resolution_spectrogram_mse_score = multi_resolution_spectrogram_mse(
|
119 |
+
gt_source, est_delimiter
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
multi_resolution_spectrogram_mse_score = None
|
123 |
+
|
124 |
+
dict_song_score[audio_name] = {
|
125 |
+
"si_sdr": metrics_dict["si_sdr"],
|
126 |
+
"multi_mse": multi_resolution_spectrogram_mse_score,
|
127 |
+
}
|
128 |
+
list_si_sdr.append(metrics_dict["si_sdr"])
|
129 |
+
list_multi_mse.append(multi_resolution_spectrogram_mse_score)
|
130 |
+
|
131 |
+
i += 1
|
132 |
+
|
133 |
+
print(f"{args.exp_name} on {args.target}")
|
134 |
+
print(f"SI-SDR score: {sum(list_si_sdr) / len(list_si_sdr)}")
|
135 |
+
if args.calc_mse:
|
136 |
+
print(f"multi-mse score: {sum(list_multi_mse) / len(list_multi_mse)}")
|
137 |
+
|
138 |
+
if args.target != "all":
|
139 |
+
# save dict_song_score to json file
|
140 |
+
with open(f"{args.test_output_dir}/score_{args.target}.json", "w") as f:
|
141 |
+
json.dump(dict_song_score, f, indent=4)
|
142 |
+
else:
|
143 |
+
# save dict_song_score to json file
|
144 |
+
with open(f"{args.test_output_dir}/score.json", "w") as f:
|
145 |
+
json.dump(dict_song_score, f, indent=4)
|
eval_delimit/score_diff_dyn_complexity.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import csv
|
4 |
+
import json
|
5 |
+
import glob
|
6 |
+
|
7 |
+
import tqdm
|
8 |
+
import numpy as np
|
9 |
+
import librosa
|
10 |
+
import musdb
|
11 |
+
import pyloudnorm as pyln
|
12 |
+
|
13 |
+
from utils import str2bool, db2linear
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
16 |
+
|
17 |
+
parser.add_argument(
|
18 |
+
"--target",
|
19 |
+
type=str,
|
20 |
+
default="all",
|
21 |
+
help="target source. all, vocals, bass, drums, other.",
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--root",
|
25 |
+
type=str,
|
26 |
+
default="/path/to/musdb18hq_loudnorm",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--output_directory",
|
30 |
+
type=str,
|
31 |
+
default="/path/to/results",
|
32 |
+
)
|
33 |
+
parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
|
34 |
+
parser.add_argument(
|
35 |
+
"--calc_results",
|
36 |
+
type=str2bool,
|
37 |
+
default=True,
|
38 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
39 |
+
)
|
40 |
+
|
41 |
+
args, _ = parser.parse_known_args()
|
42 |
+
|
43 |
+
args.sample_rate = 44100
|
44 |
+
meter = pyln.Meter(args.sample_rate)
|
45 |
+
|
46 |
+
if args.calc_results:
|
47 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
48 |
+
else:
|
49 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
50 |
+
|
51 |
+
|
52 |
+
est_track_list = glob.glob(f"{args.test_output_dir}/*/{args.target}.wav")
|
53 |
+
f = open(
|
54 |
+
f"{args.test_output_dir}/score_feature_{args.target}.json",
|
55 |
+
encoding="UTF-8",
|
56 |
+
)
|
57 |
+
dict_song_score_est = json.loads(f.read())
|
58 |
+
|
59 |
+
if args.target == "all":
|
60 |
+
ref_track_list = glob.glob(f"{args.root}/*/mixture.wav")
|
61 |
+
f = open(f"{args.root}/score_feature.json", encoding="UTF-8")
|
62 |
+
dict_song_score_ref = json.loads(f.read())
|
63 |
+
else:
|
64 |
+
ref_track_list = glob.glob(f"{args.root}/*/{args.target}.wav")
|
65 |
+
f = open(f"{args.root}/score_feature_{args.target}.json", encoding="UTF-8")
|
66 |
+
dict_song_score_ref = json.loads(f.read())
|
67 |
+
|
68 |
+
i = 0
|
69 |
+
|
70 |
+
dict_song_score = {}
|
71 |
+
list_diff_dynamic_complexity = []
|
72 |
+
|
73 |
+
for track in tqdm.tqdm(ref_track_list):
|
74 |
+
audio_name = os.path.basename(os.path.dirname(track))
|
75 |
+
ref_dyn_complexity = dict_song_score_ref[audio_name]["dynamic_complexity_score"]
|
76 |
+
est_dyn_complexity = dict_song_score_est[audio_name]["dynamic_complexity_score"]
|
77 |
+
|
78 |
+
list_diff_dynamic_complexity.append(est_dyn_complexity - ref_dyn_complexity)
|
79 |
+
|
80 |
+
i += 1
|
81 |
+
|
82 |
+
print(
|
83 |
+
f"Dynamic complexity difference {args.exp_name} vs {os.path.basename(args.root)} on {args.target}"
|
84 |
+
)
|
85 |
+
print("mean: ", np.mean(list_diff_dynamic_complexity))
|
86 |
+
print("median: ", np.median(list_diff_dynamic_complexity))
|
87 |
+
print("std: ", np.std(list_diff_dynamic_complexity))
|
eval_delimit/score_fad.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# We are going to use FAD based on https://github.com/gudgud96/frechet-audio-distance
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import glob
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
from frechet_audio_distance import FrechetAudioDistance
|
8 |
+
|
9 |
+
from utils import str2bool
|
10 |
+
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
13 |
+
|
14 |
+
parser.add_argument(
|
15 |
+
"--target",
|
16 |
+
type=str,
|
17 |
+
default="all",
|
18 |
+
help="target source. all, vocals, drums, bass, other",
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--root",
|
22 |
+
type=str,
|
23 |
+
default="/path/to/musdb18hq_loudnorm",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--output_directory",
|
27 |
+
type=str,
|
28 |
+
default="/path/to/results",
|
29 |
+
)
|
30 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
31 |
+
parser.add_argument(
|
32 |
+
"--calc_results",
|
33 |
+
type=str2bool,
|
34 |
+
default=True,
|
35 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
36 |
+
)
|
37 |
+
|
38 |
+
args, _ = parser.parse_known_args()
|
39 |
+
|
40 |
+
os.makedirs(f"{args.root}/musdb_hq_loudnorm_16k_mono_link", exist_ok=True)
|
41 |
+
|
42 |
+
song_list = glob.glob(f"{args.root}/musdb_hq_loudnorm_16k_mono/*/mixture.wav")
|
43 |
+
for song in song_list:
|
44 |
+
song_name = os.path.basename(os.path.dirname(song))
|
45 |
+
subprocess.run(
|
46 |
+
f'ln --symbolic "{song}" "{args.root}/musdb_hq_loudnorm_16k_mono_link/{song_name}.wav"',
|
47 |
+
shell=True,
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
if args.calc_results:
|
52 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
53 |
+
else:
|
54 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
55 |
+
|
56 |
+
os.makedirs(f"{args.test_output_dir}_16k_mono_link", exist_ok=True)
|
57 |
+
|
58 |
+
song_list = glob.glob(f"{args.test_output_dir}_16k_mono/*/{args.target}.wav")
|
59 |
+
for song in song_list:
|
60 |
+
song_name = os.path.basename(os.path.dirname(song))
|
61 |
+
subprocess.run(
|
62 |
+
f'ln --symbolic "{song}" "{args.test_output_dir}_16k_mono_link/{song_name}.wav"',
|
63 |
+
shell=True,
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
frechet = FrechetAudioDistance()
|
68 |
+
|
69 |
+
fad_score = frechet.score(
|
70 |
+
f"{args.root}/musdb_hq_loudnorm_16k_mono_link",
|
71 |
+
f"{args.test_output_dir}_16k_mono_link",
|
72 |
+
)
|
73 |
+
|
74 |
+
print(f"{args.exp_name}")
|
75 |
+
print(f"FAD score: {fad_score}")
|
eval_delimit/score_features.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import csv
|
4 |
+
import json
|
5 |
+
import glob
|
6 |
+
from typing import Any, Optional, Union, Collection
|
7 |
+
|
8 |
+
import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import librosa
|
11 |
+
from librosa.core.spectrum import _spectrogram
|
12 |
+
import musdb
|
13 |
+
import essentia
|
14 |
+
import essentia.standard
|
15 |
+
import pyloudnorm as pyln
|
16 |
+
|
17 |
+
from utils import str2bool, db2linear
|
18 |
+
|
19 |
+
|
20 |
+
def spectral_crest(
|
21 |
+
*,
|
22 |
+
y: Optional[np.ndarray] = None,
|
23 |
+
S: Optional[np.ndarray] = None,
|
24 |
+
n_fft: int = 2048,
|
25 |
+
hop_length: int = 512,
|
26 |
+
win_length: Optional[int] = None,
|
27 |
+
window: str = "hann",
|
28 |
+
center: bool = True,
|
29 |
+
pad_mode: str = "constant",
|
30 |
+
amin: float = 1e-10,
|
31 |
+
power: float = 2.0,
|
32 |
+
) -> np.ndarray:
|
33 |
+
"""Compute spectral crest
|
34 |
+
|
35 |
+
Spectral crest (or tonality coefficient) is a measure of
|
36 |
+
the ratio of the maximum of the spectrum to the arithmetic mean of the spectrum
|
37 |
+
|
38 |
+
A higher spectral crest => more tonality,
|
39 |
+
A lower spectral crest => more noisy.
|
40 |
+
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
y : np.ndarray [shape=(..., n)] or None
|
45 |
+
audio time series. Multi-channel is supported.
|
46 |
+
S : np.ndarray [shape=(..., d, t)] or None
|
47 |
+
(optional) pre-computed spectrogram magnitude
|
48 |
+
n_fft : int > 0 [scalar]
|
49 |
+
FFT window size
|
50 |
+
hop_length : int > 0 [scalar]
|
51 |
+
hop length for STFT. See `librosa.stft` for details.
|
52 |
+
win_length : int <= n_fft [scalar]
|
53 |
+
Each frame of audio is windowed by `window()`.
|
54 |
+
The window will be of length `win_length` and then padded
|
55 |
+
with zeros to match ``n_fft``.
|
56 |
+
If unspecified, defaults to ``win_length = n_fft``.
|
57 |
+
window : string, tuple, number, function, or np.ndarray [shape=(n_fft,)]
|
58 |
+
- a window specification (string, tuple, or number);
|
59 |
+
see `scipy.signal.get_window`
|
60 |
+
- a window function, such as `scipy.signal.windows.hann`
|
61 |
+
- a vector or array of length ``n_fft``
|
62 |
+
.. see also:: `librosa.filters.get_window`
|
63 |
+
center : boolean
|
64 |
+
- If `True`, the signal ``y`` is padded so that frame
|
65 |
+
``t`` is centered at ``y[t * hop_length]``.
|
66 |
+
- If `False`, then frame `t` begins at ``y[t * hop_length]``
|
67 |
+
pad_mode : string
|
68 |
+
If ``center=True``, the padding mode to use at the edges of the signal.
|
69 |
+
By default, STFT uses zero padding.
|
70 |
+
amin : float > 0 [scalar]
|
71 |
+
minimum threshold for ``S`` (=added noise floor for numerical stability)
|
72 |
+
power : float > 0 [scalar]
|
73 |
+
Exponent for the magnitude spectrogram.
|
74 |
+
e.g., 1 for energy, 2 for power, etc.
|
75 |
+
Power spectrogram is usually used for computing spectral flatness.
|
76 |
+
|
77 |
+
Returns
|
78 |
+
-------
|
79 |
+
crest : np.ndarray [shape=(..., 1, t)]
|
80 |
+
spectral crest for each frame.
|
81 |
+
|
82 |
+
|
83 |
+
"""
|
84 |
+
|
85 |
+
S, n_fft = _spectrogram(
|
86 |
+
y=y,
|
87 |
+
S=S,
|
88 |
+
n_fft=n_fft,
|
89 |
+
hop_length=hop_length,
|
90 |
+
power=1.0,
|
91 |
+
win_length=win_length,
|
92 |
+
window=window,
|
93 |
+
center=center,
|
94 |
+
pad_mode=pad_mode,
|
95 |
+
)
|
96 |
+
|
97 |
+
S_thresh = np.maximum(amin, S**power)
|
98 |
+
# gmean = np.exp(np.mean(np.log(S_thresh), axis=-2, keepdims=True))
|
99 |
+
gmax = np.max(S_thresh, axis=-2, keepdims=True)
|
100 |
+
amean = np.mean(S_thresh, axis=-2, keepdims=True)
|
101 |
+
crest: np.ndarray = gmax / amean
|
102 |
+
return crest
|
103 |
+
|
104 |
+
|
105 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
106 |
+
|
107 |
+
parser.add_argument(
|
108 |
+
"--target",
|
109 |
+
type=str,
|
110 |
+
default="all",
|
111 |
+
help="target source. all, vocals, drums, bass, other",
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--root", type=str, default="/path/to/musdb18hq_loudnorm"
|
115 |
+
)
|
116 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
117 |
+
parser.add_argument(
|
118 |
+
"--output_directory",
|
119 |
+
type=str,
|
120 |
+
default="/path/to/results",
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"--calc_results",
|
124 |
+
type=str2bool,
|
125 |
+
default=True,
|
126 |
+
help="calculate results or musdb-hq or musdb-XL test dataset",
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
+
args, _ = parser.parse_known_args()
|
131 |
+
|
132 |
+
args.sample_rate = 44100
|
133 |
+
|
134 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
135 |
+
|
136 |
+
if args.calc_results:
|
137 |
+
track_list = glob.glob(
|
138 |
+
f"{args.output_directory}/test/{args.exp_name}/*/{args.target}.wav"
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
if args.target == "all":
|
142 |
+
track_list = glob.glob(f"{args.root}/*/mixture.wav")
|
143 |
+
else:
|
144 |
+
track_list = glob.glob(f"{args.root}/*/{args.target}.wav")
|
145 |
+
|
146 |
+
i = 0
|
147 |
+
|
148 |
+
|
149 |
+
dynamic_complexity = essentia.standard.DynamicComplexity()
|
150 |
+
loudness_range = essentia.standard.LoudnessEBUR128()
|
151 |
+
spectral_centroid = essentia.standard.SpectralCentroidTime()
|
152 |
+
crest = essentia.standard.Crest()
|
153 |
+
dynamic_spread = essentia.standard.DistributionShape()
|
154 |
+
central_moments = essentia.standard.CentralMoments()
|
155 |
+
|
156 |
+
dict_song_score = {}
|
157 |
+
list_rms = []
|
158 |
+
list_crest_factor = []
|
159 |
+
list_dc_score = []
|
160 |
+
list_lra_score = []
|
161 |
+
list_sc_hertz = []
|
162 |
+
list_sf_score = []
|
163 |
+
list_spectral_crest_score = []
|
164 |
+
|
165 |
+
for track in tqdm.tqdm(track_list):
|
166 |
+
audio_name = os.path.basename(os.path.dirname(track))
|
167 |
+
gt_source_librosa = librosa.load(f"{track}", sr=args.sample_rate, mono=False)[
|
168 |
+
0
|
169 |
+
] # (nb_channels, nb_samples)
|
170 |
+
gt_source_librosa_mono = librosa.to_mono(gt_source_librosa) # (nb_samples)
|
171 |
+
|
172 |
+
gt_source_essentia = essentia.standard.AudioLoader(filename=f"{track}")()[
|
173 |
+
0
|
174 |
+
] # (nb_samples, nb_channels)
|
175 |
+
gt_source_essentia_cat = np.concatenate(
|
176 |
+
[gt_source_essentia[:, 0], gt_source_essentia[:, 1]]
|
177 |
+
) # (nb_samples * nb_channels)
|
178 |
+
gt_source_essentia_mono = np.mean(gt_source_essentia, axis=1) # (nb_samples)
|
179 |
+
|
180 |
+
rms = np.sqrt(np.mean(gt_source_essentia_cat**2))
|
181 |
+
crest_factor = np.max(np.abs(gt_source_essentia_cat)) / rms
|
182 |
+
|
183 |
+
dc_score, _ = dynamic_complexity(gt_source_essentia_mono)
|
184 |
+
_, _, _, lra_score = loudness_range(gt_source_essentia)
|
185 |
+
sc_hertz = spectral_centroid(gt_source_essentia_mono)
|
186 |
+
sf_score = np.mean(librosa.feature.spectral_flatness(gt_source_librosa_mono))
|
187 |
+
spectral_crest_score = np.mean(spectral_crest(y=gt_source_librosa_mono))
|
188 |
+
|
189 |
+
dict_song_score[audio_name] = {
|
190 |
+
"rms": float(rms),
|
191 |
+
"crest_factor": float(crest_factor),
|
192 |
+
"dynamic_complexity_score": float(dc_score),
|
193 |
+
"lra_score": float(lra_score),
|
194 |
+
"spectral_centroid_hertz": float(sc_hertz),
|
195 |
+
"spectral_flatness_score": float(sf_score),
|
196 |
+
"spectral_crest_score": float(spectral_crest_score),
|
197 |
+
}
|
198 |
+
list_rms.append(rms)
|
199 |
+
list_crest_factor.append(crest_factor)
|
200 |
+
list_dc_score.append(dc_score)
|
201 |
+
list_lra_score.append(lra_score)
|
202 |
+
list_sc_hertz.append(sc_hertz)
|
203 |
+
list_sf_score.append(sf_score)
|
204 |
+
list_spectral_crest_score.append(spectral_crest_score)
|
205 |
+
|
206 |
+
i += 1
|
207 |
+
|
208 |
+
if args.calc_results:
|
209 |
+
print(f"{args.exp_name} on {args.target}")
|
210 |
+
else:
|
211 |
+
print(f"{os.path.basename(args.root)} on {args.target}")
|
212 |
+
print(f"rms: {np.mean(list_rms)}")
|
213 |
+
print(f"crest_factor: {np.mean(list_crest_factor)}")
|
214 |
+
print(f"dynamic_complexity_score: {np.mean(list_dc_score)}")
|
215 |
+
print(f"lra_score: {np.mean(list_lra_score)}")
|
216 |
+
print(f"sc_hertz: {np.mean(list_sc_hertz)}")
|
217 |
+
print(f"sf_score: {np.mean(list_sf_score)}")
|
218 |
+
print(f"spectral_crest_score: {np.mean(list_spectral_crest_score)}")
|
219 |
+
|
220 |
+
|
221 |
+
# save dict_song_score to json file
|
222 |
+
if args.target == "all":
|
223 |
+
file_name = "score_features"
|
224 |
+
else:
|
225 |
+
file_name = f"score_feature_{args.target}"
|
226 |
+
if args.calc_results:
|
227 |
+
with open(
|
228 |
+
f"{args.output_directory}/test/{args.exp_name}/{file_name}.json", "w"
|
229 |
+
) as f:
|
230 |
+
json.dump(dict_song_score, f, indent=4)
|
231 |
+
else:
|
232 |
+
with open(f"{args.root}/{file_name}.json", "w") as f:
|
233 |
+
json.dump(dict_song_score, f, indent=4)
|
eval_delimit/score_peaq.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# We are going to use PEAQ based on https://github.com/HSU-ANT/gstpeaq
|
2 |
+
|
3 |
+
"""
|
4 |
+
python3 score_peaq.py --exp_name=delimit_6_s | tee /path/to/results/delimit_6_s/score_peaq.txt
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
import os
|
10 |
+
import subprocess
|
11 |
+
import glob
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
|
15 |
+
def str2bool(v):
|
16 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
17 |
+
return True
|
18 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
19 |
+
return False
|
20 |
+
else:
|
21 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
22 |
+
|
23 |
+
|
24 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
25 |
+
|
26 |
+
parser.add_argument(
|
27 |
+
"--target",
|
28 |
+
type=str,
|
29 |
+
default="all",
|
30 |
+
help="target source. all, vocals, drums, bass, other",
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--root",
|
34 |
+
type=str,
|
35 |
+
default="/path/to/musdb_XL_loudnorm",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--output_directory",
|
39 |
+
type=str,
|
40 |
+
default="/path/to/results/",
|
41 |
+
)
|
42 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
43 |
+
parser.add_argument(
|
44 |
+
"--calc_results",
|
45 |
+
type=str2bool,
|
46 |
+
default=True,
|
47 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
48 |
+
)
|
49 |
+
|
50 |
+
args, _ = parser.parse_known_args()
|
51 |
+
|
52 |
+
if args.calc_results:
|
53 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
54 |
+
else:
|
55 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
56 |
+
|
57 |
+
if args.target == "all":
|
58 |
+
song_list = sorted(glob.glob(f"{args.root}/*/mixture.wav"))
|
59 |
+
|
60 |
+
for song in song_list:
|
61 |
+
song_name = os.path.basename(os.path.dirname(song))
|
62 |
+
est_path = f"{args.test_output_dir}/{song_name}/{args.target}.wav"
|
63 |
+
subprocess.run(
|
64 |
+
f'peaq --gst-plugin-load=/usr/local/lib/gstreamer-1.0/libgstpeaq.so "{song}" "{est_path}"',
|
65 |
+
shell=True,
|
66 |
+
)
|
67 |
+
|
68 |
+
else:
|
69 |
+
song_list = sorted(glob.glob(f"{args.root}/*/{args.target}.wav"))
|
70 |
+
|
71 |
+
for song in song_list:
|
72 |
+
song_name = os.path.basename(os.path.dirname(song))
|
73 |
+
est_path = f"{args.test_output_dir}/{song_name}/{args.target}.wav"
|
74 |
+
subprocess.run(
|
75 |
+
f'peaq --gst-plugin-load=/usr/local/lib/gstreamer-1.0/libgstpeaq.so "{song}" "{est_path}"',
|
76 |
+
shell=True,
|
77 |
+
)
|
eval_delimit/score_peaq_aggregate.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PEAQ aggregate score
|
2 |
+
"""
|
3 |
+
/path/to/results/delimit_6_s/score_peaq.txt
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import glob
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
|
11 |
+
|
12 |
+
def str2bool(v):
|
13 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
14 |
+
return True
|
15 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
16 |
+
return False
|
17 |
+
else:
|
18 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
19 |
+
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
22 |
+
|
23 |
+
parser.add_argument(
|
24 |
+
"--target",
|
25 |
+
type=str,
|
26 |
+
default="all",
|
27 |
+
help="target source. all, vocals, drums, bass, other",
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--root",
|
31 |
+
type=str,
|
32 |
+
default="/path/to/musdb18hq_loudnorm",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--output_directory",
|
36 |
+
type=str,
|
37 |
+
default="/path/to/results",
|
38 |
+
)
|
39 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
40 |
+
parser.add_argument(
|
41 |
+
"--calc_results",
|
42 |
+
type=str2bool,
|
43 |
+
default=True,
|
44 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
45 |
+
)
|
46 |
+
|
47 |
+
args, _ = parser.parse_known_args()
|
48 |
+
|
49 |
+
|
50 |
+
if args.calc_results:
|
51 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
52 |
+
else:
|
53 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
54 |
+
|
55 |
+
|
56 |
+
if args.target == "all":
|
57 |
+
score_path = f"{args.test_output_dir}/score_peaq.txt"
|
58 |
+
else:
|
59 |
+
score_path = f"{args.test_output_dir}/score_peaq_{args.target}.txt"
|
60 |
+
|
61 |
+
# write the code to load score_peaq.txt
|
62 |
+
with open(score_path, "r") as f:
|
63 |
+
score_txt = f.readlines()
|
64 |
+
|
65 |
+
song_list = glob.glob(f"{args.root}/*")
|
66 |
+
|
67 |
+
dict_song_peaq = {}
|
68 |
+
list_peaq = []
|
69 |
+
for idx, song in enumerate(song_list):
|
70 |
+
song_name = os.path.basename(song)
|
71 |
+
peaq = float(score_txt[idx * 2].replace("Objective Difference Grade: ", ""))
|
72 |
+
dict_song_peaq[song_name] = peaq
|
73 |
+
list_peaq.append(peaq)
|
74 |
+
|
75 |
+
print(f"{args.exp_name} on {args.target}")
|
76 |
+
print(f"PEAQ score: {sum(list_peaq) / len(list_peaq)}")
|
77 |
+
|
78 |
+
if args.target == "all":
|
79 |
+
# save dict_song_peaq to json file
|
80 |
+
with open(f"{args.test_output_dir}/score_peaq.json", "w") as f:
|
81 |
+
json.dump(dict_song_peaq, f, indent=4)
|
82 |
+
else:
|
83 |
+
# save dict_song_peaq to json file
|
84 |
+
with open(
|
85 |
+
f"{args.test_output_dir}/score_peaq_{args.target}.json",
|
86 |
+
"w",
|
87 |
+
) as f:
|
88 |
+
json.dump(dict_song_peaq, f, indent=4)
|
inference.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import tqdm
|
8 |
+
import librosa
|
9 |
+
import soundfile as sf
|
10 |
+
import pyloudnorm as pyln
|
11 |
+
from dotmap import DotMap
|
12 |
+
|
13 |
+
from models import load_model_with_args
|
14 |
+
from separate_func import (
|
15 |
+
conv_tasnet_separate,
|
16 |
+
)
|
17 |
+
from utils import str2bool, db2linear
|
18 |
+
|
19 |
+
|
20 |
+
tqdm.monitor_interval = 0
|
21 |
+
|
22 |
+
|
23 |
+
def separate_track_with_model(
|
24 |
+
args, model, device, track_audio, track_name, meter, augmented_gain
|
25 |
+
):
|
26 |
+
with torch.no_grad():
|
27 |
+
if (
|
28 |
+
args.model_loss_params.architecture == "conv_tasnet_mask_on_output"
|
29 |
+
or args.model_loss_params.architecture == "conv_tasnet"
|
30 |
+
):
|
31 |
+
estimates = conv_tasnet_separate(
|
32 |
+
args,
|
33 |
+
model,
|
34 |
+
device,
|
35 |
+
track_audio,
|
36 |
+
track_name,
|
37 |
+
meter=meter,
|
38 |
+
augmented_gain=augmented_gain,
|
39 |
+
)
|
40 |
+
|
41 |
+
return estimates
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
46 |
+
parser.add_argument("--target", type=str, default="all")
|
47 |
+
parser.add_argument("--data_root", type=str, default="./input_data")
|
48 |
+
parser.add_argument("--weight_directory", type=str, default="./weight")
|
49 |
+
parser.add_argument("--output_directory", type=str, default="./output")
|
50 |
+
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
51 |
+
parser.add_argument("--save_name_as_target", type=str2bool, default=False)
|
52 |
+
parser.add_argument(
|
53 |
+
"--loudnorm_input_lufs",
|
54 |
+
type=float,
|
55 |
+
default=None,
|
56 |
+
help="If you want to use loudnorm for input",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--save_output_loudnorm",
|
60 |
+
type=float,
|
61 |
+
default=-14.0,
|
62 |
+
help="Save loudness normalized outputs or not. If you want to save, input target loudness",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--save_mixed_output",
|
66 |
+
type=float,
|
67 |
+
default=None,
|
68 |
+
help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--save_16k_mono",
|
72 |
+
type=str2bool,
|
73 |
+
default=False,
|
74 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--save_histogram",
|
78 |
+
type=str2bool,
|
79 |
+
default=False,
|
80 |
+
help="Save histogram of the output. Only valid when the task is 'delimit'",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--use_singletrackset",
|
84 |
+
type=str2bool,
|
85 |
+
default=False,
|
86 |
+
help="Use SingleTrackSet if input data is too long.",
|
87 |
+
)
|
88 |
+
|
89 |
+
args, _ = parser.parse_known_args()
|
90 |
+
|
91 |
+
with open(f"{args.weight_directory}/{args.target}.json", "r") as f:
|
92 |
+
args_dict = json.load(f)
|
93 |
+
args_dict = DotMap(args_dict)
|
94 |
+
|
95 |
+
for key, value in args_dict["args"].items():
|
96 |
+
if key in list(vars(args).keys()):
|
97 |
+
pass
|
98 |
+
else:
|
99 |
+
setattr(args, key, value)
|
100 |
+
|
101 |
+
args.test_output_dir = f"{args.output_directory}"
|
102 |
+
os.makedirs(args.test_output_dir, exist_ok=True)
|
103 |
+
|
104 |
+
device = torch.device(
|
105 |
+
"cuda" if torch.cuda.is_available() and args.use_gpu else "cpu"
|
106 |
+
)
|
107 |
+
|
108 |
+
###################### Define Models ######################
|
109 |
+
our_model = load_model_with_args(args)
|
110 |
+
our_model = our_model.to(device)
|
111 |
+
|
112 |
+
target_model_path = f"{args.weight_directory}/{args.target}.pth"
|
113 |
+
checkpoint = torch.load(target_model_path, map_location=device)
|
114 |
+
our_model.load_state_dict(checkpoint)
|
115 |
+
|
116 |
+
our_model.eval()
|
117 |
+
|
118 |
+
meter = pyln.Meter(44100)
|
119 |
+
|
120 |
+
test_tracks = glob.glob(f"{args.data_root}/*.wav") + glob.glob(
|
121 |
+
f"{args.data_root}/*.mp3"
|
122 |
+
)
|
123 |
+
|
124 |
+
for track in tqdm.tqdm(test_tracks):
|
125 |
+
track_name = os.path.basename(track).replace(".wav", "").replace(".mp3", "")
|
126 |
+
track_audio, sr = librosa.load(track, sr=None, mono=False) # sr should be 44100
|
127 |
+
|
128 |
+
orig_audio = track_audio.copy()
|
129 |
+
|
130 |
+
if sr != 44100:
|
131 |
+
raise ValueError("Sample rate should be 44100")
|
132 |
+
augmented_gain = None
|
133 |
+
print("Now De-limiting : ", track_name)
|
134 |
+
|
135 |
+
if args.loudnorm_input_lufs: # If you want to use loud-normalized input
|
136 |
+
track_lufs = meter.integrated_loudness(track_audio.T)
|
137 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
138 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
139 |
+
|
140 |
+
track_audio = (
|
141 |
+
torch.as_tensor(track_audio, dtype=torch.float32).unsqueeze(0).to(device)
|
142 |
+
)
|
143 |
+
|
144 |
+
estimates = separate_track_with_model(
|
145 |
+
args, our_model, device, track_audio, track_name, meter, augmented_gain
|
146 |
+
)
|
147 |
+
|
148 |
+
if args.save_mixed_output:
|
149 |
+
track_lufs = meter.integrated_loudness(orig_audio.T)
|
150 |
+
augmented_gain = args.save_output_loudnorm - track_lufs
|
151 |
+
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
|
152 |
+
|
153 |
+
mixed_output = orig_audio * args.save_mixed_output + estimates * (
|
154 |
+
1 - args.save_mixed_output
|
155 |
+
)
|
156 |
+
|
157 |
+
sf.write(
|
158 |
+
f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav",
|
159 |
+
mixed_output.T,
|
160 |
+
args.data_params.sample_rate,
|
161 |
+
)
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
main()
|
main_ddp.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from train_ddp import train
|
8 |
+
from utils import get_config
|
9 |
+
|
10 |
+
|
11 |
+
def main():
|
12 |
+
parser = argparse.ArgumentParser(description="Trainer")
|
13 |
+
|
14 |
+
# Put every argumnet in './configs/yymmdd_architecture_number.yaml' and load it.
|
15 |
+
parser.add_argument(
|
16 |
+
"-c",
|
17 |
+
"--config",
|
18 |
+
default="delimit_6_s",
|
19 |
+
type=str,
|
20 |
+
help="Name of the setting file.",
|
21 |
+
)
|
22 |
+
|
23 |
+
config_args = parser.parse_args()
|
24 |
+
|
25 |
+
args = get_config(config_args.config)
|
26 |
+
|
27 |
+
args.img_check = (
|
28 |
+
f"{args.dir_params.output_directory}/img_check/{args.dir_params.exp_name}"
|
29 |
+
)
|
30 |
+
args.output = (
|
31 |
+
f"{args.dir_params.output_directory}/checkpoint/{args.dir_params.exp_name}"
|
32 |
+
)
|
33 |
+
|
34 |
+
# Set which devices to use
|
35 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
36 |
+
os.environ["MASTER_PORT"] = str(random.randint(0, 1800))
|
37 |
+
|
38 |
+
os.makedirs(args.img_check, exist_ok=True)
|
39 |
+
os.makedirs(args.output, exist_ok=True)
|
40 |
+
|
41 |
+
torch.manual_seed(args.sys_params.seed)
|
42 |
+
random.seed(args.sys_params.seed)
|
43 |
+
|
44 |
+
print(args)
|
45 |
+
train(args)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
main()
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .load_models import load_model_with_args
|
models/base_models.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from asteroid.models.base_models import (
|
4 |
+
BaseEncoderMaskerDecoder,
|
5 |
+
_unsqueeze_to_3d,
|
6 |
+
_shape_reconstructed,
|
7 |
+
)
|
8 |
+
from asteroid.utils.torch_utils import pad_x_to_y, jitable_shape
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
|
12 |
+
class BaseEncoderMaskerDecoderWithConfigs(BaseEncoderMaskerDecoder):
|
13 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
|
14 |
+
super().__init__(encoder, masker, decoder, encoder_activation)
|
15 |
+
self.use_encoder = kwargs.get("use_encoder", True)
|
16 |
+
self.apply_mask = kwargs.get("apply_mask", True)
|
17 |
+
self.use_decoder = kwargs.get("use_decoder", True)
|
18 |
+
|
19 |
+
def forward(self, wav):
|
20 |
+
"""
|
21 |
+
Enc/Mask/Dec model forward with some additional options.
|
22 |
+
Some of the models we use, like TFC-TDF-UNet, have no masker.
|
23 |
+
In UMX or X-UMX, they already use masking in their model implementation.
|
24 |
+
Since we do not want to manipulate the model codes, we use this wrapper.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
31 |
+
"""
|
32 |
+
# Remember shape to shape reconstruction, cast to Tensor for torchscript
|
33 |
+
shape = jitable_shape(wav)
|
34 |
+
# Reshape to (batch, n_mix, time)
|
35 |
+
wav = _unsqueeze_to_3d(wav)
|
36 |
+
|
37 |
+
# Real forward
|
38 |
+
if self.use_encoder:
|
39 |
+
tf_rep = self.forward_encoder(wav)
|
40 |
+
else:
|
41 |
+
tf_rep = wav
|
42 |
+
|
43 |
+
est_masks = self.forward_masker(tf_rep)
|
44 |
+
|
45 |
+
if self.apply_mask:
|
46 |
+
masked_tf_rep = self.apply_masks(tf_rep, est_masks)
|
47 |
+
else: # model already used masking
|
48 |
+
masked_tf_rep = est_masks
|
49 |
+
|
50 |
+
if self.use_decoder:
|
51 |
+
decoded = self.forward_decoder(masked_tf_rep)
|
52 |
+
reconstructed = pad_x_to_y(decoded, wav)
|
53 |
+
|
54 |
+
return masked_tf_rep, _shape_reconstructed(reconstructed, shape)
|
55 |
+
|
56 |
+
else: # In UMX or X-UMX, decoder is not used
|
57 |
+
decoded = masked_tf_rep
|
58 |
+
|
59 |
+
return decoded
|
60 |
+
|
61 |
+
|
62 |
+
class BaseEncoderMaskerDecoder_mixture_consistency(BaseEncoderMaskerDecoder):
|
63 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None):
|
64 |
+
super().__init__(encoder, masker, decoder, encoder_activation)
|
65 |
+
|
66 |
+
def forward(self, wav):
|
67 |
+
"""Enc/Mask/Dec model forward with mixture consistent output
|
68 |
+
|
69 |
+
References:
|
70 |
+
[1] : Wisdom, Scott, et al. "Differentiable consistency constraints for improved deep speech enhancement." ICASSP 2019.
|
71 |
+
[2] : Wisdom, Scott, et al. "Unsupervised sound separation using mixture invariant training." NeurIPS 2020.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
78 |
+
"""
|
79 |
+
# Remember shape to shape reconstruction, cast to Tensor for torchscript
|
80 |
+
shape = jitable_shape(wav)
|
81 |
+
# Reshape to (batch, n_mix, time)
|
82 |
+
wav = _unsqueeze_to_3d(wav)
|
83 |
+
|
84 |
+
# Real forward
|
85 |
+
tf_rep = self.forward_encoder(wav)
|
86 |
+
est_masks = self.forward_masker(tf_rep)
|
87 |
+
masked_tf_rep = self.apply_masks(tf_rep, est_masks)
|
88 |
+
decoded = self.forward_decoder(masked_tf_rep)
|
89 |
+
|
90 |
+
reconstructed = _shape_reconstructed(pad_x_to_y(decoded, wav), shape)
|
91 |
+
|
92 |
+
reconstructed = reconstructed + 1 / reconstructed.shape[1] * (
|
93 |
+
wav - reconstructed.sum(dim=1, keepdim=True)
|
94 |
+
)
|
95 |
+
|
96 |
+
return reconstructed
|
97 |
+
|
98 |
+
|
99 |
+
class BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(BaseEncoderMaskerDecoder):
|
100 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
|
101 |
+
super().__init__(encoder, masker, decoder, encoder_activation)
|
102 |
+
self.use_encoder = kwargs.get("use_encoder", True)
|
103 |
+
self.apply_mask = kwargs.get("apply_mask", True)
|
104 |
+
self.use_decoder = kwargs.get("use_decoder", True)
|
105 |
+
self.nb_channels = kwargs.get("nb_channels", 2)
|
106 |
+
self.decoder_activation = kwargs.get("decoder_activation", "sigmoid")
|
107 |
+
if self.decoder_activation == "sigmoid":
|
108 |
+
self.act_after_dec = nn.Sigmoid()
|
109 |
+
elif self.decoder_activation == "relu":
|
110 |
+
self.act_after_dec = nn.ReLU()
|
111 |
+
elif self.decoder_activation == "relu6":
|
112 |
+
self.act_after_dec = nn.ReLU6()
|
113 |
+
elif self.decoder_activation == "tanh":
|
114 |
+
self.act_after_dec = nn.Tanh()
|
115 |
+
elif self.decoder_activation == "none":
|
116 |
+
self.act_after_dec = nn.Identity()
|
117 |
+
else:
|
118 |
+
self.act_after_dec = nn.Sigmoid()
|
119 |
+
|
120 |
+
def forward(self, wav):
|
121 |
+
"""
|
122 |
+
For the De-limit task, we will apply the mask on the output of the decoder.
|
123 |
+
We want decoder to learn the sample-wise ratio of the sources.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
130 |
+
"""
|
131 |
+
# Remember shape to shape reconstruction, cast to Tensor for torchscript
|
132 |
+
shape = jitable_shape(wav)
|
133 |
+
# Reshape to (batch, n_mix, time)
|
134 |
+
wav = _unsqueeze_to_3d(wav) # (batch, n_channels, time)
|
135 |
+
|
136 |
+
# Real forward
|
137 |
+
if self.use_encoder:
|
138 |
+
tf_rep = self.forward_encoder(wav) # (batch, n_channels, freq, time)
|
139 |
+
else:
|
140 |
+
tf_rep = wav
|
141 |
+
|
142 |
+
if self.nb_channels == 2:
|
143 |
+
tf_rep = rearrange(
|
144 |
+
tf_rep, "b c f t -> b (c f) t"
|
145 |
+
) # c == 2 when stereo input.
|
146 |
+
est_masks = self.forward_masker(tf_rep) # (batch, 1, freq, time)
|
147 |
+
|
148 |
+
# we are going to apply the mask on the output of the decoder
|
149 |
+
if self.use_decoder:
|
150 |
+
if self.nb_channels == 2:
|
151 |
+
est_masks = rearrange(est_masks, "b 1 f t -> b f t")
|
152 |
+
est_masks_decoded = self.forward_decoder(est_masks)
|
153 |
+
est_masks_decoded = pad_x_to_y(est_masks_decoded, wav) # (batch, 1, time)
|
154 |
+
est_masks_decoded = self.act_after_dec(
|
155 |
+
est_masks_decoded
|
156 |
+
) # (batch, 1, time)
|
157 |
+
decoded = wav * est_masks_decoded # (batch, n_channels, time)
|
158 |
+
|
159 |
+
return (
|
160 |
+
est_masks_decoded,
|
161 |
+
decoded,
|
162 |
+
)
|
163 |
+
|
164 |
+
else:
|
165 |
+
decoded = est_masks
|
166 |
+
|
167 |
+
return (decoded,)
|
168 |
+
|
169 |
+
|
170 |
+
class BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(BaseEncoderMaskerDecoder):
|
171 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
|
172 |
+
super().__init__(encoder, masker, decoder, encoder_activation)
|
173 |
+
self.use_encoder = kwargs.get("use_encoder", True)
|
174 |
+
self.apply_mask = kwargs.get("apply_mask", True)
|
175 |
+
self.use_decoder = kwargs.get("use_decoder", True)
|
176 |
+
self.nb_channels = kwargs.get("nb_channels", 2)
|
177 |
+
self.decoder_activation = kwargs.get("decoder_activation", "none")
|
178 |
+
if self.decoder_activation == "sigmoid":
|
179 |
+
self.act_after_dec = nn.Sigmoid()
|
180 |
+
elif self.decoder_activation == "relu":
|
181 |
+
self.act_after_dec = nn.ReLU()
|
182 |
+
elif self.decoder_activation == "relu6":
|
183 |
+
self.act_after_dec = nn.ReLU6()
|
184 |
+
elif self.decoder_activation == "tanh":
|
185 |
+
self.act_after_dec = nn.Tanh()
|
186 |
+
elif self.decoder_activation == "none":
|
187 |
+
self.act_after_dec = nn.Identity()
|
188 |
+
else:
|
189 |
+
self.act_after_dec = nn.Sigmoid()
|
190 |
+
|
191 |
+
def forward(self, wav):
|
192 |
+
"""
|
193 |
+
Enc/Mask/Dec model forward with some additional options.
|
194 |
+
For MultiChannel usage of asteroid-based models. (e.g. ConvTasNet)
|
195 |
+
|
196 |
+
|
197 |
+
Args:
|
198 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
202 |
+
"""
|
203 |
+
# Remember shape to shape reconstruction, cast to Tensor for torchscript
|
204 |
+
shape = jitable_shape(wav)
|
205 |
+
# Reshape to (batch, n_mix, time)
|
206 |
+
wav = _unsqueeze_to_3d(wav)
|
207 |
+
|
208 |
+
# Real forward
|
209 |
+
if self.use_encoder:
|
210 |
+
tf_rep = self.forward_encoder(wav)
|
211 |
+
else:
|
212 |
+
tf_rep = wav
|
213 |
+
|
214 |
+
if self.nb_channels == 2:
|
215 |
+
tf_rep = rearrange(
|
216 |
+
tf_rep, "b c f t -> b (c f) t"
|
217 |
+
) # c == 2 when stereo input.
|
218 |
+
est_masks = self.forward_masker(tf_rep)
|
219 |
+
|
220 |
+
if self.nb_channels == 2:
|
221 |
+
tf_rep = rearrange(tf_rep, "b (c f) t -> b c f t", c=self.nb_channels)
|
222 |
+
|
223 |
+
if self.apply_mask:
|
224 |
+
# Since original asteroid implementation of masking includes unnecessary unsqueeze operation, we will do it manually.
|
225 |
+
masked_tf_rep = est_masks * tf_rep
|
226 |
+
else:
|
227 |
+
masked_tf_rep = est_masks
|
228 |
+
|
229 |
+
if self.use_decoder:
|
230 |
+
decoded = self.forward_decoder(masked_tf_rep)
|
231 |
+
reconstructed = pad_x_to_y(decoded, wav)
|
232 |
+
reconstructed = self.act_after_dec(reconstructed)
|
233 |
+
|
234 |
+
return masked_tf_rep, _shape_reconstructed(reconstructed, shape)
|
235 |
+
|
236 |
+
else:
|
237 |
+
decoded = masked_tf_rep
|
238 |
+
|
239 |
+
return decoded
|
models/load_models.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from asteroid_filterbanks import make_enc_dec
|
5 |
+
|
6 |
+
from asteroid.masknn import TDConvNet
|
7 |
+
|
8 |
+
import utils
|
9 |
+
from .base_models import (
|
10 |
+
BaseEncoderMaskerDecoderWithConfigs,
|
11 |
+
BaseEncoderMaskerDecoderWithConfigsMaskOnOutput,
|
12 |
+
BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def load_model_with_args(args):
|
17 |
+
if args.model_loss_params.architecture == "conv_tasnet_mask_on_output":
|
18 |
+
encoder, decoder = make_enc_dec(
|
19 |
+
"free",
|
20 |
+
n_filters=args.conv_tasnet_params.n_filters,
|
21 |
+
kernel_size=args.conv_tasnet_params.kernel_size,
|
22 |
+
stride=args.conv_tasnet_params.stride,
|
23 |
+
sample_rate=args.sample_rate,
|
24 |
+
)
|
25 |
+
masker = TDConvNet(
|
26 |
+
in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo
|
27 |
+
n_src=1, # for de-limit task.
|
28 |
+
out_chan=encoder.n_feats_out,
|
29 |
+
n_blocks=args.conv_tasnet_params.n_blocks,
|
30 |
+
n_repeats=args.conv_tasnet_params.n_repeats,
|
31 |
+
bn_chan=args.conv_tasnet_params.bn_chan,
|
32 |
+
hid_chan=args.conv_tasnet_params.hid_chan,
|
33 |
+
skip_chan=args.conv_tasnet_params.skip_chan,
|
34 |
+
# conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
|
35 |
+
norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
|
36 |
+
mask_act=args.conv_tasnet_params.mask_act,
|
37 |
+
# causal=args.conv_tasnet_params.causal,
|
38 |
+
)
|
39 |
+
|
40 |
+
model = BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(
|
41 |
+
encoder,
|
42 |
+
masker,
|
43 |
+
decoder,
|
44 |
+
encoder_activation=args.conv_tasnet_params.encoder_activation,
|
45 |
+
use_encoder=True,
|
46 |
+
apply_mask=True,
|
47 |
+
use_decoder=True,
|
48 |
+
decoder_activation=args.conv_tasnet_params.decoder_activation,
|
49 |
+
)
|
50 |
+
model.use_encoder_to_target = False
|
51 |
+
|
52 |
+
elif args.model_loss_params.architecture == "conv_tasnet":
|
53 |
+
encoder, decoder = make_enc_dec(
|
54 |
+
"free",
|
55 |
+
n_filters=args.conv_tasnet_params.n_filters,
|
56 |
+
kernel_size=args.conv_tasnet_params.kernel_size,
|
57 |
+
stride=args.conv_tasnet_params.stride,
|
58 |
+
sample_rate=args.sample_rate,
|
59 |
+
)
|
60 |
+
masker = TDConvNet(
|
61 |
+
in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo
|
62 |
+
n_src=args.conv_tasnet_params.n_src, # for de-limit task with the standard conv-tasnet setting.
|
63 |
+
out_chan=encoder.n_feats_out,
|
64 |
+
n_blocks=args.conv_tasnet_params.n_blocks,
|
65 |
+
n_repeats=args.conv_tasnet_params.n_repeats,
|
66 |
+
bn_chan=args.conv_tasnet_params.bn_chan,
|
67 |
+
hid_chan=args.conv_tasnet_params.hid_chan,
|
68 |
+
skip_chan=args.conv_tasnet_params.skip_chan,
|
69 |
+
# conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
|
70 |
+
norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
|
71 |
+
mask_act=args.conv_tasnet_params.mask_act,
|
72 |
+
# causal=args.conv_tasnet_params.causal,
|
73 |
+
)
|
74 |
+
|
75 |
+
model = BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(
|
76 |
+
encoder,
|
77 |
+
masker,
|
78 |
+
decoder,
|
79 |
+
encoder_activation=args.conv_tasnet_params.encoder_activation,
|
80 |
+
use_encoder=True,
|
81 |
+
apply_mask=False if args.conv_tasnet_params.synthesis else True,
|
82 |
+
use_decoder=True,
|
83 |
+
decoder_activation=args.conv_tasnet_params.decoder_activation,
|
84 |
+
)
|
85 |
+
model.use_encoder_to_target = False
|
86 |
+
|
87 |
+
return model
|
prepro/delimit_save_delimiter_stems.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Save loudness normalized (-14 LUFS) musdb-XL audio files for delimiter evaluation
|
2 |
+
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import tqdm
|
7 |
+
import musdb
|
8 |
+
import soundfile as sf
|
9 |
+
import librosa
|
10 |
+
import pyloudnorm as pyln
|
11 |
+
|
12 |
+
from utils import db2linear, str2bool
|
13 |
+
|
14 |
+
|
15 |
+
tqdm.monitor_interval = 0
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
20 |
+
|
21 |
+
parser.add_argument(
|
22 |
+
"--target",
|
23 |
+
type=str,
|
24 |
+
default="vocals",
|
25 |
+
help="target source. all, vocals, drums, bass, other",
|
26 |
+
)
|
27 |
+
parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL")
|
28 |
+
parser.add_argument(
|
29 |
+
"--data_root_hq",
|
30 |
+
type=str,
|
31 |
+
default="/data1/Music/musdb18hq",
|
32 |
+
help="this is used when saving loud-norm stem of musdb-XL")
|
33 |
+
parser.add_argument(
|
34 |
+
"--output_directory",
|
35 |
+
type=str,
|
36 |
+
default="/path/to/results",
|
37 |
+
)
|
38 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
39 |
+
parser.add_argument(
|
40 |
+
"--save_16k_mono",
|
41 |
+
type=str2bool,
|
42 |
+
default=False,
|
43 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
args, _ = parser.parse_known_args()
|
48 |
+
|
49 |
+
os.makedirs(args.output_directory, exist_ok=True)
|
50 |
+
|
51 |
+
meter = pyln.Meter(44100)
|
52 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
53 |
+
|
54 |
+
test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True)
|
55 |
+
if args.target != "mixture": # In this file, args.target should not be "mixture"
|
56 |
+
hq_tracks = musdb.DB(root=args.data_root_hq, subsets='test', is_wav=True)
|
57 |
+
|
58 |
+
for idx, track in tqdm.tqdm(enumerate(test_tracks)):
|
59 |
+
track_name = track.name
|
60 |
+
if (
|
61 |
+
os.path.basename(args.data_root) == "musdb18hq"
|
62 |
+
and track_name == "PR - Oh No"
|
63 |
+
): # We have to consider this exception because 'PR - Oh No' mixture.wav is left-panned. We will use the linear mixture instead.
|
64 |
+
# Please refer https://github.com/jeonchangbin49/musdb-XL/blob/main/make_L_and_XL.py
|
65 |
+
track_audio = (
|
66 |
+
track.targets["vocals"].audio
|
67 |
+
+ track.targets["drums"].audio
|
68 |
+
+ track.targets["bass"].audio
|
69 |
+
+ track.targets["other"].audio
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
track_audio = track.audio
|
73 |
+
|
74 |
+
delimiter_track = librosa.load(f"{args.test_output_dir}/{track_name}/all.wav", sr=44100, mono=False)[0].T
|
75 |
+
|
76 |
+
print(track_name)
|
77 |
+
|
78 |
+
if args.target != "mixture":
|
79 |
+
hq_track = hq_tracks[idx]
|
80 |
+
hq_audio = hq_track.audio
|
81 |
+
hq_stem = hq_track.targets[args.target].audio
|
82 |
+
hq_samplewise_gain = track_audio / (hq_audio + 1e-8)
|
83 |
+
XL_stem = hq_samplewise_gain * hq_stem
|
84 |
+
XL_samplewise_gain = delimiter_track / (track_audio + 1e-8)
|
85 |
+
delimiter_stem = XL_samplewise_gain * XL_stem
|
86 |
+
|
87 |
+
sf.write(
|
88 |
+
f"{args.test_output_dir}/{track_name}/{args.target}.wav", delimiter_stem, 44100
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
main()
|
prepro/delimit_save_musdb_loudnorm.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Save loudness normalized (-14 LUFS) musdb-XL audio files for evaluations of de-limiter
|
2 |
+
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import tqdm
|
7 |
+
import musdb
|
8 |
+
import soundfile as sf
|
9 |
+
import librosa
|
10 |
+
import pyloudnorm as pyln
|
11 |
+
|
12 |
+
from utils import db2linear, str2bool
|
13 |
+
|
14 |
+
|
15 |
+
tqdm.monitor_interval = 0
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
20 |
+
|
21 |
+
parser.add_argument(
|
22 |
+
"--target",
|
23 |
+
type=str,
|
24 |
+
default="mixture",
|
25 |
+
help="target source. all, vocals, drums, bass, other",
|
26 |
+
)
|
27 |
+
parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL")
|
28 |
+
parser.add_argument(
|
29 |
+
"--data_root_hq",
|
30 |
+
type=str,
|
31 |
+
default="/path/to/musdb18hq",
|
32 |
+
help="this is used when saving loud-norm stem of musdb-XL")
|
33 |
+
parser.add_argument(
|
34 |
+
"--output_directory",
|
35 |
+
type=str,
|
36 |
+
default="/path/to/musdb_XL_loudnorm",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--loudnorm_input_lufs",
|
40 |
+
type=float,
|
41 |
+
default=-14.0,
|
42 |
+
help="If you want to use loudnorm, input target lufs",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--save_16k_mono",
|
46 |
+
type=str2bool,
|
47 |
+
default=True,
|
48 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
args, _ = parser.parse_known_args()
|
53 |
+
|
54 |
+
os.makedirs(args.output_directory, exist_ok=True)
|
55 |
+
|
56 |
+
meter = pyln.Meter(44100)
|
57 |
+
|
58 |
+
test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True)
|
59 |
+
if args.target != "mixture":
|
60 |
+
hq_tracks = musdb.DB(root=args.data_root_hq, subsets='test', is_wav=True)
|
61 |
+
|
62 |
+
for idx, track in tqdm.tqdm(enumerate(test_tracks)):
|
63 |
+
track_name = track.name
|
64 |
+
if (
|
65 |
+
os.path.basename(args.data_root) == "musdb18hq"
|
66 |
+
and track_name == "PR - Oh No"
|
67 |
+
): # We have to consider this exception because 'PR - Oh No' mixture.wav is left-panned. We will use the linear mixture instead.
|
68 |
+
# Please refer https://github.com/jeonchangbin49/musdb-XL/blob/main/make_L_and_XL.py
|
69 |
+
track_audio = (
|
70 |
+
track.targets["vocals"].audio
|
71 |
+
+ track.targets["drums"].audio
|
72 |
+
+ track.targets["bass"].audio
|
73 |
+
+ track.targets["other"].audio
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
track_audio = track.audio
|
77 |
+
|
78 |
+
print(track_name)
|
79 |
+
|
80 |
+
augmented_gain = None
|
81 |
+
|
82 |
+
track_lufs = meter.integrated_loudness(track_audio)
|
83 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
84 |
+
if os.path.basename(args.data_root) == "musdb18hq":
|
85 |
+
if args.target != "mixture":
|
86 |
+
track_audio = track.targets[args.target].audio
|
87 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
88 |
+
elif os.path.basename(args.data_root) == "musdb_XL":
|
89 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
90 |
+
if args.target != "mixture":
|
91 |
+
hq_track = hq_tracks[idx]
|
92 |
+
hq_audio = hq_track.audio
|
93 |
+
hq_stem = hq_track.targets[args.target].audio
|
94 |
+
samplewise_gain = track_audio / (hq_audio + 1e-8)
|
95 |
+
track_audio = samplewise_gain * hq_stem
|
96 |
+
|
97 |
+
os.makedirs(f"{args.output_directory}/{track_name}", exist_ok=True)
|
98 |
+
sf.write(
|
99 |
+
f"{args.output_directory}/{track_name}/{args.target}.wav", track_audio, 44100
|
100 |
+
)
|
101 |
+
|
102 |
+
if args.save_16k_mono:
|
103 |
+
track_audio_16k_mono = librosa.to_mono(track_audio.T)
|
104 |
+
track_audio_16k_mono = librosa.resample(
|
105 |
+
track_audio_16k_mono,
|
106 |
+
orig_sr=44100,
|
107 |
+
target_sr=16000,
|
108 |
+
)
|
109 |
+
os.makedirs(f"{args.output_directory}_16k_mono/{track_name}", exist_ok=True)
|
110 |
+
sf.write(
|
111 |
+
f"{args.output_directory}_16k_mono/{track_name}/{args.target}.wav",
|
112 |
+
track_audio_16k_mono,
|
113 |
+
samplerate=16000,
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
main()
|
prepro/delimit_train_ozone_prepro.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import csv
|
4 |
+
import glob
|
5 |
+
import argparse
|
6 |
+
import random
|
7 |
+
import math
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import soundfile as sf
|
11 |
+
import pedalboard
|
12 |
+
import numpy as np
|
13 |
+
import pyloudnorm as pyln
|
14 |
+
from scipy.stats import gamma
|
15 |
+
import torchaudio
|
16 |
+
|
17 |
+
|
18 |
+
def str2bool(v):
|
19 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
20 |
+
return True
|
21 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
22 |
+
return False
|
23 |
+
else:
|
24 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
25 |
+
|
26 |
+
|
27 |
+
def _augment_gain_ozone(audio, low=0.25, high=1.25):
|
28 |
+
"""Applies a random gain between `low` and `high`"""
|
29 |
+
g = low + random.random() * (high - low)
|
30 |
+
return audio * g, g
|
31 |
+
|
32 |
+
|
33 |
+
def _augment_channelswap_ozone(audio):
|
34 |
+
"""Swap channels of stereo signals with a probability of p=0.5"""
|
35 |
+
if audio.shape[0] == 2 and random.random() < 0.5:
|
36 |
+
return np.flip(audio, axis=0), True # axis=0 must be given
|
37 |
+
else:
|
38 |
+
return audio, False
|
39 |
+
|
40 |
+
|
41 |
+
# load wav file from arbitrary positions of 16bit stereo wav file
|
42 |
+
def load_wav_arbitrary_position_stereo(
|
43 |
+
filename, sample_rate, seq_duration, return_pos=False
|
44 |
+
):
|
45 |
+
# stereo
|
46 |
+
# seq_duration[second]
|
47 |
+
length = torchaudio.info(filename).num_frames
|
48 |
+
|
49 |
+
random_start = random.randint(
|
50 |
+
0, int(length - math.ceil(seq_duration * sample_rate) - 1)
|
51 |
+
)
|
52 |
+
random_start_sec = librosa.samples_to_time(random_start, sr=sample_rate)
|
53 |
+
X, sr = librosa.load(
|
54 |
+
filename, sr=None, mono=False, offset=random_start_sec, duration=seq_duration
|
55 |
+
)
|
56 |
+
|
57 |
+
if return_pos:
|
58 |
+
return X, random_start_sec
|
59 |
+
else:
|
60 |
+
return X
|
61 |
+
|
62 |
+
|
63 |
+
# def main():
|
64 |
+
parser = argparse.ArgumentParser(description="Preprocess audio files for training")
|
65 |
+
parser.add_argument(
|
66 |
+
"--root",
|
67 |
+
type=str,
|
68 |
+
default="/path/to/musdb18hq",
|
69 |
+
help="Root directory",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--output",
|
73 |
+
type=str,
|
74 |
+
default="/path/to/musdb-XL-train",
|
75 |
+
help="Where to save output files",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--n_samples", type=int, default=300000, help="Number of samples to save"
|
79 |
+
)
|
80 |
+
parser.add_argument("--seq_duration", type=float, default=4.0, help="Sequence duration")
|
81 |
+
parser.add_argument(
|
82 |
+
"--save_fixed", type=str2bool, default=False, help="Save fixed mixture audio"
|
83 |
+
)
|
84 |
+
parser.add_argument(
|
85 |
+
"--target_lufs_mean", type=float, default=-8.0, help="Target LUFS mean"
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"--target_lufs_std", type=float, default=-1.0, help="Target LUFS std"
|
89 |
+
)
|
90 |
+
parser.add_argument("--sample_rate", type=int, default=44100, help="Sample rate")
|
91 |
+
parser.add_argument("--seed", type=int, default=46, help="Random seed")
|
92 |
+
args = parser.parse_args()
|
93 |
+
random.seed(args.seed)
|
94 |
+
|
95 |
+
valid_list = [
|
96 |
+
"ANiMAL - Rockshow",
|
97 |
+
"Actions - One Minute Smile",
|
98 |
+
"Alexander Ross - Goodbye Bolero",
|
99 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
100 |
+
"Fergessen - Nos Palpitants",
|
101 |
+
"James May - On The Line",
|
102 |
+
"Johnny Lokke - Promises & Lies",
|
103 |
+
"Leaf - Summerghost",
|
104 |
+
"Meaxic - Take A Step",
|
105 |
+
"Patrick Talbot - A Reason To Leave",
|
106 |
+
"Skelpolu - Human Mistakes",
|
107 |
+
"Traffic Experiment - Sirens",
|
108 |
+
"Triviul - Angelsaint",
|
109 |
+
"Young Griffo - Pennies",
|
110 |
+
]
|
111 |
+
|
112 |
+
meter = pyln.Meter(args.sample_rate)
|
113 |
+
|
114 |
+
|
115 |
+
sources = ["vocals", "bass", "drums", "other"]
|
116 |
+
song_list = glob.glob(f"{args.root}/train/*")
|
117 |
+
|
118 |
+
vst = pedalboard.load_plugin(
|
119 |
+
"/Library/Audio/Plug-Ins/Components/iZOzone9ElementsAUHook.component"
|
120 |
+
)
|
121 |
+
|
122 |
+
if args.save_fixed:
|
123 |
+
vst_params = []
|
124 |
+
|
125 |
+
os.makedirs(f"{args.output}/ozone_train_fixed", exist_ok=True)
|
126 |
+
|
127 |
+
for song in song_list:
|
128 |
+
print(f"Processing {song}...")
|
129 |
+
song_name = os.path.basename(song)
|
130 |
+
audio_sources = []
|
131 |
+
for source in sources:
|
132 |
+
audio_path = f"{song}/{source}.wav"
|
133 |
+
audio, sr = librosa.load(audio_path, sr=args.sample_rate, mono=False)
|
134 |
+
audio_sources.append(audio)
|
135 |
+
stems = np.stack(audio_sources, axis=0)
|
136 |
+
mixture = stems.sum(0)
|
137 |
+
lufs = meter.integrated_loudness(mixture.T)
|
138 |
+
target_lufs = random.gauss(args.target_lufs_mean, args.target_lufs_std)
|
139 |
+
adjusted_loudness = target_lufs - lufs
|
140 |
+
|
141 |
+
vst.reset()
|
142 |
+
vst.eq_bypass = True
|
143 |
+
vst.img_bypass = True
|
144 |
+
vst.max_mode = 1.0 # Set IRC2 mode
|
145 |
+
vst.max_threshold = min(-adjusted_loudness, 0.0)
|
146 |
+
vst.max_character = min(gamma.rvs(2), 10.0)
|
147 |
+
|
148 |
+
print(
|
149 |
+
f"Applying Ozone 9 Elements IRC2 with threshold {vst.max_threshold} and character {vst.max_character}..."
|
150 |
+
)
|
151 |
+
limited_mixture = vst(mixture, args.sample_rate)
|
152 |
+
|
153 |
+
sf.write(
|
154 |
+
f"{args.output}/ozone_train_fixed/{song_name}.wav",
|
155 |
+
limited_mixture.T,
|
156 |
+
args.sample_rate,
|
157 |
+
)
|
158 |
+
vst_params.append([song_name, vst.max_threshold, vst.max_character])
|
159 |
+
# Save the song name and vst parameters (vst.max_threshold and vst.max_character) to a csv file
|
160 |
+
with open(f"{args.output}/ozone_train_fixed.csv", "w") as f:
|
161 |
+
writer = csv.writer(f)
|
162 |
+
writer.writerow(["song_name", "max_threshold", "max_character"])
|
163 |
+
for idx, list_vst_param in enumerate(vst_params):
|
164 |
+
writer.writerow(list_vst_param)
|
165 |
+
|
166 |
+
else:
|
167 |
+
if os.path.exists(f"{args.output}/ozone_train_random_0.csv"):
|
168 |
+
vst_params = []
|
169 |
+
list_csv_files = glob.glob(f"{args.output}/ozone_train_random_*.csv")
|
170 |
+
list_csv_files.sort()
|
171 |
+
for csv_file in list_csv_files:
|
172 |
+
with open(csv_file, "r") as f:
|
173 |
+
reader = csv.reader(f)
|
174 |
+
next(reader)
|
175 |
+
vst_params.extend([row for row in reader])
|
176 |
+
|
177 |
+
else:
|
178 |
+
vst_params = []
|
179 |
+
|
180 |
+
song_list = [x for x in song_list if os.path.basename(x) not in valid_list]
|
181 |
+
|
182 |
+
os.makedirs(f"{args.output}/ozone_train_random", exist_ok=True)
|
183 |
+
|
184 |
+
for n in range(len(vst_params), args.n_samples):
|
185 |
+
print(f"Processing {n} / {args.n_samples}...")
|
186 |
+
seg_name = f"ozone_seg_{n}"
|
187 |
+
|
188 |
+
lufs_not_inf = True
|
189 |
+
while lufs_not_inf:
|
190 |
+
audio_sources = []
|
191 |
+
source_song_names = {}
|
192 |
+
source_start_secs = {}
|
193 |
+
source_gains = {}
|
194 |
+
source_channelswaps = {}
|
195 |
+
for source in sources:
|
196 |
+
track_path = random.choice(song_list)
|
197 |
+
song_name = os.path.basename(track_path)
|
198 |
+
audio_path = f"{track_path}/{source}.wav"
|
199 |
+
audio, start_sec = load_wav_arbitrary_position_stereo(
|
200 |
+
audio_path, args.sample_rate, args.seq_duration, return_pos=True
|
201 |
+
)
|
202 |
+
audio, gain = _augment_gain_ozone(audio)
|
203 |
+
audio, channelswap = _augment_channelswap_ozone(audio)
|
204 |
+
audio_sources.append(audio)
|
205 |
+
source_song_names[source] = song_name
|
206 |
+
source_start_secs[source] = start_sec
|
207 |
+
source_gains[source] = gain
|
208 |
+
source_channelswaps[source] = channelswap
|
209 |
+
|
210 |
+
stems = np.stack(audio_sources, axis=0)
|
211 |
+
mixture = stems.sum(0)
|
212 |
+
lufs = meter.integrated_loudness(mixture.T)
|
213 |
+
|
214 |
+
# if lufs is inf, then the mixture is silent, so we need to generate a new mixture
|
215 |
+
lufs_not_inf = np.isinf(lufs)
|
216 |
+
|
217 |
+
target_lufs = random.gauss(args.target_lufs_mean, args.target_lufs_std)
|
218 |
+
adjusted_loudness = target_lufs - lufs
|
219 |
+
|
220 |
+
vst.reset()
|
221 |
+
vst.eq_bypass = True
|
222 |
+
vst.img_bypass = True
|
223 |
+
vst.max_mode = 1.0 # Set IRC2 mode
|
224 |
+
vst.max_threshold = min(max(-20, -adjusted_loudness), 0.0)
|
225 |
+
vst.max_character = min(gamma.rvs(2), 10.0)
|
226 |
+
|
227 |
+
print(
|
228 |
+
f"Applying Ozone 9 Elements IRC2 with threshold {vst.max_threshold} and character {vst.max_character}..."
|
229 |
+
)
|
230 |
+
limited_mixture = vst(mixture, args.sample_rate)
|
231 |
+
|
232 |
+
sf.write(
|
233 |
+
f"{args.output}/ozone_train_random_0/{seg_name}.wav",
|
234 |
+
limited_mixture.T,
|
235 |
+
args.sample_rate,
|
236 |
+
)
|
237 |
+
vst_params.append(
|
238 |
+
[
|
239 |
+
seg_name,
|
240 |
+
vst.max_threshold,
|
241 |
+
vst.max_character,
|
242 |
+
source_song_names["vocals"],
|
243 |
+
source_start_secs["vocals"],
|
244 |
+
source_gains["vocals"],
|
245 |
+
source_channelswaps["vocals"],
|
246 |
+
source_song_names["bass"],
|
247 |
+
source_start_secs["bass"],
|
248 |
+
source_gains["bass"],
|
249 |
+
source_channelswaps["bass"],
|
250 |
+
source_song_names["drums"],
|
251 |
+
source_start_secs["drums"],
|
252 |
+
source_gains["drums"],
|
253 |
+
source_channelswaps["drums"],
|
254 |
+
source_song_names["other"],
|
255 |
+
source_start_secs["other"],
|
256 |
+
source_gains["other"],
|
257 |
+
source_channelswaps["other"],
|
258 |
+
]
|
259 |
+
)
|
260 |
+
|
261 |
+
if (n + 1) % 20000 == 0 or n == args.n_samples - 1:
|
262 |
+
# We will separate the csv file into multiple files to avoid memory error
|
263 |
+
# Save the song name and vst parameters (vst.max_threshold and vst.max_character) to a csv file
|
264 |
+
number = int(n // 20000)
|
265 |
+
with open(f"{args.output}/ozone_train_random_{number}.csv", "w") as f:
|
266 |
+
writer = csv.writer(f)
|
267 |
+
writer.writerow(
|
268 |
+
[
|
269 |
+
"song_name",
|
270 |
+
"max_threshold",
|
271 |
+
"max_character",
|
272 |
+
"vocals_name",
|
273 |
+
"vocals_start_sec",
|
274 |
+
"vocals_gain",
|
275 |
+
"vocals_channelswap",
|
276 |
+
"bass_name",
|
277 |
+
"bass_start_sec",
|
278 |
+
"bass_gain",
|
279 |
+
"bass_channelswap",
|
280 |
+
"drums_name",
|
281 |
+
"drums_start_sec",
|
282 |
+
"drums_gain",
|
283 |
+
"drums_channelswap",
|
284 |
+
"other_name",
|
285 |
+
"other_start_sec",
|
286 |
+
"other_gain",
|
287 |
+
"other_channelswap",
|
288 |
+
]
|
289 |
+
)
|
290 |
+
for idx, list_vst_param in enumerate(
|
291 |
+
vst_params[number * 20000 : (number + 1) * 20000]
|
292 |
+
):
|
293 |
+
writer.writerow(list_vst_param)
|
prepro/delimit_valid_L_prepro.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import soundfile as sf
|
6 |
+
import tqdm
|
7 |
+
|
8 |
+
from dataloader import DelimitValidDataset
|
9 |
+
|
10 |
+
|
11 |
+
def main():
|
12 |
+
# Parameters
|
13 |
+
data_path = "/path/to/musdb18hq"
|
14 |
+
save_path = "/path/to/musdb18hq_limited_L"
|
15 |
+
batch_size = 1
|
16 |
+
num_workers = 1
|
17 |
+
sr = 44100
|
18 |
+
|
19 |
+
# Dataset
|
20 |
+
dataset = DelimitValidDataset(root=data_path, valid_target_lufs=-14.39)
|
21 |
+
data_loader = DataLoader(
|
22 |
+
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
|
23 |
+
)
|
24 |
+
dict_valid_loudness = {}
|
25 |
+
# Preprocessing
|
26 |
+
for limited_audio, orig_audio, audio_name, loudness in tqdm.tqdm(data_loader):
|
27 |
+
audio_name = audio_name[0]
|
28 |
+
limited_audio = limited_audio[0].numpy()
|
29 |
+
loudness = float(loudness[0].numpy())
|
30 |
+
dict_valid_loudness[audio_name] = loudness
|
31 |
+
# Save audio
|
32 |
+
os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
|
33 |
+
audio_path = os.path.join(save_path, "valid", audio_name)
|
34 |
+
sf.write(f"{audio_path}.wav", limited_audio.T, sr)
|
35 |
+
# write json write code
|
36 |
+
with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
|
37 |
+
json.dump(dict_valid_loudness, f, indent=4)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
main()
|
prepro/delimit_valid_custom_limiter_prepro.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import soundfile as sf
|
6 |
+
import tqdm
|
7 |
+
|
8 |
+
from dataloader import DelimitValidDataset
|
9 |
+
|
10 |
+
|
11 |
+
def main():
|
12 |
+
# Parameters
|
13 |
+
data_path = "/path/to/musdb18hq"
|
14 |
+
save_path = (
|
15 |
+
"/path/to/musdb18hq_custom_limiter_fixed_attack"
|
16 |
+
)
|
17 |
+
batch_size = 1
|
18 |
+
num_workers = 1
|
19 |
+
sr = 44100
|
20 |
+
|
21 |
+
# Dataset
|
22 |
+
dataset = DelimitValidDataset(
|
23 |
+
root=data_path, use_custom_limiter=True, custom_limiter_attack_range=[2.0, 2.0]
|
24 |
+
)
|
25 |
+
data_loader = DataLoader(
|
26 |
+
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
|
27 |
+
)
|
28 |
+
dict_valid_loudness = {}
|
29 |
+
dict_limiter_params = {}
|
30 |
+
# Preprocessing
|
31 |
+
for (
|
32 |
+
limited_audio,
|
33 |
+
orig_audio,
|
34 |
+
audio_name,
|
35 |
+
loudness,
|
36 |
+
custom_attack,
|
37 |
+
custom_release,
|
38 |
+
) in tqdm.tqdm(data_loader):
|
39 |
+
audio_name = audio_name[0]
|
40 |
+
limited_audio = limited_audio[0].numpy()
|
41 |
+
loudness = float(loudness[0].numpy())
|
42 |
+
dict_valid_loudness[audio_name] = loudness
|
43 |
+
dict_limiter_params[audio_name] = {
|
44 |
+
"attack_ms": float(custom_attack[0].numpy()),
|
45 |
+
"release_ms": float(custom_release[0].numpy()),
|
46 |
+
}
|
47 |
+
# Save audio
|
48 |
+
os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
|
49 |
+
audio_path = os.path.join(save_path, "valid", audio_name)
|
50 |
+
sf.write(f"{audio_path}.wav", limited_audio.T, sr)
|
51 |
+
# write json write code
|
52 |
+
with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
|
53 |
+
json.dump(dict_valid_loudness, f, indent=4)
|
54 |
+
with open(os.path.join(save_path, "valid_limiter_params.json"), "w") as f:
|
55 |
+
json.dump(dict_limiter_params, f, indent=4)
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
main()
|
prepro/delimit_valid_prepro.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import soundfile as sf
|
6 |
+
import tqdm
|
7 |
+
|
8 |
+
from dataloader import DelimitValidDataset
|
9 |
+
|
10 |
+
|
11 |
+
def main():
|
12 |
+
# Parameters
|
13 |
+
data_path = "/path/to/musdb18hq"
|
14 |
+
save_path = "/path/to/musdb18hq_limited"
|
15 |
+
batch_size = 1
|
16 |
+
num_workers = 1
|
17 |
+
sr = 44100
|
18 |
+
|
19 |
+
# Dataset
|
20 |
+
dataset = DelimitValidDataset(root=data_path)
|
21 |
+
data_loader = DataLoader(
|
22 |
+
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
|
23 |
+
)
|
24 |
+
dict_valid_loudness = {}
|
25 |
+
# Preprocessing
|
26 |
+
for limited_audio, orig_audio, audio_name, loudness in tqdm.tqdm(data_loader):
|
27 |
+
audio_name = audio_name[0]
|
28 |
+
limited_audio = limited_audio[0].numpy()
|
29 |
+
loudness = float(loudness[0].numpy())
|
30 |
+
dict_valid_loudness[audio_name] = loudness
|
31 |
+
# Save audio
|
32 |
+
os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
|
33 |
+
audio_path = os.path.join(save_path, "valid", audio_name)
|
34 |
+
sf.write(f"{audio_path}.wav", limited_audio.T, sr)
|
35 |
+
# write json write code
|
36 |
+
with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
|
37 |
+
json.dump(dict_valid_loudness, f, indent=4)
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/asteroid-team/asteroid.git@master
|
2 |
+
numpy
|
3 |
+
librosa
|
4 |
+
soundfile
|
5 |
+
torch
|
6 |
+
torchaudio
|
7 |
+
matplotlib
|
8 |
+
wandb
|
9 |
+
musdb
|
10 |
+
dotmap
|
11 |
+
ema-pytorch
|
12 |
+
pedalboard
|
13 |
+
einops
|
separate_func/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .conv_tasnet_separate import conv_tasnet_separate
|
separate_func/conv_tasnet_separate.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import soundfile as sf
|
4 |
+
import torch
|
5 |
+
import pyloudnorm as pyln
|
6 |
+
import librosa
|
7 |
+
import matplotlib
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
from dataloader import SingleTrackSet
|
11 |
+
from utils import db2linear
|
12 |
+
|
13 |
+
|
14 |
+
def conv_tasnet_separate(
|
15 |
+
args, our_model, device, track_audio, track_name, meter=None, augmented_gain=None
|
16 |
+
):
|
17 |
+
|
18 |
+
if args.use_singletrackset:
|
19 |
+
db = SingleTrackSet(
|
20 |
+
track_audio.squeeze(dim=0),
|
21 |
+
hop_length=args.data_params.nhop,
|
22 |
+
num_frame=128,
|
23 |
+
target_name=args.target,
|
24 |
+
)
|
25 |
+
separated = []
|
26 |
+
|
27 |
+
for item in db:
|
28 |
+
item = item.unsqueeze(0).to(device)
|
29 |
+
estimates, *estimates_vars = our_model(item)
|
30 |
+
if args.task_params.dataset == "delimit":
|
31 |
+
estimates = estimates_vars[0]
|
32 |
+
|
33 |
+
estimates = estimates.cpu().detach()
|
34 |
+
separated.append(
|
35 |
+
estimates[..., db.trim_length : -db.trim_length].cpu().detach().clone()
|
36 |
+
)
|
37 |
+
|
38 |
+
estimates = torch.cat(separated, dim=-1)
|
39 |
+
estimates = estimates[0, :, : track_audio.shape[-1]].numpy()
|
40 |
+
else:
|
41 |
+
estimates, *estimates_vars = our_model(track_audio)
|
42 |
+
if args.save_histogram and args.task_params.dataset == "delimit":
|
43 |
+
plt.figure(figsize=(10, 10))
|
44 |
+
plt.hist(estimates.cpu().detach().numpy().flatten(), bins=100)
|
45 |
+
os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)
|
46 |
+
plt.savefig(
|
47 |
+
f"{args.test_output_dir}/{track_name}/{args.target}_histogram.png"
|
48 |
+
)
|
49 |
+
if args.task_params.dataset == "delimit":
|
50 |
+
estimates = estimates_vars[0]
|
51 |
+
|
52 |
+
estimates = estimates.cpu().detach().numpy()
|
53 |
+
estimates = estimates[0, :, : track_audio.shape[-1]]
|
54 |
+
|
55 |
+
if args.save_name_as_target:
|
56 |
+
os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)
|
57 |
+
|
58 |
+
if args.save_output_loudnorm:
|
59 |
+
print("SAVE Loudness normalized OUTPUT ")
|
60 |
+
loudness = meter.integrated_loudness(estimates.T)
|
61 |
+
estimates = estimates * db2linear(args.save_output_loudnorm - loudness, eps=0.0)
|
62 |
+
elif augmented_gain != None and args.save_output_loudnorm == None:
|
63 |
+
estimates = estimates * db2linear(-augmented_gain, eps=0.0)
|
64 |
+
|
65 |
+
sf.write(
|
66 |
+
f"{args.test_output_dir}/{track_name}/{args.target}.wav"
|
67 |
+
if args.save_name_as_target
|
68 |
+
else f"{args.test_output_dir}/{track_name}.wav",
|
69 |
+
estimates.T,
|
70 |
+
samplerate=args.data_params.sample_rate,
|
71 |
+
)
|
72 |
+
|
73 |
+
if args.save_16k_mono:
|
74 |
+
estimates_16k_mono = librosa.to_mono(estimates)
|
75 |
+
estimates_16k_mono = librosa.resample(
|
76 |
+
estimates_16k_mono,
|
77 |
+
orig_sr=args.data_params.sample_rate,
|
78 |
+
target_sr=16000,
|
79 |
+
)
|
80 |
+
os.makedirs(f"{args.test_output_dir}_16k_mono/{track_name}", exist_ok=True)
|
81 |
+
sf.write(
|
82 |
+
f"{args.test_output_dir}_16k_mono/{track_name}/{args.target}.wav"
|
83 |
+
if args.save_name_as_target
|
84 |
+
else f"{args.test_output_dir}_16k_mono/{track_name}.wav",
|
85 |
+
estimates_16k_mono,
|
86 |
+
samplerate=16000,
|
87 |
+
)
|
88 |
+
|
89 |
+
return estimates
|
solver_ddp.py
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import json
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import wandb
|
7 |
+
import matplotlib
|
8 |
+
|
9 |
+
matplotlib.use("Agg")
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import torch.distributed as dist
|
12 |
+
from torch.utils.data.distributed import DistributedSampler
|
13 |
+
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
14 |
+
from asteroid.losses import (
|
15 |
+
pairwise_neg_sisdr,
|
16 |
+
PairwiseNegSDR,
|
17 |
+
)
|
18 |
+
from einops import rearrange, reduce
|
19 |
+
from ema_pytorch import EMA
|
20 |
+
|
21 |
+
from models import load_model_with_args
|
22 |
+
import utils
|
23 |
+
from dataloader import (
|
24 |
+
MusdbTrainDataset,
|
25 |
+
MusdbValidDataset,
|
26 |
+
DelimitTrainDataset,
|
27 |
+
DelimitValidDataset,
|
28 |
+
OzoneTrainDataset,
|
29 |
+
OzoneValidDataset,
|
30 |
+
aug_from_str,
|
31 |
+
SingleTrackSet,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class Solver(object):
|
36 |
+
def __init__(self):
|
37 |
+
pass
|
38 |
+
|
39 |
+
def set_gpu(self, args):
|
40 |
+
|
41 |
+
if args.wandb_params.use_wandb and args.gpu == 0:
|
42 |
+
if args.wandb_params.sweep:
|
43 |
+
wandb.init(
|
44 |
+
entity=args.wandb_params.entity,
|
45 |
+
project=args.wandb_params.project,
|
46 |
+
config=args,
|
47 |
+
resume=True
|
48 |
+
if args.dir_params.resume != None and args.gpu == 0
|
49 |
+
else False,
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
wandb.init(
|
53 |
+
entity=args.wandb_params.entity,
|
54 |
+
project=args.wandb_params.project,
|
55 |
+
name=f"{args.dir_params.exp_name}",
|
56 |
+
config=args,
|
57 |
+
resume="must"
|
58 |
+
if args.dir_params.resume != None
|
59 |
+
and not args.dir_params.continual_train
|
60 |
+
else False,
|
61 |
+
id=args.wandb_params.rerun_id
|
62 |
+
if args.wandb_params.rerun_id
|
63 |
+
else None,
|
64 |
+
settings=wandb.Settings(start_method="fork"),
|
65 |
+
)
|
66 |
+
|
67 |
+
###################### Define Models ######################
|
68 |
+
self.model = load_model_with_args(args)
|
69 |
+
|
70 |
+
trainable_params = []
|
71 |
+
trainable_params = trainable_params + list(self.model.parameters())
|
72 |
+
|
73 |
+
if args.hyperparams.optimizer == "sgd":
|
74 |
+
print("Use SGD optimizer.")
|
75 |
+
self.optimizer = torch.optim.SGD(
|
76 |
+
params=trainable_params,
|
77 |
+
lr=args.hyperparams.lr,
|
78 |
+
momentum=0.9,
|
79 |
+
weight_decay=args.hyperparams.weight_decay,
|
80 |
+
)
|
81 |
+
elif args.hyperparams.optimizer == "adamw":
|
82 |
+
print("Use AdamW optimizer.")
|
83 |
+
self.optimizer = torch.optim.AdamW(
|
84 |
+
params=trainable_params,
|
85 |
+
lr=args.hyperparams.lr,
|
86 |
+
betas=(0.9, 0.999),
|
87 |
+
amsgrad=False,
|
88 |
+
weight_decay=args.hyperparams.weight_decay,
|
89 |
+
)
|
90 |
+
elif args.hyperparams.optimizer == "radam":
|
91 |
+
print("Use RAdam optimizer.")
|
92 |
+
self.optimizer = torch.optim.RAdam(
|
93 |
+
params=trainable_params,
|
94 |
+
lr=args.hyperparams.lr,
|
95 |
+
betas=(0.9, 0.999),
|
96 |
+
eps=1e-08,
|
97 |
+
weight_decay=args.hyperparams.weight_decay,
|
98 |
+
)
|
99 |
+
elif args.hyperparams.optimizer == "adam":
|
100 |
+
print("Use Adam optimizer.")
|
101 |
+
self.optimizer = torch.optim.Adam(
|
102 |
+
params=trainable_params,
|
103 |
+
lr=args.hyperparams.lr,
|
104 |
+
betas=(0.9, 0.999),
|
105 |
+
weight_decay=args.hyperparams.weight_decay,
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
print("no optimizer loaded")
|
109 |
+
raise NotImplementedError
|
110 |
+
|
111 |
+
if args.hyperparams.lr_scheduler == "step_lr":
|
112 |
+
if args.model_loss_params.architecture == "umx":
|
113 |
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
114 |
+
self.optimizer,
|
115 |
+
mode="min",
|
116 |
+
factor=args.hyperparams.lr_decay_gamma,
|
117 |
+
patience=args.hyperparams.lr_decay_patience,
|
118 |
+
cooldown=10,
|
119 |
+
verbose=True,
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
123 |
+
self.optimizer,
|
124 |
+
mode="min",
|
125 |
+
factor=args.hyperparams.lr_decay_gamma,
|
126 |
+
patience=args.hyperparams.lr_decay_patience,
|
127 |
+
cooldown=0,
|
128 |
+
min_lr=5e-5,
|
129 |
+
verbose=True,
|
130 |
+
)
|
131 |
+
elif args.hyperparams.lr_scheduler == "cos_warmup":
|
132 |
+
self.scheduler = utils.CosineAnnealingWarmUpRestarts(
|
133 |
+
self.optimizer,
|
134 |
+
T_0=40,
|
135 |
+
T_mult=1,
|
136 |
+
eta_max=args.hyperparams.lr,
|
137 |
+
T_up=10,
|
138 |
+
gamma=0.5,
|
139 |
+
)
|
140 |
+
|
141 |
+
torch.cuda.set_device(args.gpu)
|
142 |
+
|
143 |
+
self.model = self.model.to(f"cuda:{args.gpu}")
|
144 |
+
|
145 |
+
############################################################
|
146 |
+
# Define Losses
|
147 |
+
self.criterion = {}
|
148 |
+
|
149 |
+
self.criterion["l1"] = nn.L1Loss().to(args.gpu)
|
150 |
+
self.criterion["mse"] = nn.MSELoss().to(args.gpu)
|
151 |
+
self.criterion["si_sdr"] = pairwise_neg_sisdr.to(args.gpu)
|
152 |
+
self.criterion["snr"] = PairwiseNegSDR("snr").to(args.gpu)
|
153 |
+
self.criterion["bcewithlogits"] = nn.BCEWithLogitsLoss().to(args.gpu)
|
154 |
+
self.criterion["bce"] = nn.BCELoss().to(args.gpu)
|
155 |
+
self.criterion["kl"] = nn.KLDivLoss(log_target=True).to(args.gpu)
|
156 |
+
|
157 |
+
print("Loss functions we use in this training:")
|
158 |
+
print(args.model_loss_params.train_loss_func)
|
159 |
+
|
160 |
+
# Early stopping utils
|
161 |
+
self.es = utils.EarlyStopping(patience=args.hyperparams.patience)
|
162 |
+
self.stop = False
|
163 |
+
|
164 |
+
if args.wandb_params.use_wandb and args.gpu == 0:
|
165 |
+
wandb.watch(self.model, log="all")
|
166 |
+
|
167 |
+
self.start_epoch = 1
|
168 |
+
self.train_losses = []
|
169 |
+
self.valid_losses = []
|
170 |
+
self.train_times = []
|
171 |
+
self.best_epoch = 0
|
172 |
+
|
173 |
+
if args.dir_params.resume and not args.hyperparams.ema:
|
174 |
+
self.resume(args)
|
175 |
+
|
176 |
+
# Distribute models to machine
|
177 |
+
self.model = DDP(
|
178 |
+
self.model,
|
179 |
+
device_ids=[args.gpu],
|
180 |
+
output_device=args.gpu,
|
181 |
+
find_unused_parameters=True,
|
182 |
+
)
|
183 |
+
|
184 |
+
if args.hyperparams.ema:
|
185 |
+
self.model_ema = EMA(
|
186 |
+
self.model,
|
187 |
+
beta=0.999,
|
188 |
+
update_after_step=100,
|
189 |
+
update_every=10,
|
190 |
+
)
|
191 |
+
|
192 |
+
if args.resume and args.hyperparams.ema:
|
193 |
+
self.resume(args)
|
194 |
+
|
195 |
+
###################### Define data pipeline ######################
|
196 |
+
args.hyperparams.batch_size = int(
|
197 |
+
args.hyperparams.batch_size / args.ngpus_per_node
|
198 |
+
)
|
199 |
+
self.mp_context = torch.multiprocessing.get_context("fork")
|
200 |
+
|
201 |
+
if args.task_params.dataset == "musdb":
|
202 |
+
self.train_dataset = MusdbTrainDataset(
|
203 |
+
target=args.task_params.target,
|
204 |
+
root=args.dir_params.root,
|
205 |
+
seq_duration=args.data_params.seq_dur,
|
206 |
+
samples_per_track=args.data_params.samples_per_track,
|
207 |
+
source_augmentations=aug_from_str(
|
208 |
+
["gain", "channelswap"],
|
209 |
+
),
|
210 |
+
sample_rate=args.data_params.sample_rate,
|
211 |
+
seed=args.sys_params.seed,
|
212 |
+
limitaug_method=args.data_params.limitaug_method,
|
213 |
+
limitaug_mode=args.data_params.limitaug_mode,
|
214 |
+
limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs,
|
215 |
+
limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std,
|
216 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
217 |
+
custom_limiter_attack_range=args.data_params.custom_limiter_attack_range,
|
218 |
+
custom_limiter_release_range=args.data_params.custom_limiter_release_range,
|
219 |
+
)
|
220 |
+
self.valid_dataset = MusdbValidDataset(
|
221 |
+
target=args.task_params.target, root=args.dir_params.root
|
222 |
+
)
|
223 |
+
elif args.task_params.dataset == "delimit":
|
224 |
+
if args.data_params.limitaug_method == "ozone":
|
225 |
+
self.train_dataset = OzoneTrainDataset(
|
226 |
+
target=args.task_params.target,
|
227 |
+
root=args.dir_params.root,
|
228 |
+
ozone_root=args.dir_params.ozone_root,
|
229 |
+
use_fixed=args.data_params.use_fixed,
|
230 |
+
seq_duration=args.data_params.seq_dur,
|
231 |
+
samples_per_track=args.data_params.samples_per_track,
|
232 |
+
source_augmentations=aug_from_str(
|
233 |
+
["gain", "channelswap"],
|
234 |
+
),
|
235 |
+
sample_rate=args.data_params.sample_rate,
|
236 |
+
seed=args.sys_params.seed,
|
237 |
+
limitaug_method=args.data_params.limitaug_method,
|
238 |
+
limitaug_mode=args.data_params.limitaug_mode,
|
239 |
+
limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs,
|
240 |
+
limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std,
|
241 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
242 |
+
target_limitaug_mode=args.data_params.target_limitaug_mode,
|
243 |
+
target_limitaug_custom_target_lufs=args.data_params.target_limitaug_custom_target_lufs,
|
244 |
+
target_limitaug_custom_target_lufs_std=args.data_params.target_limitaug_custom_target_lufs_std,
|
245 |
+
custom_limiter_attack_range=args.data_params.custom_limiter_attack_range,
|
246 |
+
custom_limiter_release_range=args.data_params.custom_limiter_release_range,
|
247 |
+
)
|
248 |
+
self.valid_dataset = OzoneValidDataset(
|
249 |
+
target=args.task_params.target,
|
250 |
+
root=args.dir_params.root,
|
251 |
+
ozone_root=args.dir_params.ozone_root,
|
252 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
self.train_dataset = DelimitTrainDataset(
|
256 |
+
target=args.task_params.target,
|
257 |
+
root=args.dir_params.root,
|
258 |
+
seq_duration=args.data_params.seq_dur,
|
259 |
+
samples_per_track=args.data_params.samples_per_track,
|
260 |
+
source_augmentations=aug_from_str(
|
261 |
+
["gain", "channelswap"],
|
262 |
+
),
|
263 |
+
sample_rate=args.data_params.sample_rate,
|
264 |
+
seed=args.sys_params.seed,
|
265 |
+
limitaug_method=args.data_params.limitaug_method,
|
266 |
+
limitaug_mode=args.data_params.limitaug_mode,
|
267 |
+
limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs,
|
268 |
+
limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std,
|
269 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
270 |
+
target_limitaug_mode=args.data_params.target_limitaug_mode,
|
271 |
+
target_limitaug_custom_target_lufs=args.data_params.target_limitaug_custom_target_lufs,
|
272 |
+
target_limitaug_custom_target_lufs_std=args.data_params.target_limitaug_custom_target_lufs_std,
|
273 |
+
custom_limiter_attack_range=args.data_params.custom_limiter_attack_range,
|
274 |
+
custom_limiter_release_range=args.data_params.custom_limiter_release_range,
|
275 |
+
)
|
276 |
+
self.valid_dataset = DelimitValidDataset(
|
277 |
+
target=args.task_params.target,
|
278 |
+
root=args.dir_params.root,
|
279 |
+
delimit_valid_root=args.dir_params.delimit_valid_root,
|
280 |
+
valid_target_lufs=args.data_params.valid_target_lufs,
|
281 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
282 |
+
delimit_valid_L_root=args.dir_params.delimit_valid_L_root,
|
283 |
+
)
|
284 |
+
|
285 |
+
self.train_sampler = DistributedSampler(
|
286 |
+
self.train_dataset, shuffle=True, rank=args.gpu
|
287 |
+
)
|
288 |
+
self.train_loader = torch.utils.data.DataLoader(
|
289 |
+
self.train_dataset,
|
290 |
+
batch_size=args.hyperparams.batch_size,
|
291 |
+
shuffle=False,
|
292 |
+
num_workers=args.sys_params.nb_workers,
|
293 |
+
multiprocessing_context=self.mp_context,
|
294 |
+
pin_memory=True,
|
295 |
+
sampler=self.train_sampler,
|
296 |
+
drop_last=False,
|
297 |
+
)
|
298 |
+
|
299 |
+
self.valid_sampler = DistributedSampler(
|
300 |
+
self.valid_dataset, shuffle=False, rank=args.gpu
|
301 |
+
)
|
302 |
+
self.valid_loader = torch.utils.data.DataLoader(
|
303 |
+
self.valid_dataset,
|
304 |
+
batch_size=1,
|
305 |
+
shuffle=False,
|
306 |
+
num_workers=args.sys_params.nb_workers,
|
307 |
+
multiprocessing_context=self.mp_context,
|
308 |
+
pin_memory=False,
|
309 |
+
sampler=self.valid_sampler,
|
310 |
+
drop_last=False,
|
311 |
+
)
|
312 |
+
|
313 |
+
def train(self, args, epoch):
|
314 |
+
self.end = time.time()
|
315 |
+
self.model.train()
|
316 |
+
|
317 |
+
# get current learning rate
|
318 |
+
for param_group in self.optimizer.param_groups:
|
319 |
+
current_lr = param_group["lr"]
|
320 |
+
|
321 |
+
if (
|
322 |
+
args.sys_params.rank % args.ngpus_per_node == 0
|
323 |
+
): # when the last rank process is finished
|
324 |
+
print(f"Epoch {epoch}, Learning rate: {current_lr}")
|
325 |
+
|
326 |
+
losses = utils.AverageMeter()
|
327 |
+
loss_logger = {}
|
328 |
+
|
329 |
+
loss_logger["train/train loss"] = 0
|
330 |
+
# with torch.autograd.detect_anomaly(): # use this if you want to detect anomaly behavior while training.
|
331 |
+
for i, values in enumerate(self.train_loader):
|
332 |
+
mixture, clean, *train_vars = values
|
333 |
+
|
334 |
+
mixture = mixture.cuda(args.gpu, non_blocking=True)
|
335 |
+
clean = clean.cuda(args.gpu, non_blocking=True)
|
336 |
+
target = clean # target_shape = [batch_size, n_srcs, nb_channels (if stereo: 2), wave_length]
|
337 |
+
loss_input = {}
|
338 |
+
|
339 |
+
estimates, *estimates_vars = self.model(mixture)
|
340 |
+
# estimates = self.model(mixture)
|
341 |
+
|
342 |
+
# loss = []
|
343 |
+
dict_loss = {}
|
344 |
+
|
345 |
+
if args.task_params.dataset == "delimit":
|
346 |
+
estimates = estimates_vars[0]
|
347 |
+
|
348 |
+
for train_loss_idx, single_train_loss_func in enumerate(
|
349 |
+
args.model_loss_params.train_loss_func
|
350 |
+
):
|
351 |
+
if self.model.module.use_encoder_to_target:
|
352 |
+
target_spec = self.model.module.encoder(
|
353 |
+
rearrange(target, "b s c t -> (b s) c t")
|
354 |
+
)
|
355 |
+
target_spec = rearrange(
|
356 |
+
target_spec,
|
357 |
+
"(b s) c f t -> b s c f t",
|
358 |
+
s=args.task_params.bleeding_nsrcs,
|
359 |
+
)
|
360 |
+
loss_else = self.criterion[single_train_loss_func](
|
361 |
+
estimates,
|
362 |
+
target_spec
|
363 |
+
if self.model.module.use_encoder_to_target
|
364 |
+
else target,
|
365 |
+
)
|
366 |
+
dict_loss[single_train_loss_func] = (
|
367 |
+
loss_else.mean()
|
368 |
+
* args.model_loss_params.train_loss_scales[train_loss_idx]
|
369 |
+
)
|
370 |
+
|
371 |
+
loss = sum([value for key, value in dict_loss.items()])
|
372 |
+
|
373 |
+
############################################################
|
374 |
+
|
375 |
+
#################### 5. Back propagation ####################
|
376 |
+
loss.backward()
|
377 |
+
if args.hyperparams.gradient_clip:
|
378 |
+
nn.utils.clip_grad_norm_(
|
379 |
+
self.model.parameters(), max_norm=args.hyperparams.gradient_clip
|
380 |
+
)
|
381 |
+
|
382 |
+
losses.update(loss.item(), clean.size(0))
|
383 |
+
|
384 |
+
loss_logger["train/train loss"] = losses.avg
|
385 |
+
for key, value in dict_loss.items():
|
386 |
+
loss_logger[f"train/{key}"] = value.item()
|
387 |
+
|
388 |
+
self.optimizer.step()
|
389 |
+
|
390 |
+
self.model.zero_grad(
|
391 |
+
set_to_none=True
|
392 |
+
) # set_to_none=True is for memory saving
|
393 |
+
|
394 |
+
if args.hyperparams.ema:
|
395 |
+
self.model_ema.update()
|
396 |
+
############################################################
|
397 |
+
|
398 |
+
# ###################### 6. Plot ######################
|
399 |
+
|
400 |
+
if i % 30 == 0:
|
401 |
+
# loss print for multiple loss function
|
402 |
+
multiple_score = torch.Tensor(
|
403 |
+
[value for key, value in loss_logger.items()]
|
404 |
+
).to(args.gpu)
|
405 |
+
gathered_score_list = [
|
406 |
+
torch.ones_like(multiple_score)
|
407 |
+
for _ in range(dist.get_world_size())
|
408 |
+
]
|
409 |
+
dist.all_gather(gathered_score_list, multiple_score)
|
410 |
+
gathered_score = torch.mean(
|
411 |
+
torch.stack(gathered_score_list, dim=0), dim=0
|
412 |
+
)
|
413 |
+
if args.gpu == 0:
|
414 |
+
print(f"Epoch {epoch}, step {i} / {len(self.train_loader)}")
|
415 |
+
temp_loss_logger = {}
|
416 |
+
for index, (key, value) in enumerate(loss_logger.items()):
|
417 |
+
temp_key = key.replace("train/", "iter-wise/")
|
418 |
+
temp_loss_logger[temp_key] = round(
|
419 |
+
gathered_score[index].item(), 6
|
420 |
+
)
|
421 |
+
print(f"{key} : {round(gathered_score[index].item(), 6)}")
|
422 |
+
|
423 |
+
single_score = torch.Tensor([losses.avg]).to(args.gpu)
|
424 |
+
|
425 |
+
gathered_score_list = [
|
426 |
+
torch.ones_like(single_score) for _ in range(dist.get_world_size())
|
427 |
+
]
|
428 |
+
dist.all_gather(gathered_score_list, single_score)
|
429 |
+
gathered_score = torch.mean(torch.cat(gathered_score_list)).item()
|
430 |
+
if args.gpu == 0:
|
431 |
+
self.train_losses.append(gathered_score)
|
432 |
+
if args.wandb_params.use_wandb:
|
433 |
+
loss_logger["train/train loss"] = single_score
|
434 |
+
loss_logger["train/epoch"] = epoch
|
435 |
+
wandb.log(loss_logger)
|
436 |
+
############################################################
|
437 |
+
|
438 |
+
def multi_validate(self, args, epoch):
|
439 |
+
if args.gpu == 0:
|
440 |
+
print(f"Epoch {epoch} Validation session!")
|
441 |
+
|
442 |
+
losses = utils.AverageMeter()
|
443 |
+
|
444 |
+
loss_logger = {}
|
445 |
+
|
446 |
+
self.model.eval()
|
447 |
+
|
448 |
+
with torch.no_grad():
|
449 |
+
for i, values in enumerate(self.valid_loader, start=1):
|
450 |
+
mixture, clean, song_name, *valid_vars = values
|
451 |
+
|
452 |
+
mixture = mixture.cuda(args.gpu, non_blocking=True)
|
453 |
+
clean = clean.cuda(args.gpu, non_blocking=True)
|
454 |
+
target = clean
|
455 |
+
|
456 |
+
dict_loss = {}
|
457 |
+
if not args.data_params.singleset_num_frames:
|
458 |
+
if args.hyperparams.ema:
|
459 |
+
estimates, *estimates_vars = self.model_ema(mixture)
|
460 |
+
else:
|
461 |
+
estimates, *estimates_vars = self.model(mixture)
|
462 |
+
if args.task_params.dataset == "delimit":
|
463 |
+
estimates = estimates_vars[0]
|
464 |
+
|
465 |
+
estimates = estimates[..., : clean.size(-1)]
|
466 |
+
|
467 |
+
else: # use SingleTrackSet
|
468 |
+
db = SingleTrackSet(
|
469 |
+
mixture[0],
|
470 |
+
hop_length=args.data_params.nhop,
|
471 |
+
num_frame=args.data_params.singleset_num_frames,
|
472 |
+
target_name=args.task_params.target,
|
473 |
+
)
|
474 |
+
separated = []
|
475 |
+
|
476 |
+
for item in db:
|
477 |
+
|
478 |
+
if args.hyperparams.ema:
|
479 |
+
estimates, *estimates_vars = self.model_ema(
|
480 |
+
item.unsqueeze(0).to(args.gpu)
|
481 |
+
)
|
482 |
+
else:
|
483 |
+
estimates, *estimates_vars = self.model(
|
484 |
+
item.unsqueeze(0).to(args.gpu)
|
485 |
+
)
|
486 |
+
|
487 |
+
if args.task_params.dataset == "delimit":
|
488 |
+
estimates = estimates_vars[0]
|
489 |
+
|
490 |
+
separated.append(
|
491 |
+
estimates_vars[0][
|
492 |
+
..., db.trim_length : -db.trim_length
|
493 |
+
].clone()
|
494 |
+
)
|
495 |
+
|
496 |
+
estimates = torch.cat(separated, dim=-1)
|
497 |
+
estimates = estimates[..., : target.shape[-1]]
|
498 |
+
|
499 |
+
for valid_loss_idx, single_valid_loss_func in enumerate(
|
500 |
+
args.model_loss_params.valid_loss_func
|
501 |
+
):
|
502 |
+
loss_else = self.criterion[single_valid_loss_func](
|
503 |
+
estimates,
|
504 |
+
target,
|
505 |
+
)
|
506 |
+
dict_loss[single_valid_loss_func] = (
|
507 |
+
loss_else.mean()
|
508 |
+
* args.model_loss_params.valid_loss_scales[valid_loss_idx]
|
509 |
+
)
|
510 |
+
|
511 |
+
loss = sum([value for key, value in dict_loss.items()])
|
512 |
+
|
513 |
+
losses.update(loss.item(), clean.size(0))
|
514 |
+
|
515 |
+
list_sum_count = torch.Tensor([losses.sum, losses.count]).to(args.gpu)
|
516 |
+
list_gathered_sum_count = [
|
517 |
+
torch.ones_like(list_sum_count) for _ in range(dist.get_world_size())
|
518 |
+
]
|
519 |
+
dist.all_gather(list_gathered_sum_count, list_sum_count)
|
520 |
+
gathered_score = reduce(
|
521 |
+
torch.stack(list_gathered_sum_count), "s c -> c", "sum"
|
522 |
+
) # s: sum of losses.sum, c: sum of losses.count
|
523 |
+
gathered_score = (gathered_score[0] / gathered_score[1]).item()
|
524 |
+
|
525 |
+
loss_logger["valid/valid loss"] = gathered_score
|
526 |
+
for key, value in dict_loss.items():
|
527 |
+
loss_logger[f"valid/{key}"] = value.item()
|
528 |
+
|
529 |
+
if args.hyperparams.lr_scheduler == "step_lr":
|
530 |
+
self.scheduler.step(gathered_score)
|
531 |
+
elif args.hyperparams.lr_scheduler == "cos_warmup":
|
532 |
+
self.scheduler.step(epoch)
|
533 |
+
else:
|
534 |
+
self.scheduler.step(gathered_score)
|
535 |
+
|
536 |
+
if args.wandb_params.use_wandb and args.gpu == 0:
|
537 |
+
loss_logger["valid/epoch"] = epoch
|
538 |
+
wandb.log(loss_logger)
|
539 |
+
|
540 |
+
if args.gpu == 0:
|
541 |
+
self.valid_losses.append(gathered_score)
|
542 |
+
|
543 |
+
self.stop = self.es.step(gathered_score)
|
544 |
+
|
545 |
+
print(f"Epoch {epoch}, validation loss : {round(gathered_score, 6)}")
|
546 |
+
|
547 |
+
plt.plot(self.train_losses, label="train loss")
|
548 |
+
plt.plot(self.valid_losses, label="valid loss")
|
549 |
+
plt.legend(loc="upper right")
|
550 |
+
plt.savefig(f"{args.output}/loss_graph_{args.task_params.target}.png")
|
551 |
+
plt.close()
|
552 |
+
|
553 |
+
save_states = {
|
554 |
+
"epoch": epoch,
|
555 |
+
"state_dict": self.model.module.state_dict()
|
556 |
+
if not args.hyperparams.ema
|
557 |
+
else self.model_ema.state_dict(),
|
558 |
+
"best_loss": self.es.best,
|
559 |
+
"optimizer": self.optimizer.state_dict(),
|
560 |
+
"scheduler": self.scheduler.state_dict(),
|
561 |
+
}
|
562 |
+
|
563 |
+
utils.save_checkpoint(
|
564 |
+
save_states,
|
565 |
+
state_dict_only=gathered_score == self.es.best,
|
566 |
+
path=args.output,
|
567 |
+
target=args.task_params.target,
|
568 |
+
)
|
569 |
+
|
570 |
+
self.train_times.append(time.time() - self.end)
|
571 |
+
|
572 |
+
if gathered_score == self.es.best:
|
573 |
+
self.best_epoch = epoch
|
574 |
+
|
575 |
+
# save params
|
576 |
+
params = {
|
577 |
+
"epochs_trained": epoch,
|
578 |
+
"args": args.toDict(),
|
579 |
+
"best_loss": self.es.best,
|
580 |
+
"best_epoch": self.best_epoch,
|
581 |
+
"train_loss_history": self.train_losses,
|
582 |
+
"valid_loss_history": self.valid_losses,
|
583 |
+
"train_time_history": self.train_times,
|
584 |
+
"num_bad_epochs": self.es.num_bad_epochs,
|
585 |
+
}
|
586 |
+
|
587 |
+
with open(
|
588 |
+
f"{args.output}/{args.task_params.target}.json", "w"
|
589 |
+
) as outfile:
|
590 |
+
outfile.write(json.dumps(params, indent=4, sort_keys=True))
|
591 |
+
|
592 |
+
self.train_times.append(time.time() - self.end)
|
593 |
+
print(
|
594 |
+
f"Epoch {epoch} train completed. Took {round(self.train_times[-1], 3)} seconds"
|
595 |
+
)
|
596 |
+
|
597 |
+
def resume(self, args):
|
598 |
+
print(f"Resume checkpoint from: {args.dir_params.resume}:")
|
599 |
+
loc = f"cuda:{args.gpu}"
|
600 |
+
checkpoint_path = f"{args.dir_params.resume}/{args.task_params.target}"
|
601 |
+
with open(f"{checkpoint_path}.json", "r") as stream:
|
602 |
+
results = json.load(stream)
|
603 |
+
checkpoint = torch.load(f"{checkpoint_path}.chkpnt", map_location=loc)
|
604 |
+
|
605 |
+
if args.hyperparams.ema:
|
606 |
+
self.model_ema.load_state_dict(checkpoint["state_dict"])
|
607 |
+
else:
|
608 |
+
self.model.load_state_dict(checkpoint["state_dict"])
|
609 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
610 |
+
|
611 |
+
if (
|
612 |
+
args.dir_params.continual_train
|
613 |
+
): # we want to use a pre-trained model but not want to use lr_scheduler history.
|
614 |
+
for param_group in self.optimizer.param_groups:
|
615 |
+
param_group["lr"] = args.hyperparams.lr
|
616 |
+
else:
|
617 |
+
self.scheduler.load_state_dict(checkpoint["scheduler"])
|
618 |
+
self.es.best = results["best_loss"]
|
619 |
+
self.es.num_bad_epochs = results["num_bad_epochs"]
|
620 |
+
|
621 |
+
self.start_epoch = results["epochs_trained"]
|
622 |
+
self.train_losses = results["train_loss_history"]
|
623 |
+
self.valid_losses = results["valid_loss_history"]
|
624 |
+
self.train_times = results["train_time_history"]
|
625 |
+
self.best_epoch = results["best_epoch"]
|
626 |
+
if args.sys_params.rank % args.ngpus_per_node == 0:
|
627 |
+
print(
|
628 |
+
f"=> loaded checkpoint {checkpoint_path} (epoch {results['epochs_trained']})"
|
629 |
+
)
|
630 |
+
|
631 |
+
def cal_loss(self, args, loss_input):
|
632 |
+
loss_dict = {}
|
633 |
+
for key, value in loss_input.items():
|
634 |
+
loss_dict[key] = self.criterion[key](*value)
|
635 |
+
|
636 |
+
return loss_dict
|
637 |
+
|
638 |
+
def cal_multiple_losses(self, args, dict_loss_name_input):
|
639 |
+
loss_dict = {}
|
640 |
+
for loss_name, loss_input in dict_loss_name_input.items():
|
641 |
+
loss_dict[loss_name] = self.cal_loss(args, loss_input)
|
642 |
+
|
643 |
+
return loss_dict
|
test_ddp.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# To be honest... this is not ddp.
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
import glob
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import tqdm
|
9 |
+
import musdb
|
10 |
+
import librosa
|
11 |
+
import soundfile as sf
|
12 |
+
import pyloudnorm as pyln
|
13 |
+
from dotmap import DotMap
|
14 |
+
|
15 |
+
from models import load_model_with_args
|
16 |
+
from separate_func import (
|
17 |
+
conv_tasnet_separate,
|
18 |
+
)
|
19 |
+
from utils import str2bool, db2linear
|
20 |
+
|
21 |
+
|
22 |
+
tqdm.monitor_interval = 0
|
23 |
+
|
24 |
+
|
25 |
+
def separate_track_with_model(
|
26 |
+
args, model, device, track_audio, track_name, meter, augmented_gain
|
27 |
+
):
|
28 |
+
with torch.no_grad():
|
29 |
+
if (
|
30 |
+
args.model_loss_params.architecture == "conv_tasnet_mask_on_output"
|
31 |
+
or args.model_loss_params.architecture == "conv_tasnet"
|
32 |
+
):
|
33 |
+
estimates = conv_tasnet_separate(
|
34 |
+
args,
|
35 |
+
model,
|
36 |
+
device,
|
37 |
+
track_audio,
|
38 |
+
track_name,
|
39 |
+
meter=meter,
|
40 |
+
augmented_gain=augmented_gain,
|
41 |
+
)
|
42 |
+
|
43 |
+
return estimates
|
44 |
+
|
45 |
+
|
46 |
+
def main():
|
47 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
48 |
+
|
49 |
+
parser.add_argument("--target", type=str, default="all")
|
50 |
+
parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL")
|
51 |
+
parser.add_argument(
|
52 |
+
"--use_musdb",
|
53 |
+
type=str2bool,
|
54 |
+
default=True,
|
55 |
+
help="Use musdb test data or just want to inference other samples?",
|
56 |
+
)
|
57 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s')
|
58 |
+
parser.add_argument("--manual_output_name", type=str, default=None)
|
59 |
+
parser.add_argument(
|
60 |
+
"--output_directory", type=str, default="/path/to/results"
|
61 |
+
)
|
62 |
+
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
63 |
+
parser.add_arugment("--save_name_as_target", type=str2bool, default=True)
|
64 |
+
parser.add_argument(
|
65 |
+
"--loudnorm_input_lufs",
|
66 |
+
type=float,
|
67 |
+
default=None,
|
68 |
+
help="If you want to use loudnorm, input target lufs",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--use_singletrackset",
|
72 |
+
type=str2bool,
|
73 |
+
default=False,
|
74 |
+
help="Use SingleTrackSet for X-UMX",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--best_model",
|
78 |
+
type=str2bool,
|
79 |
+
default=True,
|
80 |
+
help="Use best model or lastly saved model",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--save_output_loudnorm",
|
84 |
+
type=float,
|
85 |
+
default=None,
|
86 |
+
help="Save loudness normalized outputs or not. If you want to save, input target loudness",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--save_mixed_output",
|
90 |
+
type=float,
|
91 |
+
default=None,
|
92 |
+
help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--save_16k_mono",
|
96 |
+
type=str2bool,
|
97 |
+
default=False,
|
98 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--save_histogram",
|
102 |
+
type=str2bool,
|
103 |
+
default=False,
|
104 |
+
help="Save histogram of the output. Only valid when the task is 'delimit'",
|
105 |
+
)
|
106 |
+
|
107 |
+
args, _ = parser.parse_known_args()
|
108 |
+
|
109 |
+
args.output_dir = f"{args.output_directory}/checkpoint/{args.exp_name}"
|
110 |
+
with open(f"{args.output_dir}/{args.target}.json", "r") as f:
|
111 |
+
args_dict = json.load(f)
|
112 |
+
args_dict = DotMap(args_dict)
|
113 |
+
|
114 |
+
for key, value in args_dict["args"].items():
|
115 |
+
if key in list(vars(args).keys()):
|
116 |
+
pass
|
117 |
+
else:
|
118 |
+
setattr(args, key, value)
|
119 |
+
|
120 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
121 |
+
|
122 |
+
if args.manual_output_name != None:
|
123 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.manual_output_name}"
|
124 |
+
os.makedirs(args.test_output_dir, exist_ok=True)
|
125 |
+
|
126 |
+
device = torch.device(
|
127 |
+
"cuda" if torch.cuda.is_available() and args.use_gpu else "cpu"
|
128 |
+
)
|
129 |
+
|
130 |
+
###################### Define Models ######################
|
131 |
+
our_model = load_model_with_args(args)
|
132 |
+
our_model = our_model.to(device)
|
133 |
+
print(our_model)
|
134 |
+
pytorch_total_params = sum(
|
135 |
+
p.numel() for p in our_model.parameters() if p.requires_grad
|
136 |
+
)
|
137 |
+
print("Total number of parameters", pytorch_total_params)
|
138 |
+
# Future work => Torchinfo would be better for this purpose.
|
139 |
+
|
140 |
+
if args.best_model:
|
141 |
+
target_model_path = f"{args.output_dir}/{args.target}.pth"
|
142 |
+
checkpoint = torch.load(target_model_path, map_location=device)
|
143 |
+
our_model.load_state_dict(checkpoint)
|
144 |
+
else: # when using lastly saved model
|
145 |
+
target_model_path = f"{args.output_dir}/{args.target}.chkpnt"
|
146 |
+
checkpoint = torch.load(target_model_path, map_location=device)
|
147 |
+
our_model.load_state_dict(checkpoint["state_dict"])
|
148 |
+
|
149 |
+
our_model.eval()
|
150 |
+
|
151 |
+
meter = pyln.Meter(44100)
|
152 |
+
|
153 |
+
if args.use_musdb:
|
154 |
+
test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True)
|
155 |
+
|
156 |
+
for track in tqdm.tqdm(test_tracks):
|
157 |
+
track_name = track.name
|
158 |
+
track_audio = track.audio
|
159 |
+
|
160 |
+
orig_audio = track_audio.copy()
|
161 |
+
|
162 |
+
augmented_gain = None
|
163 |
+
print("Now De-limiting : ", track_name)
|
164 |
+
|
165 |
+
if args.loudnorm_input_lufs: # If you want to use loud-normalized input
|
166 |
+
track_lufs = meter.integrated_loudness(track_audio)
|
167 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
168 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
169 |
+
|
170 |
+
track_audio = (
|
171 |
+
torch.as_tensor(track_audio.T, dtype=torch.float32)
|
172 |
+
.unsqueeze(0)
|
173 |
+
.to(device)
|
174 |
+
)
|
175 |
+
|
176 |
+
estimates = separate_track_with_model(
|
177 |
+
args, our_model, device, track_audio, track_name, meter, augmented_gain
|
178 |
+
)
|
179 |
+
|
180 |
+
if args.save_mixed_output:
|
181 |
+
orig_audio = orig_audio.T
|
182 |
+
track_lufs = meter.integrated_loudness(orig_audio.T)
|
183 |
+
augmented_gain = args.save_output_loudnorm - track_lufs
|
184 |
+
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
|
185 |
+
|
186 |
+
mixed_output = orig_audio * args.save_mixed_output + estimates * (
|
187 |
+
1 - args.save_mixed_output
|
188 |
+
)
|
189 |
+
|
190 |
+
sf.write(
|
191 |
+
f"{args.test_output_dir}/{track_name}/{str(args.save_mixed_output)}_mixed.wav",
|
192 |
+
mixed_output.T,
|
193 |
+
args.data_params.sample_rate,
|
194 |
+
)
|
195 |
+
else:
|
196 |
+
test_tracks = glob.glob(f"{args.data_root}/*.wav") + glob.glob(
|
197 |
+
f"{args.data_root}/*.mp3"
|
198 |
+
)
|
199 |
+
|
200 |
+
for track in tqdm.tqdm(test_tracks):
|
201 |
+
track_name = os.path.basename(track).replace(".wav", "").replace(".mp3", "")
|
202 |
+
track_audio, sr = librosa.load(
|
203 |
+
track, sr=None, mono=False
|
204 |
+
) # sr should be 44100
|
205 |
+
|
206 |
+
orig_audio = track_audio.copy()
|
207 |
+
|
208 |
+
if sr != 44100:
|
209 |
+
raise ValueError("Sample rate should be 44100")
|
210 |
+
augmented_gain = None
|
211 |
+
print("Now De-limiting : ", track_name)
|
212 |
+
|
213 |
+
if args.loudnorm_input_lufs: # If you want to use loud-normalized input
|
214 |
+
track_lufs = meter.integrated_loudness(track_audio.T)
|
215 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
216 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
217 |
+
|
218 |
+
track_audio = (
|
219 |
+
torch.as_tensor(track_audio, dtype=torch.float32)
|
220 |
+
.unsqueeze(0)
|
221 |
+
.to(device)
|
222 |
+
)
|
223 |
+
|
224 |
+
estimates = separate_track_with_model(
|
225 |
+
args, our_model, device, track_audio, track_name, meter, augmented_gain
|
226 |
+
)
|
227 |
+
|
228 |
+
if args.save_mixed_output:
|
229 |
+
track_lufs = meter.integrated_loudness(orig_audio.T)
|
230 |
+
augmented_gain = args.save_output_loudnorm - track_lufs
|
231 |
+
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
|
232 |
+
|
233 |
+
mixed_output = orig_audio * args.save_mixed_output + estimates * (
|
234 |
+
1 - args.save_mixed_output
|
235 |
+
)
|
236 |
+
|
237 |
+
sf.write(
|
238 |
+
f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav",
|
239 |
+
mixed_output.T,
|
240 |
+
args.data_params.sample_rate,
|
241 |
+
)
|
242 |
+
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
main()
|
train_ddp.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.multiprocessing as mp
|
6 |
+
import torch.distributed as dist
|
7 |
+
import wandb
|
8 |
+
|
9 |
+
from solver_ddp import Solver
|
10 |
+
|
11 |
+
|
12 |
+
def train(args):
|
13 |
+
print("hello")
|
14 |
+
solver = Solver()
|
15 |
+
|
16 |
+
ngpus_per_node = int(torch.cuda.device_count() / args.sys_params.n_nodes)
|
17 |
+
print(f"use {ngpus_per_node} gpu machine")
|
18 |
+
args.sys_params.world_size = ngpus_per_node * args.sys_params.n_nodes
|
19 |
+
mp.spawn(worker, nprocs=ngpus_per_node, args=(solver, ngpus_per_node, args))
|
20 |
+
|
21 |
+
|
22 |
+
def worker(gpu, solver, ngpus_per_node, args):
|
23 |
+
args.sys_params.rank = args.sys_params.rank * ngpus_per_node + gpu
|
24 |
+
dist.init_process_group(
|
25 |
+
backend="nccl",
|
26 |
+
world_size=args.sys_params.world_size,
|
27 |
+
init_method="env://",
|
28 |
+
rank=args.sys_params.rank,
|
29 |
+
)
|
30 |
+
args.gpu = gpu
|
31 |
+
args.ngpus_per_node = ngpus_per_node
|
32 |
+
|
33 |
+
solver.set_gpu(args)
|
34 |
+
|
35 |
+
start_epoch = solver.start_epoch
|
36 |
+
|
37 |
+
if args.dir_params.resume:
|
38 |
+
start_epoch = start_epoch + 1
|
39 |
+
|
40 |
+
for epoch in range(start_epoch, args.hyperparams.epochs + 1):
|
41 |
+
|
42 |
+
solver.train_sampler.set_epoch(epoch)
|
43 |
+
solver.train(args, epoch)
|
44 |
+
|
45 |
+
time.sleep(1)
|
46 |
+
|
47 |
+
solver.multi_validate(args, epoch)
|
48 |
+
|
49 |
+
if solver.stop == True:
|
50 |
+
print("Apply Early Stopping")
|
51 |
+
if args.wandb_params.use_wandb:
|
52 |
+
wandb.finish()
|
53 |
+
sys.exit()
|
54 |
+
|
55 |
+
if args.wandb_params.use_wandb:
|
56 |
+
wandb.finish()
|
utils/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .read_wave_utils import (
|
2 |
+
load_wav_arbitrary_position_mono,
|
3 |
+
load_wav_specific_position_mono,
|
4 |
+
load_wav_arbitrary_position_stereo,
|
5 |
+
load_wav_specific_position_stereo,
|
6 |
+
)
|
7 |
+
from .loudness_utils import (
|
8 |
+
linear2db,
|
9 |
+
db2linear,
|
10 |
+
normalize_mag_spec,
|
11 |
+
denormalize_mag_spec,
|
12 |
+
loudness_match_and_norm,
|
13 |
+
loudness_normal_match_and_norm,
|
14 |
+
loudness_normal_match_and_norm_output_louder_first,
|
15 |
+
loudnorm,
|
16 |
+
)
|
17 |
+
from .logging import save_img_and_npy, save_checkpoint, AverageMeter, EarlyStopping
|
18 |
+
from .lr_scheduler import CosineAnnealingWarmUpRestarts
|
19 |
+
from .train_utils import worker_init_fn, str2bool, get_config
|
utils/logging.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib
|
6 |
+
|
7 |
+
matplotlib.use("Agg")
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
|
11 |
+
def save_img_and_npy(path, matrix):
|
12 |
+
plt.imsave(path + ".png", matrix, origin="lower")
|
13 |
+
|
14 |
+
|
15 |
+
def save_checkpoint(state, state_dict_only, path, target):
|
16 |
+
torch.save(state, os.path.join(path, target + ".chkpnt"))
|
17 |
+
if state_dict_only:
|
18 |
+
# save just the weights
|
19 |
+
torch.save(state["state_dict"], os.path.join(path, target + ".pth"))
|
20 |
+
|
21 |
+
|
22 |
+
class AverageMeter(object):
|
23 |
+
"""Computes and stores the average and current value"""
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
self.reset()
|
27 |
+
|
28 |
+
def reset(self):
|
29 |
+
self.val = 0
|
30 |
+
self.avg = 0
|
31 |
+
self.sum = 0
|
32 |
+
self.count = 0
|
33 |
+
|
34 |
+
def update(self, val, n=1):
|
35 |
+
self.val = val
|
36 |
+
self.sum += val * n
|
37 |
+
self.count += n
|
38 |
+
self.avg = self.sum / self.count
|
39 |
+
|
40 |
+
|
41 |
+
class EarlyStopping(object):
|
42 |
+
def __init__(self, mode="min", min_delta=0, patience=10):
|
43 |
+
self.mode = mode
|
44 |
+
self.min_delta = min_delta
|
45 |
+
self.patience = patience
|
46 |
+
self.best = None
|
47 |
+
self.num_bad_epochs = 0
|
48 |
+
self.is_better = None
|
49 |
+
self._init_is_better(mode, min_delta)
|
50 |
+
|
51 |
+
if patience == 0:
|
52 |
+
self.is_better = lambda a, b: True
|
53 |
+
|
54 |
+
def step(self, metrics):
|
55 |
+
if self.best is None:
|
56 |
+
self.best = metrics
|
57 |
+
return False
|
58 |
+
|
59 |
+
if np.isnan(metrics):
|
60 |
+
return True
|
61 |
+
|
62 |
+
if self.is_better(metrics, self.best):
|
63 |
+
self.num_bad_epochs = 0
|
64 |
+
self.best = metrics
|
65 |
+
else:
|
66 |
+
self.num_bad_epochs += 1
|
67 |
+
|
68 |
+
if self.num_bad_epochs >= self.patience:
|
69 |
+
return True
|
70 |
+
|
71 |
+
return False
|
72 |
+
|
73 |
+
def _init_is_better(self, mode, min_delta):
|
74 |
+
if mode not in {"min", "max"}:
|
75 |
+
raise ValueError("mode " + mode + " is unknown!")
|
76 |
+
if mode == "min":
|
77 |
+
self.is_better = lambda a, best: a < best - min_delta
|
78 |
+
if mode == "max":
|
79 |
+
self.is_better = lambda a, best: a > best + min_delta
|
utils/loudness_utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def linear2db(x, eps=1e-5, scale=20):
|
8 |
+
return scale * np.log10(x + eps)
|
9 |
+
|
10 |
+
|
11 |
+
def db2linear(x, eps=1e-5, scale=20):
|
12 |
+
return 10 ** (x / scale) - eps
|
13 |
+
|
14 |
+
|
15 |
+
def normalize_mag_spec(S, min_level_db=-100.0):
|
16 |
+
return torch.clamp((S - min_level_db) / -min_level_db, min=0.0, max=1.0)
|
17 |
+
|
18 |
+
|
19 |
+
def denormalize_mag_spec(S, min_level_db=-100.0):
|
20 |
+
return torch.clamp(S, min=0.0, max=1.0) * -min_level_db + min_level_db
|
21 |
+
|
22 |
+
|
23 |
+
def loudness_match_and_norm(audio1, audio2, meter):
|
24 |
+
lufs_1 = meter.integrated_loudness(audio1)
|
25 |
+
lufs_2 = meter.integrated_loudness(audio2)
|
26 |
+
|
27 |
+
if np.isinf(lufs_1) or np.isinf(lufs_2):
|
28 |
+
return audio1, audio2
|
29 |
+
else:
|
30 |
+
audio2 = audio2 * db2linear(lufs_1 - lufs_2)
|
31 |
+
|
32 |
+
return audio1, audio2
|
33 |
+
|
34 |
+
|
35 |
+
def loudness_normal_match_and_norm(audio1, audio2, meter):
|
36 |
+
lufs_1 = meter.integrated_loudness(audio1)
|
37 |
+
lufs_2 = meter.integrated_loudness(audio2)
|
38 |
+
|
39 |
+
if np.isinf(lufs_1) or np.isinf(lufs_2):
|
40 |
+
return audio1, audio2
|
41 |
+
else:
|
42 |
+
target_lufs = random.normalvariate(lufs_1, 6.0)
|
43 |
+
audio2 = audio2 * db2linear(target_lufs - lufs_2)
|
44 |
+
|
45 |
+
return audio1, audio2
|
46 |
+
|
47 |
+
|
48 |
+
def loudness_normal_match_and_norm_output_louder_first(audio1, audio2, meter):
|
49 |
+
lufs_1 = meter.integrated_loudness(audio1)
|
50 |
+
lufs_2 = meter.integrated_loudness(audio2)
|
51 |
+
|
52 |
+
if np.isinf(lufs_1) or np.isinf(lufs_2):
|
53 |
+
return audio1, audio2
|
54 |
+
else:
|
55 |
+
target_lufs = random.normalvariate(
|
56 |
+
lufs_1 - 2.0, 2.0
|
57 |
+
) # we want audio1 to be louder than audio2 about target_lufs_diff
|
58 |
+
audio2 = audio2 * db2linear(target_lufs - lufs_2)
|
59 |
+
|
60 |
+
return audio1, audio2
|
61 |
+
|
62 |
+
|
63 |
+
def loudnorm(audio, target_lufs, meter, eps=1e-5):
|
64 |
+
lufs = meter.integrated_loudness(audio)
|
65 |
+
if np.isinf(lufs):
|
66 |
+
return audio, 0.0
|
67 |
+
else:
|
68 |
+
adjusted_gain = target_lufs - lufs
|
69 |
+
audio = audio * db2linear(adjusted_gain, eps)
|
70 |
+
|
71 |
+
return audio, adjusted_gain
|
utils/lr_scheduler.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
4 |
+
|
5 |
+
|
6 |
+
class CosineAnnealingWarmUpRestarts(_LRScheduler):
|
7 |
+
def __init__(
|
8 |
+
self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1.0, last_epoch=-1
|
9 |
+
):
|
10 |
+
if T_0 <= 0 or not isinstance(T_0, int):
|
11 |
+
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
12 |
+
if T_mult < 1 or not isinstance(T_mult, int):
|
13 |
+
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
14 |
+
if T_up < 0 or not isinstance(T_up, int):
|
15 |
+
raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
|
16 |
+
self.T_0 = T_0
|
17 |
+
self.T_mult = T_mult
|
18 |
+
self.base_eta_max = eta_max
|
19 |
+
self.eta_max = eta_max
|
20 |
+
self.T_up = T_up
|
21 |
+
self.T_i = T_0
|
22 |
+
self.gamma = gamma
|
23 |
+
self.cycle = 0
|
24 |
+
self.T_cur = last_epoch
|
25 |
+
super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
|
26 |
+
|
27 |
+
def get_lr(self):
|
28 |
+
if self.T_cur == -1:
|
29 |
+
return self.base_lrs
|
30 |
+
elif self.T_cur < self.T_up:
|
31 |
+
return [
|
32 |
+
(self.eta_max - base_lr) * self.T_cur / self.T_up + base_lr
|
33 |
+
for base_lr in self.base_lrs
|
34 |
+
]
|
35 |
+
else:
|
36 |
+
return [
|
37 |
+
base_lr
|
38 |
+
+ (self.eta_max - base_lr)
|
39 |
+
* (
|
40 |
+
1
|
41 |
+
+ math.cos(
|
42 |
+
math.pi * (self.T_cur - self.T_up) / (self.T_i - self.T_up)
|
43 |
+
)
|
44 |
+
)
|
45 |
+
/ 2
|
46 |
+
for base_lr in self.base_lrs
|
47 |
+
]
|
48 |
+
|
49 |
+
def step(self, epoch=None):
|
50 |
+
if epoch is None:
|
51 |
+
epoch = self.last_epoch + 1
|
52 |
+
self.T_cur = self.T_cur + 1
|
53 |
+
if self.T_cur >= self.T_i:
|
54 |
+
self.cycle += 1
|
55 |
+
self.T_cur = self.T_cur - self.T_i
|
56 |
+
self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
|
57 |
+
else:
|
58 |
+
if epoch >= self.T_0:
|
59 |
+
if self.T_mult == 1:
|
60 |
+
self.T_cur = epoch % self.T_0
|
61 |
+
self.cycle = epoch // self.T_0
|
62 |
+
else:
|
63 |
+
n = int(
|
64 |
+
math.log(
|
65 |
+
(epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
|
66 |
+
)
|
67 |
+
)
|
68 |
+
self.cycle = n
|
69 |
+
self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (
|
70 |
+
self.T_mult - 1
|
71 |
+
)
|
72 |
+
self.T_i = self.T_0 * self.T_mult ** (n)
|
73 |
+
else:
|
74 |
+
self.T_i = self.T_0
|
75 |
+
self.T_cur = epoch
|
76 |
+
|
77 |
+
self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
|
78 |
+
self.last_epoch = math.floor(epoch)
|
79 |
+
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
80 |
+
param_group["lr"] = lr
|
utils/read_wave_utils.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import librosa
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
|
9 |
+
def load_wav_arbitrary_position_mono(filename, sample_rate, seq_duration):
|
10 |
+
# mono
|
11 |
+
# seq_duration[second]
|
12 |
+
length = torchaudio.info(filename).num_frames
|
13 |
+
|
14 |
+
read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
|
15 |
+
if length > read_length:
|
16 |
+
random_start = random.randint(0, int(length - read_length - 1)) / sample_rate
|
17 |
+
X, sr = librosa.load(
|
18 |
+
filename, sr=None, offset=random_start, duration=seq_duration
|
19 |
+
)
|
20 |
+
else:
|
21 |
+
random_start = 0
|
22 |
+
total_pad_length = read_length - length
|
23 |
+
X, sr = librosa.load(filename, sr=None, offset=0, duration=seq_duration)
|
24 |
+
pad_left = random.randint(0, total_pad_length)
|
25 |
+
X = np.pad(X, (pad_left, total_pad_length - pad_left))
|
26 |
+
|
27 |
+
return X
|
28 |
+
|
29 |
+
|
30 |
+
def load_wav_specific_position_mono(
|
31 |
+
filename, sample_rate, seq_duration, start_position
|
32 |
+
):
|
33 |
+
# mono
|
34 |
+
# seq_duration[second]
|
35 |
+
# start_position[second]
|
36 |
+
length = torchaudio.info(filename).num_frames
|
37 |
+
read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
|
38 |
+
|
39 |
+
start_pos_sec = max(
|
40 |
+
start_position, 0
|
41 |
+
) # if start_position is minus, then start from 0.
|
42 |
+
start_pos_sample = librosa.time_to_samples(start_pos_sec, sr=sample_rate)
|
43 |
+
|
44 |
+
if (
|
45 |
+
length <= start_pos_sample
|
46 |
+
): # if start position exceeds audio length, then start from 0.
|
47 |
+
start_pos_sec = 0
|
48 |
+
start_pos_sample = 0
|
49 |
+
X, sr = librosa.load(filename, sr=None, offset=start_pos_sec, duration=seq_duration)
|
50 |
+
|
51 |
+
if length < start_pos_sample + read_length:
|
52 |
+
X = np.pad(X, (0, (start_pos_sample + read_length) - length))
|
53 |
+
|
54 |
+
return X
|
55 |
+
|
56 |
+
|
57 |
+
# load wav file from arbitrary positions of 16bit stereo wav file
|
58 |
+
def load_wav_arbitrary_position_stereo(
|
59 |
+
filename, sample_rate, seq_duration, return_pos=False
|
60 |
+
):
|
61 |
+
# stereo
|
62 |
+
# seq_duration[second]
|
63 |
+
length = torchaudio.info(filename).num_frames
|
64 |
+
read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
|
65 |
+
|
66 |
+
random_start_sample = random.randint(
|
67 |
+
0, int(length - math.ceil(seq_duration * sample_rate) - 1)
|
68 |
+
)
|
69 |
+
random_start_sec = librosa.samples_to_time(random_start_sample, sr=sample_rate)
|
70 |
+
X, sr = librosa.load(
|
71 |
+
filename, sr=None, mono=False, offset=random_start_sec, duration=seq_duration
|
72 |
+
)
|
73 |
+
|
74 |
+
if length < random_start_sample + read_length:
|
75 |
+
X = np.pad(X, ((0, 0), (0, (random_start_sample + read_length) - length)))
|
76 |
+
|
77 |
+
if return_pos:
|
78 |
+
return X, random_start_sec
|
79 |
+
else:
|
80 |
+
return X
|
81 |
+
|
82 |
+
|
83 |
+
def load_wav_specific_position_stereo(
|
84 |
+
filename, sample_rate, seq_duration, start_position
|
85 |
+
):
|
86 |
+
# stereo
|
87 |
+
# seq_duration[second]
|
88 |
+
# start_position[second]
|
89 |
+
length = torchaudio.info(filename).num_frames
|
90 |
+
read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
|
91 |
+
|
92 |
+
start_pos_sec = max(
|
93 |
+
start_position, 0
|
94 |
+
) # if start_position is minus, then start from 0.
|
95 |
+
start_pos_sample = librosa.time_to_samples(start_pos_sec, sr=sample_rate)
|
96 |
+
|
97 |
+
if (
|
98 |
+
length <= start_pos_sample
|
99 |
+
): # if start position exceeds audio length, then start from 0.
|
100 |
+
start_pos_sec = 0
|
101 |
+
start_pos_sample = 0
|
102 |
+
X, sr = librosa.load(
|
103 |
+
filename, sr=None, mono=False, offset=start_pos_sec, duration=seq_duration
|
104 |
+
)
|
105 |
+
|
106 |
+
if length < start_pos_sample + read_length:
|
107 |
+
X = np.pad(X, ((0, 0), (0, (start_pos_sample + read_length) - length)))
|
108 |
+
|
109 |
+
return X
|
utils/train_utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
from dotmap import DotMap
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def worker_init_fn(worker_id):
|
9 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
10 |
+
|
11 |
+
|
12 |
+
def str2bool(v):
|
13 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
14 |
+
return True
|
15 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
16 |
+
return False
|
17 |
+
else:
|
18 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
19 |
+
|
20 |
+
|
21 |
+
def get_config(config_name="default"):
|
22 |
+
|
23 |
+
with open(f"./configs/{config_name}.yaml", "r") as f:
|
24 |
+
|
25 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
26 |
+
config = DotMap(config)
|
27 |
+
return config
|
weight/all.json
ADDED
@@ -0,0 +1,957 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"args": {
|
3 |
+
"classifier_params": {
|
4 |
+
"chosen_source_mean": 0.7,
|
5 |
+
"chosen_source_std": 0.15,
|
6 |
+
"classifier_activation": "softmax",
|
7 |
+
"classifier_n_classes": 4,
|
8 |
+
"classifier_n_srcs": 4,
|
9 |
+
"freeze_when_mixit": true,
|
10 |
+
"melspec_power": 2.0,
|
11 |
+
"model_name": "hrnet_w18_small",
|
12 |
+
"n_mels": 128,
|
13 |
+
"other_source_mean": 0.3,
|
14 |
+
"other_source_std": 0.15,
|
15 |
+
"pretrained_model": false,
|
16 |
+
"use_one_source_prob": 0.2,
|
17 |
+
"use_stereo": true
|
18 |
+
},
|
19 |
+
"conv_tasnet_params": {
|
20 |
+
"bn_chan": 128,
|
21 |
+
"decoder_activation": "sigmoid",
|
22 |
+
"encoder_activation": "relu",
|
23 |
+
"hid_chan": 512,
|
24 |
+
"kernel_size": 128,
|
25 |
+
"mask_act": "relu",
|
26 |
+
"n_blocks": 5,
|
27 |
+
"n_filters": 512,
|
28 |
+
"n_repeats": 2,
|
29 |
+
"skip_chan": 128,
|
30 |
+
"stride": 64
|
31 |
+
},
|
32 |
+
"data_params": {
|
33 |
+
"custom_limiter_attack_range": null,
|
34 |
+
"custom_limiter_release_range": null,
|
35 |
+
"limitaug_custom_target_lufs": null,
|
36 |
+
"limitaug_custom_target_lufs_std": null,
|
37 |
+
"limitaug_method": "ozone",
|
38 |
+
"limitaug_mode": null,
|
39 |
+
"nb_channels": 2,
|
40 |
+
"nfft": 4096,
|
41 |
+
"nhop": 1024,
|
42 |
+
"random_mix": true,
|
43 |
+
"sample_rate": 44100,
|
44 |
+
"samples_per_track": 128,
|
45 |
+
"seq_dur": 4.0,
|
46 |
+
"singleset_num_frames": null,
|
47 |
+
"target_limitaug_custom_target_lufs": null,
|
48 |
+
"target_limitaug_custom_target_lufs_std": null,
|
49 |
+
"target_limitaug_mode": null,
|
50 |
+
"target_loudnorm_lufs": -14.0,
|
51 |
+
"use_fixed": 0.019
|
52 |
+
},
|
53 |
+
"dir_params": {
|
54 |
+
"continual_train": false,
|
55 |
+
"delimit_valid_L_root": null,
|
56 |
+
"delimit_valid_root": null,
|
57 |
+
"exp_name": "convtasnet_35",
|
58 |
+
"output_directory": "/data2/personal/jeon/delimit/results",
|
59 |
+
"ozone_root": "/data5/personal/jeon/delimit/data",
|
60 |
+
"pretrained_classifier": null,
|
61 |
+
"resume": null,
|
62 |
+
"root": "/data1/Music/musdb18hq"
|
63 |
+
},
|
64 |
+
"gpu": 0,
|
65 |
+
"hyperparams": {
|
66 |
+
"batch_size": 8,
|
67 |
+
"ema": false,
|
68 |
+
"epochs": 200,
|
69 |
+
"gradient_clip": 5.0,
|
70 |
+
"lr": 3e-05,
|
71 |
+
"lr_decay_gamma": 0.5,
|
72 |
+
"lr_decay_patience": 15,
|
73 |
+
"lr_scheduler": "step_lr",
|
74 |
+
"optimizer": "adamw",
|
75 |
+
"patience": 50,
|
76 |
+
"weight_decay": 0.01
|
77 |
+
},
|
78 |
+
"img_check": "/data2/personal/jeon/delimit/results/img_check/convtasnet_35",
|
79 |
+
"invest_unet_params": {
|
80 |
+
"bn_factor": 16,
|
81 |
+
"f_down_layers": null,
|
82 |
+
"first_conv_activation": "relu",
|
83 |
+
"input_channels": 4,
|
84 |
+
"internal_channels": 24,
|
85 |
+
"kernel_size_f": 3,
|
86 |
+
"kernel_size_t": 3,
|
87 |
+
"last_activation": "identity",
|
88 |
+
"min_bn_units": 16,
|
89 |
+
"n_blocks": 7,
|
90 |
+
"n_internal_layers": 5,
|
91 |
+
"t_down_layers": null,
|
92 |
+
"tfc_tdf_activation": "relu",
|
93 |
+
"tfc_tdf_bias": true,
|
94 |
+
"tif_init_mode": null
|
95 |
+
},
|
96 |
+
"model_loss_params": {
|
97 |
+
"architecture": "conv_tasnet_mask_on_output",
|
98 |
+
"efficient_mixit_threshold": null,
|
99 |
+
"train_loss_func": [
|
100 |
+
"si_sdr"
|
101 |
+
],
|
102 |
+
"train_loss_scales": [
|
103 |
+
1.0
|
104 |
+
],
|
105 |
+
"valid_loss_func": [
|
106 |
+
"si_sdr"
|
107 |
+
],
|
108 |
+
"valid_loss_scales": [
|
109 |
+
1.0
|
110 |
+
]
|
111 |
+
},
|
112 |
+
"ngpus_per_node": 1,
|
113 |
+
"output": "/data2/personal/jeon/delimit/results/checkpoint/convtasnet_35",
|
114 |
+
"resume": {},
|
115 |
+
"sample_rate": {},
|
116 |
+
"sys_params": {
|
117 |
+
"n_nodes": 1,
|
118 |
+
"nb_workers": 4,
|
119 |
+
"port": null,
|
120 |
+
"rank": 0,
|
121 |
+
"seed": 777,
|
122 |
+
"world_size": 1
|
123 |
+
},
|
124 |
+
"task_params": {
|
125 |
+
"bleeding_nsrcs": null,
|
126 |
+
"dataset": "delimit",
|
127 |
+
"target": "all",
|
128 |
+
"train": true
|
129 |
+
},
|
130 |
+
"umx_params": {
|
131 |
+
"activation": "relu",
|
132 |
+
"dropout_rate": 0.05,
|
133 |
+
"hidden_size": 512,
|
134 |
+
"instead_tanh_activation": "tanh",
|
135 |
+
"lstm_dropout_rate": 0.4,
|
136 |
+
"nb_layers": 3,
|
137 |
+
"normalization": "bn",
|
138 |
+
"umx_get_statistics": false
|
139 |
+
},
|
140 |
+
"wandb_params": {
|
141 |
+
"entity": "vinyne",
|
142 |
+
"project": "delimit",
|
143 |
+
"rerun_id": null,
|
144 |
+
"sweep": false,
|
145 |
+
"use_wandb": true
|
146 |
+
}
|
147 |
+
},
|
148 |
+
"best_epoch": 183,
|
149 |
+
"best_loss": -14.165373802185059,
|
150 |
+
"epochs_trained": 200,
|
151 |
+
"num_bad_epochs": 17,
|
152 |
+
"train_loss_history": [
|
153 |
+
-11.723381042480469,
|
154 |
+
-11.759103775024414,
|
155 |
+
-11.818404197692871,
|
156 |
+
-11.88597583770752,
|
157 |
+
-11.882278442382812,
|
158 |
+
-11.943178176879883,
|
159 |
+
-11.909675598144531,
|
160 |
+
-11.93053913116455,
|
161 |
+
-11.922198295593262,
|
162 |
+
-12.013456344604492,
|
163 |
+
-12.106053352355957,
|
164 |
+
-11.999975204467773,
|
165 |
+
-12.067265510559082,
|
166 |
+
-12.079473495483398,
|
167 |
+
-12.13272762298584,
|
168 |
+
-12.15418529510498,
|
169 |
+
-12.08314037322998,
|
170 |
+
-12.152527809143066,
|
171 |
+
-12.096565246582031,
|
172 |
+
-12.219636917114258,
|
173 |
+
-12.246475219726562,
|
174 |
+
-12.170637130737305,
|
175 |
+
-12.188806533813477,
|
176 |
+
-12.230484962463379,
|
177 |
+
-12.207123756408691,
|
178 |
+
-12.307502746582031,
|
179 |
+
-12.200200080871582,
|
180 |
+
-12.284586906433105,
|
181 |
+
-12.244038581848145,
|
182 |
+
-12.302275657653809,
|
183 |
+
-12.200104713439941,
|
184 |
+
-12.31570816040039,
|
185 |
+
-12.42324447631836,
|
186 |
+
-12.352653503417969,
|
187 |
+
-12.367401123046875,
|
188 |
+
-12.295838356018066,
|
189 |
+
-12.404874801635742,
|
190 |
+
-12.338440895080566,
|
191 |
+
-12.365501403808594,
|
192 |
+
-12.365768432617188,
|
193 |
+
-12.225799560546875,
|
194 |
+
-12.26883602142334,
|
195 |
+
-12.390016555786133,
|
196 |
+
-12.410661697387695,
|
197 |
+
-12.311858177185059,
|
198 |
+
-12.408061027526855,
|
199 |
+
-12.396013259887695,
|
200 |
+
-12.353321075439453,
|
201 |
+
-12.470121383666992,
|
202 |
+
-12.469389915466309,
|
203 |
+
-12.452675819396973,
|
204 |
+
-12.381932258605957,
|
205 |
+
-12.31003475189209,
|
206 |
+
-12.412126541137695,
|
207 |
+
-12.267746925354004,
|
208 |
+
-12.440984725952148,
|
209 |
+
-12.413816452026367,
|
210 |
+
-12.417757034301758,
|
211 |
+
-12.4945650100708,
|
212 |
+
-12.445524215698242,
|
213 |
+
-12.38110065460205,
|
214 |
+
-12.454893112182617,
|
215 |
+
-12.390727996826172,
|
216 |
+
-12.339771270751953,
|
217 |
+
-12.528243064880371,
|
218 |
+
-12.434144973754883,
|
219 |
+
-12.43438720703125,
|
220 |
+
-12.458473205566406,
|
221 |
+
-12.424423217773438,
|
222 |
+
-12.387894630432129,
|
223 |
+
-12.438997268676758,
|
224 |
+
-12.528799057006836,
|
225 |
+
-12.423232078552246,
|
226 |
+
-12.534538269042969,
|
227 |
+
-12.495400428771973,
|
228 |
+
-12.53675651550293,
|
229 |
+
-12.551910400390625,
|
230 |
+
-12.478575706481934,
|
231 |
+
-12.461804389953613,
|
232 |
+
-12.483702659606934,
|
233 |
+
-12.474960327148438,
|
234 |
+
-12.441666603088379,
|
235 |
+
-12.42241096496582,
|
236 |
+
-12.48852252960205,
|
237 |
+
-12.513558387756348,
|
238 |
+
-12.40845012664795,
|
239 |
+
-12.555559158325195,
|
240 |
+
-12.589385032653809,
|
241 |
+
-12.395785331726074,
|
242 |
+
-12.496671676635742,
|
243 |
+
-12.554829597473145,
|
244 |
+
-12.530548095703125,
|
245 |
+
-12.564457893371582,
|
246 |
+
-12.52737808227539,
|
247 |
+
-12.608246803283691,
|
248 |
+
-12.3996000289917,
|
249 |
+
-12.433905601501465,
|
250 |
+
-12.490935325622559,
|
251 |
+
-12.477506637573242,
|
252 |
+
-12.470728874206543,
|
253 |
+
-12.564470291137695,
|
254 |
+
-12.525967597961426,
|
255 |
+
-12.502660751342773,
|
256 |
+
-12.440997123718262,
|
257 |
+
-12.576118469238281,
|
258 |
+
-12.538352966308594,
|
259 |
+
-12.512738227844238,
|
260 |
+
-12.525115966796875,
|
261 |
+
-12.511483192443848,
|
262 |
+
-12.571795463562012,
|
263 |
+
-12.59391975402832,
|
264 |
+
-12.442131996154785,
|
265 |
+
-12.617898941040039,
|
266 |
+
-12.495210647583008,
|
267 |
+
-12.551814079284668,
|
268 |
+
-12.4913330078125,
|
269 |
+
-12.626816749572754,
|
270 |
+
-12.556028366088867,
|
271 |
+
-12.477901458740234,
|
272 |
+
-12.596776008605957,
|
273 |
+
-12.597326278686523,
|
274 |
+
-12.484386444091797,
|
275 |
+
-12.660898208618164,
|
276 |
+
-12.440162658691406,
|
277 |
+
-12.530372619628906,
|
278 |
+
-12.51207447052002,
|
279 |
+
-12.503606796264648,
|
280 |
+
-12.670214653015137,
|
281 |
+
-12.51667308807373,
|
282 |
+
-12.546160697937012,
|
283 |
+
-12.504158020019531,
|
284 |
+
-12.6427001953125,
|
285 |
+
-12.56100082397461,
|
286 |
+
-12.506058692932129,
|
287 |
+
-12.637288093566895,
|
288 |
+
-12.572591781616211,
|
289 |
+
-12.544734001159668,
|
290 |
+
-12.604019165039062,
|
291 |
+
-12.549866676330566,
|
292 |
+
-12.521714210510254,
|
293 |
+
-12.601127624511719,
|
294 |
+
-12.629931449890137,
|
295 |
+
-12.587185859680176,
|
296 |
+
-12.605366706848145,
|
297 |
+
-12.606413841247559,
|
298 |
+
-12.536269187927246,
|
299 |
+
-12.577346801757812,
|
300 |
+
-12.703147888183594,
|
301 |
+
-12.60477066040039,
|
302 |
+
-12.603355407714844,
|
303 |
+
-12.536528587341309,
|
304 |
+
-12.601842880249023,
|
305 |
+
-12.698568344116211,
|
306 |
+
-12.72192668914795,
|
307 |
+
-12.663148880004883,
|
308 |
+
-12.644909858703613,
|
309 |
+
-12.631479263305664,
|
310 |
+
-12.596253395080566,
|
311 |
+
-12.61674690246582,
|
312 |
+
-12.701379776000977,
|
313 |
+
-12.664311408996582,
|
314 |
+
-12.646204948425293,
|
315 |
+
-12.597058296203613,
|
316 |
+
-12.652384757995605,
|
317 |
+
-12.579480171203613,
|
318 |
+
-12.757433891296387,
|
319 |
+
-12.686827659606934,
|
320 |
+
-12.65634536743164,
|
321 |
+
-12.552176475524902,
|
322 |
+
-12.625761032104492,
|
323 |
+
-12.652499198913574,
|
324 |
+
-12.668974876403809,
|
325 |
+
-12.700301170349121,
|
326 |
+
-12.591926574707031,
|
327 |
+
-12.54333782196045,
|
328 |
+
-12.541864395141602,
|
329 |
+
-12.720565795898438,
|
330 |
+
-12.625009536743164,
|
331 |
+
-12.577120780944824,
|
332 |
+
-12.67569637298584,
|
333 |
+
-12.634958267211914,
|
334 |
+
-12.660367012023926,
|
335 |
+
-12.646204948425293,
|
336 |
+
-12.713308334350586,
|
337 |
+
-12.734916687011719,
|
338 |
+
-12.602835655212402,
|
339 |
+
-12.596168518066406,
|
340 |
+
-12.66109848022461,
|
341 |
+
-12.568808555603027,
|
342 |
+
-12.719843864440918,
|
343 |
+
-12.746356010437012,
|
344 |
+
-12.602999687194824,
|
345 |
+
-12.632689476013184,
|
346 |
+
-12.715725898742676,
|
347 |
+
-12.671126365661621,
|
348 |
+
-12.659911155700684,
|
349 |
+
-12.755860328674316,
|
350 |
+
-12.591080665588379,
|
351 |
+
-12.623464584350586,
|
352 |
+
-12.643362045288086
|
353 |
+
],
|
354 |
+
"train_time_history": [
|
355 |
+
308.12283968925476,
|
356 |
+
308.12408661842346,
|
357 |
+
305.56318974494934,
|
358 |
+
305.6093053817749,
|
359 |
+
304.1926734447479,
|
360 |
+
304.2103099822998,
|
361 |
+
301.78035831451416,
|
362 |
+
301.7819468975067,
|
363 |
+
317.8168547153473,
|
364 |
+
317.818119764328,
|
365 |
+
314.8585801124573,
|
366 |
+
314.8601076602936,
|
367 |
+
311.61795926094055,
|
368 |
+
311.61953926086426,
|
369 |
+
316.2616910934448,
|
370 |
+
316.2639091014862,
|
371 |
+
312.59282636642456,
|
372 |
+
312.59408020973206,
|
373 |
+
314.6765525341034,
|
374 |
+
314.6778757572174,
|
375 |
+
314.4039900302887,
|
376 |
+
314.40531301498413,
|
377 |
+
313.9343922138214,
|
378 |
+
313.9356322288513,
|
379 |
+
315.1470823287964,
|
380 |
+
315.14854192733765,
|
381 |
+
317.65793561935425,
|
382 |
+
317.65903544425964,
|
383 |
+
316.41589403152466,
|
384 |
+
316.4171371459961,
|
385 |
+
316.253050327301,
|
386 |
+
316.2544617652893,
|
387 |
+
316.2039670944214,
|
388 |
+
316.20542550086975,
|
389 |
+
316.30707120895386,
|
390 |
+
316.30964159965515,
|
391 |
+
315.7812213897705,
|
392 |
+
315.7832131385803,
|
393 |
+
315.77191638946533,
|
394 |
+
315.7732570171356,
|
395 |
+
315.7776229381561,
|
396 |
+
315.77907848358154,
|
397 |
+
315.80343294143677,
|
398 |
+
315.8051166534424,
|
399 |
+
314.40133929252625,
|
400 |
+
314.403112411499,
|
401 |
+
314.32283997535706,
|
402 |
+
314.32424092292786,
|
403 |
+
314.90000677108765,
|
404 |
+
314.90242648124695,
|
405 |
+
313.8207128047943,
|
406 |
+
313.8227391242981,
|
407 |
+
313.86938881874084,
|
408 |
+
313.87079215049744,
|
409 |
+
316.9037547111511,
|
410 |
+
316.9056947231293,
|
411 |
+
317.4321286678314,
|
412 |
+
317.43361139297485,
|
413 |
+
316.41515493392944,
|
414 |
+
316.4182825088501,
|
415 |
+
315.69741559028625,
|
416 |
+
315.699245929718,
|
417 |
+
315.9285054206848,
|
418 |
+
315.930716753006,
|
419 |
+
314.25376319885254,
|
420 |
+
314.25567531585693,
|
421 |
+
312.997665643692,
|
422 |
+
313.0005877017975,
|
423 |
+
315.5962414741516,
|
424 |
+
315.5977747440338,
|
425 |
+
315.49425506591797,
|
426 |
+
315.4961242675781,
|
427 |
+
315.980491399765,
|
428 |
+
315.98283791542053,
|
429 |
+
315.5533638000488,
|
430 |
+
315.55492901802063,
|
431 |
+
313.9896593093872,
|
432 |
+
313.99131321907043,
|
433 |
+
314.3214478492737,
|
434 |
+
314.3232262134552,
|
435 |
+
314.6442220211029,
|
436 |
+
314.64620661735535,
|
437 |
+
315.69726514816284,
|
438 |
+
315.7001700401306,
|
439 |
+
314.78302001953125,
|
440 |
+
314.7847316265106,
|
441 |
+
313.14448523521423,
|
442 |
+
313.1465194225311,
|
443 |
+
311.8232834339142,
|
444 |
+
311.8251144886017,
|
445 |
+
318.88225960731506,
|
446 |
+
318.8843643665314,
|
447 |
+
319.20725083351135,
|
448 |
+
319.20886182785034,
|
449 |
+
317.81429648399353,
|
450 |
+
317.8159878253937,
|
451 |
+
320.23738193511963,
|
452 |
+
320.23904752731323,
|
453 |
+
315.8315763473511,
|
454 |
+
315.83344054222107,
|
455 |
+
317.32581615448,
|
456 |
+
317.3274848461151,
|
457 |
+
316.7596924304962,
|
458 |
+
316.7628848552704,
|
459 |
+
316.3167974948883,
|
460 |
+
316.3188827037811,
|
461 |
+
316.44567823410034,
|
462 |
+
316.44802141189575,
|
463 |
+
313.8653395175934,
|
464 |
+
313.8687484264374,
|
465 |
+
308.43933939933777,
|
466 |
+
308.44151163101196,
|
467 |
+
312.1857454776764,
|
468 |
+
312.18967509269714,
|
469 |
+
307.8407344818115,
|
470 |
+
307.84401679039,
|
471 |
+
307.48447585105896,
|
472 |
+
307.48623728752136,
|
473 |
+
310.300940990448,
|
474 |
+
310.3029022216797,
|
475 |
+
310.32225275039673,
|
476 |
+
310.3257050514221,
|
477 |
+
309.351779460907,
|
478 |
+
309.3539865016937,
|
479 |
+
309.4356527328491,
|
480 |
+
309.4380919933319,
|
481 |
+
312.63360381126404,
|
482 |
+
312.63535809516907,
|
483 |
+
311.7453818321228,
|
484 |
+
311.7476508617401,
|
485 |
+
311.3258364200592,
|
486 |
+
311.327698469162,
|
487 |
+
312.28111600875854,
|
488 |
+
312.2828998565674,
|
489 |
+
311.3383209705353,
|
490 |
+
311.34200048446655,
|
491 |
+
306.9764757156372,
|
492 |
+
306.9787657260895,
|
493 |
+
309.35506653785706,
|
494 |
+
309.3569576740265,
|
495 |
+
310.2506465911865,
|
496 |
+
310.2529339790344,
|
497 |
+
310.65880727767944,
|
498 |
+
310.66108298301697,
|
499 |
+
311.18562865257263,
|
500 |
+
311.1874952316284,
|
501 |
+
309.07765316963196,
|
502 |
+
309.07997822761536,
|
503 |
+
313.3008818626404,
|
504 |
+
313.3029179573059,
|
505 |
+
311.267498254776,
|
506 |
+
311.26989102363586,
|
507 |
+
310.62635374069214,
|
508 |
+
310.6306185722351,
|
509 |
+
308.1883268356323,
|
510 |
+
308.19112515449524,
|
511 |
+
310.65689158439636,
|
512 |
+
310.65896558761597,
|
513 |
+
308.98754620552063,
|
514 |
+
309.03386878967285,
|
515 |
+
309.21512937545776,
|
516 |
+
309.2185757160187,
|
517 |
+
309.93750405311584,
|
518 |
+
309.93965554237366,
|
519 |
+
310.2938587665558,
|
520 |
+
310.29592084884644,
|
521 |
+
308.24257493019104,
|
522 |
+
308.2463102340698,
|
523 |
+
310.6870594024658,
|
524 |
+
310.6905345916748,
|
525 |
+
310.7875945568085,
|
526 |
+
310.78995156288147,
|
527 |
+
310.9882712364197,
|
528 |
+
310.9906806945801,
|
529 |
+
310.95856285095215,
|
530 |
+
310.96066546440125,
|
531 |
+
312.4489221572876,
|
532 |
+
312.45125246047974,
|
533 |
+
312.24022579193115,
|
534 |
+
312.2863116264343,
|
535 |
+
309.68400406837463,
|
536 |
+
309.6862533092499,
|
537 |
+
309.64014887809753,
|
538 |
+
309.64232993125916,
|
539 |
+
309.9094281196594,
|
540 |
+
309.9119017124176,
|
541 |
+
309.40677762031555,
|
542 |
+
309.40893173217773,
|
543 |
+
309.1595506668091,
|
544 |
+
309.1617259979248,
|
545 |
+
308.4178020954132,
|
546 |
+
308.4198989868164,
|
547 |
+
308.5063133239746,
|
548 |
+
308.5085346698761,
|
549 |
+
307.5796904563904,
|
550 |
+
307.5972898006439,
|
551 |
+
309.66309905052185,
|
552 |
+
309.66530561447144,
|
553 |
+
312.70798993110657,
|
554 |
+
312.7102212905884,
|
555 |
+
310.2431013584137,
|
556 |
+
310.2453660964966,
|
557 |
+
312.2640459537506,
|
558 |
+
312.26635122299194,
|
559 |
+
311.27055287361145,
|
560 |
+
311.27321219444275,
|
561 |
+
312.58145689964294,
|
562 |
+
312.58376598358154,
|
563 |
+
313.1553518772125,
|
564 |
+
313.1574249267578,
|
565 |
+
308.4067575931549,
|
566 |
+
308.4089684486389,
|
567 |
+
311.0251498222351,
|
568 |
+
311.0274658203125,
|
569 |
+
308.0227520465851,
|
570 |
+
308.02498388290405,
|
571 |
+
308.0182030200958,
|
572 |
+
308.0204634666443,
|
573 |
+
308.63523149490356,
|
574 |
+
308.63751220703125,
|
575 |
+
308.53969383239746,
|
576 |
+
308.5420751571655,
|
577 |
+
306.51329946517944,
|
578 |
+
306.51555824279785,
|
579 |
+
309.59846591949463,
|
580 |
+
309.60128831863403,
|
581 |
+
305.3712034225464,
|
582 |
+
305.37409830093384,
|
583 |
+
305.43984270095825,
|
584 |
+
305.4421238899231,
|
585 |
+
309.3166663646698,
|
586 |
+
309.3195414543152,
|
587 |
+
308.8618497848511,
|
588 |
+
308.86409974098206,
|
589 |
+
304.8731882572174,
|
590 |
+
304.8755958080292,
|
591 |
+
306.6576888561249,
|
592 |
+
306.663143157959,
|
593 |
+
306.6716537475586,
|
594 |
+
306.6740062236786,
|
595 |
+
309.47339940071106,
|
596 |
+
309.47578954696655,
|
597 |
+
307.73386335372925,
|
598 |
+
307.7363700866699,
|
599 |
+
308.0688214302063,
|
600 |
+
308.07209277153015,
|
601 |
+
311.58968901634216,
|
602 |
+
311.6099576950073,
|
603 |
+
308.70460844039917,
|
604 |
+
308.70710158348083,
|
605 |
+
312.0563473701477,
|
606 |
+
312.05881452560425,
|
607 |
+
310.89456367492676,
|
608 |
+
310.9119510650635,
|
609 |
+
308.73097705841064,
|
610 |
+
308.73414373397827,
|
611 |
+
309.4255359172821,
|
612 |
+
309.42857813835144,
|
613 |
+
311.0751721858978,
|
614 |
+
311.07801842689514,
|
615 |
+
309.5860447883606,
|
616 |
+
309.5896680355072,
|
617 |
+
309.87396597862244,
|
618 |
+
309.8803391456604,
|
619 |
+
310.9183626174927,
|
620 |
+
310.92147397994995,
|
621 |
+
308.4321529865265,
|
622 |
+
308.4359757900238,
|
623 |
+
312.4424922466278,
|
624 |
+
312.44731879234314,
|
625 |
+
312.3443009853363,
|
626 |
+
312.3491401672363,
|
627 |
+
310.3139410018921,
|
628 |
+
310.3165555000305,
|
629 |
+
312.09410762786865,
|
630 |
+
312.09656262397766,
|
631 |
+
311.11144399642944,
|
632 |
+
311.1577796936035,
|
633 |
+
309.1589603424072,
|
634 |
+
309.16152119636536,
|
635 |
+
312.51157093048096,
|
636 |
+
312.51463317871094,
|
637 |
+
314.15198159217834,
|
638 |
+
314.15485286712646,
|
639 |
+
310.00070810317993,
|
640 |
+
310.0033264160156,
|
641 |
+
311.2290298938751,
|
642 |
+
311.23188829421997,
|
643 |
+
313.0510983467102,
|
644 |
+
313.05362153053284,
|
645 |
+
313.48791670799255,
|
646 |
+
313.4910161495209,
|
647 |
+
307.60272216796875,
|
648 |
+
307.6053590774536,
|
649 |
+
303.84622287750244,
|
650 |
+
303.8494029045105,
|
651 |
+
304.8547012805939,
|
652 |
+
304.85784125328064,
|
653 |
+
310.63141536712646,
|
654 |
+
310.63450264930725,
|
655 |
+
304.8634753227234,
|
656 |
+
304.8664004802704,
|
657 |
+
308.1505949497223,
|
658 |
+
308.15428018569946,
|
659 |
+
310.18936228752136,
|
660 |
+
310.1920323371887,
|
661 |
+
309.2550263404846,
|
662 |
+
309.2577428817749,
|
663 |
+
310.08596634864807,
|
664 |
+
310.08910751342773,
|
665 |
+
307.4643654823303,
|
666 |
+
307.4670605659485,
|
667 |
+
308.558221578598,
|
668 |
+
308.5638659000397,
|
669 |
+
309.7440264225006,
|
670 |
+
309.7467608451843,
|
671 |
+
308.2091956138611,
|
672 |
+
308.2125828266144,
|
673 |
+
307.0199763774872,
|
674 |
+
307.02332496643066,
|
675 |
+
306.3482081890106,
|
676 |
+
306.35128688812256,
|
677 |
+
307.3764581680298,
|
678 |
+
307.37923669815063,
|
679 |
+
311.61060428619385,
|
680 |
+
311.6135311126709,
|
681 |
+
306.8187861442566,
|
682 |
+
306.8240280151367,
|
683 |
+
305.19880175590515,
|
684 |
+
305.20313119888306,
|
685 |
+
309.252712726593,
|
686 |
+
309.256165266037,
|
687 |
+
310.80801463127136,
|
688 |
+
310.81236577033997,
|
689 |
+
309.1079206466675,
|
690 |
+
309.11073756217957,
|
691 |
+
310.6556165218353,
|
692 |
+
310.65838623046875,
|
693 |
+
310.94868993759155,
|
694 |
+
310.95155143737793,
|
695 |
+
308.4552607536316,
|
696 |
+
308.4580717086792,
|
697 |
+
308.2857587337494,
|
698 |
+
308.2886221408844,
|
699 |
+
306.4856150150299,
|
700 |
+
306.4887855052948,
|
701 |
+
306.8667871952057,
|
702 |
+
306.86966013908386,
|
703 |
+
306.1964519023895,
|
704 |
+
306.2005341053009,
|
705 |
+
308.2178611755371,
|
706 |
+
308.22126364707947,
|
707 |
+
305.94888377189636,
|
708 |
+
305.9523375034332,
|
709 |
+
307.48926973342896,
|
710 |
+
307.4920620918274,
|
711 |
+
307.60354018211365,
|
712 |
+
307.63674998283386,
|
713 |
+
307.2473645210266,
|
714 |
+
307.2501358985901,
|
715 |
+
308.16573452949524,
|
716 |
+
308.2115182876587,
|
717 |
+
307.30736780166626,
|
718 |
+
307.3109815120697,
|
719 |
+
307.2137475013733,
|
720 |
+
307.2178246974945,
|
721 |
+
308.5944905281067,
|
722 |
+
308.59843826293945,
|
723 |
+
307.2346291542053,
|
724 |
+
307.2382435798645,
|
725 |
+
308.417338848114,
|
726 |
+
308.4208617210388,
|
727 |
+
305.5816307067871,
|
728 |
+
305.5852439403534,
|
729 |
+
307.69459652900696,
|
730 |
+
307.6975119113922,
|
731 |
+
307.20833134651184,
|
732 |
+
307.212299823761,
|
733 |
+
305.9614431858063,
|
734 |
+
305.965185880661,
|
735 |
+
305.31594157218933,
|
736 |
+
305.3195445537567,
|
737 |
+
307.46696519851685,
|
738 |
+
307.47079825401306,
|
739 |
+
306.23966455459595,
|
740 |
+
306.2433180809021,
|
741 |
+
306.1235647201538,
|
742 |
+
306.1273248195648,
|
743 |
+
307.02436780929565,
|
744 |
+
307.02733421325684,
|
745 |
+
306.9687819480896,
|
746 |
+
306.97225856781006,
|
747 |
+
306.23205065727234,
|
748 |
+
306.2356073856354,
|
749 |
+
305.3567383289337,
|
750 |
+
305.36028504371643,
|
751 |
+
305.94446635246277,
|
752 |
+
305.9480822086334,
|
753 |
+
307.2553553581238
|
754 |
+
],
|
755 |
+
"valid_loss_history": [
|
756 |
+
-12.743322372436523,
|
757 |
+
-12.724347114562988,
|
758 |
+
-12.86701488494873,
|
759 |
+
-12.694435119628906,
|
760 |
+
-12.706733703613281,
|
761 |
+
-13.048251152038574,
|
762 |
+
-12.943618774414062,
|
763 |
+
-13.120084762573242,
|
764 |
+
-13.121935844421387,
|
765 |
+
-13.146740913391113,
|
766 |
+
-13.197364807128906,
|
767 |
+
-13.224929809570312,
|
768 |
+
-13.255891799926758,
|
769 |
+
-13.311783790588379,
|
770 |
+
-13.386489868164062,
|
771 |
+
-13.390006065368652,
|
772 |
+
-13.45509147644043,
|
773 |
+
-13.444679260253906,
|
774 |
+
-13.456311225891113,
|
775 |
+
-13.36051082611084,
|
776 |
+
-13.478644371032715,
|
777 |
+
-13.503388404846191,
|
778 |
+
-13.540580749511719,
|
779 |
+
-13.579903602600098,
|
780 |
+
-13.551591873168945,
|
781 |
+
-13.638075828552246,
|
782 |
+
-13.617512702941895,
|
783 |
+
-13.64240550994873,
|
784 |
+
-13.618767738342285,
|
785 |
+
-13.65319538116455,
|
786 |
+
-13.601574897766113,
|
787 |
+
-13.693778038024902,
|
788 |
+
-13.658882141113281,
|
789 |
+
-13.649510383605957,
|
790 |
+
-13.477263450622559,
|
791 |
+
-13.643564224243164,
|
792 |
+
-13.732584953308105,
|
793 |
+
-13.643271446228027,
|
794 |
+
-13.655325889587402,
|
795 |
+
-13.71172046661377,
|
796 |
+
-13.564180374145508,
|
797 |
+
-13.708178520202637,
|
798 |
+
-13.688010215759277,
|
799 |
+
-13.711198806762695,
|
800 |
+
-13.612863540649414,
|
801 |
+
-13.702019691467285,
|
802 |
+
-13.704530715942383,
|
803 |
+
-13.716957092285156,
|
804 |
+
-13.76714038848877,
|
805 |
+
-13.719636917114258,
|
806 |
+
-13.738469123840332,
|
807 |
+
-13.759002685546875,
|
808 |
+
-13.721348762512207,
|
809 |
+
-13.727803230285645,
|
810 |
+
-13.768327713012695,
|
811 |
+
-13.73253345489502,
|
812 |
+
-13.75208568572998,
|
813 |
+
-13.754429817199707,
|
814 |
+
-13.76417064666748,
|
815 |
+
-13.805985450744629,
|
816 |
+
-13.762914657592773,
|
817 |
+
-13.75927448272705,
|
818 |
+
-13.781553268432617,
|
819 |
+
-13.744827270507812,
|
820 |
+
-13.805213928222656,
|
821 |
+
-13.792055130004883,
|
822 |
+
-13.736992835998535,
|
823 |
+
-13.804685592651367,
|
824 |
+
-13.802186012268066,
|
825 |
+
-13.812178611755371,
|
826 |
+
-13.781081199645996,
|
827 |
+
-13.836441993713379,
|
828 |
+
-13.787053108215332,
|
829 |
+
-13.824462890625,
|
830 |
+
-13.827963829040527,
|
831 |
+
-13.768393516540527,
|
832 |
+
-13.824796676635742,
|
833 |
+
-13.809252738952637,
|
834 |
+
-13.820283889770508,
|
835 |
+
-13.811989784240723,
|
836 |
+
-13.845786094665527,
|
837 |
+
-13.801295280456543,
|
838 |
+
-13.795866966247559,
|
839 |
+
-13.847658157348633,
|
840 |
+
-13.841630935668945,
|
841 |
+
-13.887687683105469,
|
842 |
+
-13.838217735290527,
|
843 |
+
-13.833791732788086,
|
844 |
+
-13.8090181350708,
|
845 |
+
-13.810338973999023,
|
846 |
+
-13.812939643859863,
|
847 |
+
-13.813563346862793,
|
848 |
+
-13.72245979309082,
|
849 |
+
-13.829062461853027,
|
850 |
+
-13.820122718811035,
|
851 |
+
-13.764768600463867,
|
852 |
+
-13.882962226867676,
|
853 |
+
-13.887824058532715,
|
854 |
+
-13.874728202819824,
|
855 |
+
-13.83934211730957,
|
856 |
+
-13.854304313659668,
|
857 |
+
-13.853861808776855,
|
858 |
+
-13.878510475158691,
|
859 |
+
-13.855673789978027,
|
860 |
+
-13.935111999511719,
|
861 |
+
-13.873315811157227,
|
862 |
+
-13.88434886932373,
|
863 |
+
-13.913508415222168,
|
864 |
+
-13.804875373840332,
|
865 |
+
-13.874313354492188,
|
866 |
+
-13.925950050354004,
|
867 |
+
-13.898317337036133,
|
868 |
+
-13.861913681030273,
|
869 |
+
-13.83596134185791,
|
870 |
+
-13.907777786254883,
|
871 |
+
-13.832358360290527,
|
872 |
+
-13.936162948608398,
|
873 |
+
-13.925071716308594,
|
874 |
+
-13.906752586364746,
|
875 |
+
-13.87073040008545,
|
876 |
+
-13.964620590209961,
|
877 |
+
-13.925311088562012,
|
878 |
+
-13.974698066711426,
|
879 |
+
-13.957905769348145,
|
880 |
+
-13.918564796447754,
|
881 |
+
-13.975790023803711,
|
882 |
+
-13.988444328308105,
|
883 |
+
-13.959516525268555,
|
884 |
+
-14.01569652557373,
|
885 |
+
-13.992425918579102,
|
886 |
+
-14.039790153503418,
|
887 |
+
-13.940314292907715,
|
888 |
+
-14.011497497558594,
|
889 |
+
-13.953152656555176,
|
890 |
+
-13.920698165893555,
|
891 |
+
-13.960227966308594,
|
892 |
+
-13.907439231872559,
|
893 |
+
-14.014067649841309,
|
894 |
+
-13.972914695739746,
|
895 |
+
-13.942621231079102,
|
896 |
+
-14.019667625427246,
|
897 |
+
-14.037107467651367,
|
898 |
+
-13.85366153717041,
|
899 |
+
-13.980110168457031,
|
900 |
+
-13.97785472869873,
|
901 |
+
-13.983843803405762,
|
902 |
+
-13.843756675720215,
|
903 |
+
-14.002585411071777,
|
904 |
+
-14.026784896850586,
|
905 |
+
-14.028115272521973,
|
906 |
+
-14.02059268951416,
|
907 |
+
-13.985837936401367,
|
908 |
+
-14.076154708862305,
|
909 |
+
-14.060620307922363,
|
910 |
+
-13.936518669128418,
|
911 |
+
-13.957221031188965,
|
912 |
+
-14.017061233520508,
|
913 |
+
-13.995661735534668,
|
914 |
+
-14.056286811828613,
|
915 |
+
-14.037705421447754,
|
916 |
+
-13.940332412719727,
|
917 |
+
-14.092416763305664,
|
918 |
+
-14.024917602539062,
|
919 |
+
-14.002346992492676,
|
920 |
+
-14.026989936828613,
|
921 |
+
-13.944084167480469,
|
922 |
+
-14.002883911132812,
|
923 |
+
-14.120462417602539,
|
924 |
+
-14.043062210083008,
|
925 |
+
-14.008293151855469,
|
926 |
+
-14.040563583374023,
|
927 |
+
-13.994155883789062,
|
928 |
+
-14.08944034576416,
|
929 |
+
-14.078422546386719,
|
930 |
+
-14.014589309692383,
|
931 |
+
-14.083242416381836,
|
932 |
+
-14.104707717895508,
|
933 |
+
-14.103189468383789,
|
934 |
+
-14.063937187194824,
|
935 |
+
-14.0596284866333,
|
936 |
+
-14.059121131896973,
|
937 |
+
-14.102814674377441,
|
938 |
+
-14.165373802185059,
|
939 |
+
-14.106118202209473,
|
940 |
+
-14.107162475585938,
|
941 |
+
-14.085371017456055,
|
942 |
+
-14.123793601989746,
|
943 |
+
-14.053537368774414,
|
944 |
+
-14.077792167663574,
|
945 |
+
-14.056371688842773,
|
946 |
+
-14.033655166625977,
|
947 |
+
-14.096640586853027,
|
948 |
+
-14.057114601135254,
|
949 |
+
-14.115262985229492,
|
950 |
+
-14.074142456054688,
|
951 |
+
-14.067980766296387,
|
952 |
+
-14.118453025817871,
|
953 |
+
-14.117535591125488,
|
954 |
+
-14.126029968261719,
|
955 |
+
-14.117874145507812
|
956 |
+
]
|
957 |
+
}
|
weight/all.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:34f2ef4e5c32542060621f7ea9f7a06a2acf91be22825a38f9270077a7346679
|
3 |
+
size 9424379
|