Spaces:
Runtime error
Runtime error
Commit
·
b676040
1
Parent(s):
8f8de0d
Remove FAD logging on input data during train
Browse files- remfx/models.py +4 -1
remfx/models.py
CHANGED
@@ -127,6 +127,9 @@ class RemFXModel(pl.LightningModule):
|
|
127 |
negate = -1
|
128 |
else:
|
129 |
negate = 1
|
|
|
|
|
|
|
130 |
self.log(
|
131 |
f"Input_{metric}",
|
132 |
negate * self.metrics[metric](x, target),
|
@@ -215,7 +218,7 @@ class DemucsModel(torch.nn.Module):
|
|
215 |
self.model = HDemucs(**kwargs)
|
216 |
self.num_bins = kwargs["nfft"] // 2 + 1
|
217 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
218 |
-
n_bins=self.num_bins, sample_rate=
|
219 |
)
|
220 |
self.l1loss = torch.nn.L1Loss()
|
221 |
|
|
|
127 |
negate = -1
|
128 |
else:
|
129 |
negate = 1
|
130 |
+
# Only Log FAD on test set
|
131 |
+
if metric == "FAD":
|
132 |
+
continue
|
133 |
self.log(
|
134 |
f"Input_{metric}",
|
135 |
negate * self.metrics[metric](x, target),
|
|
|
218 |
self.model = HDemucs(**kwargs)
|
219 |
self.num_bins = kwargs["nfft"] // 2 + 1
|
220 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
221 |
+
n_bins=self.num_bins, sample_rate=sample_rate
|
222 |
)
|
223 |
self.l1loss = torch.nn.L1Loss()
|
224 |
|