mattricesound commited on
Commit
aecaaea
·
1 Parent(s): f8fea2a

Fix causal cropping for input metrics

Browse files
Files changed (1) hide show
  1. remfx/models.py +4 -2
remfx/models.py CHANGED
@@ -188,8 +188,9 @@ class RemFX(pl.LightningModule):
188
 
189
  loss, output = self.model((x, y))
190
  # Crop target to match output
 
191
  if output.shape[-1] < y.shape[-1]:
192
- y = causal_crop(y, output.shape[-1])
193
  self.log(f"{mode}_loss", loss)
194
  # Metric logging
195
  with torch.no_grad():
@@ -204,13 +205,14 @@ class RemFX(pl.LightningModule):
204
  continue
205
  self.log(
206
  f"{mode}_{metric}",
207
- negate * self.metrics[metric](output, y),
208
  on_step=False,
209
  on_epoch=True,
210
  logger=True,
211
  prog_bar=True,
212
  sync_dist=True,
213
  )
 
214
  self.log(
215
  f"Input_{metric}",
216
  negate * self.metrics[metric](x, y),
 
188
 
189
  loss, output = self.model((x, y))
190
  # Crop target to match output
191
+ target = y
192
  if output.shape[-1] < y.shape[-1]:
193
+ target = causal_crop(y, output.shape[-1])
194
  self.log(f"{mode}_loss", loss)
195
  # Metric logging
196
  with torch.no_grad():
 
205
  continue
206
  self.log(
207
  f"{mode}_{metric}",
208
+ negate * self.metrics[metric](output, target),
209
  on_step=False,
210
  on_epoch=True,
211
  logger=True,
212
  prog_bar=True,
213
  sync_dist=True,
214
  )
215
+
216
  self.log(
217
  f"Input_{metric}",
218
  negate * self.metrics[metric](x, y),