mattricesound commited on
Commit
652f240
·
1 Parent(s): 30c1d67

Add effects chain inference code

Browse files
cfg/config.yaml CHANGED
@@ -51,7 +51,7 @@ datamodule:
51
  _target_: remfx.datasets.EffectDatamodule
52
  train_dataset:
53
  _target_: remfx.datasets.EffectDataset
54
- total_chunks: 8000
55
  sample_rate: ${sample_rate}
56
  root: ${oc.env:DATASET_ROOT}
57
  chunk_size: ${chunk_size}
@@ -67,7 +67,7 @@ datamodule:
67
  render_root: ${render_root}
68
  val_dataset:
69
  _target_: remfx.datasets.EffectDataset
70
- total_chunks: 1000
71
  sample_rate: ${sample_rate}
72
  root: ${oc.env:DATASET_ROOT}
73
  chunk_size: ${chunk_size}
@@ -83,7 +83,7 @@ datamodule:
83
  render_root: ${render_root}
84
  test_dataset:
85
  _target_: remfx.datasets.EffectDataset
86
- total_chunks: 1000
87
  sample_rate: ${sample_rate}
88
  root: ${oc.env:DATASET_ROOT}
89
  chunk_size: ${chunk_size}
@@ -124,3 +124,10 @@ trainer:
124
  devices: 1
125
  gradient_clip_val: 10.0
126
  max_steps: 50000
 
 
 
 
 
 
 
 
51
  _target_: remfx.datasets.EffectDatamodule
52
  train_dataset:
53
  _target_: remfx.datasets.EffectDataset
54
+ total_chunks: 80
55
  sample_rate: ${sample_rate}
56
  root: ${oc.env:DATASET_ROOT}
57
  chunk_size: ${chunk_size}
 
67
  render_root: ${render_root}
68
  val_dataset:
69
  _target_: remfx.datasets.EffectDataset
70
+ total_chunks: 10
71
  sample_rate: ${sample_rate}
72
  root: ${oc.env:DATASET_ROOT}
73
  chunk_size: ${chunk_size}
 
83
  render_root: ${render_root}
84
  test_dataset:
85
  _target_: remfx.datasets.EffectDataset
86
+ total_chunks: 10
87
  sample_rate: ${sample_rate}
88
  root: ${oc.env:DATASET_ROOT}
89
  chunk_size: ${chunk_size}
 
124
  devices: 1
125
  gradient_clip_val: 10.0
126
  max_steps: 50000
127
+ ckpts:
128
+ RandomPedalboardChorus: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
129
+ RandomPedalboardDelay: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
130
+ RandomPedalboardDistortion: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
131
+ RandomPedalboardCompressor: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
132
+ RandomPedalboardReverb: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
133
+ num_bins: 1025
cfg/exp/chain_inference.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_files: True
10
+ render_root: "/scratch/EffectSet"
11
+ accelerator: "gpu"
12
+ log_audio: True
13
+ # Effects
14
+ num_kept_effects: [0,0] # [min, max]
15
+ num_removed_effects: [0,5] # [min, max]
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: True
18
+ num_classes: 5
19
+ effects_to_keep:
20
+ effects_to_remove:
21
+ - distortion
22
+ - compressor
23
+ - reverb
24
+ - chorus
25
+ - delay
26
+ datamodule:
27
+ batch_size: 16
28
+ num_workers: 8
29
+ ckpts:
30
+ RandomPedalboardChorus: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
31
+ RandomPedalboardDelay: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
32
+ RandomPedalboardDistortion: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
33
+ RandomPedalboardCompressor: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
34
+ RandomPedalboardReverb: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
35
+ num_bins: 1025
cfg/exp/{dist.yaml → distortion.yaml} RENAMED
File without changes
remfx/callbacks.py CHANGED
@@ -4,6 +4,7 @@ from einops import rearrange
4
  import torch
5
  import wandb
6
  from torch import Tensor
 
7
 
8
 
9
  class AudioCallback(Callback):
@@ -46,7 +47,10 @@ class AudioCallback(Callback):
46
  # Only run on first batch
47
  if batch_idx == 0 and self.log_audio:
48
  with torch.no_grad():
49
- y = pl_module.model.sample(x)
 
 
 
50
  # Concat samples together for easier viewing in dashboard
51
  # 2 seconds of silence between each sample
52
  silence = torch.zeros_like(x)
 
4
  import torch
5
  import wandb
6
  from torch import Tensor
7
+ from remfx.models import RemFXChainInference
8
 
9
 
10
  class AudioCallback(Callback):
 
47
  # Only run on first batch
48
  if batch_idx == 0 and self.log_audio:
49
  with torch.no_grad():
50
+ if type(pl_module) == RemFXChainInference:
51
+ y = pl_module.sample(batch)
52
+ else:
53
+ y = pl_module.model.sample(x)
54
  # Concat samples together for easier viewing in dashboard
55
  # 2 seconds of silence between each sample
56
  silence = torch.zeros_like(x)
remfx/models.py CHANGED
@@ -11,8 +11,78 @@ from umx.openunmix.model import OpenUnmix, Separator
11
  from remfx.utils import FADLoss, spectrogram
