import os from typing import Tuple import numpy as np import scipy.io import torch import torch.utils.data as data from auxiliary.utils import normalize, bgr_to_rgb, linear_to_nonlinear, hwc_to_chw from classes.data.DataAugmenter import DataAugmenter class ColorCheckerDataset(data.Dataset): def __init__(self, train: bool = True, folds_num: int = 1): self.__train = train self.__da = DataAugmenter() path_to_folds = os.path.join("dataset", "folds.mat") path_to_metadata = os.path.join("dataset", "metadata.txt") self.__path_to_data = os.path.join("dataset", "preprocessed", "numpy_data") self.__path_to_label = os.path.join("dataset", "preprocessed", "numpy_labels") folds = scipy.io.loadmat(path_to_folds) img_idx = folds["tr_split" if self.__train else "te_split"][0][folds_num][0] metadata = open(path_to_metadata, 'r').readlines() self.__fold_data = [metadata[i - 1] for i in img_idx] def __getitem__(self, index: int) -> Tuple: file_name = self.__fold_data[index].strip().split(' ')[1] img = np.array(np.load(os.path.join(self.__path_to_data, file_name + '.npy')), dtype='float32') illuminant = np.array(np.load(os.path.join(self.__path_to_label, file_name + '.npy')), dtype='float32') if self.__train: img, illuminant = self.__da.augment(img, illuminant) else: img = self.__da.crop(img) img = hwc_to_chw(linear_to_nonlinear(bgr_to_rgb(normalize(img)))) img = torch.from_numpy(img.copy()) illuminant = torch.from_numpy(illuminant.copy()) if not self.__train: img = img.type(torch.FloatTensor) return img, illuminant, file_name def __len__(self) -> int: return len(self.__fold_data)