import os import torch import pandas as pd from torch.utils.data import Dataset import nibabel as nib from monai.transforms import Affined, RandGaussianNoised, Rand3DElasticd, AdjustContrastd, ScaleIntensityd, ToTensord, Resized, RandRotate90d, Resize, RandGaussianSmoothd, GaussianSmoothd, Rotate90d, StdShiftIntensityd, RandAdjustContrastd, Flipd import random import numpy as np ####################################### ## 3D SYNC TRANSFORM ####################################### class NormalSynchronizedTransform3D: """ Vanilla Validation Transforms""" def __init__(self, image_size=(128,128,128), max_rotation=40, translate_range=0.2, scale_range=(0.9, 1.3), apply_prob=0.5): self.image_size = image_size self.max_rotation = max_rotation self.translate_range = translate_range self.scale_range = scale_range self.apply_prob = apply_prob def __call__(self, scan_list): transformed_scans = [] resize_transform = Resized(spatial_size=(128,128,128), keys=["image"]) scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) # Intensity scaling tensor_transform = ToTensord(keys=["image"]) # Convert to tensor for scan in scan_list: sample = {"image": scan} sample = resize_transform(sample) sample = scale_transform(sample) sample = tensor_transform(sample) transformed_scans.append(sample["image"].squeeze()) return torch.stack(transformed_scans) class MedicalImageDatasetBalancedIntensity3D(Dataset): """ Validation Dataset class """ def __init__(self, csv_path, root_dir, transform=None): self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str}) self.root_dir = root_dir self.transform = NormalSynchronizedTransform3D() def __len__(self): return len(self.dataframe) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() ## load the niftis from csv pat_id = str(self.dataframe.loc[idx, 'pat_id']) scan_dates = str(self.dataframe.loc[idx, 'scandate']) label = self.dataframe.loc[idx, 'label'] scandates = scan_dates.split('-') scan_list = [] for scandate in scandates: img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz") scan = nib.load(img_name).get_fdata() scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0)) ## package into a dictionary for val loader transformed_scans = self.transform(scan_list) sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id} return sample class SynchronizedTransform3D: """ Trainign Augmentation method """ def __init__(self, image_size=(128,128,128), max_rotation=0.34, translate_range=15, scale_range=(0.9, 1.3), apply_prob=0.5, gaussian_sigma_range=(0.25, 1.5), gaussian_noise_std_range=(0.05, 0.09)): self.image_size = image_size self.max_rotation = max_rotation self.translate_range = translate_range self.scale_range = scale_range self.apply_prob = apply_prob self.gaussian_sigma_range = gaussian_sigma_range self.gaussian_noise_std_range = gaussian_noise_std_range def __call__(self, scan_list): transformed_scans = [] rotate_params = (random.uniform(-self.max_rotation, self.max_rotation),) * 3 if random.random() < self.apply_prob else (0, 0, 0) translate_params = tuple([random.uniform(-self.translate_range, self.translate_range) for _ in range(3)]) if random.random() < self.apply_prob else (0, 0, 0) scale_params = tuple([random.uniform(self.scale_range[0], self.scale_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else (1, 1, 1) gaussian_sigma = tuple([random.uniform(self.gaussian_sigma_range[0], self.gaussian_sigma_range[1]) for _ in range(3)]) if random.random() < self.apply_prob else None gaussian_noise_std = random.uniform(self.gaussian_noise_std_range[0], self.gaussian_noise_std_range[1]) if random.random() < self.apply_prob else None flip_axes = (0,1) if random.random() < self.apply_prob else None # Determine if and along which axes to flip flip_x = 0 if random.random() < self.apply_prob else None flip_y = 1 if random.random() < self.apply_prob else None flip_z = 2 if random.random() < self.apply_prob else None offset = random.randint(50,100) if random.random() < self.apply_prob else None gammafactor = random.uniform(0.5,2.0) if random.random() < self.apply_prob else 1 affine_transform = Affined(keys=["image"], rotate_params=rotate_params, translate_params=translate_params, scale_params=scale_params, padding_mode='zeros') gaussian_blur_transform = GaussianSmoothd(keys=["image"], sigma=gaussian_sigma) if gaussian_sigma else None gaussian_noise_transform = RandGaussianNoised(keys=["image"], std=gaussian_noise_std, prob=1.0, mean=0.0, sample_std=False) if gaussian_noise_std else None #flip_transform = Rotate90d(keys=["image"], k=1, spatial_axes=flip_axes) if flip_axes else None flip_x_transform = Flipd(keys=["image"], spatial_axis=flip_x) if flip_x else None flip_y_transform = Flipd(keys=["image"], spatial_axis=flip_y) if flip_y else None flip_z_transform = Flipd(keys=["image"], spatial_axis=flip_z) if flip_z else None resize_transform = Resized(spatial_size=(128,128,128), keys=["image"]) scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) # Intensity scaling tensor_transform = ToTensord(keys=["image"]) # Convert to tensor shift_intensity = StdShiftIntensityd(keys = ["image"], factor = offset, nonzero=True) adjust_contrast = AdjustContrastd(keys = ["image"], gamma = gammafactor) for scan in scan_list: sample = {"image": scan} sample = resize_transform(sample) sample = affine_transform(sample) if flip_x_transform: sample = flip_x_transform(sample) if flip_y_transform: sample = flip_y_transform(sample) if flip_z_transform: sample = flip_z_transform(sample) if gaussian_blur_transform: sample = gaussian_blur_transform(sample) if offset: sample = shift_intensity(sample) sample = scale_transform(sample) sample = adjust_contrast(sample) if gaussian_noise_transform: sample = gaussian_noise_transform(sample) sample = tensor_transform(sample) transformed_scans.append(sample["image"].squeeze()) return torch.stack(transformed_scans) class TransformationMedicalImageDatasetBalancedIntensity3D(Dataset): """ Training Dataset class """ def __init__(self, csv_path, root_dir, transform=None): self.dataframe = pd.read_csv(csv_path, dtype={"pat_id":str, "scandate":str}) self.root_dir = root_dir self.transform = SynchronizedTransform3D() # calls training augmentations def __len__(self): return len(self.dataframe) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() ## load the niftis from csv pat_id = str(self.dataframe.loc[idx, 'pat_id']) scan_dates = str(self.dataframe.loc[idx, 'scandate']) label = self.dataframe.loc[idx, 'label'] scandates = scan_dates.split('-') scan_list = [] for scandate in scandates: img_name = os.path.join(self.root_dir , f"{pat_id}_{scandate}.nii.gz") #f"{pat_id}_{scandate}.nii.gz") scan = nib.load(img_name).get_fdata() scan_list.append(torch.tensor(scan, dtype=torch.float32).unsqueeze(0)) # package into a monai type dictionary transformed_scans = self.transform(scan_list) sample = {"image": transformed_scans, "label": torch.tensor(label, dtype=torch.float32), "pat_id": pat_id} return sample