Spaces:
Sleeping
Sleeping
Commit
·
d2155c7
1
Parent(s):
3b4e474
Fix train batch initial audio
Browse files- remfx/models.py +18 -6
remfx/models.py
CHANGED
@@ -39,7 +39,8 @@ class RemFXModel(pl.LightningModule):
|
|
39 |
}
|
40 |
)
|
41 |
# Log first batch metrics input vs output only once
|
42 |
-
self.
|
|
|
43 |
|
44 |
@property
|
45 |
def device(self):
|
@@ -87,22 +88,33 @@ class RemFXModel(pl.LightningModule):
|
|
87 |
return loss
|
88 |
|
89 |
def on_train_batch_start(self, batch, batch_idx):
|
90 |
-
if self.
|
91 |
x, y, label = batch
|
|
|
|
|
|
|
92 |
log_wandb_audio_batch(
|
93 |
logger=self.logger,
|
94 |
-
id="
|
95 |
-
samples=
|
96 |
sampling_rate=self.sample_rate,
|
97 |
caption="Training Data",
|
98 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
def on_validation_epoch_start(self):
|
101 |
self.log_next = True
|
102 |
|
103 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
104 |
x, target, label = batch
|
105 |
-
if self.
|
106 |
for metric in self.metrics:
|
107 |
# SISDR returns negative values, so negate them
|
108 |
if metric == "SISDR":
|
@@ -118,7 +130,7 @@ class RemFXModel(pl.LightningModule):
|
|
118 |
prog_bar=True,
|
119 |
sync_dist=True,
|
120 |
)
|
121 |
-
self.
|
122 |
|
123 |
if self.log_next:
|
124 |
self.model.eval()
|
|
|
39 |
}
|
40 |
)
|
41 |
# Log first batch metrics input vs output only once
|
42 |
+
self.log_first_metrics = True
|
43 |
+
self.log_train_audio = True
|
44 |
|
45 |
@property
|
46 |
def device(self):
|
|
|
88 |
return loss
|
89 |
|
90 |
def on_train_batch_start(self, batch, batch_idx):
|
91 |
+
if self.log_train_audio:
|
92 |
x, y, label = batch
|
93 |
+
input_samples = rearrange(x, "b c t -> c (b t)")
|
94 |
+
target_samples = rearrange(y, "b c t -> c (b t)")
|
95 |
+
|
96 |
log_wandb_audio_batch(
|
97 |
logger=self.logger,
|
98 |
+
id="input_effected_audio",
|
99 |
+
samples=input_samples.cpu(),
|
100 |
sampling_rate=self.sample_rate,
|
101 |
caption="Training Data",
|
102 |
)
|
103 |
+
log_wandb_audio_batch(
|
104 |
+
logger=self.logger,
|
105 |
+
id="target_audio",
|
106 |
+
samples=target_samples.cpu(),
|
107 |
+
sampling_rate=self.sample_rate,
|
108 |
+
caption="Target Data",
|
109 |
+
)
|
110 |
+
self.log_train_audio = False
|
111 |
|
112 |
def on_validation_epoch_start(self):
|
113 |
self.log_next = True
|
114 |
|
115 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
116 |
x, target, label = batch
|
117 |
+
if self.log_first_metrics:
|
118 |
for metric in self.metrics:
|
119 |
# SISDR returns negative values, so negate them
|
120 |
if metric == "SISDR":
|
|
|
130 |
prog_bar=True,
|
131 |
sync_dist=True,
|
132 |
)
|
133 |
+
self.log_first_metrics = False
|
134 |
|
135 |
if self.log_next:
|
136 |
self.model.eval()
|