mattricesound commited on
Commit
9325b1e
·
1 Parent(s): bd1743b

Add random sampling of datasets to prevent class imbalance

Browse files
Files changed (2) hide show
  1. cfg/config.yaml +3 -0
  2. remfx/datasets.py +22 -19
cfg/config.yaml CHANGED
@@ -56,6 +56,7 @@ datamodule:
56
  _target_: remfx.datasets.EffectDatamodule
57
  train_dataset:
58
  _target_: remfx.datasets.EffectDataset
 
59
  sample_rate: ${sample_rate}
60
  root: ${oc.env:DATASET_ROOT}
61
  chunk_size: ${chunk_size}
@@ -71,6 +72,7 @@ datamodule:
71
  render_root: ${render_root}
72
  val_dataset:
73
  _target_: remfx.datasets.EffectDataset
 
74
  sample_rate: ${sample_rate}
75
  root: ${oc.env:DATASET_ROOT}
76
  chunk_size: ${chunk_size}
@@ -86,6 +88,7 @@ datamodule:
86
  render_root: ${render_root}
87
  test_dataset:
88
  _target_: remfx.datasets.EffectDataset
 
89
  sample_rate: ${sample_rate}
90
  root: ${oc.env:DATASET_ROOT}
91
  chunk_size: ${chunk_size}
 
56
  _target_: remfx.datasets.EffectDatamodule
57
  train_dataset:
58
  _target_: remfx.datasets.EffectDataset
59
+ total_chunks: 8000
60
  sample_rate: ${sample_rate}
61
  root: ${oc.env:DATASET_ROOT}
62
  chunk_size: ${chunk_size}
 
72
  render_root: ${render_root}
73
  val_dataset:
74
  _target_: remfx.datasets.EffectDataset
75
+ total_chunks: 1000
76
  sample_rate: ${sample_rate}
77
  root: ${oc.env:DATASET_ROOT}
78
  chunk_size: ${chunk_size}
 
88
  render_root: ${render_root}
89
  test_dataset:
90
  _target_: remfx.datasets.EffectDataset
91
+ total_chunks: 1000
92
  sample_rate: ${sample_rate}
93
  root: ${oc.env:DATASET_ROOT}
94
  chunk_size: ${chunk_size}
remfx/datasets.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import shutil
6
  import torchaudio
7
  import pytorch_lightning as pl
8
-
9
  from tqdm import tqdm
10
  from pathlib import Path
11
  from remfx import effects
@@ -81,7 +81,7 @@ def locate_files(root: str, mode: str):
81
  for singer_dir in singer_dirs:
82
  files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
83
  print(f"Found {len(files)} files in VocalSet {mode}.")
84
- file_list += sorted(files)
85
  # ------------------------- GuitarSet -------------------------
86
  guitarset_dir = os.path.join(root, "audio_mono-mic")
87
  if os.path.isdir(guitarset_dir):
@@ -92,7 +92,7 @@ def locate_files(root: str, mode: str):
92
  if os.path.basename(f).split("_")[0] in guitarset_splits[mode]
93
  ]
94
  print(f"Found {len(files)} files in GuitarSet {mode}.")
95
- file_list += sorted(files)
96
  # ------------------------- IDMT-SMT-GUITAR -------------------------
97
  idmt_smt_guitar_dir = os.path.join(root, "IDMT-SMT-GUITAR_V2")
98
  if os.path.isdir(idmt_smt_guitar_dir):
@@ -107,7 +107,7 @@ def locate_files(root: str, mode: str):
107
  for f in files
108
  if os.path.basename(f).split("_")[0] in idmt_guitar_splits[mode]
109
  ]
110
- file_list += sorted(files)
111
  print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
112
  # ------------------------- IDMT-SMT-BASS -------------------------
113
  # idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
@@ -121,7 +121,7 @@ def locate_files(root: str, mode: str):
121
  # for f in files
122
  # if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
123
  # ]
124
- # file_list += sorted(files)
125
  # print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
126
  # ------------------------- DSD100 ---------------------------------
127
  dsd_100_dir = os.path.join(root, "DSD100")
@@ -130,7 +130,7 @@ def locate_files(root: str, mode: str):
130
  os.path.join(dsd_100_dir, mode, "**", "*.wav"),
131
  recursive=True,
132
  )
133
- file_list += sorted(files)
134
  print(f"Found {len(files)} files in DSD100 {mode}.")
135
  # ------------------------- IDMT-SMT-DRUMS -------------------------
136
  idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
@@ -141,7 +141,7 @@ def locate_files(root: str, mode: str):
141
  for f in files
142
  if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode]
143
  ]
144
- file_list += sorted(files)
145
  print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.")
146
 
147
  return file_list
@@ -153,6 +153,7 @@ class EffectDataset(Dataset):
153
  root: str,
154
  sample_rate: int,
155
  chunk_size: int = 262144,
 
156
  effect_modules: List[Dict[str, torch.nn.Module]] = None,
157
  effects_to_use: List[str] = None,
158
  effects_to_remove: List[str] = None,
@@ -170,6 +171,7 @@ class EffectDataset(Dataset):
170
  self.root = Path(root)
171
  self.render_root = Path(render_root)
172
  self.chunk_size = chunk_size
 
173
  self.sample_rate = sample_rate
174
  self.mode = mode
175
  self.max_kept_effects = max_kept_effects
@@ -198,14 +200,17 @@ class EffectDataset(Dataset):
198
  sys.exit()
199
  shutil.rmtree(self.proc_root)
200
 
201
- self.num_chunks = 0
202
- print("Total files:", len(self.files))
203
  print("Processing files...")
