mattricesound commited on
Commit
eae60a9
·
1 Parent(s): 561bfea

Add classifier into inference. Comment-out hearbaseline for now

Browse files
cfg/exp/chain_inference_aug_classifier.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: demucs
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_root: "/scratch/EffectSet"
10
+ accelerator: "gpu"
11
+ log_audio: True
12
+ # Effects
13
+ num_kept_effects: [0,0] # [min, max]
14
+ num_removed_effects: [0,5] # [min, max]
15
+ shuffle_kept_effects: True
16
+ shuffle_removed_effects: True
17
+ num_classes: 5
18
+ effects_to_keep:
19
+ effects_to_remove:
20
+ - distortion
21
+ - compressor
22
+ - reverb
23
+ - chorus
24
+ - delay
25
+ datamodule:
26
+ batch_size: 16
27
+ num_workers: 8
28
+
29
+ dcunet:
30
+ _target_: remfx.models.RemFX
31
+ lr: 1e-4
32
+ lr_beta1: 0.95
33
+ lr_beta2: 0.999
34
+ lr_eps: 1e-6
35
+ lr_weight_decay: 1e-3
36
+ sample_rate: ${sample_rate}
37
+ network:
38
+ _target_: remfx.models.DCUNetModel
39
+ architecture: "Large-DCUNet-20"
40
+ stft_kernel_size: 512
41
+ fix_length_mode: "pad"
42
+ sample_rate: ${sample_rate}
43
+ num_bins: 1025
44
+
45
+ classifier:
46
+ _target_: remfx.models.FXClassifier
47
+ lr: 3e-4
48
+ lr_weight_decay: 1e-3
49
+ sample_rate: ${sample_rate}
50
+ network:
51
+ _target_: remfx.classifier.Cnn14
52
+ num_classes: ${num_classes}
53
+ n_fft: 1024
54
+ hop_length: 256
55
+ n_mels: 128
56
+ sample_rate: 44100
57
+ model_sample_rate: 44100
58
+ specaugment: False
59
+ classifier_ckpt: "ckpts/classifier.ckpt"
60
+
61
+ ckpts:
62
+ RandomPedalboardDistortion:
63
+ model: ${model}
64
+ ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
65
+ RandomPedalboardCompressor:
66
+ model: ${model}
67
+ ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
68
+ RandomPedalboardReverb:
69
+ model: ${dcunet}
70
+ ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
71
+ RandomPedalboardChorus:
72
+ model: ${dcunet}
73
+ ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
74
+ RandomPedalboardDelay:
75
+ model: ${dcunet}
76
+ ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
77
+
78
+ inference_effects_ordering:
79
+ - "RandomPedalboardDistortion"
80
+ - "RandomPedalboardCompressor"
81
+ - "RandomPedalboardReverb"
82
+ - "RandomPedalboardChorus"
83
+ - "RandomPedalboardDelay"
84
+ num_bins: 1025
remfx/classifier.py CHANGED
@@ -1,9 +1,11 @@
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
4
- import hearbaseline
5
- import hearbaseline.vggish
6
- import hearbaseline.wav2vec2
 
 
7
 
8
  import wav2clip_hear
9
  import panns_hear
 
1
  import torch
2
  import torchaudio
3
  import torch.nn as nn
4
+
5
+ # import hearbaseline
6
+
7
+ # import hearbaseline.vggish
8
+ # import hearbaseline.wav2vec2
9
 
10
  import wav2clip_hear
11
  import panns_hear
remfx/models.py CHANGED
@@ -20,7 +20,7 @@ ALL_EFFECTS = effects.Pedalboard_Effects
20
 
21
 
22
  class RemFXChainInference(pl.LightningModule):
23
- def __init__(self, models, sample_rate, num_bins, effect_order):
24
  super().__init__()
25
  self.model = models
26
  self.mrstftloss = MultiResolutionSTFTLoss(
@@ -35,6 +35,7 @@ class RemFXChainInference(pl.LightningModule):
35
  )
36
  self.sample_rate = sample_rate
37
  self.effect_order = effect_order
 
38
 
39
  def forward(self, batch, batch_idx, order=None):
40
  x, y, _, rem_fx_labels = batch
@@ -43,6 +44,13 @@ class RemFXChainInference(pl.LightningModule):
43
  effects_order = order
44
  else:
45
  effects_order = self.effect_order
 
 
 
 
 
 
 
46
  effects_present = [
47
  [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
48
  for effect_label in rem_fx_labels
 
20
 
21
 
22
  class RemFXChainInference(pl.LightningModule):
23
+ def __init__(self, models, sample_rate, num_bins, effect_order, classifier=None):
24
  super().__init__()
25
  self.model = models
26
  self.mrstftloss = MultiResolutionSTFTLoss(
 
35
  )
36
  self.sample_rate = sample_rate
37
  self.effect_order = effect_order
38
+ self.classifier = classifier
39
 
40
  def forward(self, batch, batch_idx, order=None):
41
  x, y, _, rem_fx_labels = batch
 
44
  effects_order = order
45
  else:
46
  effects_order = self.effect_order
47
+
48
+ # Use classifier labels
49
+ if self.classifier:
50
+ threshold = 0.5
51
+ labels = self.classifier(x)
52
+ rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
53
+
54
  effects_present = [
55
  [ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
56
  for effect_label in rem_fx_labels
scripts/chain_inference.py CHANGED
@@ -26,6 +26,16 @@ def main(cfg: DictConfig):
26
  model.to(device)
27
  models[effect] = model
28
 
 
 
 
 
 
 
 
 
 
 
29
  callbacks = []
30
  if "callbacks" in cfg:
31
  for _, cb_conf in cfg["callbacks"].items():
@@ -54,6 +64,7 @@ def main(cfg: DictConfig):
54
  sample_rate=cfg.sample_rate,
55
  num_bins=cfg.num_bins,
56
  effect_order=cfg.inference_effects_ordering,
 
57
  )
58
  trainer.test(model=inference_model, datamodule=datamodule)
59
 
 
26
  model.to(device)
27
  models[effect] = model
28
 
29
+ classifier = None
30
+ if "classifier" in cfg:
31
+ log.info(f"Instantiating classifier <{cfg.classifier._target_}>.")
32
+ classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
33
+ ckpt_path = cfg.classifier_ckpt
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
36
+ classifier.load_state_dict(state_dict)
37
+ classifier.to(device)
38
+
39
  callbacks = []
40
  if "callbacks" in cfg:
41
  for _, cb_conf in cfg["callbacks"].items():
 
64
  sample_rate=cfg.sample_rate,
65
  num_bins=cfg.num_bins,
66
  effect_order=cfg.inference_effects_ordering,
67
+ classifier=classifier,
68
  )
69
  trainer.test(model=inference_model, datamodule=datamodule)
70