Spaces:
Running
Running
File size: 3,183 Bytes
a00b67a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import math
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
# Modified version from woosungchoi's original implementation
class SingleTrackSet(Dataset):
def __init__(self, track, hop_length, num_frame=128, target_name="vocals"):
assert len(track.shape) == 2
assert track.shape[0] == 2 # check stereo audio
self.hop_length = hop_length
self.window_length = hop_length * (num_frame - 1) # 130048
self.trim_length = self.get_trim_length(self.hop_length) # 5120
self.true_samples = self.window_length - 2 * self.trim_length # 119808
self.lengths = [track.shape[1]] # track lengths (in sample level)
self.source_names = [
"vocals",
"drums",
"bass",
"other",
] # == self.musdb_train.targets_names[:-2]
self.target_names = [target_name]
self.num_tracks = 1
import math
num_chunks = [
math.ceil(length / self.true_samples) for length in self.lengths
] # example : 44.1khz 180sec audio, => [67]
self.acc_chunk_final_ids = [
sum(num_chunks[: i + 1]) for i in range(self.num_tracks)
] # [67]
self.cache_mode = True
self.cache = {}
self.cache[0] = {}
self.cache[0]["linear_mixture"] = track
def __len__(self):
return self.acc_chunk_final_ids[-1] * len(self.target_names) # 67
def __getitem__(self, idx):
target_offset = idx % len(self.target_names) # 0
idx = idx // len(self.target_names) # idx
target_name = self.target_names[target_offset] # 'vocals'
mixture_idx, start_pos = self.idx_to_track_offset(
idx
) # idx * self.true_samples
length = self.true_samples
left_padding_num = right_padding_num = self.trim_length # 5120
if mixture_idx is None:
raise StopIteration
mixture_length = self.lengths[mixture_idx]
if start_pos + length > mixture_length: # last
right_padding_num += self.true_samples - (mixture_length - start_pos)
length = None
mixture = self.get_audio(mixture_idx, "linear_mixture", start_pos, length)
mixture = F.pad(mixture, (left_padding_num, right_padding_num), "constant", 0)
return mixture
def idx_to_track_offset(self, idx):
for i, last_chunk in enumerate(self.acc_chunk_final_ids):
if idx < last_chunk:
if i != 0:
offset = (idx - self.acc_chunk_final_ids[i - 1]) * self.true_samples
else:
offset = idx * self.true_samples
return i, offset
return None, None
def get_audio(self, idx, target_name, pos=0, length=None):
track = self.cache[idx][target_name]
return track[:, pos : pos + length] if length is not None else track[:, pos:]
def get_trim_length(self, hop_length, min_trim=5000):
trim_per_hop = math.ceil(min_trim / hop_length)
trim_length = trim_per_hop * hop_length
assert trim_per_hop > 1
return trim_length
|