denoising / datasets /__init__.py
BorisovMaksim's picture
rewrote demucs model
e2b0b28
raw
history blame
615 Bytes
from typing import Tuple
from torch.utils.data import Dataset
from datasets.valentini import Valentini
from transforms import Transform
DATASETS_POOL = {
'valentini': Valentini
}
def get_datasets(cfg) -> Tuple[Dataset, Dataset]:
name, dataset_params = list(cfg['dataset'].items())[0]
transform = Transform(input_sample_rate=dataset_params['sample_rate'], **cfg['dataloader'])
train_dataset = DATASETS_POOL[name](valid=False, transform=transform, **dataset_params)
valid_dataset = DATASETS_POOL[name](valid=True, transform=transform, **dataset_params)
return train_dataset, valid_dataset