Spaces:
Sleeping
Sleeping
Commit
·
652f240
1
Parent(s):
30c1d67
Add effects chain inference code
Browse files- cfg/config.yaml +10 -3
- cfg/exp/chain_inference.yaml +35 -0
- cfg/exp/{dist.yaml → distortion.yaml} +0 -0
- remfx/callbacks.py +5 -1
- remfx/models.py +70 -0
- scripts/chain_inference.py +61 -0
cfg/config.yaml
CHANGED
@@ -51,7 +51,7 @@ datamodule:
|
|
51 |
_target_: remfx.datasets.EffectDatamodule
|
52 |
train_dataset:
|
53 |
_target_: remfx.datasets.EffectDataset
|
54 |
-
total_chunks:
|
55 |
sample_rate: ${sample_rate}
|
56 |
root: ${oc.env:DATASET_ROOT}
|
57 |
chunk_size: ${chunk_size}
|
@@ -67,7 +67,7 @@ datamodule:
|
|
67 |
render_root: ${render_root}
|
68 |
val_dataset:
|
69 |
_target_: remfx.datasets.EffectDataset
|
70 |
-
total_chunks:
|
71 |
sample_rate: ${sample_rate}
|
72 |
root: ${oc.env:DATASET_ROOT}
|
73 |
chunk_size: ${chunk_size}
|
@@ -83,7 +83,7 @@ datamodule:
|
|
83 |
render_root: ${render_root}
|
84 |
test_dataset:
|
85 |
_target_: remfx.datasets.EffectDataset
|
86 |
-
total_chunks:
|
87 |
sample_rate: ${sample_rate}
|
88 |
root: ${oc.env:DATASET_ROOT}
|
89 |
chunk_size: ${chunk_size}
|
@@ -124,3 +124,10 @@ trainer:
|
|
124 |
devices: 1
|
125 |
gradient_clip_val: 10.0
|
126 |
max_steps: 50000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
_target_: remfx.datasets.EffectDatamodule
|
52 |
train_dataset:
|
53 |
_target_: remfx.datasets.EffectDataset
|
54 |
+
total_chunks: 80
|
55 |
sample_rate: ${sample_rate}
|
56 |
root: ${oc.env:DATASET_ROOT}
|
57 |
chunk_size: ${chunk_size}
|
|
|
67 |
render_root: ${render_root}
|
68 |
val_dataset:
|
69 |
_target_: remfx.datasets.EffectDataset
|
70 |
+
total_chunks: 10
|
71 |
sample_rate: ${sample_rate}
|
72 |
root: ${oc.env:DATASET_ROOT}
|
73 |
chunk_size: ${chunk_size}
|
|
|
83 |
render_root: ${render_root}
|
84 |
test_dataset:
|
85 |
_target_: remfx.datasets.EffectDataset
|
86 |
+
total_chunks: 10
|
87 |
sample_rate: ${sample_rate}
|
88 |
root: ${oc.env:DATASET_ROOT}
|
89 |
chunk_size: ${chunk_size}
|
|
|
124 |
devices: 1
|
125 |
gradient_clip_val: 10.0
|
126 |
max_steps: 50000
|
127 |
+
ckpts:
|
128 |
+
RandomPedalboardChorus: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
129 |
+
RandomPedalboardDelay: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
130 |
+
RandomPedalboardDistortion: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
131 |
+
RandomPedalboardCompressor: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
132 |
+
RandomPedalboardReverb: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
133 |
+
num_bins: 1025
|
cfg/exp/chain_inference.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: all
|
5 |
+
seed: 12345
|
6 |
+
sample_rate: 48000
|
7 |
+
chunk_size: 262144 # 5.5s
|
8 |
+
logs_dir: "./logs"
|
9 |
+
render_files: True
|
10 |
+
render_root: "/scratch/EffectSet"
|
11 |
+
accelerator: "gpu"
|
12 |
+
log_audio: True
|
13 |
+
# Effects
|
14 |
+
num_kept_effects: [0,0] # [min, max]
|
15 |
+
num_removed_effects: [0,5] # [min, max]
|
16 |
+
shuffle_kept_effects: True
|
17 |
+
shuffle_removed_effects: True
|
18 |
+
num_classes: 5
|
19 |
+
effects_to_keep:
|
20 |
+
effects_to_remove:
|
21 |
+
- distortion
|
22 |
+
- compressor
|
23 |
+
- reverb
|
24 |
+
- chorus
|
25 |
+
- delay
|
26 |
+
datamodule:
|
27 |
+
batch_size: 16
|
28 |
+
num_workers: 8
|
29 |
+
ckpts:
|
30 |
+
RandomPedalboardChorus: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
31 |
+
RandomPedalboardDelay: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
32 |
+
RandomPedalboardDistortion: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
33 |
+
RandomPedalboardCompressor: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
34 |
+
RandomPedalboardReverb: "/Users/matthewrice/Developer/remfx/ckpts/demucs_5-5/last.ckpt"
|
35 |
+
num_bins: 1025
|
cfg/exp/{dist.yaml → distortion.yaml}
RENAMED
File without changes
|
remfx/callbacks.py
CHANGED
@@ -4,6 +4,7 @@ from einops import rearrange
|
|
4 |
import torch
|
5 |
import wandb
|
6 |
from torch import Tensor
|
|
|
7 |
|
8 |
|
9 |
class AudioCallback(Callback):
|
@@ -46,7 +47,10 @@ class AudioCallback(Callback):
|
|
46 |
# Only run on first batch
|
47 |
if batch_idx == 0 and self.log_audio:
|
48 |
with torch.no_grad():
|
49 |
-
|
|
|
|
|
|
|
50 |
# Concat samples together for easier viewing in dashboard
|
51 |
# 2 seconds of silence between each sample
|
52 |
silence = torch.zeros_like(x)
|
|
|
4 |
import torch
|
5 |
import wandb
|
6 |
from torch import Tensor
|
7 |
+
from remfx.models import RemFXChainInference
|
8 |
|
9 |
|
10 |
class AudioCallback(Callback):
|
|
|
47 |
# Only run on first batch
|
48 |
if batch_idx == 0 and self.log_audio:
|
49 |
with torch.no_grad():
|
50 |
+
if type(pl_module) == RemFXChainInference:
|
51 |
+
y = pl_module.sample(batch)
|
52 |
+
else:
|
53 |
+
y = pl_module.model.sample(x)
|
54 |
# Concat samples together for easier viewing in dashboard
|
55 |
# 2 seconds of silence between each sample
|
56 |
silence = torch.zeros_like(x)
|
remfx/models.py
CHANGED
@@ -11,8 +11,78 @@ from umx.openunmix.model import OpenUnmix, Separator
|
|
11 |
from remfx.utils import FADLoss, spectrogram
|
12 |
from remfx.tcn import TCN
|
13 |
from remfx.utils import causal_crop
|
|
|
14 |
import asteroid
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
class RemFX(pl.LightningModule):
|
18 |
def __init__(
|
|
|
11 |
from remfx.utils import FADLoss, spectrogram
|
12 |
from remfx.tcn import TCN
|
13 |
from remfx.utils import causal_crop
|
14 |
+
from remfx import effects
|
15 |
import asteroid
|
16 |
|
17 |
+
ALL_EFFECTS = effects.Pedalboard_Effects
|
18 |
+
|
19 |
+
|
20 |
+
class RemFXChainInference(pl.LightningModule):
|
21 |
+
def __init__(self, models, sample_rate, num_bins):
|
22 |
+
super().__init__()
|
23 |
+
self.model = models
|
24 |
+
self.mrstftloss = MultiResolutionSTFTLoss(
|
25 |
+
n_bins=num_bins, sample_rate=sample_rate
|
26 |
+
)
|
27 |
+
self.l1loss = nn.L1Loss()
|
28 |
+
self.metrics = nn.ModuleDict(
|
29 |
+
{
|
30 |
+
"SISDR": SISDRLoss(),
|
31 |
+
"STFT": MultiResolutionSTFTLoss(),
|
32 |
+
"FAD": FADLoss(sample_rate=sample_rate),
|
33 |
+
}
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, batch):
|
37 |
+
x, y, _, rem_fx_labels = batch
|
38 |
+
# Use chain of effects defined in config
|
39 |
+
effects = [
|
40 |
+
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
|
41 |
+
for effect_label in rem_fx_labels
|
42 |
+
]
|
43 |
+
output = []
|
44 |
+
with torch.no_grad():
|
45 |
+
for elem, effect_chain in zip(x, effects):
|
46 |
+
elem = elem.unsqueeze(0) # Add batch dim
|
47 |
+
for effect in effect_chain:
|
48 |
+
# Get correct model based on effect name. This is a bit hacky
|
49 |
+
# Then sample the model
|
50 |
+
elem = self.model[effect.__name__].model.sample(elem)
|
51 |
+
output.append(elem.squeeze(0))
|
52 |
+
output = torch.stack(output)
|
53 |
+
|
54 |
+
loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
|
55 |
+
return loss, output
|
56 |
+
|
57 |
+
def test_step(self, batch, batch_idx):
|
58 |
+
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
|
59 |
+
|
60 |
+
loss, output = self.forward(batch)
|
61 |
+
# Crop target to match output
|
62 |
+
if output.shape[-1] < y.shape[-1]:
|
63 |
+
y = causal_crop(y, output.shape[-1])
|
64 |
+
self.log("test_loss", loss)
|
65 |
+
# Metric logging
|
66 |
+
with torch.no_grad():
|
67 |
+
for metric in self.metrics:
|
68 |
+
# SISDR returns negative values, so negate them
|
69 |
+
if metric == "SISDR":
|
70 |
+
negate = -1
|
71 |
+
else:
|
72 |
+
negate = 1
|
73 |
+
self.log(
|
74 |
+
f"test_{metric}",
|
75 |
+
negate * self.metrics[metric](output, y),
|
76 |
+
on_step=False,
|
77 |
+
on_epoch=True,
|
78 |
+
logger=True,
|
79 |
+
prog_bar=True,
|
80 |
+
sync_dist=True,
|
81 |
+
)
|
82 |
+
|
83 |
+
def sample(self, batch):
|
84 |
+
return self.forward(batch)[1]
|
85 |
+
|
86 |
|
87 |
class RemFX(pl.LightningModule):
|
88 |
def __init__(
|
scripts/chain_inference.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import hydra
|
3 |
+
from omegaconf import DictConfig
|
4 |
+
import remfx.utils as utils
|
5 |
+
from pytorch_lightning.utilities.model_summary import ModelSummary
|
6 |
+
import torch
|
7 |
+
from remfx.models import RemFXChainInference
|
8 |
+
|
9 |
+
log = utils.get_logger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
@hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
|
13 |
+
def main(cfg: DictConfig):
|
14 |
+
# Apply seed for reproducibility
|
15 |
+
if cfg.seed:
|
16 |
+
pl.seed_everything(cfg.seed)
|
17 |
+
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
18 |
+
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
19 |
+
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
20 |
+
models = {}
|
21 |
+
for effect in cfg.ckpts:
|
22 |
+
ckpt_path = cfg.ckpts[effect]
|
23 |
+
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
24 |
+
state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[
|
25 |
+
"state_dict"
|
26 |
+
]
|
27 |
+
model.load_state_dict(state_dict)
|
28 |
+
models[effect] = model
|
29 |
+
|
30 |
+
callbacks = []
|
31 |
+
if "callbacks" in cfg:
|
32 |
+
for _, cb_conf in cfg["callbacks"].items():
|
33 |
+
if "_target_" in cb_conf:
|
34 |
+
log.info(f"Instantiating callback <{cb_conf._target_}>.")
|
35 |
+
callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
|
36 |
+
|
37 |
+
logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
|
38 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
|
39 |
+
trainer = hydra.utils.instantiate(
|
40 |
+
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
|
41 |
+
)
|
42 |
+
log.info("Logging hyperparameters!")
|
43 |
+
utils.log_hyperparameters(
|
44 |
+
config=cfg,
|
45 |
+
model=model,
|
46 |
+
datamodule=datamodule,
|
47 |
+
trainer=trainer,
|
48 |
+
callbacks=callbacks,
|
49 |
+
logger=logger,
|
50 |
+
)
|
51 |
+
summary = ModelSummary(model)
|
52 |
+
print(summary)
|
53 |
+
|
54 |
+
inference_model = RemFXChainInference(
|
55 |
+
models, sample_rate=cfg.sample_rate, num_bins=cfg.num_bins
|
56 |
+
)
|
57 |
+
trainer.test(model=inference_model, datamodule=datamodule)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
main()
|