mattricesound commited on
Commit
9a9a2c9
·
1 Parent(s): 7d6f241

Update callbacks, debug new models

Browse files
cfg/config.yaml CHANGED
@@ -41,6 +41,11 @@ callbacks:
41
  learning_rate_monitor:
42
  _target_: pytorch_lightning.callbacks.LearningRateMonitor
43
  logging_interval: "step"
 
 
 
 
 
44
 
45
  datamodule:
46
  _target_: remfx.datasets.VocalSetDatamodule
@@ -116,4 +121,3 @@ trainer:
116
  devices: 1
117
  gradient_clip_val: 10.0
118
  max_steps: 50000
119
-
 
41
  learning_rate_monitor:
42
  _target_: pytorch_lightning.callbacks.LearningRateMonitor
43
  logging_interval: "step"
44
+ audio_logging:
45
+ _target_: remfx.callbacks.AudioCallback
46
+ sample_rate: ${sample_rate}
47
+ metric_logging:
48
+ _target_: remfx.callbacks.MetricCallback
49
 
50
  datamodule:
51
  _target_: remfx.datasets.VocalSetDatamodule
 
121
  devices: 1
122
  gradient_clip_val: 10.0
123
  max_steps: 50000
 
cfg/model/audio_diffusion.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
  model:
3
- _target_: remfx.models.RemFx
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
@@ -13,4 +13,4 @@ model:
13
  datamodule:
14
  dataset:
15
  effect_types: ["Clean"]
16
- batch_size: 2
 
1
  # @package _global_
2
  model:
3
+ _target_: remfx.models.RemFX
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
13
  datamodule:
14
  dataset:
15
  effect_types: ["Clean"]
16
+ batch_size: 2
cfg/model/dcunet.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
  model:
3
- _target_: remfx.models.RemFx
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
@@ -9,7 +9,7 @@ model:
9
  sample_rate: ${sample_rate}
10
  network:
11
  _target_: remfx.models.DCUNetModel
12
- spec_dim: 256 + 1
13
  hidden_dim: 768
14
  filter_len: 512
15
  hop_len: 64
@@ -19,4 +19,6 @@ model:
19
  refine_layers: 1
20
  is_mask: True
21
  norm: 'ins'
22
- act: 'comp'
 
 
 
1
  # @package _global_
2
  model:
3
+ _target_: remfx.models.RemFX
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
9
  sample_rate: ${sample_rate}
10
  network:
11
  _target_: remfx.models.DCUNetModel
12
+ spec_dim: 257
13
  hidden_dim: 768
14
  filter_len: 512
15
  hop_len: 64
 
19
  refine_layers: 1
20
  is_mask: True
21
  norm: 'ins'
22
+ act: 'comp'
23
+ sample_rate: ${sample_rate}
24
+ num_bins: 1025
cfg/model/demucs.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
  model:
3
- _target_: remfx.models.RemFx
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
@@ -13,4 +13,3 @@ model:
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
16
-
 
1
  # @package _global_
2
  model:
3
+ _target_: remfx.models.RemFX
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
13
  audio_channels: 1
14
  nfft: 4096
15
  sample_rate: ${sample_rate}
 
cfg/model/dptnet.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
  model:
3
- _target_: remfx.models.RemFx
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
@@ -16,3 +16,5 @@ model:
16
  segment_size: 250
17
  nspk: 1
18
  win_len: 2
 
 
 
1
  # @package _global_
2
  model:
3
+ _target_: remfx.models.RemFX
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
16
  segment_size: 250
17
  nspk: 1
18
  win_len: 2
19
+ sample_rate: ${sample_rate}
20
+ num_bins: 1025
cfg/model/umx.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
  model:
3
- _target_: remfx.models.RemFx
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
@@ -14,4 +14,3 @@ model:
14
  n_channels: 1
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
17
-
 
1
  # @package _global_
2
  model:
3
+ _target_: remfx.models.RemFX
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
 
14
  n_channels: 1
15
  alpha: 0.3
