Divyanshu Tak
Initial commit of BrainIAC Docker application
f5288df
import os
import torch
import pandas as pd
from torch.utils.data import Dataset
import nibabel as nib
from monai.transforms import Affined, RandGaussianNoised, Rand3DElasticd, AdjustContrastd, ScaleIntensityd, ToTensord, Resized, RandRotate90d, Resize, RandGaussianSmoothd, GaussianSmoothd, Rotate90d, StdShiftIntensityd, RandAdjustContrastd, Flipd
import random
import numpy as np
#######################################
## 3D SYNC TRANSFORM
#######################################
class NormalSynchronizedTransform3D:
""" Vanilla Validation Transforms"""
def __init__(self, image_size=(128,128,128), max_rotation=40, translate_range=0.2, scale_range=(0.9, 1.3), apply_prob=0.5):
self.image_size = image_size
self.max_rotation = max_rotation
self.translate_range = translate_range
self.scale_range = scale_range
self.apply_prob = apply_prob
def __call__(self, scan_list):
transformed_scans = []
resize_transform = Resized(spatial_size=(128,128,128), keys=["image"])
scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) # Intensity scaling
tensor_transform = ToTensord(keys=["image"]) # Convert to tensor
for scan in scan_list:
sample = {"image": scan}
sample = resize_transform(sample)
sample = scale_transform(sample)
sample = tensor_transform(sample)
transformed_scans.append(sample["image"].squeeze())
return torch.stack(transformed_scans)
class MedicalImageDatasetBalancedIntensity3D(Dataset):
""" Validation Dataset class """
def __init__(self, csv_path, root_dir, transform=None):
self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str})
self.root_dir = root_dir
self.transform = NormalSynchronizedTransform3D()
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
## load the niftis from csv
pat_id = str(self.dataframe.loc[idx, 'pat_id'])
scan_dates = str(self.dataframe.loc[idx, 'scandate'])
label = self.dataframe.loc[idx, 'label']
scandates = scan_dates.split('-')
scan_list = []
for scandate in scandates:
img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz")
scan = nib.load(img_name).get_fdata()
scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0))
## package into a dictionary for val loader
transformed_scans = self.transform(scan_list)
sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id}
return sample
class SynchronizedTransform3D:
""" Trainign Augmentation method """
def __init__(self, image_size=(128,128,128), max_rotation=0.34, translate_range=15, scale_range=(0.9, 1.3), apply_prob=0.5, gaussian_sigma_range=(0.25, 1.5), gaussian_noise_std_range=(0.05, 0.09)):
self.image_size = image_size
self.max_rotation = max_rotation
self.translate_range = translate_range
self.scale_range = scale_range
self.apply_prob = apply_prob
self.gaussian_sigma_range = gaussian_sigma_range
self.gaussian_noise_std_range = gaussian_noise_std_range
def __call__(self, scan_list):
transformed_scans = []
rotate_params = (random.uniform(-self.max_rotation, self.max_rotation),) * 3 if random.random() < self.apply_prob else (0, 0, 0)
translate_params = tuple([random.uniform(-self.translate_range, self.translate_range) for _ in range(3)]) if random.random() < self.apply_prob else (0, 0, 0)
scale_params = tuple([random.uniform(self.scale_range[0], self.scale_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else (1, 1, 1)
gaussian_sigma = tuple([random.uniform(self.gaussian_sigma_range[0], self.gaussian_sigma_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else None
gaussian_noise_std = random.uniform(self.gaussian_noise_std_range[0], self.gaussian_noise_std_range[1]) if random.random() < self.apply_prob else None
flip_axes = (0,1) if random.random() < self.apply_prob else None # Determine if and along which axes to flip
flip_x = 0 if random.random() < self.apply_prob else None
flip_y = 1 if random.random() < self.apply_prob else None
flip_z = 2 if random.random() < self.apply_prob else None
offset = random.randint(50,100) if random.random() < self.apply_prob else None
gammafactor = random.uniform(0.5,2.0) if random.random() < self.apply_prob else 1
affine_transform = Affined(keys=["image"], rotate_params=rotate_params, translate_params=translate_params, scale_params=scale_params, padding_mode='zeros')
gaussian_blur_transform = GaussianSmoothd(keys=["image"], sigma=gaussian_sigma) if gaussian_sigma else None
gaussian_noise_transform = RandGaussianNoised(keys=["image"], std=gaussian_noise_std, prob=1.0, mean=0.0, sample_std=False) if gaussian_noise_std else None
#flip_transform = Rotate90d(keys=["image"], k=1, spatial_axes=flip_axes) if flip_axes else None
flip_x_transform = Flipd(keys=["image"], spatial_axis=flip_x) if flip_x else None
flip_y_transform = Flipd(keys=["image"], spatial_axis=flip_y) if flip_y else None
flip_z_transform = Flipd(keys=["image"], spatial_axis=flip_z) if flip_z else None
resize_transform = Resized(spatial_size=(128,128,128), keys=["image"])
scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) # Intensity scaling
tensor_transform = ToTensord(keys=["image"]) # Convert to tensor
shift_intensity = StdShiftIntensityd(keys = ["image"], factor = offset, nonzero=True)
adjust_contrast = AdjustContrastd(keys = ["image"], gamma = gammafactor)
for scan in scan_list:
sample = {"image": scan}
sample = resize_transform(sample)
sample = affine_transform(sample)
if flip_x_transform:
sample = flip_x_transform(sample)
if flip_y_transform:
sample = flip_y_transform(sample)
if flip_z_transform:
sample = flip_z_transform(sample)
if gaussian_blur_transform:
sample = gaussian_blur_transform(sample)
if offset:
sample = shift_intensity(sample)
sample = scale_transform(sample)
sample = adjust_contrast(sample)
if gaussian_noise_transform:
sample = gaussian_noise_transform(sample)
sample = tensor_transform(sample)
transformed_scans.append(sample["image"].squeeze())
return torch.stack(transformed_scans)
class TransformationMedicalImageDatasetBalancedIntensity3D(Dataset):
""" Training Dataset class """
def __init__(self, csv_path, root_dir, transform=None):
self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str})
self.root_dir = root_dir
self.transform = SynchronizedTransform3D() # calls training augmentations
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
## load the niftis from csv
pat_id = str(self.dataframe.loc[idx, 'pat_id'])
scan_dates = str(self.dataframe.loc[idx, 'scandate'])
label = self.dataframe.loc[idx, 'label']
scandates = scan_dates.split('-')
scan_list = []
for scandate in scandates:
img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz") #f"{pat_id}_{scandate}.nii.gz")
scan = nib.load(img_name).get_fdata()
scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0))
# package into a monai type dictionary
transformed_scans = self.transform(scan_list)
sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id}
return sample