Spaces:
Running
Running
File size: 8,274 Bytes
f5288df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import torch
import pandas as pd
from torch.utils.data import Dataset
import nibabel as nib
from monai.transforms import Affined, RandGaussianNoised, Rand3DElasticd, AdjustContrastd, ScaleIntensityd, ToTensord, Resized, RandRotate90d, Resize, RandGaussianSmoothd, GaussianSmoothd, Rotate90d, StdShiftIntensityd, RandAdjustContrastd, Flipd
import random
import numpy as np
#######################################
## 3D SYNC TRANSFORM
#######################################
class NormalSynchronizedTransform3D:
""" Vanilla Validation Transforms"""
def __init__(self, image_size=(128,128,128), max_rotation=40, translate_range=0.2, scale_range=(0.9, 1.3), apply_prob=0.5):
self.image_size = image_size
self.max_rotation = max_rotation
self.translate_range = translate_range
self.scale_range = scale_range
self.apply_prob = apply_prob
def __call__(self, scan_list):
transformed_scans = []
resize_transform = Resized(spatial_size=(128,128,128), keys=["image"])
scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) # Intensity scaling
tensor_transform = ToTensord(keys=["image"]) # Convert to tensor
for scan in scan_list:
sample = {"image": scan}
sample = resize_transform(sample)
sample = scale_transform(sample)
sample = tensor_transform(sample)
transformed_scans.append(sample["image"].squeeze())
return torch.stack(transformed_scans)
class MedicalImageDatasetBalancedIntensity3D(Dataset):
""" Validation Dataset class """
def __init__(self, csv_path, root_dir, transform=None):
self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str})
self.root_dir = root_dir
self.transform = NormalSynchronizedTransform3D()
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
## load the niftis from csv
pat_id = str(self.dataframe.loc[idx, 'pat_id'])
scan_dates = str(self.dataframe.loc[idx, 'scandate'])
label = self.dataframe.loc[idx, 'label']
scandates = scan_dates.split('-')
scan_list = []
for scandate in scandates:
img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz")
scan = nib.load(img_name).get_fdata()
scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0))
## package into a dictionary for val loader
transformed_scans = self.transform(scan_list)
sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id}
return sample
class SynchronizedTransform3D:
""" Trainign Augmentation method """
def __init__(self, image_size=(128,128,128), max_rotation=0.34, translate_range=15, scale_range=(0.9, 1.3), apply_prob=0.5, gaussian_sigma_range=(0.25, 1.5), gaussian_noise_std_range=(0.05, 0.09)):
self.image_size = image_size
self.max_rotation = max_rotation
self.translate_range = translate_range
self.scale_range = scale_range
self.apply_prob = apply_prob
self.gaussian_sigma_range = gaussian_sigma_range
self.gaussian_noise_std_range = gaussian_noise_std_range
def __call__(self, scan_list):
transformed_scans = []
rotate_params = (random.uniform(-self.max_rotation, self.max_rotation),) * 3 if random.random() < self.apply_prob else (0, 0, 0)
translate_params = tuple([random.uniform(-self.translate_range, self.translate_range) for _ in range(3)]) if random.random() < self.apply_prob else (0, 0, 0)
scale_params = tuple([random.uniform(self.scale_range[0], self.scale_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else (1, 1, 1)
gaussian_sigma = tuple([random.uniform(self.gaussian_sigma_range[0], self.gaussian_sigma_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else None
gaussian_noise_std = random.uniform(self.gaussian_noise_std_range[0], self.gaussian_noise_std_range[1]) if random.random() < self.apply_prob else None
flip_axes = (0,1) if random.random() < self.apply_prob else None # Determine if and along which axes to flip
flip_x = 0 if random.random() < self.apply_prob else None
flip_y = 1 if random.random() < self.apply_prob else None
flip_z = 2 if random.random() < self.apply_prob else None
offset = random.randint(50,100) if random.random() < self.apply_prob else None
gammafactor = random.uniform(0.5,2.0) if random.random() < self.apply_prob else 1
affine_transform = Affined(keys=["image"], rotate_params=rotate_params, translate_params=translate_params, scale_params=scale_params, padding_mode='zeros')
gaussian_blur_transform = GaussianSmoothd(keys=["image"], sigma=gaussian_sigma) if gaussian_sigma else None
gaussian_noise_transform = RandGaussianNoised(keys=["image"], std=gaussian_noise_std, prob=1.0, mean=0.0, sample_std=False) if gaussian_noise_std else None
#flip_transform = Rotate90d(keys=["image"], k=1, spatial_axes=flip_axes) if flip_axes else None
flip_x_transform = Flipd(keys=["image"], spatial_axis=flip_x) if flip_x else None
flip_y_transform = Flipd(keys=["image"], spatial_axis=flip_y) if flip_y else None
flip_z_transform = Flipd(keys=["image"], spatial_axis=flip_z) if flip_z else None
resize_transform = Resized(spatial_size=(128,128,128), keys=["image"])
scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) # Intensity scaling
tensor_transform = ToTensord(keys=["image"]) # Convert to tensor
shift_intensity = StdShiftIntensityd(keys = ["image"], factor = offset, nonzero=True)
adjust_contrast = AdjustContrastd(keys = ["image"], gamma = gammafactor)
for scan in scan_list:
sample = {"image": scan}
sample = resize_transform(sample)
sample = affine_transform(sample)
if flip_x_transform:
sample = flip_x_transform(sample)
if flip_y_transform:
sample = flip_y_transform(sample)
if flip_z_transform:
sample = flip_z_transform(sample)
if gaussian_blur_transform:
sample = gaussian_blur_transform(sample)
if offset:
sample = shift_intensity(sample)
sample = scale_transform(sample)
sample = adjust_contrast(sample)
if gaussian_noise_transform:
sample = gaussian_noise_transform(sample)
sample = tensor_transform(sample)
transformed_scans.append(sample["image"].squeeze())
return torch.stack(transformed_scans)
class TransformationMedicalImageDatasetBalancedIntensity3D(Dataset):
""" Training Dataset class """
def __init__(self, csv_path, root_dir, transform=None):
self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str})
self.root_dir = root_dir
self.transform = SynchronizedTransform3D() # calls training augmentations
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
## load the niftis from csv
pat_id = str(self.dataframe.loc[idx, 'pat_id'])
scan_dates = str(self.dataframe.loc[idx, 'scandate'])
label = self.dataframe.loc[idx, 'label']
scandates = scan_dates.split('-')
scan_list = []
for scandate in scandates:
img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz") #f"{pat_id}_{scandate}.nii.gz")
scan = nib.load(img_name).get_fdata()
scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0))
# package into a monai type dictionary
transformed_scans = self.transform(scan_list)
sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id}
return sample
|