Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from typing import Iterator, Optional, Sequence, Sized | |
import torch | |
from mmengine.dist import get_dist_info, sync_random_seed | |
from mmengine.registry import DATA_SAMPLERS | |
from torch.utils.data import Sampler | |
class MultiDataSampler(Sampler): | |
"""The default data sampler for both distributed and non-distributed | |
environment. | |
It has several differences from the PyTorch ``DistributedSampler`` as | |
below: | |
1. This sampler supports non-distributed environment. | |
2. The round up behaviors are a little different. | |
- If ``round_up=True``, this sampler will add extra samples to make the | |
number of samples is evenly divisible by the world size. And | |
this behavior is the same as the ``DistributedSampler`` with | |
``drop_last=False``. | |
- If ``round_up=False``, this sampler won't remove or add any samples | |
while the ``DistributedSampler`` with ``drop_last=True`` will remove | |
tail samples. | |
Args: | |
dataset (Sized): The dataset. | |
dataset_ratio (Sequence(int)) The ratios of different datasets. | |
seed (int, optional): Random seed used to shuffle the sampler if | |
:attr:`shuffle=True`. This number should be identical across all | |
processes in the distributed group. Defaults to None. | |
round_up (bool): Whether to add extra samples to make the number of | |
samples evenly divisible by the world size. Defaults to True. | |
""" | |
def __init__(self, | |
dataset: Sized, | |
dataset_ratio: Sequence[int], | |
seed: Optional[int] = None, | |
round_up: bool = True) -> None: | |
rank, world_size = get_dist_info() | |
self.rank = rank | |
self.world_size = world_size | |
self.dataset = dataset | |
self.dataset_ratio = dataset_ratio | |
if seed is None: | |
seed = sync_random_seed() | |
self.seed = seed | |
self.epoch = 0 | |
self.round_up = round_up | |
if self.round_up: | |
self.num_samples = math.ceil(len(self.dataset) / world_size) | |
self.total_size = self.num_samples * self.world_size | |
else: | |
self.num_samples = math.ceil( | |
(len(self.dataset) - rank) / world_size) | |
self.total_size = len(self.dataset) | |
self.sizes = [len(dataset) for dataset in self.dataset.datasets] | |
dataset_weight = [ | |
torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio) | |
for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes)) | |
] | |
self.weights = torch.cat(dataset_weight) | |
def __iter__(self) -> Iterator[int]: | |
"""Iterate the indices.""" | |
# deterministically shuffle based on epoch and seed | |
g = torch.Generator() | |
g.manual_seed(self.seed + self.epoch) | |
indices = torch.multinomial( | |
self.weights, len(self.weights), generator=g, | |
replacement=True).tolist() | |
# add extra samples to make it evenly divisible | |
if self.round_up: | |
indices = ( | |
indices * | |
int(self.total_size / len(indices) + 1))[:self.total_size] | |
# subsample | |
indices = indices[self.rank:self.total_size:self.world_size] | |
return iter(indices) | |
def __len__(self) -> int: | |
"""The number of samples in this rank.""" | |
return self.num_samples | |
def set_epoch(self, epoch: int) -> None: | |
"""Sets the epoch for this sampler. | |
When :attr:`shuffle=True`, this ensures all replicas use a different | |
random ordering for each epoch. Otherwise, the next iteration of this | |
sampler will yield the same ordering. | |
Args: | |
epoch (int): Epoch number. | |
""" | |
self.epoch = epoch | |