Spaces:
Sleeping
Sleeping
Commit
·
9325b1e
1
Parent(s):
bd1743b
Add random sampling of datasets to prevent class imbalance
Browse files- cfg/config.yaml +3 -0
- remfx/datasets.py +22 -19
cfg/config.yaml
CHANGED
@@ -56,6 +56,7 @@ datamodule:
|
|
56 |
_target_: remfx.datasets.EffectDatamodule
|
57 |
train_dataset:
|
58 |
_target_: remfx.datasets.EffectDataset
|
|
|
59 |
sample_rate: ${sample_rate}
|
60 |
root: ${oc.env:DATASET_ROOT}
|
61 |
chunk_size: ${chunk_size}
|
@@ -71,6 +72,7 @@ datamodule:
|
|
71 |
render_root: ${render_root}
|
72 |
val_dataset:
|
73 |
_target_: remfx.datasets.EffectDataset
|
|
|
74 |
sample_rate: ${sample_rate}
|
75 |
root: ${oc.env:DATASET_ROOT}
|
76 |
chunk_size: ${chunk_size}
|
@@ -86,6 +88,7 @@ datamodule:
|
|
86 |
render_root: ${render_root}
|
87 |
test_dataset:
|
88 |
_target_: remfx.datasets.EffectDataset
|
|
|
89 |
sample_rate: ${sample_rate}
|
90 |
root: ${oc.env:DATASET_ROOT}
|
91 |
chunk_size: ${chunk_size}
|
|
|
56 |
_target_: remfx.datasets.EffectDatamodule
|
57 |
train_dataset:
|
58 |
_target_: remfx.datasets.EffectDataset
|
59 |
+
total_chunks: 8000
|
60 |
sample_rate: ${sample_rate}
|
61 |
root: ${oc.env:DATASET_ROOT}
|
62 |
chunk_size: ${chunk_size}
|
|
|
72 |
render_root: ${render_root}
|
73 |
val_dataset:
|
74 |
_target_: remfx.datasets.EffectDataset
|
75 |
+
total_chunks: 1000
|
76 |
sample_rate: ${sample_rate}
|
77 |
root: ${oc.env:DATASET_ROOT}
|
78 |
chunk_size: ${chunk_size}
|
|
|
88 |
render_root: ${render_root}
|
89 |
test_dataset:
|
90 |
_target_: remfx.datasets.EffectDataset
|
91 |
+
total_chunks: 1000
|
92 |
sample_rate: ${sample_rate}
|
93 |
root: ${oc.env:DATASET_ROOT}
|
94 |
chunk_size: ${chunk_size}
|
remfx/datasets.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import shutil
|
6 |
import torchaudio
|
7 |
import pytorch_lightning as pl
|
8 |
-
|
9 |
from tqdm import tqdm
|
10 |
from pathlib import Path
|
11 |
from remfx import effects
|
@@ -81,7 +81,7 @@ def locate_files(root: str, mode: str):
|
|
81 |
for singer_dir in singer_dirs:
|
82 |
files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
|
83 |
print(f"Found {len(files)} files in VocalSet {mode}.")
|
84 |
-
file_list
|
85 |
# ------------------------- GuitarSet -------------------------
|
86 |
guitarset_dir = os.path.join(root, "audio_mono-mic")
|
87 |
if os.path.isdir(guitarset_dir):
|
@@ -92,7 +92,7 @@ def locate_files(root: str, mode: str):
|
|
92 |
if os.path.basename(f).split("_")[0] in guitarset_splits[mode]
|
93 |
]
|
94 |
print(f"Found {len(files)} files in GuitarSet {mode}.")
|
95 |
-
file_list
|
96 |
# ------------------------- IDMT-SMT-GUITAR -------------------------
|
97 |
idmt_smt_guitar_dir = os.path.join(root, "IDMT-SMT-GUITAR_V2")
|
98 |
if os.path.isdir(idmt_smt_guitar_dir):
|
@@ -107,7 +107,7 @@ def locate_files(root: str, mode: str):
|
|
107 |
for f in files
|
108 |
if os.path.basename(f).split("_")[0] in idmt_guitar_splits[mode]
|
109 |
]
|
110 |
-
file_list
|
111 |
print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
|
112 |
# ------------------------- IDMT-SMT-BASS -------------------------
|
113 |
# idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
|
@@ -121,7 +121,7 @@ def locate_files(root: str, mode: str):
|
|
121 |
# for f in files
|
122 |
# if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
|
123 |
# ]
|
124 |
-
# file_list
|
125 |
# print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
|
126 |
# ------------------------- DSD100 ---------------------------------
|
127 |
dsd_100_dir = os.path.join(root, "DSD100")
|
@@ -130,7 +130,7 @@ def locate_files(root: str, mode: str):
|
|
130 |
os.path.join(dsd_100_dir, mode, "**", "*.wav"),
|
131 |
recursive=True,
|
132 |
)
|
133 |
-
file_list
|
134 |
print(f"Found {len(files)} files in DSD100 {mode}.")
|
135 |
# ------------------------- IDMT-SMT-DRUMS -------------------------
|
136 |
idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
|
@@ -141,7 +141,7 @@ def locate_files(root: str, mode: str):
|
|
141 |
for f in files
|
142 |
if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode]
|
143 |
]
|
144 |
-
file_list
|
145 |
print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.")
|
146 |
|
147 |
return file_list
|
@@ -153,6 +153,7 @@ class EffectDataset(Dataset):
|
|
153 |
root: str,
|
154 |
sample_rate: int,
|
155 |
chunk_size: int = 262144,
|
|
|
156 |
effect_modules: List[Dict[str, torch.nn.Module]] = None,
|
157 |
effects_to_use: List[str] = None,
|
158 |
effects_to_remove: List[str] = None,
|
@@ -170,6 +171,7 @@ class EffectDataset(Dataset):
|
|
170 |
self.root = Path(root)
|
171 |
self.render_root = Path(render_root)
|
172 |
self.chunk_size = chunk_size
|
|
|
173 |
self.sample_rate = sample_rate
|
174 |
self.mode = mode
|
175 |
self.max_kept_effects = max_kept_effects
|
@@ -198,14 +200,17 @@ class EffectDataset(Dataset):
|
|
198 |
sys.exit()
|
199 |
shutil.rmtree(self.proc_root)
|
200 |
|
201 |
-
self.
|
202 |
-
print("Total files:", len(self.files))
|
203 |
print("Processing files...")
|
204 |
if render_files:
|
205 |
# Split audio file into chunks, resample, then apply random effects
|
206 |
self.proc_root.mkdir(parents=True, exist_ok=True)
|
207 |
-
for
|
208 |
-
|
|
|
|
|
|
|
|
|
209 |
for chunk in chunks:
|
210 |
resampled_chunk = torchaudio.functional.resample(
|
211 |
chunk, orig_sr, sample_rate
|
@@ -220,23 +225,21 @@ class EffectDataset(Dataset):
|
|
220 |
dry, wet, dry_effects, wet_effects = self.process_effects(
|
221 |
resampled_chunk
|
222 |
)
|
223 |
-
output_dir = self.proc_root / str(
|
224 |
output_dir.mkdir(exist_ok=True)
|
225 |
torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
|
226 |
torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
|
227 |
torch.save(dry_effects, output_dir / "dry_effects.pt")
|
228 |
torch.save(wet_effects, output_dir / "wet_effects.pt")
|
229 |
-
|
|
|
230 |
else:
|
231 |
-
self.
|
232 |
|
233 |
-
print(
|
234 |
-
f"Found {len(self.files)} {self.mode} files .\n"
|
235 |
-
f"Total chunks: {self.num_chunks}"
|
236 |
-
)
|
237 |
|
238 |
def __len__(self):
|
239 |
-
return self.
|
240 |
|
241 |
def __getitem__(self, idx):
|
242 |
input_file = self.proc_root / str(idx) / "input.wav"
|
|
|
5 |
import shutil
|
6 |
import torchaudio
|
7 |
import pytorch_lightning as pl
|
8 |
+
import random
|
9 |
from tqdm import tqdm
|
10 |
from pathlib import Path
|
11 |
from remfx import effects
|
|
|
81 |
for singer_dir in singer_dirs:
|
82 |
files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
|
83 |
print(f"Found {len(files)} files in VocalSet {mode}.")
|
84 |
+
file_list.append(sorted(files))
|
85 |
# ------------------------- GuitarSet -------------------------
|
86 |
guitarset_dir = os.path.join(root, "audio_mono-mic")
|
87 |
if os.path.isdir(guitarset_dir):
|
|
|
92 |
if os.path.basename(f).split("_")[0] in guitarset_splits[mode]
|
93 |
]
|
94 |
print(f"Found {len(files)} files in GuitarSet {mode}.")
|
95 |
+
file_list.append(sorted(files))
|
96 |
# ------------------------- IDMT-SMT-GUITAR -------------------------
|
97 |
idmt_smt_guitar_dir = os.path.join(root, "IDMT-SMT-GUITAR_V2")
|
98 |
if os.path.isdir(idmt_smt_guitar_dir):
|
|
|
107 |
for f in files
|
108 |
if os.path.basename(f).split("_")[0] in idmt_guitar_splits[mode]
|
109 |
]
|
110 |
+
file_list.append(sorted(files))
|
111 |
print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
|
112 |
# ------------------------- IDMT-SMT-BASS -------------------------
|
113 |
# idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
|
|
|
121 |
# for f in files
|
122 |
# if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
|
123 |
# ]
|
124 |
+
# file_list.append(sorted(files))
|
125 |
# print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
|
126 |
# ------------------------- DSD100 ---------------------------------
|
127 |
dsd_100_dir = os.path.join(root, "DSD100")
|
|
|
130 |
os.path.join(dsd_100_dir, mode, "**", "*.wav"),
|
131 |
recursive=True,
|
132 |
)
|
133 |
+
file_list.append(sorted(files))
|
134 |
print(f"Found {len(files)} files in DSD100 {mode}.")
|
135 |
# ------------------------- IDMT-SMT-DRUMS -------------------------
|
136 |
idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
|
|
|
141 |
for f in files
|
142 |
if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode]
|
143 |
]
|
144 |
+
file_list.append(sorted(files))
|
145 |
print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.")
|
146 |
|
147 |
return file_list
|
|
|
153 |
root: str,
|
154 |
sample_rate: int,
|
155 |
chunk_size: int = 262144,
|
156 |
+
total_chunks: int = 1000,
|
157 |
effect_modules: List[Dict[str, torch.nn.Module]] = None,
|
158 |
effects_to_use: List[str] = None,
|
159 |
effects_to_remove: List[str] = None,
|
|
|
171 |
self.root = Path(root)
|
172 |
self.render_root = Path(render_root)
|
173 |
self.chunk_size = chunk_size
|
174 |
+
self.total_chunks = total_chunks
|
175 |
self.sample_rate = sample_rate
|
176 |
self.mode = mode
|
177 |
self.max_kept_effects = max_kept_effects
|
|
|
200 |
sys.exit()
|
201 |
shutil.rmtree(self.proc_root)
|
202 |
|
203 |
+
print("Total datasets:", len(self.files))
|
|
|
204 |
print("Processing files...")
|
205 |
if render_files:
|
206 |
# Split audio file into chunks, resample, then apply random effects
|
207 |
self.proc_root.mkdir(parents=True, exist_ok=True)
|
208 |
+
for num_chunk in tqdm(range(self.total_chunks)):
|
209 |
+
random_dataset_choice = random.choice(self.files)
|
210 |
+
random_file_choice = random.choice(random_dataset_choice)
|
211 |
+
chunks, orig_sr = create_sequential_chunks(
|
212 |
+
random_file_choice, self.chunk_size
|
213 |
+
)
|
214 |
for chunk in chunks:
|
215 |
resampled_chunk = torchaudio.functional.resample(
|
216 |
chunk, orig_sr, sample_rate
|
|
|
225 |
dry, wet, dry_effects, wet_effects = self.process_effects(
|
226 |
resampled_chunk
|
227 |
)
|
228 |
+
output_dir = self.proc_root / str(num_chunk)
|
229 |
output_dir.mkdir(exist_ok=True)
|
230 |
torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
|
231 |
torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
|
232 |
torch.save(dry_effects, output_dir / "dry_effects.pt")
|
233 |
torch.save(wet_effects, output_dir / "wet_effects.pt")
|
234 |
+
|
235 |
+
print("Finished rendering")
|
236 |
else:
|
237 |
+
self.total_chunks = len(list(self.proc_root.iterdir()))
|
238 |
|
239 |
+
print("Total chunks:", self.total_chunks)
|
|
|
|
|
|
|
240 |
|
241 |
def __len__(self):
|
242 |
+
return self.total_chunks
|
243 |
|
244 |
def __getitem__(self, idx):
|
245 |
input_file = self.proc_root / str(idx) / "input.wav"
|