File size: 5,656 Bytes
9b33fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""ImageNet 1k dataset."""

from __future__ import annotations

import os
import pickle
import tarfile
from collections.abc import Sequence

import numpy as np

from vis4d.common.logging import rank_zero_info
from vis4d.common.time import Timer
from vis4d.common.typing import ArgsType
from vis4d.data.const import CommonKeys as K
from vis4d.data.typing import DictData

from .base import Dataset
from .util import im_decode, to_onehot


class ImageNet(Dataset):
    """ImageNet 1K dataset."""

    DESCRIPTION = """ImageNet is a large visual database designed for use in
        visual object recognition software research."""
    HOMEPAGE = "http://www.image-net.org/"
    PAPER = "http://www.image-net.org/papers/imagenet_cvpr09.pdf"
    LICENSE = "http://www.image-net.org/terms-of-use"

    KEYS = [K.images, K.categories]

    def __init__(
        self,
        data_root: str,
        keys_to_load: Sequence[str] = (K.images, K.categories),
        split: str = "train",
        num_classes: int = 1000,
        use_sample_lists: bool = False,
        **kwargs: ArgsType,
    ) -> None:
        """Initialize ImageNet dataset.

        Args:
            data_root (str): Path to root directory of dataset.
            keys_to_load (list[str], optional): List of keys to load. Defaults
                to (K.images, K.categories).
            split (str, optional): Dataset split to load. Defaults to "train".
            num_classes (int, optional): Number of classes to load. Defaults to
                1000.
            use_sample_lists (bool, optional): Whether to use sample lists for
                loading the dataset. Defaults to False.

        NOTE: The dataset is expected to be in the following format:
            data_root
            β”œβ”€β”€ train.pkl  # Sample lists for training set (optional)
            β”œβ”€β”€ val.pkl    # Sample lists for validation set (optional)
            β”œβ”€β”€ train
            β”‚   β”œβ”€β”€ n01440764.tar
            β”‚   β”œβ”€β”€ ...
            └── val
                β”œβ”€β”€ n01440764.tar
                β”œβ”€β”€ ...
            With each tar file containing the images of a single class. The
            images are expected to be in ".JPEG" extension.

            Currently, we are not using the DataBackend for loading the tars to
            avoid keeping too many file pointers open at the same time.
        """
        super().__init__(**kwargs)
        self.data_root = data_root
        self.keys_to_load = keys_to_load
        self.split = split
        self.num_classes = num_classes
        self.use_sample_lists = use_sample_lists
        self.data_infos: list[tuple[tarfile.TarInfo, int]] = []
        self._classes: list[str] = []
        self._load_data_infos()

    def _load_data_infos(self) -> None:
        """Load data infos from disk."""
        timer = Timer()
        # Load tar files
        for file in os.listdir(os.path.join(self.data_root, self.split)):
            if file.endswith(".tar"):
                self._classes.append(file)
        assert len(self._classes) == self.num_classes, (
            f"Expected {self.num_classes} classes, but found "
            f"{len(self._classes)} tar files."
        )
        self._classes = sorted(self._classes)

        sample_list_path = os.path.join(self.data_root, f"{self.split}.pkl")
        if self.use_sample_lists and os.path.exists(sample_list_path):
            with open(sample_list_path, "rb") as f:
                sample_list = pickle.load(f)[0]
                if sample_list[-1][1] == self.num_classes - 1:
                    self.data_infos = sample_list
                else:
                    raise ValueError(
                        "Sample list does not match the number of classes. "
                        "Please regenerate the sample list or set "
                        "use_sample_lists=False."
                    )
        # If sample lists are not available, generate them on the fly.
        else:
            for class_idx, file in enumerate(self._classes):
                with tarfile.open(
                    os.path.join(self.data_root, self.split, file)
                ) as f:
                    members = f.getmembers()
                    for member in members:
                        if member.isfile() and member.name.endswith(".JPEG"):
                            self.data_infos.append((member, class_idx))

        rank_zero_info(f"Loading {self} takes {timer.time():.2f} seconds.")

    def __len__(self) -> int:
        """Return length of dataset."""
        return len(self.data_infos)

    def __getitem__(self, idx: int) -> DictData:
        """Convert single element at given index into Vis4D data format."""
        member, class_idx = self.data_infos[idx]
        with tarfile.open(
            os.path.join(self.data_root, self.split, self._classes[class_idx]),
            mode="r:*",  # unexclusive read mode
        ) as f:
            im_bytes = f.extractfile(member)
            assert im_bytes is not None, f"Could not extract {member.name}!"
            image = im_decode(im_bytes.read())

        data_dict: DictData = {}
        if K.images in self.keys_to_load:
            data_dict[K.images] = np.ascontiguousarray(
                image, dtype=np.float32
            )[np.newaxis, ...]
            image_hw = image.shape[:2]
            data_dict[K.input_hw] = image_hw
            data_dict[K.original_hw] = image_hw
        if K.categories in self.keys_to_load:
            data_dict[K.categories] = to_onehot(
                np.array(class_idx, dtype=np.int64), self.num_classes
            )
        return data_dict