File size: 3,190 Bytes
ffbb48e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import nibabel as nib
import os
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize 
from PIL import Image
import random
import torch

NON_AX = (0, 1)
NON_COR = (0, 2)
NON_SAG = (1, 2)


class AD_Standard_3DRandomPatch(Dataset):
    """labeled Faces in the Wild dataset."""
    
    def __init__(self, root_dir, data_file):
        """
        Args:
            root_dir (string): Directory of all the images.
            data_file (string): File name of the train/test split file.
        """
        self.root_dir = root_dir
        self.data_file = data_file
    
    def __len__(self):
        with open(self.data_file) as df:
            summation = sum(1 for line in df)
        return summation
    
    def __getitem__(self, idx):
        with open(self.data_file) as df:
            lines = df.readlines()
            lst = lines[idx].split()
            img_name = lst[0]
            image_path = os.path.join(self.root_dir, img_name)
            image = nib.load(image_path)

            image_array = np.array(image.get_data())
            patch_samples = getRandomPatches(image_array)
            patch_dict = {"patch": patch_samples}
        return patch_dict


def customToTensor(pic):
    if isinstance(pic, np.ndarray):
        img = torch.from_numpy(pic)
        img = torch.unsqueeze(img,0)
        # backward compatibility
        return img.float()

def getRandomPatches(image_array):
    patches = []
    mean_ax = np.ndarray.mean(image_array, axis = NON_AX)
    mean_cor = np.ndarray.mean(image_array, axis = NON_COR)
    mean_sag = np.ndarray.mean(image_array, axis = NON_SAG)

    first_ax = int(round(list(mean_ax).index(filter(lambda x: x>0, mean_ax)[0])))
    last_ax = int(round(list(mean_ax).index(filter(lambda x: x>0, mean_ax)[-1])))
    first_cor = int(round(list(mean_cor).index(filter(lambda x: x>0, mean_cor)[0])))
    last_cor = int(round(list(mean_cor).index(filter(lambda x: x>0, mean_cor)[-1])))
    first_sag = int(round(list(mean_sag).index(filter(lambda x: x>0, mean_sag)[0])))
    last_sag = int(round(list(mean_sag).index(filter(lambda x: x>0, mean_sag)[-1])))

    first_ax = first_ax + 20
    last_ax = last_ax - 5

    ax_samples = [random.randint(first_ax - 3, last_ax - 3) for r in xrange(10000)]
    cor_samples = [random.randint(first_cor - 3, last_cor - 3) for r in xrange(10000)]
    sag_samples = [random.randint(first_sag - 3, last_sag - 3) for r in xrange(10000)]

    for i in range(1000):
        ax_i = ax_samples[i]
        cor_i = cor_samples[i]
        sag_i = sag_samples[i]
        patch = image_array[ax_i-3:ax_i+4, cor_i-3:cor_i+4, sag_i-3:sag_i+4]
        while (np.ndarray.sum(patch) == 0):
            ax_ni = random.randint(first_ax - 3, last_ax - 4)
            cor_ni = random.randint(first_cor - 3, last_cor - 4)
            sag_ni = random.randint(first_sag - 3, last_sag - 4)
            patch = image_array[ax_ni-3:ax_ni+4, cor_ni-3:cor_ni+4, sag_ni-3:sag_ni+4]
        patch = customToTensor(patch)
        patches.append(patch)
    return patches


# plt.imshow(array[i][3,:,:], cmap = 'gray')
# plt.savefig('./section.png', dpi=100)