Christian J. Steinmetz commited on
Commit
1b6bb59
·
1 Parent(s): c756b1d

adding mixup augmentation and fixing ordering of labels

Browse files
cfg/model/cls_panns_16k_mixup.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: True
8
+ network:
9
+ _target_: remfx.classifier.Cnn14
10
+ num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 128
14
+ sample_rate: 44100
15
+ model_sample_rate: 16000
16
+
cfg/model/cls_panns_44k.yaml CHANGED
@@ -4,12 +4,14 @@ model:
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
 
7
  network:
8
  _target_: remfx.classifier.Cnn14
9
  num_classes: ${num_classes}
10
- n_fft: 1024
11
- hop_length: 256
12
  n_mels: 128
13
  sample_rate: 44100
14
  model_sample_rate: 44100
15
- specaugment: True
 
 
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
+ mixup: False
8
  network:
9
  _target_: remfx.classifier.Cnn14
10
  num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
  n_mels: 128
14
  sample_rate: 44100
15
  model_sample_rate: 44100
16
+ specaugment: False
17
+
cfg/model/{cls_panns_44k_noaug.yaml → cls_panns_44k_mixup.yaml} RENAMED
@@ -4,12 +4,14 @@ model:
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
 
7
  network:
8
  _target_: remfx.classifier.Cnn14
9
  num_classes: ${num_classes}
10
- n_fft: 1024
11
- hop_length: 256
12
- n_mels: 128
13
  sample_rate: 44100
14
  model_sample_rate: 44100
15
- specaugment: False
 
 
4
  lr: 3e-4
5
  lr_weight_decay: 1e-3
6
  sample_rate: ${sample_rate}
7
+ mixup: True
8
  network:
9
  _target_: remfx.classifier.Cnn14
10
  num_classes: ${num_classes}
11
+ n_fft: 2048
12
+ hop_length: 512
13
+ n_mels: 64
14
  sample_rate: 44100
15
  model_sample_rate: 44100
16
+ specaugment: False
17
+
cfg/model/cls_panns_pt_mixup.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ model:
3
+ _target_: remfx.models.FXClassifier
4
+ lr: 3e-4
5
+ lr_weight_decay: 1e-3
6
+ sample_rate: ${sample_rate}
7
+ mixup: True
8
+ network:
9
+ _target_: remfx.classifier.PANNs
10
+ num_classes: ${num_classes}
11
+ sample_rate: ${sample_rate}
12
+
remfx/models.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import torchmetrics
3
  import pytorch_lightning as pl
4
  from torch import Tensor, nn
@@ -409,6 +410,30 @@ class TCNModel(nn.Module):
409
  return output
410
 
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  class FXClassifier(pl.LightningModule):
413
  def __init__(
414
  self,
@@ -416,13 +441,15 @@ class FXClassifier(pl.LightningModule):
416
  lr_weight_decay: float,
417
  sample_rate: float,
418
  network: nn.Module,
 
419
  ):
420
  super().__init__()
421
  self.lr = lr
422
  self.lr_weight_decay = lr_weight_decay
423
  self.sample_rate = sample_rate
424
  self.network = network
425
- self.effects = ["distortion", "compressor", "reverb", "chorus", "delay"]
 
426
 
427
  self.train_f1 = torchmetrics.classification.MultilabelF1Score(
428
  5, average="none", multidim_average="global"
@@ -441,13 +468,24 @@ class FXClassifier(pl.LightningModule):
441
  }
442
 
443
  def forward(self, x: torch.Tensor, train: bool = False):
444
- return self.network(x)
445
 
446
  def common_step(self, batch, batch_idx, mode: str = "train"):
447
  train = True if mode == "train" else False
448
  x, y, dry_label, wet_label = batch
449
- pred_label = self(x, train)
450
- loss = nn.functional.cross_entropy(pred_label, wet_label)
 
 
 
 
 
 
 
 
 
 
 
451
  self.log(
452
  f"{mode}_loss",
453
  loss,
@@ -458,7 +496,7 @@ class FXClassifier(pl.LightningModule):
458
  sync_dist=True,
459
  )
460
 
461
- metrics = self.metrics[mode](pred_label, wet_label.long())
462
  avg_metrics = torch.mean(metrics)
463
 
464
  self.log(
 
1
  import torch
2
+ import numpy as np
3
  import torchmetrics
4
  import pytorch_lightning as pl
5
  from torch import Tensor, nn
 
410
  return output
411
 
412
 
413
+ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
414
+ """Mixup data augmentation for time-domain signals.
415
+ Args:
416
+ x (torch.Tensor): Batch of time-domain signals, shape [batch, 1, time].
417
+ y (torch.Tensor): Batch of labels, shape [batch, n_classes].
418
+ alpha (float): Beta distribution parameter.
419
+ Returns:
420
+ torch.Tensor: Mixed time-domain signals, shape [batch, 1, time].
421
+ torch.Tensor: Mixed labels, shape [batch, n_classes].
422
+ torch.Tensor: Lambda
423
+ """
424
+ batch_size = x.size(0)
425
+ if alpha > 0:
426
+ lam = np.random.beta(alpha, alpha)
427
+ else:
428
+ lam = 1
429
+
430
+ index = torch.randperm(batch_size).to(x.device)
431
+ mixed_x = lam * x + (1 - lam) * x[index, :]
432
+ mixed_y = lam * y + (1 - lam) * y[index, :]
433
+
434
+ return mixed_x, mixed_y, lam
435
+
436
+
437
  class FXClassifier(pl.LightningModule):
438
  def __init__(
439
  self,
 
441
  lr_weight_decay: float,
442
  sample_rate: float,
443
  network: nn.Module,
444
+ mixup: bool = False,
445
  ):
446
  super().__init__()
447
  self.lr = lr
448
  self.lr_weight_decay = lr_weight_decay
449
  self.sample_rate = sample_rate
450
  self.network = network
451
+ self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"]
452
+ self.mixup = mixup
453
 
454
  self.train_f1 = torchmetrics.classification.MultilabelF1Score(
455
  5, average="none", multidim_average="global"
 
468
  }
469
 
470
  def forward(self, x: torch.Tensor, train: bool = False):
471
+ return self.network(x, train=train)
472
 
473
  def common_step(self, batch, batch_idx, mode: str = "train"):
474
  train = True if mode == "train" else False
475
  x, y, dry_label, wet_label = batch
476
+
477
+ if mode == "train" and self.mixup:
478
+ x_mixed, label_mixed, lam = mixup(x, wet_label)
479
+ pred_label = self(x_mixed, train)
480
+ loss = nn.functional.cross_entropy(pred_label, label_mixed)
481
+ print(torch.sigmoid(pred_label[0, ...]))
482
+ print(label_mixed[0, ...])
483
+ else:
484
+ pred_label = self(x, train)
485
+ loss = nn.functional.cross_entropy(pred_label, wet_label)
486
+ print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
487
+ print(wet_label.long()[0, ...])
488
+
489
  self.log(
490
  f"{mode}_loss",
491
  loss,
 
496
  sync_dist=True,
497
  )
498
 
499
+ metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
500
  avg_metrics = torch.mean(metrics)
501
 
502
  self.log(