16
  sample_rate: ${sample_rate}
 
remfx/callbacks.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.callbacks import Callback
2
+ import pytorch_lightning as pl
3
+ from einops import rearrange
4
+ import torch
5
+ import wandb
6
+ from torch import Tensor
7
+
8
+
9
+ class AudioCallback(Callback):
10
+ def __init__(self, sample_rate, *args, **kwargs):
11
+ super().__init__(*args, **kwargs)
12
+ self.log_train_audio = True
13
+ self.sample_rate = sample_rate
14
+
15
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
16
+ # Log initial audio
17
+ if self.log_train_audio:
18
+ x, y, _, _ = batch
19
+ # Concat samples together for easier viewing in dashboard
20
+ input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
21
+ target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
22
+
23
+ log_wandb_audio_batch(
24
+ logger=trainer.logger,
25
+ id="input_effected_audio",
26
+ samples=input_samples.cpu(),
27
+ sampling_rate=self.sample_rate,
28
+ caption="Training Data",
29
+ )
30
+ log_wandb_audio_batch(
31
+ logger=trainer.logger,
32
+ id="target_audio",
33
+ samples=target_samples.cpu(),
34
+ sampling_rate=self.sample_rate,
35
+ caption="Target Data",
36
+ )
37
+ self.log_train_audio = False
38
+
39
+ def on_validation_batch_start(
40
+ self, trainer, pl_module, batch, batch_idx, dataloader_idx
41
+ ):
42
+ x, target, _, _ = batch
43
+ # Only run on first batch
44
+ if batch_idx == 0:
45
+ with torch.no_grad():
46
+ y = pl_module.model.sample(x)
47
+ # Concat samples together for easier viewing in dashboard
48
+ # 2 seconds of silence between each sample
49
+ silence = torch.zeros_like(x)
50
+ silence = silence[:, : self.sample_rate * 2]
51
+
52
+ concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
53
+ log_wandb_audio_batch(
54
+ logger=trainer.logger,
55
+ id="prediction_input_target",
56
+ samples=concat_samples.cpu(),
57
+ sampling_rate=self.sample_rate,
58
+ caption=f"Epoch {trainer.current_epoch}",
59
+ )
60
+
61
+ def on_test_batch_start(self, *args):
62
+ self.on_validation_batch_start(*args)
63
+
64
+
65
+ class MetricCallback(Callback):
66
+ def on_validation_batch_start(
67
+ self, trainer, pl_module, batch, batch_idx, dataloader_idx
68
+ ):
69
+ x, target, _, _ = batch
70
+ # Log Input Metrics
71
+ for metric in pl_module.metrics:
72
+ # SISDR returns negative values, so negate them
73
+ if metric == "SISDR":
74
+ negate = -1
75
+ else:
76
+ negate = 1
77
+ # Only Log FAD on test set
78
+ if metric == "FAD":
79
+ continue
80
+ pl_module.log(
81
+ f"Input_{metric}",
82
+ negate * pl_module.metrics[metric](x, target),
83
+ on_step=False,
84
+ on_epoch=True,
85
+ logger=True,
86
+ prog_bar=True,
87
+ sync_dist=True,
88
+ )
89
+
90
+ def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
91
+ self.on_validation_batch_start(
92
+ trainer, pl_module, batch, batch_idx, dataloader_idx
93
+ )
94
+ # Log FAD
95
+ x, target, _, _ = batch
96
+ pl_module.log(
97
+ "Input_FAD",
98
+ pl_module.metrics["FAD"](x, target),
99
+ on_step=False,
100
+ on_epoch=True,
101
+ logger=True,
102
+ prog_bar=True,
103
+ sync_dist=True,
104
+ )
105
+
106
+
107
+ def log_wandb_audio_batch(
108
+ logger: pl.loggers.WandbLogger,
109
+ id: str,
110
+ samples: Tensor,
111
+ sampling_rate: int,
112
+ caption: str = "",
113
+ max_items: int = 10,
114
+ ):
115
+ num_items = samples.shape[0]
116
+ samples = rearrange(samples, "b c t -> b t c")
117
+ for idx in range(num_items):
118
+ if idx >= max_items:
119
+ break
120
+ logger.experiment.log(
121
+ {
122
+ f"{id}_{idx}": wandb.Audio(
123
+ samples[idx].cpu().numpy(),
124
+ caption=caption,
125
+ sample_rate=sampling_rate,
126
+ )
127
+ }
128
+ )
remfx/datasets.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  import shutil
6
  import torchaudio
