Spaces:
Running
on
L4
Running
on
L4
| import bisect | |
| import random | |
| from typing import Iterable | |
| from torch.utils.data import Dataset, IterableDataset | |
| class ConcatRepeatDataset(Dataset): | |
| datasets: list[Dataset] | |
| cumulative_sizes: list[int] | |
| repeats: list[int] | |
| def cumsum(sequence, repeats): | |
| r, s = [], 0 | |
| for dataset, repeat in zip(sequence, repeats): | |
| l = len(dataset) * repeat | |
| r.append(l + s) | |
| s += l | |
| return r | |
| def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): | |
| super().__init__() | |
| self.datasets = list(datasets) | |
| self.repeats = repeats | |
| assert len(self.datasets) > 0, "datasets should not be an empty iterable" | |
| assert len(self.datasets) == len( | |
| repeats | |
| ), "datasets and repeats should have the same length" | |
| for d in self.datasets: | |
| assert not isinstance( | |
| d, IterableDataset | |
| ), "ConcatRepeatDataset does not support IterableDataset" | |
| self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) | |
| def __len__(self): | |
| return self.cumulative_sizes[-1] | |
| def __getitem__(self, idx): | |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | |
| if dataset_idx == 0: | |
| sample_idx = idx | |
| else: | |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | |
| dataset = self.datasets[dataset_idx] | |
| return dataset[sample_idx % len(dataset)] | |