Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import random | |
from typing import Iterator, Optional, Sized | |
import numpy as np | |
from mmengine.dataset import ClassBalancedDataset, ConcatDataset | |
from mmengine.dist import get_dist_info, sync_random_seed | |
from torch.utils.data import Sampler | |
from mmdet.registry import DATA_SAMPLERS | |
from ..base_video_dataset import BaseVideoDataset | |
class TrackImgSampler(Sampler): | |
"""Sampler that providing image-level sampling outputs for video datasets | |
in tracking tasks. It could be both used in both distributed and | |
non-distributed environment. | |
If using the default sampler in pytorch, the subsequent data receiver will | |
get one video, which is not desired in some cases: | |
(Take a non-distributed environment as an example) | |
1. In test mode, we want only one image is fed into the data pipeline. This | |
is in consideration of memory usage since feeding the whole video commonly | |
requires a large amount of memory (>=20G on MOTChallenge17 dataset), which | |
is not available in some machines. | |
2. In training mode, we may want to make sure all the images in one video | |
are randomly sampled once in one epoch and this can not be guaranteed in | |
the default sampler in pytorch. | |
Args: | |
dataset (Sized): Dataset used for sampling. | |
seed (int, optional): random seed used to shuffle the sampler. This | |
number should be identical across all processes in the distributed | |
group. Defaults to None. | |
""" | |
def __init__( | |
self, | |
dataset: Sized, | |
seed: Optional[int] = None, | |
) -> None: | |
rank, world_size = get_dist_info() | |
self.rank = rank | |
self.world_size = world_size | |
self.epoch = 0 | |
if seed is None: | |
self.seed = sync_random_seed() | |
else: | |
self.seed = seed | |
self.dataset = dataset | |
self.indices = [] | |
# Hard code here to handle different dataset wrapper | |
if isinstance(self.dataset, ConcatDataset): | |
cat_datasets = self.dataset.datasets | |
assert isinstance( | |
cat_datasets[0], BaseVideoDataset | |
), f'expected BaseVideoDataset, but got {type(cat_datasets[0])}' | |
self.test_mode = cat_datasets[0].test_mode | |
assert not self.test_mode, "'ConcatDataset' should not exist in " | |
'test mode' | |
for dataset in cat_datasets: | |
num_videos = len(dataset) | |
for video_ind in range(num_videos): | |
self.indices.extend([ | |
(video_ind, frame_ind) for frame_ind in range( | |
dataset.get_len_per_video(video_ind)) | |
]) | |
elif isinstance(self.dataset, ClassBalancedDataset): | |
ori_dataset = self.dataset.dataset | |
assert isinstance( | |
ori_dataset, BaseVideoDataset | |
), f'expected BaseVideoDataset, but got {type(ori_dataset)}' | |
self.test_mode = ori_dataset.test_mode | |
assert not self.test_mode, "'ClassBalancedDataset' should not " | |
'exist in test mode' | |
video_indices = self.dataset.repeat_indices | |
for index in video_indices: | |
self.indices.extend([(index, frame_ind) for frame_ind in range( | |
ori_dataset.get_len_per_video(index))]) | |
else: | |
assert isinstance( | |
self.dataset, BaseVideoDataset | |
), 'TrackImgSampler is only supported in BaseVideoDataset or ' | |
'dataset wrapper: ClassBalancedDataset and ConcatDataset, but ' | |
f'got {type(self.dataset)} ' | |
self.test_mode = self.dataset.test_mode | |
num_videos = len(self.dataset) | |
if self.test_mode: | |
# in test mode, the images belong to the same video must be put | |
# on the same device. | |
if num_videos < self.world_size: | |
raise ValueError(f'only {num_videos} videos loaded,' | |
f'but {self.world_size} gpus were given.') | |
chunks = np.array_split( | |
list(range(num_videos)), self.world_size) | |
for videos_inds in chunks: | |
indices_chunk = [] | |
for video_ind in videos_inds: | |
indices_chunk.extend([ | |
(video_ind, frame_ind) for frame_ind in range( | |
self.dataset.get_len_per_video(video_ind)) | |
]) | |
self.indices.append(indices_chunk) | |
else: | |
for video_ind in range(num_videos): | |
self.indices.extend([ | |
(video_ind, frame_ind) for frame_ind in range( | |
self.dataset.get_len_per_video(video_ind)) | |
]) | |
if self.test_mode: | |
self.num_samples = len(self.indices[self.rank]) | |
self.total_size = sum( | |
[len(index_list) for index_list in self.indices]) | |
else: | |
self.num_samples = int( | |
math.ceil(len(self.indices) * 1.0 / self.world_size)) | |
self.total_size = self.num_samples * self.world_size | |
def __iter__(self) -> Iterator: | |
if self.test_mode: | |
# in test mode, the order of frames can not be shuffled. | |
indices = self.indices[self.rank] | |
else: | |
# deterministically shuffle based on epoch | |
rng = random.Random(self.epoch + self.seed) | |
indices = rng.sample(self.indices, len(self.indices)) | |
# add extra samples to make it evenly divisible | |
indices += indices[:(self.total_size - len(indices))] | |
assert len(indices) == self.total_size | |
# subsample | |
indices = indices[self.rank:self.total_size:self.world_size] | |
assert len(indices) == self.num_samples | |
return iter(indices) | |
def __len__(self): | |
return self.num_samples | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |