mattricesound commited on
Commit
30c1b93
·
2 Parent(s): e0aa67f ca6b6f7

Merge pull request #37 from mhrice/dsd100-dataset

Browse files
cfg/config.yaml CHANGED
@@ -53,9 +53,10 @@ callbacks:
53
  _target_: remfx.callbacks.MetricCallback
54
 
55
  datamodule:
56
- _target_: remfx.datasets.VocalSetDatamodule
57
  train_dataset:
58
- _target_: remfx.datasets.VocalSet
 
59
  sample_rate: ${sample_rate}
60
  root: ${oc.env:DATASET_ROOT}
61
  chunk_size: ${chunk_size}
@@ -70,7 +71,8 @@ datamodule:
70
  render_files: ${render_files}
71
  render_root: ${render_root}
72
  val_dataset:
73
- _target_: remfx.datasets.VocalSet
 
74
  sample_rate: ${sample_rate}
75
  root: ${oc.env:DATASET_ROOT}
76
  chunk_size: ${chunk_size}
@@ -85,7 +87,8 @@ datamodule:
85
  render_files: ${render_files}
86
  render_root: ${render_root}
87
  test_dataset:
88
- _target_: remfx.datasets.VocalSet
 
89
  sample_rate: ${sample_rate}
90
  root: ${oc.env:DATASET_ROOT}
91
  chunk_size: ${chunk_size}
 
53
  _target_: remfx.callbacks.MetricCallback
54
 
55
  datamodule:
56
+ _target_: remfx.datasets.EffectDatamodule
57
  train_dataset:
58
+ _target_: remfx.datasets.EffectDataset
59
+ total_chunks: 8000
60
  sample_rate: ${sample_rate}
61
  root: ${oc.env:DATASET_ROOT}
62
  chunk_size: ${chunk_size}
 
71
  render_files: ${render_files}
72
  render_root: ${render_root}
73
  val_dataset:
74
+ _target_: remfx.datasets.EffectDataset
75
+ total_chunks: 1000
76
  sample_rate: ${sample_rate}
77
  root: ${oc.env:DATASET_ROOT}
78
  chunk_size: ${chunk_size}
 
87
  render_files: ${render_files}
88
  render_root: ${render_root}
89
  test_dataset:
90
+ _target_: remfx.datasets.EffectDataset
91
+ total_chunks: 1000
92
  sample_rate: ${sample_rate}
93
  root: ${oc.env:DATASET_ROOT}
94
  chunk_size: ${chunk_size}
cfg/effects/all.yaml CHANGED
@@ -4,13 +4,19 @@ effects:
4
  chorus:
5
  _target_: remfx.effects.RandomPedalboardChorus
6
  sample_rate: ${sample_rate}
 
 
 
 
7
  min_depth: 0.2
8
- min_mix: 0.3
 
 
9
  distortion:
10
  _target_: remfx.effects.RandomPedalboardDistortion
11
  sample_rate: ${sample_rate}
12
- min_drive_db: 10
13
- max_drive_db: 50
14
  compressor:
15
  _target_: remfx.effects.RandomPedalboardCompressor
16
  sample_rate: ${sample_rate}
@@ -26,7 +32,7 @@ effects:
26
  min_damping: 0.2
27
  max_damping: 1.0
28
  min_wet_dry: 0.2
29
- max_wet_dry: 0.8
30
  min_width: 0.2
31
  max_width: 1.0
32
  delay:
@@ -35,6 +41,6 @@ effects:
35
  min_delay_seconds: 0.1
36
  max_delay_sconds: 1.0
37
  min_feedback: 0.05
38
- max_feedback: 0.6
39
- min_mix: 0.2
40
- max_mix: 0.7
 
4
  chorus:
5
  _target_: remfx.effects.RandomPedalboardChorus
6
  sample_rate: ${sample_rate}
7
+ min_rate_hz: 0.25
8
+ max_rate_hz: 1.5
9
+ min_feedback: 0.1
10
+ max_feedback: 0.4
11
  min_depth: 0.2
12
+ max_depth: 0.6
13
+ min_mix: 0.15
14
+ max_mix: 0.4
15
  distortion:
16
  _target_: remfx.effects.RandomPedalboardDistortion
17
  sample_rate: ${sample_rate}
18
+ min_drive_db: 7
19
+ max_drive_db: 25
20
  compressor:
21
  _target_: remfx.effects.RandomPedalboardCompressor
22
  sample_rate: ${sample_rate}
 
32
  min_damping: 0.2
33
  max_damping: 1.0
34
  min_wet_dry: 0.2
35
+ max_wet_dry: 0.6
36
  min_width: 0.2
37
  max_width: 1.0
38
  delay:
 
41
  min_delay_seconds: 0.1
42
  max_delay_sconds: 1.0
43
  min_feedback: 0.05
44
+ max_feedback: 0.3
45
+ min_mix: 0.1
46
+ max_mix: 0.35
cfg/exp/default.yaml CHANGED
@@ -1,6 +1,6 @@
1
  # @package _global_
2
  defaults:
3
- - override /model: demucs
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
 
1
  # @package _global_
2
  defaults:
3
+ - override /model: umx
4
  - override /effects: all
5
  seed: 12345
6
  sample_rate: 48000
cfg/exp/dist.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ defaults:
3
+ - override /model: umx
4
+ - override /effects: all
5
+ seed: 12345
6
+ sample_rate: 48000
7
+ chunk_size: 262144 # 5.5s
8
+ logs_dir: "./logs"
9
+ render_files: True
10
+ render_root: "/scratch/EffectSet"
11
+ accelerator: "gpu"
12
+ log_audio: True
13
+ # Effects
14
+ max_kept_effects: 5
15
+ max_removed_effects: -1
16
+ shuffle_kept_effects: True
17
+ shuffle_removed_effects: False
18
+ num_classes: 5
19
+ effects_to_use:
20
+ - compressor
21
+ - distortion
22
+ - reverb
23
+ - chorus
24
+ - delay
25
+ effects_to_remove:
26
+ - distortion
27
+ datamodule:
28
+ batch_size: 16
29
+ num_workers: 8
cfg/model/dcunet.yaml CHANGED
@@ -9,16 +9,8 @@ model:
9
  sample_rate: ${sample_rate}
10
  network:
11
  _target_: remfx.models.DCUNetModel
12
- spec_dim: 257
13
- hidden_dim: 768
14
- filter_len: 512
15
- hop_len: 64
16
- block_layers: 4
17
- layers: 4
18
- kernel_size: 3
19
- refine_layers: 1
20
- is_mask: True
21
- norm: 'ins'
22
- act: 'comp'
23
  sample_rate: ${sample_rate}
24
  num_bins: 1025
 
9
  sample_rate: ${sample_rate}
10
  network:
11
  _target_: remfx.models.DCUNetModel
12
+ architecture: "DCUNet-10"
13
+ stft_kernel_size: 512
14
+ fix_length_mode: "pad"
 
 
 
 
 
 
 
 
15
  sample_rate: ${sample_rate}
16
  num_bins: 1025
cfg/model/dptnet.yaml CHANGED
@@ -9,12 +9,14 @@ model:
9
  sample_rate: ${sample_rate}
10
  network:
11
  _target_: remfx.models.DPTNetModel
12
- enc_dim: 256
13
- feature_dim: 64
14
- hidden_dim: 128
15
- layer: 6
16
- segment_size: 250
17
- nspk: 1
18
- win_len: 2
 
 
19
  sample_rate: ${sample_rate}
20
  num_bins: 1025
 
9
  sample_rate: ${sample_rate}
10
  network:
11
  _target_: remfx.models.DPTNetModel
12
+ n_src: 1
13
+ in_chan: 64
14
+ out_chan: 64
15
+ chunk_size: 100
16
+ n_repeats: 2
17
+ fb_name: "free"
18
+ kernel_size: 16
19
+ n_filters: 64
20
+ stride: 8
21
  sample_rate: ${sample_rate}
22
  num_bins: 1025
cfg/model/tcn.yaml CHANGED
@@ -11,12 +11,12 @@ model:
11
  _target_: remfx.models.TCNModel
12
  ninputs: 1
13
  noutputs: 1
14
- nblocks: 4
15
  channel_growth: 0
16
- channel_width: 32
17
- kernel_size: 13
18
  stack_size: 10
19
- dilation_growth: 10
20
  condition: False
21
  latent_dim: 2
22
  norm_type: "identity"
 
11
  _target_: remfx.models.TCNModel
12
  ninputs: 1
13
  noutputs: 1
14
+ nblocks: 20
15
  channel_growth: 0
16
+ channel_width: 64
17
+ kernel_size: 7
18
  stack_size: 10
19
+ dilation_growth: 2
20
  condition: False
21
  latent_dim: 2
22
  norm_type: "identity"
remfx/datasets.py CHANGED
@@ -5,20 +5,20 @@ import torch
5
  import shutil
6
  import torchaudio
7
  import pytorch_lightning as pl
8
-
9
  from tqdm import tqdm
10
  from pathlib import Path
11
  from remfx import effects
