mattricesound commited on
Commit
b676040
·
1 Parent(s): 8f8de0d

Remove FAD logging on input data during train

Browse files
Files changed (1) hide show
  1. remfx/models.py +4 -1
remfx/models.py CHANGED
@@ -127,6 +127,9 @@ class RemFXModel(pl.LightningModule):
127
  negate = -1
128
  else:
129
  negate = 1
 
 
 
130
  self.log(
131
  f"Input_{metric}",
132
  negate * self.metrics[metric](x, target),
@@ -215,7 +218,7 @@ class DemucsModel(torch.nn.Module):
215
  self.model = HDemucs(**kwargs)
216
  self.num_bins = kwargs["nfft"] // 2 + 1
217
  self.mrstftloss = MultiResolutionSTFTLoss(
218
- n_bins=self.num_bins, sample_rate=self.sample_rate
219
  )
220
  self.l1loss = torch.nn.L1Loss()
221
 
 
127
  negate = -1
128
  else:
129
  negate = 1
130
+ # Only Log FAD on test set
131
+ if metric == "FAD":
132
+ continue
133
  self.log(
134
  f"Input_{metric}",
135
  negate * self.metrics[metric](x, target),
 
218
  self.model = HDemucs(**kwargs)
219
  self.num_bins = kwargs["nfft"] // 2 + 1
220
  self.mrstftloss = MultiResolutionSTFTLoss(
221
+ n_bins=self.num_bins, sample_rate=sample_rate
222
  )
223
  self.l1loss = torch.nn.L1Loss()
224