mattricesound commited on
Commit
61b9249
·
1 Parent(s): b99be38

Remove on_validation_epoch_start

Browse files
Files changed (1) hide show
  1. remfx/models.py +19 -19
remfx/models.py CHANGED
@@ -133,26 +133,26 @@ class RemFXModel(pl.LightningModule):
133
  prog_bar=True,
134
  sync_dist=True,
135
  )
 
 
 
 
 
136
 
137
- self.model.eval()
138
- with torch.no_grad():
139
- y = self.model.sample(x)
140
-
141
- # Concat samples together for easier viewing in dashboard
142
- # 2 seconds of silence between each sample
143
- silence = torch.zeros_like(x)
144
- silence = silence[:, : self.sample_rate * 2]
145
-
146
- concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
147
- log_wandb_audio_batch(
148
- logger=self.logger,
149
- id="prediction_input_target",
150
- samples=concat_samples.cpu(),
151
- sampling_rate=self.sample_rate,
152
- caption=f"Epoch {self.current_epoch}",
153
- )
154
- self.log_next = False
155
- self.model.train()
156
 
157
  def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
158
  return self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
 
133
  prog_bar=True,
134
  sync_dist=True,
135
  )
136
+ # Only run on first batch
137
+ if batch_idx == 0:
138
+ self.model.eval()
139
+ with torch.no_grad():
140
+ y = self.model.sample(x)
141
 
142
+ # Concat samples together for easier viewing in dashboard
143
+ # 2 seconds of silence between each sample
144
+ silence = torch.zeros_like(x)
145
+ silence = silence[:, : self.sample_rate * 2]
146
+
147
+ concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
148
+ log_wandb_audio_batch(
149
+ logger=self.logger,
150
+ id="prediction_input_target",
151
+ samples=concat_samples.cpu(),
152
+ sampling_rate=self.sample_rate,
153
+ caption=f"Epoch {self.current_epoch}",
154
+ )
155
+ self.model.train()
 
 
 
 
 
156
 
157
  def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
158
  return self.on_validation_batch_start(batch, batch_idx, dataloader_idx)