Spaces:
Running
on
Zero
Running
on
Zero
"""Dataloader configuration.""" | |
from __future__ import annotations | |
from collections.abc import Sequence | |
from ml_collections import ConfigDict, FieldReference | |
from vis4d.common.typing import GenericFunc | |
from vis4d.config import class_config | |
from vis4d.data.data_pipe import DataPipe | |
from vis4d.data.loader import ( | |
DEFAULT_COLLATE_KEYS, | |
build_inference_dataloaders, | |
build_train_dataloader, | |
default_collate, | |
) | |
from vis4d.data.transforms.to_tensor import ToTensor | |
from .callable import get_callable_cfg | |
def get_train_dataloader_cfg( | |
datasets_cfg: ConfigDict | list[ConfigDict], | |
samples_per_gpu: int | FieldReference = 1, | |
workers_per_gpu: int | FieldReference = 1, | |
batchprocess_cfg: ConfigDict | None = None, | |
collate_fn: GenericFunc = default_collate, | |
collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, | |
sensors: Sequence[str] | None = None, | |
pin_memory: bool | FieldReference = True, | |
shuffle: bool | FieldReference = True, | |
aspect_ratio_grouping: bool | FieldReference = False, | |
) -> ConfigDict: | |
"""Creates dataloader configuration given dataset and preprocessing. | |
Args: | |
datasets_cfg (ConfigDict | list[ConfigDict]): The configuration | |
contains the single dataset or datasets. If it is a list, | |
it will be wrapped into a DataPipe. | |
samples_per_gpu (int | FieldReference, optional): How many samples each | |
GPU will process. Defaults to 1. | |
workers_per_gpu (int | FieldReference, optional): How many workers to | |
spawn per GPU. Defaults to 1. | |
batchprocess_cfg (ConfigDict, optional): The config that contains the | |
batch processing operations. Defaults to None. If None, ToTensor | |
will be used. | |
collate_fn (GenericFunc, optional): The collate function to use. | |
Defaults to default_collate. | |
collate_keys (Sequence[str], optional): The keys to collate. Defaults | |
to DEFAULT_COLLATE_KEYS. | |
sensors (Sequence[str], optional): The sensors to collate. Defaults to | |
None. | |
pin_memory (bool | FieldReference, optional): Whether to pin memory. | |
Defaults to True. | |
shuffle (bool | FieldReference, optional): Whether to shuffle the | |
dataset. Defaults to True. | |
aspect_ratio_grouping (bool | FieldReference, optional): Whether to | |
group the samples by aspect ratio. Defaults to False. | |
Returns: | |
ConfigDict: Configuration that can be instantiate as a dataloader. | |
""" | |
if batchprocess_cfg is None: | |
batchprocess_cfg = class_config(ToTensor) | |
if isinstance(datasets_cfg, list): | |
dataset = class_config(DataPipe, datasets=datasets_cfg) | |
else: | |
dataset = datasets_cfg | |
return class_config( | |
build_train_dataloader, | |
dataset=dataset, | |
samples_per_gpu=samples_per_gpu, | |
workers_per_gpu=workers_per_gpu, | |
batchprocess_fn=batchprocess_cfg, | |
collate_fn=get_callable_cfg(collate_fn), | |
collate_keys=collate_keys, | |
sensors=sensors, | |
pin_memory=pin_memory, | |
shuffle=shuffle, | |
aspect_ratio_grouping=aspect_ratio_grouping, | |
) | |
def get_inference_dataloaders_cfg( | |
datasets_cfg: ConfigDict | list[ConfigDict], | |
samples_per_gpu: int | FieldReference = 1, | |
workers_per_gpu: int | FieldReference = 1, | |
video_based_inference: bool | FieldReference = False, | |
batchprocess_cfg: ConfigDict | None = None, | |
collate_fn: GenericFunc = default_collate, | |
collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS, | |
sensors: Sequence[str] | None = None, | |
) -> ConfigDict: | |
"""Creates dataloader configuration given dataset for inference. | |
Args: | |
datasets_cfg (ConfigDict | list[ConfigDict]): The configuration | |
contains the single dataset or datasets. | |
samples_per_gpu (int | FieldReference, optional): How many samples each | |
GPU will process per batch. Defaults to 1. | |
workers_per_gpu (int | FieldReference, optional): How many workers each | |
GPU will spawn. Defaults to 1. | |
video_based_inference (bool | FieldReference , optional): Whether to | |
split dataset by sequences. Defaults to False. | |
batchprocess_cfg (ConfigDict, optional): The config that contains the | |
batch processing operations. Defaults to None. If None, ToTensor | |
will be used. | |
collate_fn (GenericFunc, optional): The collate function that will be | |
used to stack the batch. Defaults to default_collate. | |
collate_keys (Sequence[str], optional): The keys to collate. Defaults | |
to DEFAULT_COLLATE_KEYS. | |
sensors (Sequence[str], optional): The sensors to collate. Defaults to | |
None. | |
Returns: | |
ConfigDict: The dataloader configuration. | |
""" | |
if batchprocess_cfg is None: | |
batchprocess_cfg = class_config(ToTensor) | |
return class_config( | |
build_inference_dataloaders, | |
datasets=datasets_cfg, | |
samples_per_gpu=samples_per_gpu, | |
workers_per_gpu=workers_per_gpu, | |
video_based_inference=video_based_inference, | |
batchprocess_fn=batchprocess_cfg, | |
collate_fn=get_callable_cfg(collate_fn), | |
collate_keys=collate_keys, | |
sensors=sensors, | |
) | |