Spaces:
Sleeping
Sleeping
Commit
·
7173e65
1
Parent(s):
78dbfc8
Render effected chunks to avoid bottlenecks
Browse files- README.md +4 -1
- cfg/config.yaml +2 -0
- cfg/effects/{compressor.yaml → compression.yaml} +0 -0
- cfg/effects/distortion.yaml +2 -2
- remfx/datasets.py +66 -39
- remfx/models.py +1 -0
- shell_vars.sh +1 -0
README.md
CHANGED
@@ -32,4 +32,7 @@ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' train
|
|
32 |
- `chorus`
|
33 |
- `compressor`
|
34 |
- `distortion`
|
35 |
-
- `reverb`
|
|
|
|
|
|
|
|
32 |
- `chorus`
|
33 |
- `compressor`
|
34 |
- `distortion`
|
35 |
+
- `reverb`
|
36 |
+
|
37 |
+
## Misc.
|
38 |
+
To skip rendering files, add `+datamodule.train_dataset.render_files=False +datamodule.val_dataset.render_files=False` to the command-line
|
cfg/config.yaml
CHANGED
@@ -26,6 +26,7 @@ datamodule:
|
|
26 |
_target_: remfx.datasets.VocalSet
|
27 |
sample_rate: ${sample_rate}
|
28 |
root: ${oc.env:DATASET_ROOT}
|
|
|
29 |
chunk_size_in_sec: 6
|
30 |
mode: "train"
|
31 |
effect_types: ${effects.train_effects}
|
@@ -33,6 +34,7 @@ datamodule:
|
|
33 |
_target_: remfx.datasets.VocalSet
|
34 |
sample_rate: ${sample_rate}
|
35 |
root: ${oc.env:DATASET_ROOT}
|
|
|
36 |
chunk_size_in_sec: 6
|
37 |
mode: "val"
|
38 |
effect_types: ${effects.val_effects}
|
|
|
26 |
_target_: remfx.datasets.VocalSet
|
27 |
sample_rate: ${sample_rate}
|
28 |
root: ${oc.env:DATASET_ROOT}
|
29 |
+
output_root: ${oc.env:OUTPUT_ROOT}
|
30 |
chunk_size_in_sec: 6
|
31 |
mode: "train"
|
32 |
effect_types: ${effects.train_effects}
|
|
|
34 |
_target_: remfx.datasets.VocalSet
|
35 |
sample_rate: ${sample_rate}
|
36 |
root: ${oc.env:DATASET_ROOT}
|
37 |
+
output_root: ${oc.env:OUTPUT_ROOT}
|
38 |
chunk_size_in_sec: 6
|
39 |
mode: "val"
|
40 |
effect_types: ${effects.val_effects}
|
cfg/effects/{compressor.yaml → compression.yaml}
RENAMED
File without changes
|
cfg/effects/distortion.yaml
CHANGED
@@ -10,5 +10,5 @@ effects:
|
|
10 |
Distortion:
|
11 |
_target_: remfx.effects.RandomPedalboardDistortion
|
12 |
sample_rate: ${sample_rate}
|
13 |
-
min_drive_db:
|
14 |
-
max_drive_db:
|
|
|
10 |
Distortion:
|
11 |
_target_: remfx.effects.RandomPedalboardDistortion
|
12 |
sample_rate: ${sample_rate}
|
13 |
+
min_drive_db: 30
|
14 |
+
max_drive_db: 30
|
remfx/datasets.py
CHANGED
@@ -17,6 +17,7 @@ from pedalboard import (
|
|
17 |
Distortion,
|
18 |
Limiter,
|
19 |
)
|
|
|
20 |
|
21 |
# https://zenodo.org/record/7044411/ -> GuitarFX
|
22 |
# https://zenodo.org/record/3371780 -> GuitarSet
|
@@ -181,6 +182,8 @@ class VocalSet(Dataset):
|
|
181 |
sample_rate: int,
|
182 |
chunk_size_in_sec: int = 3,
|
183 |
effect_types: List[torch.nn.Module] = None,
|
|
|
|
|
184 |
mode: str = "train",
|
185 |
):
|
186 |
super().__init__()
|
@@ -193,49 +196,68 @@ class VocalSet(Dataset):
|
|
193 |
|
194 |
mode_path = self.root / self.mode
|
195 |
self.files = sorted(list(mode_path.glob("./**/*.wav")))
|
196 |
-
for i, audio_file in enumerate(self.files):
|
197 |
-
chunk_starts, orig_sr = create_sequential_chunks(
|
198 |
-
audio_file, self.chunk_size_in_sec
|
199 |
-
)
|
200 |
-
self.chunks += chunk_starts
|
201 |
-
self.song_idx += [i] * len(chunk_starts)
|
202 |
-
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
|
203 |
-
self.resampler = T.Resample(orig_sr, sample_rate)
|
204 |
-
self.effect_types = effect_types
|
205 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
|
|
206 |
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
chunk_start = self.chunks[idx]
|
215 |
-
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
216 |
-
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
217 |
-
resampled_x = self.resampler(x)
|
218 |
-
# Reset chunk size to be new sample rate
|
219 |
-
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
220 |
-
# Pad to chunk_size if needed
|
221 |
-
if resampled_x.shape[-1] < chunk_size_in_samples:
|
222 |
-
resampled_x = F.pad(
|
223 |
-
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
224 |
-
)
|
225 |
|
226 |
-
|
227 |
-
|
228 |
-
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
229 |
-
else:
|
230 |
-
# deterministic effect for eval
|
231 |
-
effect_idx = idx % len(self.effect_types.keys())
|
232 |
-
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
233 |
-
effect = self.effect_types[effect_name]
|
234 |
-
effected_input = effect(resampled_x)
|
235 |
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
239 |
|
240 |
|
241 |
def create_random_chunks(
|
@@ -262,10 +284,15 @@ def create_sequential_chunks(
|
|
262 |
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
263 |
Return sample_index of start of each chunk and original sr
|
264 |
"""
|
|
|
265 |
audio, sr = torchaudio.load(audio_file)
|
266 |
chunk_size_in_samples = chunk_size * sr
|
267 |
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
268 |
-
|
|
|
|
|
|
|
|
|
269 |
|
270 |
|
271 |
class Datamodule(pl.LightningDataModule):
|
|
|
17 |
Distortion,
|
18 |
Limiter,
|
19 |
)
|
20 |
+
from tqdm import tqdm
|
21 |
|
22 |
# https://zenodo.org/record/7044411/ -> GuitarFX
|
23 |
# https://zenodo.org/record/3371780 -> GuitarSet
|
|
|
182 |
sample_rate: int,
|
183 |
chunk_size_in_sec: int = 3,
|
184 |
effect_types: List[torch.nn.Module] = None,
|
185 |
+
render_files: bool = True,
|
186 |
+
output_root: str = "processed",
|
187 |
mode: str = "train",
|
188 |
):
|
189 |
super().__init__()
|
|
|
196 |
|
197 |
mode_path = self.root / self.mode
|
198 |
self.files = sorted(list(mode_path.glob("./**/*.wav")))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
200 |
+
self.effect_types = effect_types
|
201 |
|
202 |
+
self.output_root = Path(output_root)
|
203 |
+
output_mode_path = output_root / self.mode
|
204 |
+
|
205 |
+
self.num_chunks = 0
|
206 |
+
print("Total files:", len(self.files))
|
207 |
+
print("Processing files...")
|
208 |
+
if render_files:
|
209 |
+
if not output_root.exists():
|
210 |
+
output_root.mkdir()
|
211 |
+
if not output_mode_path.exists():
|
212 |
+
output_mode_path.mkdir()
|
213 |
+
for i, audio_file in tqdm(enumerate(self.files)):
|
214 |
+
chunks, orig_sr = create_sequential_chunks(
|
215 |
+
audio_file, self.chunk_size_in_sec
|
216 |
+
)
|
217 |
+
for chunk in chunks:
|
218 |
+
resampled_chunk = torchaudio.functional.resample(
|
219 |
+
chunk, orig_sr, sample_rate
|
220 |
+
)
|
221 |
+
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
222 |
+
if resampled_chunk.shape[-1] < chunk_size_in_samples:
|
223 |
+
resampled_chunk = F.pad(
|
224 |
+
resampled_chunk,
|
225 |
+
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
|
226 |
+
)
|
227 |
+
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
228 |
+
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
229 |
+
effect = self.effect_types[effect_name]
|
230 |
+
effected_input = effect(resampled_chunk)
|
231 |
+
normalized_input = self.normalize(effected_input)
|
232 |
+
normalized_target = self.normalize(resampled_chunk)
|
233 |
+
|
234 |
+
output_dir = output_mode_path / str(self.num_chunks)
|
235 |
+
output_dir.mkdir(exist_ok=True)
|
236 |
+
torchaudio.save(
|
237 |
+
output_dir / "input.wav", normalized_input, self.sample_rate
|
238 |
+
)
|
239 |
+
torchaudio.save(
|
240 |
+
output_dir / "target.wav", normalized_target, self.sample_rate
|
241 |
+
)
|
242 |
+
self.num_chunks += 1
|
243 |
+
else:
|
244 |
+
self.num_chunks = len(list(output_mode_path.glob("./**/*.wav")))
|
245 |
|
246 |
+
print(
|
247 |
+
f"Found {len(self.files)} {self.mode} files .\n"
|
248 |
+
f"Total chunks: {self.num_chunks}"
|
249 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
+
def __len__(self):
|
252 |
+
return self.num_chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
def __getitem__(self, idx):
|
255 |
+
# Load audio
|
256 |
+
input_file = self.root / "processed" / self.mode / str(idx) / "input.wav"
|
257 |
+
target_file = self.root / "processed" / self.mode / str(idx) / "target.wav"
|
258 |
+
input, sr = torchaudio.load(input_file)
|
259 |
+
target, sr = torchaudio.load(target_file)
|
260 |
+
return (input, target, "")
|
261 |
|
262 |
|
263 |
def create_random_chunks(
|
|
|
284 |
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
285 |
Return sample_index of start of each chunk and original sr
|
286 |
"""
|
287 |
+
chunks = []
|
288 |
audio, sr = torchaudio.load(audio_file)
|
289 |
chunk_size_in_samples = chunk_size * sr
|
290 |
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
291 |
+
for start in chunk_starts:
|
292 |
+
if start + chunk_size_in_samples > audio.shape[-1]:
|
293 |
+
break
|
294 |
+
chunks.append(audio[:, start : start + chunk_size_in_samples])
|
295 |
+
return chunks, sr
|
296 |
|
297 |
|
298 |
class Datamodule(pl.LightningDataModule):
|
remfx/models.py
CHANGED
@@ -20,6 +20,7 @@ class FADLoss(torch.nn.Module):
|
|
20 |
self.fad = FrechetAudioDistance(
|
21 |
use_pca=False, use_activation=False, verbose=False
|
22 |
)
|
|
|
23 |
self.sr = sample_rate
|
24 |
|
25 |
def forward(self, audio_background, audio_eval):
|
|
|
20 |
self.fad = FrechetAudioDistance(
|
21 |
use_pca=False, use_activation=False, verbose=False
|
22 |
)
|
23 |
+
self.fad.model = self.fad.model.to("cpu")
|
24 |
self.sr = sample_rate
|
25 |
|
26 |
def forward(self, audio_background, audio_eval):
|
shell_vars.sh
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
export DATASET_ROOT="./data/VocalSet"
|
|
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|
|
|
1 |
export DATASET_ROOT="./data/VocalSet"
|
2 |
+
export DATASET_ROOT="/scratch/VocalSet/processed"
|
3 |
export WANDB_PROJECT="RemFX"
|
4 |
export WANDB_ENTITY="mattricesound"
|