mattricesound commited on
Commit
b8427f9
·
2 Parent(s): be8f3b0 3c4fcfb

Merge branch 'cjs--classifier-v2' of https://github.com/mhrice/RemFx into classifier-inference

Browse files
cfg/exp/5-5_cls.yaml CHANGED
@@ -7,7 +7,7 @@ sample_rate: 48000
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
- render_root: "/scratch/EffectSet_cjs"
11
  accelerator: "gpu"
12
  log_audio: False
13
  # Effects
@@ -56,4 +56,4 @@ trainer:
56
  accelerator: ${accelerator}
57
  devices: 1
58
  gradient_clip_val: 10.0
59
- max_steps: 80000
 
7
  chunk_size: 262144 # 5.5s
8
  logs_dir: "./logs"
9
  render_files: True
10
+ render_root: "/scratch/EffectSet_cjs_nobass"
11
  accelerator: "gpu"
12
  log_audio: False
13
  # Effects
 
56
  accelerator: ${accelerator}
57
  devices: 1
58
  gradient_clip_val: 10.0
59
+ max_steps: 100000
cfg/model/cls_panns_16k.yaml CHANGED
@@ -10,6 +10,6 @@ model:
10
  n_fft: 2048
11
  hop_length: 512
12
  n_mels: 128
13
- sample_rate: 44100
14
  model_sample_rate: 16000
15
 
 
10
  n_fft: 2048
11
  hop_length: 512
12
  n_mels: 128
13
+ sample_rate: ${sample_rate}
14
  model_sample_rate: 16000
15
 
cfg/model/cls_panns_44k_label_smoothing.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: True
8
+ label_smoothing: 0.1
9
+ network:
10
+ _target_: remfx.classifier.Cnn14
11
+ num_classes: ${num_classes}
12
+ n_fft: 2048
13
+ hop_length: 512
14
+ n_mels: 128
15
+ sample_rate: ${sample_rate}
16
+ model_sample_rate: ${sample_rate}
17
+ specaugment: False
cfg/model/cls_panns_48k.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ network:
9
+ _target_: remfx.classifier.Cnn14
10
+ num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 128
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: False
17
+
cfg/model/cls_panns_48k_64.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ network:
9
+ _target_: remfx.classifier.Cnn14
10
+ num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 64
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: False
17
+
cfg/model/{cls_panns_44k.yaml → cls_panns_48k_mixup.yaml} RENAMED
@@ -4,12 +4,13 @@ model:
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
 
7
  network:
8
  _target_: remfx.classifier.Cnn14
9
  num_classes: ${num_classes}
10
- n_fft: 1024
11
- hop_length: 256
12
  n_mels: 128
13
- sample_rate: 44100
14
- model_sample_rate: 44100
15
- specaugment: True
 
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
+ mixup: True
8
  network:
9
  _target_: remfx.classifier.Cnn14
10
  num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
  n_mels: 128
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: False
cfg/model/{cls_panns_44k_noaug.yaml → cls_panns_48k_specaugment.yaml} RENAMED
@@ -4,12 +4,13 @@ model:
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
 
7
  network:
8
  _target_: remfx.classifier.Cnn14
9
  num_classes: ${num_classes}
10
- n_fft: 1024
11
- hop_length: 256
12
  n_mels: 128
13
- sample_rate: 44100
14
- model_sample_rate: 44100
15
- specaugment: False
 
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
+ mixup: False
8
  network:
9
  _target_: remfx.classifier.Cnn14
10
  num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
  n_mels: 128
14
+ sample_rate: ${sample_rate}
15
+ model_sample_rate: ${sample_rate}
16
+ specaugment: True
cfg/model/cls_panns_48k_specaugment_label_smoothing.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: False
8
+ label_smoothing: 0.15
9
+ network:
10
+ _target_: remfx.classifier.Cnn14
11
+ num_classes: ${num_classes}
12
+ n_fft: 2048
13
+ hop_length: 512
14
+ n_mels: 128
15
+ sample_rate: ${sample_rate}
16
+ model_sample_rate: ${sample_rate}
17
+ specaugment: True
cfg/model/cls_panns_pt.yaml CHANGED
@@ -4,6 +4,7 @@ model:
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
 
