mattricesound commited on
Commit
6da1b0d
·
1 Parent(s): c1cb017

Add custom-inference dataset

Browse files
cfg/exp/chain_inference_custom.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_files: False
10
+ render_root: "/scratch/EffectSet"
11
+ accelerator: "gpu"
12
+ log_audio: True
13
+ # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [0,5] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: True
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ effects_to_remove:
21
+ - distortion
22
+ - compressor
23
+ - reverb
24
+ - chorus
25
+ - delay
26
+ datamodule:
27
+ batch_size: 16
28
+ num_workers: 8
29
+ train_dataset: None
30
+ val_dataset: None
31
+ test_dataset:
32
+ _target_: remfx.datasets.InferenceDataset
33
+ root: "./data/fx-examples"
34
+ sample_rate: ${sample_rate}
35
+ ckpts:
36
+ RandomPedalboardDistortion: "ckpts/distortion.ckpt"
37
+ RandomPedalboardCompressor: "ckpts/compressor.ckpt"
38
+ RandomPedalboardReverb: "ckpts/reverb.ckpt"
39
+ RandomPedalboardChorus: "ckpts/chorus.ckpt"
40
+ RandomPedalboardDelay: "ckpts/delay.ckpt"
41
+ num_bins: 1025
remfx/datasets.py CHANGED
@@ -360,6 +360,42 @@ class EffectDataset(Dataset):
360
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
361
 
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  class EffectDatamodule(pl.LightningDataModule):
364
  def __init__(
365
  self,
 
360
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
361
 
362
 
363
+ class InferenceDataset(Dataset):
364
+ def __init__(self, root: str, sample_rate: int):
365
+ self.root = Path(root)
366
+ self.sample_rate = sample_rate
367
+ self.clean_paths = list(self.root.glob("clean/*.wav"))
368
+ self.effected_paths = list(self.root.glob("effected/*.wav"))
369
+
370
+ def __len__(self) -> int:
371
+ return len(self.audio_paths)
372
+
373
+ def __getitem__(self, idx: int) -> torch.Tensor:
374
+ clean_path = self.clean_paths[idx]
375
+ effected_path = self.effected_paths[idx]
376
+ clean_audio, sr = torchaudio.load(clean_path)
377
+ clean = torchaudio.functional.resample(clean_audio, sr, self.sample_rate)
378
+ effected_audio, sr = torchaudio.load(effected_path)
379
+ effected = torchaudio.functional.resample(effected_audio, sr, self.sample_rate)
380
+
381
+ # Sum to mono
382
+ clean = torch.sum(clean, dim=0)
383
+ effected = torch.sum(effected, dim=0)
384
+
385
+ # Pad or trim effected to clean
386
+ if len(clean) > len(effected):
387
+ effected = torch.nn.functional.pad(
388
+ effected, (0, len(clean) - len(effected))
389
+ )
390
+ elif len(effected) > len(clean):
391
+ effected = effected[: len(clean)]
392
+
393
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
394
+ wet_labels_tensor = torch.ones(len(ALL_EFFECTS))
395
+
396
+ return clean, effected, dry_labels_tensor, wet_labels_tensor
397
+
398
+
399
  class EffectDatamodule(pl.LightningDataModule):
400
  def __init__(
401
  self,
remfx/models.py CHANGED
@@ -96,7 +96,8 @@ class RemFXChainInference(pl.LightningModule):
96
  else:
97
  negate = 1
98
  self.log(
99
- f"test_{metric}_" + "".join(self.effect_order),
 
100
  negate * self.metrics[metric](output, y),
101
  on_step=False,
102
  on_epoch=True,
@@ -307,27 +308,6 @@ class DPTNetModel(nn.Module):
307
  def sample(self, x: Tensor) -> Tensor:
308
  return self.model(x.squeeze(1))
309
 
310
- def __init__(self, sample_rate, num_bins, **kwargs):
311
- super().__init__()
312
- self.model = asteroid.models.DCUNet(**kwargs)
313
- self.mrstftloss = MultiResolutionSTFTLoss(
314
- n_bins=num_bins, sample_rate=sample_rate
315
- )
316
- self.l1loss = nn.L1Loss()
317
-
318
- def forward(self, batch):
319
- x, target = batch
320
- output = self.model(x.squeeze(1)) # B x T
321
- # Crop target to match output
322
- if output.shape[-1] < target.shape[-1]:
323
- target = causal_crop(target, output.shape[-1])
324
- loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
325
- return loss, output
326
-
327
- def sample(self, x: Tensor) -> Tensor:
328
- output = self.model(x.squeeze(1)) # B x T
329
- return output
330
-
331
 
332
  class TCNModel(nn.Module):
333
  def __init__(self, sample_rate, num_bins, **kwargs):
 
96
  else:
97
  negate = 1
98
  self.log(
99
+ f"test_{metric}_"
100
+ + "".join(self.effect_order).replace("RandomPedalboard", ""),
101
  negate * self.metrics[metric](output, y),
102
  on_step=False,
103
  on_epoch=True,
 
308
  def sample(self, x: Tensor) -> Tensor:
309
  return self.model(x.squeeze(1))
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  class TCNModel(nn.Module):
313
  def __init__(self, sample_rate, num_bins, **kwargs):