Spaces:
Runtime error
Runtime error
Commit
·
32701af
1
Parent(s):
568c3f1
Fix classifier for non-panns models
Browse files- remfx/models.py +95 -29
remfx/models.py
CHANGED
@@ -12,6 +12,7 @@ from remfx.utils import spectrogram
|
|
12 |
from remfx.tcn import TCN
|
13 |
from remfx.utils import causal_crop
|
14 |
from remfx import effects
|
|
|
15 |
import asteroid
|
16 |
import random
|
17 |
|
@@ -438,19 +439,54 @@ class FXClassifier(pl.LightningModule):
|
|
438 |
self.mixup = mixup
|
439 |
self.label_smoothing = label_smoothing
|
440 |
|
441 |
-
self.
|
442 |
-
|
443 |
-
|
444 |
-
self.metrics
|
445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
)
|
447 |
-
|
448 |
-
|
|
|
|
|
|
|
|
|
449 |
)
|
450 |
-
self.
|
451 |
-
|
452 |
)
|
453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
def forward(self, x: torch.Tensor, train: bool = False):
|
455 |
return self.network(x, train=train)
|
456 |
|
@@ -467,8 +503,13 @@ class FXClassifier(pl.LightningModule):
|
|
467 |
else:
|
468 |
outputs = self(x, train)
|
469 |
loss = 0
|
470 |
-
|
471 |
-
|
|
|
|
|
|
|
|
|
|
|
472 |
|
473 |
self.log(
|
474 |
f"{mode}_loss",
|
@@ -480,32 +521,57 @@ class FXClassifier(pl.LightningModule):
|
|
480 |
sync_dist=True,
|
481 |
)
|
482 |
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
self.log(
|
489 |
-
f"{mode}
|
490 |
-
|
491 |
on_step=True,
|
492 |
on_epoch=True,
|
493 |
prog_bar=True,
|
494 |
logger=True,
|
495 |
sync_dist=True,
|
496 |
)
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
|
|
|
|
|
|
|
|
508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
509 |
return loss
|
510 |
|
511 |
def training_step(self, batch, batch_idx):
|
|
|
12 |
from remfx.tcn import TCN
|
13 |
from remfx.utils import causal_crop
|
14 |
from remfx import effects
|
15 |
+
from remfx.classifier import Cnn14
|
16 |
import asteroid
|
17 |
import random
|
18 |
|
|
|
439 |
self.mixup = mixup
|
440 |
self.label_smoothing = label_smoothing
|
441 |
|
442 |
+
if isinstance(self.network, Cnn14):
|
443 |
+
self.loss_fn = torch.nn.BCELoss()
|
444 |
+
|
445 |
+
self.metrics = torch.nn.ModuleDict()
|
446 |
+
for effect in self.effects:
|
447 |
+
self.metrics[
|
448 |
+
f"train_{effect}_acc"
|
449 |
+
] = torchmetrics.classification.Accuracy(task="binary")
|
450 |
+
self.metrics[
|
451 |
+
f"valid_{effect}_acc"
|
452 |
+
] = torchmetrics.classification.Accuracy(task="binary")
|
453 |
+
self.metrics[
|
454 |
+
f"test_{effect}_acc"
|
455 |
+
] = torchmetrics.classification.Accuracy(task="binary")
|
456 |
+
else:
|
457 |
+
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
458 |
+
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
459 |
+
5, average="none", multidim_average="global"
|
460 |
+
)
|
461 |
+
self.val_f1 = torchmetrics.classification.MultilabelF1Score(
|
462 |
+
5, average="none", multidim_average="global"
|
463 |
+
)
|
464 |
+
self.test_f1 = torchmetrics.classification.MultilabelF1Score(
|
465 |
+
5, average="none", multidim_average="global"
|
466 |
)
|
467 |
+
|
468 |
+
self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
469 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
470 |
+
)
|
471 |
+
self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
472 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
473 |
)
|
474 |
+
self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
475 |
+
5, threshold=0.5, average="macro", multidim_average="global"
|
476 |
)
|
477 |
|
478 |
+
self.metrics = {
|
479 |
+
"train": self.train_f1,
|
480 |
+
"valid": self.val_f1,
|
481 |
+
"test": self.test_f1,
|
482 |
+
}
|
483 |
+
|
484 |
+
self.avg_metrics = {
|
485 |
+
"train": self.train_f1_avg,
|
486 |
+
"valid": self.val_f1_avg,
|
487 |
+
"test": self.test_f1_avg,
|
488 |
+
}
|
489 |
+
|
490 |
def forward(self, x: torch.Tensor, train: bool = False):
|
491 |
return self.network(x, train=train)
|
492 |
|
|
|
503 |
else:
|
504 |
outputs = self(x, train)
|
505 |
loss = 0
|
506 |
+
# Multi-head binary loss
|
507 |
+
if isinstance(self.network, Cnn14):
|
508 |
+
for idx, output in enumerate(outputs):
|
509 |
+
loss += self.loss_fn(output.squeeze(-1), wet_label[..., idx])
|
510 |
+
else:
|
511 |
+
# Output is a 2d tensor
|
512 |
+
loss = self.loss_fn(outputs, wet_label)
|
513 |
|
514 |
self.log(
|
515 |
f"{mode}_loss",
|
|
|
521 |
sync_dist=True,
|
522 |
)
|
523 |
|
524 |
+
if isinstance(self.network, Cnn14):
|
525 |
+
acc_metrics = []
|
526 |
+
for idx, effect_name in enumerate(self.effects):
|
527 |
+
acc_metric = self.metrics[f"{mode}_{effect_name}_acc"](
|
528 |
+
outputs[idx].squeeze(-1), wet_label[..., idx]
|
529 |
+
)
|
530 |
+
self.log(
|
531 |
+
f"{mode}_{effect_name}_acc",
|
532 |
+
acc_metric,
|
533 |
+
on_step=True,
|
534 |
+
on_epoch=True,
|
535 |
+
prog_bar=True,
|
536 |
+
logger=True,
|
537 |
+
sync_dist=True,
|
538 |
+
)
|
539 |
+
acc_metrics.append(acc_metric)
|
540 |
+
|
541 |
self.log(
|
542 |
+
f"{mode}_avg_acc",
|
543 |
+
torch.mean(torch.stack(acc_metrics)),
|
544 |
on_step=True,
|
545 |
on_epoch=True,
|
546 |
prog_bar=True,
|
547 |
logger=True,
|
548 |
sync_dist=True,
|
549 |
)
|
550 |
+
else:
|
551 |
+
metrics = self.metrics[mode](torch.sigmoid(outputs), wet_label.long())
|
552 |
+
for idx, effect_name in enumerate(self.effects):
|
553 |
+
self.log(
|
554 |
+
f"{mode}_f1_{effect_name}",
|
555 |
+
metrics[idx],
|
556 |
+
on_step=True,
|
557 |
+
on_epoch=True,
|
558 |
+
prog_bar=True,
|
559 |
+
logger=True,
|
560 |
+
sync_dist=True,
|
561 |
+
)
|
562 |
+
avg_metrics = self.avg_metrics[mode](
|
563 |
+
torch.sigmoid(outputs), wet_label.long()
|
564 |
+
)
|
565 |
|
566 |
+
self.log(
|
567 |
+
f"{mode}_avg_acc",
|
568 |
+
avg_metrics,
|
569 |
+
on_step=True,
|
570 |
+
on_epoch=True,
|
571 |
+
prog_bar=True,
|
572 |
+
logger=True,
|
573 |
+
sync_dist=True,
|
574 |
+
)
|
575 |
return loss
|
576 |
|
577 |
def training_step(self, batch, batch_idx):
|