mattricesound commited on
Commit
32701af
·
1 Parent(s): 568c3f1

Fix classifier for non-panns models

Browse files
Files changed (1) hide show
  1. remfx/models.py +95 -29
remfx/models.py CHANGED
@@ -12,6 +12,7 @@ from remfx.utils import spectrogram
12
  from remfx.tcn import TCN
13
  from remfx.utils import causal_crop
14
  from remfx import effects
 
15
  import asteroid
16
  import random
17
 
@@ -438,19 +439,54 @@ class FXClassifier(pl.LightningModule):
438
  self.mixup = mixup
439
  self.label_smoothing = label_smoothing
440
 
441
- self.loss_fn = torch.nn.BCELoss()
442
- self.metrics = torch.nn.ModuleDict()
443
- for effect in self.effects:
444
- self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
445
- task="binary"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  )
447
- self.metrics[f"valid_{effect}_acc"] = torchmetrics.classification.Accuracy(
448
- task="binary"
 
 
 
 
449
  )
450
- self.metrics[f"test_{effect}_acc"] = torchmetrics.classification.Accuracy(
451
- task="binary"
452
  )
453
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  def forward(self, x: torch.Tensor, train: bool = False):
455
  return self.network(x, train=train)
456
 
@@ -467,8 +503,13 @@ class FXClassifier(pl.LightningModule):
467
  else:
468
  outputs = self(x, train)
469
  loss = 0
470
- for idx, output in enumerate(outputs):
471
- loss += self.loss_fn(output.squeeze(-1), wet_label[..., idx])
 
 
 
 
 
472
 
473
  self.log(
474
  f"{mode}_loss",
@@ -480,32 +521,57 @@ class FXClassifier(pl.LightningModule):
480
  sync_dist=True,
481
  )
482
 
483
- acc_metrics = []
484
- for idx, effect_name in enumerate(self.effects):
485
- acc_metric = self.metrics[f"{mode}_{effect_name}_acc"](
486
- outputs[idx].squeeze(-1), wet_label[..., idx]
487
- )
 
 
 
 
 
 
 
 
 
 
 
 
488
  self.log(
489
- f"{mode}_{effect_name}_acc",
490
- acc_metric,
491
  on_step=True,
492
  on_epoch=True,
493
  prog_bar=True,
494
  logger=True,
495
  sync_dist=True,
496
  )
497
- acc_metrics.append(acc_metric)
498
-
499
- self.log(
500
- f"{mode}_avg_acc",
501
- torch.mean(torch.stack(acc_metrics)),
502
- on_step=True,
503
- on_epoch=True,
504
- prog_bar=True,
505
- logger=True,
506
- sync_dist=True,
507
- )
 
 
 
 
508
 
 
 
 
 
 
 
 
 
 
509
  return loss
510
 
511
  def training_step(self, batch, batch_idx):
 
12
  from remfx.tcn import TCN
13
  from remfx.utils import causal_crop
14
  from remfx import effects
15
+ from remfx.classifier import Cnn14
16
  import asteroid
17
  import random
18
 
 
439
  self.mixup = mixup
440
  self.label_smoothing = label_smoothing
441
 
442
+ if isinstance(self.network, Cnn14):
443
+ self.loss_fn = torch.nn.BCELoss()
444
+
445
+ self.metrics = torch.nn.ModuleDict()
446
+ for effect in self.effects:
447
+ self.metrics[
448
+ f"train_{effect}_acc"
449
+ ] = torchmetrics.classification.Accuracy(task="binary")
450
+ self.metrics[
451
+ f"valid_{effect}_acc"
452
+ ] = torchmetrics.classification.Accuracy(task="binary")
453
+ self.metrics[
454
+ f"test_{effect}_acc"
455
+ ] = torchmetrics.classification.Accuracy(task="binary")
456
+ else:
457
+ self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
458
+ self.train_f1 = torchmetrics.classification.MultilabelF1Score(
459
+ 5, average="none", multidim_average="global"
460
+ )
461
+ self.val_f1 = torchmetrics.classification.MultilabelF1Score(
462
+ 5, average="none", multidim_average="global"
463
+ )
464
+ self.test_f1 = torchmetrics.classification.MultilabelF1Score(
465
+ 5, average="none", multidim_average="global"
466
  )
467
+
468
+ self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
469
+ 5, threshold=0.5, average="macro", multidim_average="global"
470
+ )
471
+ self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
472
+ 5, threshold=0.5, average="macro", multidim_average="global"
473
  )
474
+ self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
475
+ 5, threshold=0.5, average="macro", multidim_average="global"
476
  )
477
 
478
+ self.metrics = {
479
+ "train": self.train_f1,
480
+ "valid": self.val_f1,
481
+ "test": self.test_f1,
482
+ }
483
+
484
+ self.avg_metrics = {
485
+ "train": self.train_f1_avg,
486
+ "valid": self.val_f1_avg,
487
+ "test": self.test_f1_avg,
488
+ }
489
+
490
  def forward(self, x: torch.Tensor, train: bool = False):
491
  return self.network(x, train=train)
492
 
 
503
  else:
504
  outputs = self(x, train)
505
  loss = 0
506
+ # Multi-head binary loss
507
+ if isinstance(self.network, Cnn14):
508
+ for idx, output in enumerate(outputs):
509
+ loss += self.loss_fn(output.squeeze(-1), wet_label[..., idx])
510
+ else:
511
+ # Output is a 2d tensor
512
+ loss = self.loss_fn(outputs, wet_label)
513
 
514
  self.log(
515
  f"{mode}_loss",
 
521
  sync_dist=True,
522
  )
523
 
524
+ if isinstance(self.network, Cnn14):
525
+ acc_metrics = []
526
+ for idx, effect_name in enumerate(self.effects):
527
+ acc_metric = self.metrics[f"{mode}_{effect_name}_acc"](
528
+ outputs[idx].squeeze(-1), wet_label[..., idx]
529
+ )
530
+ self.log(
531
+ f"{mode}_{effect_name}_acc",
532
+ acc_metric,
533
+ on_step=True,
534
+ on_epoch=True,
535
+ prog_bar=True,
536
+ logger=True,
537
+ sync_dist=True,
538
+ )
539
+ acc_metrics.append(acc_metric)
540
+
541
  self.log(
542
+ f"{mode}_avg_acc",
543
+ torch.mean(torch.stack(acc_metrics)),
544
  on_step=True,
545
  on_epoch=True,
546
  prog_bar=True,
547
  logger=True,
548
  sync_dist=True,
549
  )
550
+ else:
551
+ metrics = self.metrics[mode](torch.sigmoid(outputs), wet_label.long())
552
+ for idx, effect_name in enumerate(self.effects):
553
+ self.log(
554
+ f"{mode}_f1_{effect_name}",
555
+ metrics[idx],
556
+ on_step=True,
557
+ on_epoch=True,
558
+ prog_bar=True,
559
+ logger=True,
560
+ sync_dist=True,
561
+ )
562
+ avg_metrics = self.avg_metrics[mode](
563
+ torch.sigmoid(outputs), wet_label.long()
564
+ )
565
 
566
+ self.log(
567
+ f"{mode}_avg_acc",
568
+ avg_metrics,
569
+ on_step=True,
570
+ on_epoch=True,
571
+ prog_bar=True,
572
+ logger=True,
573
+ sync_dist=True,
574
+ )
575
  return loss
576
 
577
  def training_step(self, batch, batch_idx):