jeonchangbin49 commited on
Commit
a00b67a
·
1 Parent(s): da27cbe

first commit

Browse files
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 jeonchangbin49
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,2 @@
1
- ---
2
- title: De Limiter
3
- emoji: 🏃
4
- colorFrom: pink
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.39.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # De-limiter
2
+ An official demo of "Music De-limiter Networks via Sample-wise Gain Inversion", which will be presented in WASPAA 2023.
 
 
 
 
 
 
 
 
 
 
 
add.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ import tqdm
9
+ import librosa
10
+ import librosa.display
11
+ import soundfile as sf
12
+ import pyloudnorm as pyln
13
+ from dotmap import DotMap
14
+ import gradio as gr
15
+
16
+ from models import load_model_with_args
17
+ from separate_func import (
18
+ conv_tasnet_separate,
19
+ )
20
+ from utils import db2linear
21
+
22
+
23
+ tqdm.monitor_interval = 0
24
+
25
+
26
+ def separate_track_with_model(
27
+ args, model, device, track_audio, track_name, meter, augmented_gain
28
+ ):
29
+ with torch.no_grad():
30
+ if (
31
+ args.model_loss_params.architecture == "conv_tasnet_mask_on_output"
32
+ or args.model_loss_params.architecture == "conv_tasnet"
33
+ ):
34
+ estimates = conv_tasnet_separate(
35
+ args,
36
+ model,
37
+ device,
38
+ track_audio,
39
+ track_name,
40
+ meter=meter,
41
+ augmented_gain=augmented_gain,
42
+ )
43
+
44
+ return estimates
45
+
46
+
47
+ def main(input, mix_coefficient):
48
+ parser = argparse.ArgumentParser(description="model test.py")
49
+ parser.add_argument("--target", type=str, default="all")
50
+ parser.add_argument("--weight_directory", type=str, default="weight")
51
+ parser.add_argument("--output_directory", type=str, default="output")
52
+ parser.add_argument("--use_gpu", type=bool, default=True)
53
+ parser.add_argument("--save_name_as_target", type=bool, default=False)
54
+ parser.add_argument(
55
+ "--loudnorm_input_lufs",
56
+ type=float,
57
+ default=None,
58
+ help="If you want to use loudnorm for input",
59
+ )
60
+ parser.add_argument(
61
+ "--save_output_loudnorm",
62
+ type=float,
63
+ default=-14.0,
64
+ help="Save loudness normalized outputs or not. If you want to save, input target loudness",
65
+ )
66
+ parser.add_argument(
67
+ "--save_mixed_output",
68
+ type=float,
69
+ default=None,
70
+ help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)",
71
+ )
72
+ parser.add_argument(
73
+ "--save_16k_mono",
74
+ type=bool,
75
+ default=False,
76
+ help="Save 16k mono wav files for FAD evaluation.",
77
+ )
78
+ parser.add_argument(
79
+ "--save_histogram",
80
+ type=bool,
81
+ default=False,
82
+ help="Save histogram of the output. Only valid when the task is 'delimit'",
83
+ )
84
+ parser.add_argument(
85
+ "--use_singletrackset",
86
+ type=bool,
87
+ default=False,
88
+ help="Use SingleTrackSet if input data is too long.",
89
+ )
90
+
91
+ args, _ = parser.parse_known_args()
92
+
93
+ with open(f"{args.weight_directory}/{args.target}.json", "r") as f:
94
+ args_dict = json.load(f)
95
+ args_dict = DotMap(args_dict)
96
+
97
+ for key, value in args_dict["args"].items():
98
+ if key in list(vars(args).keys()):
99
+ pass
100
+ else:
101
+ setattr(args, key, value)
102
+
103
+ args.test_output_dir = f"{args.output_directory}"
104
+ os.makedirs(args.test_output_dir, exist_ok=True)
105
+
106
+ device = torch.device(
107
+ "cuda" if torch.cuda.is_available() and args.use_gpu else "cpu"
108
+ )
109
+
110
+ ###################### Define Models ######################
111
+ our_model = load_model_with_args(args)
112
+ our_model = our_model.to(device)
113
+
114
+ target_model_path = f"{args.weight_directory}/{args.target}.pth"
115
+ checkpoint = torch.load(target_model_path, map_location=device)
116
+ our_model.load_state_dict(checkpoint)
117
+
118
+ our_model.eval()
119
+
120
+ meter = pyln.Meter(44100)
121
+
122
+ sr, track_audio = input
123
+ track_audio = track_audio.T
124
+ track_name = "gradio_demo"
125
+
126
+ orig_audio = track_audio.copy()
127
+
128
+ if sr != 44100:
129
+ raise ValueError("Sample rate should be 44100")
130
+ augmented_gain = None
131
+
132
+ if args.loudnorm_input_lufs: # If you want to use loud-normalized input
133
+ track_lufs = meter.integrated_loudness(track_audio.T)
134
+ augmented_gain = args.loudnorm_input_lufs - track_lufs
135
+ track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
136
+
137
+ track_audio = (
138
+ torch.as_tensor(track_audio, dtype=torch.float32).unsqueeze(0).to(device)
139
+ )
140
+
141
+ estimates = separate_track_with_model(
142
+ args, our_model, device, track_audio, track_name, meter, augmented_gain
143
+ )
144
+
145
+ if args.save_mixed_output:
146
+ track_lufs = meter.integrated_loudness(orig_audio.T)
147
+ augmented_gain = args.save_output_loudnorm - track_lufs
148
+ orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
149
+
150
+ mixed_output = orig_audio * args.save_mixed_output + estimates * (
151
+ 1 - args.save_mixed_output
152
+ )
153
+
154
+ sf.write(
155
+ f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav",
156
+ mixed_output.T,
157
+ args.data_params.sample_rate,
158
+ )
159
+
160
+ return (
161
+ (sr, estimates.T),
162
+ (sr, orig_audio.T),
163
+ (sr, orig_audio.T * mix_coefficient + estimates.T * (1 - mix_coefficient)),
164
+ )
165
+
166
+
167
+ def parallel_mix(input, output, mix_coefficient):
168
+ sr = 44100
169
+ return sr, input[1] * mix_coefficient + output[1] * (1 - mix_coefficient)
170
+
171
+
172
+ def int16_to_float32(wav):
173
+ wav = np.frombuffer(wav, dtype=np.int16)
174
+ X = wav / 32768
175
+ return X
176
+
177
+
178
+ def waveform_plot(input, output, prl_mix_ouptut, figsize_x=20, figsize_y=9):
179
+ sr = 44100
180
+ fig, ax = plt.subplots(
181
+ nrows=3, sharex=True, sharey=True, figsize=(figsize_x, figsize_y)
182
+ )
183
+ librosa.display.waveshow(int16_to_float32(input[1]).T, sr=sr, ax=ax[0])
184
+ ax[0].set(title="Loudness Normalized Input")
185
+ ax[0].label_outer()
186
+ librosa.display.waveshow(int16_to_float32(output[1]).T, sr=sr, ax=ax[1])
187
+ ax[1].set(title="De-limiter Output")
188
+ ax[1].label_outer()
189
+ librosa.display.waveshow(int16_to_float32(prl_mix_ouptut[1]).T, sr=sr, ax=ax[2])
190
+ ax[2].set(title="Parallel Mix of the Input and its De-limiter Output")
191
+ ax[2].label_outer()
192
+ return fig
193
+
194
+
195
+ with gr.Blocks() as demo:
196
+ gr.HTML(
197
+ """
198
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
199
+ <div
200
+ style="
201
+ display: inline-flex;
202
+ align-items: center;
203
+ gap: 0.8rem;
204
+ font-size: 1.75rem;
205
+ "
206
+ >
207
+ <h1 style="font-weight: 900; margin-bottom: 7px;">
208
+ Music De-limiter
209
+ </h1>
210
+ </div>
211
+ <p style="margin-bottom: 10px; font-size: 94%">
212
+ A demo for "Music De-limiter via Sample-wise Gain Inversion" to appear in WASPAA 2023.
213
+ You can first upload a music (.wav or .mp3) file and then press "De-limit" button to apply the De-limiter. Since we use a CPU instead of a GPU, it may require a few minute.
214
+ Then, you can apply a Parallel Mix technique, which is a simple linear mixing technique of "loudness normalized input" and the "de-limiter output".
215
+ You can modify the mixing coefficient by yourself.
216
+ If the coefficient is 0.3 then the output will be the "loudness_normalized_input * 0.3 + de-limiter_output * 0.7"
217
+ </div>
218
+ """
219
+ )
220
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
221
+ with gr.Column():
222
+ with gr.Box():
223
+ input_audio = gr.Audio(source="upload", label="De-limiter Input")
224
+ btn = gr.Button("De-limit")
225
+ with gr.Column():
226
+ with gr.Box():
227
+ loud_norm_input = gr.Audio(label="Loudness Normalized Input (-14LUFS)")
228
+ with gr.Box():
229
+ output_audio = gr.Audio(label="De-limiter Output")
230
+ with gr.Box():
231
+ output_audio_parallel = gr.Audio(
232
+ label="Parallel Mix of the Input and its De-limiter Output"
233
+ )
234
+ slider = gr.Slider(
235
+ minimum=0,
236
+ maximum=1,
237
+ step=0.1,
238
+ value=0.5,
239
+ label="Parallel Mix Coefficient",
240
+ )
241
+ btn.click(
242
+ main,
243
+ inputs=[input_audio, slider],
244
+ outputs=[output_audio, loud_norm_input, output_audio_parallel],
245
+ )
246
+ slider.release(
247
+ parallel_mix,
248
+ inputs=[input_audio, output_audio, slider],
249
+ outputs=output_audio_parallel,
250
+ )
251
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
252
+ with gr.Column():
253
+ with gr.Box():
254
+ plot = gr.Plot(label="Plots")
255
+ btn2 = gr.Button("Show Plots")
256
+ slider_plot_x = gr.Slider(
257
+ minimum=1,
258
+ maximum=100,
259
+ step=1,
260
+ value=20,
261
+ label="Plot X-axis size",
262
+ )
263
+ slider_plot_y = gr.Slider(
264
+ minimum=1,
265
+ maximum=30,
266
+ step=1,
267
+ value=9,
268
+ label="Plot Y-axis size",
269
+ )
270
+ btn2.click(
271
+ waveform_plot,
272
+ inputs=[
273
+ loud_norm_input,
274
+ output_audio,
275
+ output_audio_parallel,
276
+ slider_plot_x,
277
+ slider_plot_y,
278
+ ],
279
+ outputs=plot,
280
+ )
281
+ slider.release(
282
+ waveform_plot,
283
+ inputs=[
284
+ loud_norm_input,
285
+ output_audio,
286
+ output_audio_parallel,
287
+ slider_plot_x,
288
+ slider_plot_y,
289
+ ],
290
+ outputs=plot,
291
+ )
292
+ if __name__ == "__main__":
293
+ demo.launch(debug=True)
configs/delimit_6_s.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # For De-limit task, Conv-TasNet.
2
+ # si_sdr loss
3
+ #
4
+ # ozone_train_fixed is about 6.36 hours
5
+ # 300,000 segments is about 333.33 hours
6
+ # ratio should be about 0.019
7
+
8
+ wandb_params:
9
+ use_wandb: true
10
+ entity: null # your wandb id
11
+ project: delimit # your wandb project
12
+ rerun_id: null # use when you rerun wandb.
13
+ sweep: false
14
+
15
+ sys_params:
16
+ nb_workers: 4
17
+ seed: 777
18
+ n_nodes: 1
19
+ port: null
20
+ rank: 0
21
+
22
+ task_params:
23
+ target: all # choices=["all"]
24
+ train: true
25
+ dataset: delimit # choices=["musdb", "delimit"]
26
+
27
+ dir_params:
28
+ root: /path/to/musdb18hq
29
+ output_directory: /path/to/results
30
+ exp_name: convtasnet_6_s # you MUST specify this
31
+ resume: null # "path of checkpoint folder"
32
+ continual_train: false # when we want to use a pre-trained model but not want to use lr_scheduler history.
33
+ delimit_valid_root: null
34
+ delimit_valid_L_root: null
35
+ ozone_root: /path/to/musdb-XL-train # you have to specify data_params.use_fixed
36
+
37
+ hyperparams:
38
+ batch_size: 8 # with 1 gpus (we used 2080ti 11GB)
39
+ epochs: 200
40
+ optimizer: adamw
41
+ weight_decay: 0.01
42
+ lr: 0.00003
43
+ lr_decay_gamma: 0.5
44
+ lr_decay_patience: 15
45
+ patience: 50
46
+ lr_scheduler: step_lr
47
+ gradient_clip: 5.0
48
+ ema: false
49
+
50
+ data_params:
51
+ nfft: 4096
52
+ nhop: 1024
53
+ nb_channels: 2
54
+ sample_rate: 44100
55
+ seq_dur: 4.0
56
+ singleset_num_frames: null
57
+ samples_per_track: 128 # "Number of samples per track to use for training."
58
+ limitaug_method: ozone
59
+ limitaug_mode: null
60
+ limitaug_custom_target_lufs: null
61
+ limitaug_custom_target_lufs_std: null
62
+ target_loudnorm_lufs: -14.0
63
+ random_mix: true
64
+ target_limitaug_mode: null
65
+ target_limitaug_custom_target_lufs: null
66
+ target_limitaug_custom_target_lufs_std: null
67
+ custom_limiter_attack_range: null
68
+ custom_limiter_release_range: null
69
+ use_fixed: 0.019 # range 0.0 ~ 1.0 => 1.0 will use fixed Ozoned_mixture training examples only.
70
+
71
+ model_loss_params:
72
+ architecture: conv_tasnet_mask_on_output # Sample-wise Gain Inversion (SGI)
73
+ train_loss_func: [si_sdr]
74
+ train_loss_scales: [1.]
75
+ valid_loss_func: [si_sdr]
76
+ valid_loss_scales: [1.]
77
+
78
+ conv_tasnet_params:
79
+ encoder_activation: relu
80
+ n_filters: 512
81
+ kernel_size: 128 # about 3ms in 44100Hz
82
+ stride: 64
83
+ n_blocks: 5
84
+ n_repeats: 2
85
+ bn_chan: 128
86
+ hid_chan: 512
87
+ skip_chan: 128
88
+ # conv_kernel_size:
89
+ # norm_type:
90
+ mask_act: relu
91
+ # causal:
92
+ decoder_activation: sigmoid
dataloader/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .dataset import aug_from_str, MusdbTrainDataset, MusdbValidDataset
2
+ from .singleset import SingleTrackSet
3
+ from .delimit_dataset import (
4
+ DelimitTrainDataset,
5
+ DelimitValidDataset,
6
+ OzoneTrainDataset,
7
+ OzoneValidDataset,
8
+ )
dataloader/dataset.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataloader based on https://github.com/jeonchangbin49/LimitAug
2
+ import os
3
+ from glob import glob
4
+ import random
5
+ from typing import Optional, Callable
6
+
7
+ import numpy as np
8
+ import torch
9
+ import librosa
10
+ from torch.utils.data import Dataset
11
+ import pyloudnorm as pyln
12
+ from pedalboard import Pedalboard, Limiter, Gain, Compressor, Clipping
13
+
14
+ from utils import load_wav_arbitrary_position_stereo, db2linear
15
+
16
+
17
+ # based on https://github.com/sigsep/open-unmix-pytorch
18
+ def aug_from_str(list_of_function_names: list):
19
+ if list_of_function_names:
20
+ return Compose([globals()["_augment_" + aug] for aug in list_of_function_names])
21
+ else:
22
+ return lambda audio: audio
23
+
24
+
25
+ class Compose(object):
26
+ """Composes several augmentation transforms.
27
+ Args:
28
+ augmentations: list of augmentations to compose.
29
+ """
30
+
31
+ def __init__(self, transforms):
32
+ self.transforms = transforms
33
+
34
+ def __call__(self, audio: torch.Tensor) -> torch.Tensor:
35
+ for t in self.transforms:
36
+ audio = t(audio)
37
+ return audio
38
+
39
+
40
+ # numpy based augmentation
41
+ # based on https://github.com/sigsep/open-unmix-pytorch
42
+ def _augment_gain(audio, low=0.25, high=1.25):
43
+ """Applies a random gain between `low` and `high`"""
44
+ g = low + random.random() * (high - low)
45
+ return audio * g
46
+
47
+
48
+ def _augment_channelswap(audio):
49
+ """Swap channels of stereo signals with a probability of p=0.5"""
50
+ if audio.shape[0] == 2 and random.random() < 0.5:
51
+ return np.flip(audio, axis=0) # axis=0 must be given
52
+ else:
53
+ return audio
54
+
55
+
56
+ # Linear gain increasing implementation for Method (1)
57
+ def apply_linear_gain_increase(mixture, target, board, meter, samplerate, target_lufs):
58
+ mixture, target = mixture.T, target.T
59
+ loudness = meter.integrated_loudness(mixture)
60
+
61
+ if np.isinf(loudness):
62
+ augmented_gain = 0.0
63
+ board[0].gain_db = augmented_gain
64
+ else:
65
+ augmented_gain = target_lufs - loudness
66
+ board[0].gain_db = augmented_gain
67
+ mixture = board(mixture.T, samplerate)
68
+ target = board(target.T, samplerate)
69
+ return mixture, target
70
+
71
+
72
+ # LimitAug implementation for Method (2) and
73
+ # implementation of LimitAug then Loudness normalization for Method (4)
74
+ def apply_limitaug(
75
+ audio,
76
+ board,
77
+ meter,
78
+ samplerate,
79
+ target_lufs,
80
+ target_loudnorm_lufs=None,
81
+ loudness=None,
82
+ ):
83
+ audio = audio.T
84
+ if loudness is None:
85
+ loudness = meter.integrated_loudness(audio)
86
+
87
+ if np.isinf(loudness):
88
+ augmented_gain = 0.0
89
+ board[0].gain_db = augmented_gain
90
+ else:
91
+ augmented_gain = target_lufs - loudness
92
+ board[0].gain_db = augmented_gain
93
+ audio = board(audio.T, samplerate)
94
+
95
+ if target_loudnorm_lufs:
96
+ after_loudness = meter.integrated_loudness(audio.T)
97
+
98
+ if np.isinf(after_loudness):
99
+ pass
100
+ else:
101
+ target_gain = target_loudnorm_lufs - after_loudness
102
+ audio = audio * db2linear(target_gain)
103
+ return audio, loudness
104
+
105
+
106
+ """
107
+ This dataloader implementation is based on https://github.com/sigsep/open-unmix-pytorch
108
+ """
109
+
110
+
111
+ class MusdbTrainDataset(Dataset):
112
+ def __init__(
113
+ self,
114
+ target: str = "vocals",
115
+ root: str = None,
116
+ seq_duration: Optional[float] = 6.0,
117
+ samples_per_track: int = 64,
118
+ source_augmentations: Optional[Callable] = lambda audio: audio,
119
+ sample_rate: int = 44100,
120
+ seed: int = 42,
121
+ limitaug_method: str = "limitaug_then_loudnorm",
122
+ limitaug_mode: str = "normal_L",
123
+ limitaug_custom_target_lufs: float = None,
124
+ limitaug_custom_target_lufs_std: float = None,
125
+ target_loudnorm_lufs: float = -14.0,
126
+ custom_limiter_attack_range: list = [2.0, 2.0],
127
+ custom_limiter_release_range: list = [200.0, 200.0],
128
+ *args,
129
+ **kwargs,
130
+ ) -> None:
131
+ """
132
+ Parameters
133
+ ----------
134
+ limitaug_method : str
135
+ choose from ["linear_gain_increase", "limitaug", "limitaug_then_loudnorm", "only_loudnorm"]
136
+ limitaug_mode : str
137
+ choose from ["uniform", "normal", "normal_L", "normal_XL", "normal_short_term", "normal_L_short_term", "normal_XL_short_term", "custom"]
138
+ limitaug_custom_target_lufs : float
139
+ valid only when
140
+ limitaug_mode == "custom"
141
+ limitaug_custom_target_lufs_std : float
142
+ also valid only when
143
+ limitaug_mode == "custom
144
+ target_loudnorm_lufs : float
145
+ valid only when
146
+ limitaug_method == 'limitaug_then_loudnorm' or 'only_loudnorm'
147
+ default is -14.
148
+ To the best of my knowledge, Spotify and Youtube music is using -14 as a reference loudness normalization level.
149
+ No special reason for the choice of -14 as target_loudnorm_lufs.
150
+ target : str
151
+ target name of the source to be separated, defaults to ``vocals``.
152
+ root : str
153
+ root path of MUSDB
154
+ seq_duration : float
155
+ training is performed in chunks of ``seq_duration`` (in seconds,
156
+ defaults to ``None`` which loads the full audio track
157
+ samples_per_track : int
158
+ sets the number of samples, yielded from each track per epoch.
159
+ Defaults to 64
160
+ source_augmentations : list[callables]
161
+ provide list of augmentation function that take a multi-channel
162
+ audio file of shape (src, samples) as input and output. Defaults to
163
+ no-augmentations (input = output)
164
+ seed : int
165
+ control randomness of dataset iterations
166
+ args, kwargs : additional keyword arguments
167
+ used to add further control for the musdb dataset
168
+ initialization function.
169
+ """
170
+
171
+ self.seed = seed
172
+ random.seed(seed)
173
+ self.seq_duration = seq_duration
174
+ self.target = target
175
+ self.samples_per_track = samples_per_track
176
+ self.source_augmentations = source_augmentations
177
+ self.sample_rate = sample_rate
178
+
179
+ self.root = root
180
+ self.sources = ["vocals", "bass", "drums", "other"]
181
+ self.train_list = glob(f"{self.root}/train/*")
182
+ self.valid_list = [
183
+ "ANiMAL - Rockshow",
184
+ "Actions - One Minute Smile",
185
+ "Alexander Ross - Goodbye Bolero",
186
+ "Clara Berry And Wooldog - Waltz For My Victims",
187
+ "Fergessen - Nos Palpitants",
188
+ "James May - On The Line",
189
+ "Johnny Lokke - Promises & Lies",
190
+ "Leaf - Summerghost",
191
+ "Meaxic - Take A Step",
192
+ "Patrick Talbot - A Reason To Leave",
193
+ "Skelpolu - Human Mistakes",
194
+ "Traffic Experiment - Sirens",
195
+ "Triviul - Angelsaint",
196
+ "Young Griffo - Pennies",
197
+ ]
198
+
199
+ self.train_list = [
200
+ x for x in self.train_list if os.path.basename(x) not in self.valid_list
201
+ ]
202
+
203
+ # limitaug related
204
+ self.limitaug_method = limitaug_method
205
+ self.limitaug_mode = limitaug_mode
206
+ self.limitaug_custom_target_lufs = limitaug_custom_target_lufs
207
+ self.limitaug_custom_target_lufs_std = limitaug_custom_target_lufs_std
208
+ self.target_loudnorm_lufs = target_loudnorm_lufs
209
+ self.meter = pyln.Meter(self.sample_rate)
210
+
211
+ # Method (1) in our paper's Results section and Table 5
212
+ if self.limitaug_method == "linear_gain_increase":
213
+ print("using linear gain increasing!")
214
+ self.board = Pedalboard([Gain(gain_db=0.0)])
215
+
216
+ # Method (2) in our paper's Results section and Table 5
217
+ elif self.limitaug_method == "limitaug":
218
+ print("using limitaug!")
219
+ self.board = Pedalboard(
220
+ [Gain(gain_db=0.0), Limiter(threshold_db=0.0, release_ms=100.0)]
221
+ )
222
+
223
+ # Method (3) in our paper's Results section and Table 5
224
+ elif self.limitaug_method == "only_loudnorm":
225
+ print("using only loudness normalized inputs")
226
+
227
+ # Method (4) in our paper's Results section and Table 5
228
+ elif self.limitaug_method == "limitaug_then_loudnorm":
229
+ print("using limitaug then loudness normalize!")
230
+ self.board = Pedalboard(
231
+ [Gain(gain_db=0.0), Limiter(threshold_db=0.0, release_ms=100.0)]
232
+ )
233
+
234
+ elif self.limitaug_method == "custom_limiter_limitaug":
235
+ print("using Custom limiter limitaug!")
236
+ self.custom_limiter_attack_range = custom_limiter_attack_range
237
+ self.custom_limiter_release_range = custom_limiter_release_range
238
+ self.board = Pedalboard(
239
+ [
240
+ Gain(gain_db=0.0),
241
+ Compressor(
242
+ threshold_db=-10.0, ratio=4.0, attack_ms=2.0, release_ms=200.0
243
+ ), # attack_ms and release_ms will be changed later.
244
+ Compressor(
245
+ threshold_db=0.0,
246
+ ratio=1000.0,
247
+ attack_ms=0.001,
248
+ release_ms=100.0,
249
+ ),
250
+ Gain(gain_db=3.75),
251
+ Clipping(threshold_db=0.0),
252
+ ]
253
+ ) # This implementation is the same as JUCE Limiter.
254
+ # However, we want the first compressor to have a variable attack and release time.
255
+ # Therefore, we use the Custom Limiter instead of the JUCE Limiter.
256
+
257
+ self.limitaug_mode_statistics = {
258
+ "normal": [
259
+ -15.954,
260
+ 1.264,
261
+ ], # -15.954 is mean LUFS of musdb-hq and 1.264 is standard deviation
262
+ "normal_L": [
263
+ -10.887,
264
+ 1.191,
265
+ ], # -10.887 is mean LUFS of musdb-L and 1.191 is standard deviation
266
+ "normal_XL": [
267
+ -8.608,
268
+ 1.165,
269
+ ], # -8.608 is mean LUFS of musdb-L and 1.165 is standard deviation
270
+ "normal_short_term": [
271
+ -17.317,
272
+ 5.036,
273
+ ], # In our experiments, short-term statistics were not helpful.
274
+ "normal_L_short_term": [-12.303, 5.233],
275
+ "normal_XL_short_term": [-9.988, 5.518],
276
+ "custom": [limitaug_custom_target_lufs, limitaug_custom_target_lufs_std],
277
+ }
278
+
279
+ def sample_target_lufs(self):
280
+ if (
281
+ self.limitaug_mode == "uniform"
282
+ ): # if limitaug_mode is uniform, then choose target_lufs from uniform distribution
283
+ target_lufs = random.uniform(-20, -5)
284
+ else: # else, choose target_lufs from gaussian distribution
285
+ target_lufs = random.gauss(
286
+ self.limitaug_mode_statistics[self.limitaug_mode][0],
287
+ self.limitaug_mode_statistics[self.limitaug_mode][1],
288
+ )
289
+
290
+ return target_lufs
291
+
292
+ def get_limitaug_results(self, mixture, target):
293
+ # Apply linear gain increasing (Method (1))
294
+ if self.limitaug_method == "linear_gain_increase":
295
+ target_lufs = self.sample_target_lufs()
296
+ mixture, target = apply_linear_gain_increase(
297
+ mixture,
298
+ target,
299
+ self.board,
300
+ self.meter,
301
+ self.sample_rate,
302
+ target_lufs=target_lufs,
303
+ )
304
+
305
+ # Apply LimitAug (Method (2))
306
+ elif self.limitaug_method == "limitaug":
307
+ self.board[1].release_ms = random.uniform(30.0, 200.0)
308
+ mixture_orig = mixture.copy()
309
+ target_lufs = self.sample_target_lufs()
310
+ mixture, _ = apply_limitaug(
311
+ mixture,
312
+ self.board,
313
+ self.meter,
314
+ self.sample_rate,
315
+ target_lufs=target_lufs,
316
+ )
317
+ print("mixture shape:", mixture.shape)
318
+ print("target shape:", target.shape)
319
+ target *= mixture / (mixture_orig + 1e-8)
320
+
321
+ # Apply only loudness normalization (Method(3))
322
+ elif self.limitaug_method == "only_loudnorm":
323
+ mixture_loudness = self.meter.integrated_loudness(mixture.T)
324
+ if np.isinf(
325
+ mixture_loudness
326
+ ): # if the source is silence, then mixture_loudness is -inf.
327
+ pass
328
+ else:
329
+ augmented_gain = (
330
+ self.target_loudnorm_lufs - mixture_loudness
331
+ ) # default target_loudnorm_lufs is -14.
332
+ mixture = mixture * db2linear(augmented_gain)
333
+ target = target * db2linear(augmented_gain)
334
+
335
+ # Apply LimitAug then loudness normalization (Method (4))
336
+ elif self.limitaug_method == "limitaug_then_loudnorm":
337
+ self.board[1].release_ms = random.uniform(30.0, 200.0)
338
+ mixture_orig = mixture.copy()
339
+ target_lufs = self.sample_target_lufs()
340
+ mixture, _ = apply_limitaug(
341
+ mixture,
342
+ self.board,
343
+ self.meter,
344
+ self.sample_rate,
345
+ target_lufs=target_lufs,
346
+ target_loudnorm_lufs=self.target_loudnorm_lufs,
347
+ )
348
+ target *= mixture / (mixture_orig + 1e-8)
349
+
350
+ # Apply LimitAug using Custom Limiter
351
+ elif self.limitaug_method == "custom_limiter_limitaug":
352
+ # Change attack time of First compressor of the Limiter
353
+ self.board[1].attack_ms = random.uniform(
354
+ self.custom_limiter_attack_range[0], self.custom_limiter_attack_range[1]
355
+ )
356
+ # Change release time of First compressor of the Limiter
357
+ self.board[1].release_ms = random.uniform(
358
+ self.custom_limiter_release_range[0],
359
+ self.custom_limiter_release_range[1],
360
+ )
361
+ # Change release time of Second compressor of the Limiter
362
+ self.board[2].release_ms = random.uniform(30.0, 200.0)
363
+ mixture_orig = mixture.copy()
364
+ target_lufs = self.sample_target_lufs()
365
+ mixture, _ = apply_limitaug(
366
+ mixture,
367
+ self.board,
368
+ self.meter,
369
+ self.sample_rate,
370
+ target_lufs=target_lufs,
371
+ target_loudnorm_lufs=self.target_loudnorm_lufs,
372
+ )
373
+ target *= mixture / (mixture_orig + 1e-8)
374
+
375
+ return mixture, target
376
+
377
+ def __getitem__(self, index):
378
+ audio_sources = []
379
+ target_ind = None
380
+
381
+ for k, source in enumerate(self.sources):
382
+ # memorize index of target source
383
+ if source == self.target: # if source is 'vocals'
384
+ target_ind = k
385
+ track_path = self.train_list[
386
+ index // self.samples_per_track
387
+ ] # we want to use # training samples per each track.
388
+ audio_path = f"{track_path}/{source}.wav"
389
+ audio = load_wav_arbitrary_position_stereo(
390
+ audio_path, self.sample_rate, self.seq_duration
391
+ )
392
+ else:
393
+ track_path = random.choice(self.train_list)
394
+ audio_path = f"{track_path}/{source}.wav"
395
+ audio = load_wav_arbitrary_position_stereo(
396
+ audio_path, self.sample_rate, self.seq_duration
397
+ )
398
+ audio = self.source_augmentations(audio)
399
+ audio_sources.append(audio)
400
+
401
+ stems = np.stack(audio_sources, axis=0)
402
+
403
+ # # apply linear mix over source index=0
404
+ x = stems.sum(0)
405
+ # get the target stem
406
+ y = stems[target_ind]
407
+
408
+ # Apply the limitaug,
409
+ x, y = self.get_limitaug_results(x, y)
410
+
411
+ x = torch.as_tensor(x, dtype=torch.float32)
412
+ y = torch.as_tensor(y, dtype=torch.float32)
413
+
414
+ return x, y
415
+
416
+ def __len__(self):
417
+ return len(self.train_list) * self.samples_per_track
418
+
419
+
420
+ class MusdbValidDataset(Dataset):
421
+ def __init__(
422
+ self,
423
+ target: str = "vocals",
424
+ root: str = None,
425
+ *args,
426
+ **kwargs,
427
+ ) -> None:
428
+ """MUSDB18 torch.data.Dataset that samples from the MUSDB tracks
429
+ using track and excerpts with replacement.
430
+ Parameters
431
+ ----------
432
+ target : str
433
+ target name of the source to be separated, defaults to ``vocals``.
434
+ root : str
435
+ root path of MUSDB18HQ dataset, defaults to ``None``.
436
+ args, kwargs : additional keyword arguments
437
+ used to add further control for the musdb dataset
438
+ initialization function.
439
+ """
440
+ self.target = target
441
+ self.sample_rate = 44100.0 # musdb is fixed sample rate
442
+
443
+ self.root = root
444
+ self.sources = ["vocals", "bass", "drums", "other"]
445
+ self.train_list = glob(f"{self.root}/train/*")
446
+
447
+ self.valid_list = [
448
+ "ANiMAL - Rockshow",
449
+ "Actions - One Minute Smile",
450
+ "Alexander Ross - Goodbye Bolero",
451
+ "Clara Berry And Wooldog - Waltz For My Victims",
452
+ "Fergessen - Nos Palpitants",
453
+ "James May - On The Line",
454
+ "Johnny Lokke - Promises & Lies",
455
+ "Leaf - Summerghost",
456
+ "Meaxic - Take A Step",
457
+ "Patrick Talbot - A Reason To Leave",
458
+ "Skelpolu - Human Mistakes",
459
+ "Traffic Experiment - Sirens",
460
+ "Triviul - Angelsaint",
461
+ "Young Griffo - Pennies",
462
+ ]
463
+ self.valid_list = [
464
+ x for x in self.train_list if os.path.basename(x) in self.valid_list
465
+ ]
466
+
467
+ def __getitem__(self, index):
468
+ audio_sources = []
469
+ target_ind = None
470
+
471
+ for k, source in enumerate(self.sources):
472
+ # memorize index of target source
473
+ if source == self.target: # if source is 'vocals'
474
+ target_ind = k
475
+ track_path = self.valid_list[index]
476
+ song_name = os.path.basename(track_path)
477
+ audio_path = f"{track_path}/{source}.wav"
478
+ # audio = utils.load_wav_stereo(audio_path, self.sample_rate)
479
+ audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
480
+ else:
481
+ track_path = self.valid_list[index]
482
+ song_name = os.path.basename(track_path)
483
+ audio_path = f"{track_path}/{source}.wav"
484
+ # audio = utils.load_wav_stereo(audio_path, self.sample_rate)
485
+ audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
486
+
487
+ audio = torch.as_tensor(audio, dtype=torch.float32)
488
+ audio_sources.append(audio)
489
+
490
+ stems = torch.stack(audio_sources, dim=0)
491
+ # # apply linear mix over source index=0
492
+ x = stems.sum(0)
493
+ # get the target stem
494
+ y = stems[target_ind]
495
+
496
+ return x, y, song_name
497
+
498
+ def __len__(self):
499
+ return len(self.valid_list)
500
+
501
+
502
+ # If you want to check the LUFS values of training examples, run this.
503
+ if __name__ == "__main__":
504
+ import argparse
505
+
506
+ parser = argparse.ArgumentParser(
507
+ description="Make musdb-L and musdb-XL dataset from its ratio data"
508
+ )
509
+
510
+ parser.add_argument(
511
+ "--musdb_root",
512
+ type=str,
513
+ default="/path/to/musdb",
514
+ help="root path of musdb-hq dataset",
515
+ )
516
+ parser.add_argument(
517
+ "--limitaug_method",
518
+ type=str,
519
+ default="limitaug",
520
+ choices=[
521
+ "linear_gain_increase",
522
+ "limitaug",
523
+ "limitaug_then_loudnorm",
524
+ "only_loudnorm",
525
+ None,
526
+ ],
527
+ help="choose limitaug method",
528
+ )
529
+ parser.add_argument(
530
+ "--limitaug_mode",
531
+ type=str,
532
+ default="normal_L",
533
+ choices=[
534
+ "uniform",
535
+ "normal",
536
+ "normal_L",
537
+ "normal_XL",
538
+ "normal_short_term",
539
+ "normal_L_short_term",
540
+ "normal_XL_short_term",
541
+ "custom",
542
+ ],
543
+ help="if you use LimitAug, what lufs distribution to target",
544
+ )
545
+ parser.add_argument(
546
+ "--limitaug_custom_target_lufs",
547
+ type=float,
548
+ default=None,
549
+ help="if limitaug_mode is custom, set custom target lufs for LimitAug",
550
+ )
551
+
552
+ args, _ = parser.parse_known_args()
553
+
554
+ source_augmentations_ = aug_from_str(["gain", "channelswap"])
555
+
556
+ train_dataset = MusdbTrainDataset(
557
+ target="vocals",
558
+ root=args.musdb_root,
559
+ seq_duration=6.0,
560
+ source_augmentations=source_augmentations_,
561
+ limitaug_method=args.limitaug_method,
562
+ limitaug_mode=args.limitaug_mode,
563
+ limitaug_custom_target_lufs=args.limitaug_custom_target_lufs,
564
+ )
565
+
566
+ dataloader = torch.utils.data.DataLoader(
567
+ train_dataset,
568
+ batch_size=1,
569
+ shuffle=True,
570
+ num_workers=4,
571
+ pin_memory=True,
572
+ drop_last=False,
573
+ )
574
+
575
+ meter = pyln.Meter(44100)
576
+ for i in range(5):
577
+ for x, y in dataloader:
578
+ loudness = meter.integrated_loudness(x[0].numpy().T)
579
+ print(f"mixture loudness : {loudness} LUFS")
dataloader/delimit_dataset.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import Optional, Callable
4
+ import json
5
+ import glob
6
+ import csv
7
+
8
+ import numpy as np
9
+ import torch
10
+ import librosa
11
+ import pyloudnorm as pyln
12
+ from pedalboard import Pedalboard, Limiter, Gain, Compressor, Clipping
13
+
14
+ from .dataset import (
15
+ MusdbTrainDataset,
16
+ MusdbValidDataset,
17
+ apply_limitaug,
18
+ )
19
+ from utils import (
20
+ load_wav_arbitrary_position_stereo,
21
+ load_wav_specific_position_stereo,
22
+ db2linear,
23
+ )
24
+
25
+
26
+ class DelimitTrainDataset(MusdbTrainDataset):
27
+ def __init__(
28
+ self,
29
+ target: str = "all",
30
+ root: str = None,
31
+ seq_duration: Optional[float] = 6.0,
32
+ samples_per_track: int = 64,
33
+ source_augmentations: Optional[Callable] = lambda audio: audio,
34
+ sample_rate: int = 44100,
35
+ seed: int = 42,
36
+ limitaug_method: str = "limitaug",
37
+ limitaug_mode: str = "normal_L",
38
+ limitaug_custom_target_lufs: float = None,
39
+ limitaug_custom_target_lufs_std: float = None,
40
+ target_loudnorm_lufs: float = -14.0,
41
+ target_limitaug_mode: str = None,
42
+ target_limitaug_custom_target_lufs: float = None,
43
+ target_limitaug_custom_target_lufs_std: float = None,
44
+ custom_limiter_attack_range: list = [2.0, 2.0],
45
+ custom_limiter_release_range: list = [200.0, 200.0],
46
+ *args,
47
+ **kwargs,
48
+ ) -> None:
49
+ super().__init__(
50
+ target=target,
51
+ root=root,
52
+ seq_duration=seq_duration,
53
+ samples_per_track=samples_per_track,
54
+ source_augmentations=source_augmentations,
55
+ sample_rate=sample_rate,
56
+ seed=seed,
57
+ limitaug_method=limitaug_method,
58
+ limitaug_mode=limitaug_mode,
59
+ limitaug_custom_target_lufs=limitaug_custom_target_lufs,
60
+ limitaug_custom_target_lufs_std=limitaug_custom_target_lufs_std,
61
+ target_loudnorm_lufs=target_loudnorm_lufs,
62
+ custom_limiter_attack_range=custom_limiter_attack_range,
63
+ custom_limiter_release_range=custom_limiter_release_range,
64
+ *args,
65
+ **kwargs,
66
+ )
67
+
68
+ self.target_limitaug_mode = target_limitaug_mode
69
+
70
+ self.target_limitaug_custom_target_lufs = (target_limitaug_custom_target_lufs,)
71
+ self.target_limitaug_custom_target_lufs_std = (
72
+ target_limitaug_custom_target_lufs_std,
73
+ )
74
+ self.limitaug_mode_statistics["target_custom"] = [
75
+ target_limitaug_custom_target_lufs,
76
+ target_limitaug_custom_target_lufs_std,
77
+ ]
78
+
79
+ """
80
+ Parameters
81
+ ----------
82
+ limitaug_method : str
83
+ choose from ["linear_gain_increase", "limitaug", "limitaug_then_loudnorm", "only_loudnorm"]
84
+ limitaug_mode : str
85
+ choose from ["uniform", "normal", "normal_L", "normal_XL", "normal_short_term", "normal_L_short_term", "normal_XL_short_term", "custom"]
86
+ limitaug_custom_target_lufs : float
87
+ valid only when
88
+ limitaug_mode == "custom"
89
+ target_loudnorm_lufs : float
90
+ valid only when
91
+ limitaug_method == 'limitaug_then_loudnorm' or 'only_loudnorm'
92
+ default is -14.
93
+ To the best of my knowledge, Spotify and Youtube music is using -14 as a reference loudness normalization level.
94
+ No special reason for the choice of -14 as target_loudnorm_lufs.
95
+ target : str
96
+ target name of the source to be separated, defaults to ``vocals``.
97
+ root : str
98
+ root path of MUSDB
99
+ seq_duration : float
100
+ training is performed in chunks of ``seq_duration`` (in seconds,
101
+ defaults to ``None`` which loads the full audio track
102
+ samples_per_track : int
103
+ sets the number of samples, yielded from each track per epoch.
104
+ Defaults to 64
105
+ source_augmentations : list[callables]
106
+ provide list of augmentation function that take a multi-channel
107
+ audio file of shape (src, samples) as input and output. Defaults to
108
+ no-augmentations (input = output)
109
+ seed : int
110
+ control randomness of dataset iterations
111
+ args, kwargs : additional keyword arguments
112
+ used to add further control for the musdb dataset
113
+ initialization function.
114
+ """
115
+
116
+ # Get a limitaug result without target (individual stem source)
117
+ def get_limitaug_mixture(self, mixture):
118
+ if self.limitaug_method == "limitaug":
119
+ self.board[1].release_ms = random.uniform(30.0, 200.0)
120
+ target_lufs = self.sample_target_lufs()
121
+ mixture_limited, mixture_lufs = apply_limitaug(
122
+ mixture,
123
+ self.board,
124
+ self.meter,
125
+ self.sample_rate,
126
+ target_lufs=target_lufs,
127
+ )
128
+
129
+ elif self.limitaug_method == "limitaug_then_loudnorm":
130
+ self.board[1].release_ms = random.uniform(30.0, 200.0)
131
+ target_lufs = self.sample_target_lufs()
132
+ mixture_limited, mixture_lufs = (
133
+ apply_limitaug(
134
+ mixture,
135
+ self.board,
136
+ self.meter,
137
+ self.sample_rate,
138
+ target_lufs=target_lufs,
139
+ target_loudnorm_lufs=self.target_loudnorm_lufs,
140
+ ),
141
+ )
142
+
143
+ # Apply LimitAug using Custom Limiter
144
+ elif self.limitaug_method == "custom_limiter_limitaug":
145
+ # Change attack time of First compressor of the Limiter
146
+ self.board[1].attack_ms = random.uniform(
147
+ self.custom_limiter_attack_range[0], self.custom_limiter_attack_range[1]
148
+ )
149
+ # Change release time of First compressor of the Limiter
150
+ self.board[1].release_ms = random.uniform(
151
+ self.custom_limiter_release_range[0],
152
+ self.custom_limiter_release_range[1],
153
+ )
154
+ # Change release time of Second compressor of the Limiter
155
+ self.board[2].release_ms = random.uniform(30.0, 200.0)
156
+ target_lufs = self.sample_target_lufs()
157
+ mixture_limited, mixture_lufs = apply_limitaug(
158
+ mixture,
159
+ self.board,
160
+ self.meter,
161
+ self.sample_rate,
162
+ target_lufs=target_lufs,
163
+ target_loudnorm_lufs=self.target_loudnorm_lufs,
164
+ )
165
+
166
+ # When we want to force NN to output an appropriately compressed target output
167
+ if self.target_limitaug_mode:
168
+ mixture_target_lufs = random.gauss(
169
+ self.limitaug_mode_statistics[self.target_limitaug_mode][0],
170
+ self.limitaug_mode_statistics[self.target_limitaug_mode][1],
171
+ )
172
+ mixture, target_lufs = apply_limitaug(
173
+ mixture,
174
+ self.board,
175
+ self.meter,
176
+ self.sample_rate,
177
+ target_lufs=mixture_target_lufs,
178
+ loudness=mixture_lufs,
179
+ )
180
+
181
+ if np.isinf(mixture_lufs):
182
+ mixture_loudnorm = mixture
183
+ else:
184
+ augmented_gain = self.target_loudnorm_lufs - mixture_lufs
185
+ mixture_loudnorm = mixture * db2linear(augmented_gain, eps=0.0)
186
+
187
+ return mixture_limited, mixture_loudnorm
188
+
189
+ def __getitem__(self, index):
190
+ audio_sources = []
191
+
192
+ for k, source in enumerate(self.sources):
193
+ # memorize index of target source
194
+ if source == self.target: # if source is 'vocals'
195
+ track_path = self.train_list[
196
+ index // self.samples_per_track
197
+ ] # we want to use # training samples per each track.
198
+ audio_path = f"{track_path}/{source}.wav"
199
+ audio = load_wav_arbitrary_position_stereo(
200
+ audio_path, self.sample_rate, self.seq_duration
201
+ )
202
+ else:
203
+ track_path = random.choice(self.train_list)
204
+ audio_path = f"{track_path}/{source}.wav"
205
+ audio = load_wav_arbitrary_position_stereo(
206
+ audio_path, self.sample_rate, self.seq_duration
207
+ )
208
+ audio = self.source_augmentations(audio)
209
+ audio_sources.append(audio)
210
+
211
+ stems = np.stack(audio_sources, axis=0)
212
+
213
+ # apply linear mix over source index=0
214
+ # and here, linear mixture is a target unlike in MusdbTrainDataset
215
+ mixture = stems.sum(0)
216
+ mixture_limited, mixture_loudnorm = self.get_limitaug_mixture(mixture)
217
+ # We will give mixture_limited as an input and mixture_loudnorm as a target to the model.
218
+
219
+ mixture_limited = np.clip(mixture_limited, -1.0, 1.0)
220
+ mixture_limited = torch.as_tensor(mixture_limited, dtype=torch.float32)
221
+ mixture_loudnorm = torch.as_tensor(mixture_loudnorm, dtype=torch.float32)
222
+
223
+ return mixture_limited, mixture_loudnorm
224
+
225
+
226
+ class OzoneTrainDataset(DelimitTrainDataset):
227
+ def __init__(
228
+ self,
229
+ target: str = "all",
230
+ root: str = None,
231
+ ozone_root: str = None,
232
+ use_fixed: float = 0.1, # ratio of fixed samples
233
+ seq_duration: Optional[float] = 6.0,
234
+ samples_per_track: int = 64,
235
+ source_augmentations: Optional[Callable] = lambda audio: audio,
236
+ sample_rate: int = 44100,
237
+ seed: int = 42,
238
+ limitaug_method: str = "limitaug",
239
+ limitaug_mode: str = "normal_L",
240
+ limitaug_custom_target_lufs: float = None,
241
+ limitaug_custom_target_lufs_std: float = None,
242
+ target_loudnorm_lufs: float = -14.0,
243
+ target_limitaug_mode: str = None,
244
+ target_limitaug_custom_target_lufs: float = None,
245
+ target_limitaug_custom_target_lufs_std: float = None,
246
+ custom_limiter_attack_range: list = [2.0, 2.0],
247
+ custom_limiter_release_range: list = [200.0, 200.0],
248
+ *args,
249
+ **kwargs,
250
+ ) -> None:
251
+ super().__init__(
252
+ target,
253
+ root,
254
+ seq_duration,
255
+ samples_per_track,
256
+ source_augmentations,
257
+ sample_rate,
258
+ seed,
259
+ limitaug_method,
260
+ limitaug_mode,
261
+ limitaug_custom_target_lufs,
262
+ limitaug_custom_target_lufs_std,
263
+ target_loudnorm_lufs,
264
+ target_limitaug_mode,
265
+ target_limitaug_custom_target_lufs,
266
+ target_limitaug_custom_target_lufs_std,
267
+ custom_limiter_attack_range,
268
+ custom_limiter_release_range,
269
+ *args,
270
+ **kwargs,
271
+ )
272
+
273
+ self.ozone_root = ozone_root
274
+ self.use_fixed = use_fixed
275
+ self.list_train_fixed = glob.glob(f"{self.ozone_root}/ozone_train_fixed/*.wav")
276
+ self.list_train_random = glob.glob(
277
+ f"{self.ozone_root}/ozone_train_random/*.wav"
278
+ )
279
+ self.dict_train_random = {}
280
+
281
+ # Load information of pre-generated random training examples
282
+ list_csv_files = glob.glob(f"{self.ozone_root}/ozone_train_random_*.csv")
283
+ list_csv_files.sort()
284
+ for csv_file in list_csv_files:
285
+ with open(csv_file, "r") as f:
286
+ reader = csv.reader(f)
287
+ next(reader)
288
+ for row in reader:
289
+ self.dict_train_random[row[0]] = {
290
+ "max_threshold": float(row[1]),
291
+ "max_character": float(row[2]),
292
+ "vocals": {
293
+ "name": row[3],
294
+ "start_sec": float(row[4]),
295
+ "gain": float(row[5]),
296
+ "channelswap": bool(row[6]),
297
+ },
298
+ "bass": {
299
+ "name": row[7],
300
+ "start_sec": float(row[8]),
301
+ "gain": float(row[9]),
302
+ "channelswap": bool(row[10]),
303
+ },
304
+ "drums": {
305
+ "name": row[11],
306
+ "start_sec": float(row[12]),
307
+ "gain": float(row[13]),
308
+ "channelswap": bool(row[14]),
309
+ },
310
+ "other": {
311
+ "name": row[15],
312
+ "start_sec": float(row[16]),
313
+ "gain": float(row[17]),
314
+ "channelswap": bool(row[18]),
315
+ },
316
+ }
317
+
318
+ def __getitem__(self, idx):
319
+ use_fixed_prob = random.random()
320
+
321
+ if use_fixed_prob <= self.use_fixed:
322
+ # Fixed examples
323
+ audio_path = random.choice(self.list_train_fixed)
324
+ song_name = os.path.basename(audio_path).replace(".wav", "")
325
+ mixture_limited, start_pos_sec = load_wav_arbitrary_position_stereo(
326
+ audio_path, self.sample_rate, self.seq_duration, return_pos=True
327
+ )
328
+
329
+ audio_sources = []
330
+ track_path = f"{self.root}/train/{song_name}"
331
+ for source in self.sources:
332
+ audio_path = f"{track_path}/{source}.wav"
333
+ audio = load_wav_specific_position_stereo(
334
+ audio_path,
335
+ self.sample_rate,
336
+ self.seq_duration,
337
+ start_position=start_pos_sec,
338
+ )
339
+ audio_sources.append(audio)
340
+
341
+ else:
342
+ # Random examples
343
+ # Load mixture_limited (pre-generated)
344
+ audio_path = random.choice(self.list_train_random)
345
+ seg_name = os.path.basename(audio_path).replace(".wav", "")
346
+ mixture_limited, sr = librosa.load(
347
+ audio_path, sr=self.sample_rate, mono=False
348
+ )
349
+
350
+ # Load mixture_unlimited (from the original musdb18, using metadata)
351
+ audio_sources = []
352
+ for source in self.sources:
353
+ dict_seg_info = self.dict_train_random[seg_name]
354
+ dict_seg_source_info = dict_seg_info[source]
355
+ audio_path = (
356
+ f"{self.root}/train/{dict_seg_source_info['name']}/{source}.wav"
357
+ )
358
+ audio = load_wav_specific_position_stereo(
359
+ audio_path,
360
+ self.sample_rate,
361
+ self.seq_duration,
362
+ start_position=dict_seg_source_info["start_sec"],
363
+ )
364
+
365
+ # apply augmentations
366
+ audio = audio * dict_seg_source_info["gain"]
367
+ if dict_seg_source_info["channelswap"]:
368
+ audio = np.flip(audio, axis=0)
369
+
370
+ audio_sources.append(audio)
371
+
372
+ stems = np.stack(audio_sources, axis=0)
373
+ mixture = stems.sum(axis=0)
374
+ mixture_lufs = self.meter.integrated_loudness(mixture.T)
375
+ if np.isinf(mixture_lufs):
376
+ mixture_loudnorm = mixture
377
+ else:
378
+ augmented_gain = self.target_loudnorm_lufs - mixture_lufs
379
+ mixture_loudnorm = mixture * db2linear(augmented_gain, eps=0.0)
380
+
381
+ return mixture_limited, mixture_loudnorm
382
+
383
+
384
+ class DelimitValidDataset(MusdbValidDataset):
385
+ def __init__(
386
+ self,
387
+ target: str = "vocals",
388
+ root: str = None,
389
+ delimit_valid_root: str = None,
390
+ valid_target_lufs: float = -8.05, # From the Table 1 of the "Towards robust music source separation on loud commercial music" paper, the average loudness of commerical music.
391
+ target_loudnorm_lufs: float = -14.0,
392
+ delimit_valid_L_root: str = None, # This will be used when using the target as compressed (normal_L) mixture.
393
+ use_custom_limiter: bool = False,
394
+ custom_limiter_attack_range: list = [0.1, 10.0],
395
+ custom_limiter_release_range: list = [30.0, 200.0],
396
+ *args,
397
+ **kwargs,
398
+ ) -> None:
399
+ super().__init__(target=target, root=root, *args, **kwargs)
400
+ self.delimit_valid_root = delimit_valid_root
401
+ if self.delimit_valid_root:
402
+ with open(f"{self.delimit_valid_root}/valid_loudness.json", "r") as f:
403
+ self.dict_valid_loudness = json.load(f)
404
+ self.delimit_valid_L_root = delimit_valid_L_root
405
+ if self.delimit_valid_L_root:
406
+ with open(f"{self.delimit_valid_L_root}/valid_loudness.json", "r") as f:
407
+ self.dict_valid_L_loudness = json.load(f)
408
+
409
+ self.valid_target_lufs = valid_target_lufs
410
+ self.target_loudnorm_lufs = target_loudnorm_lufs
411
+ self.meter = pyln.Meter(self.sample_rate)
412
+ self.use_custom_limiter = use_custom_limiter
413
+
414
+ if self.use_custom_limiter:
415
+ print("using Custom limiter limitaug for validation!!")
416
+ self.custom_limiter_attack_range = custom_limiter_attack_range
417
+ self.custom_limiter_release_range = custom_limiter_release_range
418
+ self.board = Pedalboard(
419
+ [
420
+ Gain(gain_db=0.0),
421
+ Compressor(
422
+ threshold_db=-10.0, ratio=4.0, attack_ms=2.0, release_ms=200.0
423
+ ), # attack_ms and release_ms will be changed later.
424
+ Compressor(
425
+ threshold_db=0.0,
426
+ ratio=1000.0,
427
+ attack_ms=0.001,
428
+ release_ms=100.0,
429
+ ),
430
+ Gain(gain_db=3.75),
431
+ Clipping(threshold_db=0.0),
432
+ ]
433
+ ) # This implementation is the same as JUCE Limiter.
434
+ # However, we want the first compressor to have a variable attack and release time.
435
+ # Therefore, we use the Custom Limiter instead of the JUCE Limiter.
436
+ else:
437
+ self.board = Pedalboard(
438
+ [Gain(gain_db=0.0), Limiter(threshold_db=0.0, release_ms=100.0)]
439
+ ) # Currently, we are using a limiter with a release time of 100ms.
440
+
441
+ def __getitem__(self, index):
442
+ audio_sources = []
443
+ target_ind = None
444
+
445
+ for k, source in enumerate(self.sources):
446
+ # memorize index of target source
447
+ if source == self.target: # if source is 'vocals'
448
+ target_ind = k
449
+ track_path = self.valid_list[index]
450
+ song_name = os.path.basename(track_path)
451
+ audio_path = f"{track_path}/{source}.wav"
452
+ # audio = utils.load_wav_stereo(audio_path, self.sample_rate)
453
+ audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
454
+ else:
455
+ track_path = self.valid_list[index]
456
+ song_name = os.path.basename(track_path)
457
+ audio_path = f"{track_path}/{source}.wav"
458
+ # audio = utils.load_wav_stereo(audio_path, self.sample_rate)
459
+ audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
460
+
461
+ audio = torch.as_tensor(audio, dtype=torch.float32)
462
+ audio_sources.append(audio)
463
+
464
+ stems = np.stack(audio_sources, axis=0)
465
+
466
+ # apply linear mix over source index=0
467
+ # and here, linear mixture is a target unlike in MusdbTrainDataset
468
+ mixture = stems.sum(0)
469
+ if (
470
+ self.delimit_valid_root
471
+ ): # If there exists a pre-processed delimit valid dataset
472
+ audio_path = f"{self.delimit_valid_root}/valid/{song_name}.wav"
473
+ mixture_limited = librosa.load(audio_path, mono=False, sr=self.sample_rate)[
474
+ 0
475
+ ]
476
+ mixture_lufs = self.dict_valid_loudness[song_name]
477
+
478
+ else:
479
+ if self.use_custom_limiter:
480
+ custom_limiter_attack = random.uniform(
481
+ self.custom_limiter_attack_range[0],
482
+ self.custom_limiter_attack_range[1],
483
+ )
484
+ self.board[1].attack_ms = custom_limiter_attack
485
+
486
+ custom_limiter_release = random.uniform(
487
+ self.custom_limiter_release_range[0],
488
+ self.custom_limiter_release_range[1],
489
+ )
490
+ self.board[1].release_ms = custom_limiter_release
491
+
492
+ mixture_limited, mixture_lufs = apply_limitaug(
493
+ mixture,
494
+ self.board,
495
+ self.meter,
496
+ self.sample_rate,
497
+ target_lufs=self.valid_target_lufs,
498
+ )
499
+ else:
500
+ mixture_limited, mixture_lufs = apply_limitaug(
501
+ mixture,
502
+ self.board,
503
+ self.meter,
504
+ self.sample_rate,
505
+ target_lufs=self.valid_target_lufs,
506
+ # target_loudnorm_lufs=self.target_loudnorm_lufs,
507
+ ) # mixture_limited is a limiter applied mixture
508
+ # We will give mixture_limited as an input and mixture_loudnorm as a target to the model.
509
+
510
+ if self.delimit_valid_L_root:
511
+ audio_L_path = f"{self.delimit_valid_L_root}/valid/{song_name}.wav"
512
+ mixture_loudnorm = librosa.load(
513
+ audio_L_path, mono=False, sr=self.sample_rate
514
+ )[0]
515
+ mixture_lufs = self.dict_valid_L_loudness[song_name]
516
+ mixture = mixture_loudnorm
517
+
518
+ augmented_gain = self.target_loudnorm_lufs - mixture_lufs
519
+ mixture_loudnorm = mixture * db2linear(augmented_gain)
520
+
521
+ if self.use_custom_limiter:
522
+ return (
523
+ mixture_limited,
524
+ mixture_loudnorm,
525
+ song_name,
526
+ mixture_lufs,
527
+ custom_limiter_attack,
528
+ custom_limiter_release,
529
+ )
530
+ else:
531
+ return mixture_limited, mixture_loudnorm, song_name, mixture_lufs
532
+
533
+
534
+ class OzoneValidDataset(MusdbValidDataset):
535
+ def __init__(
536
+ self,
537
+ target: str = "all",
538
+ root: str = None,
539
+ ozone_root: str = None,
540
+ target_loudnorm_lufs: float = -14.0,
541
+ *args,
542
+ **kwargs,
543
+ ) -> None:
544
+ super().__init__(target=target, root=root, *args, **kwargs)
545
+
546
+ self.ozone_root = ozone_root
547
+ self.target_loudnorm_lufs = target_loudnorm_lufs
548
+
549
+ with open(f"{self.ozone_root}/valid_loudness.json", "r") as f:
550
+ self.dict_valid_loudness = json.load(f)
551
+
552
+ def __getitem__(self, index):
553
+ audio_sources = []
554
+
555
+ track_path = self.valid_list[index]
556
+ song_name = os.path.basename(track_path)
557
+ for k, source in enumerate(self.sources):
558
+ audio_path = f"{track_path}/{source}.wav"
559
+ audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
560
+ audio_sources.append(audio)
561
+
562
+ stems = np.stack(audio_sources, axis=0)
563
+
564
+ mixture = stems.sum(0)
565
+
566
+ audio_path = f"{self.ozone_root}/ozone_train_fixed/{song_name}.wav"
567
+ mixture_limited = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
568
+
569
+ mixture_lufs = self.dict_valid_loudness[song_name]
570
+ augmented_gain = self.target_loudnorm_lufs - mixture_lufs
571
+ mixture_loudnorm = mixture * db2linear(augmented_gain)
572
+
573
+ return mixture_limited, mixture_loudnorm, song_name, mixture_lufs
dataloader/singleset.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torch.nn.functional as F
6
+
7
+ # Modified version from woosungchoi's original implementation
8
+ class SingleTrackSet(Dataset):
9
+ def __init__(self, track, hop_length, num_frame=128, target_name="vocals"):
10
+
11
+ assert len(track.shape) == 2
12
+ assert track.shape[0] == 2 # check stereo audio
13
+
14
+ self.hop_length = hop_length
15
+ self.window_length = hop_length * (num_frame - 1) # 130048
16
+ self.trim_length = self.get_trim_length(self.hop_length) # 5120
17
+
18
+ self.true_samples = self.window_length - 2 * self.trim_length # 119808
19
+
20
+ self.lengths = [track.shape[1]] # track lengths (in sample level)
21
+ self.source_names = [
22
+ "vocals",
23
+ "drums",
24
+ "bass",
25
+ "other",
26
+ ] # == self.musdb_train.targets_names[:-2]
27
+
28
+ self.target_names = [target_name]
29
+
30
+ self.num_tracks = 1
31
+
32
+ import math
33
+
34
+ num_chunks = [
35
+ math.ceil(length / self.true_samples) for length in self.lengths
36
+ ] # example : 44.1khz 180sec audio, => [67]
37
+ self.acc_chunk_final_ids = [
38
+ sum(num_chunks[: i + 1]) for i in range(self.num_tracks)
39
+ ] # [67]
40
+
41
+ self.cache_mode = True
42
+ self.cache = {}
43
+ self.cache[0] = {}
44
+ self.cache[0]["linear_mixture"] = track
45
+
46
+ def __len__(self):
47
+ return self.acc_chunk_final_ids[-1] * len(self.target_names) # 67
48
+
49
+ def __getitem__(self, idx):
50
+
51
+ target_offset = idx % len(self.target_names) # 0
52
+ idx = idx // len(self.target_names) # idx
53
+
54
+ target_name = self.target_names[target_offset] # 'vocals'
55
+ mixture_idx, start_pos = self.idx_to_track_offset(
56
+ idx
57
+ ) # idx * self.true_samples
58
+
59
+ length = self.true_samples
60
+ left_padding_num = right_padding_num = self.trim_length # 5120
61
+ if mixture_idx is None:
62
+ raise StopIteration
63
+ mixture_length = self.lengths[mixture_idx]
64
+ if start_pos + length > mixture_length: # last
65
+ right_padding_num += self.true_samples - (mixture_length - start_pos)
66
+ length = None
67
+
68
+ mixture = self.get_audio(mixture_idx, "linear_mixture", start_pos, length)
69
+ mixture = F.pad(mixture, (left_padding_num, right_padding_num), "constant", 0)
70
+
71
+ return mixture
72
+
73
+ def idx_to_track_offset(self, idx):
74
+
75
+ for i, last_chunk in enumerate(self.acc_chunk_final_ids):
76
+ if idx < last_chunk:
77
+ if i != 0:
78
+ offset = (idx - self.acc_chunk_final_ids[i - 1]) * self.true_samples
79
+ else:
80
+ offset = idx * self.true_samples
81
+ return i, offset
82
+
83
+ return None, None
84
+
85
+ def get_audio(self, idx, target_name, pos=0, length=None):
86
+ track = self.cache[idx][target_name]
87
+ return track[:, pos : pos + length] if length is not None else track[:, pos:]
88
+
89
+ def get_trim_length(self, hop_length, min_trim=5000):
90
+ trim_per_hop = math.ceil(min_trim / hop_length)
91
+
92
+ trim_length = trim_per_hop * hop_length
93
+ assert trim_per_hop > 1
94
+ return trim_length
95
+
eval_delimit/calc_flops.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import random
4
+
5
+ import torch
6
+ from deepspeed.profiling.flops_profiler import get_model_profile
7
+
8
+ from utils import get_config
9
+ from models import load_model_with_args
10
+
11
+
12
+ # def main():
13
+ parser = argparse.ArgumentParser(description="FLOPs calculation")
14
+
15
+ parser.add_argument(
16
+ "-c", "--config", default="delimit_6_s", type=str, help="Name of the setting file."
17
+ )
18
+
19
+ config_args = parser.parse_args()
20
+
21
+ args = get_config(config_args.config)
22
+ print(args)
23
+
24
+ with torch.cuda.device(0):
25
+ model = load_model_with_args(args)
26
+ batch_size = 1
27
+ flops, macs, params = get_model_profile(
28
+ model=model, # model
29
+ input_shape=(batch_size, 2, 44100 * 60), # input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
30
+ args=[], # list of positional arguments to the model.
31
+ kwargs={}, # dictionary of keyword arguments to the model.
32
+ print_profile=True, # prints the model graph with the measured profile attached to each module
33
+ detailed=True, # print the detailed profile
34
+ module_depth=-1, # depth into the nested modules, with -1 being the inner most modules
35
+ top_modules=1, # the number of top modules to print aggregated profile
36
+ warm_up=1, # the number of warm-ups before measuring the time of each module
37
+ as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
38
+ output_file=None, # path to the output file. If None, the profiler prints to stdout.
39
+ ignore_modules=None,
40
+ ) # the list of modules to ignore in the profiling
41
+ print(args.dir_params.exp_name)
42
+ print('flops: ', flops)
43
+ print('macs: ', macs)
44
+ print('params: ', params)
eval_delimit/score_calc_delimit.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Calculate SI-SDR, Multi-resolution spectrogram mse score of the pre-inferenced sources
2
+ import os
3
+ import argparse
4
+ import csv
5
+ import json
6
+ import glob
7
+
8
+ import tqdm
9
+ import numpy as np
10
+ import librosa
11
+ import pyloudnorm as pyln
12
+ from asteroid.metrics import get_metrics
13
+
14
+ from utils import str2bool
15
+
16
+
17
+ def multi_resolution_spectrogram_mse(
18
+ gt, est, n_fft=[2048, 1024, 512], n_hop=[512, 256, 128]
19
+ ):
20
+ assert gt.shape == est.shape
21
+ assert len(n_fft) == len(n_hop)
22
+
23
+ score = 0.0
24
+ for i in range(len(n_fft)):
25
+ gt_spec = librosa.magphase(
26
+ librosa.stft(gt, n_fft=n_fft[i], hop_length=n_hop[i])
27
+ )[0]
28
+ est_spec = librosa.magphase(
29
+ librosa.stft(est, n_fft=n_fft[i], hop_length=n_hop[i])
30
+ )[0]
31
+ score = score + np.mean((gt_spec - est_spec) ** 2)
32
+
33
+ return score
34
+
35
+
36
+ parser = argparse.ArgumentParser(description="model test.py")
37
+
38
+ parser.add_argument(
39
+ "--target",
40
+ type=str,
41
+ default="all",
42
+ help="target source. all, vocals, drums, bass, other, 0.5_mixed",
43
+ )
44
+ parser.add_argument(
45
+ "--root", type=str, default="/path/to/musdb18hq_loudnorm"
46
+ )
47
+ parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
48
+ parser.add_argument(
49
+ "--output_directory",
50
+ type=str,
51
+ default="/path/to/results",
52
+ )
53
+ parser.add_argument("--loudnorm_lufs", type=float, default=-14.0)
54
+ parser.add_argument(
55
+ "--calc_mse",
56
+ type=str2bool,
57
+ default=True,
58
+ help="calculate multi-resolution spectrogram mse",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--calc_results",
63
+ type=str2bool,
64
+ default=True,
65
+ help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
66
+ )
67
+
68
+ args, _ = parser.parse_known_args()
69
+
70
+ args.sample_rate = 44100
71
+
72
+ meter = pyln.Meter(args.sample_rate)
73
+
74
+ if args.calc_results:
75
+ args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
76
+ else:
77
+ args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
78
+
79
+ if args.target == "all" or args.target == "0.5_mixed":
80
+ test_tracks = glob.glob(f"{args.root}/*/mixture.wav")
81
+ else:
82
+ test_tracks = glob.glob(f"{args.root}/*/{args.target}.wav")
83
+ i = 0
84
+
85
+ dict_song_score = {}
86
+ list_si_sdr = []
87
+ list_multi_mse = []
88
+ for track in tqdm.tqdm(test_tracks):
89
+ if args.target == "all": # for standard de-limiter estimation
90
+ audio_name = os.path.basename(os.path.dirname(track))
91
+ gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
92
+
93
+ est_delimiter = librosa.load(
94
+ f"{args.test_output_dir}/{audio_name}/all.wav",
95
+ sr=args.sample_rate,
96
+ mono=False,
97
+ )[0]
98
+
99
+ else: # for source-separated de-limiter estimation
100
+ audio_name = os.path.basename(os.path.dirname(track))
101
+ gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
102
+ est_delimiter = librosa.load(
103
+ f"{args.test_output_dir}/{audio_name}/{args.target}.wav",
104
+ sr=args.sample_rate,
105
+ mono=False,
106
+ )[0]
107
+
108
+
109
+ metrics_dict = get_metrics(
110
+ gt_source + est_delimiter,
111
+ gt_source,
112
+ est_delimiter,
113
+ sample_rate=args.sample_rate,
114
+ metrics_list=["si_sdr"],
115
+ )
116
+
117
+ if args.calc_mse:
118
+ multi_resolution_spectrogram_mse_score = multi_resolution_spectrogram_mse(
119
+ gt_source, est_delimiter
120
+ )
121
+ else:
122
+ multi_resolution_spectrogram_mse_score = None
123
+
124
+ dict_song_score[audio_name] = {
125
+ "si_sdr": metrics_dict["si_sdr"],
126
+ "multi_mse": multi_resolution_spectrogram_mse_score,
127
+ }
128
+ list_si_sdr.append(metrics_dict["si_sdr"])
129
+ list_multi_mse.append(multi_resolution_spectrogram_mse_score)
130
+
131
+ i += 1
132
+
133
+ print(f"{args.exp_name} on {args.target}")
134
+ print(f"SI-SDR score: {sum(list_si_sdr) / len(list_si_sdr)}")
135
+ if args.calc_mse:
136
+ print(f"multi-mse score: {sum(list_multi_mse) / len(list_multi_mse)}")
137
+
138
+ if args.target != "all":
139
+ # save dict_song_score to json file
140
+ with open(f"{args.test_output_dir}/score_{args.target}.json", "w") as f:
141
+ json.dump(dict_song_score, f, indent=4)
142
+ else:
143
+ # save dict_song_score to json file
144
+ with open(f"{args.test_output_dir}/score.json", "w") as f:
145
+ json.dump(dict_song_score, f, indent=4)
eval_delimit/score_diff_dyn_complexity.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import csv
4
+ import json
5
+ import glob
6
+
7
+ import tqdm
8
+ import numpy as np
9
+ import librosa
10
+ import musdb
11
+ import pyloudnorm as pyln
12
+
13
+ from utils import str2bool, db2linear
14
+
15
+ parser = argparse.ArgumentParser(description="model test.py")
16
+
17
+ parser.add_argument(
18
+ "--target",
19
+ type=str,
20
+ default="all",
21
+ help="target source. all, vocals, bass, drums, other.",
22
+ )
23
+ parser.add_argument(
24
+ "--root",
25
+ type=str,
26
+ default="/path/to/musdb18hq_loudnorm",
27
+ )
28
+ parser.add_argument(
29
+ "--output_directory",
30
+ type=str,
31
+ default="/path/to/results",
32
+ )
33
+ parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
34
+ parser.add_argument(
35
+ "--calc_results",
36
+ type=str2bool,
37
+ default=True,
38
+ help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
39
+ )
40
+
41
+ args, _ = parser.parse_known_args()
42
+
43
+ args.sample_rate = 44100
44
+ meter = pyln.Meter(args.sample_rate)
45
+
46
+ if args.calc_results:
47
+ args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
48
+ else:
49
+ args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
50
+
51
+
52
+ est_track_list = glob.glob(f"{args.test_output_dir}/*/{args.target}.wav")
53
+ f = open(
54
+ f"{args.test_output_dir}/score_feature_{args.target}.json",
55
+ encoding="UTF-8",
56
+ )
57
+ dict_song_score_est = json.loads(f.read())
58
+
59
+ if args.target == "all":
60
+ ref_track_list = glob.glob(f"{args.root}/*/mixture.wav")
61
+ f = open(f"{args.root}/score_feature.json", encoding="UTF-8")
62
+ dict_song_score_ref = json.loads(f.read())
63
+ else:
64
+ ref_track_list = glob.glob(f"{args.root}/*/{args.target}.wav")
65
+ f = open(f"{args.root}/score_feature_{args.target}.json", encoding="UTF-8")
66
+ dict_song_score_ref = json.loads(f.read())
67
+
68
+ i = 0
69
+
70
+ dict_song_score = {}
71
+ list_diff_dynamic_complexity = []
72
+
73
+ for track in tqdm.tqdm(ref_track_list):
74
+ audio_name = os.path.basename(os.path.dirname(track))
75
+ ref_dyn_complexity = dict_song_score_ref[audio_name]["dynamic_complexity_score"]
76
+ est_dyn_complexity = dict_song_score_est[audio_name]["dynamic_complexity_score"]
77
+
78
+ list_diff_dynamic_complexity.append(est_dyn_complexity - ref_dyn_complexity)
79
+
80
+ i += 1
81
+
82
+ print(
83
+ f"Dynamic complexity difference {args.exp_name} vs {os.path.basename(args.root)} on {args.target}"
84
+ )
85
+ print("mean: ", np.mean(list_diff_dynamic_complexity))
86
+ print("median: ", np.median(list_diff_dynamic_complexity))
87
+ print("std: ", np.std(list_diff_dynamic_complexity))
eval_delimit/score_fad.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We are going to use FAD based on https://github.com/gudgud96/frechet-audio-distance
2
+ import os
3
+ import subprocess
4
+ import glob
5
+ import argparse
6
+
7
+ from frechet_audio_distance import FrechetAudioDistance
8
+
9
+ from utils import str2bool
10
+
11
+
12
+ parser = argparse.ArgumentParser(description="model test.py")
13
+
14
+ parser.add_argument(
15
+ "--target",
16
+ type=str,
17
+ default="all",
18
+ help="target source. all, vocals, drums, bass, other",
19
+ )
20
+ parser.add_argument(
21
+ "--root",
22
+ type=str,
23
+ default="/path/to/musdb18hq_loudnorm",
24
+ )
25
+ parser.add_argument(
26
+ "--output_directory",
27
+ type=str,
28
+ default="/path/to/results",
29
+ )
30
+ parser.add_argument("--exp_name", type=str, default="delimit_6_s")
31
+ parser.add_argument(
32
+ "--calc_results",
33
+ type=str2bool,
34
+ default=True,
35
+ help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
36
+ )
37
+
38
+ args, _ = parser.parse_known_args()
39
+
40
+ os.makedirs(f"{args.root}/musdb_hq_loudnorm_16k_mono_link", exist_ok=True)
41
+
42
+ song_list = glob.glob(f"{args.root}/musdb_hq_loudnorm_16k_mono/*/mixture.wav")
43
+ for song in song_list:
44
+ song_name = os.path.basename(os.path.dirname(song))
45
+ subprocess.run(
46
+ f'ln --symbolic "{song}" "{args.root}/musdb_hq_loudnorm_16k_mono_link/{song_name}.wav"',
47
+ shell=True,
48
+ )
49
+
50
+
51
+ if args.calc_results:
52
+ args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
53
+ else:
54
+ args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
55
+
56
+ os.makedirs(f"{args.test_output_dir}_16k_mono_link", exist_ok=True)
57
+
58
+ song_list = glob.glob(f"{args.test_output_dir}_16k_mono/*/{args.target}.wav")
59
+ for song in song_list:
60
+ song_name = os.path.basename(os.path.dirname(song))
61
+ subprocess.run(
62
+ f'ln --symbolic "{song}" "{args.test_output_dir}_16k_mono_link/{song_name}.wav"',
63
+ shell=True,
64
+ )
65
+
66
+
67
+ frechet = FrechetAudioDistance()
68
+
69
+ fad_score = frechet.score(
70
+ f"{args.root}/musdb_hq_loudnorm_16k_mono_link",
71
+ f"{args.test_output_dir}_16k_mono_link",
72
+ )
73
+
74
+ print(f"{args.exp_name}")
75
+ print(f"FAD score: {fad_score}")
eval_delimit/score_features.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import csv
4
+ import json
5
+ import glob
6
+ from typing import Any, Optional, Union, Collection
7
+
8
+ import tqdm
9
+ import numpy as np
10
+ import librosa
11
+ from librosa.core.spectrum import _spectrogram
12
+ import musdb
13
+ import essentia
14
+ import essentia.standard
15
+ import pyloudnorm as pyln
16
+
17
+ from utils import str2bool, db2linear
18
+
19
+
20
+ def spectral_crest(
21
+ *,
22
+ y: Optional[np.ndarray] = None,
23
+ S: Optional[np.ndarray] = None,
24
+ n_fft: int = 2048,
25
+ hop_length: int = 512,
26
+ win_length: Optional[int] = None,
27
+ window: str = "hann",
28
+ center: bool = True,
29
+ pad_mode: str = "constant",
30
+ amin: float = 1e-10,
31
+ power: float = 2.0,
32
+ ) -> np.ndarray:
33
+ """Compute spectral crest
34
+
35
+ Spectral crest (or tonality coefficient) is a measure of
36
+ the ratio of the maximum of the spectrum to the arithmetic mean of the spectrum
37
+
38
+ A higher spectral crest => more tonality,
39
+ A lower spectral crest => more noisy.
40
+
41
+
42
+ Parameters
43
+ ----------
44
+ y : np.ndarray [shape=(..., n)] or None
45
+ audio time series. Multi-channel is supported.
46
+ S : np.ndarray [shape=(..., d, t)] or None
47
+ (optional) pre-computed spectrogram magnitude
48
+ n_fft : int > 0 [scalar]
49
+ FFT window size
50
+ hop_length : int > 0 [scalar]
51
+ hop length for STFT. See `librosa.stft` for details.
52
+ win_length : int <= n_fft [scalar]
53
+ Each frame of audio is windowed by `window()`.
54
+ The window will be of length `win_length` and then padded
55
+ with zeros to match ``n_fft``.
56
+ If unspecified, defaults to ``win_length = n_fft``.
57
+ window : string, tuple, number, function, or np.ndarray [shape=(n_fft,)]
58
+ - a window specification (string, tuple, or number);
59
+ see `scipy.signal.get_window`
60
+ - a window function, such as `scipy.signal.windows.hann`
61
+ - a vector or array of length ``n_fft``
62
+ .. see also:: `librosa.filters.get_window`
63
+ center : boolean
64
+ - If `True`, the signal ``y`` is padded so that frame
65
+ ``t`` is centered at ``y[t * hop_length]``.
66
+ - If `False`, then frame `t` begins at ``y[t * hop_length]``
67
+ pad_mode : string
68
+ If ``center=True``, the padding mode to use at the edges of the signal.
69
+ By default, STFT uses zero padding.
70
+ amin : float > 0 [scalar]
71
+ minimum threshold for ``S`` (=added noise floor for numerical stability)
72
+ power : float > 0 [scalar]
73
+ Exponent for the magnitude spectrogram.
74
+ e.g., 1 for energy, 2 for power, etc.
75
+ Power spectrogram is usually used for computing spectral flatness.
76
+
77
+ Returns
78
+ -------
79
+ crest : np.ndarray [shape=(..., 1, t)]
80
+ spectral crest for each frame.
81
+
82
+
83
+ """
84
+
85
+ S, n_fft = _spectrogram(
86
+ y=y,
87
+ S=S,
88
+ n_fft=n_fft,
89
+ hop_length=hop_length,
90
+ power=1.0,
91
+ win_length=win_length,
92
+ window=window,
93
+ center=center,
94
+ pad_mode=pad_mode,
95
+ )
96
+
97
+ S_thresh = np.maximum(amin, S**power)
98
+ # gmean = np.exp(np.mean(np.log(S_thresh), axis=-2, keepdims=True))
99
+ gmax = np.max(S_thresh, axis=-2, keepdims=True)
100
+ amean = np.mean(S_thresh, axis=-2, keepdims=True)
101
+ crest: np.ndarray = gmax / amean
102
+ return crest
103
+
104
+
105
+ parser = argparse.ArgumentParser(description="model test.py")
106
+
107
+ parser.add_argument(
108
+ "--target",
109
+ type=str,
110
+ default="all",
111
+ help="target source. all, vocals, drums, bass, other",
112
+ )
113
+ parser.add_argument(
114
+ "--root", type=str, default="/path/to/musdb18hq_loudnorm"
115
+ )
116
+ parser.add_argument("--exp_name", type=str, default="delimit_6_s")
117
+ parser.add_argument(
118
+ "--output_directory",
119
+ type=str,
120
+ default="/path/to/results",
121
+ )
122
+ parser.add_argument(
123
+ "--calc_results",
124
+ type=str2bool,
125
+ default=True,
126
+ help="calculate results or musdb-hq or musdb-XL test dataset",
127
+ )
128
+
129
+
130
+ args, _ = parser.parse_known_args()
131
+
132
+ args.sample_rate = 44100
133
+
134
+ args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
135
+
136
+ if args.calc_results:
137
+ track_list = glob.glob(
138
+ f"{args.output_directory}/test/{args.exp_name}/*/{args.target}.wav"
139
+ )
140
+ else:
141
+ if args.target == "all":
142
+ track_list = glob.glob(f"{args.root}/*/mixture.wav")
143
+ else:
144
+ track_list = glob.glob(f"{args.root}/*/{args.target}.wav")
145
+
146
+ i = 0
147
+
148
+
149
+ dynamic_complexity = essentia.standard.DynamicComplexity()
150
+ loudness_range = essentia.standard.LoudnessEBUR128()
151
+ spectral_centroid = essentia.standard.SpectralCentroidTime()
152
+ crest = essentia.standard.Crest()
153
+ dynamic_spread = essentia.standard.DistributionShape()
154
+ central_moments = essentia.standard.CentralMoments()
155
+
156
+ dict_song_score = {}
157
+ list_rms = []
158
+ list_crest_factor = []
159
+ list_dc_score = []
160
+ list_lra_score = []
161
+ list_sc_hertz = []
162
+ list_sf_score = []
163
+ list_spectral_crest_score = []
164
+
165
+ for track in tqdm.tqdm(track_list):
166
+ audio_name = os.path.basename(os.path.dirname(track))
167
+ gt_source_librosa = librosa.load(f"{track}", sr=args.sample_rate, mono=False)[
168
+ 0
169
+ ] # (nb_channels, nb_samples)
170
+ gt_source_librosa_mono = librosa.to_mono(gt_source_librosa) # (nb_samples)
171
+
172
+ gt_source_essentia = essentia.standard.AudioLoader(filename=f"{track}")()[
173
+ 0
174
+ ] # (nb_samples, nb_channels)
175
+ gt_source_essentia_cat = np.concatenate(
176
+ [gt_source_essentia[:, 0], gt_source_essentia[:, 1]]
177
+ ) # (nb_samples * nb_channels)
178
+ gt_source_essentia_mono = np.mean(gt_source_essentia, axis=1) # (nb_samples)
179
+
180
+ rms = np.sqrt(np.mean(gt_source_essentia_cat**2))
181
+ crest_factor = np.max(np.abs(gt_source_essentia_cat)) / rms
182
+
183
+ dc_score, _ = dynamic_complexity(gt_source_essentia_mono)
184
+ _, _, _, lra_score = loudness_range(gt_source_essentia)
185
+ sc_hertz = spectral_centroid(gt_source_essentia_mono)
186
+ sf_score = np.mean(librosa.feature.spectral_flatness(gt_source_librosa_mono))
187
+ spectral_crest_score = np.mean(spectral_crest(y=gt_source_librosa_mono))
188
+
189
+ dict_song_score[audio_name] = {
190
+ "rms": float(rms),
191
+ "crest_factor": float(crest_factor),
192
+ "dynamic_complexity_score": float(dc_score),
193
+ "lra_score": float(lra_score),
194
+ "spectral_centroid_hertz": float(sc_hertz),
195
+ "spectral_flatness_score": float(sf_score),
196
+ "spectral_crest_score": float(spectral_crest_score),
197
+ }
198
+ list_rms.append(rms)
199
+ list_crest_factor.append(crest_factor)
200
+ list_dc_score.append(dc_score)
201
+ list_lra_score.append(lra_score)
202
+ list_sc_hertz.append(sc_hertz)
203
+ list_sf_score.append(sf_score)
204
+ list_spectral_crest_score.append(spectral_crest_score)
205
+
206
+ i += 1
207
+
208
+ if args.calc_results:
209
+ print(f"{args.exp_name} on {args.target}")
210
+ else:
211
+ print(f"{os.path.basename(args.root)} on {args.target}")
212
+ print(f"rms: {np.mean(list_rms)}")
213
+ print(f"crest_factor: {np.mean(list_crest_factor)}")
214
+ print(f"dynamic_complexity_score: {np.mean(list_dc_score)}")
215
+ print(f"lra_score: {np.mean(list_lra_score)}")
216
+ print(f"sc_hertz: {np.mean(list_sc_hertz)}")
217
+ print(f"sf_score: {np.mean(list_sf_score)}")
218
+ print(f"spectral_crest_score: {np.mean(list_spectral_crest_score)}")
219
+
220
+
221
+ # save dict_song_score to json file
222
+ if args.target == "all":
223
+ file_name = "score_features"
224
+ else:
225
+ file_name = f"score_feature_{args.target}"
226
+ if args.calc_results:
227
+ with open(
228
+ f"{args.output_directory}/test/{args.exp_name}/{file_name}.json", "w"
229
+ ) as f:
230
+ json.dump(dict_song_score, f, indent=4)
231
+ else:
232
+ with open(f"{args.root}/{file_name}.json", "w") as f:
233
+ json.dump(dict_song_score, f, indent=4)
eval_delimit/score_peaq.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We are going to use PEAQ based on https://github.com/HSU-ANT/gstpeaq
2
+
3
+ """
4
+ python3 score_peaq.py --exp_name=delimit_6_s | tee /path/to/results/delimit_6_s/score_peaq.txt
5
+ """
6
+
7
+
8
+
9
+ import os
10
+ import subprocess
11
+ import glob
12
+ import argparse
13
+
14
+
15
+ def str2bool(v):
16
+ if v.lower() in ("yes", "true", "t", "y", "1"):
17
+ return True
18
+ elif v.lower() in ("no", "false", "f", "n", "0"):
19
+ return False
20
+ else:
21
+ raise argparse.ArgumentTypeError("Boolean value expected.")
22
+
23
+
24
+ parser = argparse.ArgumentParser(description="model test.py")
25
+
26
+ parser.add_argument(
27
+ "--target",
28
+ type=str,
29
+ default="all",
30
+ help="target source. all, vocals, drums, bass, other",
31
+ )
32
+ parser.add_argument(
33
+ "--root",
34
+ type=str,
35
+ default="/path/to/musdb_XL_loudnorm",
36
+ )
37
+ parser.add_argument(
38
+ "--output_directory",
39
+ type=str,
40
+ default="/path/to/results/",
41
+ )
42
+ parser.add_argument("--exp_name", type=str, default="delimit_6_s")
43
+ parser.add_argument(
44
+ "--calc_results",
45
+ type=str2bool,
46
+ default=True,
47
+ help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
48
+ )
49
+
50
+ args, _ = parser.parse_known_args()
51
+
52
+ if args.calc_results:
53
+ args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
54
+ else:
55
+ args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
56
+
57
+ if args.target == "all":
58
+ song_list = sorted(glob.glob(f"{args.root}/*/mixture.wav"))
59
+
60
+ for song in song_list:
61
+ song_name = os.path.basename(os.path.dirname(song))
62
+ est_path = f"{args.test_output_dir}/{song_name}/{args.target}.wav"
63
+ subprocess.run(
64
+ f'peaq --gst-plugin-load=/usr/local/lib/gstreamer-1.0/libgstpeaq.so "{song}" "{est_path}"',
65
+ shell=True,
66
+ )
67
+
68
+ else:
69
+ song_list = sorted(glob.glob(f"{args.root}/*/{args.target}.wav"))
70
+
71
+ for song in song_list:
72
+ song_name = os.path.basename(os.path.dirname(song))
73
+ est_path = f"{args.test_output_dir}/{song_name}/{args.target}.wav"
74
+ subprocess.run(
75
+ f'peaq --gst-plugin-load=/usr/local/lib/gstreamer-1.0/libgstpeaq.so "{song}" "{est_path}"',
76
+ shell=True,
77
+ )
eval_delimit/score_peaq_aggregate.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PEAQ aggregate score
2
+ """
3
+ /path/to/results/delimit_6_s/score_peaq.txt
4
+ """
5
+
6
+ import os
7
+ import glob
8
+ import argparse
9
+ import json
10
+
11
+
12
+ def str2bool(v):
13
+ if v.lower() in ("yes", "true", "t", "y", "1"):
14
+ return True
15
+ elif v.lower() in ("no", "false", "f", "n", "0"):
16
+ return False
17
+ else:
18
+ raise argparse.ArgumentTypeError("Boolean value expected.")
19
+
20
+
21
+ parser = argparse.ArgumentParser(description="model test.py")
22
+
23
+ parser.add_argument(
24
+ "--target",
25
+ type=str,
26
+ default="all",
27
+ help="target source. all, vocals, drums, bass, other",
28
+ )
29
+ parser.add_argument(
30
+ "--root",
31
+ type=str,
32
+ default="/path/to/musdb18hq_loudnorm",
33
+ )
34
+ parser.add_argument(
35
+ "--output_directory",
36
+ type=str,
37
+ default="/path/to/results",
38
+ )
39
+ parser.add_argument("--exp_name", type=str, default="delimit_6_s")
40
+ parser.add_argument(
41
+ "--calc_results",
42
+ type=str2bool,
43
+ default=True,
44
+ help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
45
+ )
46
+
47
+ args, _ = parser.parse_known_args()
48
+
49
+
50
+ if args.calc_results:
51
+ args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
52
+ else:
53
+ args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
54
+
55
+
56
+ if args.target == "all":
57
+ score_path = f"{args.test_output_dir}/score_peaq.txt"
58
+ else:
59
+ score_path = f"{args.test_output_dir}/score_peaq_{args.target}.txt"
60
+
61
+ # write the code to load score_peaq.txt
62
+ with open(score_path, "r") as f:
63
+ score_txt = f.readlines()
64
+
65
+ song_list = glob.glob(f"{args.root}/*")
66
+
67
+ dict_song_peaq = {}
68
+ list_peaq = []
69
+ for idx, song in enumerate(song_list):
70
+ song_name = os.path.basename(song)
71
+ peaq = float(score_txt[idx * 2].replace("Objective Difference Grade: ", ""))
72
+ dict_song_peaq[song_name] = peaq
73
+ list_peaq.append(peaq)
74
+
75
+ print(f"{args.exp_name} on {args.target}")
76
+ print(f"PEAQ score: {sum(list_peaq) / len(list_peaq)}")
77
+
78
+ if args.target == "all":
79
+ # save dict_song_peaq to json file
80
+ with open(f"{args.test_output_dir}/score_peaq.json", "w") as f:
81
+ json.dump(dict_song_peaq, f, indent=4)
82
+ else:
83
+ # save dict_song_peaq to json file
84
+ with open(
85
+ f"{args.test_output_dir}/score_peaq_{args.target}.json",
86
+ "w",
87
+ ) as f:
88
+ json.dump(dict_song_peaq, f, indent=4)
inference.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import glob
5
+
6
+ import torch
7
+ import tqdm
8
+ import librosa
9
+ import soundfile as sf
10
+ import pyloudnorm as pyln
11
+ from dotmap import DotMap
12
+
13
+ from models import load_model_with_args
14
+ from separate_func import (
15
+ conv_tasnet_separate,
16
+ )
17
+ from utils import str2bool, db2linear
18
+
19
+
20
+ tqdm.monitor_interval = 0
21
+
22
+
23
+ def separate_track_with_model(
24
+ args, model, device, track_audio, track_name, meter, augmented_gain
25
+ ):
26
+ with torch.no_grad():
27
+ if (
28
+ args.model_loss_params.architecture == "conv_tasnet_mask_on_output"
29
+ or args.model_loss_params.architecture == "conv_tasnet"
30
+ ):
31
+ estimates = conv_tasnet_separate(
32
+ args,
33
+ model,
34
+ device,
35
+ track_audio,
36
+ track_name,
37
+ meter=meter,
38
+ augmented_gain=augmented_gain,
39
+ )
40
+
41
+ return estimates
42
+
43
+
44
+ def main():
45
+ parser = argparse.ArgumentParser(description="model test.py")
46
+ parser.add_argument("--target", type=str, default="all")
47
+ parser.add_argument("--data_root", type=str, default="./input_data")
48
+ parser.add_argument("--weight_directory", type=str, default="./weight")
49
+ parser.add_argument("--output_directory", type=str, default="./output")
50
+ parser.add_argument("--use_gpu", type=str2bool, default=True)
51
+ parser.add_argument("--save_name_as_target", type=str2bool, default=False)
52
+ parser.add_argument(
53
+ "--loudnorm_input_lufs",
54
+ type=float,
55
+ default=None,
56
+ help="If you want to use loudnorm for input",
57
+ )
58
+ parser.add_argument(
59
+ "--save_output_loudnorm",
60
+ type=float,
61
+ default=-14.0,
62
+ help="Save loudness normalized outputs or not. If you want to save, input target loudness",
63
+ )
64
+ parser.add_argument(
65
+ "--save_mixed_output",
66
+ type=float,
67
+ default=None,
68
+ help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)",
69
+ )
70
+ parser.add_argument(
71
+ "--save_16k_mono",
72
+ type=str2bool,
73
+ default=False,
74
+ help="Save 16k mono wav files for FAD evaluation.",
75
+ )
76
+ parser.add_argument(
77
+ "--save_histogram",
78
+ type=str2bool,
79
+ default=False,
80
+ help="Save histogram of the output. Only valid when the task is 'delimit'",
81
+ )
82
+ parser.add_argument(
83
+ "--use_singletrackset",
84
+ type=str2bool,
85
+ default=False,
86
+ help="Use SingleTrackSet if input data is too long.",
87
+ )
88
+
89
+ args, _ = parser.parse_known_args()
90
+
91
+ with open(f"{args.weight_directory}/{args.target}.json", "r") as f:
92
+ args_dict = json.load(f)
93
+ args_dict = DotMap(args_dict)
94
+
95
+ for key, value in args_dict["args"].items():
96
+ if key in list(vars(args).keys()):
97
+ pass
98
+ else:
99
+ setattr(args, key, value)
100
+
101
+ args.test_output_dir = f"{args.output_directory}"
102
+ os.makedirs(args.test_output_dir, exist_ok=True)
103
+
104
+ device = torch.device(
105
+ "cuda" if torch.cuda.is_available() and args.use_gpu else "cpu"
106
+ )
107
+
108
+ ###################### Define Models ######################
109
+ our_model = load_model_with_args(args)
110
+ our_model = our_model.to(device)
111
+
112
+ target_model_path = f"{args.weight_directory}/{args.target}.pth"
113
+ checkpoint = torch.load(target_model_path, map_location=device)
114
+ our_model.load_state_dict(checkpoint)
115
+
116
+ our_model.eval()
117
+
118
+ meter = pyln.Meter(44100)
119
+
120
+ test_tracks = glob.glob(f"{args.data_root}/*.wav") + glob.glob(
121
+ f"{args.data_root}/*.mp3"
122
+ )
123
+
124
+ for track in tqdm.tqdm(test_tracks):
125
+ track_name = os.path.basename(track).replace(".wav", "").replace(".mp3", "")
126
+ track_audio, sr = librosa.load(track, sr=None, mono=False) # sr should be 44100
127
+
128
+ orig_audio = track_audio.copy()
129
+
130
+ if sr != 44100:
131
+ raise ValueError("Sample rate should be 44100")
132
+ augmented_gain = None
133
+ print("Now De-limiting : ", track_name)
134
+
135
+ if args.loudnorm_input_lufs: # If you want to use loud-normalized input
136
+ track_lufs = meter.integrated_loudness(track_audio.T)
137
+ augmented_gain = args.loudnorm_input_lufs - track_lufs
138
+ track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
139
+
140
+ track_audio = (
141
+ torch.as_tensor(track_audio, dtype=torch.float32).unsqueeze(0).to(device)
142
+ )
143
+
144
+ estimates = separate_track_with_model(
145
+ args, our_model, device, track_audio, track_name, meter, augmented_gain
146
+ )
147
+
148
+ if args.save_mixed_output:
149
+ track_lufs = meter.integrated_loudness(orig_audio.T)
150
+ augmented_gain = args.save_output_loudnorm - track_lufs
151
+ orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
152
+
153
+ mixed_output = orig_audio * args.save_mixed_output + estimates * (
154
+ 1 - args.save_mixed_output
155
+ )
156
+
157
+ sf.write(
158
+ f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav",
159
+ mixed_output.T,
160
+ args.data_params.sample_rate,
161
+ )
162
+
163
+
164
+ if __name__ == "__main__":
165
+ main()
main_ddp.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import random
4
+
5
+ import torch
6
+
7
+ from train_ddp import train
8
+ from utils import get_config
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(description="Trainer")
13
+
14
+ # Put every argumnet in './configs/yymmdd_architecture_number.yaml' and load it.
15
+ parser.add_argument(
16
+ "-c",
17
+ "--config",
18
+ default="delimit_6_s",
19
+ type=str,
20
+ help="Name of the setting file.",
21
+ )
22
+
23
+ config_args = parser.parse_args()
24
+
25
+ args = get_config(config_args.config)
26
+
27
+ args.img_check = (
28
+ f"{args.dir_params.output_directory}/img_check/{args.dir_params.exp_name}"
29
+ )
30
+ args.output = (
31
+ f"{args.dir_params.output_directory}/checkpoint/{args.dir_params.exp_name}"
32
+ )
33
+
34
+ # Set which devices to use
35
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
36
+ os.environ["MASTER_PORT"] = str(random.randint(0, 1800))
37
+
38
+ os.makedirs(args.img_check, exist_ok=True)
39
+ os.makedirs(args.output, exist_ok=True)
40
+
41
+ torch.manual_seed(args.sys_params.seed)
42
+ random.seed(args.sys_params.seed)
43
+
44
+ print(args)
45
+ train(args)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .load_models import load_model_with_args
models/base_models.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from asteroid.models.base_models import (
4
+ BaseEncoderMaskerDecoder,
5
+ _unsqueeze_to_3d,
6
+ _shape_reconstructed,
7
+ )
8
+ from asteroid.utils.torch_utils import pad_x_to_y, jitable_shape
9
+ from einops import rearrange
10
+
11
+
12
+ class BaseEncoderMaskerDecoderWithConfigs(BaseEncoderMaskerDecoder):
13
+ def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
14
+ super().__init__(encoder, masker, decoder, encoder_activation)
15
+ self.use_encoder = kwargs.get("use_encoder", True)
16
+ self.apply_mask = kwargs.get("apply_mask", True)
17
+ self.use_decoder = kwargs.get("use_decoder", True)
18
+
19
+ def forward(self, wav):
20
+ """
21
+ Enc/Mask/Dec model forward with some additional options.
22
+ Some of the models we use, like TFC-TDF-UNet, have no masker.
23
+ In UMX or X-UMX, they already use masking in their model implementation.
24
+ Since we do not want to manipulate the model codes, we use this wrapper.
25
+
26
+ Args:
27
+ wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
28
+
29
+ Returns:
30
+ torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
31
+ """
32
+ # Remember shape to shape reconstruction, cast to Tensor for torchscript
33
+ shape = jitable_shape(wav)
34
+ # Reshape to (batch, n_mix, time)
35
+ wav = _unsqueeze_to_3d(wav)
36
+
37
+ # Real forward
38
+ if self.use_encoder:
39
+ tf_rep = self.forward_encoder(wav)
40
+ else:
41
+ tf_rep = wav
42
+
43
+ est_masks = self.forward_masker(tf_rep)
44
+
45
+ if self.apply_mask:
46
+ masked_tf_rep = self.apply_masks(tf_rep, est_masks)
47
+ else: # model already used masking
48
+ masked_tf_rep = est_masks
49
+
50
+ if self.use_decoder:
51
+ decoded = self.forward_decoder(masked_tf_rep)
52
+ reconstructed = pad_x_to_y(decoded, wav)
53
+
54
+ return masked_tf_rep, _shape_reconstructed(reconstructed, shape)
55
+
56
+ else: # In UMX or X-UMX, decoder is not used
57
+ decoded = masked_tf_rep
58
+
59
+ return decoded
60
+
61
+
62
+ class BaseEncoderMaskerDecoder_mixture_consistency(BaseEncoderMaskerDecoder):
63
+ def __init__(self, encoder, masker, decoder, encoder_activation=None):
64
+ super().__init__(encoder, masker, decoder, encoder_activation)
65
+
66
+ def forward(self, wav):
67
+ """Enc/Mask/Dec model forward with mixture consistent output
68
+
69
+ References:
70
+ [1] : Wisdom, Scott, et al. "Differentiable consistency constraints for improved deep speech enhancement." ICASSP 2019.
71
+ [2] : Wisdom, Scott, et al. "Unsupervised sound separation using mixture invariant training." NeurIPS 2020.
72
+
73
+ Args:
74
+ wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
75
+
76
+ Returns:
77
+ torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
78
+ """
79
+ # Remember shape to shape reconstruction, cast to Tensor for torchscript
80
+ shape = jitable_shape(wav)
81
+ # Reshape to (batch, n_mix, time)
82
+ wav = _unsqueeze_to_3d(wav)
83
+
84
+ # Real forward
85
+ tf_rep = self.forward_encoder(wav)
86
+ est_masks = self.forward_masker(tf_rep)
87
+ masked_tf_rep = self.apply_masks(tf_rep, est_masks)
88
+ decoded = self.forward_decoder(masked_tf_rep)
89
+
90
+ reconstructed = _shape_reconstructed(pad_x_to_y(decoded, wav), shape)
91
+
92
+ reconstructed = reconstructed + 1 / reconstructed.shape[1] * (
93
+ wav - reconstructed.sum(dim=1, keepdim=True)
94
+ )
95
+
96
+ return reconstructed
97
+
98
+
99
+ class BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(BaseEncoderMaskerDecoder):
100
+ def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
101
+ super().__init__(encoder, masker, decoder, encoder_activation)
102
+ self.use_encoder = kwargs.get("use_encoder", True)
103
+ self.apply_mask = kwargs.get("apply_mask", True)
104
+ self.use_decoder = kwargs.get("use_decoder", True)
105
+ self.nb_channels = kwargs.get("nb_channels", 2)
106
+ self.decoder_activation = kwargs.get("decoder_activation", "sigmoid")
107
+ if self.decoder_activation == "sigmoid":
108
+ self.act_after_dec = nn.Sigmoid()
109
+ elif self.decoder_activation == "relu":
110
+ self.act_after_dec = nn.ReLU()
111
+ elif self.decoder_activation == "relu6":
112
+ self.act_after_dec = nn.ReLU6()
113
+ elif self.decoder_activation == "tanh":
114
+ self.act_after_dec = nn.Tanh()
115
+ elif self.decoder_activation == "none":
116
+ self.act_after_dec = nn.Identity()
117
+ else:
118
+ self.act_after_dec = nn.Sigmoid()
119
+
120
+ def forward(self, wav):
121
+ """
122
+ For the De-limit task, we will apply the mask on the output of the decoder.
123
+ We want decoder to learn the sample-wise ratio of the sources.
124
+
125
+ Args:
126
+ wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
127
+
128
+ Returns:
129
+ torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
130
+ """
131
+ # Remember shape to shape reconstruction, cast to Tensor for torchscript
132
+ shape = jitable_shape(wav)
133
+ # Reshape to (batch, n_mix, time)
134
+ wav = _unsqueeze_to_3d(wav) # (batch, n_channels, time)
135
+
136
+ # Real forward
137
+ if self.use_encoder:
138
+ tf_rep = self.forward_encoder(wav) # (batch, n_channels, freq, time)
139
+ else:
140
+ tf_rep = wav
141
+
142
+ if self.nb_channels == 2:
143
+ tf_rep = rearrange(
144
+ tf_rep, "b c f t -> b (c f) t"
145
+ ) # c == 2 when stereo input.
146
+ est_masks = self.forward_masker(tf_rep) # (batch, 1, freq, time)
147
+
148
+ # we are going to apply the mask on the output of the decoder
149
+ if self.use_decoder:
150
+ if self.nb_channels == 2:
151
+ est_masks = rearrange(est_masks, "b 1 f t -> b f t")
152
+ est_masks_decoded = self.forward_decoder(est_masks)
153
+ est_masks_decoded = pad_x_to_y(est_masks_decoded, wav) # (batch, 1, time)
154
+ est_masks_decoded = self.act_after_dec(
155
+ est_masks_decoded
156
+ ) # (batch, 1, time)
157
+ decoded = wav * est_masks_decoded # (batch, n_channels, time)
158
+
159
+ return (
160
+ est_masks_decoded,
161
+ decoded,
162
+ )
163
+
164
+ else:
165
+ decoded = est_masks
166
+
167
+ return (decoded,)
168
+
169
+
170
+ class BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(BaseEncoderMaskerDecoder):
171
+ def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
172
+ super().__init__(encoder, masker, decoder, encoder_activation)
173
+ self.use_encoder = kwargs.get("use_encoder", True)
174
+ self.apply_mask = kwargs.get("apply_mask", True)
175
+ self.use_decoder = kwargs.get("use_decoder", True)
176
+ self.nb_channels = kwargs.get("nb_channels", 2)
177
+ self.decoder_activation = kwargs.get("decoder_activation", "none")
178
+ if self.decoder_activation == "sigmoid":
179
+ self.act_after_dec = nn.Sigmoid()
180
+ elif self.decoder_activation == "relu":
181
+ self.act_after_dec = nn.ReLU()
182
+ elif self.decoder_activation == "relu6":
183
+ self.act_after_dec = nn.ReLU6()
184
+ elif self.decoder_activation == "tanh":
185
+ self.act_after_dec = nn.Tanh()
186
+ elif self.decoder_activation == "none":
187
+ self.act_after_dec = nn.Identity()
188
+ else:
189
+ self.act_after_dec = nn.Sigmoid()
190
+
191
+ def forward(self, wav):
192
+ """
193
+ Enc/Mask/Dec model forward with some additional options.
194
+ For MultiChannel usage of asteroid-based models. (e.g. ConvTasNet)
195
+
196
+
197
+ Args:
198
+ wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
199
+
200
+ Returns:
201
+ torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
202
+ """
203
+ # Remember shape to shape reconstruction, cast to Tensor for torchscript
204
+ shape = jitable_shape(wav)
205
+ # Reshape to (batch, n_mix, time)
206
+ wav = _unsqueeze_to_3d(wav)
207
+
208
+ # Real forward
209
+ if self.use_encoder:
210
+ tf_rep = self.forward_encoder(wav)
211
+ else:
212
+ tf_rep = wav
213
+
214
+ if self.nb_channels == 2:
215
+ tf_rep = rearrange(
216
+ tf_rep, "b c f t -> b (c f) t"
217
+ ) # c == 2 when stereo input.
218
+ est_masks = self.forward_masker(tf_rep)
219
+
220
+ if self.nb_channels == 2:
221
+ tf_rep = rearrange(tf_rep, "b (c f) t -> b c f t", c=self.nb_channels)
222
+
223
+ if self.apply_mask:
224
+ # Since original asteroid implementation of masking includes unnecessary unsqueeze operation, we will do it manually.
225
+ masked_tf_rep = est_masks * tf_rep
226
+ else:
227
+ masked_tf_rep = est_masks
228
+
229
+ if self.use_decoder:
230
+ decoded = self.forward_decoder(masked_tf_rep)
231
+ reconstructed = pad_x_to_y(decoded, wav)
232
+ reconstructed = self.act_after_dec(reconstructed)
233
+
234
+ return masked_tf_rep, _shape_reconstructed(reconstructed, shape)
235
+
236
+ else:
237
+ decoded = masked_tf_rep
238
+
239
+ return decoded
models/load_models.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from asteroid_filterbanks import make_enc_dec
5
+
6
+ from asteroid.masknn import TDConvNet
7
+
8
+ import utils
9
+ from .base_models import (
10
+ BaseEncoderMaskerDecoderWithConfigs,
11
+ BaseEncoderMaskerDecoderWithConfigsMaskOnOutput,
12
+ BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid,
13
+ )
14
+
15
+
16
+ def load_model_with_args(args):
17
+ if args.model_loss_params.architecture == "conv_tasnet_mask_on_output":
18
+ encoder, decoder = make_enc_dec(
19
+ "free",
20
+ n_filters=args.conv_tasnet_params.n_filters,
21
+ kernel_size=args.conv_tasnet_params.kernel_size,
22
+ stride=args.conv_tasnet_params.stride,
23
+ sample_rate=args.sample_rate,
24
+ )
25
+ masker = TDConvNet(
26
+ in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo
27
+ n_src=1, # for de-limit task.
28
+ out_chan=encoder.n_feats_out,
29
+ n_blocks=args.conv_tasnet_params.n_blocks,
30
+ n_repeats=args.conv_tasnet_params.n_repeats,
31
+ bn_chan=args.conv_tasnet_params.bn_chan,
32
+ hid_chan=args.conv_tasnet_params.hid_chan,
33
+ skip_chan=args.conv_tasnet_params.skip_chan,
34
+ # conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
35
+ norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
36
+ mask_act=args.conv_tasnet_params.mask_act,
37
+ # causal=args.conv_tasnet_params.causal,
38
+ )
39
+
40
+ model = BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(
41
+ encoder,
42
+ masker,
43
+ decoder,
44
+ encoder_activation=args.conv_tasnet_params.encoder_activation,
45
+ use_encoder=True,
46
+ apply_mask=True,
47
+ use_decoder=True,
48
+ decoder_activation=args.conv_tasnet_params.decoder_activation,
49
+ )
50
+ model.use_encoder_to_target = False
51
+
52
+ elif args.model_loss_params.architecture == "conv_tasnet":
53
+ encoder, decoder = make_enc_dec(
54
+ "free",
55
+ n_filters=args.conv_tasnet_params.n_filters,
56
+ kernel_size=args.conv_tasnet_params.kernel_size,
57
+ stride=args.conv_tasnet_params.stride,
58
+ sample_rate=args.sample_rate,
59
+ )
60
+ masker = TDConvNet(
61
+ in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo
62
+ n_src=args.conv_tasnet_params.n_src, # for de-limit task with the standard conv-tasnet setting.
63
+ out_chan=encoder.n_feats_out,
64
+ n_blocks=args.conv_tasnet_params.n_blocks,
65
+ n_repeats=args.conv_tasnet_params.n_repeats,
66
+ bn_chan=args.conv_tasnet_params.bn_chan,
67
+ hid_chan=args.conv_tasnet_params.hid_chan,
68
+ skip_chan=args.conv_tasnet_params.skip_chan,
69
+ # conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
70
+ norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
71
+ mask_act=args.conv_tasnet_params.mask_act,
72
+ # causal=args.conv_tasnet_params.causal,
73
+ )
74
+
75
+ model = BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(
76
+ encoder,
77
+ masker,
78
+ decoder,
79
+ encoder_activation=args.conv_tasnet_params.encoder_activation,
80
+ use_encoder=True,
81
+ apply_mask=False if args.conv_tasnet_params.synthesis else True,
82
+ use_decoder=True,
83
+ decoder_activation=args.conv_tasnet_params.decoder_activation,
84
+ )
85
+ model.use_encoder_to_target = False
86
+
87
+ return model
prepro/delimit_save_delimiter_stems.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save loudness normalized (-14 LUFS) musdb-XL audio files for delimiter evaluation
2
+
3
+ import os
4
+ import argparse
5
+
6
+ import tqdm
7
+ import musdb
8
+ import soundfile as sf
9
+ import librosa
10
+ import pyloudnorm as pyln
11
+
12
+ from utils import db2linear, str2bool
13
+
14
+
15
+ tqdm.monitor_interval = 0
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description="model test.py")
20
+
21
+ parser.add_argument(
22
+ "--target",
23
+ type=str,
24
+ default="vocals",
25
+ help="target source. all, vocals, drums, bass, other",
26
+ )
27
+ parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL")
28
+ parser.add_argument(
29
+ "--data_root_hq",
30
+ type=str,
31
+ default="/data1/Music/musdb18hq",
32
+ help="this is used when saving loud-norm stem of musdb-XL")
33
+ parser.add_argument(
34
+ "--output_directory",
35
+ type=str,
36
+ default="/path/to/results",
37
+ )
38
+ parser.add_argument("--exp_name", type=str, default="delimit_6_s")
39
+ parser.add_argument(
40
+ "--save_16k_mono",
41
+ type=str2bool,
42
+ default=False,
43
+ help="Save 16k mono wav files for FAD evaluation.",
44
+ )
45
+
46
+
47
+ args, _ = parser.parse_known_args()
48
+
49
+ os.makedirs(args.output_directory, exist_ok=True)
50
+
51
+ meter = pyln.Meter(44100)
52
+ args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
53
+
54
+ test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True)
55
+ if args.target != "mixture": # In this file, args.target should not be "mixture"
56
+ hq_tracks = musdb.DB(root=args.data_root_hq, subsets='test', is_wav=True)
57
+
58
+ for idx, track in tqdm.tqdm(enumerate(test_tracks)):
59
+ track_name = track.name
60
+ if (
61
+ os.path.basename(args.data_root) == "musdb18hq"
62
+ and track_name == "PR - Oh No"
63
+ ): # We have to consider this exception because 'PR - Oh No' mixture.wav is left-panned. We will use the linear mixture instead.
64
+ # Please refer https://github.com/jeonchangbin49/musdb-XL/blob/main/make_L_and_XL.py
65
+ track_audio = (
66
+ track.targets["vocals"].audio
67
+ + track.targets["drums"].audio
68
+ + track.targets["bass"].audio
69
+ + track.targets["other"].audio
70
+ )
71
+ else:
72
+ track_audio = track.audio
73
+
74
+ delimiter_track = librosa.load(f"{args.test_output_dir}/{track_name}/all.wav", sr=44100, mono=False)[0].T
75
+
76
+ print(track_name)
77
+
78
+ if args.target != "mixture":
79
+ hq_track = hq_tracks[idx]
80
+ hq_audio = hq_track.audio
81
+ hq_stem = hq_track.targets[args.target].audio
82
+ hq_samplewise_gain = track_audio / (hq_audio + 1e-8)
83
+ XL_stem = hq_samplewise_gain * hq_stem
84
+ XL_samplewise_gain = delimiter_track / (track_audio + 1e-8)
85
+ delimiter_stem = XL_samplewise_gain * XL_stem
86
+
87
+ sf.write(
88
+ f"{args.test_output_dir}/{track_name}/{args.target}.wav", delimiter_stem, 44100
89
+ )
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
prepro/delimit_save_musdb_loudnorm.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save loudness normalized (-14 LUFS) musdb-XL audio files for evaluations of de-limiter
2
+
3
+ import os
4
+ import argparse
5
+
6
+ import tqdm
7
+ import musdb
8
+ import soundfile as sf
9
+ import librosa
10
+ import pyloudnorm as pyln
11
+
12
+ from utils import db2linear, str2bool
13
+
14
+
15
+ tqdm.monitor_interval = 0
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description="model test.py")
20
+
21
+ parser.add_argument(
22
+ "--target",
23
+ type=str,
24
+ default="mixture",
25
+ help="target source. all, vocals, drums, bass, other",
26
+ )
27
+ parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL")
28
+ parser.add_argument(
29
+ "--data_root_hq",
30
+ type=str,
31
+ default="/path/to/musdb18hq",
32
+ help="this is used when saving loud-norm stem of musdb-XL")
33
+ parser.add_argument(
34
+ "--output_directory",
35
+ type=str,
36
+ default="/path/to/musdb_XL_loudnorm",
37
+ )
38
+ parser.add_argument(
39
+ "--loudnorm_input_lufs",
40
+ type=float,
41
+ default=-14.0,
42
+ help="If you want to use loudnorm, input target lufs",
43
+ )
44
+ parser.add_argument(
45
+ "--save_16k_mono",
46
+ type=str2bool,
47
+ default=True,
48
+ help="Save 16k mono wav files for FAD evaluation.",
49
+ )
50
+
51
+
52
+ args, _ = parser.parse_known_args()
53
+
54
+ os.makedirs(args.output_directory, exist_ok=True)
55
+
56
+ meter = pyln.Meter(44100)
57
+
58
+ test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True)
59
+ if args.target != "mixture":
60
+ hq_tracks = musdb.DB(root=args.data_root_hq, subsets='test', is_wav=True)
61
+
62
+ for idx, track in tqdm.tqdm(enumerate(test_tracks)):
63
+ track_name = track.name
64
+ if (
65
+ os.path.basename(args.data_root) == "musdb18hq"
66
+ and track_name == "PR - Oh No"
67
+ ): # We have to consider this exception because 'PR - Oh No' mixture.wav is left-panned. We will use the linear mixture instead.
68
+ # Please refer https://github.com/jeonchangbin49/musdb-XL/blob/main/make_L_and_XL.py
69
+ track_audio = (
70
+ track.targets["vocals"].audio
71
+ + track.targets["drums"].audio
72
+ + track.targets["bass"].audio
73
+ + track.targets["other"].audio
74
+ )
75
+ else:
76
+ track_audio = track.audio
77
+
78
+ print(track_name)
79
+
80
+ augmented_gain = None
81
+
82
+ track_lufs = meter.integrated_loudness(track_audio)
83
+ augmented_gain = args.loudnorm_input_lufs - track_lufs
84
+ if os.path.basename(args.data_root) == "musdb18hq":
85
+ if args.target != "mixture":
86
+ track_audio = track.targets[args.target].audio
87
+ track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
88
+ elif os.path.basename(args.data_root) == "musdb_XL":
89
+ track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
90
+ if args.target != "mixture":
91
+ hq_track = hq_tracks[idx]
92
+ hq_audio = hq_track.audio
93
+ hq_stem = hq_track.targets[args.target].audio
94
+ samplewise_gain = track_audio / (hq_audio + 1e-8)
95
+ track_audio = samplewise_gain * hq_stem
96
+
97
+ os.makedirs(f"{args.output_directory}/{track_name}", exist_ok=True)
98
+ sf.write(
99
+ f"{args.output_directory}/{track_name}/{args.target}.wav", track_audio, 44100
100
+ )
101
+
102
+ if args.save_16k_mono:
103
+ track_audio_16k_mono = librosa.to_mono(track_audio.T)
104
+ track_audio_16k_mono = librosa.resample(
105
+ track_audio_16k_mono,
106
+ orig_sr=44100,
107
+ target_sr=16000,
108
+ )
109
+ os.makedirs(f"{args.output_directory}_16k_mono/{track_name}", exist_ok=True)
110
+ sf.write(
111
+ f"{args.output_directory}_16k_mono/{track_name}/{args.target}.wav",
112
+ track_audio_16k_mono,
113
+ samplerate=16000,
114
+ )
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()
prepro/delimit_train_ozone_prepro.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import csv
4
+ import glob
5
+ import argparse
6
+ import random
7
+ import math
8
+
9
+ import librosa
10
+ import soundfile as sf
11
+ import pedalboard
12
+ import numpy as np
13
+ import pyloudnorm as pyln
14
+ from scipy.stats import gamma
15
+ import torchaudio
16
+
17
+
18
+ def str2bool(v):
19
+ if v.lower() in ("yes", "true", "t", "y", "1"):
20
+ return True
21
+ elif v.lower() in ("no", "false", "f", "n", "0"):
22
+ return False
23
+ else:
24
+ raise argparse.ArgumentTypeError("Boolean value expected.")
25
+
26
+
27
+ def _augment_gain_ozone(audio, low=0.25, high=1.25):
28
+ """Applies a random gain between `low` and `high`"""
29
+ g = low + random.random() * (high - low)
30
+ return audio * g, g
31
+
32
+
33
+ def _augment_channelswap_ozone(audio):
34
+ """Swap channels of stereo signals with a probability of p=0.5"""
35
+ if audio.shape[0] == 2 and random.random() < 0.5:
36
+ return np.flip(audio, axis=0), True # axis=0 must be given
37
+ else:
38
+ return audio, False
39
+
40
+
41
+ # load wav file from arbitrary positions of 16bit stereo wav file
42
+ def load_wav_arbitrary_position_stereo(
43
+ filename, sample_rate, seq_duration, return_pos=False
44
+ ):
45
+ # stereo
46
+ # seq_duration[second]
47
+ length = torchaudio.info(filename).num_frames
48
+
49
+ random_start = random.randint(
50
+ 0, int(length - math.ceil(seq_duration * sample_rate) - 1)
51
+ )
52
+ random_start_sec = librosa.samples_to_time(random_start, sr=sample_rate)
53
+ X, sr = librosa.load(
54
+ filename, sr=None, mono=False, offset=random_start_sec, duration=seq_duration
55
+ )
56
+
57
+ if return_pos:
58
+ return X, random_start_sec
59
+ else:
60
+ return X
61
+
62
+
63
+ # def main():
64
+ parser = argparse.ArgumentParser(description="Preprocess audio files for training")
65
+ parser.add_argument(
66
+ "--root",
67
+ type=str,
68
+ default="/path/to/musdb18hq",
69
+ help="Root directory",
70
+ )
71
+ parser.add_argument(
72
+ "--output",
73
+ type=str,
74
+ default="/path/to/musdb-XL-train",
75
+ help="Where to save output files",
76
+ )
77
+ parser.add_argument(
78
+ "--n_samples", type=int, default=300000, help="Number of samples to save"
79
+ )
80
+ parser.add_argument("--seq_duration", type=float, default=4.0, help="Sequence duration")
81
+ parser.add_argument(
82
+ "--save_fixed", type=str2bool, default=False, help="Save fixed mixture audio"
83
+ )
84
+ parser.add_argument(
85
+ "--target_lufs_mean", type=float, default=-8.0, help="Target LUFS mean"
86
+ )
87
+ parser.add_argument(
88
+ "--target_lufs_std", type=float, default=-1.0, help="Target LUFS std"
89
+ )
90
+ parser.add_argument("--sample_rate", type=int, default=44100, help="Sample rate")
91
+ parser.add_argument("--seed", type=int, default=46, help="Random seed")
92
+ args = parser.parse_args()
93
+ random.seed(args.seed)
94
+
95
+ valid_list = [
96
+ "ANiMAL - Rockshow",
97
+ "Actions - One Minute Smile",
98
+ "Alexander Ross - Goodbye Bolero",
99
+ "Clara Berry And Wooldog - Waltz For My Victims",
100
+ "Fergessen - Nos Palpitants",
101
+ "James May - On The Line",
102
+ "Johnny Lokke - Promises & Lies",
103
+ "Leaf - Summerghost",
104
+ "Meaxic - Take A Step",
105
+ "Patrick Talbot - A Reason To Leave",
106
+ "Skelpolu - Human Mistakes",
107
+ "Traffic Experiment - Sirens",
108
+ "Triviul - Angelsaint",
109
+ "Young Griffo - Pennies",
110
+ ]
111
+
112
+ meter = pyln.Meter(args.sample_rate)
113
+
114
+
115
+ sources = ["vocals", "bass", "drums", "other"]
116
+ song_list = glob.glob(f"{args.root}/train/*")
117
+
118
+ vst = pedalboard.load_plugin(
119
+ "/Library/Audio/Plug-Ins/Components/iZOzone9ElementsAUHook.component"
120
+ )
121
+
122
+ if args.save_fixed:
123
+ vst_params = []
124
+
125
+ os.makedirs(f"{args.output}/ozone_train_fixed", exist_ok=True)
126
+
127
+ for song in song_list:
128
+ print(f"Processing {song}...")
129
+ song_name = os.path.basename(song)
130
+ audio_sources = []
131
+ for source in sources:
132
+ audio_path = f"{song}/{source}.wav"
133
+ audio, sr = librosa.load(audio_path, sr=args.sample_rate, mono=False)
134
+ audio_sources.append(audio)
135
+ stems = np.stack(audio_sources, axis=0)
136
+ mixture = stems.sum(0)
137
+ lufs = meter.integrated_loudness(mixture.T)
138
+ target_lufs = random.gauss(args.target_lufs_mean, args.target_lufs_std)
139
+ adjusted_loudness = target_lufs - lufs
140
+
141
+ vst.reset()
142
+ vst.eq_bypass = True
143
+ vst.img_bypass = True
144
+ vst.max_mode = 1.0 # Set IRC2 mode
145
+ vst.max_threshold = min(-adjusted_loudness, 0.0)
146
+ vst.max_character = min(gamma.rvs(2), 10.0)
147
+
148
+ print(
149
+ f"Applying Ozone 9 Elements IRC2 with threshold {vst.max_threshold} and character {vst.max_character}..."
150
+ )
151
+ limited_mixture = vst(mixture, args.sample_rate)
152
+
153
+ sf.write(
154
+ f"{args.output}/ozone_train_fixed/{song_name}.wav",
155
+ limited_mixture.T,
156
+ args.sample_rate,
157
+ )
158
+ vst_params.append([song_name, vst.max_threshold, vst.max_character])
159
+ # Save the song name and vst parameters (vst.max_threshold and vst.max_character) to a csv file
160
+ with open(f"{args.output}/ozone_train_fixed.csv", "w") as f:
161
+ writer = csv.writer(f)
162
+ writer.writerow(["song_name", "max_threshold", "max_character"])
163
+ for idx, list_vst_param in enumerate(vst_params):
164
+ writer.writerow(list_vst_param)
165
+
166
+ else:
167
+ if os.path.exists(f"{args.output}/ozone_train_random_0.csv"):
168
+ vst_params = []
169
+ list_csv_files = glob.glob(f"{args.output}/ozone_train_random_*.csv")
170
+ list_csv_files.sort()
171
+ for csv_file in list_csv_files:
172
+ with open(csv_file, "r") as f:
173
+ reader = csv.reader(f)
174
+ next(reader)
175
+ vst_params.extend([row for row in reader])
176
+
177
+ else:
178
+ vst_params = []
179
+
180
+ song_list = [x for x in song_list if os.path.basename(x) not in valid_list]
181
+
182
+ os.makedirs(f"{args.output}/ozone_train_random", exist_ok=True)
183
+
184
+ for n in range(len(vst_params), args.n_samples):
185
+ print(f"Processing {n} / {args.n_samples}...")
186
+ seg_name = f"ozone_seg_{n}"
187
+
188
+ lufs_not_inf = True
189
+ while lufs_not_inf:
190
+ audio_sources = []
191
+ source_song_names = {}
192
+ source_start_secs = {}
193
+ source_gains = {}
194
+ source_channelswaps = {}
195
+ for source in sources:
196
+ track_path = random.choice(song_list)
197
+ song_name = os.path.basename(track_path)
198
+ audio_path = f"{track_path}/{source}.wav"
199
+ audio, start_sec = load_wav_arbitrary_position_stereo(
200
+ audio_path, args.sample_rate, args.seq_duration, return_pos=True
201
+ )
202
+ audio, gain = _augment_gain_ozone(audio)
203
+ audio, channelswap = _augment_channelswap_ozone(audio)
204
+ audio_sources.append(audio)
205
+ source_song_names[source] = song_name
206
+ source_start_secs[source] = start_sec
207
+ source_gains[source] = gain
208
+ source_channelswaps[source] = channelswap
209
+
210
+ stems = np.stack(audio_sources, axis=0)
211
+ mixture = stems.sum(0)
212
+ lufs = meter.integrated_loudness(mixture.T)
213
+
214
+ # if lufs is inf, then the mixture is silent, so we need to generate a new mixture
215
+ lufs_not_inf = np.isinf(lufs)
216
+
217
+ target_lufs = random.gauss(args.target_lufs_mean, args.target_lufs_std)
218
+ adjusted_loudness = target_lufs - lufs
219
+
220
+ vst.reset()
221
+ vst.eq_bypass = True
222
+ vst.img_bypass = True
223
+ vst.max_mode = 1.0 # Set IRC2 mode
224
+ vst.max_threshold = min(max(-20, -adjusted_loudness), 0.0)
225
+ vst.max_character = min(gamma.rvs(2), 10.0)
226
+
227
+ print(
228
+ f"Applying Ozone 9 Elements IRC2 with threshold {vst.max_threshold} and character {vst.max_character}..."
229
+ )
230
+ limited_mixture = vst(mixture, args.sample_rate)
231
+
232
+ sf.write(
233
+ f"{args.output}/ozone_train_random_0/{seg_name}.wav",
234
+ limited_mixture.T,
235
+ args.sample_rate,
236
+ )
237
+ vst_params.append(
238
+ [
239
+ seg_name,
240
+ vst.max_threshold,
241
+ vst.max_character,
242
+ source_song_names["vocals"],
243
+ source_start_secs["vocals"],
244
+ source_gains["vocals"],
245
+ source_channelswaps["vocals"],
246
+ source_song_names["bass"],
247
+ source_start_secs["bass"],
248
+ source_gains["bass"],
249
+ source_channelswaps["bass"],
250
+ source_song_names["drums"],
251
+ source_start_secs["drums"],
252
+ source_gains["drums"],
253
+ source_channelswaps["drums"],
254
+ source_song_names["other"],
255
+ source_start_secs["other"],
256
+ source_gains["other"],
257
+ source_channelswaps["other"],
258
+ ]
259
+ )
260
+
261
+ if (n + 1) % 20000 == 0 or n == args.n_samples - 1:
262
+ # We will separate the csv file into multiple files to avoid memory error
263
+ # Save the song name and vst parameters (vst.max_threshold and vst.max_character) to a csv file
264
+ number = int(n // 20000)
265
+ with open(f"{args.output}/ozone_train_random_{number}.csv", "w") as f:
266
+ writer = csv.writer(f)
267
+ writer.writerow(
268
+ [
269
+ "song_name",
270
+ "max_threshold",
271
+ "max_character",
272
+ "vocals_name",
273
+ "vocals_start_sec",
274
+ "vocals_gain",
275
+ "vocals_channelswap",
276
+ "bass_name",
277
+ "bass_start_sec",
278
+ "bass_gain",
279
+ "bass_channelswap",
280
+ "drums_name",
281
+ "drums_start_sec",
282
+ "drums_gain",
283
+ "drums_channelswap",
284
+ "other_name",
285
+ "other_start_sec",
286
+ "other_gain",
287
+ "other_channelswap",
288
+ ]
289
+ )
290
+ for idx, list_vst_param in enumerate(
291
+ vst_params[number * 20000 : (number + 1) * 20000]
292
+ ):
293
+ writer.writerow(list_vst_param)
prepro/delimit_valid_L_prepro.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import DataLoader
5
+ import soundfile as sf
6
+ import tqdm
7
+
8
+ from dataloader import DelimitValidDataset
9
+
10
+
11
+ def main():
12
+ # Parameters
13
+ data_path = "/path/to/musdb18hq"
14
+ save_path = "/path/to/musdb18hq_limited_L"
15
+ batch_size = 1
16
+ num_workers = 1
17
+ sr = 44100
18
+
19
+ # Dataset
20
+ dataset = DelimitValidDataset(root=data_path, valid_target_lufs=-14.39)
21
+ data_loader = DataLoader(
22
+ dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
23
+ )
24
+ dict_valid_loudness = {}
25
+ # Preprocessing
26
+ for limited_audio, orig_audio, audio_name, loudness in tqdm.tqdm(data_loader):
27
+ audio_name = audio_name[0]
28
+ limited_audio = limited_audio[0].numpy()
29
+ loudness = float(loudness[0].numpy())
30
+ dict_valid_loudness[audio_name] = loudness
31
+ # Save audio
32
+ os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
33
+ audio_path = os.path.join(save_path, "valid", audio_name)
34
+ sf.write(f"{audio_path}.wav", limited_audio.T, sr)
35
+ # write json write code
36
+ with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
37
+ json.dump(dict_valid_loudness, f, indent=4)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
prepro/delimit_valid_custom_limiter_prepro.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import DataLoader
5
+ import soundfile as sf
6
+ import tqdm
7
+
8
+ from dataloader import DelimitValidDataset
9
+
10
+
11
+ def main():
12
+ # Parameters
13
+ data_path = "/path/to/musdb18hq"
14
+ save_path = (
15
+ "/path/to/musdb18hq_custom_limiter_fixed_attack"
16
+ )
17
+ batch_size = 1
18
+ num_workers = 1
19
+ sr = 44100
20
+
21
+ # Dataset
22
+ dataset = DelimitValidDataset(
23
+ root=data_path, use_custom_limiter=True, custom_limiter_attack_range=[2.0, 2.0]
24
+ )
25
+ data_loader = DataLoader(
26
+ dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
27
+ )
28
+ dict_valid_loudness = {}
29
+ dict_limiter_params = {}
30
+ # Preprocessing
31
+ for (
32
+ limited_audio,
33
+ orig_audio,
34
+ audio_name,
35
+ loudness,
36
+ custom_attack,
37
+ custom_release,
38
+ ) in tqdm.tqdm(data_loader):
39
+ audio_name = audio_name[0]
40
+ limited_audio = limited_audio[0].numpy()
41
+ loudness = float(loudness[0].numpy())
42
+ dict_valid_loudness[audio_name] = loudness
43
+ dict_limiter_params[audio_name] = {
44
+ "attack_ms": float(custom_attack[0].numpy()),
45
+ "release_ms": float(custom_release[0].numpy()),
46
+ }
47
+ # Save audio
48
+ os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
49
+ audio_path = os.path.join(save_path, "valid", audio_name)
50
+ sf.write(f"{audio_path}.wav", limited_audio.T, sr)
51
+ # write json write code
52
+ with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
53
+ json.dump(dict_valid_loudness, f, indent=4)
54
+ with open(os.path.join(save_path, "valid_limiter_params.json"), "w") as f:
55
+ json.dump(dict_limiter_params, f, indent=4)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ main()
prepro/delimit_valid_prepro.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import DataLoader
5
+ import soundfile as sf
6
+ import tqdm
7
+
8
+ from dataloader import DelimitValidDataset
9
+
10
+
11
+ def main():
12
+ # Parameters
13
+ data_path = "/path/to/musdb18hq"
14
+ save_path = "/path/to/musdb18hq_limited"
15
+ batch_size = 1
16
+ num_workers = 1
17
+ sr = 44100
18
+
19
+ # Dataset
20
+ dataset = DelimitValidDataset(root=data_path)
21
+ data_loader = DataLoader(
22
+ dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
23
+ )
24
+ dict_valid_loudness = {}
25
+ # Preprocessing
26
+ for limited_audio, orig_audio, audio_name, loudness in tqdm.tqdm(data_loader):
27
+ audio_name = audio_name[0]
28
+ limited_audio = limited_audio[0].numpy()
29
+ loudness = float(loudness[0].numpy())
30
+ dict_valid_loudness[audio_name] = loudness
31
+ # Save audio
32
+ os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
33
+ audio_path = os.path.join(save_path, "valid", audio_name)
34
+ sf.write(f"{audio_path}.wav", limited_audio.T, sr)
35
+ # write json write code
36
+ with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
37
+ json.dump(dict_valid_loudness, f, indent=4)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/asteroid-team/asteroid.git@master
2
+ numpy
3
+ librosa
4
+ soundfile
5
+ torch
6
+ torchaudio
7
+ matplotlib
8
+ wandb
9
+ musdb
10
+ dotmap
11
+ ema-pytorch
12
+ pedalboard
13
+ einops
separate_func/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .conv_tasnet_separate import conv_tasnet_separate
separate_func/conv_tasnet_separate.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import soundfile as sf
4
+ import torch
5
+ import pyloudnorm as pyln
6
+ import librosa
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+
10
+ from dataloader import SingleTrackSet
11
+ from utils import db2linear
12
+
13
+
14
+ def conv_tasnet_separate(
15
+ args, our_model, device, track_audio, track_name, meter=None, augmented_gain=None
16
+ ):
17
+
18
+ if args.use_singletrackset:
19
+ db = SingleTrackSet(
20
+ track_audio.squeeze(dim=0),
21
+ hop_length=args.data_params.nhop,
22
+ num_frame=128,
23
+ target_name=args.target,
24
+ )
25
+ separated = []
26
+
27
+ for item in db:
28
+ item = item.unsqueeze(0).to(device)
29
+ estimates, *estimates_vars = our_model(item)
30
+ if args.task_params.dataset == "delimit":
31
+ estimates = estimates_vars[0]
32
+
33
+ estimates = estimates.cpu().detach()
34
+ separated.append(
35
+ estimates[..., db.trim_length : -db.trim_length].cpu().detach().clone()
36
+ )
37
+
38
+ estimates = torch.cat(separated, dim=-1)
39
+ estimates = estimates[0, :, : track_audio.shape[-1]].numpy()
40
+ else:
41
+ estimates, *estimates_vars = our_model(track_audio)
42
+ if args.save_histogram and args.task_params.dataset == "delimit":
43
+ plt.figure(figsize=(10, 10))
44
+ plt.hist(estimates.cpu().detach().numpy().flatten(), bins=100)
45
+ os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)
46
+ plt.savefig(
47
+ f"{args.test_output_dir}/{track_name}/{args.target}_histogram.png"
48
+ )
49
+ if args.task_params.dataset == "delimit":
50
+ estimates = estimates_vars[0]
51
+
52
+ estimates = estimates.cpu().detach().numpy()
53
+ estimates = estimates[0, :, : track_audio.shape[-1]]
54
+
55
+ if args.save_name_as_target:
56
+ os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)
57
+
58
+ if args.save_output_loudnorm:
59
+ print("SAVE Loudness normalized OUTPUT ")
60
+ loudness = meter.integrated_loudness(estimates.T)
61
+ estimates = estimates * db2linear(args.save_output_loudnorm - loudness, eps=0.0)
62
+ elif augmented_gain != None and args.save_output_loudnorm == None:
63
+ estimates = estimates * db2linear(-augmented_gain, eps=0.0)
64
+
65
+ sf.write(
66
+ f"{args.test_output_dir}/{track_name}/{args.target}.wav"
67
+ if args.save_name_as_target
68
+ else f"{args.test_output_dir}/{track_name}.wav",
69
+ estimates.T,
70
+ samplerate=args.data_params.sample_rate,
71
+ )
72
+
73
+ if args.save_16k_mono:
74
+ estimates_16k_mono = librosa.to_mono(estimates)
75
+ estimates_16k_mono = librosa.resample(
76
+ estimates_16k_mono,
77
+ orig_sr=args.data_params.sample_rate,
78
+ target_sr=16000,
79
+ )
80
+ os.makedirs(f"{args.test_output_dir}_16k_mono/{track_name}", exist_ok=True)
81
+ sf.write(
82
+ f"{args.test_output_dir}_16k_mono/{track_name}/{args.target}.wav"
83
+ if args.save_name_as_target
84
+ else f"{args.test_output_dir}_16k_mono/{track_name}.wav",
85
+ estimates_16k_mono,
86
+ samplerate=16000,
87
+ )
88
+
89
+ return estimates
solver_ddp.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import wandb
7
+ import matplotlib
8
+
9
+ matplotlib.use("Agg")
10
+ import matplotlib.pyplot as plt
11
+ import torch.distributed as dist
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from torch.nn.parallel.distributed import DistributedDataParallel as DDP
14
+ from asteroid.losses import (
15
+ pairwise_neg_sisdr,
16
+ PairwiseNegSDR,
17
+ )
18
+ from einops import rearrange, reduce
19
+ from ema_pytorch import EMA
20
+
21
+ from models import load_model_with_args
22
+ import utils
23
+ from dataloader import (
24
+ MusdbTrainDataset,
25
+ MusdbValidDataset,
26
+ DelimitTrainDataset,
27
+ DelimitValidDataset,
28
+ OzoneTrainDataset,
29
+ OzoneValidDataset,
30
+ aug_from_str,
31
+ SingleTrackSet,
32
+ )
33
+
34
+
35
+ class Solver(object):
36
+ def __init__(self):
37
+ pass
38
+
39
+ def set_gpu(self, args):
40
+
41
+ if args.wandb_params.use_wandb and args.gpu == 0:
42
+ if args.wandb_params.sweep:
43
+ wandb.init(
44
+ entity=args.wandb_params.entity,
45
+ project=args.wandb_params.project,
46
+ config=args,
47
+ resume=True
48
+ if args.dir_params.resume != None and args.gpu == 0
49
+ else False,
50
+ )
51
+ else:
52
+ wandb.init(
53
+ entity=args.wandb_params.entity,
54
+ project=args.wandb_params.project,
55
+ name=f"{args.dir_params.exp_name}",
56
+ config=args,
57
+ resume="must"
58
+ if args.dir_params.resume != None
59
+ and not args.dir_params.continual_train
60
+ else False,
61
+ id=args.wandb_params.rerun_id
62
+ if args.wandb_params.rerun_id
63
+ else None,
64
+ settings=wandb.Settings(start_method="fork"),
65
+ )
66
+
67
+ ###################### Define Models ######################
68
+ self.model = load_model_with_args(args)
69
+
70
+ trainable_params = []
71
+ trainable_params = trainable_params + list(self.model.parameters())
72
+
73
+ if args.hyperparams.optimizer == "sgd":
74
+ print("Use SGD optimizer.")
75
+ self.optimizer = torch.optim.SGD(
76
+ params=trainable_params,
77
+ lr=args.hyperparams.lr,
78
+ momentum=0.9,
79
+ weight_decay=args.hyperparams.weight_decay,
80
+ )
81
+ elif args.hyperparams.optimizer == "adamw":
82
+ print("Use AdamW optimizer.")
83
+ self.optimizer = torch.optim.AdamW(
84
+ params=trainable_params,
85
+ lr=args.hyperparams.lr,
86
+ betas=(0.9, 0.999),
87
+ amsgrad=False,
88
+ weight_decay=args.hyperparams.weight_decay,
89
+ )
90
+ elif args.hyperparams.optimizer == "radam":
91
+ print("Use RAdam optimizer.")
92
+ self.optimizer = torch.optim.RAdam(
93
+ params=trainable_params,
94
+ lr=args.hyperparams.lr,
95
+ betas=(0.9, 0.999),
96
+ eps=1e-08,
97
+ weight_decay=args.hyperparams.weight_decay,
98
+ )
99
+ elif args.hyperparams.optimizer == "adam":
100
+ print("Use Adam optimizer.")
101
+ self.optimizer = torch.optim.Adam(
102
+ params=trainable_params,
103
+ lr=args.hyperparams.lr,
104
+ betas=(0.9, 0.999),
105
+ weight_decay=args.hyperparams.weight_decay,
106
+ )
107
+ else:
108
+ print("no optimizer loaded")
109
+ raise NotImplementedError
110
+
111
+ if args.hyperparams.lr_scheduler == "step_lr":
112
+ if args.model_loss_params.architecture == "umx":
113
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
114
+ self.optimizer,
115
+ mode="min",
116
+ factor=args.hyperparams.lr_decay_gamma,
117
+ patience=args.hyperparams.lr_decay_patience,
118
+ cooldown=10,
119
+ verbose=True,
120
+ )
121
+ else:
122
+ self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
123
+ self.optimizer,
124
+ mode="min",
125
+ factor=args.hyperparams.lr_decay_gamma,
126
+ patience=args.hyperparams.lr_decay_patience,
127
+ cooldown=0,
128
+ min_lr=5e-5,
129
+ verbose=True,
130
+ )
131
+ elif args.hyperparams.lr_scheduler == "cos_warmup":
132
+ self.scheduler = utils.CosineAnnealingWarmUpRestarts(
133
+ self.optimizer,
134
+ T_0=40,
135
+ T_mult=1,
136
+ eta_max=args.hyperparams.lr,
137
+ T_up=10,
138
+ gamma=0.5,
139
+ )
140
+
141
+ torch.cuda.set_device(args.gpu)
142
+
143
+ self.model = self.model.to(f"cuda:{args.gpu}")
144
+
145
+ ############################################################
146
+ # Define Losses
147
+ self.criterion = {}
148
+
149
+ self.criterion["l1"] = nn.L1Loss().to(args.gpu)
150
+ self.criterion["mse"] = nn.MSELoss().to(args.gpu)
151
+ self.criterion["si_sdr"] = pairwise_neg_sisdr.to(args.gpu)
152
+ self.criterion["snr"] = PairwiseNegSDR("snr").to(args.gpu)
153
+ self.criterion["bcewithlogits"] = nn.BCEWithLogitsLoss().to(args.gpu)
154
+ self.criterion["bce"] = nn.BCELoss().to(args.gpu)
155
+ self.criterion["kl"] = nn.KLDivLoss(log_target=True).to(args.gpu)
156
+
157
+ print("Loss functions we use in this training:")
158
+ print(args.model_loss_params.train_loss_func)
159
+
160
+ # Early stopping utils
161
+ self.es = utils.EarlyStopping(patience=args.hyperparams.patience)
162
+ self.stop = False
163
+
164
+ if args.wandb_params.use_wandb and args.gpu == 0:
165
+ wandb.watch(self.model, log="all")
166
+
167
+ self.start_epoch = 1
168
+ self.train_losses = []
169
+ self.valid_losses = []
170
+ self.train_times = []
171
+ self.best_epoch = 0
172
+
173
+ if args.dir_params.resume and not args.hyperparams.ema:
174
+ self.resume(args)
175
+
176
+ # Distribute models to machine
177
+ self.model = DDP(
178
+ self.model,
179
+ device_ids=[args.gpu],
180
+ output_device=args.gpu,
181
+ find_unused_parameters=True,
182
+ )
183
+
184
+ if args.hyperparams.ema:
185
+ self.model_ema = EMA(
186
+ self.model,
187
+ beta=0.999,
188
+ update_after_step=100,
189
+ update_every=10,
190
+ )
191
+
192
+ if args.resume and args.hyperparams.ema:
193
+ self.resume(args)
194
+
195
+ ###################### Define data pipeline ######################
196
+ args.hyperparams.batch_size = int(
197
+ args.hyperparams.batch_size / args.ngpus_per_node
198
+ )
199
+ self.mp_context = torch.multiprocessing.get_context("fork")
200
+
201
+ if args.task_params.dataset == "musdb":
202
+ self.train_dataset = MusdbTrainDataset(
203
+ target=args.task_params.target,
204
+ root=args.dir_params.root,
205
+ seq_duration=args.data_params.seq_dur,
206
+ samples_per_track=args.data_params.samples_per_track,
207
+ source_augmentations=aug_from_str(
208
+ ["gain", "channelswap"],
209
+ ),
210
+ sample_rate=args.data_params.sample_rate,
211
+ seed=args.sys_params.seed,
212
+ limitaug_method=args.data_params.limitaug_method,
213
+ limitaug_mode=args.data_params.limitaug_mode,
214
+ limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs,
215
+ limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std,
216
+ target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
217
+ custom_limiter_attack_range=args.data_params.custom_limiter_attack_range,
218
+ custom_limiter_release_range=args.data_params.custom_limiter_release_range,
219
+ )
220
+ self.valid_dataset = MusdbValidDataset(
221
+ target=args.task_params.target, root=args.dir_params.root
222
+ )
223
+ elif args.task_params.dataset == "delimit":
224
+ if args.data_params.limitaug_method == "ozone":
225
+ self.train_dataset = OzoneTrainDataset(
226
+ target=args.task_params.target,
227
+ root=args.dir_params.root,
228
+ ozone_root=args.dir_params.ozone_root,
229
+ use_fixed=args.data_params.use_fixed,
230
+ seq_duration=args.data_params.seq_dur,
231
+ samples_per_track=args.data_params.samples_per_track,
232
+ source_augmentations=aug_from_str(
233
+ ["gain", "channelswap"],
234
+ ),
235
+ sample_rate=args.data_params.sample_rate,
236
+ seed=args.sys_params.seed,
237
+ limitaug_method=args.data_params.limitaug_method,
238
+ limitaug_mode=args.data_params.limitaug_mode,
239
+ limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs,
240
+ limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std,
241
+ target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
242
+ target_limitaug_mode=args.data_params.target_limitaug_mode,
243
+ target_limitaug_custom_target_lufs=args.data_params.target_limitaug_custom_target_lufs,
244
+ target_limitaug_custom_target_lufs_std=args.data_params.target_limitaug_custom_target_lufs_std,
245
+ custom_limiter_attack_range=args.data_params.custom_limiter_attack_range,
246
+ custom_limiter_release_range=args.data_params.custom_limiter_release_range,
247
+ )
248
+ self.valid_dataset = OzoneValidDataset(
249
+ target=args.task_params.target,
250
+ root=args.dir_params.root,
251
+ ozone_root=args.dir_params.ozone_root,
252
+ target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
253
+ )
254
+ else:
255
+ self.train_dataset = DelimitTrainDataset(
256
+ target=args.task_params.target,
257
+ root=args.dir_params.root,
258
+ seq_duration=args.data_params.seq_dur,
259
+ samples_per_track=args.data_params.samples_per_track,
260
+ source_augmentations=aug_from_str(
261
+ ["gain", "channelswap"],
262
+ ),
263
+ sample_rate=args.data_params.sample_rate,
264
+ seed=args.sys_params.seed,
265
+ limitaug_method=args.data_params.limitaug_method,
266
+ limitaug_mode=args.data_params.limitaug_mode,
267
+ limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs,
268
+ limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std,
269
+ target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
270
+ target_limitaug_mode=args.data_params.target_limitaug_mode,
271
+ target_limitaug_custom_target_lufs=args.data_params.target_limitaug_custom_target_lufs,
272
+ target_limitaug_custom_target_lufs_std=args.data_params.target_limitaug_custom_target_lufs_std,
273
+ custom_limiter_attack_range=args.data_params.custom_limiter_attack_range,
274
+ custom_limiter_release_range=args.data_params.custom_limiter_release_range,
275
+ )
276
+ self.valid_dataset = DelimitValidDataset(
277
+ target=args.task_params.target,
278
+ root=args.dir_params.root,
279
+ delimit_valid_root=args.dir_params.delimit_valid_root,
280
+ valid_target_lufs=args.data_params.valid_target_lufs,
281
+ target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
282
+ delimit_valid_L_root=args.dir_params.delimit_valid_L_root,
283
+ )
284
+
285
+ self.train_sampler = DistributedSampler(
286
+ self.train_dataset, shuffle=True, rank=args.gpu
287
+ )
288
+ self.train_loader = torch.utils.data.DataLoader(
289
+ self.train_dataset,
290
+ batch_size=args.hyperparams.batch_size,
291
+ shuffle=False,
292
+ num_workers=args.sys_params.nb_workers,
293
+ multiprocessing_context=self.mp_context,
294
+ pin_memory=True,
295
+ sampler=self.train_sampler,
296
+ drop_last=False,
297
+ )
298
+
299
+ self.valid_sampler = DistributedSampler(
300
+ self.valid_dataset, shuffle=False, rank=args.gpu
301
+ )
302
+ self.valid_loader = torch.utils.data.DataLoader(
303
+ self.valid_dataset,
304
+ batch_size=1,
305
+ shuffle=False,
306
+ num_workers=args.sys_params.nb_workers,
307
+ multiprocessing_context=self.mp_context,
308
+ pin_memory=False,
309
+ sampler=self.valid_sampler,
310
+ drop_last=False,
311
+ )
312
+
313
+ def train(self, args, epoch):
314
+ self.end = time.time()
315
+ self.model.train()
316
+
317
+ # get current learning rate
318
+ for param_group in self.optimizer.param_groups:
319
+ current_lr = param_group["lr"]
320
+
321
+ if (
322
+ args.sys_params.rank % args.ngpus_per_node == 0
323
+ ): # when the last rank process is finished
324
+ print(f"Epoch {epoch}, Learning rate: {current_lr}")
325
+
326
+ losses = utils.AverageMeter()
327
+ loss_logger = {}
328
+
329
+ loss_logger["train/train loss"] = 0
330
+ # with torch.autograd.detect_anomaly(): # use this if you want to detect anomaly behavior while training.
331
+ for i, values in enumerate(self.train_loader):
332
+ mixture, clean, *train_vars = values
333
+
334
+ mixture = mixture.cuda(args.gpu, non_blocking=True)
335
+ clean = clean.cuda(args.gpu, non_blocking=True)
336
+ target = clean # target_shape = [batch_size, n_srcs, nb_channels (if stereo: 2), wave_length]
337
+ loss_input = {}
338
+
339
+ estimates, *estimates_vars = self.model(mixture)
340
+ # estimates = self.model(mixture)
341
+
342
+ # loss = []
343
+ dict_loss = {}
344
+
345
+ if args.task_params.dataset == "delimit":
346
+ estimates = estimates_vars[0]
347
+
348
+ for train_loss_idx, single_train_loss_func in enumerate(
349
+ args.model_loss_params.train_loss_func
350
+ ):
351
+ if self.model.module.use_encoder_to_target:
352
+ target_spec = self.model.module.encoder(
353
+ rearrange(target, "b s c t -> (b s) c t")
354
+ )
355
+ target_spec = rearrange(
356
+ target_spec,
357
+ "(b s) c f t -> b s c f t",
358
+ s=args.task_params.bleeding_nsrcs,
359
+ )
360
+ loss_else = self.criterion[single_train_loss_func](
361
+ estimates,
362
+ target_spec
363
+ if self.model.module.use_encoder_to_target
364
+ else target,
365
+ )
366
+ dict_loss[single_train_loss_func] = (
367
+ loss_else.mean()
368
+ * args.model_loss_params.train_loss_scales[train_loss_idx]
369
+ )
370
+
371
+ loss = sum([value for key, value in dict_loss.items()])
372
+
373
+ ############################################################
374
+
375
+ #################### 5. Back propagation ####################
376
+ loss.backward()
377
+ if args.hyperparams.gradient_clip:
378
+ nn.utils.clip_grad_norm_(
379
+ self.model.parameters(), max_norm=args.hyperparams.gradient_clip
380
+ )
381
+
382
+ losses.update(loss.item(), clean.size(0))
383
+
384
+ loss_logger["train/train loss"] = losses.avg
385
+ for key, value in dict_loss.items():
386
+ loss_logger[f"train/{key}"] = value.item()
387
+
388
+ self.optimizer.step()
389
+
390
+ self.model.zero_grad(
391
+ set_to_none=True
392
+ ) # set_to_none=True is for memory saving
393
+
394
+ if args.hyperparams.ema:
395
+ self.model_ema.update()
396
+ ############################################################
397
+
398
+ # ###################### 6. Plot ######################
399
+
400
+ if i % 30 == 0:
401
+ # loss print for multiple loss function
402
+ multiple_score = torch.Tensor(
403
+ [value for key, value in loss_logger.items()]
404
+ ).to(args.gpu)
405
+ gathered_score_list = [
406
+ torch.ones_like(multiple_score)
407
+ for _ in range(dist.get_world_size())
408
+ ]
409
+ dist.all_gather(gathered_score_list, multiple_score)
410
+ gathered_score = torch.mean(
411
+ torch.stack(gathered_score_list, dim=0), dim=0
412
+ )
413
+ if args.gpu == 0:
414
+ print(f"Epoch {epoch}, step {i} / {len(self.train_loader)}")
415
+ temp_loss_logger = {}
416
+ for index, (key, value) in enumerate(loss_logger.items()):
417
+ temp_key = key.replace("train/", "iter-wise/")
418
+ temp_loss_logger[temp_key] = round(
419
+ gathered_score[index].item(), 6
420
+ )
421
+ print(f"{key} : {round(gathered_score[index].item(), 6)}")
422
+
423
+ single_score = torch.Tensor([losses.avg]).to(args.gpu)
424
+
425
+ gathered_score_list = [
426
+ torch.ones_like(single_score) for _ in range(dist.get_world_size())
427
+ ]
428
+ dist.all_gather(gathered_score_list, single_score)
429
+ gathered_score = torch.mean(torch.cat(gathered_score_list)).item()
430
+ if args.gpu == 0:
431
+ self.train_losses.append(gathered_score)
432
+ if args.wandb_params.use_wandb:
433
+ loss_logger["train/train loss"] = single_score
434
+ loss_logger["train/epoch"] = epoch
435
+ wandb.log(loss_logger)
436
+ ############################################################
437
+
438
+ def multi_validate(self, args, epoch):
439
+ if args.gpu == 0:
440
+ print(f"Epoch {epoch} Validation session!")
441
+
442
+ losses = utils.AverageMeter()
443
+
444
+ loss_logger = {}
445
+
446
+ self.model.eval()
447
+
448
+ with torch.no_grad():
449
+ for i, values in enumerate(self.valid_loader, start=1):
450
+ mixture, clean, song_name, *valid_vars = values
451
+
452
+ mixture = mixture.cuda(args.gpu, non_blocking=True)
453
+ clean = clean.cuda(args.gpu, non_blocking=True)
454
+ target = clean
455
+
456
+ dict_loss = {}
457
+ if not args.data_params.singleset_num_frames:
458
+ if args.hyperparams.ema:
459
+ estimates, *estimates_vars = self.model_ema(mixture)
460
+ else:
461
+ estimates, *estimates_vars = self.model(mixture)
462
+ if args.task_params.dataset == "delimit":
463
+ estimates = estimates_vars[0]
464
+
465
+ estimates = estimates[..., : clean.size(-1)]
466
+
467
+ else: # use SingleTrackSet
468
+ db = SingleTrackSet(
469
+ mixture[0],
470
+ hop_length=args.data_params.nhop,
471
+ num_frame=args.data_params.singleset_num_frames,
472
+ target_name=args.task_params.target,
473
+ )
474
+ separated = []
475
+
476
+ for item in db:
477
+
478
+ if args.hyperparams.ema:
479
+ estimates, *estimates_vars = self.model_ema(
480
+ item.unsqueeze(0).to(args.gpu)
481
+ )
482
+ else:
483
+ estimates, *estimates_vars = self.model(
484
+ item.unsqueeze(0).to(args.gpu)
485
+ )
486
+
487
+ if args.task_params.dataset == "delimit":
488
+ estimates = estimates_vars[0]
489
+
490
+ separated.append(
491
+ estimates_vars[0][
492
+ ..., db.trim_length : -db.trim_length
493
+ ].clone()
494
+ )
495
+
496
+ estimates = torch.cat(separated, dim=-1)
497
+ estimates = estimates[..., : target.shape[-1]]
498
+
499
+ for valid_loss_idx, single_valid_loss_func in enumerate(
500
+ args.model_loss_params.valid_loss_func
501
+ ):
502
+ loss_else = self.criterion[single_valid_loss_func](
503
+ estimates,
504
+ target,
505
+ )
506
+ dict_loss[single_valid_loss_func] = (
507
+ loss_else.mean()
508
+ * args.model_loss_params.valid_loss_scales[valid_loss_idx]
509
+ )
510
+
511
+ loss = sum([value for key, value in dict_loss.items()])
512
+
513
+ losses.update(loss.item(), clean.size(0))
514
+
515
+ list_sum_count = torch.Tensor([losses.sum, losses.count]).to(args.gpu)
516
+ list_gathered_sum_count = [
517
+ torch.ones_like(list_sum_count) for _ in range(dist.get_world_size())
518
+ ]
519
+ dist.all_gather(list_gathered_sum_count, list_sum_count)
520
+ gathered_score = reduce(
521
+ torch.stack(list_gathered_sum_count), "s c -> c", "sum"
522
+ ) # s: sum of losses.sum, c: sum of losses.count
523
+ gathered_score = (gathered_score[0] / gathered_score[1]).item()
524
+
525
+ loss_logger["valid/valid loss"] = gathered_score
526
+ for key, value in dict_loss.items():
527
+ loss_logger[f"valid/{key}"] = value.item()
528
+
529
+ if args.hyperparams.lr_scheduler == "step_lr":
530
+ self.scheduler.step(gathered_score)
531
+ elif args.hyperparams.lr_scheduler == "cos_warmup":
532
+ self.scheduler.step(epoch)
533
+ else:
534
+ self.scheduler.step(gathered_score)
535
+
536
+ if args.wandb_params.use_wandb and args.gpu == 0:
537
+ loss_logger["valid/epoch"] = epoch
538
+ wandb.log(loss_logger)
539
+
540
+ if args.gpu == 0:
541
+ self.valid_losses.append(gathered_score)
542
+
543
+ self.stop = self.es.step(gathered_score)
544
+
545
+ print(f"Epoch {epoch}, validation loss : {round(gathered_score, 6)}")
546
+
547
+ plt.plot(self.train_losses, label="train loss")
548
+ plt.plot(self.valid_losses, label="valid loss")
549
+ plt.legend(loc="upper right")
550
+ plt.savefig(f"{args.output}/loss_graph_{args.task_params.target}.png")
551
+ plt.close()
552
+
553
+ save_states = {
554
+ "epoch": epoch,
555
+ "state_dict": self.model.module.state_dict()
556
+ if not args.hyperparams.ema
557
+ else self.model_ema.state_dict(),
558
+ "best_loss": self.es.best,
559
+ "optimizer": self.optimizer.state_dict(),
560
+ "scheduler": self.scheduler.state_dict(),
561
+ }
562
+
563
+ utils.save_checkpoint(
564
+ save_states,
565
+ state_dict_only=gathered_score == self.es.best,
566
+ path=args.output,
567
+ target=args.task_params.target,
568
+ )
569
+
570
+ self.train_times.append(time.time() - self.end)
571
+
572
+ if gathered_score == self.es.best:
573
+ self.best_epoch = epoch
574
+
575
+ # save params
576
+ params = {
577
+ "epochs_trained": epoch,
578
+ "args": args.toDict(),
579
+ "best_loss": self.es.best,
580
+ "best_epoch": self.best_epoch,
581
+ "train_loss_history": self.train_losses,
582
+ "valid_loss_history": self.valid_losses,
583
+ "train_time_history": self.train_times,
584
+ "num_bad_epochs": self.es.num_bad_epochs,
585
+ }
586
+
587
+ with open(
588
+ f"{args.output}/{args.task_params.target}.json", "w"
589
+ ) as outfile:
590
+ outfile.write(json.dumps(params, indent=4, sort_keys=True))
591
+
592
+ self.train_times.append(time.time() - self.end)
593
+ print(
594
+ f"Epoch {epoch} train completed. Took {round(self.train_times[-1], 3)} seconds"
595
+ )
596
+
597
+ def resume(self, args):
598
+ print(f"Resume checkpoint from: {args.dir_params.resume}:")
599
+ loc = f"cuda:{args.gpu}"
600
+ checkpoint_path = f"{args.dir_params.resume}/{args.task_params.target}"
601
+ with open(f"{checkpoint_path}.json", "r") as stream:
602
+ results = json.load(stream)
603
+ checkpoint = torch.load(f"{checkpoint_path}.chkpnt", map_location=loc)
604
+
605
+ if args.hyperparams.ema:
606
+ self.model_ema.load_state_dict(checkpoint["state_dict"])
607
+ else:
608
+ self.model.load_state_dict(checkpoint["state_dict"])
609
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
610
+
611
+ if (
612
+ args.dir_params.continual_train
613
+ ): # we want to use a pre-trained model but not want to use lr_scheduler history.
614
+ for param_group in self.optimizer.param_groups:
615
+ param_group["lr"] = args.hyperparams.lr
616
+ else:
617
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
618
+ self.es.best = results["best_loss"]
619
+ self.es.num_bad_epochs = results["num_bad_epochs"]
620
+
621
+ self.start_epoch = results["epochs_trained"]
622
+ self.train_losses = results["train_loss_history"]
623
+ self.valid_losses = results["valid_loss_history"]
624
+ self.train_times = results["train_time_history"]
625
+ self.best_epoch = results["best_epoch"]
626
+ if args.sys_params.rank % args.ngpus_per_node == 0:
627
+ print(
628
+ f"=> loaded checkpoint {checkpoint_path} (epoch {results['epochs_trained']})"
629
+ )
630
+
631
+ def cal_loss(self, args, loss_input):
632
+ loss_dict = {}
633
+ for key, value in loss_input.items():
634
+ loss_dict[key] = self.criterion[key](*value)
635
+
636
+ return loss_dict
637
+
638
+ def cal_multiple_losses(self, args, dict_loss_name_input):
639
+ loss_dict = {}
640
+ for loss_name, loss_input in dict_loss_name_input.items():
641
+ loss_dict[loss_name] = self.cal_loss(args, loss_input)
642
+
643
+ return loss_dict
test_ddp.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # To be honest... this is not ddp.
2
+ import os
3
+ import json
4
+ import argparse
5
+ import glob
6
+
7
+ import torch
8
+ import tqdm
9
+ import musdb
10
+ import librosa
11
+ import soundfile as sf
12
+ import pyloudnorm as pyln
13
+ from dotmap import DotMap
14
+
15
+ from models import load_model_with_args
16
+ from separate_func import (
17
+ conv_tasnet_separate,
18
+ )
19
+ from utils import str2bool, db2linear
20
+
21
+
22
+ tqdm.monitor_interval = 0
23
+
24
+
25
+ def separate_track_with_model(
26
+ args, model, device, track_audio, track_name, meter, augmented_gain
27
+ ):
28
+ with torch.no_grad():
29
+ if (
30
+ args.model_loss_params.architecture == "conv_tasnet_mask_on_output"
31
+ or args.model_loss_params.architecture == "conv_tasnet"
32
+ ):
33
+ estimates = conv_tasnet_separate(
34
+ args,
35
+ model,
36
+ device,
37
+ track_audio,
38
+ track_name,
39
+ meter=meter,
40
+ augmented_gain=augmented_gain,
41
+ )
42
+
43
+ return estimates
44
+
45
+
46
+ def main():
47
+ parser = argparse.ArgumentParser(description="model test.py")
48
+
49
+ parser.add_argument("--target", type=str, default="all")
50
+ parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL")
51
+ parser.add_argument(
52
+ "--use_musdb",
53
+ type=str2bool,
54
+ default=True,
55
+ help="Use musdb test data or just want to inference other samples?",
56
+ )
57
+ parser.add_argument("--exp_name", type=str, default="delimit_6_s')
58
+ parser.add_argument("--manual_output_name", type=str, default=None)
59
+ parser.add_argument(
60
+ "--output_directory", type=str, default="/path/to/results"
61
+ )
62
+ parser.add_argument("--use_gpu", type=str2bool, default=True)
63
+ parser.add_arugment("--save_name_as_target", type=str2bool, default=True)
64
+ parser.add_argument(
65
+ "--loudnorm_input_lufs",
66
+ type=float,
67
+ default=None,
68
+ help="If you want to use loudnorm, input target lufs",
69
+ )
70
+ parser.add_argument(
71
+ "--use_singletrackset",
72
+ type=str2bool,
73
+ default=False,
74
+ help="Use SingleTrackSet for X-UMX",
75
+ )
76
+ parser.add_argument(
77
+ "--best_model",
78
+ type=str2bool,
79
+ default=True,
80
+ help="Use best model or lastly saved model",
81
+ )
82
+ parser.add_argument(
83
+ "--save_output_loudnorm",
84
+ type=float,
85
+ default=None,
86
+ help="Save loudness normalized outputs or not. If you want to save, input target loudness",
87
+ )
88
+ parser.add_argument(
89
+ "--save_mixed_output",
90
+ type=float,
91
+ default=None,
92
+ help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)",
93
+ )
94
+ parser.add_argument(
95
+ "--save_16k_mono",
96
+ type=str2bool,
97
+ default=False,
98
+ help="Save 16k mono wav files for FAD evaluation.",
99
+ )
100
+ parser.add_argument(
101
+ "--save_histogram",
102
+ type=str2bool,
103
+ default=False,
104
+ help="Save histogram of the output. Only valid when the task is 'delimit'",
105
+ )
106
+
107
+ args, _ = parser.parse_known_args()
108
+
109
+ args.output_dir = f"{args.output_directory}/checkpoint/{args.exp_name}"
110
+ with open(f"{args.output_dir}/{args.target}.json", "r") as f:
111
+ args_dict = json.load(f)
112
+ args_dict = DotMap(args_dict)
113
+
114
+ for key, value in args_dict["args"].items():
115
+ if key in list(vars(args).keys()):
116
+ pass
117
+ else:
118
+ setattr(args, key, value)
119
+
120
+ args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
121
+
122
+ if args.manual_output_name != None:
123
+ args.test_output_dir = f"{args.output_directory}/test/{args.manual_output_name}"
124
+ os.makedirs(args.test_output_dir, exist_ok=True)
125
+
126
+ device = torch.device(
127
+ "cuda" if torch.cuda.is_available() and args.use_gpu else "cpu"
128
+ )
129
+
130
+ ###################### Define Models ######################
131
+ our_model = load_model_with_args(args)
132
+ our_model = our_model.to(device)
133
+ print(our_model)
134
+ pytorch_total_params = sum(
135
+ p.numel() for p in our_model.parameters() if p.requires_grad
136
+ )
137
+ print("Total number of parameters", pytorch_total_params)
138
+ # Future work => Torchinfo would be better for this purpose.
139
+
140
+ if args.best_model:
141
+ target_model_path = f"{args.output_dir}/{args.target}.pth"
142
+ checkpoint = torch.load(target_model_path, map_location=device)
143
+ our_model.load_state_dict(checkpoint)
144
+ else: # when using lastly saved model
145
+ target_model_path = f"{args.output_dir}/{args.target}.chkpnt"
146
+ checkpoint = torch.load(target_model_path, map_location=device)
147
+ our_model.load_state_dict(checkpoint["state_dict"])
148
+
149
+ our_model.eval()
150
+
151
+ meter = pyln.Meter(44100)
152
+
153
+ if args.use_musdb:
154
+ test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True)
155
+
156
+ for track in tqdm.tqdm(test_tracks):
157
+ track_name = track.name
158
+ track_audio = track.audio
159
+
160
+ orig_audio = track_audio.copy()
161
+
162
+ augmented_gain = None
163
+ print("Now De-limiting : ", track_name)
164
+
165
+ if args.loudnorm_input_lufs: # If you want to use loud-normalized input
166
+ track_lufs = meter.integrated_loudness(track_audio)
167
+ augmented_gain = args.loudnorm_input_lufs - track_lufs
168
+ track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
169
+
170
+ track_audio = (
171
+ torch.as_tensor(track_audio.T, dtype=torch.float32)
172
+ .unsqueeze(0)
173
+ .to(device)
174
+ )
175
+
176
+ estimates = separate_track_with_model(
177
+ args, our_model, device, track_audio, track_name, meter, augmented_gain
178
+ )
179
+
180
+ if args.save_mixed_output:
181
+ orig_audio = orig_audio.T
182
+ track_lufs = meter.integrated_loudness(orig_audio.T)
183
+ augmented_gain = args.save_output_loudnorm - track_lufs
184
+ orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
185
+
186
+ mixed_output = orig_audio * args.save_mixed_output + estimates * (
187
+ 1 - args.save_mixed_output
188
+ )
189
+
190
+ sf.write(
191
+ f"{args.test_output_dir}/{track_name}/{str(args.save_mixed_output)}_mixed.wav",
192
+ mixed_output.T,
193
+ args.data_params.sample_rate,
194
+ )
195
+ else:
196
+ test_tracks = glob.glob(f"{args.data_root}/*.wav") + glob.glob(
197
+ f"{args.data_root}/*.mp3"
198
+ )
199
+
200
+ for track in tqdm.tqdm(test_tracks):
201
+ track_name = os.path.basename(track).replace(".wav", "").replace(".mp3", "")
202
+ track_audio, sr = librosa.load(
203
+ track, sr=None, mono=False
204
+ ) # sr should be 44100
205
+
206
+ orig_audio = track_audio.copy()
207
+
208
+ if sr != 44100:
209
+ raise ValueError("Sample rate should be 44100")
210
+ augmented_gain = None
211
+ print("Now De-limiting : ", track_name)
212
+
213
+ if args.loudnorm_input_lufs: # If you want to use loud-normalized input
214
+ track_lufs = meter.integrated_loudness(track_audio.T)
215
+ augmented_gain = args.loudnorm_input_lufs - track_lufs
216
+ track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
217
+
218
+ track_audio = (
219
+ torch.as_tensor(track_audio, dtype=torch.float32)
220
+ .unsqueeze(0)
221
+ .to(device)
222
+ )
223
+
224
+ estimates = separate_track_with_model(
225
+ args, our_model, device, track_audio, track_name, meter, augmented_gain
226
+ )
227
+
228
+ if args.save_mixed_output:
229
+ track_lufs = meter.integrated_loudness(orig_audio.T)
230
+ augmented_gain = args.save_output_loudnorm - track_lufs
231
+ orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
232
+
233
+ mixed_output = orig_audio * args.save_mixed_output + estimates * (
234
+ 1 - args.save_mixed_output
235
+ )
236
+
237
+ sf.write(
238
+ f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav",
239
+ mixed_output.T,
240
+ args.data_params.sample_rate,
241
+ )
242
+
243
+
244
+ if __name__ == "__main__":
245
+ main()
train_ddp.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+
4
+ import torch
5
+ import torch.multiprocessing as mp
6
+ import torch.distributed as dist
7
+ import wandb
8
+
9
+ from solver_ddp import Solver
10
+
11
+
12
+ def train(args):
13
+ print("hello")
14
+ solver = Solver()
15
+
16
+ ngpus_per_node = int(torch.cuda.device_count() / args.sys_params.n_nodes)
17
+ print(f"use {ngpus_per_node} gpu machine")
18
+ args.sys_params.world_size = ngpus_per_node * args.sys_params.n_nodes
19
+ mp.spawn(worker, nprocs=ngpus_per_node, args=(solver, ngpus_per_node, args))
20
+
21
+
22
+ def worker(gpu, solver, ngpus_per_node, args):
23
+ args.sys_params.rank = args.sys_params.rank * ngpus_per_node + gpu
24
+ dist.init_process_group(
25
+ backend="nccl",
26
+ world_size=args.sys_params.world_size,
27
+ init_method="env://",
28
+ rank=args.sys_params.rank,
29
+ )
30
+ args.gpu = gpu
31
+ args.ngpus_per_node = ngpus_per_node
32
+
33
+ solver.set_gpu(args)
34
+
35
+ start_epoch = solver.start_epoch
36
+
37
+ if args.dir_params.resume:
38
+ start_epoch = start_epoch + 1
39
+
40
+ for epoch in range(start_epoch, args.hyperparams.epochs + 1):
41
+
42
+ solver.train_sampler.set_epoch(epoch)
43
+ solver.train(args, epoch)
44
+
45
+ time.sleep(1)
46
+
47
+ solver.multi_validate(args, epoch)
48
+
49
+ if solver.stop == True:
50
+ print("Apply Early Stopping")
51
+ if args.wandb_params.use_wandb:
52
+ wandb.finish()
53
+ sys.exit()
54
+
55
+ if args.wandb_params.use_wandb:
56
+ wandb.finish()
utils/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .read_wave_utils import (
2
+ load_wav_arbitrary_position_mono,
3
+ load_wav_specific_position_mono,
4
+ load_wav_arbitrary_position_stereo,
5
+ load_wav_specific_position_stereo,
6
+ )
7
+ from .loudness_utils import (
8
+ linear2db,
9
+ db2linear,
10
+ normalize_mag_spec,
11
+ denormalize_mag_spec,
12
+ loudness_match_and_norm,
13
+ loudness_normal_match_and_norm,
14
+ loudness_normal_match_and_norm_output_louder_first,
15
+ loudnorm,
16
+ )
17
+ from .logging import save_img_and_npy, save_checkpoint, AverageMeter, EarlyStopping
18
+ from .lr_scheduler import CosineAnnealingWarmUpRestarts
19
+ from .train_utils import worker_init_fn, str2bool, get_config
utils/logging.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib
6
+
7
+ matplotlib.use("Agg")
8
+ import matplotlib.pyplot as plt
9
+
10
+
11
+ def save_img_and_npy(path, matrix):
12
+ plt.imsave(path + ".png", matrix, origin="lower")
13
+
14
+
15
+ def save_checkpoint(state, state_dict_only, path, target):
16
+ torch.save(state, os.path.join(path, target + ".chkpnt"))
17
+ if state_dict_only:
18
+ # save just the weights
19
+ torch.save(state["state_dict"], os.path.join(path, target + ".pth"))
20
+
21
+
22
+ class AverageMeter(object):
23
+ """Computes and stores the average and current value"""
24
+
25
+ def __init__(self):
26
+ self.reset()
27
+
28
+ def reset(self):
29
+ self.val = 0
30
+ self.avg = 0
31
+ self.sum = 0
32
+ self.count = 0
33
+
34
+ def update(self, val, n=1):
35
+ self.val = val
36
+ self.sum += val * n
37
+ self.count += n
38
+ self.avg = self.sum / self.count
39
+
40
+
41
+ class EarlyStopping(object):
42
+ def __init__(self, mode="min", min_delta=0, patience=10):
43
+ self.mode = mode
44
+ self.min_delta = min_delta
45
+ self.patience = patience
46
+ self.best = None
47
+ self.num_bad_epochs = 0
48
+ self.is_better = None
49
+ self._init_is_better(mode, min_delta)
50
+
51
+ if patience == 0:
52
+ self.is_better = lambda a, b: True
53
+
54
+ def step(self, metrics):
55
+ if self.best is None:
56
+ self.best = metrics
57
+ return False
58
+
59
+ if np.isnan(metrics):
60
+ return True
61
+
62
+ if self.is_better(metrics, self.best):
63
+ self.num_bad_epochs = 0
64
+ self.best = metrics
65
+ else:
66
+ self.num_bad_epochs += 1
67
+
68
+ if self.num_bad_epochs >= self.patience:
69
+ return True
70
+
71
+ return False
72
+
73
+ def _init_is_better(self, mode, min_delta):
74
+ if mode not in {"min", "max"}:
75
+ raise ValueError("mode " + mode + " is unknown!")
76
+ if mode == "min":
77
+ self.is_better = lambda a, best: a < best - min_delta
78
+ if mode == "max":
79
+ self.is_better = lambda a, best: a > best + min_delta
utils/loudness_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def linear2db(x, eps=1e-5, scale=20):
8
+ return scale * np.log10(x + eps)
9
+
10
+
11
+ def db2linear(x, eps=1e-5, scale=20):
12
+ return 10 ** (x / scale) - eps
13
+
14
+
15
+ def normalize_mag_spec(S, min_level_db=-100.0):
16
+ return torch.clamp((S - min_level_db) / -min_level_db, min=0.0, max=1.0)
17
+
18
+
19
+ def denormalize_mag_spec(S, min_level_db=-100.0):
20
+ return torch.clamp(S, min=0.0, max=1.0) * -min_level_db + min_level_db
21
+
22
+
23
+ def loudness_match_and_norm(audio1, audio2, meter):
24
+ lufs_1 = meter.integrated_loudness(audio1)
25
+ lufs_2 = meter.integrated_loudness(audio2)
26
+
27
+ if np.isinf(lufs_1) or np.isinf(lufs_2):
28
+ return audio1, audio2
29
+ else:
30
+ audio2 = audio2 * db2linear(lufs_1 - lufs_2)
31
+
32
+ return audio1, audio2
33
+
34
+
35
+ def loudness_normal_match_and_norm(audio1, audio2, meter):
36
+ lufs_1 = meter.integrated_loudness(audio1)
37
+ lufs_2 = meter.integrated_loudness(audio2)
38
+
39
+ if np.isinf(lufs_1) or np.isinf(lufs_2):
40
+ return audio1, audio2
41
+ else:
42
+ target_lufs = random.normalvariate(lufs_1, 6.0)
43
+ audio2 = audio2 * db2linear(target_lufs - lufs_2)
44
+
45
+ return audio1, audio2
46
+
47
+
48
+ def loudness_normal_match_and_norm_output_louder_first(audio1, audio2, meter):
49
+ lufs_1 = meter.integrated_loudness(audio1)
50
+ lufs_2 = meter.integrated_loudness(audio2)
51
+
52
+ if np.isinf(lufs_1) or np.isinf(lufs_2):
53
+ return audio1, audio2
54
+ else:
55
+ target_lufs = random.normalvariate(
56
+ lufs_1 - 2.0, 2.0
57
+ ) # we want audio1 to be louder than audio2 about target_lufs_diff
58
+ audio2 = audio2 * db2linear(target_lufs - lufs_2)
59
+
60
+ return audio1, audio2
61
+
62
+
63
+ def loudnorm(audio, target_lufs, meter, eps=1e-5):
64
+ lufs = meter.integrated_loudness(audio)
65
+ if np.isinf(lufs):
66
+ return audio, 0.0
67
+ else:
68
+ adjusted_gain = target_lufs - lufs
69
+ audio = audio * db2linear(adjusted_gain, eps)
70
+
71
+ return audio, adjusted_gain
utils/lr_scheduler.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+
6
+ class CosineAnnealingWarmUpRestarts(_LRScheduler):
7
+ def __init__(
8
+ self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1.0, last_epoch=-1
9
+ ):
10
+ if T_0 <= 0 or not isinstance(T_0, int):
11
+ raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
12
+ if T_mult < 1 or not isinstance(T_mult, int):
13
+ raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
14
+ if T_up < 0 or not isinstance(T_up, int):
15
+ raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
16
+ self.T_0 = T_0
17
+ self.T_mult = T_mult
18
+ self.base_eta_max = eta_max
19
+ self.eta_max = eta_max
20
+ self.T_up = T_up
21
+ self.T_i = T_0
22
+ self.gamma = gamma
23
+ self.cycle = 0
24
+ self.T_cur = last_epoch
25
+ super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
26
+
27
+ def get_lr(self):
28
+ if self.T_cur == -1:
29
+ return self.base_lrs
30
+ elif self.T_cur < self.T_up:
31
+ return [
32
+ (self.eta_max - base_lr) * self.T_cur / self.T_up + base_lr
33
+ for base_lr in self.base_lrs
34
+ ]
35
+ else:
36
+ return [
37
+ base_lr
38
+ + (self.eta_max - base_lr)
39
+ * (
40
+ 1
41
+ + math.cos(
42
+ math.pi * (self.T_cur - self.T_up) / (self.T_i - self.T_up)
43
+ )
44
+ )
45
+ / 2
46
+ for base_lr in self.base_lrs
47
+ ]
48
+
49
+ def step(self, epoch=None):
50
+ if epoch is None:
51
+ epoch = self.last_epoch + 1
52
+ self.T_cur = self.T_cur + 1
53
+ if self.T_cur >= self.T_i:
54
+ self.cycle += 1
55
+ self.T_cur = self.T_cur - self.T_i
56
+ self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
57
+ else:
58
+ if epoch >= self.T_0:
59
+ if self.T_mult == 1:
60
+ self.T_cur = epoch % self.T_0
61
+ self.cycle = epoch // self.T_0
62
+ else:
63
+ n = int(
64
+ math.log(
65
+ (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
66
+ )
67
+ )
68
+ self.cycle = n
69
+ self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (
70
+ self.T_mult - 1
71
+ )
72
+ self.T_i = self.T_0 * self.T_mult ** (n)
73
+ else:
74
+ self.T_i = self.T_0
75
+ self.T_cur = epoch
76
+
77
+ self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
78
+ self.last_epoch = math.floor(epoch)
79
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
80
+ param_group["lr"] = lr
utils/read_wave_utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+
4
+ import numpy as np
5
+ import librosa
6
+ import torchaudio
7
+
8
+
9
+ def load_wav_arbitrary_position_mono(filename, sample_rate, seq_duration):
10
+ # mono
11
+ # seq_duration[second]
12
+ length = torchaudio.info(filename).num_frames
13
+
14
+ read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
15
+ if length > read_length:
16
+ random_start = random.randint(0, int(length - read_length - 1)) / sample_rate
17
+ X, sr = librosa.load(
18
+ filename, sr=None, offset=random_start, duration=seq_duration
19
+ )
20
+ else:
21
+ random_start = 0
22
+ total_pad_length = read_length - length
23
+ X, sr = librosa.load(filename, sr=None, offset=0, duration=seq_duration)
24
+ pad_left = random.randint(0, total_pad_length)
25
+ X = np.pad(X, (pad_left, total_pad_length - pad_left))
26
+
27
+ return X
28
+
29
+
30
+ def load_wav_specific_position_mono(
31
+ filename, sample_rate, seq_duration, start_position
32
+ ):
33
+ # mono
34
+ # seq_duration[second]
35
+ # start_position[second]
36
+ length = torchaudio.info(filename).num_frames
37
+ read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
38
+
39
+ start_pos_sec = max(
40
+ start_position, 0
41
+ ) # if start_position is minus, then start from 0.
42
+ start_pos_sample = librosa.time_to_samples(start_pos_sec, sr=sample_rate)
43
+
44
+ if (
45
+ length <= start_pos_sample
46
+ ): # if start position exceeds audio length, then start from 0.
47
+ start_pos_sec = 0
48
+ start_pos_sample = 0
49
+ X, sr = librosa.load(filename, sr=None, offset=start_pos_sec, duration=seq_duration)
50
+
51
+ if length < start_pos_sample + read_length:
52
+ X = np.pad(X, (0, (start_pos_sample + read_length) - length))
53
+
54
+ return X
55
+
56
+
57
+ # load wav file from arbitrary positions of 16bit stereo wav file
58
+ def load_wav_arbitrary_position_stereo(
59
+ filename, sample_rate, seq_duration, return_pos=False
60
+ ):
61
+ # stereo
62
+ # seq_duration[second]
63
+ length = torchaudio.info(filename).num_frames
64
+ read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
65
+
66
+ random_start_sample = random.randint(
67
+ 0, int(length - math.ceil(seq_duration * sample_rate) - 1)
68
+ )
69
+ random_start_sec = librosa.samples_to_time(random_start_sample, sr=sample_rate)
70
+ X, sr = librosa.load(
71
+ filename, sr=None, mono=False, offset=random_start_sec, duration=seq_duration
72
+ )
73
+
74
+ if length < random_start_sample + read_length:
75
+ X = np.pad(X, ((0, 0), (0, (random_start_sample + read_length) - length)))
76
+
77
+ if return_pos:
78
+ return X, random_start_sec
79
+ else:
80
+ return X
81
+
82
+
83
+ def load_wav_specific_position_stereo(
84
+ filename, sample_rate, seq_duration, start_position
85
+ ):
86
+ # stereo
87
+ # seq_duration[second]
88
+ # start_position[second]
89
+ length = torchaudio.info(filename).num_frames
90
+ read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
91
+
92
+ start_pos_sec = max(
93
+ start_position, 0
94
+ ) # if start_position is minus, then start from 0.
95
+ start_pos_sample = librosa.time_to_samples(start_pos_sec, sr=sample_rate)
96
+
97
+ if (
98
+ length <= start_pos_sample
99
+ ): # if start position exceeds audio length, then start from 0.
100
+ start_pos_sec = 0
101
+ start_pos_sample = 0
102
+ X, sr = librosa.load(
103
+ filename, sr=None, mono=False, offset=start_pos_sec, duration=seq_duration
104
+ )
105
+
106
+ if length < start_pos_sample + read_length:
107
+ X = np.pad(X, ((0, 0), (0, (start_pos_sample + read_length) - length)))
108
+
109
+ return X
utils/train_utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import yaml
4
+ from dotmap import DotMap
5
+ import numpy as np
6
+
7
+
8
+ def worker_init_fn(worker_id):
9
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
10
+
11
+
12
+ def str2bool(v):
13
+ if v.lower() in ("yes", "true", "t", "y", "1"):
14
+ return True
15
+ elif v.lower() in ("no", "false", "f", "n", "0"):
16
+ return False
17
+ else:
18
+ raise argparse.ArgumentTypeError("Boolean value expected.")
19
+
20
+
21
+ def get_config(config_name="default"):
22
+
23
+ with open(f"./configs/{config_name}.yaml", "r") as f:
24
+
25
+ config = yaml.load(f, Loader=yaml.FullLoader)
26
+ config = DotMap(config)
27
+ return config
weight/all.json ADDED
@@ -0,0 +1,957 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "args": {
3
+ "classifier_params": {
4
+ "chosen_source_mean": 0.7,
5
+ "chosen_source_std": 0.15,
6
+ "classifier_activation": "softmax",
7
+ "classifier_n_classes": 4,
8
+ "classifier_n_srcs": 4,
9
+ "freeze_when_mixit": true,
10
+ "melspec_power": 2.0,
11
+ "model_name": "hrnet_w18_small",
12
+ "n_mels": 128,
13
+ "other_source_mean": 0.3,
14
+ "other_source_std": 0.15,
15
+ "pretrained_model": false,
16
+ "use_one_source_prob": 0.2,
17
+ "use_stereo": true
18
+ },
19
+ "conv_tasnet_params": {
20
+ "bn_chan": 128,
21
+ "decoder_activation": "sigmoid",
22
+ "encoder_activation": "relu",
23
+ "hid_chan": 512,
24
+ "kernel_size": 128,
25
+ "mask_act": "relu",
26
+ "n_blocks": 5,
27
+ "n_filters": 512,
28
+ "n_repeats": 2,
29
+ "skip_chan": 128,
30
+ "stride": 64
31
+ },
32
+ "data_params": {
33
+ "custom_limiter_attack_range": null,
34
+ "custom_limiter_release_range": null,
35
+ "limitaug_custom_target_lufs": null,
36
+ "limitaug_custom_target_lufs_std": null,
37
+ "limitaug_method": "ozone",
38
+ "limitaug_mode": null,
39
+ "nb_channels": 2,
40
+ "nfft": 4096,
41
+ "nhop": 1024,
42
+ "random_mix": true,
43
+ "sample_rate": 44100,
44
+ "samples_per_track": 128,
45
+ "seq_dur": 4.0,
46
+ "singleset_num_frames": null,
47
+ "target_limitaug_custom_target_lufs": null,
48
+ "target_limitaug_custom_target_lufs_std": null,
49
+ "target_limitaug_mode": null,
50
+ "target_loudnorm_lufs": -14.0,
51
+ "use_fixed": 0.019
52
+ },
53
+ "dir_params": {
54
+ "continual_train": false,
55
+ "delimit_valid_L_root": null,
56
+ "delimit_valid_root": null,
57
+ "exp_name": "convtasnet_35",
58
+ "output_directory": "/data2/personal/jeon/delimit/results",
59
+ "ozone_root": "/data5/personal/jeon/delimit/data",
60
+ "pretrained_classifier": null,
61
+ "resume": null,
62
+ "root": "/data1/Music/musdb18hq"
63
+ },
64
+ "gpu": 0,
65
+ "hyperparams": {
66
+ "batch_size": 8,
67
+ "ema": false,
68
+ "epochs": 200,
69
+ "gradient_clip": 5.0,
70
+ "lr": 3e-05,
71
+ "lr_decay_gamma": 0.5,
72
+ "lr_decay_patience": 15,
73
+ "lr_scheduler": "step_lr",
74
+ "optimizer": "adamw",
75
+ "patience": 50,
76
+ "weight_decay": 0.01
77
+ },
78
+ "img_check": "/data2/personal/jeon/delimit/results/img_check/convtasnet_35",
79
+ "invest_unet_params": {
80
+ "bn_factor": 16,
81
+ "f_down_layers": null,
82
+ "first_conv_activation": "relu",
83
+ "input_channels": 4,
84
+ "internal_channels": 24,
85
+ "kernel_size_f": 3,
86
+ "kernel_size_t": 3,
87
+ "last_activation": "identity",
88
+ "min_bn_units": 16,
89
+ "n_blocks": 7,
90
+ "n_internal_layers": 5,
91
+ "t_down_layers": null,
92
+ "tfc_tdf_activation": "relu",
93
+ "tfc_tdf_bias": true,
94
+ "tif_init_mode": null
95
+ },
96
+ "model_loss_params": {
97
+ "architecture": "conv_tasnet_mask_on_output",
98
+ "efficient_mixit_threshold": null,
99
+ "train_loss_func": [
100
+ "si_sdr"
101
+ ],
102
+ "train_loss_scales": [
103
+ 1.0
104
+ ],
105
+ "valid_loss_func": [
106
+ "si_sdr"
107
+ ],
108
+ "valid_loss_scales": [
109
+ 1.0
110
+ ]
111
+ },
112
+ "ngpus_per_node": 1,
113
+ "output": "/data2/personal/jeon/delimit/results/checkpoint/convtasnet_35",
114
+ "resume": {},
115
+ "sample_rate": {},
116
+ "sys_params": {
117
+ "n_nodes": 1,
118
+ "nb_workers": 4,
119
+ "port": null,
120
+ "rank": 0,
121
+ "seed": 777,
122
+ "world_size": 1
123
+ },
124
+ "task_params": {
125
+ "bleeding_nsrcs": null,
126
+ "dataset": "delimit",
127
+ "target": "all",
128
+ "train": true
129
+ },
130
+ "umx_params": {
131
+ "activation": "relu",
132
+ "dropout_rate": 0.05,
133
+ "hidden_size": 512,
134
+ "instead_tanh_activation": "tanh",
135
+ "lstm_dropout_rate": 0.4,
136
+ "nb_layers": 3,
137
+ "normalization": "bn",
138
+ "umx_get_statistics": false
139
+ },
140
+ "wandb_params": {
141
+ "entity": "vinyne",
142
+ "project": "delimit",
143
+ "rerun_id": null,
144
+ "sweep": false,
145
+ "use_wandb": true
146
+ }
147
+ },
148
+ "best_epoch": 183,
149
+ "best_loss": -14.165373802185059,
150
+ "epochs_trained": 200,
151
+ "num_bad_epochs": 17,
152
+ "train_loss_history": [
153
+ -11.723381042480469,
154
+ -11.759103775024414,
155
+ -11.818404197692871,
156
+ -11.88597583770752,
157
+ -11.882278442382812,
158
+ -11.943178176879883,
159
+ -11.909675598144531,
160
+ -11.93053913116455,
161
+ -11.922198295593262,
162
+ -12.013456344604492,
163
+ -12.106053352355957,
164
+ -11.999975204467773,
165
+ -12.067265510559082,
166
+ -12.079473495483398,
167
+ -12.13272762298584,
168
+ -12.15418529510498,
169
+ -12.08314037322998,
170
+ -12.152527809143066,
171
+ -12.096565246582031,
172
+ -12.219636917114258,
173
+ -12.246475219726562,
174
+ -12.170637130737305,
175
+ -12.188806533813477,
176
+ -12.230484962463379,
177
+ -12.207123756408691,
178
+ -12.307502746582031,
179
+ -12.200200080871582,
180
+ -12.284586906433105,
181
+ -12.244038581848145,
182
+ -12.302275657653809,
183
+ -12.200104713439941,
184
+ -12.31570816040039,
185
+ -12.42324447631836,
186
+ -12.352653503417969,
187
+ -12.367401123046875,
188
+ -12.295838356018066,
189
+ -12.404874801635742,
190
+ -12.338440895080566,
191
+ -12.365501403808594,
192
+ -12.365768432617188,
193
+ -12.225799560546875,
194
+ -12.26883602142334,
195
+ -12.390016555786133,
196
+ -12.410661697387695,
197
+ -12.311858177185059,
198
+ -12.408061027526855,
199
+ -12.396013259887695,
200
+ -12.353321075439453,
201
+ -12.470121383666992,
202
+ -12.469389915466309,
203
+ -12.452675819396973,
204
+ -12.381932258605957,
205
+ -12.31003475189209,
206
+ -12.412126541137695,
207
+ -12.267746925354004,
208
+ -12.440984725952148,
209
+ -12.413816452026367,
210
+ -12.417757034301758,
211
+ -12.4945650100708,
212
+ -12.445524215698242,
213
+ -12.38110065460205,
214
+ -12.454893112182617,
215
+ -12.390727996826172,
216
+ -12.339771270751953,
217
+ -12.528243064880371,
218
+ -12.434144973754883,
219
+ -12.43438720703125,
220
+ -12.458473205566406,
221
+ -12.424423217773438,
222
+ -12.387894630432129,
223
+ -12.438997268676758,
224
+ -12.528799057006836,
225
+ -12.423232078552246,
226
+ -12.534538269042969,
227
+ -12.495400428771973,
228
+ -12.53675651550293,
229
+ -12.551910400390625,
230
+ -12.478575706481934,
231
+ -12.461804389953613,
232
+ -12.483702659606934,
233
+ -12.474960327148438,
234
+ -12.441666603088379,
235
+ -12.42241096496582,
236
+ -12.48852252960205,
237
+ -12.513558387756348,
238
+ -12.40845012664795,
239
+ -12.555559158325195,
240
+ -12.589385032653809,
241
+ -12.395785331726074,
242
+ -12.496671676635742,
243
+ -12.554829597473145,
244
+ -12.530548095703125,
245
+ -12.564457893371582,
246
+ -12.52737808227539,
247
+ -12.608246803283691,
248
+ -12.3996000289917,
249
+ -12.433905601501465,
250
+ -12.490935325622559,
251
+ -12.477506637573242,
252
+ -12.470728874206543,
253
+ -12.564470291137695,
254
+ -12.525967597961426,
255
+ -12.502660751342773,
256
+ -12.440997123718262,
257
+ -12.576118469238281,
258
+ -12.538352966308594,
259
+ -12.512738227844238,
260
+ -12.525115966796875,
261
+ -12.511483192443848,
262
+ -12.571795463562012,
263
+ -12.59391975402832,
264
+ -12.442131996154785,
265
+ -12.617898941040039,
266
+ -12.495210647583008,
267
+ -12.551814079284668,
268
+ -12.4913330078125,
269
+ -12.626816749572754,
270
+ -12.556028366088867,
271
+ -12.477901458740234,
272
+ -12.596776008605957,
273
+ -12.597326278686523,
274
+ -12.484386444091797,
275
+ -12.660898208618164,
276
+ -12.440162658691406,
277
+ -12.530372619628906,
278
+ -12.51207447052002,
279
+ -12.503606796264648,
280
+ -12.670214653015137,
281
+ -12.51667308807373,
282
+ -12.546160697937012,
283
+ -12.504158020019531,
284
+ -12.6427001953125,
285
+ -12.56100082397461,
286
+ -12.506058692932129,
287
+ -12.637288093566895,
288
+ -12.572591781616211,
289
+ -12.544734001159668,
290
+ -12.604019165039062,
291
+ -12.549866676330566,
292
+ -12.521714210510254,
293
+ -12.601127624511719,
294
+ -12.629931449890137,
295
+ -12.587185859680176,
296
+ -12.605366706848145,
297
+ -12.606413841247559,
298
+ -12.536269187927246,
299
+ -12.577346801757812,
300
+ -12.703147888183594,
301
+ -12.60477066040039,
302
+ -12.603355407714844,
303
+ -12.536528587341309,
304
+ -12.601842880249023,
305
+ -12.698568344116211,
306
+ -12.72192668914795,
307
+ -12.663148880004883,
308
+ -12.644909858703613,
309
+ -12.631479263305664,
310
+ -12.596253395080566,
311
+ -12.61674690246582,
312
+ -12.701379776000977,
313
+ -12.664311408996582,
314
+ -12.646204948425293,
315
+ -12.597058296203613,
316
+ -12.652384757995605,
317
+ -12.579480171203613,
318
+ -12.757433891296387,
319
+ -12.686827659606934,
320
+ -12.65634536743164,
321
+ -12.552176475524902,
322
+ -12.625761032104492,
323
+ -12.652499198913574,
324
+ -12.668974876403809,
325
+ -12.700301170349121,
326
+ -12.591926574707031,
327
+ -12.54333782196045,
328
+ -12.541864395141602,
329
+ -12.720565795898438,
330
+ -12.625009536743164,
331
+ -12.577120780944824,
332
+ -12.67569637298584,
333
+ -12.634958267211914,
334
+ -12.660367012023926,
335
+ -12.646204948425293,
336
+ -12.713308334350586,
337
+ -12.734916687011719,
338
+ -12.602835655212402,
339
+ -12.596168518066406,
340
+ -12.66109848022461,
341
+ -12.568808555603027,
342
+ -12.719843864440918,
343
+ -12.746356010437012,
344
+ -12.602999687194824,
345
+ -12.632689476013184,
346
+ -12.715725898742676,
347
+ -12.671126365661621,
348
+ -12.659911155700684,
349
+ -12.755860328674316,
350
+ -12.591080665588379,
351
+ -12.623464584350586,
352
+ -12.643362045288086
353
+ ],
354
+ "train_time_history": [
355
+ 308.12283968925476,
356
+ 308.12408661842346,
357
+ 305.56318974494934,
358
+ 305.6093053817749,
359
+ 304.1926734447479,
360
+ 304.2103099822998,
361
+ 301.78035831451416,
362
+ 301.7819468975067,
363
+ 317.8168547153473,
364
+ 317.818119764328,
365
+ 314.8585801124573,
366
+ 314.8601076602936,
367
+ 311.61795926094055,
368
+ 311.61953926086426,
369
+ 316.2616910934448,
370
+ 316.2639091014862,
371
+ 312.59282636642456,
372
+ 312.59408020973206,
373
+ 314.6765525341034,
374
+ 314.6778757572174,
375
+ 314.4039900302887,
376
+ 314.40531301498413,
377
+ 313.9343922138214,
378
+ 313.9356322288513,
379
+ 315.1470823287964,
380
+ 315.14854192733765,
381
+ 317.65793561935425,
382
+ 317.65903544425964,
383
+ 316.41589403152466,
384
+ 316.4171371459961,
385
+ 316.253050327301,
386
+ 316.2544617652893,
387
+ 316.2039670944214,
388
+ 316.20542550086975,
389
+ 316.30707120895386,
390
+ 316.30964159965515,
391
+ 315.7812213897705,
392
+ 315.7832131385803,
393
+ 315.77191638946533,
394
+ 315.7732570171356,
395
+ 315.7776229381561,
396
+ 315.77907848358154,
397
+ 315.80343294143677,
398
+ 315.8051166534424,
399
+ 314.40133929252625,
400
+ 314.403112411499,
401
+ 314.32283997535706,
402
+ 314.32424092292786,
403
+ 314.90000677108765,
404
+ 314.90242648124695,
405
+ 313.8207128047943,
406
+ 313.8227391242981,
407
+ 313.86938881874084,
408
+ 313.87079215049744,
409
+ 316.9037547111511,
410
+ 316.9056947231293,
411
+ 317.4321286678314,
412
+ 317.43361139297485,
413
+ 316.41515493392944,
414
+ 316.4182825088501,
415
+ 315.69741559028625,
416
+ 315.699245929718,
417
+ 315.9285054206848,
418
+ 315.930716753006,
419
+ 314.25376319885254,
420
+ 314.25567531585693,
421
+ 312.997665643692,
422
+ 313.0005877017975,
423
+ 315.5962414741516,
424
+ 315.5977747440338,
425
+ 315.49425506591797,
426
+ 315.4961242675781,
427
+ 315.980491399765,
428
+ 315.98283791542053,
429
+ 315.5533638000488,
430
+ 315.55492901802063,
431
+ 313.9896593093872,
432
+ 313.99131321907043,
433
+ 314.3214478492737,
434
+ 314.3232262134552,
435
+ 314.6442220211029,
436
+ 314.64620661735535,
437
+ 315.69726514816284,
438
+ 315.7001700401306,
439
+ 314.78302001953125,
440
+ 314.7847316265106,
441
+ 313.14448523521423,
442
+ 313.1465194225311,
443
+ 311.8232834339142,
444
+ 311.8251144886017,
445
+ 318.88225960731506,
446
+ 318.8843643665314,
447
+ 319.20725083351135,
448
+ 319.20886182785034,
449
+ 317.81429648399353,
450
+ 317.8159878253937,
451
+ 320.23738193511963,
452
+ 320.23904752731323,
453
+ 315.8315763473511,
454
+ 315.83344054222107,
455
+ 317.32581615448,
456
+ 317.3274848461151,
457
+ 316.7596924304962,
458
+ 316.7628848552704,
459
+ 316.3167974948883,
460
+ 316.3188827037811,
461
+ 316.44567823410034,
462
+ 316.44802141189575,
463
+ 313.8653395175934,
464
+ 313.8687484264374,
465
+ 308.43933939933777,
466
+ 308.44151163101196,
467
+ 312.1857454776764,
468
+ 312.18967509269714,
469
+ 307.8407344818115,
470
+ 307.84401679039,
471
+ 307.48447585105896,
472
+ 307.48623728752136,
473
+ 310.300940990448,
474
+ 310.3029022216797,
475
+ 310.32225275039673,
476
+ 310.3257050514221,
477
+ 309.351779460907,
478
+ 309.3539865016937,
479
+ 309.4356527328491,
480
+ 309.4380919933319,
481
+ 312.63360381126404,
482
+ 312.63535809516907,
483
+ 311.7453818321228,
484
+ 311.7476508617401,
485
+ 311.3258364200592,
486
+ 311.327698469162,
487
+ 312.28111600875854,
488
+ 312.2828998565674,
489
+ 311.3383209705353,
490
+ 311.34200048446655,
491
+ 306.9764757156372,
492
+ 306.9787657260895,
493
+ 309.35506653785706,
494
+ 309.3569576740265,
495
+ 310.2506465911865,
496
+ 310.2529339790344,
497
+ 310.65880727767944,
498
+ 310.66108298301697,
499
+ 311.18562865257263,
500
+ 311.1874952316284,
501
+ 309.07765316963196,
502
+ 309.07997822761536,
503
+ 313.3008818626404,
504
+ 313.3029179573059,
505
+ 311.267498254776,
506
+ 311.26989102363586,
507
+ 310.62635374069214,
508
+ 310.6306185722351,
509
+ 308.1883268356323,
510
+ 308.19112515449524,
511
+ 310.65689158439636,
512
+ 310.65896558761597,
513
+ 308.98754620552063,
514
+ 309.03386878967285,
515
+ 309.21512937545776,
516
+ 309.2185757160187,
517
+ 309.93750405311584,
518
+ 309.93965554237366,
519
+ 310.2938587665558,
520
+ 310.29592084884644,
521
+ 308.24257493019104,
522
+ 308.2463102340698,
523
+ 310.6870594024658,
524
+ 310.6905345916748,
525
+ 310.7875945568085,
526
+ 310.78995156288147,
527
+ 310.9882712364197,
528
+ 310.9906806945801,
529
+ 310.95856285095215,
530
+ 310.96066546440125,
531
+ 312.4489221572876,
532
+ 312.45125246047974,
533
+ 312.24022579193115,
534
+ 312.2863116264343,
535
+ 309.68400406837463,
536
+ 309.6862533092499,
537
+ 309.64014887809753,
538
+ 309.64232993125916,
539
+ 309.9094281196594,
540
+ 309.9119017124176,
541
+ 309.40677762031555,
542
+ 309.40893173217773,
543
+ 309.1595506668091,
544
+ 309.1617259979248,
545
+ 308.4178020954132,
546
+ 308.4198989868164,
547
+ 308.5063133239746,
548
+ 308.5085346698761,
549
+ 307.5796904563904,
550
+ 307.5972898006439,
551
+ 309.66309905052185,
552
+ 309.66530561447144,
553
+ 312.70798993110657,
554
+ 312.7102212905884,
555
+ 310.2431013584137,
556
+ 310.2453660964966,
557
+ 312.2640459537506,
558
+ 312.26635122299194,
559
+ 311.27055287361145,
560
+ 311.27321219444275,
561
+ 312.58145689964294,
562
+ 312.58376598358154,
563
+ 313.1553518772125,
564
+ 313.1574249267578,
565
+ 308.4067575931549,
566
+ 308.4089684486389,
567
+ 311.0251498222351,
568
+ 311.0274658203125,
569
+ 308.0227520465851,
570
+ 308.02498388290405,
571
+ 308.0182030200958,
572
+ 308.0204634666443,
573
+ 308.63523149490356,
574
+ 308.63751220703125,
575
+ 308.53969383239746,
576
+ 308.5420751571655,
577
+ 306.51329946517944,
578
+ 306.51555824279785,
579
+ 309.59846591949463,
580
+ 309.60128831863403,
581
+ 305.3712034225464,
582
+ 305.37409830093384,
583
+ 305.43984270095825,
584
+ 305.4421238899231,
585
+ 309.3166663646698,
586
+ 309.3195414543152,
587
+ 308.8618497848511,
588
+ 308.86409974098206,
589
+ 304.8731882572174,
590
+ 304.8755958080292,
591
+ 306.6576888561249,
592
+ 306.663143157959,
593
+ 306.6716537475586,
594
+ 306.6740062236786,
595
+ 309.47339940071106,
596
+ 309.47578954696655,
597
+ 307.73386335372925,
598
+ 307.7363700866699,
599
+ 308.0688214302063,
600
+ 308.07209277153015,
601
+ 311.58968901634216,
602
+ 311.6099576950073,
603
+ 308.70460844039917,
604
+ 308.70710158348083,
605
+ 312.0563473701477,
606
+ 312.05881452560425,
607
+ 310.89456367492676,
608
+ 310.9119510650635,
609
+ 308.73097705841064,
610
+ 308.73414373397827,
611
+ 309.4255359172821,
612
+ 309.42857813835144,
613
+ 311.0751721858978,
614
+ 311.07801842689514,
615
+ 309.5860447883606,
616
+ 309.5896680355072,
617
+ 309.87396597862244,
618
+ 309.8803391456604,
619
+ 310.9183626174927,
620
+ 310.92147397994995,
621
+ 308.4321529865265,
622
+ 308.4359757900238,
623
+ 312.4424922466278,
624
+ 312.44731879234314,
625
+ 312.3443009853363,
626
+ 312.3491401672363,
627
+ 310.3139410018921,
628
+ 310.3165555000305,
629
+ 312.09410762786865,
630
+ 312.09656262397766,
631
+ 311.11144399642944,
632
+ 311.1577796936035,
633
+ 309.1589603424072,
634
+ 309.16152119636536,
635
+ 312.51157093048096,
636
+ 312.51463317871094,
637
+ 314.15198159217834,
638
+ 314.15485286712646,
639
+ 310.00070810317993,
640
+ 310.0033264160156,
641
+ 311.2290298938751,
642
+ 311.23188829421997,
643
+ 313.0510983467102,
644
+ 313.05362153053284,
645
+ 313.48791670799255,
646
+ 313.4910161495209,
647
+ 307.60272216796875,
648
+ 307.6053590774536,
649
+ 303.84622287750244,
650
+ 303.8494029045105,
651
+ 304.8547012805939,
652
+ 304.85784125328064,
653
+ 310.63141536712646,
654
+ 310.63450264930725,
655
+ 304.8634753227234,
656
+ 304.8664004802704,
657
+ 308.1505949497223,
658
+ 308.15428018569946,
659
+ 310.18936228752136,
660
+ 310.1920323371887,
661
+ 309.2550263404846,
662
+ 309.2577428817749,
663
+ 310.08596634864807,
664
+ 310.08910751342773,
665
+ 307.4643654823303,
666
+ 307.4670605659485,
667
+ 308.558221578598,
668
+ 308.5638659000397,
669
+ 309.7440264225006,
670
+ 309.7467608451843,
671
+ 308.2091956138611,
672
+ 308.2125828266144,
673
+ 307.0199763774872,
674
+ 307.02332496643066,
675
+ 306.3482081890106,
676
+ 306.35128688812256,
677
+ 307.3764581680298,
678
+ 307.37923669815063,
679
+ 311.61060428619385,
680
+ 311.6135311126709,
681
+ 306.8187861442566,
682
+ 306.8240280151367,
683
+ 305.19880175590515,
684
+ 305.20313119888306,
685
+ 309.252712726593,
686
+ 309.256165266037,
687
+ 310.80801463127136,
688
+ 310.81236577033997,
689
+ 309.1079206466675,
690
+ 309.11073756217957,
691
+ 310.6556165218353,
692
+ 310.65838623046875,
693
+ 310.94868993759155,
694
+ 310.95155143737793,
695
+ 308.4552607536316,
696
+ 308.4580717086792,
697
+ 308.2857587337494,
698
+ 308.2886221408844,
699
+ 306.4856150150299,
700
+ 306.4887855052948,
701
+ 306.8667871952057,
702
+ 306.86966013908386,
703
+ 306.1964519023895,
704
+ 306.2005341053009,
705
+ 308.2178611755371,
706
+ 308.22126364707947,
707
+ 305.94888377189636,
708
+ 305.9523375034332,
709
+ 307.48926973342896,
710
+ 307.4920620918274,
711
+ 307.60354018211365,
712
+ 307.63674998283386,
713
+ 307.2473645210266,
714
+ 307.2501358985901,
715
+ 308.16573452949524,
716
+ 308.2115182876587,
717
+ 307.30736780166626,
718
+ 307.3109815120697,
719
+ 307.2137475013733,
720
+ 307.2178246974945,
721
+ 308.5944905281067,
722
+ 308.59843826293945,
723
+ 307.2346291542053,
724
+ 307.2382435798645,
725
+ 308.417338848114,
726
+ 308.4208617210388,
727
+ 305.5816307067871,
728
+ 305.5852439403534,
729
+ 307.69459652900696,
730
+ 307.6975119113922,
731
+ 307.20833134651184,
732
+ 307.212299823761,
733
+ 305.9614431858063,
734
+ 305.965185880661,
735
+ 305.31594157218933,
736
+ 305.3195445537567,
737
+ 307.46696519851685,
738
+ 307.47079825401306,
739
+ 306.23966455459595,
740
+ 306.2433180809021,
741
+ 306.1235647201538,
742
+ 306.1273248195648,
743
+ 307.02436780929565,
744
+ 307.02733421325684,
745
+ 306.9687819480896,
746
+ 306.97225856781006,
747
+ 306.23205065727234,
748
+ 306.2356073856354,
749
+ 305.3567383289337,
750
+ 305.36028504371643,
751
+ 305.94446635246277,
752
+ 305.9480822086334,
753
+ 307.2553553581238
754
+ ],
755
+ "valid_loss_history": [
756
+ -12.743322372436523,
757
+ -12.724347114562988,
758
+ -12.86701488494873,
759
+ -12.694435119628906,
760
+ -12.706733703613281,
761
+ -13.048251152038574,
762
+ -12.943618774414062,
763
+ -13.120084762573242,
764
+ -13.121935844421387,
765
+ -13.146740913391113,
766
+ -13.197364807128906,
767
+ -13.224929809570312,
768
+ -13.255891799926758,
769
+ -13.311783790588379,
770
+ -13.386489868164062,
771
+ -13.390006065368652,
772
+ -13.45509147644043,
773
+ -13.444679260253906,
774
+ -13.456311225891113,
775
+ -13.36051082611084,
776
+ -13.478644371032715,
777
+ -13.503388404846191,
778
+ -13.540580749511719,
779
+ -13.579903602600098,
780
+ -13.551591873168945,
781
+ -13.638075828552246,
782
+ -13.617512702941895,
783
+ -13.64240550994873,
784
+ -13.618767738342285,
785
+ -13.65319538116455,
786
+ -13.601574897766113,
787
+ -13.693778038024902,
788
+ -13.658882141113281,
789
+ -13.649510383605957,
790
+ -13.477263450622559,
791
+ -13.643564224243164,
792
+ -13.732584953308105,
793
+ -13.643271446228027,
794
+ -13.655325889587402,
795
+ -13.71172046661377,
796
+ -13.564180374145508,
797
+ -13.708178520202637,
798
+ -13.688010215759277,
799
+ -13.711198806762695,
800
+ -13.612863540649414,
801
+ -13.702019691467285,
802
+ -13.704530715942383,
803
+ -13.716957092285156,
804
+ -13.76714038848877,
805
+ -13.719636917114258,
806
+ -13.738469123840332,
807
+ -13.759002685546875,
808
+ -13.721348762512207,
809
+ -13.727803230285645,
810
+ -13.768327713012695,
811
+ -13.73253345489502,
812
+ -13.75208568572998,
813
+ -13.754429817199707,
814
+ -13.76417064666748,
815
+ -13.805985450744629,
816
+ -13.762914657592773,
817
+ -13.75927448272705,
818
+ -13.781553268432617,
819
+ -13.744827270507812,
820
+ -13.805213928222656,
821
+ -13.792055130004883,
822
+ -13.736992835998535,
823
+ -13.804685592651367,
824
+ -13.802186012268066,
825
+ -13.812178611755371,
826
+ -13.781081199645996,
827
+ -13.836441993713379,
828
+ -13.787053108215332,
829
+ -13.824462890625,
830
+ -13.827963829040527,
831
+ -13.768393516540527,
832
+ -13.824796676635742,
833
+ -13.809252738952637,
834
+ -13.820283889770508,
835
+ -13.811989784240723,
836
+ -13.845786094665527,
837
+ -13.801295280456543,
838
+ -13.795866966247559,
839
+ -13.847658157348633,
840
+ -13.841630935668945,
841
+ -13.887687683105469,
842
+ -13.838217735290527,
843
+ -13.833791732788086,
844
+ -13.8090181350708,
845
+ -13.810338973999023,
846
+ -13.812939643859863,
847
+ -13.813563346862793,
848
+ -13.72245979309082,
849
+ -13.829062461853027,
850
+ -13.820122718811035,
851
+ -13.764768600463867,
852
+ -13.882962226867676,
853
+ -13.887824058532715,
854
+ -13.874728202819824,
855
+ -13.83934211730957,
856
+ -13.854304313659668,
857
+ -13.853861808776855,
858
+ -13.878510475158691,
859
+ -13.855673789978027,
860
+ -13.935111999511719,
861
+ -13.873315811157227,
862
+ -13.88434886932373,
863
+ -13.913508415222168,
864
+ -13.804875373840332,
865
+ -13.874313354492188,
866
+ -13.925950050354004,
867
+ -13.898317337036133,
868
+ -13.861913681030273,
869
+ -13.83596134185791,
870
+ -13.907777786254883,
871
+ -13.832358360290527,
872
+ -13.936162948608398,
873
+ -13.925071716308594,
874
+ -13.906752586364746,
875
+ -13.87073040008545,
876
+ -13.964620590209961,
877
+ -13.925311088562012,
878
+ -13.974698066711426,
879
+ -13.957905769348145,
880
+ -13.918564796447754,
881
+ -13.975790023803711,
882
+ -13.988444328308105,
883
+ -13.959516525268555,
884
+ -14.01569652557373,
885
+ -13.992425918579102,
886
+ -14.039790153503418,
887
+ -13.940314292907715,
888
+ -14.011497497558594,
889
+ -13.953152656555176,
890
+ -13.920698165893555,
891
+ -13.960227966308594,
892
+ -13.907439231872559,
893
+ -14.014067649841309,
894
+ -13.972914695739746,
895
+ -13.942621231079102,
896
+ -14.019667625427246,
897
+ -14.037107467651367,
898
+ -13.85366153717041,
899
+ -13.980110168457031,
900
+ -13.97785472869873,
901
+ -13.983843803405762,
902
+ -13.843756675720215,
903
+ -14.002585411071777,
904
+ -14.026784896850586,
905
+ -14.028115272521973,
906
+ -14.02059268951416,
907
+ -13.985837936401367,
908
+ -14.076154708862305,
909
+ -14.060620307922363,
910
+ -13.936518669128418,
911
+ -13.957221031188965,
912
+ -14.017061233520508,
913
+ -13.995661735534668,
914
+ -14.056286811828613,
915
+ -14.037705421447754,
916
+ -13.940332412719727,
917
+ -14.092416763305664,
918
+ -14.024917602539062,
919
+ -14.002346992492676,
920
+ -14.026989936828613,
921
+ -13.944084167480469,
922
+ -14.002883911132812,
923
+ -14.120462417602539,
924
+ -14.043062210083008,
925
+ -14.008293151855469,
926
+ -14.040563583374023,
927
+ -13.994155883789062,
928
+ -14.08944034576416,
929
+ -14.078422546386719,
930
+ -14.014589309692383,
931
+ -14.083242416381836,
932
+ -14.104707717895508,
933
+ -14.103189468383789,
934
+ -14.063937187194824,
935
+ -14.0596284866333,
936
+ -14.059121131896973,
937
+ -14.102814674377441,
938
+ -14.165373802185059,
939
+ -14.106118202209473,
940
+ -14.107162475585938,
941
+ -14.085371017456055,
942
+ -14.123793601989746,
943
+ -14.053537368774414,
944
+ -14.077792167663574,
945
+ -14.056371688842773,
946
+ -14.033655166625977,
947
+ -14.096640586853027,
948
+ -14.057114601135254,
949
+ -14.115262985229492,
950
+ -14.074142456054688,
951
+ -14.067980766296387,
952
+ -14.118453025817871,
953
+ -14.117535591125488,
954
+ -14.126029968261719,
955
+ -14.117874145507812
956
+ ]
957
+ }
weight/all.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34f2ef4e5c32542060621f7ea9f7a06a2acf91be22825a38f9270077a7346679
3
+ size 9424379