mattricesound commited on
Commit
5e4e307
·
1 Parent(s): 056a44f

Improve speed of selecting random chunks

Browse files
Files changed (2) hide show
  1. remfx/datasets.py +10 -8
  2. remfx/utils.py +15 -1
remfx/datasets.py CHANGED
@@ -12,7 +12,7 @@ from remfx import effects
12
  from ordered_set import OrderedSet
13
  from typing import Any, List, Dict
14
  from torch.utils.data import Dataset, DataLoader
15
- from remfx.utils import create_sequential_chunks
16
 
17
 
18
  # https://zenodo.org/record/1193957 -> VocalSet
@@ -205,21 +205,23 @@ class EffectDataset(Dataset):
205
  if render_files:
206
  # Split audio file into chunks, resample, then apply random effects
207
  self.proc_root.mkdir(parents=True, exist_ok=True)
 
208
  for num_chunk in tqdm(range(self.total_chunks)):
209
- chunks = []
210
- while len(chunks) == 0:
211
  random_dataset_choice = random.choice(self.files)
212
  random_file_choice = random.choice(random_dataset_choice)
213
- chunks = create_sequential_chunks(
 
 
214
  random_file_choice, self.chunk_size, self.sample_rate
215
  )
216
- random_chunk = random.choice(chunks)
217
 
218
  # Sum to mono
219
- if random_chunk.shape[0] > 1:
220
- random_chunk = random_chunk.sum(0, keepdim=True)
221
 
222
- dry, wet, dry_effects, wet_effects = self.process_effects(random_chunk)
223
  output_dir = self.proc_root / str(num_chunk)
224
  output_dir.mkdir(exist_ok=True)
225
  torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
 
12
  from ordered_set import OrderedSet
13
  from typing import Any, List, Dict
14
  from torch.utils.data import Dataset, DataLoader
15
+ from remfx.utils import select_random_chunk
16
 
17
 
18
  # https://zenodo.org/record/1193957 -> VocalSet
 
205
  if render_files:
206
  # Split audio file into chunks, resample, then apply random effects
207
  self.proc_root.mkdir(parents=True, exist_ok=True)
208
+ bad_files = set()
209
  for num_chunk in tqdm(range(self.total_chunks)):
210
+ chunk = None
211
+ while chunk is None:
212
  random_dataset_choice = random.choice(self.files)
213
  random_file_choice = random.choice(random_dataset_choice)
214
+ if random_file_choice in bad_files:
215
+ continue
216
+ chunk = select_random_chunk(
217
  random_file_choice, self.chunk_size, self.sample_rate
218
  )
 
219
 
220
  # Sum to mono
221
+ if chunk.shape[0] > 1:
222
+ chunk = chunk.sum(0, keepdim=True)
223
 
224
+ dry, wet, dry_effects, wet_effects = self.process_effects(chunk)
225
  output_dir = self.proc_root / str(num_chunk)
226
  output_dir.mkdir(exist_ok=True)
227
  torchaudio.save(output_dir / "input.wav", wet, self.sample_rate)
remfx/utils.py CHANGED
@@ -129,7 +129,7 @@ def create_random_chunks(
129
  def create_sequential_chunks(
130
  audio_file: str, chunk_size: int, sample_rate: int
131
  ) -> List[torch.Tensor]:
132
- """Create sequential chunks of size chunk_size (seconds) from an audio file.
133
  Return sample_index of start of each chunk and original sr
134
  """
135
  chunks = []
@@ -147,6 +147,20 @@ def create_sequential_chunks(
147
  return chunks
148
 
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def spectrogram(
151
  x: torch.Tensor,
152
  window: torch.Tensor,
 
129
  def create_sequential_chunks(
130
  audio_file: str, chunk_size: int, sample_rate: int
131
  ) -> List[torch.Tensor]:
132
+ """Create sequential chunks of size chunk_size from an audio file.
133
  Return sample_index of start of each chunk and original sr
134
  """
135
  chunks = []
 
147
  return chunks
148
 
149
 
150
+ def select_random_chunk(
151
+ audio_file: str, chunk_size: int, sample_rate: int
152
+ ) -> List[torch.Tensor]:
153
+ """Create sequential chunks of size chunk_size (samples) from an audio file.
154
+ Return sample_index of start of each chunk and original sr
155
+ """
156
+ audio, sr = torchaudio.load(audio_file)
157
+ max_len = audio.shape[-1] - int(chunk_size * (sample_rate / sr))
158
+ random_start = torch.randint(0, max_len, (1,)).item()
159
+ chunk = audio[:, random_start : random_start + chunk_size]
160
+ resampled_chunk = torchaudio.functional.resample(chunk, sr, sample_rate)
161
+ return resampled_chunk
162
+
163
+
164
  def spectrogram(
165
  x: torch.Tensor,
166
  window: torch.Tensor,