mattricesound commited on
Commit
f73ffe0
·
1 Parent(s): fb92c76

Add first-time metrics logging to compare input metrics

Browse files
Files changed (1) hide show
  1. remfx/models.py +22 -0
remfx/models.py CHANGED
@@ -38,6 +38,8 @@ class RemFXModel(pl.LightningModule):
38
  "L1": L1Loss(),
39
  }
40
  )
 
 
41
 
42
  @property
43
  def device(self):
@@ -84,6 +86,26 @@ class RemFXModel(pl.LightningModule):
84
 
85
  return loss
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def on_validation_epoch_start(self):
88
  self.log_next = True
89
 
 
38
  "L1": L1Loss(),
39
  }
40
  )
41
+ # Log first batch metrics input vs output only once
42
+ self.log_first = True
43
 
44
  @property
45
  def device(self):
 
86
 
87
  return loss
88
 
89
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
90
+ if self.log_first:
91
+ x, target, label = batch
92
+ for metric in self.metrics:
93
+ # SISDR returns negative values, so negate them
94
+ if metric == "SISDR":
95
+ negate = -1
96
+ else:
97
+ negate = 1
98
+ self.log(
99
+ f"Input_{metric}",
100
+ negate * self.metrics[metric](x, target),
101
+ on_step=False,
102
+ on_epoch=True,
103
+ logger=True,
104
+ prog_bar=True,
105
+ sync_dist=True,
106
+ )
107
+ self.log_first = False
108
+
109
  def on_validation_epoch_start(self):
110
  self.log_next = True
111