7
  network:
8
  _target_: remfx.classifier.PANNs
9
  num_classes: ${num_classes}
 
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
+ mixup: False
8
  network:
9
  _target_: remfx.classifier.PANNs
10
  num_classes: ${num_classes}
remfx/classifier.py CHANGED
@@ -33,7 +33,7 @@ class PANNs(torch.nn.Module):
33
  torch.nn.Linear(hidden_dim, num_classes),
34
  )
35
 
36
- def forward(self, x: torch.Tensor):
37
  with torch.no_grad():
38
  x = self.resample(x)
39
  embed = panns_hear.get_scene_embeddings(x.view(x.shape[0], -1), self.model)
@@ -61,7 +61,7 @@ class Wav2CLIP(nn.Module):
61
  torch.nn.Linear(hidden_dim, num_classes),
62
  )
63
 
64
- def forward(self, x: torch.Tensor):
65
  with torch.no_grad():
66
  x = self.resample(x)
67
  embed = wav2clip_hear.get_scene_embeddings(
@@ -91,7 +91,7 @@ class VGGish(nn.Module):
91
  torch.nn.Linear(hidden_dim, num_classes),
92
  )
93
 
94
- def forward(self, x: torch.Tensor):
95
  with torch.no_grad():
96
  x = self.resample(x)
97
  embed = hearbaseline.vggish.get_scene_embeddings(
@@ -121,7 +121,7 @@ class wav2vec2(nn.Module):
121
  torch.nn.Linear(hidden_dim, num_classes),
122
  )
123
 
124
- def forward(self, x: torch.Tensor):
125
  with torch.no_grad():
126
  x = self.resample(x)
127
  embed = hearbaseline.wav2vec2.get_scene_embeddings(
@@ -181,6 +181,10 @@ class Cnn14(nn.Module):
181
  orig_freq=sample_rate, new_freq=model_sample_rate
182
  )
183
 
 
 
 
 
184
  def init_weight(self):
185
  init_bn(self.bn0)
186
  init_layer(self.fc1)
 
33
  torch.nn.Linear(hidden_dim, num_classes),
34
  )
35
 
36
+ def forward(self, x: torch.Tensor, **kwargs):
37
  with torch.no_grad():
38
  x = self.resample(x)
39
  embed = panns_hear.get_scene_embeddings(x.view(x.shape[0], -1), self.model)
 
61
  torch.nn.Linear(hidden_dim, num_classes),
62
  )
63
 
64
+ def forward(self, x: torch.Tensor, **kwargs):
65
  with torch.no_grad():
66
  x = self.resample(x)
67
  embed = wav2clip_hear.get_scene_embeddings(
 
91
  torch.nn.Linear(hidden_dim, num_classes),
92
  )
93
 
94
+ def forward(self, x: torch.Tensor, **kwargs):
95
  with torch.no_grad():
96
  x = self.resample(x)
97
  embed = hearbaseline.vggish.get_scene_embeddings(
 
121
  torch.nn.Linear(hidden_dim, num_classes),
122
  )
123
 
124
+ def forward(self, x: torch.Tensor, **kwargs):
125
  with torch.no_grad():
126
  x = self.resample(x)
127
  embed = hearbaseline.wav2vec2.get_scene_embeddings(
 
181
  orig_freq=sample_rate, new_freq=model_sample_rate
182
  )
183
 
184
+ if self.specaugment:
185
+ self.freq_mask = torchaudio.transforms.FrequencyMasking(64, True)
186
+ self.time_mask = torchaudio.transforms.TimeMasking(128, True)
187
+
188
  def init_weight(self):
189
  init_bn(self.bn0)
190
  init_layer(self.fc1)
remfx/datasets.py CHANGED
@@ -8,15 +8,16 @@ import pytorch_lightning as pl
8
  import random
9
  from tqdm import tqdm
10
  from pathlib import Path
11
- from remfx import effects
12
  from typing import Any, List, Dict
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
 
15
 
16
 