7
  import pytorch_lightning as pl
8
- import torch.nn.functional as F
9
 
10
  from tqdm import tqdm
11
  from pathlib import Path
 
5
  import shutil
6
  import torchaudio
7
  import pytorch_lightning as pl
 
8
 
9
  from tqdm import tqdm
10
  from pathlib import Path
remfx/dcunet.py CHANGED
@@ -5,11 +5,11 @@ import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import numpy as np
8
- from utils import single, concat_complex
9
  from torch.nn.init import calculate_gain
10
  from typing import Tuple
11
  from scipy.signal import get_window
12
  from librosa.util import pad_center
 
13
 
14
 
15
  class ComplexConvBlock(nn.Module):
@@ -549,7 +549,7 @@ class ComplexActLayer(nn.Module):
549
 
550
  def forward(self, x):
551
  real, img = x.chunk(2, 1)
552
- return torch.cat([F.leaky_relu_(real), torch.tanh(img) * np.pi], dim=1)
553
 
554
 
555
  class STFT(nn.Module):
 
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import numpy as np
 
8
  from torch.nn.init import calculate_gain
9
  from typing import Tuple
10
  from scipy.signal import get_window
11
  from librosa.util import pad_center
12
+ from remfx.utils import single, concat_complex
13
 
14
 
15
  class ComplexConvBlock(nn.Module):
 
549
 
550
  def forward(self, x):
551
  real, img = x.chunk(2, 1)
552
+ return torch.cat([F.leaky_relu(real), torch.tanh(img) * np.pi], dim=1)
553
 
554
 
555
  class STFT(nn.Module):
remfx/dptnet.py CHANGED
@@ -57,11 +57,10 @@ class DPTNet_base(nn.Module):
57
  self.mask_conv1x1 = nn.Conv1d(self.feature_dim, self.enc_dim, 1, bias=False)
58
  self.decoder = DPTDecoder(n_filters=enc_dim, window_size=win_len)
59
 
60
- def forward(self, batch):
61
  """
62
  mix: shape (batch, T)
63
  """
64
- mix, target = batch
65
  batch_size = mix.shape[0]
66
  mix = self.dpt_encoder(mix) # (B, E, L)
67
 
 
57
  self.mask_conv1x1 = nn.Conv1d(self.feature_dim, self.enc_dim, 1, bias=False)
58
  self.decoder = DPTDecoder(n_filters=enc_dim, window_size=win_len)
59
 
60
+ def forward(self, mix):
61
  """
62
  mix: shape (batch, T)
63
  """
 
64
  batch_size = mix.shape[0]
65
  mix = self.dpt_encoder(mix) # (B, E, L)
66
 
remfx/models.py CHANGED
@@ -2,16 +2,16 @@ import torch
2
  import torchmetrics
3
  import pytorch_lightning as pl
4
  from torch import Tensor, nn
5
- from einops import rearrange
6
  from torchaudio.models import HDemucs
7
  from audio_diffusion_pytorch import DiffusionModel
8
  from auraloss.time import SISDRLoss
9
  from auraloss.freq import MultiResolutionSTFTLoss
10
  from umx.openunmix.model import OpenUnmix, Separator
11
 
12
- from utils import FADLoss, spectrogram, log_wandb_audio_batch
13
- from dptnet import DPTNet_base
14
- from dcunet import RefineSpectrogramUnet
15
 
16
 
17
  class RemFX(pl.LightningModule):
@@ -55,41 +55,29 @@ class RemFX(pl.LightningModule):
55
  eps=self.lr_eps,