12
  from ordered_set import OrderedSet
13
  from typing import Any, List, Dict
14
  from torch.utils.data import Dataset, DataLoader
15
- from remfx.utils import create_sequential_chunks
16
 
17
 
18
  # https://zenodo.org/record/1193957 -> VocalSet
19
 
20
  ALL_EFFECTS = effects.Pedalboard_Effects
21
- print(ALL_EFFECTS)
22
 
23
 
24
  vocalset_splits = {
@@ -55,6 +55,11 @@ idmt_bass_splits = {
55
  "val": ["VIF"],
56
  "test": ["VIS"],
57
  }
 
 
 
 
 
58
  idmt_drums_splits = {
59
  "train": ["WaveDrum02", "TechnoDrum01"],
60
  "val": ["RealDrum01"],
@@ -76,7 +81,7 @@ def locate_files(root: str, mode: str):
76
  for singer_dir in singer_dirs:
77
  files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
78
  print(f"Found {len(files)} files in VocalSet {mode}.")
79
- file_list += sorted(files)
80
  # ------------------------- GuitarSet -------------------------
81
  guitarset_dir = os.path.join(root, "audio_mono-mic")
82
  if os.path.isdir(guitarset_dir):
@@ -87,37 +92,46 @@ def locate_files(root: str, mode: str):
87
  if os.path.basename(f).split("_")[0] in guitarset_splits[mode]
88
  ]
89
  print(f"Found {len(files)} files in GuitarSet {mode}.")
90
- file_list += sorted(files)
91
- # ------------------------- IDMT-SMT-GUITAR -------------------------
92
- idmt_smt_guitar_dir = os.path.join(root, "IDMT-SMT-GUITAR_V2")
93
- if os.path.isdir(idmt_smt_guitar_dir):
94
- files = glob.glob(
95
- os.path.join(
96
- idmt_smt_guitar_dir, "IDMT-SMT-GUITAR_V2", "dataset4", "**", "*.wav"
97
- ),
98
- recursive=True,
99
- )
100
- files = [
101
- f
102
- for f in files
103
- if os.path.basename(f).split("_")[0] in idmt_guitar_splits[mode]
104
- ]
105
- file_list += sorted(files)
106
- print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
107
  # ------------------------- IDMT-SMT-BASS -------------------------
108
- idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
109
- if os.path.isdir(idmt_smt_bass_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  files = glob.glob(
111
- os.path.join(idmt_smt_bass_dir, "**", "*.wav"),
112
  recursive=True,
113
  )
114
- files = [
115
- f
116
- for f in files
117
- if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
118
- ]
119
- file_list += sorted(files)
120
- print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
121
  # ------------------------- IDMT-SMT-DRUMS -------------------------
122
  idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
123
  if os.path.isdir(idmt_smt_drums_dir):
@@ -127,18 +141,19 @@ def locate_files(root: str, mode: str):
127
  for f in files
128
  if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode]
129
  ]
130
- file_list += sorted(files)
131
  print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.")
132
 
133
  return file_list
134
 
135
 
136
- class VocalSet(Dataset):
137
  def __init__(
138
  self,
139
  root: str,
140
  sample_rate: int,
141
  chunk_size: int = 262144,
 
142
  effect_modules: List[Dict[str, torch.nn.Module]] = None,
143
  effects_to_use: List[str] = None,
144
  effects_to_remove: List[str] = None,
@@ -156,6 +171,7 @@ class VocalSet(Dataset):
156
  self.root = Path(root)
157
  self.render_root = Path(render_root)
158
  self.chunk_size = chunk_size
 
159
  self.sample_rate = sample_rate
160
  self.mode = mode
161
  self.max_kept_effects = max_kept_effects
@@ -184,42 +200,40 @@ class VocalSet(Dataset):
184
  sys.exit()
185
  shutil.rmtree(self.proc_root)
186
 
187
- self.num_chunks = 0
188
- print("Total files:", len(self.files))
189
  print("Processing files...")
190
  if render_files:
191
  # Split audio file into chunks, resample, then apply random effects
192
  self.proc_root.mkdir(parents=True, exist_ok=True)
193
- for audio_file in tqdm(self.files, total=len(self.files)):
194
- chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
195
- for chunk in chunks:
196
- resampled_chunk = torchaudio.functional.resample(
197
- chunk, orig_sr, sample_rate
 
 
198
  )
199
- if resampled_chunk.shape[-1] < chunk_size:
200
- # Skip if chunk is too small
201
- continue
202
 
203
- dry, wet, dry_effects, wet_effects = self.process_effects(
204
- resampled_chunk
205
- )
206
- output_dir = self.proc_root / str(self.num_chunks)
207
- output_dir.mkdir(exist_ok=True)
208
- torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
209
- torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
210
- torch.save(dry_effects, output_dir / "dry_effects.pt")
211
- torch.save(wet_effects, output_dir / "wet_effects.pt")
212
- self.num_chunks += 1
 
 
 
213
  else:
214
- self.num_chunks = len(list(self.proc_root.iterdir()))
215
 
216
- print(
217
- f"Found {len(self.files)} {self.mode} files .\n"
218
- f"Total chunks: {self.num_chunks}"
219
- )
220
 
221
  def __len__(self):
222
- return self.num_chunks
223
 
224
  def __getitem__(self, idx):
225
  input_file = self.proc_root / str(idx) / "input.wav"
@@ -281,7 +295,7 @@ class VocalSet(Dataset):
281
 
282
  # Up to max_kept_effects
283
  if self.max_kept_effects != -1:
284
- num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects)) + 1
285
  else:
286
  num_kept_effects = len(self.effects_to_keep)
287
  effect_indices = effect_indices[:num_kept_effects]
@@ -292,7 +306,8 @@ class VocalSet(Dataset):
292
  # Apply
293
  dry_labels = []
294
  for effect in effects_to_apply:
295
- dry = effect(dry)
 
296
  dry_labels.append(ALL_EFFECTS.index(type(effect)))
297
 
298
  # Apply effects_to_remove
@@ -315,7 +330,8 @@ class VocalSet(Dataset):
315
 
316
  wet_labels = []
317
  for effect in effects_to_apply:
318
- wet = effect(wet)
 
319
  wet_labels.append(ALL_EFFECTS.index(type(effect)))
320
 
321
  wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
@@ -334,7 +350,7 @@ class VocalSet(Dataset):
334
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
335
 
336
 
337
- class VocalSetDatamodule(pl.LightningDataModule):
338
  def __init__(
339
  self,
340
  train_dataset,
 
5
  import shutil
6
  import torchaudio
7
  import pytorch_lightning as pl
8
+ import random
9
  from tqdm import tqdm
10
  from pathlib import Path
11
  from remfx import effects
12
  from ordered_set import OrderedSet
13
  from typing import Any, List, Dict
14
  from torch.utils.data import Dataset, DataLoader
15
+ from remfx.utils import select_random_chunk
16
 
17
 
18
  # https://zenodo.org/record/1193957 -> VocalSet
19
 
20
  ALL_EFFECTS = effects.Pedalboard_Effects
21
+ # print(ALL_EFFECTS)
22
 
23
 
24
  vocalset_splits = {
 
55
  "val": ["VIF"],
56
  "test": ["VIS"],
57
  }
58
+ dsd_100_splits = {
59
+ "train": ["train"],
60
+ "val": ["val"],
61
+ "test": ["test"],
62
+ }
63
  idmt_drums_splits = {
64
  "train": ["WaveDrum02", "TechnoDrum01"],
65
  "val": ["RealDrum01"],
 
81
  for singer_dir in singer_dirs:
82
  files += glob.glob(os.path.join(singer_dir, "**", "**", "*.wav"))
83
  print(f"Found {len(files)} files in VocalSet {mode}.")
84
+ file_list.append(sorted(files))
85
  # ------------------------- GuitarSet -------------------------
86
  guitarset_dir = os.path.join(root, "audio_mono-mic")
87
  if os.path.isdir(guitarset_dir):
 
92
  if os.path.basename(f).split("_")[0] in guitarset_splits[mode]
93
  ]
94
  print(f"Found {len(files)} files in GuitarSet {mode}.")
95
+ file_list.append(sorted(files))
96
+ # # ------------------------- IDMT-SMT-GUITAR -------------------------
97
+ # idmt_smt_guitar_dir = os.path.join(root, "IDMT-SMT-GUITAR_V2")
98
+ # if os.path.isdir(idmt_smt_guitar_dir):
99
+ # files = glob.glob(
100
+ # os.path.join(
101
+ # idmt_smt_guitar_dir, "IDMT-SMT-GUITAR_V2", "dataset4", "**", "*.wav"
102
+ # ),
103
+ # recursive=True,
104
+ # )
105
+ # files = [
106
+ # f
107
+ # for f in files
108
+ # if os.path.basename(f).split("_")[0] in idmt_guitar_splits[mode]
109
+ # ]
110
+ # file_list.append(sorted(files))
111
+ # print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
112
  # ------------------------- IDMT-SMT-BASS -------------------------
113
+ # idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
114
+ # if os.path.isdir(idmt_smt_bass_dir):
115
+ # files = glob.glob(
116
+ # os.path.join(idmt_smt_bass_dir, "**", "*.wav"),
117
+ # recursive=True,
118
+ # )
119
+ # files = [
120
+ # f
121
+ # for f in files
122
+ # if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
123
+ # ]
124
+ # file_list.append(sorted(files))
125
+ # print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
126
+ # ------------------------- DSD100 ---------------------------------
127
+ dsd_100_dir = os.path.join(root, "DSD100")
128
+ if os.path.isdir(dsd_100_dir):
129
  files = glob.glob(
130
+ os.path.join(dsd_100_dir, mode, "**", "*.wav"),
131
  recursive=True,
132
  )
133
+ file_list.append(sorted(files))
134
+ print(f"Found {len(files)} files in DSD100 {mode}.")
 
 
 
 
 
135
  # ------------------------- IDMT-SMT-DRUMS -------------------------
136
  idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
137
  if os.path.isdir(idmt_smt_drums_dir):
 
141
  for f in files
142
  if os.path.basename(f).split("_")[0] in idmt_drums_splits[mode]
143
  ]
144
+ file_list.append(sorted(files))
145
  print(f"Found {len(files)} files in IDMT-SMT-Drums {mode}.")
146
 
147
  return file_list
148
 
149
 
150
+ class EffectDataset(Dataset):
151
  def __init__(
152
  self,
153
  root: str,
154
  sample_rate: int,
155
  chunk_size: int = 262144,
156
+ total_chunks: int = 1000,
157
  effect_modules: List[Dict[str, torch.nn.Module]] = None,
158
  effects_to_use: List[str] = None,
159
  effects_to_remove: List[str] = None,
 
171
  self.root = Path(root)
172
  self.render_root = Path(render_root)
173
  self.chunk_size = chunk_size
174
+ self.total_chunks = total_chunks
175
  self.sample_rate = sample_rate
176
  self.mode = mode
177
  self.max_kept_effects = max_kept_effects
 
200
  sys.exit()
201
  shutil.rmtree(self.proc_root)
202
 
203
+ print("Total datasets:", len(self.files))
 
204
  print("Processing files...")
205
  if render_files:
206
  # Split audio file into chunks, resample, then apply random effects
207
  self.proc_root.mkdir(parents=True, exist_ok=True)
208
+ for num_chunk in tqdm(range(self.total_chunks)):
209
+ chunk = None
210
+ random_dataset_choice = random.choice(self.files)
211
+ while chunk is None:
212
+ random_file_choice = random.choice(random_dataset_choice)
213
+ chunk = select_random_chunk(
214
+ random_file_choice, self.chunk_size, self.sample_rate
215
  )
 
 
 
216
 
217
+ # Sum to mono
218
+ if chunk.shape[0] > 1:
219
+ chunk = chunk.sum(0, keepdim=True)
220
+
221
+ dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
222
+ output_dir = self.proc_root / str(num_chunk)
223
+ output_dir.mkdir(exist_ok=True)
224
+ torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
225
+ torchaudio.save(output_dir / "target.wav", dry, self.sample_rate)
226
+ torch.save(dry_effects, output_dir / "dry_effects.pt")
227
+ torch.save(wet_effects, output_dir / "wet_effects.pt")
228
+
229
+ print("Finished rendering")
230
  else:
231
+ self.total_chunks = len(list(self.proc_root.iterdir()))
232
 
233
+ print("Total chunks:", self.total_chunks)
 
 
 
234
 
235
  def __len__(self):
236
+ return self.total_chunks
237
 
238
  def __getitem__(self, idx):
239
  input_file = self.proc_root / str(idx) / "input.wav"
 
295
 
296
  # Up to max_kept_effects
297
  if self.max_kept_effects != -1:
298
+ num_kept_effects = int(torch.rand(1).item() * (self.max_kept_effects))
299
  else:
300
  num_kept_effects = len(self.effects_to_keep)
301
  effect_indices = effect_indices[:num_kept_effects]
 
306
  # Apply
307
  dry_labels = []
308
  for effect in effects_to_apply:
309
+ # Normalize in-between effects
310
+ dry = self.normalize(effect(dry))
311
  dry_labels.append(ALL_EFFECTS.index(type(effect)))
312
 
313
  # Apply effects_to_remove
 
330
 
331
  wet_labels = []
332
  for effect in effects_to_apply:
333
+ # Normalize in-between effects
334
+ wet = self.normalize(effect(wet))
335
  wet_labels.append(ALL_EFFECTS.index(type(effect)))
336
 
337
  wet_labels_tensor = torch.zeros(len(ALL_EFFECTS))
 
350
  return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
351
 
352
 
353
+ class EffectDatamodule(pl.LightningDataModule):
354
  def __init__(
355
  self,
356
  train_dataset,
remfx/dcunet.py DELETED
@@ -1,649 +0,0 @@
1
- # Adapted from https://github.com/AppleHolic/source_separation/tree/master/source_separation
2
-
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import numpy as np
8
- from torch.nn.init import calculate_gain
9
- from typing import Tuple
10
- from scipy.signal import get_window
11
- from librosa.util import pad_center
12
- from remfx.utils import single, concat_complex
13
-
14
-
15
- class ComplexConvBlock(nn.Module):
16
- """
17
- Convolution block
18
- """
19
-
20
- def __init__(
21
- self,
22
- in_channels: int,
23
- out_channels: int,
24
- kernel_size: int,
25
- padding: int = 0,
26
- layers: int = 4,
27
- bn_func=nn.BatchNorm1d,
28
- act_func=nn.LeakyReLU,
29
- skip_res: bool = False,
30
- ):
31
- super().__init__()
32
- # modules
33
- self.blocks = nn.ModuleList()
34
- self.skip_res = skip_res
35
-
36
- for idx in range(layers):
37
- in_ = in_channels if idx == 0 else out_channels
38
- self.blocks.append(
39
- nn.Sequential(
40
- *[
41
- bn_func(in_),
42
- act_func(),
43
- ComplexConv1d(in_, out_channels, kernel_size, padding=padding),
44
- ]
45
- )
46
- )
47
-
48
- def forward(self, x: torch.tensor) -> torch.tensor:
49
- temp = x
50
- for idx, block in enumerate(self.blocks):
51
- x = block(x)
52
-
53
- if temp.size() != x.size() or self.skip_res:
54
- return x
55
- else:
56
- return x + temp
57
-
58
-
59
- class SpectrogramUnet(nn.Module):
60
- def __init__(
61
- self,
62
- spec_dim: int,
63
- hidden_dim: int,
64
- filter_len: int,
65
- hop_len: int,
66
- layers: int = 3,
67
- block_layers: int = 3,
68
- kernel_size: int = 5,
69
- is_mask: bool = False,
70
- norm: str = "bn",
71
- act: str = "tanh",
72
- ):
73
- super().__init__()
74
- self.layers = layers
75
- self.is_mask = is_mask
76
-
77
- # stft modules
78
- self.stft = STFT(filter_len, hop_len)
79
-
80
- if norm == "bn":
81
- self.bn_func = nn.BatchNorm1d
82
- elif norm == "ins":
83
- self.bn_func = lambda x: nn.InstanceNorm1d(x, affine=True)
84
- else:
85
- raise NotImplementedError("{} is not implemented !".format(norm))
86
-
87
- if act == "tanh":
88
- self.act_func = nn.Tanh
89
- self.act_out = nn.Tanh
90
- elif act == "comp":
91
- self.act_func = ComplexActLayer
92
- self.act_out = lambda: ComplexActLayer(is_out=True)
93
- else:
94
- raise NotImplementedError("{} is not implemented !".format(act))
95
-
96
- # prev conv
97
- self.prev_conv = ComplexConv1d(spec_dim * 2, hidden_dim, 1)
98
-
99
- # down
100
- self.down = nn.ModuleList()
101
- self.down_pool = nn.MaxPool1d(3, stride=2, padding=1)
102
- for idx in range(self.layers):
103
- block = ComplexConvBlock(
104
- hidden_dim,
105
- hidden_dim,
106
- kernel_size=kernel_size,
107
- padding=kernel_size // 2,
108
- bn_func=self.bn_func,
109
- act_func=self.act_func,
110
- layers=block_layers,
111
- )
112
- self.down.append(block)
113
-
114
- # up
115
- self.up = nn.ModuleList()
116
- for idx in range(self.layers):
117
- in_c = hidden_dim if idx == 0 else hidden_dim * 2
118
- self.up.append(
119
- nn.Sequential(
120
- ComplexConvBlock(
121
- in_c,
122
- hidden_dim,
123
- kernel_size=kernel_size,
124
- padding=kernel_size // 2,
125
- bn_func=self.bn_func,
126
- act_func=self.act_func,
127
- layers=block_layers,
128
- ),
129
- self.bn_func(hidden_dim),
130
- self.act_func(),
131
- ComplexTransposedConv1d(
132
- hidden_dim, hidden_dim, kernel_size=2, stride=2
133
- ),
134
- )
135
- )
136
-
137
- # out_conv
138
- self.out_conv = nn.Sequential(
139
- ComplexConvBlock(
140
- hidden_dim * 2,
141
- spec_dim * 2,
142
- kernel_size=kernel_size,
143
- padding=kernel_size // 2,
144
- bn_func=self.bn_func,
145
- act_func=self.act_func,
146
- ),
147
- self.bn_func(spec_dim * 2),
148
- self.act_func(),
149
- )
150
-
151
- # refine conv
152
- self.refine_conv = nn.Sequential(
153
- ComplexConvBlock(
154
- spec_dim * 4,
155
- spec_dim * 2,
156
- kernel_size=kernel_size,
157
- padding=kernel_size // 2,
158
- bn_func=self.bn_func,
159
- act_func=self.act_func,
160
- ),
161
- self.bn_func(spec_dim * 2),
162
- self.act_func(),
163
- )
164
-
165
- def log_stft(self, wav):
166
- # stft
167
- mag, phase = self.stft.transform(wav)
168
- return torch.log(mag + 1), phase
169
-
170
- def exp_istft(self, log_mag, phase):
171
- # exp
172
- mag = np.e**log_mag - 1
173
- # istft
174
- wav = self.stft.inverse(mag, phase)
175
- return wav
176
-
177
- def adjust_diff(self, x, target):
178
- size_diff = target.size()[-1] - x.size()[-1]
179
- assert size_diff >= 0
180
- if size_diff > 0:
181
- x = F.pad(
182
- x.unsqueeze(1), (size_diff // 2, size_diff // 2), "reflect"
183
- ).squeeze(1)
184
- return x
185
-
186
- def masking(self, mag, phase, origin_mag, origin_phase):
187
- abs_mag = torch.abs(mag)
188
- mag_mask = torch.tanh(abs_mag)
189
- phase_mask = mag / abs_mag
190
-
191
- # masking
192
- mag = mag_mask * origin_mag
193
- phase = phase_mask * (origin_phase + phase)
194
- return mag, phase
195
-
196
- def forward(self, wav):
197
- # stft
198
- origin_mag, origin_phase = self.log_stft(wav)
199
- origin_x = torch.cat([origin_mag, origin_phase], dim=1)
200
-
201
- # prev
202
- x = self.prev_conv(origin_x)
203
-
204
- # body
205
- # down
206
- down_cache = []
207
- for idx, block in enumerate(self.down):
208
- x = block(x)
209
- down_cache.append(x)
210
- x = self.down_pool(x)
211
-
212
- # up
213
- for idx, block in enumerate(self.up):
214
- x = block(x)
215
- res = F.interpolate(
216
- down_cache[self.layers - (idx + 1)],
217
- size=[x.size()[2]],
218
- mode="linear",
219
- align_corners=False,
220
- )
221
- x = concat_complex(x, res, dim=1)
222
-
223
- # match spec dimension
224
- x = self.out_conv(x)
225
- if origin_mag.size(2) != x.size(2):
226
- x = F.interpolate(
227
- x, size=[origin_mag.size(2)], mode="linear", align_corners=False
228
- )
229
-
230
- # refine
231
- x = self.refine_conv(concat_complex(x, origin_x))
232
-
233
- def to_wav(stft):
234
- mag, phase = stft.chunk(2, 1)
235
- if self.is_mask:
236
- mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
237
- out = self.exp_istft(mag, phase)
238
- out = self.adjust_diff(out, wav)
239
- return out
240
-
241
- refine_wav = to_wav(x)
242
-
243
- return refine_wav
244
-
245
-
246
- class RefineSpectrogramUnet(SpectrogramUnet):
247
- def __init__(
248
- self,
249
- spec_dim: int,
250
- hidden_dim: int,
251
- filter_len: int,
252
- hop_len: int,
253
- layers: int = 4,
254
- block_layers: int = 4,
255
- kernel_size: int = 3,
256
- is_mask: bool = True,
257
- norm: str = "ins",
258
- act: str = "comp",
259
- refine_layers: int = 1,
260
- add_spec_results: bool = False,
261
- ):
262
- super().__init__(
263
- spec_dim,
264
- hidden_dim,
265
- filter_len,
266
- hop_len,
267
- layers,
268
- block_layers,
269
- kernel_size,
270
- is_mask,
271
- norm,
272
- act,
273
- )
274
- self.add_spec_results = add_spec_results
275
- # refine conv
276
- self.refine_conv = nn.ModuleList(
277
- [
278
- nn.Sequential(
279
- ComplexConvBlock(
280
- spec_dim * 2,
281
- spec_dim * 2,
282
- kernel_size=kernel_size,
283
- padding=kernel_size // 2,
284
- bn_func=self.bn_func,
285
- act_func=self.act_func,
286
- ),
287
- self.bn_func(spec_dim * 2),
288
- self.act_func(),
289
- )
290
- ]
291
- * refine_layers
292
- )
293
-
294
- def forward(self, wav):
295
- # stft
296
- origin_mag, origin_phase = self.log_stft(wav)
297
- origin_x = torch.cat([origin_mag, origin_phase], dim=1)
298
-
299
- # prev
300
- x = self.prev_conv(origin_x)
301
-
302
- # body
303
- # down
304
- down_cache = []
305
- for idx, block in enumerate(self.down):
306
- x = block(x)
307
- down_cache.append(x)
308
- x = self.down_pool(x)
309
-
310
- # up
311
- for idx, block in enumerate(self.up):
312
- x = block(x)
313
- res = F.interpolate(
314
- down_cache[self.layers - (idx + 1)],
315
- size=[x.size()[2]],
316
- mode="linear",
317
- align_corners=False,
318
- )
319
- x = concat_complex(x, res, dim=1)
320
-
321
- # match spec dimension
322
- x = self.out_conv(x)
323
- if origin_mag.size(2) != x.size(2):
324
- x = F.interpolate(
325
- x, size=[origin_mag.size(2)], mode="linear", align_corners=False
326
- )
327
-
328
- # refine
329
- for idx, refine_module in enumerate(self.refine_conv):
330
- x = refine_module(x)
331
- mag, phase = x.chunk(2, 1)
332
- mag, phase = self.masking(mag, phase, origin_mag, origin_phase)
333
- if idx < len(self.refine_conv) - 1:
334
- x = torch.cat([mag, phase], dim=1)
335
-
336
- # clamp phase
337
- phase = phase.clamp(-np.pi, np.pi)
338
-
339
- out = self.exp_istft(mag, phase)
340
- out = self.adjust_diff(out, wav)
341
-
342
- if self.add_spec_results:
343
- out = (out, mag, phase)
344
-
345
- return out
346
-
347
-
348
- class _ComplexConvNd(nn.Module):
349
- """
350
- Implement Complex Convolution
351
- A: real weight
352
- B: img weight
353
- """
354
-
355
- def __init__(
356
- self,
357
- in_channels,
358
- out_channels,
359
- kernel_size,
360
- stride,
361
- padding,
362
- dilation,
363
- transposed,
364
- output_padding,
365
- ):
366
- super().__init__()
367
- self.in_channels = in_channels
368
- self.out_channels = out_channels
369
- self.kernel_size = kernel_size
370
- self.stride = stride
371
- self.padding = padding
372
- self.dilation = dilation
373
- self.output_padding = output_padding
374
- self.transposed = transposed
375
-
376
- self.A = self.make_weight(in_channels, out_channels, kernel_size)
377
- self.B = self.make_weight(in_channels, out_channels, kernel_size)
378
-
379
- self.reset_parameters()
380
-
381
- def make_weight(self, in_ch, out_ch, kernel_size):
382
- if self.transposed:
383
- tensor = nn.Parameter(torch.Tensor(in_ch, out_ch // 2, *kernel_size))
384
- else:
385
- tensor = nn.Parameter(torch.Tensor(out_ch, in_ch // 2, *kernel_size))
386
- return tensor
387
-
388
- def reset_parameters(self):
389
- # init real weight
390
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.A)
391
-
392
- # init A
393
- gain = calculate_gain("leaky_relu", 0)
394
- std = gain / np.sqrt(fan_in)
395
- bound = np.sqrt(3.0) * std
396
-
397
- with torch.no_grad():
398
- # TODO: find more stable initial values
399
- self.A.uniform_(-bound * (1 / (np.pi**2)), bound * (1 / (np.pi**2)))
400
- #
401
- # B is initialized by pi
402
- # -pi and pi is too big, so it is powed by -1
403
- self.B.uniform_(-1 / np.pi, 1 / np.pi)
404
-
405
-
406
- class ComplexConv1d(_ComplexConvNd):
407
- """
408
- Complex Convolution 1d
409
- """
410
-
411
- def __init__(
412
- self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1
413
- ):
414
- kernel_size = single(kernel_size)
415
- stride = single(stride)
416
- # edit padding
417
- padding = padding
418
- dilation = single(dilation)
419
- super(ComplexConv1d, self).__init__(
420
- in_channels,
421
- out_channels,
422
- kernel_size,
423
- stride,
424
- padding,
425
- dilation,
426
- False,
427
- single(0),
428
- )
429
-
430
- def forward(self, x):
431
- """
432
- Implemented complex convolution using combining 'grouped convolution' and
433
- 'real / img weight'
434
- :param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
435
- :return: complex conved result
436
- """
437
- # adopt reflect padding
438
- if self.padding:
439
- x = F.pad(x, (self.padding, self.padding), "reflect")
440
-
441
- # forward real
442
- real_part = F.conv1d(
443
- x,
444
- self.A,
445
- None,
446
- stride=self.stride,
447
- padding=0,
448
- dilation=self.dilation,
449
- groups=2,
450
- )
451
-
452
- # forward idea
453
- spl = self.in_channels // 2
454
- weight_B = torch.cat([self.B[:spl].data * (-1), self.B[spl:].data])
455
- idea_part = F.conv1d(
456
- x,
457
- weight_B,
458
- None,
459
- stride=self.stride,
460
- padding=0,
461
- dilation=self.dilation,
462
- groups=2,
463
- )
464
-
465
- return real_part + idea_part
466
-
467
-
468
- class ComplexTransposedConv1d(_ComplexConvNd):
469
- """
470
- Complex Transposed Convolution 1d
471
- """
472
-
473
- def __init__(
474
- self,
475
- in_channels,
476
- out_channels,
477
- kernel_size,
478
- stride=1,
479
- padding=0,
480
- output_padding=0,
481
- dilation=1,
482
- ):
483
- kernel_size = single(kernel_size)
484
- stride = single(stride)
485
- padding = padding
486
- dilation = single(dilation)
487
- super().__init__(
488
- in_channels,
489
- out_channels,
490
- kernel_size,
491
- stride,
492
- padding,
493
- dilation,
494
- True,
495
- output_padding,
496
- )
497
-
498
- def forward(self, x, output_size=None):
499
- """
500
- Implemented complex transposed convolution using combining 'grouped convolution'
501
- and 'real / img weight'
502
- :param x: data (N, C, T) C is concatenated with C/2 real channels and C/2 idea channels
503
- :return: complex transposed convolution result
504
- """
505
- # forward real
506
- if self.padding:
507
- x = F.pad(x, (self.padding, self.padding), "reflect")
508
-
509
- real_part = F.conv_transpose1d(
510
- x,
511
- self.A,
512
- None,
513
- stride=self.stride,
514
- padding=0,
515
- dilation=self.dilation,
516
- groups=2,
517
- )
518
-
519
- # forward idea
520
- spl = self.out_channels // 2
521
- weight_B = torch.cat([self.B[:spl] * (-1), self.B[spl:]])
522
- idea_part = F.conv_transpose1d(
523
- x,
524
- weight_B,
525
- None,
526
- stride=self.stride,
527
- padding=0,
528
- dilation=self.dilation,
529
- groups=2,
530
- )
531
-
532
- if self.output_padding:
533
- real_part = F.pad(
534
- real_part, (self.output_padding, self.output_padding), "reflect"
535
- )
536
- idea_part = F.pad(
537
- idea_part, (self.output_padding, self.output_padding), "reflect"
538
- )
539
-
540
- return real_part + idea_part
541
-
542
-
543
- class ComplexActLayer(nn.Module):
544
- """
545
- Activation differently 'real' part and 'img' part
546
- In implemented DCUnet on this repository, Real part is activated to log space.
547
- And Phase(img) part, it is distributed in [-pi, pi]...
548
- """
549
-
550
- def forward(self, x):
551
- real, img = x.chunk(2, 1)
552
- return torch.cat([F.leaky_relu(real), torch.tanh(img) * np.pi], dim=1)
553
-
554
-
555
- class STFT(nn.Module):
556
- """
557
- Re-construct stft for calculating backward operation
558
- refer on : https://github.com/pseeth/torch-stft/blob/master/torch_stft/stft.py
559
- """
560
-
561
- def __init__(
562
- self,
563
- filter_length: int = 1024,
564
- hop_length: int = 512,
565
- win_length: int = None,
566
- window: str = "hann",
567
- ):
568
- super().__init__()
569
- self.filter_length = filter_length
570
- self.hop_length = hop_length
571
- self.win_length = win_length if win_length else filter_length
572
- self.window = window
573
- self.pad_amount = self.filter_length // 2
574
-
575
- # make fft window
576
- assert filter_length >= self.win_length
577
- # get window and zero center pad it to filter_length
578
- fft_window = get_window(window, self.win_length, fftbins=True)
579
- fft_window = pad_center(fft_window, filter_length)
580
- fft_window = torch.from_numpy(fft_window).float()
581
-
582
- # calculate fourer_basis
583
- cut_off = int((self.filter_length / 2 + 1))
584
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
585
- fourier_basis = np.vstack(
586
- [np.real(fourier_basis[:cut_off, :]), np.imag(fourier_basis[:cut_off, :])]
587
- )
588
-
589
- # make forward & inverse basis
590
- self.register_buffer("square_window", fft_window**2)
591
-
592
- forward_basis = torch.FloatTensor(fourier_basis[:, np.newaxis, :]) * fft_window
593
- inverse_basis = (
594
- torch.FloatTensor(
595
- np.linalg.pinv(self.filter_length / self.hop_length * fourier_basis).T[
596
- :, np.newaxis, :
597
- ]
598
- )
599
- * fft_window
600
- )
601
- # torch.pinverse has a bug, so at this time, it is separated into two parts..
602
- self.register_buffer("forward_basis", forward_basis)
603
- self.register_buffer("inverse_basis", inverse_basis)
604
-
605
- def transform(self, wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
606
- # reflect padding
607
- wav = wav.unsqueeze(1).unsqueeze(1)
608
- wav = F.pad(
609
- wav, (self.pad_amount, self.pad_amount, 0, 0), mode="reflect"
610
- ).squeeze(1)
611
-
612
- # conv
613
- forward_trans = F.conv1d(
614
- wav, self.forward_basis, stride=self.hop_length, padding=0
615
- )
616
- real_part, imag_part = forward_trans.chunk(2, 1)
617
-
618
- return torch.sqrt(real_part**2 + imag_part**2), torch.atan2(
619
- imag_part.data, real_part.data
620
- )
621
-
622
- def inverse(
623
- self, magnitude: torch.Tensor, phase: torch.Tensor, eps: float = 1e-9
624
- ) -> torch.Tensor:
625
- comp = torch.cat(
626
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
627
- )
628
- inverse_transform = F.conv_transpose1d(
629
- comp, self.inverse_basis, stride=self.hop_length, padding=0
630
- )
631
-
632
- # remove window effect
633
- n_frames = comp.size(-1)
634
- inverse_size = inverse_transform.size(-1)
635
-
636
- window_filter = torch.ones(1, 1, n_frames).type_as(inverse_transform)
637
-
638
- weight = self.square_window[: self.filter_length].unsqueeze(0).unsqueeze(0)
639
- window_filter = F.conv_transpose1d(
640
- window_filter, weight, stride=self.hop_length, padding=0
641
- )
642
- window_filter = window_filter.squeeze()[:inverse_size] + eps
643
-
644
- inverse_transform /= window_filter
645
-
646
- # scale by hop ratio
647
- inverse_transform *= self.filter_length / self.hop_length
648
-
649
- return inverse_transform[..., self.pad_amount : -self.pad_amount].squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
remfx/dptnet.py DELETED
@@ -1,459 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch.nn.modules.container import ModuleList
5
- from torch.nn.modules.activation import MultiheadAttention
6
- from torch.nn.modules.dropout import Dropout
7
- from torch.nn.modules.linear import Linear
8
- from torch.nn.modules.rnn import LSTM
9
- from torch.nn.modules.normalization import LayerNorm
10
- from torch.autograd import Variable
11
- import copy
12
- import math
13
-
14
-
15
- # adapted from https://github.com/ujscjj/DPTNet
16
-
17
-
18
- class DPTNet_base(nn.Module):
19
- def __init__(
20
- self,
21
- enc_dim,
22
- feature_dim,
23
- hidden_dim,
24
- layer,
25
- segment_size=250,
26
- nspk=2,
27
- win_len=2,
28
- ):
29
- super().__init__()
30
- # parameters
31
- self.window = win_len
32
- self.stride = self.window // 2
33
-
34
- self.enc_dim = enc_dim
35
- self.feature_dim = feature_dim
36
- self.hidden_dim = hidden_dim
37
- self.segment_size = segment_size
38
-
39
- self.layer = layer
40
- self.num_spk = nspk
41
- self.eps = 1e-8
42
-
43
- self.dpt_encoder = DPTEncoder(
44
- n_filters=enc_dim,
45
- window_size=win_len,
46
- )
47
- self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=1e-8)
48
- self.dpt_separation = DPTSeparation(
49
- self.enc_dim,
50
- self.feature_dim,
51
- self.hidden_dim,
52
- self.num_spk,
53
- self.layer,
54
- self.segment_size,
55
- )
56
-
57
- self.mask_conv1x1 = nn.Conv1d(self.feature_dim, self.enc_dim, 1, bias=False)
58
- self.decoder = DPTDecoder(n_filters=enc_dim, window_size=win_len)
59
-
60
- def forward(self, mix):
61
- """
62
- mix: shape (batch, T)
63
- """
64
- batch_size = mix.shape[0]
65
- mix = self.dpt_encoder(mix) # (B, E, L)
66
-
67
- score_ = self.enc_LN(mix) # B, E, L
68
- score_ = self.dpt_separation(score_) # B, nspk, T, N
69
- score_ = (
70
- score_.view(batch_size * self.num_spk, -1, self.feature_dim)
71
- .transpose(1, 2)
72
- .contiguous()
73
- ) # B*nspk, N, T
74
- score = self.mask_conv1x1(score_) # [B*nspk, N, L] -> [B*nspk, E, L]
75
- score = score.view(
76
- batch_size, self.num_spk, self.enc_dim, -1
77
- ) # [B*nspk, E, L] -> [B, nspk, E, L]
78
- est_mask = F.relu(score)
79
-
80
- est_source = self.decoder(
81
- mix, est_mask
82
- ) # [B, E, L] + [B, nspk, E, L]--> [B, nspk, T]
83
-
84
- return est_source
85
-
86
-
87
- class DPTEncoder(nn.Module):
88
- def __init__(self, n_filters: int = 64, window_size: int = 2):
89
- super().__init__()
90
- self.conv = nn.Conv1d(
91
- 1, n_filters, kernel_size=window_size, stride=window_size // 2, bias=False
92
- )
93
-
94
- def forward(self, x):
95
- x = x.unsqueeze(1)
96
- x = F.relu(self.conv(x))
97
- return x
98
-
99
-
100
- class TransformerEncoderLayer(torch.nn.Module):
101
- def __init__(
102
- self, d_model, nhead, hidden_size, dim_feedforward, dropout, activation="relu"
103
- ):
104
- super(TransformerEncoderLayer, self).__init__()
105
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
106
-
107
- # Implementation of improved part
108
- self.lstm = LSTM(d_model, hidden_size, 1, bidirectional=True)
109
- self.dropout = Dropout(dropout)
110
- self.linear = Linear(hidden_size * 2, d_model)
111
-
112
- self.norm1 = LayerNorm(d_model)
113
- self.norm2 = LayerNorm(d_model)
114
- self.dropout1 = Dropout(dropout)
115
- self.dropout2 = Dropout(dropout)
116
-
117
- self.activation = _get_activation_fn(activation)
118
-
119
- def __setstate__(self, state):
120
- if "activation" not in state:
121
- state["activation"] = F.relu
122
- super(TransformerEncoderLayer, self).__setstate__(state)
123
-
124
- def forward(self, src, src_mask=None, src_key_padding_mask=None):
125
- r"""Pass the input through the encoder layer.
126
- Args:
127
- src: the sequnce to the encoder layer (required).
128
- src_mask: the mask for the src sequence (optional).
129
- src_key_padding_mask: the mask for the src keys per batch (optional).
130
- Shape:
131
- see the docs in Transformer class.
132
- """
133
- src2 = self.self_attn(
134
- src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
135
- )[0]
136
- src = src + self.dropout1(src2)
137
- src = self.norm1(src)
138
- src2 = self.linear(self.dropout(self.activation(self.lstm(src)[0])))
139
- src = src + self.dropout2(src2)
140
- src = self.norm2(src)
141
- return src
142
-
143
-
144
- def _get_clones(module, N):
145
- return ModuleList([copy.deepcopy(module) for i in range(N)])
146
-
147
-
148
- def _get_activation_fn(activation):
149
- if activation == "relu":
150
- return F.relu
151
- elif activation == "gelu":
152
- return F.gelu
153
-
154
- raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
155
-
156
-
157
- class SingleTransformer(nn.Module):
158
- """
159
- Container module for a single Transformer layer.
160
- args: input_size: int, dimension of the input feature.
161
- The input should have shape (batch, seq_len, input_size).
162
- """
163
-
164
- def __init__(self, input_size, hidden_size, dropout):
165
- super(SingleTransformer, self).__init__()
166
- self.transformer = TransformerEncoderLayer(
167
- d_model=input_size,
168
- nhead=4,
169
- hidden_size=hidden_size,
170
- dim_feedforward=hidden_size * 2,
171
- dropout=dropout,
172
- )
173
-
174
- def forward(self, input):
175
- # input shape: batch, seq, dim
176
- output = input
177
- transformer_output = (
178
- self.transformer(output.permute(1, 0, 2).contiguous())
179
- .permute(1, 0, 2)
180
- .contiguous()
181
- )
182
- return transformer_output
183
-
184
-
185
- # dual-path transformer
186
- class DPT(nn.Module):
187
- """
188
- Deep dual-path transformer.
189
- args:
190
- input_size: int, dimension of the input feature. The input should have shape
191
- (batch, seq_len, input_size).
192
- hidden_size: int, dimension of the hidden state.
193
- output_size: int, dimension of the output size.
194
- num_layers: int, number of stacked Transformer layers. Default is 1.
195
- dropout: float, dropout ratio. Default is 0.
196
- """
197
-
198
- def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0):
199
- super(DPT, self).__init__()
200
-
201
- self.input_size = input_size
202
- self.output_size = output_size
203
- self.hidden_size = hidden_size
204
-
205
- # dual-path transformer
206
- self.row_transformer = nn.ModuleList([])
207
- self.col_transformer = nn.ModuleList([])
208
- for i in range(num_layers):
209
- self.row_transformer.append(
210
- SingleTransformer(input_size, hidden_size, dropout)
211
- )
212
- self.col_transformer.append(
213
- SingleTransformer(input_size, hidden_size, dropout)
214
- )
215
-
216
- # output layer
217
- self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1))
218
-
219
- def forward(self, input):
220
- # input shape: batch, N, dim1, dim2
221
- # apply transformer on dim1 first and then dim2
222
- # output shape: B, output_size, dim1, dim2
223
- # input = input.to(device)
224
- batch_size, _, dim1, dim2 = input.shape
225
- output = input
226
- for i in range(len(self.row_transformer)):
227
- row_input = (
228
- output.permute(0, 3, 2, 1)
229
- .contiguous()
230
- .view(batch_size * dim2, dim1, -1)
231
- ) # B*dim2, dim1, N
232
- row_output = self.row_transformer[i](row_input) # B*dim2, dim1, H
233
- row_output = (
234
- row_output.view(batch_size, dim2, dim1, -1)
235
- .permute(0, 3, 2, 1)
236
- .contiguous()
237
- ) # B, N, dim1, dim2
238
- output = row_output
239
-
240
- col_input = (
241
- output.permute(0, 2, 3, 1)
242
- .contiguous()
243
- .view(batch_size * dim1, dim2, -1)
244
- ) # B*dim1, dim2, N
245
- col_output = self.col_transformer[i](col_input) # B*dim1, dim2, H
246
- col_output = (
247
- col_output.view(batch_size, dim1, dim2, -1)
248
- .permute(0, 3, 1, 2)
249
- .contiguous()
250
- ) # B, N, dim1, dim2
251
- output = col_output
252
-
253
- output = self.output(output) # B, output_size, dim1, dim2
254
-
255
- return output
256
-
257
-
258
- # base module for deep DPT
259
- class DPT_base(nn.Module):
260
- def __init__(
261
- self, input_dim, feature_dim, hidden_dim, num_spk=2, layer=6, segment_size=250
262
- ):
263
- super(DPT_base, self).__init__()
264
-
265
- self.input_dim = input_dim
266
- self.feature_dim = feature_dim
267
- self.hidden_dim = hidden_dim
268
-
269
- self.layer = layer
270
- self.segment_size = segment_size
271
- self.num_spk = num_spk
272
-
273
- self.eps = 1e-8
274
-
275
- # bottleneck
276
- self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, bias=False)
277
-
278
- # DPT model
279
- self.DPT = DPT(
280
- self.feature_dim,
281
- self.hidden_dim,
282
- self.feature_dim * self.num_spk,
283
- num_layers=layer,
284
- )
285
-
286
- def pad_segment(self, input, segment_size):
287
- # input is the features: (B, N, T)
288
- batch_size, dim, seq_len = input.shape
289
- segment_stride = segment_size // 2
290
-
291
- rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
292
- if rest > 0:
293
- pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
294
- input = torch.cat([input, pad], 2)
295
-
296
- pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(
297
- input.type()
298
- )
299
- input = torch.cat([pad_aux, input, pad_aux], 2)
300
-
301
- return input, rest
302
-
303
- def split_feature(self, input, segment_size):
304
- # split the feature into chunks of segment size
305
- # input is the features: (B, N, T)
306
-
307
- input, rest = self.pad_segment(input, segment_size)
308
- batch_size, dim, seq_len = input.shape
309
- segment_stride = segment_size // 2
310
-
311
- segments1 = (
312
- input[:, :, :-segment_stride]
313
- .contiguous()
314
- .view(batch_size, dim, -1, segment_size)
315
- )
316
- segments2 = (
317
- input[:, :, segment_stride:]
318
- .contiguous()
319
- .view(batch_size, dim, -1, segment_size)
320
- )
321
- segments = (
322
- torch.cat([segments1, segments2], 3)
323
- .view(batch_size, dim, -1, segment_size)
324
- .transpose(2, 3)
325
- )
326
-
327
- return segments.contiguous(), rest
328
-
329
- def merge_feature(self, input, rest):
330
- # merge the splitted features into full utterance
331
- # input is the features: (B, N, L, K)
332
-
333
- batch_size, dim, segment_size, _ = input.shape
334
- segment_stride = segment_size // 2
335
- input = (
336
- input.transpose(2, 3)
337
- .contiguous()
338
- .view(batch_size, dim, -1, segment_size * 2)
339
- ) # B, N, K, L
340
-
341
- input1 = (
342
- input[:, :, :, :segment_size]
343
- .contiguous()
344
- .view(batch_size, dim, -1)[:, :, segment_stride:]
345
- )
346
- input2 = (
347
- input[:, :, :, segment_size:]
348
- .contiguous()
349
- .view(batch_size, dim, -1)[:, :, :-segment_stride]
350
- )
351
-
352
- output = input1 + input2
353
- if rest > 0:
354
- output = output[:, :, :-rest]
355
-
356
- return output.contiguous() # B, N, T
357
-
358
- def forward(self, input):
359
- pass
360
-
361
-
362
- class DPTSeparation(DPT_base):
363
- def __init__(self, *args, **kwargs):
364
- super(DPTSeparation, self).__init__(*args, **kwargs)
365
-
366
- # gated output layer
367
- self.output = nn.Sequential(
368
- nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Tanh()
369
- )
370
- self.output_gate = nn.Sequential(
371
- nn.Conv1d(self.feature_dim, self.feature_dim, 1), nn.Sigmoid()
372
- )
373
-
374
- def forward(self, input):
375
- # input = input.to(device)
376
- # input: (B, E, T)
377
- batch_size, E, seq_length = input.shape
378
-
379
- enc_feature = self.BN(input) # (B, E, L)-->(B, N, L)
380
- # split the encoder output into overlapped, longer segments
381
- enc_segments, enc_rest = self.split_feature(
382
- enc_feature, self.segment_size
383
- ) # B, N, L, K: L is the segment_size
384
- # print('enc_segments.shape {}'.format(enc_segments.shape))
385
- # pass to DPT
386
- output = self.DPT(enc_segments).view(
387
- batch_size * self.num_spk, self.feature_dim, self.segment_size, -1
388
- ) # B*nspk, N, L, K
389
-
390
- # overlap-and-add of the outputs
391
- output = self.merge_feature(output, enc_rest) # B*nspk, N, T
392
-
393
- # gated output layer for filter generation
394
- bf_filter = self.output(output) * self.output_gate(output) # B*nspk, K, T
395
- bf_filter = (
396
- bf_filter.transpose(1, 2)
397
- .contiguous()
398
- .view(batch_size, self.num_spk, -1, self.feature_dim)
399
- ) # B, nspk, T, N
400
-
401
- return bf_filter
402
-
403
-
404
- class DPTDecoder(nn.Module):
405
- def __init__(self, n_filters: int = 64, window_size: int = 2):
406
- super().__init__()
407
- self.W = window_size
408
- self.basis_signals = nn.Linear(n_filters, window_size, bias=False)
409
-
410
- def forward(self, mixture, mask):
411
- """
412
- mixture: (batch, n_filters, L)
413
- mask: (batch, sources, n_filters, L)
414
- """
415
- source_w = torch.unsqueeze(mixture, 1) * mask # [B, C, E, L]
416
- source_w = torch.transpose(source_w, 2, 3) # [B, C, L, E]
417
- # S = DV
418
- est_source = self.basis_signals(source_w) # [B, C, L, W]
419
- est_source = overlap_and_add(est_source, self.W // 2) # B x C x T
420
- return est_source
421
-
422
-
423
- def overlap_and_add(signal, frame_step):
424
- """Reconstructs a signal from a framed representation.
425
- Adds potentially overlapping frames of a signal with shape
426
- `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
427
- The resulting tensor has shape `[..., output_size]` where
428
- output_size = (frames - 1) * frame_step + frame_length
429
- Args:
430
- signal: A [..., frames, frame_length] Tensor.
431
- All dimensions may be unknown, and rank must be at least 2.
432
- frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
433
- Returns:
434
- A Tensor with shape [..., output_size] containing the overlap-added frames of signal's
435
- inner-most two dimensions.
436
- output_size = (frames - 1) * frame_step + frame_length
437
- Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
438
- """
439
- outer_dimensions = signal.size()[:-2]
440
- frames, frame_length = signal.size()[-2:]
441
-
442
- subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
443
- subframe_step = frame_step // subframe_length
444
- subframes_per_frame = frame_length // subframe_length
445
- output_size = frame_step * (frames - 1) + frame_length
446
- output_subframes = output_size // subframe_length
447
-
448
- subframe_signal = signal.reshape(*outer_dimensions, -1, subframe_length)
449
-
450
- frame = torch.arange(0, output_subframes).unfold(
451
- 0, subframes_per_frame, subframe_step
452
- )
453
- frame = signal.new_tensor(frame).long() # signal may in GPU or CPU
454
- frame = frame.contiguous().view(-1)
455
-
456
- result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
457
- result.index_add_(-2, frame, subframe_signal)
458
- result = result.view(*outer_dimensions, -1)
459
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
remfx/models.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import torchmetrics
3
  import pytorch_lightning as pl
4
  from torch import Tensor, nn
5
- from torch.nn import functional as F
6
  from torchaudio.models import HDemucs
7
  from audio_diffusion_pytorch import DiffusionModel
8
  from auraloss.time import SISDRLoss
@@ -14,6 +13,7 @@ from remfx.dptnet import DPTNet_base
14
  from remfx.dcunet import RefineSpectrogramUnet
15
  from remfx.tcn import TCN
16
  from remfx.utils import causal_crop
 
17
 
18
 
19
  class RemFX(pl.LightningModule):
@@ -85,6 +85,9 @@ class RemFX(pl.LightningModule):
85
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
86
 
87
  loss, output = self.model((x, y))
 
 
 
88
  self.log(f"{mode}_loss", loss)
89
  # Metric logging
90
  with torch.no_grad():
@@ -195,7 +198,7 @@ class DiffusionGenerationModel(nn.Module):
195
  class DPTNetModel(nn.Module):
196
  def __init__(self, sample_rate, num_bins, **kwargs):
197
  super().__init__()
198
- self.model = DPTNet_base(**kwargs)
199
  self.num_bins = num_bins
200
  self.mrstftloss = MultiResolutionSTFTLoss(
201
  n_bins=self.num_bins, sample_rate=sample_rate
@@ -215,7 +218,7 @@ class DPTNetModel(nn.Module):
215
  class DCUNetModel(nn.Module):
216
  def __init__(self, sample_rate, num_bins, **kwargs):
217
  super().__init__()
218
- self.model = RefineSpectrogramUnet(**kwargs)
219
  self.mrstftloss = MultiResolutionSTFTLoss(
220
  n_bins=num_bins, sample_rate=sample_rate
221
  )
@@ -223,7 +226,7 @@ class DCUNetModel(nn.Module):
223
 
224
  def forward(self, batch):
225
  x, target = batch
226
- output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
227
  # Crop target to match output
228
  if output.shape[-1] < target.shape[-1]:
229
  target = causal_crop(target, output.shape[-1])
@@ -231,7 +234,7 @@ class DCUNetModel(nn.Module):
231
  return loss, output
232
 
233
  def sample(self, x: Tensor) -> Tensor:
234
- output = self.model(x.squeeze(1)).unsqueeze(1) # B x 1 x T
235
  return output
236
 
237
 
 
2
  import torchmetrics
3
  import pytorch_lightning as pl
4
  from torch import Tensor, nn
 
5
  from torchaudio.models import HDemucs
6
  from audio_diffusion_pytorch import DiffusionModel
7
  from auraloss.time import SISDRLoss
 
13
  from remfx.dcunet import RefineSpectrogramUnet
14
  from remfx.tcn import TCN
15
  from remfx.utils import causal_crop
16
+ import asteroid
17
 
18
 
19
  class RemFX(pl.LightningModule):
 
85
  x, y, _, _ = batch # x, y = (B, C, T), (B, C, T)
86
 
87
  loss, output = self.model((x, y))
88
+ # Crop target to match output
89
+ if output.shape[-1] < y.shape[-1]:
90
+ y = causal_crop(y, output.shape[-1])
91
  self.log(f"{mode}_loss", loss)
92
  # Metric logging
93
  with torch.no_grad():
 
198
  class DPTNetModel(nn.Module):
199
  def __init__(self, sample_rate, num_bins, **kwargs):
200
  super().__init__()
201
+ self.model = asteroid.models.dptnet.DPTNet(**kwargs)
202
  self.num_bins = num_bins
203
  self.mrstftloss = MultiResolutionSTFTLoss(
204
  n_bins=self.num_bins, sample_rate=sample_rate
 
218
  class DCUNetModel(nn.Module):
219
  def __init__(self, sample_rate, num_bins, **kwargs):
220
  super().__init__()
221
+ self.model = asteroid.models.DCUNet(**kwargs)
222
  self.mrstftloss = MultiResolutionSTFTLoss(
223
  n_bins=num_bins, sample_rate=sample_rate
224
  )
 
226
 
227
  def forward(self, batch):
228
  x, target = batch
229
+ output = self.model(x.squeeze(1)) # B x T
230
  # Crop target to match output
231
  if output.shape[-1] < target.shape[-1]:
232
  target = causal_crop(target, output.shape[-1])
 
234
  return loss, output
235
 
236
  def sample(self, x: Tensor) -> Tensor:
237
+ output = self.model(x.squeeze(1)) # B x T
238
  return output
239
 
240
 
remfx/tcn.py CHANGED
@@ -128,10 +128,7 @@ class TCN(nn.Module):
128
  x_in = x
129
  for _, block in enumerate(self.process_blocks):
130
  x = block(x)
131
- # y_hat = torch.tanh(self.output(x))
132
- x_in = causal_crop(x_in, x.shape[-1])
133
- gain_ln = self.output(x)
134
- y_hat = torch.tanh(gain_ln * x_in)
135
  return y_hat
136
 
137
  def compute_receptive_field(self):
 
128
  x_in = x
129
  for _, block in enumerate(self.process_blocks):
130
  x = block(x)
131
+ y_hat = torch.tanh(self.output(x))
 
 
 
132
  return y_hat
133
 
134
  def compute_receptive_field(self):
remfx/utils.py CHANGED
@@ -127,10 +127,10 @@ def create_random_chunks(
127
 
128
 
129
  def create_sequential_chunks(
130
- audio_file: str, chunk_size: int
131
- ) -> Tuple[List[Tuple[int, int]], int]:
132
- """Create sequential chunks of size chunk_size (seconds) from an audio file.
133
- Return sample_index of start of each chunk and original sr
134
  """
135
  chunks = []
136
  audio, sr = torchaudio.load(audio_file)
@@ -138,8 +138,31 @@ def create_sequential_chunks(
138
  for start in chunk_starts:
139
  if start + chunk_size > audio.shape[-1]:
140
  break
141
- chunks.append(audio[:, start : start + chunk_size])
142
- return chunks, sr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  def spectrogram(
 
127
 
128
 
129
  def create_sequential_chunks(
130
+ audio_file: str, chunk_size: int, sample_rate: int
131
+ ) -> List[torch.Tensor]:
132
+ """Create sequential chunks of size chunk_size from an audio file.
133
+ Return each chunk
134
  """
135
  chunks = []
136
  audio, sr = torchaudio.load(audio_file)
 
138
  for start in chunk_starts:
139
  if start + chunk_size > audio.shape[-1]:
140
  break
141
+ chunk = audio[:, start : start + chunk_size]
142
+ resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
143
+ # Skip chunks that are too short
144
+ if resampled_chunk.shape[-1] < chunk_size:
145
+ continue
146
+ chunks.append(chunk)
147
+ return chunks
148
+
149
+
150
+ def select_random_chunk(
151
+ audio_file: str, chunk_size: int, sample_rate: int
152
+ ) -> List[torch.Tensor]:
153
+ """Select random chunk of size chunk_size (samples) from an audio file."""
154
+ audio, sr = torchaudio.load(audio_file)
155
+ new_chunk_size = int(chunk_size * (sr / sample_rate))
156
+ if new_chunk_size >= audio.shape[-1]:
157
+ return None
158
+ max_len = audio.shape[-1] - new_chunk_size
159
+ random_start = torch.randint(0, max_len, (1,)).item()
160
+ chunk = audio[:, random_start : random_start + new_chunk_size]
161
+ # Skip if energy too low
162
+ if torch.mean(torch.abs(chunk)) < 1e-6:
163
+ return None
164
+ resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
165
+ return resampled_chunk
166
 
167
 
168
  def spectrogram(
scripts/download.py CHANGED
@@ -1,8 +1,6 @@
1
  import os
2
- import sys
3
- import glob
4
- import torch
5
  import argparse
 
6
 
7
 
8
  def download_zip_dataset(dataset_url: str, output_dir: str):
@@ -26,8 +24,42 @@ def process_dataset(dataset_dir: str, output_dir: str):
26
  pass
27
  elif dataset_dir == "IDMT-SMT-DRUMS-V2":
28
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  else:
30
- raise NotImplemented(f"Invalid dataset_dir = {dataset_dir}.")
31
 
32
 
33
  if __name__ == "__main__":
@@ -38,7 +70,7 @@ if __name__ == "__main__":
38
  "vocalset",
39
  "guitarset",
40
  "idmt-smt-guitar",
41
- "idmt-smt-bass",
42
  "idmt-smt-drums",
43
  ],
44
  nargs="+",
@@ -49,10 +81,11 @@ if __name__ == "__main__":
49
  "vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
50
  "guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
51
  "IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
52
- "IDMT-SMT-BASS": "https://zenodo.org/record/7188892/files/IDMT-SMT-BASS.zip",
53
  "IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
54
  }
55
 
56
  for dataset_name, dataset_url in dataset_urls.items():
57
  if dataset_name in args.dataset_names:
58
  download_zip_dataset(dataset_url, "~/data/remfx-data")
 
 
1
  import os
 
 
 
2
  import argparse
3
+ import shutil
4
 
5
 
6
  def download_zip_dataset(dataset_url: str, output_dir: str):
 
24
  pass
25
  elif dataset_dir == "IDMT-SMT-DRUMS-V2":
26
  pass
27
+ elif dataset_dir == "DSD100":
28
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, "Mixtures"))
29
+ for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Dev")):
30
+ source = os.path.join(output_dir, dataset_dir, "Sources", "Dev", dir)
31
+ shutil.move(source, os.path.join(output_dir, dataset_dir))
32
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Dev"))
33
+ for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Test")):
34
+ source = os.path.join(output_dir, dataset_dir, "Sources", "Test", dir)
35
+ shutil.move(source, os.path.join(output_dir, dataset_dir))
36
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Test"))
37
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources"))
38
+
39
+ os.mkdir(os.path.join(output_dir, dataset_dir, "train"))
40
+ os.mkdir(os.path.join(output_dir, dataset_dir, "val"))
41
+ os.mkdir(os.path.join(output_dir, dataset_dir, "test"))
42
+ files = os.listdir(os.path.join(output_dir, dataset_dir))
43
+
44
+ num = 0
45
+ for dir in files:
46
+ if not os.path.isdir(os.path.join(output_dir, dataset_dir, dir)):
47
+ continue
48
+ if dir == "train" or dir == "val" or dir == "test":
49
+ continue
50
+ source = os.path.join(output_dir, dataset_dir, dir, "bass.wav")
51
+ if num < 80:
52
+ dest = os.path.join(output_dir, dataset_dir, "train", f"{num}.wav")
53
+ elif num < 90:
54
+ dest = os.path.join(output_dir, dataset_dir, "val", f"{num}.wav")
55
+ else:
56
+ dest = os.path.join(output_dir, dataset_dir, "test", f"{num}.wav")
57
+ shutil.move(source, dest)
58
+ shutil.rmtree(os.path.join(output_dir, dataset_dir, dir))
59
+ num += 1
60
+
61
  else:
62
+ raise NotImplementedError(f"Invalid dataset_dir = {dataset_dir}.")
63
 
64
 
65
  if __name__ == "__main__":
 
70
  "vocalset",
71
  "guitarset",
72
  "idmt-smt-guitar",
73
+ "dsd100",
74
  "idmt-smt-drums",
75
  ],
76
  nargs="+",
 
81
  "vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
82
  "guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
83
  "IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
84
+ "DSD100": "http://liutkus.net/DSD100.zip",
85
  "IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
86
  }
87
 
88
  for dataset_name, dataset_url in dataset_urls.items():
89
  if dataset_name in args.dataset_names:
90
  download_zip_dataset(dataset_url, "~/data/remfx-data")
91
+ process_dataset(dataset_name, "~/data/remfx-data")
setup.py CHANGED
@@ -48,6 +48,7 @@ setup(
48
  "pedalboard",
49
  "frechet_audio_distance",
50
  "ordered-set",
 
51
  ],
52
  include_package_data=True,
53
  license="Apache License 2.0",
 
48
  "pedalboard",
49
  "frechet_audio_distance",
50
  "ordered-set",
51
+ "asteroid",
52
  ],
53
  include_package_data=True,
54
  license="Apache License 2.0",
shell_vars.sh CHANGED
@@ -1,3 +1,3 @@
1
- export DATASET_ROOT="./data/VocalSet"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"
 
1
+ export DATASET_ROOT="./data/remfx-data"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"