Spaces:
Runtime error
Runtime error
Commit
·
f07a3b6
1
Parent(s):
1efdf7d
Fix wandb on gpu
Browse files
models.py
CHANGED
@@ -71,23 +71,26 @@ class OpenUnmixModel(pl.LightningModule):
|
|
71 |
sample_rate=SAMPLE_RATE,
|
72 |
n_fft=self.n_fft,
|
73 |
n_hop=self.hop_length,
|
74 |
-
)
|
75 |
outputs = s(x).squeeze(1)
|
76 |
log_wandb_audio_batch(
|
|
|
77 |
id="sample",
|
78 |
-
samples=x,
|
79 |
sampling_rate=SAMPLE_RATE,
|
80 |
caption=f"Epoch {self.current_epoch}",
|
81 |
)
|
82 |
log_wandb_audio_batch(
|
|
|
83 |
id="prediction",
|
84 |
-
samples=outputs,
|
85 |
sampling_rate=SAMPLE_RATE,
|
86 |
caption=f"Epoch {self.current_epoch}",
|
87 |
)
|
88 |
log_wandb_audio_batch(
|
|
|
89 |
id="target",
|
90 |
-
samples=target,
|
91 |
sampling_rate=SAMPLE_RATE,
|
92 |
caption=f"Epoch {self.current_epoch}",
|
93 |
)
|
@@ -146,12 +149,16 @@ class DiffusionGenerationModel(pl.LightningModule):
|
|
146 |
|
147 |
|
148 |
def log_wandb_audio_batch(
|
149 |
-
|
|
|
|
|
|
|
|
|
150 |
):
|
151 |
num_items = samples.shape[0]
|
152 |
samples = rearrange(samples, "b c t -> b t c")
|
153 |
for idx in range(num_items):
|
154 |
-
|
155 |
{
|
156 |
f"{id}_{idx}": wandb.Audio(
|
157 |
samples[idx].cpu().numpy(),
|
|
|
71 |
sample_rate=SAMPLE_RATE,
|
72 |
n_fft=self.n_fft,
|
73 |
n_hop=self.hop_length,
|
74 |
+
).to(self.device)
|
75 |
outputs = s(x).squeeze(1)
|
76 |
log_wandb_audio_batch(
|
77 |
+
logger=self.logger,
|
78 |
id="sample",
|
79 |
+
samples=x.cpu(),
|
80 |
sampling_rate=SAMPLE_RATE,
|
81 |
caption=f"Epoch {self.current_epoch}",
|
82 |
)
|
83 |
log_wandb_audio_batch(
|
84 |
+
logger=self.logger,
|
85 |
id="prediction",
|
86 |
+
samples=outputs.cpu(),
|
87 |
sampling_rate=SAMPLE_RATE,
|
88 |
caption=f"Epoch {self.current_epoch}",
|
89 |
)
|
90 |
log_wandb_audio_batch(
|
91 |
+
logger=self.loggger,
|
92 |
id="target",
|
93 |
+
samples=target.cpu(),
|
94 |
sampling_rate=SAMPLE_RATE,
|
95 |
caption=f"Epoch {self.current_epoch}",
|
96 |
)
|
|
|
149 |
|
150 |
|
151 |
def log_wandb_audio_batch(
|
152 |
+
logger: pl.loggers.WandbLogger,
|
153 |
+
id: str,
|
154 |
+
samples: Tensor,
|
155 |
+
sampling_rate: int,
|
156 |
+
caption: str = "",
|
157 |
):
|
158 |
num_items = samples.shape[0]
|
159 |
samples = rearrange(samples, "b c t -> b t c")
|
160 |
for idx in range(num_items):
|
161 |
+
logger.experiment.log(
|
162 |
{
|
163 |
f"{id}_{idx}": wandb.Audio(
|
164 |
samples[idx].cpu().numpy(),
|
train.py
CHANGED
@@ -12,7 +12,7 @@ TRAIN_SPLIT = 0.8
|
|
12 |
|
13 |
def main():
|
14 |
wandb_logger = WandbLogger(project="RemFX", save_dir="./")
|
15 |
-
trainer = pl.Trainer(logger=wandb_logger, max_epochs=
|
16 |
guitfx = GuitarFXDataset(
|
17 |
root="./data/egfx",
|
18 |
sample_rate=SAMPLE_RATE,
|
|
|
12 |
|
13 |
def main():
|
14 |
wandb_logger = WandbLogger(project="RemFX", save_dir="./")
|
15 |
+
trainer = pl.Trainer(logger=wandb_logger, max_epochs=100)
|
16 |
guitfx = GuitarFXDataset(
|
17 |
root="./data/egfx",
|
18 |
sample_rate=SAMPLE_RATE,
|