56
  weight_decay=self.lr_weight_decay,
57
  )
58
- return optimizer
59
-
60
- # Add step-based learning rate scheduler
61
- def optimizer_step(
62
- self,
63
- epoch,
64
- batch_idx,
65
- optimizer,
66
- optimizer_idx,
67
- optimizer_closure,
68
- on_tpu,
69
- using_lbfgs,
70
- ):
71
- # update params
72
- optimizer.step(closure=optimizer_closure)
73
-
74
- # update learning rate. Reduce by factor of 10 at 80% and 95% of training
75
- if self.trainer.global_step == 0.8 * self.trainer.max_steps:
76
- for pg in optimizer.param_groups:
77
- pg["lr"] = 0.1 * pg["lr"]
78
- if self.trainer.global_step == 0.95 * self.trainer.max_steps:
79
- for pg in optimizer.param_groups:
80
- pg["lr"] = 0.1 * pg["lr"]
81
 
82
  def training_step(self, batch, batch_idx):
83
- loss = self.common_step(batch, batch_idx, mode="train")
84
- return loss
85
 
86
  def validation_step(self, batch, batch_idx):
87
- loss = self.common_step(batch, batch_idx, mode="valid")
88
- return loss
89
 
90
  def test_step(self, batch, batch_idx):
91
- loss = self.common_step(batch, batch_idx, mode="test")
92
- return loss
93
 
94
  def common_step(self, batch, batch_idx, mode: str = "train"):
95
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
@@ -116,89 +104,8 @@ class RemFX(pl.LightningModule):
116
  prog_bar=True,
117
  sync_dist=True,
118
  )
119
-
120
  return loss
121
 
122
- def on_train_batch_start(self, batch, batch_idx):
123
- # Log initial audio
124
- if self.log_train_audio:
125
- x, y, _, _ = batch
126
- # Concat samples together for easier viewing in dashboard
127
- input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
128
- target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
129
-
130
- log_wandb_audio_batch(
131
- logger=self.logger,
132
- id="input_effected_audio",
133
- samples=input_samples.cpu(),
134
- sampling_rate=self.sample_rate,
135
- caption="Training Data",
136
- )
137
- log_wandb_audio_batch(
138
- logger=self.logger,
139
- id="target_audio",
140
- samples=target_samples.cpu(),
141
- sampling_rate=self.sample_rate,
142
- caption="Target Data",
143
- )
144
- self.log_train_audio = False
145
-
146
- def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
147
- x, target, _, _ = batch
148
- # Log Input Metrics
149
- for metric in self.metrics:
150
- # SISDR returns negative values, so negate them
151
- if metric == "SISDR":
152
- negate = -1
153
- else:
154
- negate = 1
155
- # Only Log FAD on test set
156
- if metric == "FAD":
157
- continue
158
- self.log(
159
- f"Input_{metric}",
160
- negate * self.metrics[metric](x, target),
161
- on_step=False,
162
- on_epoch=True,
163
- logger=True,
164
- prog_bar=True,
165
- sync_dist=True,
166
- )
167
- # Only run on first batch
168
- if batch_idx == 0:
169
- self.model.eval()
170
- with torch.no_grad():
171
- y = self.model.sample(x)
172
-
173
- # Concat samples together for easier viewing in dashboard
174
- # 2 seconds of silence between each sample
175
- silence = torch.zeros_like(x)
176
- silence = silence[:, : self.sample_rate * 2]
177
-
178
- concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
179
- log_wandb_audio_batch(
180
- logger=self.logger,
181
- id="prediction_input_target",
182
- samples=concat_samples.cpu(),
183
- sampling_rate=self.sample_rate,
184
- caption=f"Epoch {self.current_epoch}",
185
- )
186
- self.model.train()
187
-
188
- def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
189
- self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
190
- # Log FAD
191
- x, target, _, _ = batch
192
- self.log(
193
- "Input_FAD",
194
- self.metrics["FAD"](x, target),
195
- on_step=False,
196
- on_epoch=True,
197
- logger=True,
198
- prog_bar=True,
199
- sync_dist=True,
200
- )
201
-
202
 
