Spaces:
Runtime error
Runtime error
Merge pull request #40 from mhrice/classifier-inference
Browse files- README.md +4 -0
- cfg/exp/chain_inference.yaml +3 -1
- cfg/exp/chain_inference_aug.yaml +3 -1
- cfg/exp/chain_inference_aug_classifier.yaml +87 -0
- cfg/exp/chain_inference_custom.yaml +3 -1
- remfx/callbacks.py +0 -3
- remfx/classifier.py +5 -3
- remfx/datasets.py +50 -44
- remfx/models.py +82 -34
- remfx/utils.py +1 -1
- scripts/chain_inference.py +14 -1
README.md
CHANGED
@@ -47,6 +47,9 @@ see `cfg/exp/default.yaml` for an example.
|
|
47 |
- `reverb`
|
48 |
- `delay`
|
49 |
|
|
|
|
|
|
|
50 |
## Run inference on directory
|
51 |
Assumes directory is structured as
|
52 |
- root
|
@@ -64,6 +67,7 @@ Change root path in `shell_vars.sh` and `source shell_vars.sh`
|
|
64 |
`python scripts/chain_inference.py +exp=chain_inference_custom`
|
65 |
|
66 |
|
|
|
67 |
## Misc.
|
68 |
By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
|
69 |
|
|
|
47 |
- `reverb`
|
48 |
- `delay`
|
49 |
|
50 |
+
## Chain Inference
|
51 |
+
`python scripts/chain_inference.py +exp=chain_inference`
|
52 |
+
|
53 |
## Run inference on directory
|
54 |
Assumes directory is structured as
|
55 |
- root
|
|
|
67 |
`python scripts/chain_inference.py +exp=chain_inference_custom`
|
68 |
|
69 |
|
70 |
+
|
71 |
## Misc.
|
72 |
By default, files are rendered to `input_dir / processed / {string_of_effects} / {train|val|test}`.
|
73 |
|
cfg/exp/chain_inference.yaml
CHANGED
@@ -63,4 +63,6 @@ inference_effects_ordering:
|
|
63 |
- "RandomPedalboardReverb"
|
64 |
- "RandomPedalboardChorus"
|
65 |
- "RandomPedalboardDelay"
|
66 |
-
num_bins: 1025
|
|
|
|
|
|
63 |
- "RandomPedalboardReverb"
|
64 |
- "RandomPedalboardChorus"
|
65 |
- "RandomPedalboardDelay"
|
66 |
+
num_bins: 1025
|
67 |
+
inference_effects_shuffle: False
|
68 |
+
inference_use_all_effect_models: False
|
cfg/exp/chain_inference_aug.yaml
CHANGED
@@ -63,4 +63,6 @@ inference_effects_ordering:
|
|
63 |
- "RandomPedalboardReverb"
|
64 |
- "RandomPedalboardChorus"
|
65 |
- "RandomPedalboardDelay"
|
66 |
-
num_bins: 1025
|
|
|
|
|
|
63 |
- "RandomPedalboardReverb"
|
64 |
- "RandomPedalboardChorus"
|
65 |
- "RandomPedalboardDelay"
|
66 |
+
num_bins: 1025
|
67 |
+
inference_effects_shuffle: False
|
68 |
+
inference_use_all_effect_models: False
|
cfg/exp/chain_inference_aug_classifier.yaml
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: all
|
5 |
+
seed: 12345
|
6 |
+
sample_rate: 48000
|
7 |
+
chunk_size: 262144 # 5.5s
|
8 |
+
logs_dir: "./logs"
|
9 |
+
render_root: "/scratch/EffectSet"
|
10 |
+
accelerator: "gpu"
|
11 |
+
log_audio: True
|
12 |
+
# Effects
|
13 |
+
num_kept_effects: [0,0] # [min, max]
|
14 |
+
num_removed_effects: [0,5] # [min, max]
|
15 |
+
shuffle_kept_effects: True
|
16 |
+
shuffle_removed_effects: True
|
17 |
+
num_classes: 5
|
18 |
+
effects_to_keep:
|
19 |
+
effects_to_remove:
|
20 |
+
- distortion
|
21 |
+
- compressor
|
22 |
+
- reverb
|
23 |
+
- chorus
|
24 |
+
- delay
|
25 |
+
datamodule:
|
26 |
+
batch_size: 16
|
27 |
+
num_workers: 8
|
28 |
+
|
29 |
+
dcunet:
|
30 |
+
_target_: remfx.models.RemFX
|
31 |
+
lr: 1e-4
|
32 |
+
lr_beta1: 0.95
|
33 |
+
lr_beta2: 0.999
|
34 |
+
lr_eps: 1e-6
|
35 |
+
lr_weight_decay: 1e-3
|
36 |
+
sample_rate: ${sample_rate}
|
37 |
+
network:
|
38 |
+
_target_: remfx.models.DCUNetModel
|
39 |
+
architecture: "Large-DCUNet-20"
|
40 |
+
stft_kernel_size: 512
|
41 |
+
fix_length_mode: "pad"
|
42 |
+
sample_rate: ${sample_rate}
|
43 |
+
num_bins: 1025
|
44 |
+
|
45 |
+
classifier:
|
46 |
+
_target_: remfx.models.FXClassifier
|
47 |
+
lr: 3e-4
|
48 |
+
lr_weight_decay: 1e-3
|
49 |
+
sample_rate: ${sample_rate}
|
50 |
+
mixup: False
|
51 |
+
network:
|
52 |
+
_target_: remfx.classifier.Cnn14
|
53 |
+
num_classes: ${num_classes}
|
54 |
+
n_fft: 2048
|
55 |
+
hop_length: 512
|
56 |
+
n_mels: 128
|
57 |
+
sample_rate: ${sample_rate}
|
58 |
+
model_sample_rate: ${sample_rate}
|
59 |
+
specaugment: False
|
60 |
+
classifier_ckpt: "ckpts/classifier.ckpt"
|
61 |
+
|
62 |
+
ckpts:
|
63 |
+
RandomPedalboardDistortion:
|
64 |
+
model: ${model}
|
65 |
+
ckpt_path: "ckpts/demucs_distortion_aug.ckpt"
|
66 |
+
RandomPedalboardCompressor:
|
67 |
+
model: ${model}
|
68 |
+
ckpt_path: "ckpts/demucs_compressor_aug.ckpt"
|
69 |
+
RandomPedalboardReverb:
|
70 |
+
model: ${dcunet}
|
71 |
+
ckpt_path: "ckpts/dcunet_reverb_aug.ckpt"
|
72 |
+
RandomPedalboardChorus:
|
73 |
+
model: ${dcunet}
|
74 |
+
ckpt_path: "ckpts/dcunet_chorus_aug.ckpt"
|
75 |
+
RandomPedalboardDelay:
|
76 |
+
model: ${dcunet}
|
77 |
+
ckpt_path: "ckpts/dcunet_delay_aug.ckpt"
|
78 |
+
|
79 |
+
inference_effects_ordering:
|
80 |
+
- "RandomPedalboardDistortion"
|
81 |
+
- "RandomPedalboardCompressor"
|
82 |
+
- "RandomPedalboardReverb"
|
83 |
+
- "RandomPedalboardChorus"
|
84 |
+
- "RandomPedalboardDelay"
|
85 |
+
num_bins: 1025
|
86 |
+
inference_effects_shuffle: False
|
87 |
+
inference_use_all_effect_models: False
|
cfg/exp/chain_inference_custom.yaml
CHANGED
@@ -68,4 +68,6 @@ inference_effects_ordering:
|
|
68 |
- "RandomPedalboardReverb"
|
69 |
- "RandomPedalboardChorus"
|
70 |
- "RandomPedalboardDelay"
|
71 |
-
num_bins: 1025
|
|
|
|
|
|
68 |
- "RandomPedalboardReverb"
|
69 |
- "RandomPedalboardChorus"
|
70 |
- "RandomPedalboardDelay"
|
71 |
+
num_bins: 1025
|
72 |
+
inference_effects_shuffle: False
|
73 |
+
inference_use_all_effect_models: False
|
remfx/callbacks.py
CHANGED
@@ -64,9 +64,6 @@ class AudioCallback(Callback):
|
|
64 |
]
|
65 |
for i, label in enumerate(effects_present_name):
|
66 |
self.log(f"{'_'.join(label)}", 0.0)
|
67 |
-
# self.log(f"{effects}_{i}", label)
|
68 |
-
# trainer.logger.experiment.log(
|
69 |
-
# {f"effects_{i}": f"{'_'.join(label)}"}
|
70 |
else:
|
71 |
y = pl_module.model.sample(x)
|
72 |
# Concat samples together for easier viewing in dashboard
|
|
|
64 |
]
|
65 |
for i, label in enumerate(effects_present_name):
|
66 |
self.log(f"{'_'.join(label)}", 0.0)
|
|
|
|
|
|
|
67 |
else:
|
68 |
y = pl_module.model.sample(x)
|
69 |
# Concat samples together for easier viewing in dashboard
|
remfx/classifier.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import torch
|
2 |
import torchaudio
|
3 |
import torch.nn as nn
|
4 |
-
|
5 |
-
import hearbaseline
|
6 |
-
|
|
|
|
|
7 |
|
8 |
import wav2clip_hear
|
9 |
import panns_hear
|
|
|
1 |
import torch
|
2 |
import torchaudio
|
3 |
import torch.nn as nn
|
4 |
+
|
5 |
+
# import hearbaseline
|
6 |
+
|
7 |
+
# import hearbaseline.vggish
|
8 |
+
# import hearbaseline.wav2vec2
|
9 |
|
10 |
import wav2clip_hear
|
11 |
import panns_hear
|
remfx/datasets.py
CHANGED
@@ -13,10 +13,10 @@ from typing import Any, List, Dict
|
|
13 |
from torch.utils.data import Dataset, DataLoader
|
14 |
from remfx.utils import select_random_chunk
|
15 |
import multiprocessing
|
|
|
16 |
|
17 |
|
18 |
-
|
19 |
-
|
20 |
ALL_EFFECTS = effect_lib.Pedalboard_Effects
|
21 |
# print(ALL_EFFECTS)
|
22 |
|
@@ -404,6 +404,7 @@ class EffectDataset(Dataset):
|
|
404 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
405 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
406 |
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
|
|
407 |
self.effects = effect_modules
|
408 |
self.shuffle_kept_effects = shuffle_kept_effects
|
409 |
self.shuffle_removed_effects = shuffle_removed_effects
|
@@ -471,7 +472,6 @@ class EffectDataset(Dataset):
|
|
471 |
chunk = select_random_chunk(
|
472 |
random_file_choice, self.chunk_size, self.sample_rate
|
473 |
)
|
474 |
-
|
475 |
# Sum to mono
|
476 |
if chunk.shape[0] > 1:
|
477 |
chunk = chunk.sum(0, keepdim=True)
|
@@ -568,46 +568,52 @@ class EffectDataset(Dataset):
|
|
568 |
# Index in effect settings
|
569 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
570 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
571 |
-
#
|
572 |
-
|
573 |
-
|
574 |
-
#
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
#
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
611 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
612 |
|
613 |
|
@@ -692,7 +698,7 @@ class EffectDatamodule(pl.LightningDataModule):
|
|
692 |
def test_dataloader(self) -> DataLoader:
|
693 |
return DataLoader(
|
694 |
dataset=self.test_dataset,
|
695 |
-
batch_size=
|
696 |
num_workers=self.num_workers,
|
697 |
pin_memory=self.pin_memory,
|
698 |
shuffle=False,
|
|
|
13 |
from torch.utils.data import Dataset, DataLoader
|
14 |
from remfx.utils import select_random_chunk
|
15 |
import multiprocessing
|
16 |
+
from auraloss.freq import MultiResolutionSTFTLoss
|
17 |
|
18 |
|
19 |
+
STFT_THRESH = 1e-3
|
|
|
20 |
ALL_EFFECTS = effect_lib.Pedalboard_Effects
|
21 |
# print(ALL_EFFECTS)
|
22 |
|
|
|
404 |
self.effects_to_keep = [] if effects_to_keep is None else effects_to_keep
|
405 |
self.effects_to_remove = [] if effects_to_remove is None else effects_to_remove
|
406 |
self.normalize = effect_lib.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
407 |
+
self.mrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate)
|
408 |
self.effects = effect_modules
|
409 |
self.shuffle_kept_effects = shuffle_kept_effects
|
410 |
self.shuffle_removed_effects = shuffle_removed_effects
|
|
|
472 |
chunk = select_random_chunk(
|
473 |
random_file_choice, self.chunk_size, self.sample_rate
|
474 |
)
|
|
|
475 |
# Sum to mono
|
476 |
if chunk.shape[0] > 1:
|
477 |
chunk = chunk.sum(0, keepdim=True)
|
|
|
568 |
# Index in effect settings
|
569 |
effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
|
570 |
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
571 |
+
# stft comparison
|
572 |
+
stft = 0
|
573 |
+
while stft < STFT_THRESH:
|
574 |
+
# Apply
|
575 |
+
dry_labels = []
|
576 |
+
for effect in effects_to_apply:
|
577 |
+
# Normalize in-between effects
|
578 |
+
dry = self.normalize(effect(dry))
|
579 |
+
dry_labels.append(ALL_EFFECTS.index(type(effect)))
|
580 |
+
|
581 |
+
# Apply effects_to_remove
|
582 |
+
# Shuffle effects if specified
|
583 |
+
if self.shuffle_removed_effects:
|
584 |
+
effect_indices = torch.randperm(len(self.effects_to_remove))
|
585 |
+
else:
|
586 |
+
effect_indices = torch.arange(len(self.effects_to_remove))
|
587 |
+
wet = torch.clone(dry)
|
588 |
+
r1 = self.num_removed_effects[0]
|
589 |
+
r2 = self.num_removed_effects[1]
|
590 |
+
num_removed_effects = torch.round((r1 - r2) * torch.rand(1) + r2).int()
|
591 |
+
effect_indices = effect_indices[:num_removed_effects]
|
592 |
+
# Index in effect settings
|
593 |
+
effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
|
594 |
+
effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
|
595 |
+
# Apply
|
596 |
+
wet_labels = []
|
597 |
+
for effect in effects_to_apply:
|
598 |
+
# Normalize in-between effects
|
599 |
+
wet = self.normalize(effect(wet))
|
600 |
+
wet_labels.append(ALL_EFFECTS.index(type(effect)))
|
601 |
+
|
602 |
+
wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
603 |
+
dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
|
604 |
+
|
605 |
+
for label_idx in wet_labels:
|
606 |
+
wet_labels_tensor[label_idx] = 1.0
|
607 |
+
|
608 |
+
for label_idx in dry_labels:
|
609 |
+
dry_labels_tensor[label_idx] = 1.0
|
610 |
+
|
611 |
+
# Normalize
|
612 |
+
normalized_dry = self.normalize(dry)
|
613 |
+
normalized_wet = self.normalize(wet)
|
614 |
+
|
615 |
+
# Check STFT, pick different effects if necessary
|
616 |
+
stft = self.mrstft(normalized_wet, normalized_dry)
|
617 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
618 |
|
619 |
|
|
|
698 |
def test_dataloader(self) -> DataLoader:
|
699 |
return DataLoader(
|
700 |
dataset=self.test_dataset,
|
701 |
+
batch_size=1, # Use small, consistent batch size for testing
|
702 |
num_workers=self.num_workers,
|
703 |
pin_memory=self.pin_memory,
|
704 |
shuffle=False,
|
remfx/models.py
CHANGED
@@ -16,12 +16,22 @@ from remfx.callbacks import log_wandb_audio_batch
|
|
16 |
from einops import rearrange
|
17 |
from remfx import effects
|
18 |
import asteroid
|
|
|
19 |
|
20 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
21 |
|
22 |
|
23 |
class RemFXChainInference(pl.LightningModule):
|
24 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
super().__init__()
|
26 |
self.model = models
|
27 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
@@ -36,6 +46,10 @@ class RemFXChainInference(pl.LightningModule):
|
|
36 |
)
|
37 |
self.sample_rate = sample_rate
|
38 |
self.effect_order = effect_order
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def forward(self, batch, batch_idx, order=None):
|
41 |
x, y, _, rem_fx_labels = batch
|
@@ -44,28 +58,46 @@ class RemFXChainInference(pl.LightningModule):
|
|
44 |
effects_order = order
|
45 |
else:
|
46 |
effects_order = self.effect_order
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
output = []
|
52 |
-
input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
53 |
-
target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
54 |
-
|
55 |
-
log_wandb_audio_batch(
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
)
|
62 |
-
log_wandb_audio_batch(
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
)
|
69 |
with torch.no_grad():
|
70 |
for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
|
71 |
elem = elem.unsqueeze(0) # Add batch dim
|
@@ -101,22 +133,22 @@ class RemFXChainInference(pl.LightningModule):
|
|
101 |
# )
|
102 |
output.append(elem.squeeze(0))
|
103 |
output = torch.stack(output)
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
)
|
113 |
loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
|
114 |
return loss, output
|
115 |
|
116 |
def test_step(self, batch, batch_idx):
|
117 |
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
|
118 |
-
|
119 |
-
|
|
|
120 |
loss, output = self.forward(batch, batch_idx, order=self.effect_order)
|
121 |
# Crop target to match output
|
122 |
if output.shape[-1] < y.shape[-1]:
|
@@ -148,8 +180,16 @@ class RemFXChainInference(pl.LightningModule):
|
|
148 |
prog_bar=True,
|
149 |
sync_dist=True,
|
150 |
)
|
|
|
|
|
|
|
|
|
151 |
return loss
|
152 |
|
|
|
|
|
|
|
|
|
153 |
def sample(self, batch):
|
154 |
return self.forward(batch, 0)[1]
|
155 |
|
@@ -181,6 +221,7 @@ class RemFX(pl.LightningModule):
|
|
181 |
)
|
182 |
# Log first batch metrics input vs output only once
|
183 |
self.log_train_audio = True
|
|
|
184 |
|
185 |
@property
|
186 |
def device(self):
|
@@ -257,9 +298,16 @@ class RemFX(pl.LightningModule):
|
|
257 |
prog_bar=True,
|
258 |
sync_dist=True,
|
259 |
)
|
260 |
-
|
|
|
|
|
|
|
261 |
return loss
|
262 |
|
|
|
|
|
|
|
|
|
263 |
|
264 |
class OpenUnmixModel(nn.Module):
|
265 |
def __init__(
|
|
|
16 |
from einops import rearrange
|
17 |
from remfx import effects
|
18 |
import asteroid
|
19 |
+
import random
|
20 |
|
21 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
22 |
|
23 |
|
24 |
class RemFXChainInference(pl.LightningModule):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
models,
|
28 |
+
sample_rate,
|
29 |
+
num_bins,
|
30 |
+
effect_order,
|
31 |
+
classifier=None,
|
32 |
+
shuffle_effect_order=False,
|
33 |
+
use_all_effect_models=False,
|
34 |
+
):
|
35 |
super().__init__()
|
36 |
self.model = models
|
37 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
|
|
46 |
)
|
47 |
self.sample_rate = sample_rate
|
48 |
self.effect_order = effect_order
|
49 |
+
self.classifier = classifier
|
50 |
+
self.shuffle_effect_order = shuffle_effect_order
|
51 |
+
self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
|
52 |
+
self.use_all_effect_models = use_all_effect_models
|
53 |
|
54 |
def forward(self, batch, batch_idx, order=None):
|
55 |
x, y, _, rem_fx_labels = batch
|
|
|
58 |
effects_order = order
|
59 |
else:
|
60 |
effects_order = self.effect_order
|
61 |
+
|
62 |
+
# Use classifier labels
|
63 |
+
if self.classifier:
|
64 |
+
threshold = 0.5
|
65 |
+
with torch.no_grad():
|
66 |
+
labels = torch.sigmoid(self.classifier(x))
|
67 |
+
rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
|
68 |
+
if self.use_all_effect_models:
|
69 |
+
effects_present = [
|
70 |
+
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label)]
|
71 |
+
for effect_label in rem_fx_labels
|
72 |
+
]
|
73 |
+
else:
|
74 |
+
effects_present = [
|
75 |
+
[
|
76 |
+
ALL_EFFECTS[i]
|
77 |
+
for i, effect in enumerate(effect_label)
|
78 |
+
if effect == 1.0
|
79 |
+
]
|
80 |
+
for effect_label in rem_fx_labels
|
81 |
+
]
|
82 |
+
|
83 |
output = []
|
84 |
+
# input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
85 |
+
# target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
86 |
+
|
87 |
+
# log_wandb_audio_batch(
|
88 |
+
# logger=self.logger,
|
89 |
+
# id="input_effected_audio",
|
90 |
+
# samples=input_samples.cpu(),
|
91 |
+
# sampling_rate=self.sample_rate,
|
92 |
+
# caption="Input Data",
|
93 |
+
# )
|
94 |
+
# log_wandb_audio_batch(
|
95 |
+
# logger=self.logger,
|
96 |
+
# id="target_audio",
|
97 |
+
# samples=target_samples.cpu(),
|
98 |
+
# sampling_rate=self.sample_rate,
|
99 |
+
# caption="Target Data",
|
100 |
+
# )
|
101 |
with torch.no_grad():
|
102 |
for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
|
103 |
elem = elem.unsqueeze(0) # Add batch dim
|
|
|
133 |
# )
|
134 |
output.append(elem.squeeze(0))
|
135 |
output = torch.stack(output)
|
136 |
+
|
137 |
+
# log_wandb_audio_batch(
|
138 |
+
# logger=self.logger,
|
139 |
+
# id="output_audio",
|
140 |
+
# samples=output_samples.cpu(),
|
141 |
+
# sampling_rate=self.sample_rate,
|
142 |
+
# caption="Output Data",
|
143 |
+
# )
|
|
|
144 |
loss = self.mrstftloss(output, y) + self.l1loss(output, y) * 100
|
145 |
return loss, output
|
146 |
|
147 |
def test_step(self, batch, batch_idx):
|
148 |
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
|
149 |
+
if self.shuffle_effect_order:
|
150 |
+
# Random order
|
151 |
+
random.shuffle(self.effect_order)
|
152 |
loss, output = self.forward(batch, batch_idx, order=self.effect_order)
|
153 |
# Crop target to match output
|
154 |
if output.shape[-1] < y.shape[-1]:
|
|
|
180 |
prog_bar=True,
|
181 |
sync_dist=True,
|
182 |
)
|
183 |
+
# print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
184 |
+
# print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
185 |
+
self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
186 |
+
self.output_str += "\n"
|
187 |
return loss
|
188 |
|
189 |
+
def on_test_end(self) -> None:
|
190 |
+
with open("output.csv", "w") as f:
|
191 |
+
f.write(self.output_str)
|
192 |
+
|
193 |
def sample(self, batch):
|
194 |
return self.forward(batch, 0)[1]
|
195 |
|
|
|
221 |
)
|
222 |
# Log first batch metrics input vs output only once
|
223 |
self.log_train_audio = True
|
224 |
+
self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
|
225 |
|
226 |
@property
|
227 |
def device(self):
|
|
|
298 |
prog_bar=True,
|
299 |
sync_dist=True,
|
300 |
)
|
301 |
+
# print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
302 |
+
# print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
303 |
+
self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
304 |
+
self.output_str += "\n"
|
305 |
return loss
|
306 |
|
307 |
+
def on_test_end(self) -> None:
|
308 |
+
with open("output.csv", "w") as f:
|
309 |
+
f.write(self.output_str)
|
310 |
+
|
311 |
|
312 |
class OpenUnmixModel(nn.Module):
|
313 |
def __init__(
|
remfx/utils.py
CHANGED
@@ -159,7 +159,7 @@ def select_random_chunk(
|
|
159 |
random_start = torch.randint(0, max_len, (1,)).item()
|
160 |
chunk = audio[:, random_start : random_start + new_chunk_size]
|
161 |
# Skip if energy too low
|
162 |
-
if torch.mean(torch.abs(chunk)) < 1e-
|
163 |
return None
|
164 |
resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
|
165 |
return resampled_chunk
|
|
|
159 |
random_start = torch.randint(0, max_len, (1,)).item()
|
160 |
chunk = audio[:, random_start : random_start + new_chunk_size]
|
161 |
# Skip if energy too low
|
162 |
+
if torch.mean(torch.abs(chunk)) < 1e-4:
|
163 |
return None
|
164 |
resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
|
165 |
return resampled_chunk
|
scripts/chain_inference.py
CHANGED
@@ -15,7 +15,7 @@ def main(cfg: DictConfig):
|
|
15 |
pl.seed_everything(cfg.seed)
|
16 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
17 |
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
18 |
-
log.info(
|
19 |
models = {}
|
20 |
for effect in cfg.ckpts:
|
21 |
model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
|
@@ -26,6 +26,16 @@ def main(cfg: DictConfig):
|
|
26 |
model.to(device)
|
27 |
models[effect] = model
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
callbacks = []
|
30 |
if "callbacks" in cfg:
|
31 |
for _, cb_conf in cfg["callbacks"].items():
|
@@ -54,6 +64,9 @@ def main(cfg: DictConfig):
|
|
54 |
sample_rate=cfg.sample_rate,
|
55 |
num_bins=cfg.num_bins,
|
56 |
effect_order=cfg.inference_effects_ordering,
|
|
|
|
|
|
|
57 |
)
|
58 |
trainer.test(model=inference_model, datamodule=datamodule)
|
59 |
|
|
|
15 |
pl.seed_everything(cfg.seed)
|
16 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
17 |
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
18 |
+
log.info("Instantiating Chain Inference Models")
|
19 |
models = {}
|
20 |
for effect in cfg.ckpts:
|
21 |
model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
|
|
|
26 |
model.to(device)
|
27 |
models[effect] = model
|
28 |
|
29 |
+
classifier = None
|
30 |
+
if "classifier" in cfg:
|
31 |
+
log.info(f"Instantiating classifier <{cfg.classifier._target_}>.")
|
32 |
+
classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
|
33 |
+
ckpt_path = cfg.classifier_ckpt
|
34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
|
36 |
+
classifier.load_state_dict(state_dict)
|
37 |
+
classifier.to(device)
|
38 |
+
|
39 |
callbacks = []
|
40 |
if "callbacks" in cfg:
|
41 |
for _, cb_conf in cfg["callbacks"].items():
|
|
|
64 |
sample_rate=cfg.sample_rate,
|
65 |
num_bins=cfg.num_bins,
|
66 |
effect_order=cfg.inference_effects_ordering,
|
67 |
+
classifier=classifier,
|
68 |
+
shuffle_effect_order=cfg.inference_effects_shuffle,
|
69 |
+
use_all_effect_models=cfg.inference_use_all_effect_models,
|
70 |
)
|
71 |
trainer.test(model=inference_model, datamodule=datamodule)
|
72 |
|