File size: 8,274 Bytes
f5288df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import torch
import pandas as pd 
from torch.utils.data import Dataset
import nibabel as nib 
from monai.transforms import Affined, RandGaussianNoised, Rand3DElasticd, AdjustContrastd, ScaleIntensityd, ToTensord, Resized, RandRotate90d, Resize, RandGaussianSmoothd, GaussianSmoothd, Rotate90d, StdShiftIntensityd, RandAdjustContrastd, Flipd
import random 
import numpy as np

    
#######################################
##      3D SYNC TRANSFORM
#######################################

class NormalSynchronizedTransform3D:
    """ Vanilla Validation Transforms"""

    def __init__(self, image_size=(128,128,128), max_rotation=40, translate_range=0.2, scale_range=(0.9, 1.3), apply_prob=0.5):
        self.image_size = image_size
        self.max_rotation = max_rotation
        self.translate_range = translate_range
        self.scale_range = scale_range
        self.apply_prob = apply_prob

    def __call__(self, scan_list):
        transformed_scans = []
        resize_transform = Resized(spatial_size=(128,128,128), keys=["image"])
        scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0)  # Intensity scaling
        tensor_transform = ToTensord(keys=["image"])  # Convert to tensor

        for scan in scan_list:
            sample = {"image": scan}
            sample = resize_transform(sample)
            sample = scale_transform(sample)
            sample = tensor_transform(sample)
            transformed_scans.append(sample["image"].squeeze())
            
        return torch.stack(transformed_scans)

class MedicalImageDatasetBalancedIntensity3D(Dataset):
    """ Validation Dataset class """

    def __init__(self, csv_path, root_dir, transform=None):
        self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str})
        self.root_dir = root_dir
        self.transform = NormalSynchronizedTransform3D()

    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

         ## load the niftis from csv
        pat_id = str(self.dataframe.loc[idx, 'pat_id'])
        scan_dates = str(self.dataframe.loc[idx, 'scandate'])
        label = self.dataframe.loc[idx, 'label']
        scandates = scan_dates.split('-')
        scan_list = []


        for scandate in scandates:
            img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz")
            scan = nib.load(img_name).get_fdata()
            scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0)) 

        ## package into a dictionary for val loader
        transformed_scans = self.transform(scan_list)
        sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id}  
        return sample


class SynchronizedTransform3D:
    """ Trainign Augmentation method """

    def __init__(self, image_size=(128,128,128), max_rotation=0.34, translate_range=15, scale_range=(0.9, 1.3), apply_prob=0.5, gaussian_sigma_range=(0.25, 1.5), gaussian_noise_std_range=(0.05, 0.09)):
        self.image_size = image_size
        self.max_rotation = max_rotation
        self.translate_range = translate_range
        self.scale_range = scale_range
        self.apply_prob = apply_prob
        self.gaussian_sigma_range = gaussian_sigma_range
        self.gaussian_noise_std_range = gaussian_noise_std_range

    def __call__(self, scan_list):
        transformed_scans = []
        rotate_params = (random.uniform(-self.max_rotation, self.max_rotation),) * 3 if random.random() < self.apply_prob else (0, 0, 0)
        translate_params = tuple([random.uniform(-self.translate_range, self.translate_range) for _ in range(3)]) if random.random() < self.apply_prob else (0, 0, 0)
        scale_params = tuple([random.uniform(self.scale_range[0], self.scale_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else (1, 1, 1)
        gaussian_sigma = tuple([random.uniform(self.gaussian_sigma_range[0], self.gaussian_sigma_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else None
        gaussian_noise_std = random.uniform(self.gaussian_noise_std_range[0], self.gaussian_noise_std_range[1]) if random.random() < self.apply_prob else None
        flip_axes = (0,1) if random.random() < self.apply_prob else None  # Determine if and along which axes to flip
        flip_x = 0 if random.random() < self.apply_prob else None
        flip_y = 1 if random.random() < self.apply_prob else None
        flip_z = 2 if random.random() < self.apply_prob else None
        offset = random.randint(50,100) if random.random() < self.apply_prob else None 
        gammafactor = random.uniform(0.5,2.0) if random.random() < self.apply_prob else 1 

        affine_transform = Affined(keys=["image"], rotate_params=rotate_params, translate_params=translate_params, scale_params=scale_params, padding_mode='zeros')
        gaussian_blur_transform = GaussianSmoothd(keys=["image"], sigma=gaussian_sigma) if gaussian_sigma else None
        gaussian_noise_transform = RandGaussianNoised(keys=["image"], std=gaussian_noise_std, prob=1.0, mean=0.0, sample_std=False) if gaussian_noise_std else None
        #flip_transform = Rotate90d(keys=["image"], k=1, spatial_axes=flip_axes) if flip_axes else None
        flip_x_transform = Flipd(keys=["image"], spatial_axis=flip_x) if flip_x else None
        flip_y_transform = Flipd(keys=["image"], spatial_axis=flip_y) if flip_y else None
        flip_z_transform = Flipd(keys=["image"], spatial_axis=flip_z) if flip_z else None
        resize_transform = Resized(spatial_size=(128,128,128), keys=["image"])
        scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0)  # Intensity scaling
        tensor_transform = ToTensord(keys=["image"])  # Convert to tensor
        shift_intensity = StdShiftIntensityd(keys = ["image"], factor = offset, nonzero=True)
        adjust_contrast = AdjustContrastd(keys = ["image"], gamma = gammafactor)

        for scan in scan_list:
            sample = {"image": scan}
            sample = resize_transform(sample)
            sample = affine_transform(sample)
            if flip_x_transform:
                sample = flip_x_transform(sample)
            if flip_y_transform:
                sample = flip_y_transform(sample)
            if flip_z_transform:
                sample = flip_z_transform(sample)
            if gaussian_blur_transform:
                sample = gaussian_blur_transform(sample) 
            if offset:
                sample = shift_intensity(sample)
            sample = scale_transform(sample)
            sample = adjust_contrast(sample)
            if gaussian_noise_transform:
                sample = gaussian_noise_transform(sample)
            sample = tensor_transform(sample)
            transformed_scans.append(sample["image"].squeeze())
            
        return torch.stack(transformed_scans)


class TransformationMedicalImageDatasetBalancedIntensity3D(Dataset):
    """ Training Dataset class """

    def __init__(self, csv_path, root_dir, transform=None):
        self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str})
        self.root_dir = root_dir
        self.transform = SynchronizedTransform3D() # calls training augmentations 

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        ## load the niftis from csv
        pat_id = str(self.dataframe.loc[idx, 'pat_id'])
        scan_dates = str(self.dataframe.loc[idx, 'scandate'])
        label = self.dataframe.loc[idx, 'label']
        scandates = scan_dates.split('-')
        scan_list = []


        for scandate in scandates:
            img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz") #f"{pat_id}_{scandate}.nii.gz")  
            scan = nib.load(img_name).get_fdata()
        scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0)) 

        # package into a monai type dictionary 
        transformed_scans = self.transform(scan_list) 
        sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id}  
        return sample