203
  class OpenUnmixModel(nn.Module):
204
  def __init__(
@@ -284,9 +191,10 @@ class DiffusionGenerationModel(nn.Module):
284
 
285
 
286
  class DPTNetModel(nn.Module):
287
- def __init__(self, sample_rate, **kwargs):
288
  super().__init__()
289
  self.model = DPTNet_base(**kwargs)
 
290
  self.mrstftloss = MultiResolutionSTFTLoss(
291
  n_bins=self.num_bins, sample_rate=sample_rate
292
  )
@@ -294,31 +202,42 @@ class DPTNetModel(nn.Module):
294
 
295
  def forward(self, batch):
296
  x, target = batch
297
- output = self.model(x).squeeze(1)
298
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
299
  return loss, output
300
 
301
  def sample(self, x: Tensor) -> Tensor:
302
- return self.model.sample(x)
303
 
304
 
305
  class DCUNetModel(nn.Module):
306
- def __init__(self, sample_rate, **kwargs):
307
  super().__init__()
308
  self.model = RefineSpectrogramUnet(**kwargs)
309
  self.mrstftloss = MultiResolutionSTFTLoss(
310
- n_bins=self.num_bins, sample_rate=sample_rate
311
  )
312
  self.l1loss = nn.L1Loss()
313
 
314
  def forward(self, batch):
315
  x, target = batch
316
- output = self.model(x).squeeze(1)
 
 
 
 
 
317
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
318
  return loss, output
319
 
320
  def sample(self, x: Tensor) -> Tensor:
321
- return self.model.sample(x)
 
 
 
 
 
 
322
 
323
 
324
  class FXClassifier(pl.LightningModule):
 
2
  import torchmetrics
3
  import pytorch_lightning as pl
4
  from torch import Tensor, nn
5
+ from torch.nn import functional as F
6
  from torchaudio.models import HDemucs
7
  from audio_diffusion_pytorch import DiffusionModel
8
  from auraloss.time import SISDRLoss
9
  from auraloss.freq import MultiResolutionSTFTLoss
10
  from umx.openunmix.model import OpenUnmix, Separator
11
 
12
+ from remfx.utils import FADLoss, spectrogram
13
+ from remfx.dptnet import DPTNet_base
14
+ from remfx.dcunet import RefineSpectrogramUnet
15
 
16
 
17
  class RemFX(pl.LightningModule):
 
55
  eps=self.lr_eps,
56
  weight_decay=self.lr_weight_decay,
57
  )
58
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
59
+ optimizer,
60
+ [0.8 * self.trainer.max_steps, 0.95 * self.trainer.max_steps],
61
+ gamma=0.1,
62
+ )
63
+ return {
64
+ "optimizer": optimizer,
65
+ "lr_scheduler": {
66
+ "scheduler": lr_scheduler,
67
+ "monitor": "val_loss",
68
+ "interval": "step",
69
+ "frequency": 1,
70
+ },
71
+ }
 
 
 
 
 
 
 
 
 
72
 
73
  def training_step(self, batch, batch_idx):
74
+ return self.common_step(batch, batch_idx, mode="train")
 
75
 
76
  def validation_step(self, batch, batch_idx):
77
+ return self.common_step(batch, batch_idx, mode="valid")
 
78
 
79
  def test_step(self, batch, batch_idx):
80
+ return self.common_step(batch, batch_idx, mode="test")
 
81
 
82
  def common_step(self, batch, batch_idx, mode: str = "train"):
83
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
 
104
  prog_bar=True,
105
  sync_dist=True,
106
  )
 
