NPRC24 / MiAlgo /classes /data /ColorCheckerDataset.py
Artyom
MiAlgo
82567db verified
import os
from typing import Tuple
import numpy as np
import scipy.io
import torch
import torch.utils.data as data
from auxiliary.utils import normalize, bgr_to_rgb, linear_to_nonlinear, hwc_to_chw
from classes.data.DataAugmenter import DataAugmenter
class ColorCheckerDataset(data.Dataset):
def __init__(self, train: bool = True, folds_num: int = 1):
self.__train = train
self.__da = DataAugmenter()
path_to_folds = os.path.join("dataset", "folds.mat")
path_to_metadata = os.path.join("dataset", "metadata.txt")
self.__path_to_data = os.path.join("dataset", "preprocessed", "numpy_data")
self.__path_to_label = os.path.join("dataset", "preprocessed", "numpy_labels")
folds = scipy.io.loadmat(path_to_folds)
img_idx = folds["tr_split" if self.__train else "te_split"][0][folds_num][0]
metadata = open(path_to_metadata, 'r').readlines()
self.__fold_data = [metadata[i - 1] for i in img_idx]
def __getitem__(self, index: int) -> Tuple:
file_name = self.__fold_data[index].strip().split(' ')[1]
img = np.array(np.load(os.path.join(self.__path_to_data, file_name + '.npy')), dtype='float32')
illuminant = np.array(np.load(os.path.join(self.__path_to_label, file_name + '.npy')), dtype='float32')
if self.__train:
img, illuminant = self.__da.augment(img, illuminant)
else:
img = self.__da.crop(img)
img = hwc_to_chw(linear_to_nonlinear(bgr_to_rgb(normalize(img))))
img = torch.from_numpy(img.copy())
illuminant = torch.from_numpy(illuminant.copy())
if not self.__train:
img = img.type(torch.FloatTensor)
return img, illuminant, file_name
def __len__(self) -> int:
return len(self.__fold_data)