|
import nibabel as nib |
|
import os |
|
from torch.utils.data import Dataset |
|
|
|
class AD_Dataset(Dataset): |
|
"""labeled Faces in the Wild dataset.""" |
|
|
|
def __init__(self, root_dir, data_file, transform=None): |
|
""" |
|
Args: |
|
root_dir (string): Directory of all the images. |
|
data_file (string): File name of the train/test split file. |
|
transform (callable, optional): Optional transform to be applied on a sample. |
|
data_augmentation (boolean): Optional data augmentation. |
|
""" |
|
self.root_dir = root_dir |
|
self.data_file = data_file |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return sum(1 for line in open(self.data_file)) |
|
|
|
def __getitem__(self, idx): |
|
df = open(self.data_file) |
|
lines = df.readlines() |
|
lst = lines[idx].split() |
|
img_name = lst[0] |
|
img_label = lst[1] |
|
image_path = os.path.join(self.root_dir, img_name) |
|
image = nib.load(image_path) |
|
|
|
|
|
if img_label == 'Normal': |
|
label = 0 |
|
elif img_label == 'AD': |
|
label = 1 |
|
elif img_label == 'MCI': |
|
label = 2 |
|
|
|
if self.transform: |
|
image = self.transform(image) |
|
|
|
sample = {'image': image, 'label': label} |
|
|
|
return sample |