107
  return loss
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  class OpenUnmixModel(nn.Module):
111
  def __init__(
 
191
 
192
 
193
  class DPTNetModel(nn.Module):
194
+ def __init__(self, sample_rate, num_bins, **kwargs):
195
  super().__init__()
196
  self.model = DPTNet_base(**kwargs)
197
+ self.num_bins = num_bins
198
  self.mrstftloss = MultiResolutionSTFTLoss(
199
  n_bins=self.num_bins, sample_rate=sample_rate
200
  )
 
202
 
203
  def forward(self, batch):
204
  x, target = batch
205
+ output = self.model(x.squeeze(1))
206
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
207
  return loss, output
208
 
209
  def sample(self, x: Tensor) -> Tensor:
210
+ return self.model(x.squeeze(1))
211
 
212
 
213
  class DCUNetModel(nn.Module):
214
+ def __init__(self, sample_rate, num_bins, **kwargs):
215
  super().__init__()
216
  self.model = RefineSpectrogramUnet(**kwargs)
217
  self.mrstftloss = MultiResolutionSTFTLoss(
218
+ n_bins=num_bins, sample_rate=sample_rate
219
  )
220
  self.l1loss = nn.L1Loss()
221
 
222
  def forward(self, batch):
223
  x, target = batch
224
+ output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
225
+ # Pad or crop to match target
226
+ if output.shape[-1] > target.shape[-1]:
227
+ output = output[:, : target.shape[-1]]
228
+ elif output.shape[-1] < target.shape[-1]:
229
+ output = F.pad(output, (0, target.shape[-1] - output.shape[-1]))
230
  loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
231
  return loss, output
232
 
233
  def sample(self, x: Tensor) -> Tensor:
234
+ output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
235
+ # Pad or crop to match target
236
+ if output.shape[-1] > x.shape[-1]:
237
+ output = output[:, : x.shape[-1]]
238
+ elif output.shape[-1] < x.shape[-1]:
239
+ output = F.pad(output, (0, x.shape[-1] - output.shape[-1]))
240
+ return output
241
 
242
 
243
  class FXClassifier(pl.LightningModule):
remfx/utils.py CHANGED
@@ -7,10 +7,8 @@ from frechet_audio_distance import FrechetAudioDistance
7
  import numpy as np
8
  import torch
9
  import torchaudio
10
- from torch import Tensor, nn
11
- import wandb
12
- from einops import rearrange
13
- from torch._six import container_abcs
14
 
15
 
16
  def get_logger(name=__name__) -> logging.Logger:
@@ -144,30 +142,6 @@ def create_sequential_chunks(
144
  return chunks, sr
145
 
146
 
147
- def log_wandb_audio_batch(
148
- logger: pl.loggers.WandbLogger,
149
- id: str,
150
- samples: Tensor,
151
- sampling_rate: int,
152
- caption: str = "",
153
- max_items: int = 10,
154
- ):
155
- num_items = samples.shape[0]
156
- samples = rearrange(samples, "b c t -> b t c")
157
- for idx in range(num_items):
158
- if idx >= max_items:
159
- break
160
- logger.experiment.log(
161
- {
162
- f"{id}_{idx}": wandb.Audio(
163
- samples[idx].cpu().numpy(),
164
- caption=caption,
165
- sample_rate=sampling_rate,
166
- )
167
- }
168
- )
169
-
170
-
171
  def spectrogram(
172
  x: torch.Tensor,
173
  window: torch.Tensor,
@@ -209,7 +183,7 @@ def init_bn(bn):
209
 
210
  def _ntuple(n: int):
211
  def parse(x):
212
- if isinstance(x, container_abcs.Iterable):
213
  return x
214
  return tuple([x] * n)
215
 
 
7
  import numpy as np
8
  import torch
9
  import torchaudio
10
+ from torch import nn
11
+ import collections.abc
 
 
12
 
13
 
14
  def get_logger(name=__name__) -> logging.Logger:
 
142
  return chunks, sr
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  def spectrogram(
146
  x: torch.Tensor,
147
  window: torch.Tensor,
 
183
 
184
  def _ntuple(n: int):
185
  def parse(x):
186
+ if isinstance(x, collections.abc.Iterable):
187
  return x
188
  return tuple([x] * n)
189