17
  # https://zenodo.org/record/1193957 -> VocalSet
18
 
19
- ALL_EFFECTS = effects.Pedalboard_Effects
20
  # print(ALL_EFFECTS)
21
 
22
 
@@ -146,6 +147,101 @@ def locate_files(root: str, mode: str):
146
  return file_list
147
 
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  class EffectDataset(Dataset):
150
  def __init__(
151
  self,
@@ -163,6 +259,7 @@ class EffectDataset(Dataset):
163
  render_files: bool = True,
164
  render_root: str = None,
165
  mode: str = "train",
 
166
  ):
167
  super().__init__()
168
  self.chunks = []
@@ -177,7 +274,7 @@ class EffectDataset(Dataset):
177
  self.num_removed_effects = num_removed_effects
178
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
179
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
180
- self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
181
  self.effects = effect_modules
182
  self.shuffle_kept_effects = shuffle_kept_effects
183
  self.shuffle_removed_effects = shuffle_removed_effects
@@ -192,6 +289,7 @@ class EffectDataset(Dataset):
192
  )
193
  self.validate_effect_input()
194
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
 
195
 
196
  self.files = locate_files(self.root, self.mode)
197
 
@@ -212,26 +310,50 @@ class EffectDataset(Dataset):
212
  if render_files:
213
  # Split audio file into chunks, resample, then apply random effects
214
  self.proc_root.mkdir(parents=True, exist_ok=True)
215
- for num_chunk in tqdm(range(self.total_chunks)):
216
- chunk = None
217
- random_dataset_choice = random.choice(self.files)
218
- while chunk is None:
219
- random_file_choice = random.choice(random_dataset_choice)
220
- chunk = select_random_chunk(
221
- random_file_choice, self.chunk_size, self.sample_rate
222
- )
223
 
