Spaces:
Sleeping
Sleeping
Commit
·
ace4057
1
Parent(s):
e8eaf47
Add custom model choice for chain inference
Browse files- cfg/exp/chain_inference.yaml +30 -5
- cfg/exp/chain_inference_aug.yaml +30 -5
- cfg/exp/chain_inference_custom.yaml +30 -5
- remfx/callbacks.py +17 -1
- scripts/chain_inference.py +2 -2
- scripts/train.py +10 -1
cfg/exp/chain_inference.yaml
CHANGED
@@ -26,12 +26,37 @@ datamodule:
|
|
26 |
batch_size: 16
|
27 |
num_workers: 8
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
ckpts:
|
30 |
-
RandomPedalboardDistortion:
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
inference_effects_ordering:
|
36 |
- "RandomPedalboardDistortion"
|
37 |
- "RandomPedalboardCompressor"
|
|
|
26 |
batch_size: 16
|
27 |
num_workers: 8
|
28 |
|
29 |
+
dcunet:
|
30 |
+
_target_: remfx.models.RemFX
|
31 |
+
lr: 1e-4
|
32 |
+
lr_beta1: 0.95
|
33 |
+
lr_beta2: 0.999
|
34 |
+
lr_eps: 1e-6
|
35 |
+
lr_weight_decay: 1e-3
|
36 |
+
sample_rate: ${sample_rate}
|
37 |
+
network:
|
38 |
+
_target_: remfx.models.DCUNetModel
|
39 |
+
architecture: "Large-DCUNet-20"
|
40 |
+
stft_kernel_size: 512
|
41 |
+
fix_length_mode: "pad"
|
42 |
+
sample_rate: ${sample_rate}
|
43 |
+
num_bins: 1025
|
44 |
ckpts:
|
45 |
+
RandomPedalboardDistortion:
|
46 |
+
model: ${model}
|
47 |
+
ckpt_path: "ckpts/demucs_distortion.ckpt"
|
48 |
+
RandomPedalboardCompressor:
|
49 |
+
model: ${model}
|
50 |
+
ckpt_path: "ckpts/demucs_compressor.ckpt"
|
51 |
+
RandomPedalboardReverb:
|
52 |
+
model: ${dcunet}
|
53 |
+
ckpt_path: "ckpts/dcunet_reverb.ckpt"
|
54 |
+
RandomPedalboardChorus:
|
55 |
+
model: ${dcunet}
|
56 |
+
ckpt_path: "ckpts/dcunet_chorus.ckpt"
|
57 |
+
RandomPedalboardDelay:
|
58 |
+
model: ${dcunet}
|
59 |
+
ckpt_path: "ckpts/dcunet_delay.ckpt"
|
60 |
inference_effects_ordering:
|
61 |
- "RandomPedalboardDistortion"
|
62 |
- "RandomPedalboardCompressor"
|
cfg/exp/chain_inference_aug.yaml
CHANGED
@@ -26,12 +26,37 @@ datamodule:
|
|
26 |
batch_size: 16
|
27 |
num_workers: 8
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
ckpts:
|
30 |
-
RandomPedalboardDistortion:
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
inference_effects_ordering:
|
36 |
- "RandomPedalboardDistortion"
|
37 |
- "RandomPedalboardCompressor"
|
|
|
26 |
batch_size: 16
|
27 |
num_workers: 8
|
28 |
|
29 |
+
dcunet:
|
30 |
+
_target_: remfx.models.RemFX
|
31 |
+
lr: 1e-4
|
32 |
+
lr_beta1: 0.95
|
33 |
+
lr_beta2: 0.999
|
34 |
+
lr_eps: 1e-6
|
35 |
+
lr_weight_decay: 1e-3
|
36 |
+
sample_rate: ${sample_rate}
|
37 |
+
network:
|
38 |
+
_target_: remfx.models.DCUNetModel
|
39 |
+
architecture: "Large-DCUNet-20"
|
40 |
+
stft_kernel_size: 512
|
41 |
+
fix_length_mode: "pad"
|
42 |
+
sample_rate: ${sample_rate}
|
43 |
+
num_bins: 1025
|
44 |
ckpts:
|
45 |
+
RandomPedalboardDistortion:
|
46 |
+
model: ${model}
|
47 |
+
ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
|
48 |
+
RandomPedalboardCompressor:
|
49 |
+
model: ${model}
|
50 |
+
ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
|
51 |
+
RandomPedalboardReverb:
|
52 |
+
model: ${dcunet}
|
53 |
+
ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
|
54 |
+
RandomPedalboardChorus:
|
55 |
+
model: ${dcunet}
|
56 |
+
ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
|
57 |
+
RandomPedalboardDelay:
|
58 |
+
model: ${dcunet}
|
59 |
+
ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
|
60 |
inference_effects_ordering:
|
61 |
- "RandomPedalboardDistortion"
|
62 |
- "RandomPedalboardCompressor"
|
cfg/exp/chain_inference_custom.yaml
CHANGED
@@ -31,12 +31,37 @@ datamodule:
|
|
31 |
_target_: remfx.datasets.InferenceDataset
|
32 |
root: ${oc.env:DATASET_ROOT}
|
33 |
sample_rate: ${sample_rate}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
ckpts:
|
35 |
-
RandomPedalboardDistortion:
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
inference_effects_ordering:
|
41 |
- "RandomPedalboardDistortion"
|
42 |
- "RandomPedalboardCompressor"
|
|
|
31 |
_target_: remfx.datasets.InferenceDataset
|
32 |
root: ${oc.env:DATASET_ROOT}
|
33 |
sample_rate: ${sample_rate}
|
34 |
+
dcunet:
|
35 |
+
_target_: remfx.models.RemFX
|
36 |
+
lr: 1e-4
|
37 |
+
lr_beta1: 0.95
|
38 |
+
lr_beta2: 0.999
|
39 |
+
lr_eps: 1e-6
|
40 |
+
lr_weight_decay: 1e-3
|
41 |
+
sample_rate: ${sample_rate}
|
42 |
+
network:
|
43 |
+
_target_: remfx.models.DCUNetModel
|
44 |
+
architecture: "Large-DCUNet-20"
|
45 |
+
stft_kernel_size: 512
|
46 |
+
fix_length_mode: "pad"
|
47 |
+
sample_rate: ${sample_rate}
|
48 |
+
num_bins: 1025
|
49 |
ckpts:
|
50 |
+
RandomPedalboardDistortion:
|
51 |
+
model: ${model}
|
52 |
+
ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
|
53 |
+
RandomPedalboardCompressor:
|
54 |
+
model: ${model}
|
55 |
+
ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
|
56 |
+
RandomPedalboardReverb:
|
57 |
+
model: ${dcunet}
|
58 |
+
ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
|
59 |
+
RandomPedalboardChorus:
|
60 |
+
model: ${dcunet}
|
61 |
+
ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
|
62 |
+
RandomPedalboardDelay:
|
63 |
+
model: ${dcunet}
|
64 |
+
ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
|
65 |
inference_effects_ordering:
|
66 |
- "RandomPedalboardDistortion"
|
67 |
- "RandomPedalboardCompressor"
|
remfx/callbacks.py
CHANGED
@@ -4,6 +4,9 @@ from einops import rearrange
|
|
4 |
import torch
|
5 |
import wandb
|
6 |
from torch import Tensor
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
class AudioCallback(Callback):
|
@@ -42,7 +45,7 @@ class AudioCallback(Callback):
|
|
42 |
def on_validation_batch_start(
|
43 |
self, trainer, pl_module, batch, batch_idx, dataloader_idx
|
44 |
):
|
45 |
-
x, target, _,
|
46 |
# Only run on first batch
|
47 |
if batch_idx == 0 and self.log_audio:
|
48 |
with torch.no_grad():
|
@@ -51,6 +54,19 @@ class AudioCallback(Callback):
|
|
51 |
|
52 |
if type(pl_module) == RemFXChainInference:
|
53 |
y = pl_module.sample(batch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
else:
|
55 |
y = pl_module.model.sample(x)
|
56 |
# Concat samples together for easier viewing in dashboard
|
|
|
4 |
import torch
|
5 |
import wandb
|
6 |
from torch import Tensor
|
7 |
+
from remfx import effects
|
8 |
+
|
9 |
+
ALL_EFFECTS = effects.Pedalboard_Effects
|
10 |
|
11 |
|
12 |
class AudioCallback(Callback):
|
|
|
45 |
def on_validation_batch_start(
|
46 |
self, trainer, pl_module, batch, batch_idx, dataloader_idx
|
47 |
):
|
48 |
+
x, target, _, rem_fx_labels = batch
|
49 |
# Only run on first batch
|
50 |
if batch_idx == 0 and self.log_audio:
|
51 |
with torch.no_grad():
|
|
|
54 |
|
55 |
if type(pl_module) == RemFXChainInference:
|
56 |
y = pl_module.sample(batch)
|
57 |
+
effects_present_name = [
|
58 |
+
[
|
59 |
+
ALL_EFFECTS[i].__name__.replace("RandomPedalboard", "")
|
60 |
+
for i, effect in enumerate(effect_label)
|
61 |
+
if effect == 1.0
|
62 |
+
]
|
63 |
+
for effect_label in rem_fx_labels
|
64 |
+
]
|
65 |
+
for i, label in enumerate(effects_present_name):
|
66 |
+
self.log(f"{'_'.join(label)}", 0.0)
|
67 |
+
# self.log(f"{effects}_{i}", label)
|
68 |
+
# trainer.logger.experiment.log(
|
69 |
+
# {f"effects_{i}": f"{'_'.join(label)}"}
|
70 |
else:
|
71 |
y = pl_module.model.sample(x)
|
72 |
# Concat samples together for easier viewing in dashboard
|
scripts/chain_inference.py
CHANGED
@@ -18,8 +18,8 @@ def main(cfg: DictConfig):
|
|
18 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
19 |
models = {}
|
20 |
for effect in cfg.ckpts:
|
21 |
-
|
22 |
-
|
23 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
|
25 |
model.load_state_dict(state_dict)
|
|
|
18 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
19 |
models = {}
|
20 |
for effect in cfg.ckpts:
|
21 |
+
model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
|
22 |
+
ckpt_path = cfg.ckpts[effect].ckpt_path
|
23 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
|
25 |
model.load_state_dict(state_dict)
|
scripts/train.py
CHANGED
@@ -18,7 +18,16 @@ def main(cfg: DictConfig):
|
|
18 |
|
19 |
if "ckpt_path" in cfg:
|
20 |
log.info(f"Loading checkpoint from <{cfg.ckpt_path}>.")
|
21 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Init all callbacks
|
24 |
callbacks = []
|
|
|
18 |
|
19 |
if "ckpt_path" in cfg:
|
20 |
log.info(f"Loading checkpoint from <{cfg.ckpt_path}>.")
|
21 |
+
model.load_from_checkpoint(
|
22 |
+
cfg.ckpt_path,
|
23 |
+
lr=model.lr,
|
24 |
+
lr_beta1=model.lr_beta1,
|
25 |
+
lr_beta2=model.lr_beta2,
|
26 |
+
lr_eps=model.lr_eps,
|
27 |
+
lr_weight_decay=model.lr_weight_decay,
|
28 |
+
sample_rate=model.sample_rate,
|
29 |
+
network=model.model,
|
30 |
+
)
|
31 |
|
32 |
# Init all callbacks
|
33 |
callbacks = []
|