File size: 5,322 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Dataloader configuration."""

from __future__ import annotations

from collections.abc import Sequence

from ml_collections import ConfigDict, FieldReference

from vis4d.common.typing import GenericFunc
from vis4d.config import class_config
from vis4d.data.data_pipe import DataPipe
from vis4d.data.loader import (
    DEFAULT_COLLATE_KEYS,
    build_inference_dataloaders,
    build_train_dataloader,
    default_collate,
)
from vis4d.data.transforms.to_tensor import ToTensor

from .callable import get_callable_cfg


def get_train_dataloader_cfg(
    datasets_cfg: ConfigDict | list[ConfigDict],
    samples_per_gpu: int | FieldReference = 1,
    workers_per_gpu: int | FieldReference = 1,
    batchprocess_cfg: ConfigDict | None = None,
    collate_fn: GenericFunc = default_collate,
    collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
    sensors: Sequence[str] | None = None,
    pin_memory: bool | FieldReference = True,
    shuffle: bool | FieldReference = True,
    aspect_ratio_grouping: bool | FieldReference = False,
) -> ConfigDict:
    """Creates dataloader configuration given dataset and preprocessing.

    Args:
        datasets_cfg (ConfigDict | list[ConfigDict]): The configuration
            contains the single dataset or datasets. If it is a list,
            it will be wrapped into a DataPipe.
        samples_per_gpu (int | FieldReference, optional): How many samples each
            GPU will process. Defaults to 1.
        workers_per_gpu (int | FieldReference, optional): How many workers to
            spawn per GPU. Defaults to 1.
        batchprocess_cfg (ConfigDict, optional): The config that contains the
            batch processing operations. Defaults to None. If None, ToTensor
            will be used.
        collate_fn (GenericFunc, optional): The collate function to use.
            Defaults to default_collate.
        collate_keys (Sequence[str], optional): The keys to collate. Defaults
            to DEFAULT_COLLATE_KEYS.
        sensors (Sequence[str], optional): The sensors to collate. Defaults to
            None.
        pin_memory (bool | FieldReference, optional): Whether to pin memory.
            Defaults to True.
        shuffle (bool | FieldReference, optional): Whether to shuffle the
            dataset. Defaults to True.
        aspect_ratio_grouping (bool | FieldReference, optional): Whether to
            group the samples by aspect ratio. Defaults to False.

    Returns:
        ConfigDict: Configuration that can be instantiate as a dataloader.
    """
    if batchprocess_cfg is None:
        batchprocess_cfg = class_config(ToTensor)

    if isinstance(datasets_cfg, list):
        dataset = class_config(DataPipe, datasets=datasets_cfg)
    else:
        dataset = datasets_cfg

    return class_config(
        build_train_dataloader,
        dataset=dataset,
        samples_per_gpu=samples_per_gpu,
        workers_per_gpu=workers_per_gpu,
        batchprocess_fn=batchprocess_cfg,
        collate_fn=get_callable_cfg(collate_fn),
        collate_keys=collate_keys,
        sensors=sensors,
        pin_memory=pin_memory,
        shuffle=shuffle,
        aspect_ratio_grouping=aspect_ratio_grouping,
    )


def get_inference_dataloaders_cfg(
    datasets_cfg: ConfigDict | list[ConfigDict],
    samples_per_gpu: int | FieldReference = 1,
    workers_per_gpu: int | FieldReference = 1,
    video_based_inference: bool | FieldReference = False,
    batchprocess_cfg: ConfigDict | None = None,
    collate_fn: GenericFunc = default_collate,
    collate_keys: Sequence[str] = DEFAULT_COLLATE_KEYS,
    sensors: Sequence[str] | None = None,
) -> ConfigDict:
    """Creates dataloader configuration given dataset for inference.

    Args:
        datasets_cfg (ConfigDict | list[ConfigDict]): The configuration
            contains the single dataset or datasets.
        samples_per_gpu (int | FieldReference, optional): How many samples each
            GPU will process per batch. Defaults to 1.
        workers_per_gpu (int | FieldReference, optional): How many workers each
            GPU will spawn. Defaults to 1.
        video_based_inference (bool | FieldReference , optional): Whether to
            split dataset by sequences. Defaults to False.
        batchprocess_cfg (ConfigDict, optional): The config that contains the
            batch processing operations. Defaults to None. If None, ToTensor
            will be used.
        collate_fn (GenericFunc, optional): The collate function that will be
            used to stack the batch. Defaults to default_collate.
        collate_keys (Sequence[str], optional): The keys to collate. Defaults
            to DEFAULT_COLLATE_KEYS.
        sensors (Sequence[str], optional): The sensors to collate. Defaults to
            None.

    Returns:
        ConfigDict: The dataloader configuration.
    """
    if batchprocess_cfg is None:
        batchprocess_cfg = class_config(ToTensor)

    return class_config(
        build_inference_dataloaders,
        datasets=datasets_cfg,
        samples_per_gpu=samples_per_gpu,
        workers_per_gpu=workers_per_gpu,
        video_based_inference=video_based_inference,
        batchprocess_fn=batchprocess_cfg,
        collate_fn=get_callable_cfg(collate_fn),
        collate_keys=collate_keys,
        sensors=sensors,
    )