Spaces:
Sleeping
Sleeping
Merge branch 'cjs--classifier-v2' of https://github.com/mhrice/RemFx into classifier-inference
Browse files- cfg/exp/5-5_cls.yaml +2 -2
- cfg/model/cls_panns_16k.yaml +1 -1
- cfg/model/cls_panns_44k_label_smoothing.yaml +17 -0
- cfg/model/cls_panns_48k.yaml +17 -0
- cfg/model/cls_panns_48k_64.yaml +17 -0
- cfg/model/{cls_panns_44k.yaml → cls_panns_48k_mixup.yaml} +6 -5
- cfg/model/{cls_panns_44k_noaug.yaml → cls_panns_48k_specaugment.yaml} +6 -5
- cfg/model/cls_panns_48k_specaugment_label_smoothing.yaml +17 -0
- cfg/model/cls_panns_pt.yaml +1 -0
- remfx/classifier.py +8 -4
- remfx/datasets.py +144 -22
- remfx/models.py +77 -16
cfg/exp/5-5_cls.yaml
CHANGED
@@ -7,7 +7,7 @@ sample_rate: 48000
|
|
7 |
chunk_size: 262144 # 5.5s
|
8 |
logs_dir: "./logs"
|
9 |
render_files: True
|
10 |
-
render_root: "/scratch/
|
11 |
accelerator: "gpu"
|
12 |
log_audio: False
|
13 |
# Effects
|
@@ -56,4 +56,4 @@ trainer:
|
|
56 |
accelerator: ${accelerator}
|
57 |
devices: 1
|
58 |
gradient_clip_val: 10.0
|
59 |
-
max_steps:
|
|
|
7 |
chunk_size: 262144 # 5.5s
|
8 |
logs_dir: "./logs"
|
9 |
render_files: True
|
10 |
+
render_root: "/scratch/EffectSet_cjs_nobass"
|
11 |
accelerator: "gpu"
|
12 |
log_audio: False
|
13 |
# Effects
|
|
|
56 |
accelerator: ${accelerator}
|
57 |
devices: 1
|
58 |
gradient_clip_val: 10.0
|
59 |
+
max_steps: 100000
|
cfg/model/cls_panns_16k.yaml
CHANGED
@@ -10,6 +10,6 @@ model:
|
|
10 |
n_fft: 2048
|
11 |
hop_length: 512
|
12 |
n_mels: 128
|
13 |
-
sample_rate:
|
14 |
model_sample_rate: 16000
|
15 |
|
|
|
10 |
n_fft: 2048
|
11 |
hop_length: 512
|
12 |
n_mels: 128
|
13 |
+
sample_rate: ${sample_rate}
|
14 |
model_sample_rate: 16000
|
15 |
|
cfg/model/cls_panns_44k_label_smoothing.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
model:
|
3 |
+
_target_: remfx.models.FXClassifier
|
4 |
+
lr: 3e-4
|
5 |
+
lr_weight_decay: 1e-3
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
mixup: True
|
8 |
+
label_smoothing: 0.1
|
9 |
+
network:
|
10 |
+
_target_: remfx.classifier.Cnn14
|
11 |
+
num_classes: ${num_classes}
|
12 |
+
n_fft: 2048
|
13 |
+
hop_length: 512
|
14 |
+
n_mels: 128
|
15 |
+
sample_rate: ${sample_rate}
|
16 |
+
model_sample_rate: ${sample_rate}
|
17 |
+
specaugment: False
|
cfg/model/cls_panns_48k.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
model:
|
3 |
+
_target_: remfx.models.FXClassifier
|
4 |
+
lr: 3e-4
|
5 |
+
lr_weight_decay: 1e-3
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
mixup: False
|
8 |
+
network:
|
9 |
+
_target_: remfx.classifier.Cnn14
|
10 |
+
num_classes: ${num_classes}
|
11 |
+
n_fft: 2048
|
12 |
+
hop_length: 512
|
13 |
+
n_mels: 128
|
14 |
+
sample_rate: ${sample_rate}
|
15 |
+
model_sample_rate: ${sample_rate}
|
16 |
+
specaugment: False
|
17 |
+
|
cfg/model/cls_panns_48k_64.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
model:
|
3 |
+
_target_: remfx.models.FXClassifier
|
4 |
+
lr: 3e-4
|
5 |
+
lr_weight_decay: 1e-3
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
mixup: False
|
8 |
+
network:
|
9 |
+
_target_: remfx.classifier.Cnn14
|
10 |
+
num_classes: ${num_classes}
|
11 |
+
n_fft: 2048
|
12 |
+
hop_length: 512
|
13 |
+
n_mels: 64
|
14 |
+
sample_rate: ${sample_rate}
|
15 |
+
model_sample_rate: ${sample_rate}
|
16 |
+
specaugment: False
|
17 |
+
|
cfg/model/{cls_panns_44k.yaml → cls_panns_48k_mixup.yaml}
RENAMED
@@ -4,12 +4,13 @@ model:
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
|
|
7 |
network:
|
8 |
_target_: remfx.classifier.Cnn14
|
9 |
num_classes: ${num_classes}
|
10 |
-
n_fft:
|
11 |
-
hop_length:
|
12 |
n_mels: 128
|
13 |
-
sample_rate:
|
14 |
-
model_sample_rate:
|
15 |
-
specaugment:
|
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
7 |
+
mixup: True
|
8 |
network:
|
9 |
_target_: remfx.classifier.Cnn14
|
10 |
num_classes: ${num_classes}
|
11 |
+
n_fft: 2048
|
12 |
+
hop_length: 512
|
13 |
n_mels: 128
|
14 |
+
sample_rate: ${sample_rate}
|
15 |
+
model_sample_rate: ${sample_rate}
|
16 |
+
specaugment: False
|
cfg/model/{cls_panns_44k_noaug.yaml → cls_panns_48k_specaugment.yaml}
RENAMED
@@ -4,12 +4,13 @@ model:
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
|
|
7 |
network:
|
8 |
_target_: remfx.classifier.Cnn14
|
9 |
num_classes: ${num_classes}
|
10 |
-
n_fft:
|
11 |
-
hop_length:
|
12 |
n_mels: 128
|
13 |
-
sample_rate:
|
14 |
-
model_sample_rate:
|
15 |
-
specaugment:
|
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
7 |
+
mixup: False
|
8 |
network:
|
9 |
_target_: remfx.classifier.Cnn14
|
10 |
num_classes: ${num_classes}
|
11 |
+
n_fft: 2048
|
12 |
+
hop_length: 512
|
13 |
n_mels: 128
|
14 |
+
sample_rate: ${sample_rate}
|
15 |
+
model_sample_rate: ${sample_rate}
|
16 |
+
specaugment: True
|
cfg/model/cls_panns_48k_specaugment_label_smoothing.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
model:
|
3 |
+
_target_: remfx.models.FXClassifier
|
4 |
+
lr: 3e-4
|
5 |
+
lr_weight_decay: 1e-3
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
mixup: False
|
8 |
+
label_smoothing: 0.15
|
9 |
+
network:
|
10 |
+
_target_: remfx.classifier.Cnn14
|
11 |
+
num_classes: ${num_classes}
|
12 |
+
n_fft: 2048
|
13 |
+
hop_length: 512
|
14 |
+
n_mels: 128
|
15 |
+
sample_rate: ${sample_rate}
|
16 |
+
model_sample_rate: ${sample_rate}
|
17 |
+
specaugment: True
|
cfg/model/cls_panns_pt.yaml
CHANGED
@@ -4,6 +4,7 @@ model:
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
|
|
7 |
network:
|
8 |
_target_: remfx.classifier.PANNs
|
9 |
num_classes: ${num_classes}
|
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
7 |
+
mixup: False
|
8 |
network:
|
9 |
_target_: remfx.classifier.PANNs
|
10 |
num_classes: ${num_classes}
|
remfx/classifier.py
CHANGED
@@ -33,7 +33,7 @@ class PANNs(torch.nn.Module):
|
|
33 |
torch.nn.Linear(hidden_dim, num_classes),
|
34 |
)
|
35 |
|
36 |
-
def forward(self, x: torch.Tensor):
|
37 |
with torch.no_grad():
|
38 |
x = self.resample(x)
|
39 |
embed = panns_hear.get_scene_embeddings(x.view(x.shape[0], -1), self.model)
|
@@ -61,7 +61,7 @@ class Wav2CLIP(nn.Module):
|
|
61 |
torch.nn.Linear(hidden_dim, num_classes),
|
62 |
)
|
63 |
|
64 |
-
def forward(self, x: torch.Tensor):
|
65 |
with torch.no_grad():
|
66 |
x = self.resample(x)
|
67 |
embed = wav2clip_hear.get_scene_embeddings(
|
@@ -91,7 +91,7 @@ class VGGish(nn.Module):
|
|
91 |
torch.nn.Linear(hidden_dim, num_classes),
|
92 |
)
|
93 |
|
94 |
-
def forward(self, x: torch.Tensor):
|
95 |
with torch.no_grad():
|
96 |
x = self.resample(x)
|
97 |
embed = hearbaseline.vggish.get_scene_embeddings(
|
@@ -121,7 +121,7 @@ class wav2vec2(nn.Module):
|
|
121 |
torch.nn.Linear(hidden_dim, num_classes),
|
122 |
)
|
123 |
|
124 |
-
def forward(self, x: torch.Tensor):
|
125 |
with torch.no_grad():
|
126 |
x = self.resample(x)
|
127 |
embed = hearbaseline.wav2vec2.get_scene_embeddings(
|
@@ -181,6 +181,10 @@ class Cnn14(nn.Module):
|
|
181 |
orig_freq=sample_rate, new_freq=model_sample_rate
|
182 |
)
|
183 |
|
|
|
|
|
|
|
|
|
184 |
def init_weight(self):
|
185 |
init_bn(self.bn0)
|
186 |
init_layer(self.fc1)
|
|
|
33 |
torch.nn.Linear(hidden_dim, num_classes),
|
34 |
)
|
35 |
|
36 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
37 |
with torch.no_grad():
|
38 |
x = self.resample(x)
|
39 |
embed = panns_hear.get_scene_embeddings(x.view(x.shape[0], -1), self.model)
|
|
|
61 |
torch.nn.Linear(hidden_dim, num_classes),
|
62 |
)
|
63 |
|
64 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
65 |
with torch.no_grad():
|
66 |
x = self.resample(x)
|
67 |
embed = wav2clip_hear.get_scene_embeddings(
|
|
|
91 |
torch.nn.Linear(hidden_dim, num_classes),
|
92 |
)
|
93 |
|
94 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
95 |
with torch.no_grad():
|
96 |
x = self.resample(x)
|
97 |
embed = hearbaseline.vggish.get_scene_embeddings(
|
|
|
121 |
torch.nn.Linear(hidden_dim, num_classes),
|
122 |
)
|
123 |
|
124 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
125 |
with torch.no_grad():
|
126 |
x = self.resample(x)
|
127 |
embed = hearbaseline.wav2vec2.get_scene_embeddings(
|
|
|
181 |
orig_freq=sample_rate, new_freq=model_sample_rate
|
182 |
)
|
183 |
|
184 |
+
if self.specaugment:
|
185 |
+
self.freq_mask = torchaudio.transforms.FrequencyMasking(64, True)
|
186 |
+
self.time_mask = torchaudio.transforms.TimeMasking(128, True)
|
187 |
+
|
188 |
def init_weight(self):
|
189 |
init_bn(self.bn0)
|
190 |
init_layer(self.fc1)
|
remfx/datasets.py
CHANGED
@@ -8,15 +8,16 @@ import pytorch_lightning as pl
|
|
8 |
import random
|
9 |
from tqdm import tqdm
|
10 |
from pathlib import Path
|
11 |
-
from remfx import effects
|
12 |
from typing import Any, List, Dict
|
13 |
from torch.utils.data import Dataset, DataLoader
|
14 |
from remfx.utils import select_random_chunk
|
|
|
15 |
|
16 |
|
17 |
# https://zenodo.org/record/1193957 -> VocalSet
|
18 |
|
19 |
-
ALL_EFFECTS =
|
20 |
# print(ALL_EFFECTS)
|
21 |
|
22 |
|
@@ -146,6 +147,101 @@ def locate_files(root: str, mode: str):
|
|
146 |
return file_list
|
147 |
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
class EffectDataset(Dataset):
|
150 |
def __init__(
|
151 |
self,
|
@@ -163,6 +259,7 @@ class EffectDataset(Dataset):
|
|
163 |
render_files: bool = True,
|
164 |
render_root: str = None,
|
165 |
mode: str = "train",
|
|
|
166 |
):
|
167 |
super().__init__()
|
168 |
self.chunks = []
|
@@ -177,7 +274,7 @@ class EffectDataset(Dataset):
|
|
177 |
self.num_removed_effects = num_removed_effects
|
178 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
179 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
180 |
-
self.normalize =
|
181 |
self.effects = effect_modules
|
182 |
self.shuffle_kept_effects = shuffle_kept_effects
|
183 |
self.shuffle_removed_effects = shuffle_removed_effects
|
@@ -192,6 +289,7 @@ class EffectDataset(Dataset):
|
|
192 |
)
|
193 |
self.validate_effect_input()
|
194 |
self.proc_root = self.render_root / "processed" / effects_string / self.mode
|
|
|
195 |
|
196 |
self.files = locate_files(self.root, self.mode)
|
197 |
|
@@ -212,26 +310,50 @@ class EffectDataset(Dataset):
|
|
212 |
if render_files:
|
213 |
# Split audio file into chunks, resample, then apply random effects
|
214 |
self.proc_root.mkdir(parents=True, exist_ok=True)
|
215 |
-
for num_chunk in tqdm(range(self.total_chunks)):
|
216 |
-
chunk = None
|
217 |
-
random_dataset_choice = random.choice(self.files)
|
218 |
-
while chunk is None:
|
219 |
-
random_file_choice = random.choice(random_dataset_choice)
|
220 |
-
chunk = select_random_chunk(
|
221 |
-
random_file_choice, self.chunk_size, self.sample_rate
|
222 |
-
)
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
print("Finished rendering")
|
237 |
else:
|
|
|
8 |
import random
|
9 |
from tqdm import tqdm
|
10 |
from pathlib import Path
|
11 |
+
from remfx import effects as effect_lib
|
12 |
from typing import Any, List, Dict
|
13 |
from torch.utils.data import Dataset, DataLoader
|
14 |
from remfx.utils import select_random_chunk
|
15 |
+
import multiprocessing
|
16 |
|
17 |
|
18 |
# https://zenodo.org/record/1193957 -> VocalSet
|
19 |
|
20 |
+
ALL_EFFECTS = effect_lib.Pedalboard_Effects
|
21 |
# print(ALL_EFFECTS)
|
22 |
|
23 |
|
|
|
147 |
return file_list
|
148 |
|
149 |
|
150 |
+
def parallel_process_effects(
|
151 |
+
chunk_idx: int,
|
152 |
+
proc_root: str,
|
153 |
+
files: list,
|
154 |
+
chunk_size: int,
|
155 |
+
effects: list,
|
156 |
+
effects_to_keep: list,
|
157 |
+
num_kept_effects: tuple,
|
158 |
+
shuffle_kept_effects: bool,
|
159 |
+
effects_to_remove: list,
|
160 |
+
num_removed_effects: tuple,
|
161 |
+
shuffle_removed_effects: bool,
|
162 |
+
sample_rate: int,
|
163 |
+
target_lufs_db: float,
|
164 |
+
):
|
165 |
+
chunk = None
|
166 |
+
random_dataset_choice = random.choice(files)
|
167 |
+
while chunk is None:
|
168 |
+
random_file_choice = random.choice(random_dataset_choice)
|
169 |
+
chunk = select_random_chunk(random_file_choice, chunk_size, sample_rate)
|
170 |
+
|
171 |
+
# Sum to mono
|
172 |
+
if chunk.shape[0] > 1:
|
173 |
+
chunk = chunk.sum(0, keepdim=True)
|
174 |
+
|
175 |
+
dry = chunk
|
176 |
+
|
177 |
+
# loudness normalization
|
178 |
+
normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db)
|
179 |
+
|
180 |
+
# Apply Kept Effects
|
181 |
+
# Shuffle effects if specified
|
182 |
+
if shuffle_kept_effects:
|
183 |
+
effect_indices = torch.randperm(len(effects_to_keep))
|
184 |
+
else:
|
185 |
+
effect_indices = torch.arange(len(effects_to_keep))
|
186 |
+
|
187 |
+
r1 = num_kept_effects[0]
|
188 |
+
r2 = num_kept_effects[1]
|
189 |
+
num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
190 |
+
effect_indices = effect_indices[:num_kept_effects]
|
191 |
+
# Index in effect settings
|
192 |
+
effect_names_to_apply = [effects_to_keep[i] for i in effect_indices]
|
193 |
+
effects_to_apply = [effects[i] for i in effect_names_to_apply]
|
194 |
+
# Apply
|
195 |
+
dry_labels = []
|
196 |
+
for effect in effects_to_apply:
|
197 |
+
# Normalize in-between effects
|
198 |
+
dry = normalize(effect(dry))
|
199 |
+
dry_labels.append(ALL_EFFECTS.index(type(effect)))
|
200 |
+
|
201 |
+
# Apply effects_to_remove
|
202 |
+
# Shuffle effects if specified
|
203 |
+
if shuffle_removed_effects:
|
204 |
+
effect_indices = torch.randperm(len(effects_to_remove))
|
205 |
+
else:
|
206 |
+
effect_indices = torch.arange(len(effects_to_remove))
|
207 |
+
wet = torch.clone(dry)
|
208 |
+
r1 = num_removed_effects[0]
|
209 |
+
r2 = num_removed_effects[1]
|
210 |
+
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
211 |
+
effect_indices = effect_indices[:num_removed_effects]
|
212 |
+
# Index in effect settings
|
213 |
+
effect_names_to_apply = [effects_to_remove[i] for i in effect_indices]
|
214 |
+
effects_to_apply = [effects[i] for i in effect_names_to_apply]
|
215 |
+
# Apply
|
216 |
+
wet_labels = []
|
217 |
+
for effect in effects_to_apply:
|
218 |
+
# Normalize in-between effects
|
219 |
+
wet = normalize(effect(wet))
|
220 |
+
wet_labels.append(ALL_EFFECTS.index(type(effect)))
|
221 |
+
|
222 |
+
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
223 |
+
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
224 |
+
|
225 |
+
for label_idx in wet_labels:
|
226 |
+
wet_labels_tensor[label_idx] = 1.0
|
227 |
+
|
228 |
+
for label_idx in dry_labels:
|
229 |
+
dry_labels_tensor[label_idx] = 1.0
|
230 |
+
|
231 |
+
# Normalize
|
232 |
+
normalized_dry = normalize(dry)
|
233 |
+
normalized_wet = normalize(wet)
|
234 |
+
|
235 |
+
output_dir = proc_root / str(chunk_idx)
|
236 |
+
output_dir.mkdir(exist_ok=True)
|
237 |
+
torchaudio.save(output_dir / "input.wav", normalized_wet, sample_rate)
|
238 |
+
torchaudio.save(output_dir / "target.wav", normalized_dry, sample_rate)
|
239 |
+
torch.save(dry_labels_tensor, output_dir / "dry_effects.pt")
|
240 |
+
torch.save(wet_labels_tensor, output_dir / "wet_effects.pt")
|
241 |
+
|
242 |
+
# return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
243 |
+
|
244 |
+
|
245 |
class EffectDataset(Dataset):
|
246 |
def __init__(
|
247 |
self,
|
|
|
259 |
render_files: bool = True,
|
260 |
render_root: str = None,
|
261 |
mode: str = "train",
|
262 |
+
parallel: bool = True,
|
263 |
):
|
264 |
super().__init__()
|
265 |
self.chunks = []
|
|
|
274 |
self.num_removed_effects = num_removed_effects
|
275 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
276 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
277 |
+
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
278 |
self.effects = effect_modules
|
279 |
self.shuffle_kept_effects = shuffle_kept_effects
|
280 |
self.shuffle_removed_effects = shuffle_removed_effects
|
|
|
289 |
)
|
290 |
self.validate_effect_input()
|
291 |
self.proc_root = self.render_root / "processed" / effects_string / self.mode
|
292 |
+
self.parallel = parallel
|
293 |
|
294 |
self.files = locate_files(self.root, self.mode)
|
295 |
|
|
|
310 |
if render_files:
|
311 |
# Split audio file into chunks, resample, then apply random effects
|
312 |
self.proc_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
+
if self.parallel:
|
315 |
+
items = [
|
316 |
+
(
|
317 |
+
chunk_idx,
|
318 |
+
self.proc_root,
|
319 |
+
self.files,
|
320 |
+
self.chunk_size,
|
321 |
+
self.effects,
|
322 |
+
self.effects_to_keep,
|
323 |
+
self.num_kept_effects,
|
324 |
+
self.shuffle_kept_effects,
|
325 |
+
self.effects_to_remove,
|
326 |
+
self.num_removed_effects,
|
327 |
+
self.shuffle_removed_effects,
|
328 |
+
self.sample_rate,
|
329 |
+
-20.0,
|
330 |
+
)
|
331 |
+
for chunk_idx in range(self.total_chunks)
|
332 |
+
]
|
333 |
+
with multiprocessing.Pool(processes=32) as pool:
|
334 |
+
pool.starmap(parallel_process_effects, items)
|
335 |
+
print(f"Done proccessing {self.total_chunks}", flush=True)
|
336 |
+
else:
|
337 |
+
for num_chunk in tqdm(range(self.total_chunks)):
|
338 |
+
chunk = None
|
339 |
+
random_dataset_choice = random.choice(self.files)
|
340 |
+
while chunk is None:
|
341 |
+
random_file_choice = random.choice(random_dataset_choice)
|
342 |
+
chunk = select_random_chunk(
|
343 |
+
random_file_choice, self.chunk_size, self.sample_rate
|
344 |
+
)
|
345 |
+
|
346 |
+
# Sum to mono
|
347 |
+
if chunk.shape[0] > 1:
|
348 |
+
chunk = chunk.sum(0, keepdim=True)
|
349 |
+
|
350 |
+
dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
|
351 |
+
output_dir = self.proc_root / str(num_chunk)
|
352 |
+
output_dir.mkdir(exist_ok=True)
|
353 |
+
torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
|
354 |
+
torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
|
355 |
+
torch.save(dry_effects, output_dir / "dry_effects.pt")
|
356 |
+
torch.save(wet_effects, output_dir / "wet_effects.pt")
|
357 |
|
358 |
print("Finished rendering")
|
359 |
else:
|
remfx/models.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import torch
|
|
|
2 |
import torchmetrics
|
3 |
import pytorch_lightning as pl
|
4 |
from torch import Tensor, nn
|
@@ -424,6 +425,30 @@ class TCNModel(nn.Module):
|
|
424 |
return output
|
425 |
|
426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
class FXClassifier(pl.LightningModule):
|
428 |
def __init__(
|
429 |
self,
|
@@ -431,13 +456,19 @@ class FXClassifier(pl.LightningModule):
|
|
431 |
lr_weight_decay: float,
|
432 |
sample_rate: float,
|
433 |
network: nn.Module,
|
|
|
|
|
434 |
):
|
435 |
super().__init__()
|
436 |
self.lr = lr
|
437 |
self.lr_weight_decay = lr_weight_decay
|
438 |
self.sample_rate = sample_rate
|
439 |
self.network = network
|
440 |
-
self.effects = ["
|
|
|
|
|
|
|
|
|
441 |
|
442 |
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
443 |
5, average="none", multidim_average="global"
|
@@ -449,20 +480,47 @@ class FXClassifier(pl.LightningModule):
|
|
449 |
5, average="none", multidim_average="global"
|
450 |
)
|
451 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
self.metrics = {
|
453 |
"train": self.train_f1,
|
454 |
"valid": self.val_f1,
|
455 |
"test": self.test_f1,
|
456 |
}
|
457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
def forward(self, x: torch.Tensor, train: bool = False):
|
459 |
-
return self.network(x)
|
460 |
|
461 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
462 |
train = True if mode == "train" else False
|
463 |
x, y, dry_label, wet_label = batch
|
464 |
-
|
465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
self.log(
|
467 |
f"{mode}_loss",
|
468 |
loss,
|
@@ -473,18 +531,7 @@ class FXClassifier(pl.LightningModule):
|
|
473 |
sync_dist=True,
|
474 |
)
|
475 |
|
476 |
-
metrics = self.metrics[mode](pred_label, wet_label.long())
|
477 |
-
avg_metrics = torch.mean(metrics)
|
478 |
-
|
479 |
-
self.log(
|
480 |
-
f"{mode}_f1_avg",
|
481 |
-
avg_metrics,
|
482 |
-
on_step=True,
|
483 |
-
on_epoch=True,
|
484 |
-
prog_bar=True,
|
485 |
-
logger=True,
|
486 |
-
sync_dist=True,
|
487 |
-
)
|
488 |
|
489 |
for idx, effect_name in enumerate(self.effects):
|
490 |
self.log(
|
@@ -497,6 +544,20 @@ class FXClassifier(pl.LightningModule):
|
|
497 |
sync_dist=True,
|
498 |
)
|
499 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
return loss
|
501 |
|
502 |
def training_step(self, batch, batch_idx):
|
|
|
1 |
import torch
|
2 |
+
import numpy as np
|
3 |
import torchmetrics
|
4 |
import pytorch_lightning as pl
|
5 |
from torch import Tensor, nn
|
|
|
425 |
return output
|
426 |
|
427 |
|
428 |
+
def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
|
429 |
+
"""Mixup data augmentation for time-domain signals.
|
430 |
+
Args:
|
431 |
+
x (torch.Tensor): Batch of time-domain signals, shape [batch, 1, time].
|
432 |
+
y (torch.Tensor): Batch of labels, shape [batch, n_classes].
|
433 |
+
alpha (float): Beta distribution parameter.
|
434 |
+
Returns:
|
435 |
+
torch.Tensor: Mixed time-domain signals, shape [batch, 1, time].
|
436 |
+
torch.Tensor: Mixed labels, shape [batch, n_classes].
|
437 |
+
torch.Tensor: Lambda
|
438 |
+
"""
|
439 |
+
batch_size = x.size(0)
|
440 |
+
if alpha > 0:
|
441 |
+
lam = np.random.beta(alpha, alpha)
|
442 |
+
else:
|
443 |
+
lam = 1
|
444 |
+
|
445 |
+
index = torch.randperm(batch_size).to(x.device)
|
446 |
+
mixed_x = lam * x + (1 - lam) * x[index, :]
|
447 |
+
mixed_y = lam * y + (1 - lam) * y[index, :]
|
448 |
+
|
449 |
+
return mixed_x, mixed_y, lam
|
450 |
+
|
451 |
+
|
452 |
class FXClassifier(pl.LightningModule):
|
453 |
def __init__(
|
454 |
self,
|
|
|
456 |
lr_weight_decay: float,
|
457 |
sample_rate: float,
|
458 |
network: nn.Module,
|
459 |
+
mixup: bool = False,
|
460 |
+
label_smoothing: float = 0.0,
|
461 |
):
|
462 |
super().__init__()
|
463 |
self.lr = lr
|
464 |
self.lr_weight_decay = lr_weight_decay
|
465 |
self.sample_rate = sample_rate
|
466 |
self.network = network
|
467 |
+
self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"]
|
468 |
+
self.mixup = mixup
|
469 |
+
self.label_smoothing = label_smoothing
|
470 |
+
|
471 |
+
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
472 |
|
473 |
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
474 |
5, average="none", multidim_average="global"
|
|
|
480 |
5, average="none", multidim_average="global"
|
481 |
)
|
482 |
|
483 |
+
self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
484 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
485 |
+
)
|
486 |
+
self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
487 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
488 |
+
)
|
489 |
+
self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
490 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
491 |
+
)
|
492 |
+
|
493 |
self.metrics = {
|
494 |
"train": self.train_f1,
|
495 |
"valid": self.val_f1,
|
496 |
"test": self.test_f1,
|
497 |
}
|
498 |
|
499 |
+
self.avg_metrics = {
|
500 |
+
"train": self.train_f1_avg,
|
501 |
+
"valid": self.val_f1_avg,
|
502 |
+
"test": self.test_f1_avg,
|
503 |
+
}
|
504 |
+
|
505 |
def forward(self, x: torch.Tensor, train: bool = False):
|
506 |
+
return self.network(x, train=train)
|
507 |
|
508 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
509 |
train = True if mode == "train" else False
|
510 |
x, y, dry_label, wet_label = batch
|
511 |
+
|
512 |
+
if mode == "train" and self.mixup:
|
513 |
+
x_mixed, label_mixed, lam = mixup(x, wet_label)
|
514 |
+
pred_label = self(x_mixed, train)
|
515 |
+
loss = self.loss_fn(pred_label, label_mixed)
|
516 |
+
print(torch.sigmoid(pred_label[0, ...]))
|
517 |
+
print(label_mixed[0, ...])
|
518 |
+
else:
|
519 |
+
pred_label = self(x, train)
|
520 |
+
loss = self.loss_fn(pred_label, wet_label)
|
521 |
+
print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
|
522 |
+
print(wet_label.long()[0, ...])
|
523 |
+
|
524 |
self.log(
|
525 |
f"{mode}_loss",
|
526 |
loss,
|
|
|
531 |
sync_dist=True,
|
532 |
)
|
533 |
|
534 |
+
metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
|
536 |
for idx, effect_name in enumerate(self.effects):
|
537 |
self.log(
|
|
|
544 |
sync_dist=True,
|
545 |
)
|
546 |
|
547 |
+
avg_metrics = self.avg_metrics[mode](
|
548 |
+
torch.sigmoid(pred_label), wet_label.long()
|
549 |
+
)
|
550 |
+
|
551 |
+
self.log(
|
552 |
+
f"{mode}_f1_avg",
|
553 |
+
avg_metrics,
|
554 |
+
on_step=True,
|
555 |
+
on_epoch=True,
|
556 |
+
prog_bar=True,
|
557 |
+
logger=True,
|
558 |
+
sync_dist=True,
|
559 |
+
)
|
560 |
+
|
561 |
return loss
|
562 |
|
563 |
def training_step(self, batch, batch_idx):
|