Christian J. Steinmetz commited on
Commit
a3e84f7
·
1 Parent(s): 99511d7

rework dataset generation for better logging of effects added

Browse files
Files changed (2) hide show
  1. remfx/datasets.py +77 -28
  2. remfx/effects.py +3 -3
remfx/datasets.py CHANGED
@@ -1,21 +1,49 @@
 
 
 
1
  import torch
2
- from torch.utils.data import Dataset, DataLoader
3
- import torch.nn.functional as F
4
  import torchaudio
5
- from pathlib import Path
6
  import pytorch_lightning as pl
7
- import sys
8
- from typing import Any, List, Dict
9
- from remfx import effects
10
  from tqdm import tqdm
11
- from remfx.utils import create_sequential_chunks
12
- import shutil
13
  from ordered_set import OrderedSet
 
 
 
14
 
15
 
16
  # https://zenodo.org/record/1193957 -> VocalSet
17
 
18
  ALL_EFFECTS = effects.Pedalboard_Effects
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  class VocalSet(Dataset):
@@ -43,8 +71,6 @@ class VocalSet(Dataset):
43
  self.chunk_size = chunk_size
44
  self.sample_rate = sample_rate
45
  self.mode = mode
46
- mode_path = self.root / self.mode
47
- self.files = sorted(list(mode_path.glob("./**/*.wav")))
48
  self.max_kept_effects = max_kept_effects
49
  self.max_removed_effects = max_removed_effects
50
  self.effects_to_use = effects_to_use
@@ -53,11 +79,20 @@ class VocalSet(Dataset):
53
  self.effects = effect_modules
54
  self.shuffle_kept_effects = shuffle_kept_effects
55
  self.shuffle_removed_effects = shuffle_removed_effects
56
-
57
  effects_string = "_".join(self.effects_to_use + ["_"] + self.effects_to_remove)
58
  self.effects_to_keep = self.validate_effect_input()
59
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
60
 
 
 
 
 
 
 
 
 
 
 
61
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
62
  print("Found processed files.")
63
  if render_files:
@@ -86,12 +121,15 @@ class VocalSet(Dataset):
86
  # Skip if chunk is too small
87
  continue
88
 
89
- x, y, effect = self.process_effects(resampled_chunk)
 
 
90
  output_dir = self.proc_root / str(self.num_chunks)
91
  output_dir.mkdir(exist_ok=True)
92
- torchaudio.save(output_dir / "input.wav", x, self.sample_rate)
93
- torchaudio.save(output_dir / "target.wav", y, self.sample_rate)
94
- torch.save(effect, output_dir / "effect.pt")
 
95
  self.num_chunks += 1
96
  else:
97
  self.num_chunks = len(list(self.proc_root.iterdir()))
@@ -107,10 +145,11 @@ class VocalSet(Dataset):
107
  def __getitem__(self, idx):
108
  input_file = self.proc_root / str(idx) / "input.wav"
109
  target_file = self.proc_root / str(idx) / "target.wav"
110
- effect_name = torch.load(self.proc_root / str(idx) / "effect.pt")
 
111
  input, sr = torchaudio.load(input_file)
112
  target, sr = torchaudio.load(target_file)
113
- return (input, target, effect_name)
114
 
115
  def validate_effect_input(self):
116
  for effect in self.effects.values():
@@ -154,27 +193,29 @@ class VocalSet(Dataset):
154
  return kept_fx
155
 
156
  def process_effects(self, dry: torch.Tensor):
157
- labels = []
158
-
159
  # Apply Kept Effects
160
  # Shuffle effects if specified
161
  if self.shuffle_kept_effects:
162
  effect_indices = torch.randperm(len(self.effects_to_keep))
163
  else:
164
  effect_indices = torch.arange(len(self.effects_to_keep))
 
165
  # Up to max_kept_effects
166
  if self.max_kept_effects != -1:
167
  num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1
168
  else:
169
  num_kept_effects = len(self.effects_to_keep)
170
  effect_indices = effect_indices[:num_kept_effects]
 
 
171
  # Index in effect settings
172
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
173
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
174
  # Apply
 
175
  for effect in effects_to_apply:
176
  dry = effect(dry)
177
- labels.append(ALL_EFFECTS.index(type(effect)))
178
 
179
  # Apply effects_to_remove
180
  # Shuffle effects if specified
