Spaces:
Sleeping
Sleeping
Commit
·
61b9249
1
Parent(s):
b99be38
Remove on_validation_epoch_start
Browse files- remfx/models.py +19 -19
remfx/models.py
CHANGED
@@ -133,26 +133,26 @@ class RemFXModel(pl.LightningModule):
|
|
133 |
prog_bar=True,
|
134 |
sync_dist=True,
|
135 |
)
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
sampling_rate=self.sample_rate,
|
152 |
-
caption=f"Epoch {self.current_epoch}",
|
153 |
-
)
|
154 |
-
self.log_next = False
|
155 |
-
self.model.train()
|
156 |
|
157 |
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
158 |
return self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
|
|
133 |
prog_bar=True,
|
134 |
sync_dist=True,
|
135 |
)
|
136 |
+
# Only run on first batch
|
137 |
+
if batch_idx == 0:
|
138 |
+
self.model.eval()
|
139 |
+
with torch.no_grad():
|
140 |
+
y = self.model.sample(x)
|
141 |
|
142 |
+
# Concat samples together for easier viewing in dashboard
|
143 |
+
# 2 seconds of silence between each sample
|
144 |
+
silence = torch.zeros_like(x)
|
145 |
+
silence = silence[:, : self.sample_rate * 2]
|
146 |
+
|
147 |
+
concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
|
148 |
+
log_wandb_audio_batch(
|
149 |
+
logger=self.logger,
|
150 |
+
id="prediction_input_target",
|
151 |
+
samples=concat_samples.cpu(),
|
152 |
+
sampling_rate=self.sample_rate,
|
153 |
+
caption=f"Epoch {self.current_epoch}",
|
154 |
+
)
|
155 |
+
self.model.train()
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
158 |
return self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|