Spaces:
Runtime error
Runtime error
Commit
·
6da1b0d
1
Parent(s):
c1cb017
Add custom-inference dataset
Browse files- cfg/exp/chain_inference_custom.yaml +41 -0
- remfx/datasets.py +36 -0
- remfx/models.py +2 -22
cfg/exp/chain_inference_custom.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: all
|
5 |
+
seed: 12345
|
6 |
+
sample_rate: 48000
|
7 |
+
chunk_size: 262144 # 5.5s
|
8 |
+
logs_dir: "./logs"
|
9 |
+
render_files: False
|
10 |
+
render_root: "/scratch/EffectSet"
|
11 |
+
accelerator: "gpu"
|
12 |
+
log_audio: True
|
13 |
+
# Effects
|
14 |
+
num_kept_effects: [0,0] # [min, max]
|
15 |
+
num_removed_effects: [0,5] # [min, max]
|
16 |
+
shuffle_kept_effects: True
|
17 |
+
shuffle_removed_effects: True
|
18 |
+
num_classes: 5
|
19 |
+
effects_to_keep:
|
20 |
+
effects_to_remove:
|
21 |
+
- distortion
|
22 |
+
- compressor
|
23 |
+
- reverb
|
24 |
+
- chorus
|
25 |
+
- delay
|
26 |
+
datamodule:
|
27 |
+
batch_size: 16
|
28 |
+
num_workers: 8
|
29 |
+
train_dataset: None
|
30 |
+
val_dataset: None
|
31 |
+
test_dataset:
|
32 |
+
_target_: remfx.datasets.InferenceDataset
|
33 |
+
root: "./data/fx-examples"
|
34 |
+
sample_rate: ${sample_rate}
|
35 |
+
ckpts:
|
36 |
+
RandomPedalboardDistortion: "ckpts/distortion.ckpt"
|
37 |
+
RandomPedalboardCompressor: "ckpts/compressor.ckpt"
|
38 |
+
RandomPedalboardReverb: "ckpts/reverb.ckpt"
|
39 |
+
RandomPedalboardChorus: "ckpts/chorus.ckpt"
|
40 |
+
RandomPedalboardDelay: "ckpts/delay.ckpt"
|
41 |
+
num_bins: 1025
|
remfx/datasets.py
CHANGED
@@ -360,6 +360,42 @@ class EffectDataset(Dataset):
|
|
360 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
361 |
|
362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
class EffectDatamodule(pl.LightningDataModule):
|
364 |
def __init__(
|
365 |
self,
|
|
|
360 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
361 |
|
362 |
|
363 |
+
class InferenceDataset(Dataset):
|
364 |
+
def __init__(self, root: str, sample_rate: int):
|
365 |
+
self.root = Path(root)
|
366 |
+
self.sample_rate = sample_rate
|
367 |
+
self.clean_paths = list(self.root.glob("clean/*.wav"))
|
368 |
+
self.effected_paths = list(self.root.glob("effected/*.wav"))
|
369 |
+
|
370 |
+
def __len__(self) -> int:
|
371 |
+
return len(self.audio_paths)
|
372 |
+
|
373 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
374 |
+
clean_path = self.clean_paths[idx]
|
375 |
+
effected_path = self.effected_paths[idx]
|
376 |
+
clean_audio, sr = torchaudio.load(clean_path)
|
377 |
+
clean = torchaudio.functional.resample(clean_audio, sr, self.sample_rate)
|
378 |
+
effected_audio, sr = torchaudio.load(effected_path)
|
379 |
+
effected = torchaudio.functional.resample(effected_audio, sr, self.sample_rate)
|
380 |
+
|
381 |
+
# Sum to mono
|
382 |
+
clean = torch.sum(clean, dim=0)
|
383 |
+
effected = torch.sum(effected, dim=0)
|
384 |
+
|
385 |
+
# Pad or trim effected to clean
|
386 |
+
if len(clean) > len(effected):
|
387 |
+
effected = torch.nn.functional.pad(
|
388 |
+
effected, (0, len(clean) - len(effected))
|
389 |
+
)
|
390 |
+
elif len(effected) > len(clean):
|
391 |
+
effected = effected[: len(clean)]
|
392 |
+
|
393 |
+
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
394 |
+
wet_labels_tensor = torch.ones(len(ALL_EFFECTS))
|
395 |
+
|
396 |
+
return clean, effected, dry_labels_tensor, wet_labels_tensor
|
397 |
+
|
398 |
+
|
399 |
class EffectDatamodule(pl.LightningDataModule):
|
400 |
def __init__(
|
401 |
self,
|
remfx/models.py
CHANGED
@@ -96,7 +96,8 @@ class RemFXChainInference(pl.LightningModule):
|
|
96 |
else:
|
97 |
negate = 1
|
98 |
self.log(
|
99 |
-
f"test_{metric}_"
|
|
|
100 |
negate * self.metrics[metric](output, y),
|
101 |
on_step=False,
|
102 |
on_epoch=True,
|
@@ -307,27 +308,6 @@ class DPTNetModel(nn.Module):
|
|
307 |
def sample(self, x: Tensor) -> Tensor:
|
308 |
return self.model(x.squeeze(1))
|
309 |
|
310 |
-
def __init__(self, sample_rate, num_bins, **kwargs):
|
311 |
-
super().__init__()
|
312 |
-
self.model = asteroid.models.DCUNet(**kwargs)
|
313 |
-
self.mrstftloss = MultiResolutionSTFTLoss(
|
314 |
-
n_bins=num_bins, sample_rate=sample_rate
|
315 |
-
)
|
316 |
-
self.l1loss = nn.L1Loss()
|
317 |
-
|
318 |
-
def forward(self, batch):
|
319 |
-
x, target = batch
|
320 |
-
output = self.model(x.squeeze(1)) # B x T
|
321 |
-
# Crop target to match output
|
322 |
-
if output.shape[-1] < target.shape[-1]:
|
323 |
-
target = causal_crop(target, output.shape[-1])
|
324 |
-
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
325 |
-
return loss, output
|
326 |
-
|
327 |
-
def sample(self, x: Tensor) -> Tensor:
|
328 |
-
output = self.model(x.squeeze(1)) # B x T
|
329 |
-
return output
|
330 |
-
|
331 |
|
332 |
class TCNModel(nn.Module):
|
333 |
def __init__(self, sample_rate, num_bins, **kwargs):
|
|
|
96 |
else:
|
97 |
negate = 1
|
98 |
self.log(
|
99 |
+
f"test_{metric}_"
|
100 |
+
+ "".join(self.effect_order).replace("RandomPedalboard", ""),
|
101 |
negate * self.metrics[metric](output, y),
|
102 |
on_step=False,
|
103 |
on_epoch=True,
|
|
|
308 |
def sample(self, x: Tensor) -> Tensor:
|
309 |
return self.model(x.squeeze(1))
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
class TCNModel(nn.Module):
|
313 |
def __init__(self, sample_rate, num_bins, **kwargs):
|