Spaces:
Sleeping
Sleeping
Commit
·
133e1dc
1
Parent(s):
7173f20
Add shuffling effect order, all effects present for chain_inference to cfg
Browse files
cfg/exp/chain_inference.yaml
CHANGED
@@ -63,4 +63,6 @@ inference_effects_ordering:
|
|
63 |
- "RandomPedalboardReverb"
|
64 |
- "RandomPedalboardChorus"
|
65 |
- "RandomPedalboardDelay"
|
66 |
-
num_bins: 1025
|
|
|
|
|
|
63 |
- "RandomPedalboardReverb"
|
64 |
- "RandomPedalboardChorus"
|
65 |
- "RandomPedalboardDelay"
|
66 |
+
num_bins: 1025
|
67 |
+
inference_effects_shuffle: False
|
68 |
+
inference_use_all_effect_models: False
|
cfg/exp/chain_inference_aug.yaml
CHANGED
@@ -63,4 +63,6 @@ inference_effects_ordering:
|
|
63 |
- "RandomPedalboardReverb"
|
64 |
- "RandomPedalboardChorus"
|
65 |
- "RandomPedalboardDelay"
|
66 |
-
num_bins: 1025
|
|
|
|
|
|
63 |
- "RandomPedalboardReverb"
|
64 |
- "RandomPedalboardChorus"
|
65 |
- "RandomPedalboardDelay"
|
66 |
+
num_bins: 1025
|
67 |
+
inference_effects_shuffle: False
|
68 |
+
inference_use_all_effect_models: False
|
cfg/exp/chain_inference_aug_classifier.yaml
CHANGED
@@ -82,4 +82,6 @@ inference_effects_ordering:
|
|
82 |
- "RandomPedalboardReverb"
|
83 |
- "RandomPedalboardChorus"
|
84 |
- "RandomPedalboardDelay"
|
85 |
-
num_bins: 1025
|
|
|
|
|
|
82 |
- "RandomPedalboardReverb"
|
83 |
- "RandomPedalboardChorus"
|
84 |
- "RandomPedalboardDelay"
|
85 |
+
num_bins: 1025
|
86 |
+
inference_effects_shuffle: False
|
87 |
+
inference_use_all_effect_models: False
|
cfg/exp/chain_inference_custom.yaml
CHANGED
@@ -68,4 +68,6 @@ inference_effects_ordering:
|
|
68 |
- "RandomPedalboardReverb"
|
69 |
- "RandomPedalboardChorus"
|
70 |
- "RandomPedalboardDelay"
|
71 |
-
num_bins: 1025
|
|
|
|
|
|
68 |
- "RandomPedalboardReverb"
|
69 |
- "RandomPedalboardChorus"
|
70 |
- "RandomPedalboardDelay"
|
71 |
+
num_bins: 1025
|
72 |
+
inference_effects_shuffle: False
|
73 |
+
inference_use_all_effect_models: False
|
remfx/models.py
CHANGED
@@ -16,12 +16,22 @@ from remfx.callbacks import log_wandb_audio_batch
|
|
16 |
from einops import rearrange
|
17 |
from remfx import effects
|
18 |
import asteroid
|
|
|
19 |
|
20 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
21 |
|
22 |
|
23 |
class RemFXChainInference(pl.LightningModule):
|
24 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
super().__init__()
|
26 |
self.model = models
|
27 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
@@ -37,7 +47,9 @@ class RemFXChainInference(pl.LightningModule):
|
|
37 |
self.sample_rate = sample_rate
|
38 |
self.effect_order = effect_order
|
39 |
self.classifier = classifier
|
|
|
40 |
self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
|
|
|
41 |
|
42 |
def forward(self, batch, batch_idx, order=None):
|
43 |
x, y, _, rem_fx_labels = batch
|
@@ -46,36 +58,45 @@ class RemFXChainInference(pl.LightningModule):
|
|
46 |
effects_order = order
|
47 |
else:
|
48 |
effects_order = self.effect_order
|
49 |
-
|
50 |
# Use classifier labels
|
51 |
if self.classifier:
|
52 |
threshold = 0.5
|
53 |
with torch.no_grad():
|
54 |
labels = torch.sigmoid(self.classifier(x))
|
55 |
rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
output = []
|
62 |
-
input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
63 |
-
target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
64 |
-
|
65 |
-
log_wandb_audio_batch(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
)
|
72 |
-
log_wandb_audio_batch(
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
)
|
79 |
with torch.no_grad():
|
80 |
for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
|
81 |
elem = elem.unsqueeze(0) # Add batch dim
|
@@ -111,7 +132,6 @@ class RemFXChainInference(pl.LightningModule):
|
|
111 |
# )
|
112 |
output.append(elem.squeeze(0))
|
113 |
output = torch.stack(output)
|
114 |
-
output_samples = rearrange(output, "b c t -> c (b t)").unsqueeze(0)
|
115 |
|
116 |
# log_wandb_audio_batch(
|
117 |
# logger=self.logger,
|
@@ -125,8 +145,9 @@ class RemFXChainInference(pl.LightningModule):
|
|
125 |
|
126 |
def test_step(self, batch, batch_idx):
|
127 |
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
|
128 |
-
|
129 |
-
|
|
|
130 |
loss, output = self.forward(batch, batch_idx, order=self.effect_order)
|
131 |
# Crop target to match output
|
132 |
if output.shape[-1] < y.shape[-1]:
|
|
|
16 |
from einops import rearrange
|
17 |
from remfx import effects
|
18 |
import asteroid
|
19 |
+
import random
|
20 |
|
21 |
ALL_EFFECTS = effects.Pedalboard_Effects
|
22 |
|
23 |
|
24 |
class RemFXChainInference(pl.LightningModule):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
models,
|
28 |
+
sample_rate,
|
29 |
+
num_bins,
|
30 |
+
effect_order,
|
31 |
+
classifier=None,
|
32 |
+
shuffle_effect_order=False,
|
33 |
+
use_all_effect_models=False,
|
34 |
+
):
|
35 |
super().__init__()
|
36 |
self.model = models
|
37 |
self.mrstftloss = MultiResolutionSTFTLoss(
|
|
|
47 |
self.sample_rate = sample_rate
|
48 |
self.effect_order = effect_order
|
49 |
self.classifier = classifier
|
50 |
+
self.shuffle_effect_order = shuffle_effect_order
|
51 |
self.output_str = "IN_SISDR,OUT_SISDR,IN_STFT,OUT_STFT\n"
|
52 |
+
self.use_all_effect_models = use_all_effect_models
|
53 |
|
54 |
def forward(self, batch, batch_idx, order=None):
|
55 |
x, y, _, rem_fx_labels = batch
|
|
|
58 |
effects_order = order
|
59 |
else:
|
60 |
effects_order = self.effect_order
|
61 |
+
|
62 |
# Use classifier labels
|
63 |
if self.classifier:
|
64 |
threshold = 0.5
|
65 |
with torch.no_grad():
|
66 |
labels = torch.sigmoid(self.classifier(x))
|
67 |
rem_fx_labels = torch.where(labels > threshold, 1.0, 0.0)
|
68 |
+
if self.use_all_effect_models:
|
69 |
+
effects_present = [
|
70 |
+
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect]
|
71 |
+
for effect_label in rem_fx_labels
|
72 |
+
]
|
73 |
+
else:
|
74 |
+
effects_present = [
|
75 |
+
[
|
76 |
+
ALL_EFFECTS[i]
|
77 |
+
for i, effect in enumerate(effect_label)
|
78 |
+
if effect == 1.0
|
79 |
+
]
|
80 |
+
for effect_label in rem_fx_labels
|
81 |
+
]
|
82 |
output = []
|
83 |
+
# input_samples = rearrange(x, "b c t -> c (b t)").unsqueeze(0)
|
84 |
+
# target_samples = rearrange(y, "b c t -> c (b t)").unsqueeze(0)
|
85 |
+
|
86 |
+
# log_wandb_audio_batch(
|
87 |
+
# logger=self.logger,
|
88 |
+
# id="input_effected_audio",
|
89 |
+
# samples=input_samples.cpu(),
|
90 |
+
# sampling_rate=self.sample_rate,
|
91 |
+
# caption="Input Data",
|
92 |
+
# )
|
93 |
+
# log_wandb_audio_batch(
|
94 |
+
# logger=self.logger,
|
95 |
+
# id="target_audio",
|
96 |
+
# samples=target_samples.cpu(),
|
97 |
+
# sampling_rate=self.sample_rate,
|
98 |
+
# caption="Target Data",
|
99 |
+
# )
|
100 |
with torch.no_grad():
|
101 |
for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
|
102 |
elem = elem.unsqueeze(0) # Add batch dim
|
|
|
132 |
# )
|
133 |
output.append(elem.squeeze(0))
|
134 |
output = torch.stack(output)
|
|
|
135 |
|
136 |
# log_wandb_audio_batch(
|
137 |
# logger=self.logger,
|
|
|
145 |
|
146 |
def test_step(self, batch, batch_idx):
|
147 |
x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
|
148 |
+
if self.shuffle_effect_order:
|
149 |
+
# Random order
|
150 |
+
random.shuffle(self.effect_order)
|
151 |
loss, output = self.forward(batch, batch_idx, order=self.effect_order)
|
152 |
# Crop target to match output
|
153 |
if output.shape[-1] < y.shape[-1]:
|
scripts/chain_inference.py
CHANGED
@@ -65,6 +65,8 @@ def main(cfg: DictConfig):
|
|
65 |
num_bins=cfg.num_bins,
|
66 |
effect_order=cfg.inference_effects_ordering,
|
67 |
classifier=classifier,
|
|
|
|
|
68 |
)
|
69 |
trainer.test(model=inference_model, datamodule=datamodule)
|
70 |
|
|
|
65 |
num_bins=cfg.num_bins,
|
66 |
effect_order=cfg.inference_effects_ordering,
|
67 |
classifier=classifier,
|
68 |
+
shuffle_effect_order=cfg.inference_effects_shuffle,
|
69 |
+
use_all_effect_models=cfg.inference_use_all_effect_models,
|
70 |
)
|
71 |
trainer.test(model=inference_model, datamodule=datamodule)
|
72 |
|