224
- # Sum to mono
225
- if chunk.shape[0] > 1:
226
- chunk = chunk.sum(0, keepdim=True)
227
-
228
- dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
229
- output_dir = self.proc_root / str(num_chunk)
230
- output_dir.mkdir(exist_ok=True)
231
- torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
232
- torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
233
- torch.save(dry_effects, output_dir / "dry_effects.pt")
234
- torch.save(wet_effects, output_dir / "wet_effects.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  print("Finished rendering")
237
  else:
 
8
  import random
9
  from tqdm import tqdm
10
  from pathlib import Path
11
+ from remfx import effects as effect_lib
12
  from typing import Any, List, Dict
13
  from torch.utils.data import Dataset, DataLoader
14
  from remfx.utils import select_random_chunk
15
+ import multiprocessing
16
 
17
 
18
  # https://zenodo.org/record/1193957 -> VocalSet
19
 
20
+ ALL_EFFECTS = effect_lib.Pedalboard_Effects
21
  # print(ALL_EFFECTS)
22
 
23
 
 
147
  return file_list
148
 
149
 
150
+ def parallel_process_effects(
151
+ chunk_idx: int,
152
+ proc_root: str,
153
+ files: list,
154
+ chunk_size: int,
155
+ effects: list,
156
+ effects_to_keep: list,
157
+ num_kept_effects: tuple,
158
+ shuffle_kept_effects: bool,
159
+ effects_to_remove: list,
160
+ num_removed_effects: tuple,
161
+ shuffle_removed_effects: bool,
162
+ sample_rate: int,
163
+ target_lufs_db: float,
164
+ ):
165
+ chunk = None
166
+ random_dataset_choice = random.choice(files)
167
+ while chunk is None:
168
+ random_file_choice = random.choice(random_dataset_choice)
169
+ chunk = select_random_chunk(random_file_choice, chunk_size, sample_rate)
170
+
171
+ # Sum to mono
172
+ if chunk.shape[0] > 1:
173
+ chunk = chunk.sum(0, keepdim=True)
174
+
175
+ dry = chunk
176
+
177
+ # loudness normalization
178
+ normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db)
179
+
180
+ # Apply Kept Effects
181
+ # Shuffle effects if specified
182
+ if shuffle_kept_effects:
183
+ effect_indices = torch.randperm(len(effects_to_keep))
184
+ else:
185
+ effect_indices = torch.arange(len(effects_to_keep))
186
+
187
+ r1 = num_kept_effects[0]
188
+ r2 = num_kept_effects[1]
189
+ num_kept_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
190
+ effect_indices = effect_indices[:num_kept_effects]
191
+ # Index in effect settings
192
+ effect_names_to_apply = [effects_to_keep[i] for i in effect_indices]
193
+ effects_to_apply = [effects[i] for i in effect_names_to_apply]
194
+ # Apply
195
+ dry_labels = []
196
+ for effect in effects_to_apply:
197
+ # Normalize in-between effects
198
+ dry = normalize(effect(dry))
199
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
200
+
201
+ # Apply effects_to_remove
202
+ # Shuffle effects if specified
203
+ if shuffle_removed_effects:
204
+ effect_indices = torch.randperm(len(effects_to_remove))
205
+ else:
206
+ effect_indices = torch.arange(len(effects_to_remove))
207
+ wet = torch.clone(dry)
208
+ r1 = num_removed_effects[0]
209
+ r2 = num_removed_effects[1]
210
+ num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
211
+ effect_indices = effect_indices[:num_removed_effects]
212
+ # Index in effect settings
213
+ effect_names_to_apply = [effects_to_remove[i] for i in effect_indices]
214
+ effects_to_apply = [effects[i] for i in effect_names_to_apply]
215
+ # Apply
216
+ wet_labels = []
217
+ for effect in effects_to_apply:
218
+ # Normalize in-between effects
219
+ wet = normalize(effect(wet))
220
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
221
+
222
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
223
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
224
+
225
+ for label_idx in wet_labels:
226
+ wet_labels_tensor[label_idx] = 1.0
227
+
228
+ for label_idx in dry_labels:
229
+ dry_labels_tensor[label_idx] = 1.0
230
+
231
+ # Normalize
232
+ normalized_dry = normalize(dry)
233
+ normalized_wet = normalize(wet)
234
+
235
+ output_dir = proc_root / str(chunk_idx)
236
+ output_dir.mkdir(exist_ok=True)
237
+ torchaudio.save(output_dir / "input.wav", normalized_wet, sample_rate)
238
+ torchaudio.save(output_dir / "target.wav", normalized_dry, sample_rate)
239
+ torch.save(dry_labels_tensor, output_dir / "dry_effects.pt")
240
+ torch.save(wet_labels_tensor, output_dir / "wet_effects.pt")
241
+
242
+ # return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
243
+
244
+
245
  class EffectDataset(Dataset):
246
  def __init__(
247
  self,
 
259
  render_files: bool = True,
260
  render_root: str = None,
261
  mode: str = "train",
262
+ parallel: bool = True,
263
  ):
264
  super().__init__()
265
  self.chunks = []
 
274
  self.num_removed_effects = num_removed_effects
275
  self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
276
  self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
277
+ self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
278
  self.effects = effect_modules
279
  self.shuffle_kept_effects = shuffle_kept_effects
280
  self.shuffle_removed_effects = shuffle_removed_effects
 
289
  )
290
  self.validate_effect_input()
291
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
292
+ self.parallel = parallel
293
 
294
  self.files = locate_files(self.root, self.mode)
295
 
 
310
  if render_files:
311
  # Split audio file into chunks, resample, then apply random effects
312
  self.proc_root.mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
313
 
