Spaces:
Runtime error
Runtime error
Commit
·
e0a5f6f
1
Parent(s):
1b89540
Add vocalset and FAD
Browse files- config.yaml +10 -4
- config_guitarset.yaml +52 -0
- exp/umx.yaml +8 -1
- remfx/datasets.py +109 -0
- remfx/models.py +33 -1
- shell_vars.sh +1 -1
config.yaml
CHANGED
@@ -19,13 +19,19 @@ callbacks:
|
|
19 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
20 |
|
21 |
datamodule:
|
22 |
-
_target_: remfx.datasets.
|
23 |
-
|
24 |
-
_target_: remfx.datasets.
|
25 |
sample_rate: ${sample_rate}
|
26 |
root: ${oc.env:DATASET_ROOT}
|
27 |
chunk_size_in_sec: 6
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
batch_size: 16
|
30 |
num_workers: 8
|
31 |
pin_memory: True
|
|
|
19 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
20 |
|
21 |
datamodule:
|
22 |
+
_target_: remfx.datasets.VocalSetDatamodule
|
23 |
+
train_dataset:
|
24 |
+
_target_: remfx.datasets.VocalSet
|
25 |
sample_rate: ${sample_rate}
|
26 |
root: ${oc.env:DATASET_ROOT}
|
27 |
chunk_size_in_sec: 6
|
28 |
+
mode: "train"
|
29 |
+
val_dataset:
|
30 |
+
_target_: remfx.datasets.VocalSet
|
31 |
+
sample_rate: ${sample_rate}
|
32 |
+
root: ${oc.env:DATASET_ROOT}
|
33 |
+
chunk_size_in_sec: 6
|
34 |
+
mode: "val"
|
35 |
batch_size: 16
|
36 |
num_workers: 8
|
37 |
pin_memory: True
|
config_guitarset.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _self_
|
3 |
+
- exp: null
|
4 |
+
seed: 12345
|
5 |
+
train: True
|
6 |
+
sample_rate: 48000
|
7 |
+
logs_dir: "./logs"
|
8 |
+
log_every_n_steps: 1000
|
9 |
+
|
10 |
+
callbacks:
|
11 |
+
model_checkpoint:
|
12 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
13 |
+
monitor: "valid_loss" # name of the logged metric which determines when model is improving
|
14 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
15 |
+
save_last: True # additionaly always save model from last epoch
|
16 |
+
mode: "min" # can be "max" or "min"
|
17 |
+
verbose: False
|
18 |
+
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
19 |
+
filename: '{epoch:02d}-{valid_loss:.3f}'
|
20 |
+
|
21 |
+
datamodule:
|
22 |
+
_target_: remfx.datasets.Datamodule
|
23 |
+
dataset:
|
24 |
+
_target_: remfx.datasets.GuitarSet
|
25 |
+
sample_rate: ${sample_rate}
|
26 |
+
root: ${oc.env:DATASET_ROOT}
|
27 |
+
chunk_size_in_sec: 6
|
28 |
+
val_split: 0.2
|
29 |
+
batch_size: 16
|
30 |
+
num_workers: 8
|
31 |
+
pin_memory: True
|
32 |
+
persistent_workers: True
|
33 |
+
|
34 |
+
logger:
|
35 |
+
_target_: pytorch_lightning.loggers.WandbLogger
|
36 |
+
project: ${oc.env:WANDB_PROJECT}
|
37 |
+
entity: ${oc.env:WANDB_ENTITY}
|
38 |
+
# offline: False # set True to store all logs only locally
|
39 |
+
job_type: "train"
|
40 |
+
group: ""
|
41 |
+
save_dir: "."
|
42 |
+
|
43 |
+
trainer:
|
44 |
+
_target_: pytorch_lightning.Trainer
|
45 |
+
precision: 32 # Precision used for tensors, default `32`
|
46 |
+
min_epochs: 0
|
47 |
+
max_epochs: -1
|
48 |
+
enable_model_summary: False
|
49 |
+
log_every_n_steps: 1 # Logs metrics every N batches
|
50 |
+
accumulate_grad_batches: 1
|
51 |
+
accelerator: null
|
52 |
+
devices: 1
|
exp/umx.yaml
CHANGED
@@ -15,7 +15,14 @@ model:
|
|
15 |
alpha: 0.3
|
16 |
sample_rate: ${sample_rate}
|
17 |
datamodule:
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
effect_types:
|
20 |
Distortion:
|
21 |
_target_: remfx.effects.RandomPedalboardDistortion
|
|
|
15 |
alpha: 0.3
|
16 |
sample_rate: ${sample_rate}
|
17 |
datamodule:
|
18 |
+
train_dataset:
|
19 |
+
effect_types:
|
20 |
+
Distortion:
|
21 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
22 |
+
sample_rate: ${sample_rate}
|
23 |
+
min_drive_db: -10
|
24 |
+
max_drive_db: 50
|
25 |
+
val_dataset:
|
26 |
effect_types:
|
27 |
Distortion:
|
28 |
_target_: remfx.effects.RandomPedalboardDistortion
|
remfx/datasets.py
CHANGED
@@ -20,6 +20,7 @@ from pedalboard import (
|
|
20 |
|
21 |
# https://zenodo.org/record/7044411/ -> GuitarFX
|
22 |
# https://zenodo.org/record/3371780 -> GuitarSet
|
|
|
23 |
|
24 |
deterministic_effects = {
|
25 |
"Distortion": Pedalboard([Distortion()]),
|
@@ -173,6 +174,74 @@ class GuitarSet(Dataset):
|
|
173 |
return (normalized_input, normalized_target, effect_name)
|
174 |
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
def create_random_chunks(
|
177 |
audio_file: str, chunk_size: int, num_chunks: int
|
178 |
) -> Tuple[List[Tuple[int, int]], int]:
|
@@ -249,3 +318,43 @@ class Datamodule(pl.LightningDataModule):
|
|
249 |
pin_memory=self.pin_memory,
|
250 |
shuffle=False,
|
251 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# https://zenodo.org/record/7044411/ -> GuitarFX
|
22 |
# https://zenodo.org/record/3371780 -> GuitarSet
|
23 |
+
# https://zenodo.org/record/1193957 -> VocalSet
|
24 |
|
25 |
deterministic_effects = {
|
26 |
"Distortion": Pedalboard([Distortion()]),
|
|
|
174 |
return (normalized_input, normalized_target, effect_name)
|
175 |
|
176 |
|
177 |
+
class VocalSet(Dataset):
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
root: str,
|
181 |
+
sample_rate: int,
|
182 |
+
chunk_size_in_sec: int = 3,
|
183 |
+
effect_types: List[torch.nn.Module] = None,
|
184 |
+
mode: str = "train",
|
185 |
+
):
|
186 |
+
super().__init__()
|
187 |
+
self.chunks = []
|
188 |
+
self.song_idx = []
|
189 |
+
self.root = Path(root)
|
190 |
+
self.chunk_size_in_sec = chunk_size_in_sec
|
191 |
+
self.sample_rate = sample_rate
|
192 |
+
self.mode = mode
|
193 |
+
|
194 |
+
mode_path = self.root / self.mode
|
195 |
+
self.files = sorted(list(mode_path.glob("./**/*.wav")))
|
196 |
+
for i, audio_file in enumerate(self.files):
|
197 |
+
chunk_starts, orig_sr = create_sequential_chunks(
|
198 |
+
audio_file, self.chunk_size_in_sec
|
199 |
+
)
|
200 |
+
self.chunks += chunk_starts
|
201 |
+
self.song_idx += [i] * len(chunk_starts)
|
202 |
+
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
|
203 |
+
self.resampler = T.Resample(orig_sr, sample_rate)
|
204 |
+
self.effect_types = effect_types
|
205 |
+
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
206 |
+
|
207 |
+
def __len__(self):
|
208 |
+
return len(self.chunks)
|
209 |
+
|
210 |
+
def __getitem__(self, idx):
|
211 |
+
# Load and effect audio
|
212 |
+
song_idx = self.song_idx[idx]
|
213 |
+
x, sr = torchaudio.load(self.files[song_idx])
|
214 |
+
chunk_start = self.chunks[idx]
|
215 |
+
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
216 |
+
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
217 |
+
resampled_x = self.resampler(x)
|
218 |
+
# Reset chunk size to be new sample rate
|
219 |
+
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
220 |
+
# Pad to chunk_size if needed
|
221 |
+
if resampled_x.shape[-1] < chunk_size_in_samples:
|
222 |
+
resampled_x = F.pad(
|
223 |
+
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
224 |
+
)
|
225 |
+
|
226 |
+
# Add random effect if train
|
227 |
+
if self.mode == "train":
|
228 |
+
random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
229 |
+
effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
|
230 |
+
effect = self.effect_types[effect_name]
|
231 |
+
effected_input = effect(resampled_x)
|
232 |
+
else:
|
233 |
+
# deterministic static effect for eval
|
234 |
+
effect_idx = idx % len(self.effect_types.keys())
|
235 |
+
effect_name = list(self.effect_types.keys())[effect_idx]
|
236 |
+
effect = deterministic_effects[effect_name]
|
237 |
+
effected_input = torch.from_numpy(
|
238 |
+
effect(resampled_x.numpy(), self.sample_rate)
|
239 |
+
)
|
240 |
+
normalized_input = self.normalize(effected_input)
|
241 |
+
normalized_target = self.normalize(resampled_x)
|
242 |
+
return (normalized_input, normalized_target, effect_name)
|
243 |
+
|
244 |
+
|
245 |
def create_random_chunks(
|
246 |
audio_file: str, chunk_size: int, num_chunks: int
|
247 |
) -> Tuple[List[Tuple[int, int]], int]:
|
|
|
318 |
pin_memory=self.pin_memory,
|
319 |
shuffle=False,
|
320 |
)
|
321 |
+
|
322 |
+
|
323 |
+
class VocalSetDatamodule(pl.LightningDataModule):
|
324 |
+
def __init__(
|
325 |
+
self,
|
326 |
+
train_dataset,
|
327 |
+
val_dataset,
|
328 |
+
*,
|
329 |
+
batch_size: int,
|
330 |
+
num_workers: int,
|
331 |
+
pin_memory: bool = False,
|
332 |
+
**kwargs: int,
|
333 |
+
) -> None:
|
334 |
+
super().__init__()
|
335 |
+
self.train_dataset = train_dataset
|
336 |
+
self.val_dataset = val_dataset
|
337 |
+
self.batch_size = batch_size
|
338 |
+
self.num_workers = num_workers
|
339 |
+
self.pin_memory = pin_memory
|
340 |
+
|
341 |
+
def setup(self, stage: Any = None) -> None:
|
342 |
+
pass
|
343 |
+
|
344 |
+
def train_dataloader(self) -> DataLoader:
|
345 |
+
return DataLoader(
|
346 |
+
dataset=self.train_dataset,
|
347 |
+
batch_size=self.batch_size,
|
348 |
+
num_workers=self.num_workers,
|
349 |
+
pin_memory=self.pin_memory,
|
350 |
+
shuffle=True,
|
351 |
+
)
|
352 |
+
|
353 |
+
def val_dataloader(self) -> DataLoader:
|
354 |
+
return DataLoader(
|
355 |
+
dataset=self.val_dataset,
|
356 |
+
batch_size=self.batch_size,
|
357 |
+
num_workers=self.num_workers,
|
358 |
+
pin_memory=self.pin_memory,
|
359 |
+
shuffle=False,
|
360 |
+
)
|
remfx/models.py
CHANGED
@@ -7,11 +7,43 @@ from audio_diffusion_pytorch import DiffusionModel
|
|
7 |
from auraloss.time import SISDRLoss
|
8 |
from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
|
9 |
from torch.nn import L1Loss
|
|
|
|
|
10 |
|
11 |
from umx.openunmix.model import OpenUnmix, Separator
|
12 |
from torchaudio.models import HDemucs
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
class RemFXModel(pl.LightningModule):
|
16 |
def __init__(
|
17 |
self,
|
@@ -35,7 +67,7 @@ class RemFXModel(pl.LightningModule):
|
|
35 |
{
|
36 |
"SISDR": SISDRLoss(),
|
37 |
"STFT": STFTLoss(),
|
38 |
-
"
|
39 |
}
|
40 |
)
|
41 |
# Log first batch metrics input vs output only once
|
|
|
7 |
from auraloss.time import SISDRLoss
|
8 |
from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
|
9 |
from torch.nn import L1Loss
|
10 |
+
from frechet_audio_distance import FrechetAudioDistance
|
11 |
+
import numpy as np
|
12 |
|
13 |
from umx.openunmix.model import OpenUnmix, Separator
|
14 |
from torchaudio.models import HDemucs
|
15 |
|
16 |
|
17 |
+
class FADLoss(torch.nn.Module):
|
18 |
+
def __init__(self, sample_rate: float):
|
19 |
+
super().__init__()
|
20 |
+
self.fad = FrechetAudioDistance(
|
21 |
+
use_pca=False, use_activation=False, verbose=False
|
22 |
+
)
|
23 |
+
self.sr = sample_rate
|
24 |
+
|
25 |
+
def forward(self, audio_background, audio_eval):
|
26 |
+
embds_background = []
|
27 |
+
embds_eval = []
|
28 |
+
for sample in audio_background:
|
29 |
+
embd = self.fad.model.forward(sample.T.detach().numpy(), self.sr)
|
30 |
+
embds_background.append(embd.cpu().detach().numpy())
|
31 |
+
for sample in audio_eval:
|
32 |
+
embd = self.fad.model.forward(sample.T.detach().numpy(), self.sr)
|
33 |
+
embds_eval.append(embd.cpu().detach().numpy())
|
34 |
+
embds_background = np.concatenate(embds_background, axis=0)
|
35 |
+
embds_eval = np.concatenate(embds_eval, axis=0)
|
36 |
+
mu_background, sigma_background = self.fad.calculate_embd_statistics(
|
37 |
+
embds_background
|
38 |
+
)
|
39 |
+
mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
|
40 |
+
|
41 |
+
fad_score = self.fad.calculate_frechet_distance(
|
42 |
+
mu_background, sigma_background, mu_eval, sigma_eval
|
43 |
+
)
|
44 |
+
return fad_score
|
45 |
+
|
46 |
+
|
47 |
class RemFXModel(pl.LightningModule):
|
48 |
def __init__(
|
49 |
self,
|
|
|
67 |
{
|
68 |
"SISDR": SISDRLoss(),
|
69 |
"STFT": STFTLoss(),
|
70 |
+
"FAD": FADLoss(sample_rate=sample_rate),
|
71 |
}
|
72 |
)
|
73 |
# Log first batch metrics input vs output only once
|
shell_vars.sh
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
export DATASET_ROOT="./data/
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|
|
|
1 |
+
export DATASET_ROOT="./data/VocalSet"
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|