Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import random | |
| from torch.utils.data import ConcatDataset, Dataset | |
| from torch.utils.data.sampler import ( | |
| BatchSampler, | |
| RandomSampler, | |
| Sampler, | |
| SequentialSampler, | |
| ) | |
| class ScheduledSampler(Sampler): | |
| """A sampler that samples data from a given concat-dataset. | |
| Args: | |
| concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets | |
| batch_size (int): batch size | |
| holistic_shuffle (bool): whether to shuffle the whole dataset or not | |
| logger (logging.Logger): logger to print warning message | |
| Usage: | |
| For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True: | |
| >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]]))) | |
| [3, 4, 5, 0, 1, 2, 6, 7, 8] | |
| """ | |
| def __init__( | |
| self, | |
| concat_dataset, | |
| batch_size, | |
| holistic_shuffle, | |
| logger=None, | |
| loader_type="train", | |
| ): | |
| if not isinstance(concat_dataset, ConcatDataset): | |
| raise ValueError( | |
| "concat_dataset must be an instance of ConcatDataset, but got {}".format( | |
| type(concat_dataset) | |
| ) | |
| ) | |
| if not isinstance(batch_size, int): | |
| raise ValueError( | |
| "batch_size must be an integer, but got {}".format(type(batch_size)) | |
| ) | |
| if not isinstance(holistic_shuffle, bool): | |
| raise ValueError( | |
| "holistic_shuffle must be a boolean, but got {}".format( | |
| type(holistic_shuffle) | |
| ) | |
| ) | |
| self.concat_dataset = concat_dataset | |
| self.batch_size = batch_size | |
| self.holistic_shuffle = holistic_shuffle | |
| affected_dataset_name = [] | |
| affected_dataset_len = [] | |
| for dataset in concat_dataset.datasets: | |
| dataset_len = len(dataset) | |
| dataset_name = dataset.get_dataset_name() | |
| if dataset_len < batch_size: | |
| affected_dataset_name.append(dataset_name) | |
| affected_dataset_len.append(dataset_len) | |
| self.type = loader_type | |
| for dataset_name, dataset_len in zip( | |
| affected_dataset_name, affected_dataset_len | |
| ): | |
| if not loader_type == "valid": | |
| logger.warning( | |
| "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format( | |
| loader_type, dataset_name, dataset_len, batch_size | |
| ) | |
| ) | |
| def __len__(self): | |
| # the number of batches with drop last | |
| num_of_batches = sum( | |
| [ | |
| math.floor(len(dataset) / self.batch_size) | |
| for dataset in self.concat_dataset.datasets | |
| ] | |
| ) | |
| # if samples are not enough for one batch, we don't drop last | |
| if self.type == "valid" and num_of_batches < 1: | |
| return len(self.concat_dataset) | |
| return num_of_batches * self.batch_size | |
| def __iter__(self): | |
| iters = [] | |
| for dataset in self.concat_dataset.datasets: | |
| iters.append( | |
| SequentialSampler(dataset).__iter__() | |
| if not self.holistic_shuffle | |
| else RandomSampler(dataset).__iter__() | |
| ) | |
| # e.g. [0, 200, 400] | |
| init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1] | |
| output_batches = [] | |
| for dataset_idx in range(len(self.concat_dataset.datasets)): | |
| cur_batch = [] | |
| for idx in iters[dataset_idx]: | |
| cur_batch.append(idx + init_indices[dataset_idx]) | |
| if len(cur_batch) == self.batch_size: | |
| output_batches.append(cur_batch) | |
| cur_batch = [] | |
| # if loader_type is valid, we don't need to drop last | |
| if self.type == "valid" and len(cur_batch) > 0: | |
| output_batches.append(cur_batch) | |
| # force drop last in training | |
| random.shuffle(output_batches) | |
| output_indices = [item for sublist in output_batches for item in sublist] | |
| return iter(output_indices) | |
| def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type): | |
| sampler = ScheduledSampler( | |
| concat_dataset, | |
| cfg.train.batch_size, | |
| cfg.train.sampler.holistic_shuffle, | |
| logger, | |
| loader_type, | |
| ) | |
| batch_sampler = BatchSampler( | |
| sampler, | |
| cfg.train.batch_size, | |
| cfg.train.sampler.drop_last if not loader_type == "valid" else False, | |
| ) | |
| return sampler, batch_sampler | |