314
+ if self.parallel:
315
+ items = [
316
+ (
317
+ chunk_idx,
318
+ self.proc_root,
319
+ self.files,
320
+ self.chunk_size,
321
+ self.effects,
322
+ self.effects_to_keep,
323
+ self.num_kept_effects,
324
+ self.shuffle_kept_effects,
325
+ self.effects_to_remove,
326
+ self.num_removed_effects,
327
+ self.shuffle_removed_effects,
328
+ self.sample_rate,
329
+ -20.0,
330
+ )
331
+ for chunk_idx in range(self.total_chunks)
332
+ ]
333
+ with multiprocessing.Pool(processes=32) as pool:
334
+ pool.starmap(parallel_process_effects, items)
335
+ print(f"Done proccessing {self.total_chunks}", flush=True)
336
+ else:
337
+ for num_chunk in tqdm(range(self.total_chunks)):
338
+ chunk = None
339
+ random_dataset_choice = random.choice(self.files)
340
+ while chunk is None:
341
+ random_file_choice = random.choice(random_dataset_choice)
342
+ chunk = select_random_chunk(
343
+ random_file_choice, self.chunk_size, self.sample_rate
344
+ )
345
+
346
+ # Sum to mono
347
+ if chunk.shape[0] > 1:
348
+ chunk = chunk.sum(0, keepdim=True)
349
+
350
+ dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
351
+ output_dir = self.proc_root / str(num_chunk)
352
+ output_dir.mkdir(exist_ok=True)
353
+ torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
354
+ torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
355
+ torch.save(dry_effects, output_dir / "dry_effects.pt")
356
+ torch.save(wet_effects, output_dir / "wet_effects.pt")
357
 
358
  print("Finished rendering")
359
  else:
remfx/models.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import torchmetrics
3
  import pytorch_lightning as pl
4
  from torch import Tensor, nn
@@ -424,6 +425,30 @@ class TCNModel(nn.Module):
424
  return output
425
 
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  class FXClassifier(pl.LightningModule):
428
  def __init__(
429
  self,
@@ -431,13 +456,19 @@ class FXClassifier(pl.LightningModule):
431
  lr_weight_decay: float,
432
  sample_rate: float,
433
  network: nn.Module,
 
 
434
  ):
435
  super().__init__()
436
  self.lr = lr
437
  self.lr_weight_decay = lr_weight_decay
438
  self.sample_rate = sample_rate
439
  self.network = network
440
- self.effects = ["distortion", "compressor", "reverb", "chorus", "delay"]
 
 
 
 
441
 
442
  self.train_f1 = torchmetrics.classification.MultilabelF1Score(
443
  5, average="none", multidim_average="global"
@@ -449,20 +480,47 @@ class FXClassifier(pl.LightningModule):
449
  5, average="none", multidim_average="global"
450
  )
451
 
 
 
 
 
 
 
 
 
 
 
452
  self.metrics = {
453
  "train": self.train_f1,
454
  "valid": self.val_f1,
455
  "test": self.test_f1,
456
  }
457
 
 
 
 
 
 
 
458
  def forward(self, x: torch.Tensor, train: bool = False):
459
- return self.network(x)
460
 
461
  def common_step(self, batch, batch_idx, mode: str = "train"):
462
  train = True if mode == "train" else False
463
  x, y, dry_label, wet_label = batch
464
- pred_label = self(x, train)
465
- loss = nn.functional.cross_entropy(pred_label, wet_label)
 
 
 
 
 
 
 
 
 
 
 
466
  self.log(
467
  f"{mode}_loss",
468
  loss,
@@ -473,18 +531,7 @@ class FXClassifier(pl.LightningModule):
473
  sync_dist=True,
474
  )
475
 
476
- metrics = self.metrics[mode](pred_label, wet_label.long())
477
- avg_metrics = torch.mean(metrics)
478
-
479
- self.log(
480
- f"{mode}_f1_avg",
481
- avg_metrics,
482
- on_step=True,
483
- on_epoch=True,
484
- prog_bar=True,
485
- logger=True,
486
- sync_dist=True,
487
- )
488
 
489
  for idx, effect_name in enumerate(self.effects):
490
  self.log(
@@ -497,6 +544,20 @@ class FXClassifier(pl.LightningModule):
497
  sync_dist=True,
498
  )
499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  return loss
501
 
502
  def training_step(self, batch, batch_idx):
 
1
  import torch
2
+ import numpy as np
3
  import torchmetrics
4
  import pytorch_lightning as pl
5
  from torch import Tensor, nn
 
425
  return output
426
 
427
 
