|
import os |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import scipy.io |
|
import torch |
|
import torch.utils.data as data |
|
|
|
from auxiliary.utils import normalize, bgr_to_rgb, linear_to_nonlinear, hwc_to_chw |
|
from classes.data.DataAugmenter import DataAugmenter |
|
|
|
|
|
class ColorCheckerDataset(data.Dataset): |
|
|
|
def __init__(self, train: bool = True, folds_num: int = 1): |
|
|
|
self.__train = train |
|
self.__da = DataAugmenter() |
|
|
|
path_to_folds = os.path.join("dataset", "folds.mat") |
|
path_to_metadata = os.path.join("dataset", "metadata.txt") |
|
self.__path_to_data = os.path.join("dataset", "preprocessed", "numpy_data") |
|
self.__path_to_label = os.path.join("dataset", "preprocessed", "numpy_labels") |
|
|
|
folds = scipy.io.loadmat(path_to_folds) |
|
img_idx = folds["tr_split" if self.__train else "te_split"][0][folds_num][0] |
|
|
|
metadata = open(path_to_metadata, 'r').readlines() |
|
self.__fold_data = [metadata[i - 1] for i in img_idx] |
|
|
|
def __getitem__(self, index: int) -> Tuple: |
|
file_name = self.__fold_data[index].strip().split(' ')[1] |
|
img = np.array(np.load(os.path.join(self.__path_to_data, file_name + '.npy')), dtype='float32') |
|
illuminant = np.array(np.load(os.path.join(self.__path_to_label, file_name + '.npy')), dtype='float32') |
|
|
|
if self.__train: |
|
img, illuminant = self.__da.augment(img, illuminant) |
|
else: |
|
img = self.__da.crop(img) |
|
|
|
img = hwc_to_chw(linear_to_nonlinear(bgr_to_rgb(normalize(img)))) |
|
|
|
img = torch.from_numpy(img.copy()) |
|
illuminant = torch.from_numpy(illuminant.copy()) |
|
|
|
if not self.__train: |
|
img = img.type(torch.FloatTensor) |
|
|
|
return img, illuminant, file_name |
|
|
|
def __len__(self) -> int: |
|
return len(self.__fold_data) |
|
|