OptVQ / optvq /data /dataset.py
BorelTHU's picture
initiate
223d932
# ------------------------------------------------------------------------------
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from .preprocessor import normalize_params
class ImageNetDataset(Dataset):
def __init__(self, root, transform=None, convert_to_numpy: bool = True, post_normalize: str = "plain"):
self.root = root
self.transform = transform
self.convert_to_numpy = convert_to_numpy
self.post_normalize = transforms.Normalize(
**normalize_params[post_normalize]
)
# find classes
classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
# make dataset
self.samples = []
self.extensions = []
for target_class in sorted(class_to_idx.keys()):
class_idx = class_to_idx[target_class]
target_dir = os.path.join(root, target_class)
if not os.path.isdir(target_dir):
continue
for fname in sorted(os.listdir(target_dir)):
path = os.path.join(target_dir, fname)
item = (path, class_idx)
self.samples.append(item)
ext = path.split(".")[-1]
if ext not in self.extensions:
self.extensions.append(ext)
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
path, label = self.samples[index]
image = Image.open(path)
if not image.mode == "RGB":
image = image.convert("RGB")
if self.convert_to_numpy:
image = np.array(image).astype("uint8")
# image augmentation
image = self.transform(image=image)["image"]
# to tensor and normalize
image = (image / 255).astype(np.float32)
image = torch.from_numpy(image).permute(2, 0, 1)
image = self.post_normalize(image)
return image, label