import nibabel as nib import os from torch.utils.data import Dataset class AD_Dataset(Dataset): """labeled Faces in the Wild dataset.""" def __init__(self, root_dir, data_file, transform=None): """ Args: root_dir (string): Directory of all the images. data_file (string): File name of the train/test split file. transform (callable, optional): Optional transform to be applied on a sample. data_augmentation (boolean): Optional data augmentation. """ self.root_dir = root_dir self.data_file = data_file self.transform = transform def __len__(self): return sum(1 for line in open(self.data_file)) def __getitem__(self, idx): df = open(self.data_file) lines = df.readlines() lst = lines[idx].split() img_name = lst[0] img_label = lst[1] image_path = os.path.join(self.root_dir, img_name) image = nib.load(image_path) if img_label == 'Normal': label = 0 elif img_label == 'AD': label = 1 elif img_label == 'MCI': label = 2 if self.transform: image = self.transform(image) sample = {'image': image, 'label': label} return sample