@@ -185,9 +226,7 @@ class VocalSet(Dataset):
185
  effect_indices = torch.arange(len(self.effects_to_remove))
186
  # Up to max_removed_effects
187
  if self.max_removed_effects != -1:
188
- num_kept_effects = (
189
- int(torch.rand(1).item() * (self.max_removed_effects)) + 1
190
- )
191
  else:
192
  num_kept_effects = len(self.effects_to_remove)
193
  effect_indices = effect_indices[: self.max_removed_effects]
@@ -195,17 +234,27 @@ class VocalSet(Dataset):
195
  effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
196
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
197
  # Apply
 
 
198
  for effect in effects_to_apply:
199
  wet = effect(wet)
200
- labels.append(ALL_EFFECTS.index(type(effect)))
 
 
 
 
 
 
 
 
 
201
 
202
- # Convert labels to one-hot
203
- one_hot = F.one_hot(torch.tensor(labels), num_classes=len(ALL_EFFECTS))
204
- effects_present = torch.sum(one_hot, dim=0).float()
205
  # Normalize
206
  normalized_dry = self.normalize(dry)
207
  normalized_wet = self.normalize(wet)
208
- return normalized_dry, normalized_wet, effects_present
209
 
210
 
211
  class VocalSetDatamodule(pl.LightningDataModule):
 
1
+ import os
2
+ import sys
3
+ import glob
4
  import torch
5
+ import shutil
 
6
  import torchaudio
 
7
  import pytorch_lightning as pl
8
+ import torch.nn.functional as F
9
+
 
10
  from tqdm import tqdm
11
+ from pathlib import Path
12
+ from remfx import effects
13
  from ordered_set import OrderedSet
14
+ from typing import Any, List, Dict
15
+ from torch.utils.data import Dataset, DataLoader
16
+ from remfx.utils import create_sequential_chunks
17
 
18
 
19
  # https://zenodo.org/record/1193957 -> VocalSet
20
 
21
  ALL_EFFECTS = effects.Pedalboard_Effects
22
+ print(ALL_EFFECTS)
23
+
24
+
25
+ singer_splits = {
26
+ "train": [
27
+ "male1",
28
+ "male2",
29
+ "male3",
30
+ "male4",
31
+ "male5",
32
+ "male6",
33
+ "male7",
34
+ "male8",
35
+ "male9",
36
+ "female1",
37
+ "female2",
38
+ "female3",
39
+ "female4",
40
+ "female5",
41
+ "female6",
42
+ "female7",
43
+ ],
44
+ "val": ["male10", "female8"],
45
+ "test": ["male11", "female9"],
46
+ }
47
 
48
 
49
  class VocalSet(Dataset):
 
71
  self.chunk_size = chunk_size
72
  self.sample_rate = sample_rate
73
  self.mode = mode
 
 
74
  self.max_kept_effects = max_kept_effects
75
  self.max_removed_effects = max_removed_effects
76
  self.effects_to_use = effects_to_use
 
79
  self.effects = effect_modules
80
  self.shuffle_kept_effects = shuffle_kept_effects
81
  self.shuffle_removed_effects = shuffle_removed_effects
 
82
  effects_string = "_".join(self.effects_to_use + ["_"] + self.effects_to_remove)
83
  self.effects_to_keep = self.validate_effect_input()
84
  self.proc_root = self.render_root / "processed" / effects_string / self.mode
85
 
86
+ # find all singer directories
87
+ singer_dirs = glob.glob(os.path.join(self.root, "data_by_singer", "*"))
88
+ singer_dirs = [
89
+ sd for sd in singer_dirs if os.path.basename(sd) in singer_splits[mode]
90
+ ]
91
+ self.files = []
92
+ for singer_dir in singer_dirs:
93
+ self.files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
94
+ self.files = sorted(self.files)
95
+
96
  if self.proc_root.exists() and len(list(self.proc_root.iterdir())) > 0:
97
  print("Found processed files.")
98
  if render_files:
 
121
  # Skip if chunk is too small
122
  continue
123
 
124
+ dry, wet, dry_effects, wet_effects = self.process_effects(
125
+ resampled_chunk
126
+ )
127
  output_dir = self.proc_root / str(self.num_chunks)
128
  output_dir.mkdir(exist_ok=True)
