Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,656 Bytes
9b33fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
"""ImageNet 1k dataset."""
from __future__ import annotations
import os
import pickle
import tarfile
from collections.abc import Sequence
import numpy as np
from vis4d.common.logging import rank_zero_info
from vis4d.common.time import Timer
from vis4d.common.typing import ArgsType
from vis4d.data.const import CommonKeys as K
from vis4d.data.typing import DictData
from .base import Dataset
from .util import im_decode, to_onehot
class ImageNet(Dataset):
"""ImageNet 1K dataset."""
DESCRIPTION = """ImageNet is a large visual database designed for use in
visual object recognition software research."""
HOMEPAGE = "http://www.image-net.org/"
PAPER = "http://www.image-net.org/papers/imagenet_cvpr09.pdf"
LICENSE = "http://www.image-net.org/terms-of-use"
KEYS = [K.images, K.categories]
def __init__(
self,
data_root: str,
keys_to_load: Sequence[str] = (K.images, K.categories),
split: str = "train",
num_classes: int = 1000,
use_sample_lists: bool = False,
**kwargs: ArgsType,
) -> None:
"""Initialize ImageNet dataset.
Args:
data_root (str): Path to root directory of dataset.
keys_to_load (list[str], optional): List of keys to load. Defaults
to (K.images, K.categories).
split (str, optional): Dataset split to load. Defaults to "train".
num_classes (int, optional): Number of classes to load. Defaults to
1000.
use_sample_lists (bool, optional): Whether to use sample lists for
loading the dataset. Defaults to False.
NOTE: The dataset is expected to be in the following format:
data_root
βββ train.pkl # Sample lists for training set (optional)
βββ val.pkl # Sample lists for validation set (optional)
βββ train
β βββ n01440764.tar
β βββ ...
βββ val
βββ n01440764.tar
βββ ...
With each tar file containing the images of a single class. The
images are expected to be in ".JPEG" extension.
Currently, we are not using the DataBackend for loading the tars to
avoid keeping too many file pointers open at the same time.
"""
super().__init__(**kwargs)
self.data_root = data_root
self.keys_to_load = keys_to_load
self.split = split
self.num_classes = num_classes
self.use_sample_lists = use_sample_lists
self.data_infos: list[tuple[tarfile.TarInfo, int]] = []
self._classes: list[str] = []
self._load_data_infos()
def _load_data_infos(self) -> None:
"""Load data infos from disk."""
timer = Timer()
# Load tar files
for file in os.listdir(os.path.join(self.data_root, self.split)):
if file.endswith(".tar"):
self._classes.append(file)
assert len(self._classes) == self.num_classes, (
f"Expected {self.num_classes} classes, but found "
f"{len(self._classes)} tar files."
)
self._classes = sorted(self._classes)
sample_list_path = os.path.join(self.data_root, f"{self.split}.pkl")
if self.use_sample_lists and os.path.exists(sample_list_path):
with open(sample_list_path, "rb") as f:
sample_list = pickle.load(f)[0]
if sample_list[-1][1] == self.num_classes - 1:
self.data_infos = sample_list
else:
raise ValueError(
"Sample list does not match the number of classes. "
"Please regenerate the sample list or set "
"use_sample_lists=False."
)
# If sample lists are not available, generate them on the fly.
else:
for class_idx, file in enumerate(self._classes):
with tarfile.open(
os.path.join(self.data_root, self.split, file)
) as f:
members = f.getmembers()
for member in members:
if member.isfile() and member.name.endswith(".JPEG"):
self.data_infos.append((member, class_idx))
rank_zero_info(f"Loading {self} takes {timer.time():.2f} seconds.")
def __len__(self) -> int:
"""Return length of dataset."""
return len(self.data_infos)
def __getitem__(self, idx: int) -> DictData:
"""Convert single element at given index into Vis4D data format."""
member, class_idx = self.data_infos[idx]
with tarfile.open(
os.path.join(self.data_root, self.split, self._classes[class_idx]),
mode="r:*", # unexclusive read mode
) as f:
im_bytes = f.extractfile(member)
assert im_bytes is not None, f"Could not extract {member.name}!"
image = im_decode(im_bytes.read())
data_dict: DictData = {}
if K.images in self.keys_to_load:
data_dict[K.images] = np.ascontiguousarray(
image, dtype=np.float32
)[np.newaxis, ...]
image_hw = image.shape[:2]
data_dict[K.input_hw] = image_hw
data_dict[K.original_hw] = image_hw
if K.categories in self.keys_to_load:
data_dict[K.categories] = to_onehot(
np.array(class_idx, dtype=np.int64), self.num_classes
)
return data_dict
|