204
  if render_files:
205
  # Split audio file into chunks, resample, then apply random effects
206
  self.proc_root.mkdir(parents=True, exist_ok=True)
207
- for audio_file in tqdm(self.files, total=len(self.files)):
208
- chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
 
 
 
 
209
  for chunk in chunks:
210
  resampled_chunk = torchaudio.functional.resample(
211
  chunk, orig_sr, sample_rate
@@ -220,23 +225,21 @@ class EffectDataset(Dataset):
220
  dry, wet, dry_effects, wet_effects = self.process_effects(
221
  resampled_chunk
222
  )
223
- output_dir = self.proc_root / str(self.num_chunks)
224
  output_dir.mkdir(exist_ok=True)
225
  torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
226
  torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
227
  torch.save(dry_effects, output_dir / "dry_effects.pt")
228
  torch.save(wet_effects, output_dir / "wet_effects.pt")
229
- self.num_chunks += 1
 
230
  else:
231
- self.num_chunks = len(list(self.proc_root.iterdir()))
232
 
233
- print(
234
- f"Found {len(self.files)} {self.mode} files .\n"
235
- f"Total chunks: {self.num_chunks}"
236
- )
237
 
238
  def __len__(self):
239
- return self.num_chunks
240
 
241
  def __getitem__(self, idx):
242
  input_file = self.proc_root / str(idx) / "input.wav"
 
5
  import shutil
6
  import torchaudio
7
  import pytorch_lightning as pl
8
+ import random
9
  from tqdm import tqdm
10
  from pathlib import Path
11
  from remfx import effects
 
81
  for singer_dir in singer_dirs:
82
  files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
83
  print(f"Found {len(files)} files in VocalSet {mode}.")
84
+ file_list.append(sorted(files))
85
  # ------------------------- GuitarSet -------------------------
86
  guitarset_dir = os.path.join(root, "audio_mono-mic")
87
  if os.path.isdir(guitarset_dir):
 
92
  if os.path.basename(f).split("_")[0] in guitarset_splits[mode]
93
  ]
94
  print(f"Found {len(files)} files in GuitarSet {mode}.")
95
+ file_list.append(sorted(files))
96
  # ------------------------- IDMT-SMT-GUITAR -------------------------
97
  idmt_smt_guitar_dir = os.path.join(root, "IDMT-SMT-GUITAR_V2")
98
  if os.path.isdir(idmt_smt_guitar_dir):
 
107
  for f in files
108
  if os.path.basename(f).split("_")[0] in idmt_guitar_splits[mode]
109
  ]
110
+ file_list.append(sorted(files))
111
  print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
112
  # ------------------------- IDMT-SMT-BASS -------------------------
113
  # idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
 
121
  # for f in files
122
  # if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
123
  # ]
124
+ # file_list.append(sorted(files))
125
  # print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
126
  # ------------------------- DSD100 ---------------------------------
127
  dsd_100_dir = os.path.join(root, "DSD100")
 
130
  os.path.join(dsd_100_dir, mode, "**", "*.wav"),
131
  recursive=True,
132
  )
133
+ file_list.append(sorted(files))
134
  print(f"Found {len(files)} files in DSD100 {mode}.")
135
  # ------------------------- IDMT-SMT-DRUMS -------------------------
136
  idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
 
141
  for f in files
142
  if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode]
143
  ]
144
+ file_list.append(sorted(files))
145
  print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.")
146
 
147
  return file_list
 
153
  root: str,
154
  sample_rate: int,
155
  chunk_size: int = 262144,
156
+ total_chunks: int = 1000,
157
  effect_modules: List[Dict[str, torch.nn.Module]] = None,
158
  effects_to_use: List[str] = None,
159
  effects_to_remove: List[str] = None,
 
171
  self.root = Path(root)
172
  self.render_root = Path(render_root)
173
  self.chunk_size = chunk_size
174
+ self.total_chunks = total_chunks
175
  self.sample_rate = sample_rate
176
  self.mode = mode
177
  self.max_kept_effects = max_kept_effects
 
200
  sys.exit()
201
  shutil.rmtree(self.proc_root)
202
 
203
+ print("Total datasets:", len(self.files))
 
204
  print("Processing files...")
205
  if render_files:
206
  # Split audio file into chunks, resample, then apply random effects
207
  self.proc_root.mkdir(parents=True, exist_ok=True)
208
+ for num_chunk in tqdm(range(self.total_chunks)):
209
+ random_dataset_choice = random.choice(self.files)
210
+ random_file_choice = random.choice(random_dataset_choice)
211
+ chunks, orig_sr = create_sequential_chunks(
212
+ random_file_choice, self.chunk_size
213
+ )
214
  for chunk in chunks:
215
  resampled_chunk = torchaudio.functional.resample(
216
  chunk, orig_sr, sample_rate
 
225
  dry, wet, dry_effects, wet_effects = self.process_effects(
226
  resampled_chunk
227
  )
228
+ output_dir = self.proc_root / str(num_chunk)
229
  output_dir.mkdir(exist_ok=True)
230
  torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
231
  torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
232
  torch.save(dry_effects, output_dir / "dry_effects.pt")
233
  torch.save(wet_effects, output_dir / "wet_effects.pt")
234
+
235
+ print("Finished rendering")
236
  else:
237
+ self.total_chunks = len(list(self.proc_root.iterdir()))
238
 
239
+ print("Total chunks:", self.total_chunks)
 
 
 
240
 
241
  def __len__(self):
242
+ return self.total_chunks
243
 
244
  def __getitem__(self, idx):
245
  input_file = self.proc_root / str(idx) / "input.wav"