428
+ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
429
+ """Mixup data augmentation for time-domain signals.
430
+ Args:
431
+ x (torch.Tensor): Batch of time-domain signals, shape [batch, 1, time].
432
+ y (torch.Tensor): Batch of labels, shape [batch, n_classes].
433
+ alpha (float): Beta distribution parameter.
434
+ Returns:
435
+ torch.Tensor: Mixed time-domain signals, shape [batch, 1, time].
436
+ torch.Tensor: Mixed labels, shape [batch, n_classes].
437
+ torch.Tensor: Lambda
438
+ """
439
+ batch_size = x.size(0)
440
+ if alpha > 0:
441
+ lam = np.random.beta(alpha, alpha)
442
+ else:
443
+ lam = 1
444
+
445
+ index = torch.randperm(batch_size).to(x.device)
446
+ mixed_x = lam * x + (1 - lam) * x[index, :]
447
+ mixed_y = lam * y + (1 - lam) * y[index, :]
448
+
449
+ return mixed_x, mixed_y, lam
450
+
451
+
452
  class FXClassifier(pl.LightningModule):
453
  def __init__(
454
  self,
 
456
  lr_weight_decay: float,
457
  sample_rate: float,
458
  network: nn.Module,
459
+ mixup: bool = False,
460
+ label_smoothing: float = 0.0,
461
  ):
462
  super().__init__()
463
  self.lr = lr
464
  self.lr_weight_decay = lr_weight_decay
465
  self.sample_rate = sample_rate
466
  self.network = network
467
+ self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"]
468
+ self.mixup = mixup
469
+ self.label_smoothing = label_smoothing
470
+
471
+ self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
472
 
473
  self.train_f1 = torchmetrics.classification.MultilabelF1Score(
474
  5, average="none", multidim_average="global"
 
480
  5, average="none", multidim_average="global"
481
  )
482
 
483
+ self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
484
+ 5, threshold=0.5, average="macro", multidim_average="global"
485
+ )
486
+ self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
487
+ 5, threshold=0.5, average="macro", multidim_average="global"
488
+ )
489
+ self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
490
+ 5, threshold=0.5, average="macro", multidim_average="global"
491
+ )
492
+
493
  self.metrics = {
494
  "train": self.train_f1,
495
  "valid": self.val_f1,
496
  "test": self.test_f1,
497
  }
498
 
499
+ self.avg_metrics = {
500
+ "train": self.train_f1_avg,
501
+ "valid": self.val_f1_avg,
502
+ "test": self.test_f1_avg,
503
+ }
504
+
505
  def forward(self, x: torch.Tensor, train: bool = False):
506
+ return self.network(x, train=train)
507
 
508
  def common_step(self, batch, batch_idx, mode: str = "train"):
509
  train = True if mode == "train" else False
510
  x, y, dry_label, wet_label = batch
511
+
512
+ if mode == "train" and self.mixup:
513
+ x_mixed, label_mixed, lam = mixup(x, wet_label)
514
+ pred_label = self(x_mixed, train)
515
+ loss = self.loss_fn(pred_label, label_mixed)
516
+ print(torch.sigmoid(pred_label[0, ...]))
517
+ print(label_mixed[0, ...])
518
+ else:
519
+ pred_label = self(x, train)
520
+ loss = self.loss_fn(pred_label, wet_label)
521
+ print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
522
+ print(wet_label.long()[0, ...])
523
+
524
  self.log(
525
  f"{mode}_loss",
526
  loss,
 
531
  sync_dist=True,
532
  )
533
 
534
+ metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
 
 
 
 
 
 
 
 
 
 
 
535
 
536
  for idx, effect_name in enumerate(self.effects):
537
  self.log(
 
544
  sync_dist=True,
545
  )
546
 
547
+ avg_metrics = self.avg_metrics[mode](
548
+ torch.sigmoid(pred_label), wet_label.long()
549
+ )
550
+
551
+ self.log(
552
+ f"{mode}_f1_avg",
553
+ avg_metrics,
554
+ on_step=True,
555
+ on_epoch=True,
556
+ prog_bar=True,
557
+ logger=True,
558
+ sync_dist=True,
559
+ )
560
+
561
  return loss
562
 
563
  def training_step(self, batch, batch_idx):