|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from PIL import Image |
|
import numpy as np |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
|
|
from .preprocessor import normalize_params |
|
|
|
class ImageNetDataset(Dataset): |
|
def __init__(self, root, transform=None, convert_to_numpy: bool = True, post_normalize: str = "plain"): |
|
self.root = root |
|
self.transform = transform |
|
self.convert_to_numpy = convert_to_numpy |
|
self.post_normalize = transforms.Normalize( |
|
**normalize_params[post_normalize] |
|
) |
|
|
|
|
|
classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) |
|
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} |
|
|
|
|
|
self.samples = [] |
|
self.extensions = [] |
|
for target_class in sorted(class_to_idx.keys()): |
|
class_idx = class_to_idx[target_class] |
|
target_dir = os.path.join(root, target_class) |
|
if not os.path.isdir(target_dir): |
|
continue |
|
for fname in sorted(os.listdir(target_dir)): |
|
path = os.path.join(target_dir, fname) |
|
item = (path, class_idx) |
|
self.samples.append(item) |
|
ext = path.split(".")[-1] |
|
if ext not in self.extensions: |
|
self.extensions.append(ext) |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, index): |
|
path, label = self.samples[index] |
|
image = Image.open(path) |
|
if not image.mode == "RGB": |
|
image = image.convert("RGB") |
|
if self.convert_to_numpy: |
|
image = np.array(image).astype("uint8") |
|
|
|
image = self.transform(image=image)["image"] |
|
|
|
image = (image / 255).astype(np.float32) |
|
image = torch.from_numpy(image).permute(2, 0, 1) |
|
image = self.post_normalize(image) |
|
return image, label |
|
|