Spaces:
Sleeping
Sleeping
Commit
·
eae60a9
1
Parent(s):
561bfea
Add classifier into inference. Comment-out hearbaseline for now
Browse files- cfg/exp/chain_inference_aug_classifier.yaml +84 -0
- remfx/classifier.py +5 -3
- remfx/models.py +9 -1
- scripts/chain_inference.py +11 -0
cfg/exp/chain_inference_aug_classifier.yaml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: all
|
5 |
+
seed: 12345
|
6 |
+
sample_rate: 48000
|
7 |
+
chunk_size: 262144 # 5.5s
|
8 |
+
logs_dir: "./logs"
|
9 |
+
render_root: "/scratch/EffectSet"
|
10 |
+
accelerator: "gpu"
|
11 |
+
log_audio: True
|
12 |
+
# Effects
|
13 |
+
num_kept_effects: [0,0] # [min, max]
|
14 |
+
num_removed_effects: [0,5] # [min, max]
|
15 |
+
shuffle_kept_effects: True
|
16 |
+
shuffle_removed_effects: True
|
17 |
+
num_classes: 5
|
18 |
+
effects_to_keep:
|
19 |
+
effects_to_remove:
|
20 |
+
- distortion
|
21 |
+
- compressor
|
22 |
+
- reverb
|
23 |
+
- chorus
|
24 |
+
- delay
|
25 |
+
datamodule:
|
26 |
+
batch_size: 16
|
27 |
+
num_workers: 8
|
28 |
+
|
29 |
+
dcunet:
|
30 |
+
_target_: remfx.models.RemFX
|
31 |
+
lr: 1e-4
|
32 |
+
lr_beta1: 0.95
|
33 |
+
lr_beta2: 0.999
|
34 |
+
lr_eps: 1e-6
|
35 |
+
lr_weight_decay: 1e-3
|
36 |
+
sample_rate: ${sample_rate}
|
37 |
+
network:
|
38 |
+
_target_: remfx.models.DCUNetModel
|
39 |
+
architecture: "Large-DCUNet-20"
|
40 |
+
stft_kernel_size: 512
|
41 |
+
fix_length_mode: "pad"
|
42 |
+
sample_rate: ${sample_rate}
|
43 |
+
num_bins: 1025
|
44 |
+
|
45 |
+
classifier:
|
46 |
+
_target_: remfx.models.FXClassifier
|
47 |
+
lr: 3e-4
|
48 |
+
lr_weight_decay: 1e-3
|
49 |
+
sample_rate: ${sample_rate}
|
50 |
+
network:
|
51 |
+
_target_: remfx.classifier.Cnn14
|
52 |
+
num_classes: ${num_classes}
|
53 |
+
n_fft: 1024
|
54 |
+
hop_length: 256
|
55 |
+
n_mels: 128
|
56 |
+
sample_rate: 44100
|
57 |
+
model_sample_rate: 44100
|
58 |
+
specaugment: False
|
59 |
+
classifier_ckpt: "ckpts/classifier.ckpt"
|
60 |
+
|
61 |
+
ckpts:
|
62 |
+
RandomPedalboardDistortion:
|
63 |
+
model: ${model}
|
64 |
+
ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
|
65 |
+
RandomPedalboardCompressor:
|
66 |
+
model: ${model}
|
67 |
+
ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
|
68 |
+
RandomPedalboardReverb:
|
69 |
+
model: ${dcunet}
|
70 |
+
ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
|
71 |
+
RandomPedalboardChorus:
|
72 |
+
model: ${dcunet}
|
73 |
+
ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
|
74 |
+
RandomPedalboardDelay:
|
75 |
+
model: ${dcunet}
|
76 |
+
ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
|
77 |
+
|
78 |
+
inference_effects_ordering:
|
79 |
+
- "RandomPedalboardDistortion"
|
80 |
+
- "RandomPedalboardCompressor"
|
81 |
+
- "RandomPedalboardReverb"
|
82 |
+
- "RandomPedalboardChorus"
|
83 |
+
- "RandomPedalboardDelay"
|
84 |
+
num_bins: 1025
|
remfx/classifier.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import torch
|
2 |
import torchaudio
|
3 |
import torch.nn as nn
|
4 |
-
|
5 |
-
import hearbaseline
|
6 |
-
|
|
|
|
|
7 |
|
8 |
import wav2clip_hear
|
9 |
import panns_hear
|
|
|
1 |
import torch
|
2 |
import torchaudio
|
3 |
import torch.nn as nn
|
4 |
+
|
5 |
+
# import hearbaseline
|
6 |
+
|
7 |
+
# import hearbaseline.vggish
|
8 |
+
# import hearbaseline.wav2vec2
|
9 |
|
10 |
import wav2clip_hear
|
11 |
import panns_hear
|
remfx/models.py
CHANGED
@@ -20,7 +20,7 @@ ALL_EFFECTS = effects.Pedalboard_Effects
|
|
20 |
|
21 |
|
22 |
class RemFXChainInference(pl.LightningModule):
|
23 |
-
def __init__(self, models, sample_rate, num_bins, effect_order):
|
24 |
super().__init__()
|
25 |
self.model = models
|
26 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
@@ -35,6 +35,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
35 |
)
|
36 |
self.sample_rate = sample_rate
|
37 |
self.effect_order = effect_order
|
|
|
38 |
|
39 |
def forward(self, batch, batch_idx, order=None):
|
40 |
x, y, _, rem_fx_labels = batch
|
@@ -43,6 +44,13 @@ class RemFXChainInference(pl.LightningModule):
|
|
43 |
effects_order = order
|
44 |
else:
|
45 |
effects_order = self.effect_order
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
effects_present = [
|
47 |
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
|
48 |
for effect_label in rem_fx_labels
|
|
|
20 |
|
21 |
|
22 |
class RemFXChainInference(pl.LightningModule):
|
23 |
+
def __init__(self, models, sample_rate, num_bins, effect_order, classifier=None):
|
24 |
super().__init__()
|
25 |
self.model = models
|
26 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
|
|
35 |
)
|
36 |
self.sample_rate = sample_rate
|
37 |
self.effect_order = effect_order
|
38 |
+
self.classifier = classifier
|
39 |
|
40 |
def forward(self, batch, batch_idx, order=None):
|
41 |
x, y, _, rem_fx_labels = batch
|
|
|
44 |
effects_order = order
|
45 |
else:
|
46 |
effects_order = self.effect_order
|
47 |
+
|
48 |
+
# Use classifier labels
|
49 |
+
if self.classifier:
|
50 |
+
threshold = 0.5
|
51 |
+
labels = self.classifier(x)
|
52 |
+
rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
|
53 |
+
|
54 |
effects_present = [
|
55 |
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
|
56 |
for effect_label in rem_fx_labels
|
scripts/chain_inference.py
CHANGED
@@ -26,6 +26,16 @@ def main(cfg: DictConfig):
|
|
26 |
model.to(device)
|
27 |
models[effect] = model
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
callbacks = []
|
30 |
if "callbacks" in cfg:
|
31 |
for _, cb_conf in cfg["callbacks"].items():
|
@@ -54,6 +64,7 @@ def main(cfg: DictConfig):
|
|
54 |
sample_rate=cfg.sample_rate,
|
55 |
num_bins=cfg.num_bins,
|
56 |
effect_order=cfg.inference_effects_ordering,
|
|
|
57 |
)
|
58 |
trainer.test(model=inference_model, datamodule=datamodule)
|
59 |
|
|
|
26 |
model.to(device)
|
27 |
models[effect] = model
|
28 |
|
29 |
+
classifier = None
|
30 |
+
if "classifier" in cfg:
|
31 |
+
log.info(f"Instantiating classifier <{cfg.classifier._target_}>.")
|
32 |
+
classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
|
33 |
+
ckpt_path = cfg.classifier_ckpt
|
34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
|
36 |
+
classifier.load_state_dict(state_dict)
|
37 |
+
classifier.to(device)
|
38 |
+
|
39 |
callbacks = []
|
40 |
if "callbacks" in cfg:
|
41 |
for _, cb_conf in cfg["callbacks"].items():
|
|
|
64 |
sample_rate=cfg.sample_rate,
|
65 |
num_bins=cfg.num_bins,
|
66 |
effect_order=cfg.inference_effects_ordering,
|
67 |
+
classifier=classifier,
|
68 |
)
|
69 |
trainer.test(model=inference_model, datamodule=datamodule)
|
70 |
|