Spaces:
Sleeping
Sleeping
Christian J. Steinmetz
commited on
Commit
·
1b6bb59
1
Parent(s):
c756b1d
adding mixup augmentation and fixing ordering of labels
Browse files
cfg/model/cls_panns_16k_mixup.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
model:
|
3 |
+
_target_: remfx.models.FXClassifier
|
4 |
+
lr: 3e-4
|
5 |
+
lr_weight_decay: 1e-3
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
mixup: True
|
8 |
+
network:
|
9 |
+
_target_: remfx.classifier.Cnn14
|
10 |
+
num_classes: ${num_classes}
|
11 |
+
n_fft: 2048
|
12 |
+
hop_length: 512
|
13 |
+
n_mels: 128
|
14 |
+
sample_rate: 44100
|
15 |
+
model_sample_rate: 16000
|
16 |
+
|
cfg/model/cls_panns_44k.yaml
CHANGED
@@ -4,12 +4,14 @@ model:
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
|
|
7 |
network:
|
8 |
_target_: remfx.classifier.Cnn14
|
9 |
num_classes: ${num_classes}
|
10 |
-
n_fft:
|
11 |
-
hop_length:
|
12 |
n_mels: 128
|
13 |
sample_rate: 44100
|
14 |
model_sample_rate: 44100
|
15 |
-
specaugment:
|
|
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
7 |
+
mixup: False
|
8 |
network:
|
9 |
_target_: remfx.classifier.Cnn14
|
10 |
num_classes: ${num_classes}
|
11 |
+
n_fft: 2048
|
12 |
+
hop_length: 512
|
13 |
n_mels: 128
|
14 |
sample_rate: 44100
|
15 |
model_sample_rate: 44100
|
16 |
+
specaugment: False
|
17 |
+
|
cfg/model/{cls_panns_44k_noaug.yaml → cls_panns_44k_mixup.yaml}
RENAMED
@@ -4,12 +4,14 @@ model:
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
|
|
7 |
network:
|
8 |
_target_: remfx.classifier.Cnn14
|
9 |
num_classes: ${num_classes}
|
10 |
-
n_fft:
|
11 |
-
hop_length:
|
12 |
-
n_mels:
|
13 |
sample_rate: 44100
|
14 |
model_sample_rate: 44100
|
15 |
-
specaugment: False
|
|
|
|
4 |
lr: 3e-4
|
5 |
lr_weight_decay: 1e-3
|
6 |
sample_rate: ${sample_rate}
|
7 |
+
mixup: True
|
8 |
network:
|
9 |
_target_: remfx.classifier.Cnn14
|
10 |
num_classes: ${num_classes}
|
11 |
+
n_fft: 2048
|
12 |
+
hop_length: 512
|
13 |
+
n_mels: 64
|
14 |
sample_rate: 44100
|
15 |
model_sample_rate: 44100
|
16 |
+
specaugment: False
|
17 |
+
|
cfg/model/cls_panns_pt_mixup.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
model:
|
3 |
+
_target_: remfx.models.FXClassifier
|
4 |
+
lr: 3e-4
|
5 |
+
lr_weight_decay: 1e-3
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
mixup: True
|
8 |
+
network:
|
9 |
+
_target_: remfx.classifier.PANNs
|
10 |
+
num_classes: ${num_classes}
|
11 |
+
sample_rate: ${sample_rate}
|
12 |
+
|
remfx/models.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import torch
|
|
|
2 |
import torchmetrics
|
3 |
import pytorch_lightning as pl
|
4 |
from torch import Tensor, nn
|
@@ -409,6 +410,30 @@ class TCNModel(nn.Module):
|
|
409 |
return output
|
410 |
|
411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
class FXClassifier(pl.LightningModule):
|
413 |
def __init__(
|
414 |
self,
|
@@ -416,13 +441,15 @@ class FXClassifier(pl.LightningModule):
|
|
416 |
lr_weight_decay: float,
|
417 |
sample_rate: float,
|
418 |
network: nn.Module,
|
|
|
419 |
):
|
420 |
super().__init__()
|
421 |
self.lr = lr
|
422 |
self.lr_weight_decay = lr_weight_decay
|
423 |
self.sample_rate = sample_rate
|
424 |
self.network = network
|
425 |
-
self.effects = ["
|
|
|
426 |
|
427 |
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
428 |
5, average="none", multidim_average="global"
|
@@ -441,13 +468,24 @@ class FXClassifier(pl.LightningModule):
|
|
441 |
}
|
442 |
|
443 |
def forward(self, x: torch.Tensor, train: bool = False):
|
444 |
-
return self.network(x)
|
445 |
|
446 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
447 |
train = True if mode == "train" else False
|
448 |
x, y, dry_label, wet_label = batch
|
449 |
-
|
450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
self.log(
|
452 |
f"{mode}_loss",
|
453 |
loss,
|
@@ -458,7 +496,7 @@ class FXClassifier(pl.LightningModule):
|
|
458 |
sync_dist=True,
|
459 |
)
|
460 |
|
461 |
-
metrics = self.metrics[mode](pred_label, wet_label.long())
|
462 |
avg_metrics = torch.mean(metrics)
|
463 |
|
464 |
self.log(
|
|
|
1 |
import torch
|
2 |
+
import numpy as np
|
3 |
import torchmetrics
|
4 |
import pytorch_lightning as pl
|
5 |
from torch import Tensor, nn
|
|
|
410 |
return output
|
411 |
|
412 |
|
413 |
+
def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
|
414 |
+
"""Mixup data augmentation for time-domain signals.
|
415 |
+
Args:
|
416 |
+
x (torch.Tensor): Batch of time-domain signals, shape [batch, 1, time].
|
417 |
+
y (torch.Tensor): Batch of labels, shape [batch, n_classes].
|
418 |
+
alpha (float): Beta distribution parameter.
|
419 |
+
Returns:
|
420 |
+
torch.Tensor: Mixed time-domain signals, shape [batch, 1, time].
|
421 |
+
torch.Tensor: Mixed labels, shape [batch, n_classes].
|
422 |
+
torch.Tensor: Lambda
|
423 |
+
"""
|
424 |
+
batch_size = x.size(0)
|
425 |
+
if alpha > 0:
|
426 |
+
lam = np.random.beta(alpha, alpha)
|
427 |
+
else:
|
428 |
+
lam = 1
|
429 |
+
|
430 |
+
index = torch.randperm(batch_size).to(x.device)
|
431 |
+
mixed_x = lam * x + (1 - lam) * x[index, :]
|
432 |
+
mixed_y = lam * y + (1 - lam) * y[index, :]
|
433 |
+
|
434 |
+
return mixed_x, mixed_y, lam
|
435 |
+
|
436 |
+
|
437 |
class FXClassifier(pl.LightningModule):
|
438 |
def __init__(
|
439 |
self,
|
|
|
441 |
lr_weight_decay: float,
|
442 |
sample_rate: float,
|
443 |
network: nn.Module,
|
444 |
+
mixup: bool = False,
|
445 |
):
|
446 |
super().__init__()
|
447 |
self.lr = lr
|
448 |
self.lr_weight_decay = lr_weight_decay
|
449 |
self.sample_rate = sample_rate
|
450 |
self.network = network
|
451 |
+
self.effects = ["Reverb", "Chorus", "Delay", "Distortion", "Compressor"]
|
452 |
+
self.mixup = mixup
|
453 |
|
454 |
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
455 |
5, average="none", multidim_average="global"
|
|
|
468 |
}
|
469 |
|
470 |
def forward(self, x: torch.Tensor, train: bool = False):
|
471 |
+
return self.network(x, train=train)
|
472 |
|
473 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
474 |
train = True if mode == "train" else False
|
475 |
x, y, dry_label, wet_label = batch
|
476 |
+
|
477 |
+
if mode == "train" and self.mixup:
|
478 |
+
x_mixed, label_mixed, lam = mixup(x, wet_label)
|
479 |
+
pred_label = self(x_mixed, train)
|
480 |
+
loss = nn.functional.cross_entropy(pred_label, label_mixed)
|
481 |
+
print(torch.sigmoid(pred_label[0, ...]))
|
482 |
+
print(label_mixed[0, ...])
|
483 |
+
else:
|
484 |
+
pred_label = self(x, train)
|
485 |
+
loss = nn.functional.cross_entropy(pred_label, wet_label)
|
486 |
+
print(torch.where(torch.sigmoid(pred_label[0, ...]) > 0.5, 1.0, 0.0).long())
|
487 |
+
print(wet_label.long()[0, ...])
|
488 |
+
|
489 |
self.log(
|
490 |
f"{mode}_loss",
|
491 |
loss,
|
|
|
496 |
sync_dist=True,
|
497 |
)
|
498 |
|
499 |
+
metrics = self.metrics[mode](torch.sigmoid(pred_label), wet_label.long())
|
500 |
avg_metrics = torch.mean(metrics)
|
501 |
|
502 |
self.log(
|