mattricesound commited on
Commit
d2155c7
·
1 Parent(s): 3b4e474

Fix train batch initial audio

Browse files
Files changed (1) hide show
  1. remfx/models.py +18 -6
remfx/models.py CHANGED
@@ -39,7 +39,8 @@ class RemFXModel(pl.LightningModule):
39
  }
40
  )
41
  # Log first batch metrics input vs output only once
42
- self.log_first = True
 
43
 
44
  @property
45
  def device(self):
@@ -87,22 +88,33 @@ class RemFXModel(pl.LightningModule):
87
  return loss
88
 
89
  def on_train_batch_start(self, batch, batch_idx):
90
- if self.log_first:
91
  x, y, label = batch
 
 
 
92
  log_wandb_audio_batch(
93
  logger=self.logger,
94
- id="input_target",
95
- samples=x.cpu(),
96
  sampling_rate=self.sample_rate,
97
  caption="Training Data",
98
  )
 
 
 
 
 
 
 
 
99
 
100
  def on_validation_epoch_start(self):
101
  self.log_next = True
102
 
103
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
104
  x, target, label = batch
105
- if self.log_first:
106
  for metric in self.metrics:
107
  # SISDR returns negative values, so negate them
108
  if metric == "SISDR":
@@ -118,7 +130,7 @@ class RemFXModel(pl.LightningModule):
118
  prog_bar=True,
119
  sync_dist=True,
120
  )
121
- self.log_first = False
122
 
123
  if self.log_next:
124
  self.model.eval()
 
39
  }
40
  )
41
  # Log first batch metrics input vs output only once
42
+ self.log_first_metrics = True
43
+ self.log_train_audio = True
44
 
45
  @property
46
  def device(self):
 
88
  return loss
89
 
90
  def on_train_batch_start(self, batch, batch_idx):
91
+ if self.log_train_audio:
92
  x, y, label = batch
93
+ input_samples = rearrange(x, "b c t -> c (b t)")
94
+ target_samples = rearrange(y, "b c t -> c (b t)")
95
+
96
  log_wandb_audio_batch(
97
  logger=self.logger,
98
+ id="input_effected_audio",
99
+ samples=input_samples.cpu(),
100
  sampling_rate=self.sample_rate,
101
  caption="Training Data",
102
  )
103
+ log_wandb_audio_batch(
104
+ logger=self.logger,
105
+ id="target_audio",
106
+ samples=target_samples.cpu(),
107
+ sampling_rate=self.sample_rate,
108
+ caption="Target Data",
109
+ )
110
+ self.log_train_audio = False
111
 
112
  def on_validation_epoch_start(self):
113
  self.log_next = True
114
 
115
  def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
116
  x, target, label = batch
117
+ if self.log_first_metrics:
118
  for metric in self.metrics:
119
  # SISDR returns negative values, so negate them
120
  if metric == "SISDR":
 
130
  prog_bar=True,
131
  sync_dist=True,
132
  )
133
+ self.log_first_metrics = False
134
 
135
  if self.log_next:
136
  self.model.eval()