Spaces:
Runtime error
Runtime error
"""Base dataset classes.""" | |
import logging | |
import math | |
import random | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torchaudio | |
from torch.utils.data import Dataset | |
from torch.utils.data.dataset import T_co | |
LOGGER = logging.getLogger(__name__) | |
SAMPLING_RATE = 16_000 | |
APPLY_NORMALIZATION = True | |
APPLY_TRIMMING = True | |
APPLY_PADDING = True | |
FRAMES_NUMBER = 480_000 # <- originally 64_600 | |
SOX_SILENCE = [ | |
# trim all silence that is longer than 0.2s and louder than 1% volume (relative to the file) | |
# from beginning and middle/end | |
["silence", "1", "0.2", "1%", "-1", "0.2", "1%"], | |
] | |
class SimpleAudioFakeDataset(Dataset): | |
def __init__( | |
self, | |
subset, | |
transform=None, | |
return_label: bool = True, | |
return_meta: bool = False, | |
): | |
self.transform = transform | |
self.samples = pd.DataFrame() | |
self.subset = subset | |
self.allowed_attacks = None | |
self.partition_ratio = None | |
self.seed = None | |
self.return_label = return_label | |
self.return_meta = return_meta | |
def split_samples(self, samples_list): | |
if isinstance(samples_list, pd.DataFrame): | |
samples_list = samples_list.sort_values(by=list(samples_list.columns)) | |
samples_list = samples_list.sample(frac=1, random_state=self.seed) | |
else: | |
samples_list = sorted(samples_list) | |
random.seed(self.seed) | |
random.shuffle(samples_list) | |
p, s = self.partition_ratio | |
subsets = np.split( | |
samples_list, [int(p * len(samples_list)), int((p + s) * len(samples_list))] | |
) | |
return dict(zip(["train", "test", "val"], subsets))[self.subset] | |
def df2tuples(self): | |
tuple_samples = [] | |
for i, elem in self.samples.iterrows(): | |
tuple_samples.append( | |
(str(elem["path"]), elem["label"], elem["attack_type"]) | |
) | |
self.samples = tuple_samples | |
return self.samples | |
def __getitem__(self, index) -> T_co: | |
if isinstance(self.samples, pd.DataFrame): | |
sample = self.samples.iloc[index] | |
path = str(sample["path"]) | |
label = sample["label"] | |
attack_type = sample["attack_type"] | |
if type(attack_type) != str and math.isnan(attack_type): | |
attack_type = "N/A" | |
else: | |
path, label, attack_type = self.samples[index] | |
waveform, sample_rate = torchaudio.load(path, normalize=APPLY_NORMALIZATION) | |
real_sec_length = len(waveform[0]) / sample_rate | |
waveform, sample_rate = apply_preprocessing(waveform, sample_rate) | |
return_data = [waveform, sample_rate] | |
if self.return_label: | |
label = 1 if label == "bonafide" else 0 | |
return_data.append(label) | |
if self.return_meta: | |
return_data.append( | |
( | |
attack_type, | |
path, | |
self.subset, | |
real_sec_length, | |
) | |
) | |
return return_data | |
def __len__(self): | |
return len(self.samples) | |
def apply_preprocessing( | |
waveform, | |
sample_rate, | |
): | |
if sample_rate != SAMPLING_RATE and SAMPLING_RATE != -1: | |
waveform, sample_rate = resample_wave(waveform, sample_rate, SAMPLING_RATE) | |
# Stereo to mono | |
if waveform.dim() > 1 and waveform.shape[0] > 1: | |
waveform = waveform[:1, ...] | |
# Trim too long utterances... | |
if APPLY_TRIMMING: | |
waveform, sample_rate = apply_trim(waveform, sample_rate) | |
# ... or pad too short ones. | |
if APPLY_PADDING: | |
waveform = apply_pad(waveform, FRAMES_NUMBER) | |
return waveform, sample_rate | |
def resample_wave(waveform, sample_rate, target_sample_rate): | |
waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( | |
waveform, sample_rate, [["rate", f"{target_sample_rate}"]] | |
) | |
return waveform, sample_rate | |
def resample_file(path, target_sample_rate, normalize=True): | |
waveform, sample_rate = torchaudio.sox_effects.apply_effects_file( | |
path, [["rate", f"{target_sample_rate}"]], normalize=normalize | |
) | |
return waveform, sample_rate | |
def apply_trim(waveform, sample_rate): | |
( | |
waveform_trimmed, | |
sample_rate_trimmed, | |
) = torchaudio.sox_effects.apply_effects_tensor(waveform, sample_rate, SOX_SILENCE) | |
if waveform_trimmed.size()[1] > 0: | |
waveform = waveform_trimmed | |
sample_rate = sample_rate_trimmed | |
return waveform, sample_rate | |
def apply_pad(waveform, cut): | |
"""Pad wave by repeating signal until `cut` length is achieved.""" | |
waveform = waveform.squeeze(0) | |
waveform_len = waveform.shape[0] | |
if waveform_len >= cut: | |
return waveform[:cut] | |
# need to pad | |
num_repeats = int(cut / waveform_len) + 1 | |
padded_waveform = torch.tile(waveform, (1, num_repeats))[:, :cut][0] | |
return padded_waveform | |