129
+ torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
130
+ torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
131
+ torch.save(dry_effects, output_dir / "dry_effects.pt")
132
+ torch.save(wet_effects, output_dir / "wet_effects.pt")
133
  self.num_chunks += 1
134
  else:
135
  self.num_chunks = len(list(self.proc_root.iterdir()))
 
145
  def __getitem__(self, idx):
146
  input_file = self.proc_root / str(idx) / "input.wav"
147
  target_file = self.proc_root / str(idx) / "target.wav"
148
+ dry_effect_names = torch.load(self.proc_root / str(idx) / "dry_effects.pt")
149
+ wet_effect_names = torch.load(self.proc_root / str(idx) / "wet_effects.pt")
150
  input, sr = torchaudio.load(input_file)
151
  target, sr = torchaudio.load(target_file)
152
+ return (input, target, dry_effect_names, wet_effect_names)
153
 
154
  def validate_effect_input(self):
155
  for effect in self.effects.values():
 
193
  return kept_fx
194
 
195
  def process_effects(self, dry: torch.Tensor):
 
 
196
  # Apply Kept Effects
197
  # Shuffle effects if specified
198
  if self.shuffle_kept_effects:
199
  effect_indices = torch.randperm(len(self.effects_to_keep))
200
  else:
201
  effect_indices = torch.arange(len(self.effects_to_keep))
202
+
203
  # Up to max_kept_effects
204
  if self.max_kept_effects != -1:
205
  num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1
206
  else:
207
  num_kept_effects = len(self.effects_to_keep)
208
  effect_indices = effect_indices[:num_kept_effects]
209
+ print(effect_indices)
210
+
211
  # Index in effect settings
212
  effect_names_to_apply = [self.effects_to_keep[i] for i in effect_indices]
213
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
214
  # Apply
215
+ dry_labels = []
216
  for effect in effects_to_apply:
217
  dry = effect(dry)
218
+ dry_labels.append(ALL_EFFECTS.index(type(effect)))
219
 
220
  # Apply effects_to_remove
221
  # Shuffle effects if specified
 
226
  effect_indices = torch.arange(len(self.effects_to_remove))
227
  # Up to max_removed_effects
228
  if self.max_removed_effects != -1:
229
+ num_kept_effects = int(torch.rand(1).item() * (self.max_removed_effects))
 
 
230
  else:
231
  num_kept_effects = len(self.effects_to_remove)
232
  effect_indices = effect_indices[: self.max_removed_effects]
 
234
  effect_names_to_apply = [self.effects_to_remove[i] for i in effect_indices]
235
  effects_to_apply = [self.effects[i] for i in effect_names_to_apply]
236
  # Apply
237
+
238
+ wet_labels = []
239
  for effect in effects_to_apply:
240
  wet = effect(wet)
241
+ wet_labels.append(ALL_EFFECTS.index(type(effect)))
242
+
243
+ wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
244
+ dry_labels_tensor = torch.zeros(len(ALL_EFFECTS))
245
+
246
+ for label_idx in wet_labels:
247
+ wet_labels_tensor[label_idx] = 1.0
248
+
249
+ for label_idx in dry_labels:
250
+ dry_labels_tensor[label_idx] = 1.0
251
 
252
+ # effects_present = torch.sum(one_hot, dim=0).float()
253
+ print(dry_labels_tensor, wet_labels_tensor)
 
254
  # Normalize
255
  normalized_dry = self.normalize(dry)
256
  normalized_wet = self.normalize(wet)
257
+ return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
258
 
259
 
260
  class VocalSetDatamodule(pl.LightningDataModule):
remfx/effects.py CHANGED
@@ -701,9 +701,9 @@ class RandomAudioEffectsChannel(torch.nn.Module):
701
  Pedalboard_Effects = [
702
  RandomPedalboardReverb,
703
  RandomPedalboardChorus,
704
- RandomPedalboardDelay,
705
  RandomPedalboardDistortion,
706
  RandomPedalboardCompressor,
707
- RandomPedalboardPhaser,
708
- RandomPedalboardLimiter,
709
  ]
 
701
  Pedalboard_Effects = [
702
  RandomPedalboardReverb,
703
  RandomPedalboardChorus,
704
+ # RandomPedalboardDelay,
705
  RandomPedalboardDistortion,
706
  RandomPedalboardCompressor,
707
+ # RandomPedalboardPhaser,
708
+ # RandomPedalboardLimiter,
709
  ]