12
  from remfx.tcn import TCN
13
  from remfx.utils import causal_crop
 
14
  import asteroid
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  class RemFX(pl.LightningModule):
18
  def __init__(
 
11
  from remfx.utils import FADLoss, spectrogram
12
  from remfx.tcn import TCN
13
  from remfx.utils import causal_crop
14
+ from remfx import effects
15
  import asteroid
16
 
17
+ ALL_EFFECTS = effects.Pedalboard_Effects
18
+
19
+
20
+ class RemFXChainInference(pl.LightningModule):
21
+ def __init__(self, models, sample_rate, num_bins):
22
+ super().__init__()
23
+ self.model = models
24
+ self.mrstftloss = MultiResolutionSTFTLoss(
25
+ n_bins=num_bins, sample_rate=sample_rate
26
+ )
27
+ self.l1loss = nn.L1Loss()
28
+ self.metrics = nn.ModuleDict(
29
+ {
30
+ "SISDR": SISDRLoss(),
31
+ "STFT": MultiResolutionSTFTLoss(),
32
+ "FAD": FADLoss(sample_rate=sample_rate),
33
+ }
34
+ )
35
+
36
+ def forward(self, batch):
37
+ x, y, _, rem_fx_labels = batch
38
+ # Use chain of effects defined in config
39
+ effects = [
40
+ [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
41
+ for effect_label in rem_fx_labels
42
+ ]
43
+ output = []
44
+ with torch.no_grad():
45
+ for elem, effect_chain in zip(x, effects):
46
+ elem = elem.unsqueeze(0) # Add batch dim
47
+ for effect in effect_chain:
48
+ # Get correct model based on effect name. This is a bit hacky
49
+ # Then sample the model
50
+ elem = self.model[effect.__name__].model.sample(elem)
51
+ output.append(elem.squeeze(0))
52
+ output = torch.stack(output)
53
+
54
+ loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
55
+ return loss, output
56
+
57
+ def test_step(self, batch, batch_idx):
58
+ x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
59
+
60
+ loss, output = self.forward(batch)
61
+ # Crop target to match output
62
+ if output.shape[-1] < y.shape[-1]:
63
+ y = causal_crop(y, output.shape[-1])
64
+ self.log("test_loss", loss)
65
+ # Metric logging
66
+ with torch.no_grad():
67
+ for metric in self.metrics:
68
+ # SISDR returns negative values, so negate them
69
+ if metric == "SISDR":
70
+ negate = -1
71
+ else:
72
+ negate = 1
73
+ self.log(
74
+ f"test_{metric}",
75
+ negate * self.metrics[metric](output, y),
76
+ on_step=False,
77
+ on_epoch=True,
78
+ logger=True,
79
+ prog_bar=True,
80
+ sync_dist=True,
81
+ )
82
+
83
+ def sample(self, batch):
84
+ return self.forward(batch)[1]
85
+
86
 
87
  class RemFX(pl.LightningModule):
88
  def __init__(
scripts/chain_inference.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import hydra
3
+ from omegaconf import DictConfig
4
+ import remfx.utils as utils
5
+ from pytorch_lightning.utilities.model_summary import ModelSummary
6
+ import torch
7
+ from remfx.models import RemFXChainInference
8
+
9
+ log = utils.get_logger(__name__)
10
+
11
+
12
+ @hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
13
+ def main(cfg: DictConfig):
14
+ # Apply seed for reproducibility
15
+ if cfg.seed:
16
+ pl.seed_everything(cfg.seed)
17
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
18
+ datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
19
+ log.info(f"Instantiating model <{cfg.model._target_}>.")
20
+ models = {}
21
+ for effect in cfg.ckpts:
22
+ ckpt_path = cfg.ckpts[effect]
23
+ model = hydra.utils.instantiate(cfg.model, _convert_="partial")
24
+ state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[
25
+ "state_dict"
26
+ ]
27
+ model.load_state_dict(state_dict)
28
+ models[effect] = model
29
+
30
+ callbacks = []
31
+ if "callbacks" in cfg:
32
+ for _, cb_conf in cfg["callbacks"].items():
33
+ if "_target_" in cb_conf:
34
+ log.info(f"Instantiating callback <{cb_conf._target_}>.")
35
+ callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
36
+
37
+ logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
38
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
39
+ trainer = hydra.utils.instantiate(
40
+ cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
41
+ )
42
+ log.info("Logging hyperparameters!")
43
+ utils.log_hyperparameters(
44
+ config=cfg,
45
+ model=model,
46
+ datamodule=datamodule,
47
+ trainer=trainer,
48
+ callbacks=callbacks,
49
+ logger=logger,
50
+ )
51
+ summary = ModelSummary(model)
52
+ print(summary)
53
+
54
+ inference_model = RemFXChainInference(
55
+ models, sample_rate=cfg.sample_rate, num_bins=cfg.num_bins
56
+ )
57
+ trainer.test(model=inference_model, datamodule=datamodule)
58
+
59
+
60
+ if __name__ == "__main__":
61
+ main()