diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..e58e78c478639f2f6d74b44964cfcf73e4e962e9 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,37 @@ +# Use an official Python runtime as a parent image +FROM python:3.10-slim + +# Set the working directory in the container +WORKDIR /app + +# Install necessary system dependencies (kept minimal) +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + git \ + libgl1 \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Copy the requirements file first to leverage Docker cache +COPY requirements.txt ./ + +# Install Python packages +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the entire project +COPY . /app/ + +# Create a non-root user (HF Spaces requirement) +RUN useradd -m -u 1000 user +USER user + +# Make sure the user owns the app directory +COPY --chown=user:user . /app/ + +# Expose Gradio default port +EXPOSE 7860 + +ENV PYTHONUNBUFFERED=1 + +# Run the app from the src/IDH directory +CMD ["python", "src/IDH/app_gradio.py"] \ No newline at end of file diff --git a/README.md b/README.md index 1ac9046c0279f3f9399b2627830d9cbd33866804..09d1a29b42687cd6690f335af0e623f9ea004d4b 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,28 @@ --- title: BrainIAC Glioma Segmentation -emoji: đ -colorFrom: red -colorTo: yellow +emoji: đ§ +colorFrom: blue +colorTo: red sdk: docker pinned: false -license: cc-by-4.0 +license: mit --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# BrainIAC: Glioma Segmentation + +A Vision Transformer UNETR model for glioma segmentation from FLAIR MRI scans. + +## Features +- Upload FLAIR MRI NIfTI files +- Optional preprocessing (debiasing + registration + skull stripping) +- Interactive slice-by-slice visualization +- Download preprocessed images and segmentation masks +- Real-time segmentation statistics + +## Usage +1. Upload a FLAIR MRI scan (.nii or .nii.gz) +2. Optionally enable preprocessing +3. Adjust segmentation threshold +4. View results and download files + +*Research use only* \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..02c167ed13eecb98e5567102e78f5e726b4196fe --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +monai==1.3.2 +nibabel==5.2.1 +numpy==1.23.5 +pydicom +PyYAML +pytorch-lightning==2.3.3 +scipy==1.10.1 +SimpleITK==2.4.0 +torch==2.6.0 +tqdm +gradio +pandas +scikit-image==0.21.0 +opencv-python +itk-elastix +dicom2nifti +einops +matplotlib \ No newline at end of file diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-310.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd5fefbfdd6c33ded5680391725b1501da66b8d6 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-38.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61cf361fe68cd40f9c85c2fb100d27c9d1cfb4c0 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-39.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0804f212002980f7ee1b2de3b9a49d5d7b34e05 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-310.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84c28b095603940c02cc2ea9fa91b0b0446c10eb Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-38.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04949544425bca1aba1f446b3fdb054012d51adc Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-39.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c87337e83582ef7e61d9c7a027040f52afef166 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-310.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffe26540ea77664aa42c19aee6a76b18156a53d5 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-38.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81aa4cfb5eb39f270424b98107c40d28744386eb Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-310.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e496ed582831f96a58feb9e3865c4081ab9d11db Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-38.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c20f2e3c729378c589e22eff4e8ec151d20b470 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-39.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5516dcf581d4e55ef49ee14ff2dbac3860f0086 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-310.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac20a90e2d0069f0ca86e7dabb67a0baa1e93440 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-38.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..567f623c1fa55d81bc9c1d4001d6aea1ab274f5f Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-39.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a2228880835eb9a907395309369a846da5010f2 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-310.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3511e128de506944ecb85ca4fed03b31b5531dc Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-38.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a8f633c24280649275fb7bb7182991c07ced8eb Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-39.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..294527e04fcfd5d3be038dc1c8712e925d3352ff Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-310.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0cce40e0b34a33a6d930014fc91d903d4892fb5 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-38.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5747d903a2b450fbeb16df2ec1f1d3680fcc6a48 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-39.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51e286390fe2b53d201a211aee37bf7b952fcfc3 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-310.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4388f005b876e63a702375b68a13143344f4dff9 Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-38.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad8b6ff21be7f57044f1a5e0f76501545f0fb97d Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-39.pyc b/src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe6e8221daae04ee32f9159e87904a914887135d Binary files /dev/null and b/src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/HD_BET/config.py b/src/IDH/HD_BET/HD_BET/config.py new file mode 100644 index 0000000000000000000000000000000000000000..870951e5c9059fb9e20d6143e68266732f19234e --- /dev/null +++ b/src/IDH/HD_BET/HD_BET/config.py @@ -0,0 +1,121 @@ +import numpy as np +import torch +from HD_BET.utils import SetNetworkToVal, softmax_helper +from abc import abstractmethod +from HD_BET.network_architecture import Network + + +class BaseConfig(object): + def __init__(self): + pass + + @abstractmethod + def get_split(self, fold, random_state=12345): + pass + + @abstractmethod + def get_network(self, mode="train"): + pass + + @abstractmethod + def get_basic_generators(self, fold): + pass + + @abstractmethod + def get_data_generators(self, fold): + pass + + def preprocess(self, data): + return data + + def __repr__(self): + res = "" + for v in vars(self): + if not v.startswith("__") and not v.startswith("_") and v != 'dataset': + res += (v + ": " + str(self.__getattribute__(v)) + "\n") + return res + + +class HD_BET_Config(BaseConfig): + def __init__(self): + super(HD_BET_Config, self).__init__() + + self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name + + # network parameters + self.net_base_num_layers = 21 + self.BATCH_SIZE = 2 + self.net_do_DS = True + self.net_dropout_p = 0.0 + self.net_use_inst_norm = True + self.net_conv_use_bias = True + self.net_norm_use_affine = True + self.net_leaky_relu_slope = 1e-1 + + # hyperparameters + self.INPUT_PATCH_SIZE = (128, 128, 128) + self.num_classes = 2 + self.selected_data_channels = range(1) + + # data augmentation + self.da_mirror_axes = (2, 3, 4) + + # validation + self.val_use_DO = False + self.val_use_train_mode = False # for dropout sampling + self.val_num_repeats = 1 # only useful if dropout sampling + self.val_batch_size = 1 # only useful if dropout sampling + self.val_save_npz = True + self.val_do_mirroring = True # test time data augmentation via mirroring + self.val_write_images = True + self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property + self.val_min_size = self.INPUT_PATCH_SIZE + self.val_fn = None + + # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_ + # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults + # to false in 0.4) + self.val_use_moving_averages = False + + def get_network(self, train=True, pretrained_weights=None): + net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers, + self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias, + self.net_norm_use_affine, True, self.net_do_DS) + + if pretrained_weights is not None: + net.load_state_dict( + torch.load(pretrained_weights, map_location=lambda storage, loc: storage)) + + if train: + net.train(True) + else: + net.train(False) + net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages)) + net.do_ds = False + + optimizer = None + self.lr_scheduler = None + return net, optimizer + + def get_data_generators(self, fold): + pass + + def get_split(self, fold, random_state=12345): + pass + + def get_basic_generators(self, fold): + pass + + def on_epoch_end(self, epoch): + pass + + def preprocess(self, data): + data = np.copy(data) + for c in range(data.shape[0]): + data[c] -= data[c].mean() + data[c] /= data[c].std() + return data + + +config = HD_BET_Config + diff --git a/src/IDH/HD_BET/HD_BET/data_loading.py b/src/IDH/HD_BET/HD_BET/data_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec4be63a8186b65bfb390770fefa6217b5dd2c5 --- /dev/null +++ b/src/IDH/HD_BET/HD_BET/data_loading.py @@ -0,0 +1,121 @@ +import SimpleITK as sitk +import numpy as np +from skimage.transform import resize + + +def resize_image(image, old_spacing, new_spacing, order=3): + new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), + int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), + int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) + return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False) + + +def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)): + spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]] + image = sitk.GetArrayFromImage(itk_image).astype(float) + + assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed" + + if not is_seg: + if np.any([[i != j] for i, j in zip(spacing, spacing_target)]): + image = resize_image(image, spacing, spacing_target).astype(np.float32) + + image -= image.mean() + image /= image.std() + else: + new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))), + int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))), + int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2])))) + image = resize_segmentation(image, new_shape, 1) + return image + + +def load_and_preprocess(mri_file): + images = {} + # t1 + images["T1"] = sitk.ReadImage(mri_file) + + properties_dict = { + "spacing": images["T1"].GetSpacing(), + "direction": images["T1"].GetDirection(), + "size": images["T1"].GetSize(), + "origin": images["T1"].GetOrigin() + } + + for k in images.keys(): + images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5)) + + properties_dict['size_before_cropping'] = images["T1"].shape + + imgs = [] + for seq in ['T1']: + imgs.append(images[seq][None]) + all_data = np.vstack(imgs) + print("image shape after preprocessing: ", str(all_data[0].shape)) + return all_data, properties_dict + + +def save_segmentation_nifti(segmentation, dct, out_fname, order=1): + ''' + segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out + of the original image + + dct: + size_before_cropping + brain_bbox + size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping + spacing + origin + direction + + :param segmentation: + :param dct: + :param out_fname: + :return: + ''' + old_size = dct.get('size_before_cropping') + bbox = dct.get('brain_bbox') + if bbox is not None: + seg_old_size = np.zeros(old_size) + for c in range(3): + bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c])) + seg_old_size[bbox[0][0]:bbox[0][1], + bbox[1][0]:bbox[1][1], + bbox[2][0]:bbox[2][1]] = segmentation + else: + seg_old_size = segmentation + if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]): + seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order) + else: + seg_old_spacing = seg_old_size + seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32)) + seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]]) + seg_resized_itk.SetOrigin(dct['origin']) + seg_resized_itk.SetDirection(dct['direction']) + sitk.WriteImage(seg_resized_itk, out_fname) + + +def resize_segmentation(segmentation, new_shape, order=3, cval=0): + ''' + Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency + + Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one + hot encoding which is resized and transformed back to a segmentation map. + This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2]) + :param segmentation: + :param new_shape: + :param order: + :return: + ''' + tpe = segmentation.dtype + unique_labels = np.unique(segmentation) + assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" + if order == 0: + return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe) + else: + reshaped = np.zeros(new_shape, dtype=segmentation.dtype) + + for i, c in enumerate(unique_labels): + reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) + reshaped[reshaped_multihot >= 0.5] = c + return reshaped diff --git a/src/IDH/HD_BET/HD_BET/hd_bet.py b/src/IDH/HD_BET/HD_BET/hd_bet.py new file mode 100644 index 0000000000000000000000000000000000000000..128575b6cfb4bdd98bf417ed598f905ef4896fd1 --- /dev/null +++ b/src/IDH/HD_BET/HD_BET/hd_bet.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +import os +import sys +sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/") +from HD_BET.run import run_hd_bet +from HD_BET.utils import maybe_mkdir_p, subfiles +import HD_BET + +def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1): + + if output_file_or_dir is None: + output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir), + os.path.basename(input_file_or_dir).split(".")[0] + "_bet") + + + params_file = os.path.join(HD_BET.__path__[0], "model_final.py") + config_file = os.path.join(HD_BET.__path__[0], "config.py") + + assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" + + if device == 'cpu': + pass + else: + device = int(device) + + if os.path.isdir(input_file_or_dir): + maybe_mkdir_p(output_file_or_dir) + input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False) + + if len(input_files) == 0: + raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here") + + output_files = [os.path.join(output_file_or_dir, i) for i in input_files] + input_files = [os.path.join(input_file_or_dir, i) for i in input_files] + else: + if not output_file_or_dir.endswith('.nii.gz'): + output_file_or_dir += '.nii.gz' + assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" + + output_files = [output_file_or_dir] + input_files = [input_file_or_dir] + + if tta == 0: + tta = False + elif tta == 1: + tta = True + else: + raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta)) + + if overwrite_existing == 0: + overwrite_existing = False + elif overwrite_existing == 1: + overwrite_existing = True + else: + raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing)) + + if pp == 0: + pp = False + elif pp == 1: + pp = True + else: + raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp)) + + if save_mask == 0: + save_mask = False + elif save_mask == 1: + save_mask = True + else: + raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp)) + + run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing) + + +if __name__ == "__main__": + print("\n########################") + print("If you are using hd-bet, please cite the following paper:") + print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W," + "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial" + "neural networks. arXiv preprint arXiv:1901.11341, 2019.") + print("########################\n") + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be ' + 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to ' + 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz ' + 'within that folder will be brain extracted.', required=True, type=str) + parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder' + ' will be created', required=False, type=str) + parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will ' + 'use only one set of parameters whereas accurate will ' + 'use the five sets of parameters that resulted from ' + 'our cross-validation as an ensemble. Default: ' + 'accurate', + required=False) + parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. ' + 'Must be either int or str. Use int for GPU id or ' + '\'cpu\' to run on CPU. When using CPU you should ' + 'consider disabling tta. Default for -device is: 0', + required=False) + parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation ' + '(mirroring). 1= True, 0=False. Disable this ' + 'if you are using CPU to speed things up! ' + 'Default: 1') + parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all' + ' but the largest connected component in ' + 'the prediction. Default: 1') + parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation ' + 'mask will not be ' + 'saved') + parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't " + "want to overwrite existing " + "predictions") + + args = parser.parse_args() + + hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing) + diff --git a/src/IDH/HD_BET/HD_BET/network_architecture.py b/src/IDH/HD_BET/HD_BET/network_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..0824aa10839024368ad8ab38c637ce81aa9327e5 --- /dev/null +++ b/src/IDH/HD_BET/HD_BET/network_architecture.py @@ -0,0 +1,213 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from HD_BET.utils import softmax_helper + + +class EncodingModule(nn.Module): + def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True): + nn.Module.__init__(self) + self.dropout_p = dropout_p + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) + self.dropout = nn.Dropout3d(dropout_p) + self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) + + def forward(self, x): + skip = x + x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = self.conv1(x) + if self.dropout_p is not None and self.dropout_p > 0: + x = self.dropout(x) + x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = self.conv2(x) + x = x + skip + return x + + +class Upsample(nn.Module): + def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True): + super(Upsample, self).__init__() + self.align_corners = align_corners + self.mode = mode + self.scale_factor = scale_factor + self.size = size + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, + align_corners=self.align_corners) + + +class LocalizationModule(nn.Module): + def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias) + self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias) + self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) + + def forward(self, x): + x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + return x + + +class UpsamplingModule(nn.Module): + def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True) + self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias) + self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) + + def forward(self, x): + x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness, + inplace=self.lrelu_inplace) + return x + + +class DownsamplingModule(nn.Module): + def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias) + + def forward(self, x): + x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + b = self.downsample(x) + return x, b + + +class Network(nn.Module): + def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3, + final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True, do_ds=True): + super(Network, self).__init__() + + self.do_ds = do_ds + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.final_nonlin = final_nonlin + self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias) + + self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2, + conv_bias=True, inst_norm_affine=True, lrelu_inplace=True) + + self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False) + self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) + self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) + self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) + self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) + + def forward(self, x): + seg_outputs = [] + + x = self.init_conv(x) + x = self.context1(x) + + skip1, x = self.down1(x) + x = self.context2(x) + + skip2, x = self.down2(x) + x = self.context3(x) + + skip3, x = self.down3(x) + x = self.context4(x) + + skip4, x = self.down4(x) + x = self.context5(x) + + x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = self.up1(x) + + x = torch.cat((skip4, x), dim=1) + x = self.loc1(x) + x = self.up2(x) + + x = torch.cat((skip3, x), dim=1) + x = self.loc2(x) + loc2_seg = self.final_nonlin(self.loc2_seg(x)) + seg_outputs.append(loc2_seg) + x = self.up3(x) + + x = torch.cat((skip2, x), dim=1) + x = self.loc3(x) + loc3_seg = self.final_nonlin(self.loc3_seg(x)) + seg_outputs.append(loc3_seg) + x = self.up4(x) + + x = torch.cat((skip1, x), dim=1) + x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness, + inplace=self.lrelu_inplace) + x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness, + inplace=self.lrelu_inplace) + x = self.final_nonlin(self.seg_layer(x)) + seg_outputs.append(x) + + if self.do_ds: + return seg_outputs[::-1] + else: + return seg_outputs[-1] diff --git a/src/IDH/HD_BET/HD_BET/paths.py b/src/IDH/HD_BET/HD_BET/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..9d125b947184c03d1a12ebd77e39a75a730c4594 --- /dev/null +++ b/src/IDH/HD_BET/HD_BET/paths.py @@ -0,0 +1,6 @@ +import os + +# please refer to the readme on where to get the parameters. Save them in this folder: +# Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params" +# Updated path for Docker container: +folder_with_parameter_files = "/app/IDH/hdbet_model" diff --git a/src/IDH/HD_BET/HD_BET/predict_case.py b/src/IDH/HD_BET/HD_BET/predict_case.py new file mode 100644 index 0000000000000000000000000000000000000000..559c66739ae890f7e985e072eb49ce0ee0484978 --- /dev/null +++ b/src/IDH/HD_BET/HD_BET/predict_case.py @@ -0,0 +1,126 @@ +import torch +import numpy as np + + +def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None): + if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)): + shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3 + shp = patient.shape + new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0], + shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1], + shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]] + for i in range(len(shp)): + if shp[i] % shape_must_be_divisible_by[i] == 0: + new_shp[i] -= shape_must_be_divisible_by[i] + if min_size is not None: + new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0) + return reshape_by_padding_upper_coords(patient, new_shp, 0), shp + + +def reshape_by_padding_upper_coords(image, new_shape, pad_value=None): + shape = tuple(list(image.shape)) + new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0)) + if pad_value is None: + if len(shape) == 2: + pad_value = image[0,0] + elif len(shape) == 3: + pad_value = image[0, 0, 0] + else: + raise ValueError("Image must be either 2 or 3 dimensional") + res = np.ones(list(new_shape), dtype=image.dtype) * pad_value + if len(shape) == 2: + res[0:0+int(shape[0]), 0:0+int(shape[1])] = image + elif len(shape) == 3: + res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image + return res + + +def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None, + new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)): + with torch.no_grad(): + pad_res = [] + for i in range(patient_data.shape[0]): + t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size) + pad_res.append(t[None]) + + patient_data = np.vstack(pad_res) + + new_shp = patient_data.shape + + data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32) + + data[0] = patient_data + + if BATCH_SIZE is not None: + data = np.vstack([data] * BATCH_SIZE) + + a = torch.rand(data.shape).float() + + if main_device == 'cpu': + pass + else: + a = a.cuda(main_device) + + if do_mirroring: + x = 8 + else: + x = 1 + all_preds = [] + for i in range(num_repeats): + for m in range(x): + data_for_net = np.array(data) + do_stuff = False + if m == 0: + do_stuff = True + pass + if m == 1 and (4 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, :, :, ::-1] + if m == 2 and (3 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, :, ::-1, :] + if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, :, ::-1, ::-1] + if m == 4 and (2 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, :, :] + if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, :, ::-1] + if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, ::-1, :] + if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1] + + if do_stuff: + _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net))) + p = net(a) # np.copy is necessary because ::-1 creates just a view i think + p = p.data.cpu().numpy() + + if m == 0: + pass + if m == 1 and (4 in mirror_axes): + p = p[:, :, :, :, ::-1] + if m == 2 and (3 in mirror_axes): + p = p[:, :, :, ::-1, :] + if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): + p = p[:, :, :, ::-1, ::-1] + if m == 4 and (2 in mirror_axes): + p = p[:, :, ::-1, :, :] + if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): + p = p[:, :, ::-1, :, ::-1] + if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): + p = p[:, :, ::-1, ::-1, :] + if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + p = p[:, :, ::-1, ::-1, ::-1] + all_preds.append(p) + + stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]] + predicted_segmentation = stacked.mean(0).argmax(0) + uncertainty = stacked.var(0) + bayesian_predictions = stacked + softmax_pred = stacked.mean(0) + return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty diff --git a/src/IDH/HD_BET/HD_BET/run.py b/src/IDH/HD_BET/HD_BET/run.py new file mode 100644 index 0000000000000000000000000000000000000000..858934d8f67175df508884e9030f8d38ba0d07cf --- /dev/null +++ b/src/IDH/HD_BET/HD_BET/run.py @@ -0,0 +1,117 @@ +import torch +import numpy as np +import SimpleITK as sitk +from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti +from HD_BET.predict_case import predict_case_3D_net +import imp +from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters +import os +import HD_BET + + +def apply_bet(img, bet, out_fname): + img_itk = sitk.ReadImage(img) + img_npy = sitk.GetArrayFromImage(img_itk) + img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet)) + img_npy[img_bet == 0] = 0 + out = sitk.GetImageFromArray(img_npy) + out.CopyInformation(img_itk) + sitk.WriteImage(out, out_fname) + + +def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0, + postprocess=False, do_tta=True, keep_mask=True, overwrite=True): + """ + + :param mri_fnames: str or list/tuple of str + :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames + :param mode: fast or accurate + :param config_file: config.py + :param device: either int (for device id) or 'cpu' + :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all + but the largest predicted connected component. Default False + :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use + CPU you may want to turn that off to speed things up + :return: + """ + + list_of_param_files = [] + + if mode == 'fast': + params_file = get_params_fname(0) + maybe_download_parameters(0) + + list_of_param_files.append(params_file) + elif mode == 'accurate': + for i in range(5): + params_file = get_params_fname(i) + maybe_download_parameters(i) + + list_of_param_files.append(params_file) + else: + raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode) + + assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files" + + cf = imp.load_source('cf', config_file) + cf = cf.config() + + net, _ = cf.get_network(cf.val_use_train_mode, None) + if device == "cpu": + net = net.cpu() + else: + net.cuda(device) + + if not isinstance(mri_fnames, (list, tuple)): + mri_fnames = [mri_fnames] + + if not isinstance(output_fnames, (list, tuple)): + output_fnames = [output_fnames] + + assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length" + + params = [] + for p in list_of_param_files: + params.append(torch.load(p, map_location=lambda storage, loc: storage)) + + for in_fname, out_fname in zip(mri_fnames, output_fnames): + mask_fname = out_fname[:-7] + "_mask.nii.gz" + if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)): + print("File:", in_fname) + print("preprocessing...") + try: + data, data_dict = load_and_preprocess(in_fname) + except RuntimeError: + print("\nERROR\nCould not read file", in_fname, "\n") + continue + except AssertionError as e: + print(e) + continue + + softmax_preds = [] + + print("prediction (CNN id)...") + for i, p in enumerate(params): + print(i) + net.load_state_dict(p) + net.eval() + net.apply(SetNetworkToVal(False, False)) + _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats, + cf.val_batch_size, cf.net_input_must_be_divisible_by, + cf.val_min_size, device, cf.da_mirror_axes) + softmax_preds.append(softmax_pred[None]) + + seg = np.argmax(np.vstack(softmax_preds).mean(0), 0) + + if postprocess: + seg = postprocess_prediction(seg) + + print("exporting segmentation...") + save_segmentation_nifti(seg, data_dict, mask_fname) + + apply_bet(in_fname, mask_fname, out_fname) + + if not keep_mask: + os.remove(mask_fname) + + diff --git a/src/IDH/HD_BET/HD_BET/utils.py b/src/IDH/HD_BET/HD_BET/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba72a3d4d70accfd1fdc313a2f80b8d4c4c6eea --- /dev/null +++ b/src/IDH/HD_BET/HD_BET/utils.py @@ -0,0 +1,115 @@ +from urllib.request import urlopen +import torch +from torch import nn +import numpy as np +from skimage.morphology import label +import os +from HD_BET.paths import folder_with_parameter_files + + +def get_params_fname(fold): + return os.path.join(folder_with_parameter_files, "%d.model" % fold) + + +def maybe_download_parameters(fold=0, force_overwrite=False): + """ + Downloads the parameters for some fold if it is not present yet. + :param fold: + :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download + :return: + """ + + assert 0 <= fold <= 4, "fold must be between 0 and 4" + + if not os.path.isdir(folder_with_parameter_files): + maybe_mkdir_p(folder_with_parameter_files) + + out_filename = get_params_fname(fold) + + if force_overwrite and os.path.isfile(out_filename): + os.remove(out_filename) + + if not os.path.isfile(out_filename): + url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold + print("Downloading", url, "...") + data = urlopen(url).read() + #out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model" + with open(out_filename, 'wb') as f: + f.write(data) + + +def init_weights(module): + if isinstance(module, nn.Conv3d): + module.weight = nn.init.kaiming_normal(module.weight, a=1e-2) + if module.bias is not None: + module.bias = nn.init.constant(module.bias, 0) + + +def softmax_helper(x): + rpt = [1 for _ in range(len(x.size()))] + rpt[1] = x.size(1) + x_max = x.max(1, keepdim=True)[0].repeat(*rpt) + e_x = torch.exp(x - x_max) + return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) + + +class SetNetworkToVal(object): + def __init__(self, use_dropout_sampling=False, norm_use_average=True): + self.norm_use_average = norm_use_average + self.use_dropout_sampling = use_dropout_sampling + + def __call__(self, module): + if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout): + module.train(self.use_dropout_sampling) + elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \ + isinstance(module, nn.InstanceNorm1d) \ + or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \ + isinstance(module, nn.BatchNorm1d): + module.train(not self.norm_use_average) + + +def postprocess_prediction(seg): + # basically look for connected components and choose the largest one, delete everything else + print("running postprocessing... ") + mask = seg != 0 + lbls = label(mask, connectivity=mask.ndim) + lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)] + largest_region = np.argmax(lbls_sizes[1:]) + 1 + seg[lbls != largest_region] = 0 + return seg + + +def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): + if join: + l = os.path.join + else: + l = lambda x, y: y + res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix))] + if sort: + res.sort() + return res + + +def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): + if join: + l = os.path.join + else: + l = lambda x, y: y + res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix))] + if sort: + res.sort() + return res + + +subfolders = subdirs # I am tired of confusing those + + +def maybe_mkdir_p(directory): + splits = directory.split("/")[1:] + for i in range(0, len(splits)): + if not os.path.isdir(os.path.join("", *splits[:i+1])): + os.mkdir(os.path.join("", *splits[:i+1])) diff --git a/src/IDH/HD_BET/__pycache__/config.cpython-310.pyc b/src/IDH/HD_BET/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd5fefbfdd6c33ded5680391725b1501da66b8d6 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/config.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/config.cpython-38.pyc b/src/IDH/HD_BET/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61cf361fe68cd40f9c85c2fb100d27c9d1cfb4c0 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/config.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/config.cpython-39.pyc b/src/IDH/HD_BET/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0804f212002980f7ee1b2de3b9a49d5d7b34e05 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/config.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/data_loading.cpython-310.pyc b/src/IDH/HD_BET/__pycache__/data_loading.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84c28b095603940c02cc2ea9fa91b0b0446c10eb Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/data_loading.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/data_loading.cpython-38.pyc b/src/IDH/HD_BET/__pycache__/data_loading.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04949544425bca1aba1f446b3fdb054012d51adc Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/data_loading.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/data_loading.cpython-39.pyc b/src/IDH/HD_BET/__pycache__/data_loading.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c87337e83582ef7e61d9c7a027040f52afef166 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/data_loading.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/hd_bet.cpython-310.pyc b/src/IDH/HD_BET/__pycache__/hd_bet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffe26540ea77664aa42c19aee6a76b18156a53d5 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/hd_bet.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/hd_bet.cpython-38.pyc b/src/IDH/HD_BET/__pycache__/hd_bet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81aa4cfb5eb39f270424b98107c40d28744386eb Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/hd_bet.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/network_architecture.cpython-310.pyc b/src/IDH/HD_BET/__pycache__/network_architecture.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e496ed582831f96a58feb9e3865c4081ab9d11db Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/network_architecture.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/network_architecture.cpython-38.pyc b/src/IDH/HD_BET/__pycache__/network_architecture.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c20f2e3c729378c589e22eff4e8ec151d20b470 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/network_architecture.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/network_architecture.cpython-39.pyc b/src/IDH/HD_BET/__pycache__/network_architecture.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5516dcf581d4e55ef49ee14ff2dbac3860f0086 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/network_architecture.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/paths.cpython-310.pyc b/src/IDH/HD_BET/__pycache__/paths.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac20a90e2d0069f0ca86e7dabb67a0baa1e93440 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/paths.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/paths.cpython-38.pyc b/src/IDH/HD_BET/__pycache__/paths.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..567f623c1fa55d81bc9c1d4001d6aea1ab274f5f Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/paths.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/paths.cpython-39.pyc b/src/IDH/HD_BET/__pycache__/paths.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a2228880835eb9a907395309369a846da5010f2 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/paths.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/predict_case.cpython-310.pyc b/src/IDH/HD_BET/__pycache__/predict_case.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3511e128de506944ecb85ca4fed03b31b5531dc Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/predict_case.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/predict_case.cpython-38.pyc b/src/IDH/HD_BET/__pycache__/predict_case.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a8f633c24280649275fb7bb7182991c07ced8eb Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/predict_case.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/predict_case.cpython-39.pyc b/src/IDH/HD_BET/__pycache__/predict_case.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..294527e04fcfd5d3be038dc1c8712e925d3352ff Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/predict_case.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/run.cpython-310.pyc b/src/IDH/HD_BET/__pycache__/run.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0cce40e0b34a33a6d930014fc91d903d4892fb5 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/run.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/run.cpython-38.pyc b/src/IDH/HD_BET/__pycache__/run.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5747d903a2b450fbeb16df2ec1f1d3680fcc6a48 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/run.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/run.cpython-39.pyc b/src/IDH/HD_BET/__pycache__/run.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51e286390fe2b53d201a211aee37bf7b952fcfc3 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/run.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/utils.cpython-310.pyc b/src/IDH/HD_BET/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4388f005b876e63a702375b68a13143344f4dff9 Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/utils.cpython-310.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/utils.cpython-38.pyc b/src/IDH/HD_BET/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad8b6ff21be7f57044f1a5e0f76501545f0fb97d Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/utils.cpython-38.pyc differ diff --git a/src/IDH/HD_BET/__pycache__/utils.cpython-39.pyc b/src/IDH/HD_BET/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe6e8221daae04ee32f9159e87904a914887135d Binary files /dev/null and b/src/IDH/HD_BET/__pycache__/utils.cpython-39.pyc differ diff --git a/src/IDH/HD_BET/config.py b/src/IDH/HD_BET/config.py new file mode 100644 index 0000000000000000000000000000000000000000..870951e5c9059fb9e20d6143e68266732f19234e --- /dev/null +++ b/src/IDH/HD_BET/config.py @@ -0,0 +1,121 @@ +import numpy as np +import torch +from HD_BET.utils import SetNetworkToVal, softmax_helper +from abc import abstractmethod +from HD_BET.network_architecture import Network + + +class BaseConfig(object): + def __init__(self): + pass + + @abstractmethod + def get_split(self, fold, random_state=12345): + pass + + @abstractmethod + def get_network(self, mode="train"): + pass + + @abstractmethod + def get_basic_generators(self, fold): + pass + + @abstractmethod + def get_data_generators(self, fold): + pass + + def preprocess(self, data): + return data + + def __repr__(self): + res = "" + for v in vars(self): + if not v.startswith("__") and not v.startswith("_") and v != 'dataset': + res += (v + ": " + str(self.__getattribute__(v)) + "\n") + return res + + +class HD_BET_Config(BaseConfig): + def __init__(self): + super(HD_BET_Config, self).__init__() + + self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name + + # network parameters + self.net_base_num_layers = 21 + self.BATCH_SIZE = 2 + self.net_do_DS = True + self.net_dropout_p = 0.0 + self.net_use_inst_norm = True + self.net_conv_use_bias = True + self.net_norm_use_affine = True + self.net_leaky_relu_slope = 1e-1 + + # hyperparameters + self.INPUT_PATCH_SIZE = (128, 128, 128) + self.num_classes = 2 + self.selected_data_channels = range(1) + + # data augmentation + self.da_mirror_axes = (2, 3, 4) + + # validation + self.val_use_DO = False + self.val_use_train_mode = False # for dropout sampling + self.val_num_repeats = 1 # only useful if dropout sampling + self.val_batch_size = 1 # only useful if dropout sampling + self.val_save_npz = True + self.val_do_mirroring = True # test time data augmentation via mirroring + self.val_write_images = True + self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property + self.val_min_size = self.INPUT_PATCH_SIZE + self.val_fn = None + + # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_ + # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults + # to false in 0.4) + self.val_use_moving_averages = False + + def get_network(self, train=True, pretrained_weights=None): + net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers, + self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias, + self.net_norm_use_affine, True, self.net_do_DS) + + if pretrained_weights is not None: + net.load_state_dict( + torch.load(pretrained_weights, map_location=lambda storage, loc: storage)) + + if train: + net.train(True) + else: + net.train(False) + net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages)) + net.do_ds = False + + optimizer = None + self.lr_scheduler = None + return net, optimizer + + def get_data_generators(self, fold): + pass + + def get_split(self, fold, random_state=12345): + pass + + def get_basic_generators(self, fold): + pass + + def on_epoch_end(self, epoch): + pass + + def preprocess(self, data): + data = np.copy(data) + for c in range(data.shape[0]): + data[c] -= data[c].mean() + data[c] /= data[c].std() + return data + + +config = HD_BET_Config + diff --git a/src/IDH/HD_BET/data_loading.py b/src/IDH/HD_BET/data_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec4be63a8186b65bfb390770fefa6217b5dd2c5 --- /dev/null +++ b/src/IDH/HD_BET/data_loading.py @@ -0,0 +1,121 @@ +import SimpleITK as sitk +import numpy as np +from skimage.transform import resize + + +def resize_image(image, old_spacing, new_spacing, order=3): + new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), + int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), + int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) + return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False) + + +def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)): + spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]] + image = sitk.GetArrayFromImage(itk_image).astype(float) + + assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed" + + if not is_seg: + if np.any([[i != j] for i, j in zip(spacing, spacing_target)]): + image = resize_image(image, spacing, spacing_target).astype(np.float32) + + image -= image.mean() + image /= image.std() + else: + new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))), + int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))), + int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2])))) + image = resize_segmentation(image, new_shape, 1) + return image + + +def load_and_preprocess(mri_file): + images = {} + # t1 + images["T1"] = sitk.ReadImage(mri_file) + + properties_dict = { + "spacing": images["T1"].GetSpacing(), + "direction": images["T1"].GetDirection(), + "size": images["T1"].GetSize(), + "origin": images["T1"].GetOrigin() + } + + for k in images.keys(): + images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5)) + + properties_dict['size_before_cropping'] = images["T1"].shape + + imgs = [] + for seq in ['T1']: + imgs.append(images[seq][None]) + all_data = np.vstack(imgs) + print("image shape after preprocessing: ", str(all_data[0].shape)) + return all_data, properties_dict + + +def save_segmentation_nifti(segmentation, dct, out_fname, order=1): + ''' + segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out + of the original image + + dct: + size_before_cropping + brain_bbox + size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping + spacing + origin + direction + + :param segmentation: + :param dct: + :param out_fname: + :return: + ''' + old_size = dct.get('size_before_cropping') + bbox = dct.get('brain_bbox') + if bbox is not None: + seg_old_size = np.zeros(old_size) + for c in range(3): + bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c])) + seg_old_size[bbox[0][0]:bbox[0][1], + bbox[1][0]:bbox[1][1], + bbox[2][0]:bbox[2][1]] = segmentation + else: + seg_old_size = segmentation + if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]): + seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order) + else: + seg_old_spacing = seg_old_size + seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32)) + seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]]) + seg_resized_itk.SetOrigin(dct['origin']) + seg_resized_itk.SetDirection(dct['direction']) + sitk.WriteImage(seg_resized_itk, out_fname) + + +def resize_segmentation(segmentation, new_shape, order=3, cval=0): + ''' + Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency + + Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one + hot encoding which is resized and transformed back to a segmentation map. + This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2]) + :param segmentation: + :param new_shape: + :param order: + :return: + ''' + tpe = segmentation.dtype + unique_labels = np.unique(segmentation) + assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" + if order == 0: + return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe) + else: + reshaped = np.zeros(new_shape, dtype=segmentation.dtype) + + for i, c in enumerate(unique_labels): + reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) + reshaped[reshaped_multihot >= 0.5] = c + return reshaped diff --git a/src/IDH/HD_BET/hd_bet.py b/src/IDH/HD_BET/hd_bet.py new file mode 100644 index 0000000000000000000000000000000000000000..128575b6cfb4bdd98bf417ed598f905ef4896fd1 --- /dev/null +++ b/src/IDH/HD_BET/hd_bet.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +import os +import sys +sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/") +from HD_BET.run import run_hd_bet +from HD_BET.utils import maybe_mkdir_p, subfiles +import HD_BET + +def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1): + + if output_file_or_dir is None: + output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir), + os.path.basename(input_file_or_dir).split(".")[0] + "_bet") + + + params_file = os.path.join(HD_BET.__path__[0], "model_final.py") + config_file = os.path.join(HD_BET.__path__[0], "config.py") + + assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" + + if device == 'cpu': + pass + else: + device = int(device) + + if os.path.isdir(input_file_or_dir): + maybe_mkdir_p(output_file_or_dir) + input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False) + + if len(input_files) == 0: + raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here") + + output_files = [os.path.join(output_file_or_dir, i) for i in input_files] + input_files = [os.path.join(input_file_or_dir, i) for i in input_files] + else: + if not output_file_or_dir.endswith('.nii.gz'): + output_file_or_dir += '.nii.gz' + assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" + + output_files = [output_file_or_dir] + input_files = [input_file_or_dir] + + if tta == 0: + tta = False + elif tta == 1: + tta = True + else: + raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta)) + + if overwrite_existing == 0: + overwrite_existing = False + elif overwrite_existing == 1: + overwrite_existing = True + else: + raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing)) + + if pp == 0: + pp = False + elif pp == 1: + pp = True + else: + raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp)) + + if save_mask == 0: + save_mask = False + elif save_mask == 1: + save_mask = True + else: + raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp)) + + run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing) + + +if __name__ == "__main__": + print("\n########################") + print("If you are using hd-bet, please cite the following paper:") + print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W," + "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial" + "neural networks. arXiv preprint arXiv:1901.11341, 2019.") + print("########################\n") + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be ' + 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to ' + 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz ' + 'within that folder will be brain extracted.', required=True, type=str) + parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder' + ' will be created', required=False, type=str) + parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will ' + 'use only one set of parameters whereas accurate will ' + 'use the five sets of parameters that resulted from ' + 'our cross-validation as an ensemble. Default: ' + 'accurate', + required=False) + parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. ' + 'Must be either int or str. Use int for GPU id or ' + '\'cpu\' to run on CPU. When using CPU you should ' + 'consider disabling tta. Default for -device is: 0', + required=False) + parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation ' + '(mirroring). 1= True, 0=False. Disable this ' + 'if you are using CPU to speed things up! ' + 'Default: 1') + parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all' + ' but the largest connected component in ' + 'the prediction. Default: 1') + parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation ' + 'mask will not be ' + 'saved') + parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't " + "want to overwrite existing " + "predictions") + + args = parser.parse_args() + + hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing) + diff --git a/src/IDH/HD_BET/network_architecture.py b/src/IDH/HD_BET/network_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..0824aa10839024368ad8ab38c637ce81aa9327e5 --- /dev/null +++ b/src/IDH/HD_BET/network_architecture.py @@ -0,0 +1,213 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from HD_BET.utils import softmax_helper + + +class EncodingModule(nn.Module): + def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True): + nn.Module.__init__(self) + self.dropout_p = dropout_p + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) + self.dropout = nn.Dropout3d(dropout_p) + self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) + + def forward(self, x): + skip = x + x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = self.conv1(x) + if self.dropout_p is not None and self.dropout_p > 0: + x = self.dropout(x) + x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = self.conv2(x) + x = x + skip + return x + + +class Upsample(nn.Module): + def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True): + super(Upsample, self).__init__() + self.align_corners = align_corners + self.mode = mode + self.scale_factor = scale_factor + self.size = size + + def forward(self, x): + return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, + align_corners=self.align_corners) + + +class LocalizationModule(nn.Module): + def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias) + self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias) + self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) + + def forward(self, x): + x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + return x + + +class UpsamplingModule(nn.Module): + def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True) + self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias) + self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) + + def forward(self, x): + x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness, + inplace=self.lrelu_inplace) + return x + + +class DownsamplingModule(nn.Module): + def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias) + + def forward(self, x): + x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + b = self.downsample(x) + return x, b + + +class Network(nn.Module): + def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3, + final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, + lrelu_inplace=True, do_ds=True): + super(Network, self).__init__() + + self.do_ds = do_ds + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.final_nonlin = final_nonlin + self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias) + + self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2, + conv_bias=True, inst_norm_affine=True, lrelu_inplace=True) + + self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False) + self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) + self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True, + inst_norm_affine=True, lrelu_inplace=True) + + self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) + self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) + self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) + + def forward(self, x): + seg_outputs = [] + + x = self.init_conv(x) + x = self.context1(x) + + skip1, x = self.down1(x) + x = self.context2(x) + + skip2, x = self.down2(x) + x = self.context3(x) + + skip3, x = self.down3(x) + x = self.context4(x) + + skip4, x = self.down4(x) + x = self.context5(x) + + x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = self.up1(x) + + x = torch.cat((skip4, x), dim=1) + x = self.loc1(x) + x = self.up2(x) + + x = torch.cat((skip3, x), dim=1) + x = self.loc2(x) + loc2_seg = self.final_nonlin(self.loc2_seg(x)) + seg_outputs.append(loc2_seg) + x = self.up3(x) + + x = torch.cat((skip2, x), dim=1) + x = self.loc3(x) + loc3_seg = self.final_nonlin(self.loc3_seg(x)) + seg_outputs.append(loc3_seg) + x = self.up4(x) + + x = torch.cat((skip1, x), dim=1) + x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness, + inplace=self.lrelu_inplace) + x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness, + inplace=self.lrelu_inplace) + x = self.final_nonlin(self.seg_layer(x)) + seg_outputs.append(x) + + if self.do_ds: + return seg_outputs[::-1] + else: + return seg_outputs[-1] diff --git a/src/IDH/HD_BET/paths.py b/src/IDH/HD_BET/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..9d125b947184c03d1a12ebd77e39a75a730c4594 --- /dev/null +++ b/src/IDH/HD_BET/paths.py @@ -0,0 +1,6 @@ +import os + +# please refer to the readme on where to get the parameters. Save them in this folder: +# Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params" +# Updated path for Docker container: +folder_with_parameter_files = "/app/IDH/hdbet_model" diff --git a/src/IDH/HD_BET/predict_case.py b/src/IDH/HD_BET/predict_case.py new file mode 100644 index 0000000000000000000000000000000000000000..559c66739ae890f7e985e072eb49ce0ee0484978 --- /dev/null +++ b/src/IDH/HD_BET/predict_case.py @@ -0,0 +1,126 @@ +import torch +import numpy as np + + +def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None): + if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)): + shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3 + shp = patient.shape + new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0], + shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1], + shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]] + for i in range(len(shp)): + if shp[i] % shape_must_be_divisible_by[i] == 0: + new_shp[i] -= shape_must_be_divisible_by[i] + if min_size is not None: + new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0) + return reshape_by_padding_upper_coords(patient, new_shp, 0), shp + + +def reshape_by_padding_upper_coords(image, new_shape, pad_value=None): + shape = tuple(list(image.shape)) + new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0)) + if pad_value is None: + if len(shape) == 2: + pad_value = image[0,0] + elif len(shape) == 3: + pad_value = image[0, 0, 0] + else: + raise ValueError("Image must be either 2 or 3 dimensional") + res = np.ones(list(new_shape), dtype=image.dtype) * pad_value + if len(shape) == 2: + res[0:0+int(shape[0]), 0:0+int(shape[1])] = image + elif len(shape) == 3: + res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image + return res + + +def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None, + new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)): + with torch.no_grad(): + pad_res = [] + for i in range(patient_data.shape[0]): + t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size) + pad_res.append(t[None]) + + patient_data = np.vstack(pad_res) + + new_shp = patient_data.shape + + data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32) + + data[0] = patient_data + + if BATCH_SIZE is not None: + data = np.vstack([data] * BATCH_SIZE) + + a = torch.rand(data.shape).float() + + if main_device == 'cpu': + pass + else: + a = a.cuda(main_device) + + if do_mirroring: + x = 8 + else: + x = 1 + all_preds = [] + for i in range(num_repeats): + for m in range(x): + data_for_net = np.array(data) + do_stuff = False + if m == 0: + do_stuff = True + pass + if m == 1 and (4 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, :, :, ::-1] + if m == 2 and (3 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, :, ::-1, :] + if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, :, ::-1, ::-1] + if m == 4 and (2 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, :, :] + if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, :, ::-1] + if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, ::-1, :] + if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + do_stuff = True + data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1] + + if do_stuff: + _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net))) + p = net(a) # np.copy is necessary because ::-1 creates just a view i think + p = p.data.cpu().numpy() + + if m == 0: + pass + if m == 1 and (4 in mirror_axes): + p = p[:, :, :, :, ::-1] + if m == 2 and (3 in mirror_axes): + p = p[:, :, :, ::-1, :] + if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): + p = p[:, :, :, ::-1, ::-1] + if m == 4 and (2 in mirror_axes): + p = p[:, :, ::-1, :, :] + if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): + p = p[:, :, ::-1, :, ::-1] + if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): + p = p[:, :, ::-1, ::-1, :] + if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + p = p[:, :, ::-1, ::-1, ::-1] + all_preds.append(p) + + stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]] + predicted_segmentation = stacked.mean(0).argmax(0) + uncertainty = stacked.var(0) + bayesian_predictions = stacked + softmax_pred = stacked.mean(0) + return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty diff --git a/src/IDH/HD_BET/run.py b/src/IDH/HD_BET/run.py new file mode 100644 index 0000000000000000000000000000000000000000..858934d8f67175df508884e9030f8d38ba0d07cf --- /dev/null +++ b/src/IDH/HD_BET/run.py @@ -0,0 +1,117 @@ +import torch +import numpy as np +import SimpleITK as sitk +from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti +from HD_BET.predict_case import predict_case_3D_net +import imp +from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters +import os +import HD_BET + + +def apply_bet(img, bet, out_fname): + img_itk = sitk.ReadImage(img) + img_npy = sitk.GetArrayFromImage(img_itk) + img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet)) + img_npy[img_bet == 0] = 0 + out = sitk.GetImageFromArray(img_npy) + out.CopyInformation(img_itk) + sitk.WriteImage(out, out_fname) + + +def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0, + postprocess=False, do_tta=True, keep_mask=True, overwrite=True): + """ + + :param mri_fnames: str or list/tuple of str + :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames + :param mode: fast or accurate + :param config_file: config.py + :param device: either int (for device id) or 'cpu' + :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all + but the largest predicted connected component. Default False + :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use + CPU you may want to turn that off to speed things up + :return: + """ + + list_of_param_files = [] + + if mode == 'fast': + params_file = get_params_fname(0) + maybe_download_parameters(0) + + list_of_param_files.append(params_file) + elif mode == 'accurate': + for i in range(5): + params_file = get_params_fname(i) + maybe_download_parameters(i) + + list_of_param_files.append(params_file) + else: + raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode) + + assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files" + + cf = imp.load_source('cf', config_file) + cf = cf.config() + + net, _ = cf.get_network(cf.val_use_train_mode, None) + if device == "cpu": + net = net.cpu() + else: + net.cuda(device) + + if not isinstance(mri_fnames, (list, tuple)): + mri_fnames = [mri_fnames] + + if not isinstance(output_fnames, (list, tuple)): + output_fnames = [output_fnames] + + assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length" + + params = [] + for p in list_of_param_files: + params.append(torch.load(p, map_location=lambda storage, loc: storage)) + + for in_fname, out_fname in zip(mri_fnames, output_fnames): + mask_fname = out_fname[:-7] + "_mask.nii.gz" + if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)): + print("File:", in_fname) + print("preprocessing...") + try: + data, data_dict = load_and_preprocess(in_fname) + except RuntimeError: + print("\nERROR\nCould not read file", in_fname, "\n") + continue + except AssertionError as e: + print(e) + continue + + softmax_preds = [] + + print("prediction (CNN id)...") + for i, p in enumerate(params): + print(i) + net.load_state_dict(p) + net.eval() + net.apply(SetNetworkToVal(False, False)) + _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats, + cf.val_batch_size, cf.net_input_must_be_divisible_by, + cf.val_min_size, device, cf.da_mirror_axes) + softmax_preds.append(softmax_pred[None]) + + seg = np.argmax(np.vstack(softmax_preds).mean(0), 0) + + if postprocess: + seg = postprocess_prediction(seg) + + print("exporting segmentation...") + save_segmentation_nifti(seg, data_dict, mask_fname) + + apply_bet(in_fname, mask_fname, out_fname) + + if not keep_mask: + os.remove(mask_fname) + + diff --git a/src/IDH/HD_BET/utils.py b/src/IDH/HD_BET/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba72a3d4d70accfd1fdc313a2f80b8d4c4c6eea --- /dev/null +++ b/src/IDH/HD_BET/utils.py @@ -0,0 +1,115 @@ +from urllib.request import urlopen +import torch +from torch import nn +import numpy as np +from skimage.morphology import label +import os +from HD_BET.paths import folder_with_parameter_files + + +def get_params_fname(fold): + return os.path.join(folder_with_parameter_files, "%d.model" % fold) + + +def maybe_download_parameters(fold=0, force_overwrite=False): + """ + Downloads the parameters for some fold if it is not present yet. + :param fold: + :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download + :return: + """ + + assert 0 <= fold <= 4, "fold must be between 0 and 4" + + if not os.path.isdir(folder_with_parameter_files): + maybe_mkdir_p(folder_with_parameter_files) + + out_filename = get_params_fname(fold) + + if force_overwrite and os.path.isfile(out_filename): + os.remove(out_filename) + + if not os.path.isfile(out_filename): + url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold + print("Downloading", url, "...") + data = urlopen(url).read() + #out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model" + with open(out_filename, 'wb') as f: + f.write(data) + + +def init_weights(module): + if isinstance(module, nn.Conv3d): + module.weight = nn.init.kaiming_normal(module.weight, a=1e-2) + if module.bias is not None: + module.bias = nn.init.constant(module.bias, 0) + + +def softmax_helper(x): + rpt = [1 for _ in range(len(x.size()))] + rpt[1] = x.size(1) + x_max = x.max(1, keepdim=True)[0].repeat(*rpt) + e_x = torch.exp(x - x_max) + return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) + + +class SetNetworkToVal(object): + def __init__(self, use_dropout_sampling=False, norm_use_average=True): + self.norm_use_average = norm_use_average + self.use_dropout_sampling = use_dropout_sampling + + def __call__(self, module): + if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout): + module.train(self.use_dropout_sampling) + elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \ + isinstance(module, nn.InstanceNorm1d) \ + or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \ + isinstance(module, nn.BatchNorm1d): + module.train(not self.norm_use_average) + + +def postprocess_prediction(seg): + # basically look for connected components and choose the largest one, delete everything else + print("running postprocessing... ") + mask = seg != 0 + lbls = label(mask, connectivity=mask.ndim) + lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)] + largest_region = np.argmax(lbls_sizes[1:]) + 1 + seg[lbls != largest_region] = 0 + return seg + + +def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): + if join: + l = os.path.join + else: + l = lambda x, y: y + res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix))] + if sort: + res.sort() + return res + + +def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): + if join: + l = os.path.join + else: + l = lambda x, y: y + res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix))] + if sort: + res.sort() + return res + + +subfolders = subdirs # I am tired of confusing those + + +def maybe_mkdir_p(directory): + splits = directory.split("/")[1:] + for i in range(0, len(splits)): + if not os.path.isdir(os.path.join("", *splits[:i+1])): + os.mkdir(os.path.join("", *splits[:i+1])) diff --git a/src/IDH/app_gradio.py b/src/IDH/app_gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..181d0206cf4faee341eb129a5584e0ac6e42066f --- /dev/null +++ b/src/IDH/app_gradio.py @@ -0,0 +1,1015 @@ +import os +import yaml +import torch +import nibabel as nib +import numpy as np +import gradio as gr +from typing import Tuple +import tempfile +import shutil +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use('Agg') # Use non-interactive backend +import cv2 # For Gaussian Blur +import io # For saving plots to memory +import base64 # For encoding plots +import uuid # For unique IDs +import traceback # For detailed error printing + +import SimpleITK as sitk +import itk +from scipy.signal import medfilt +import skimage.filters + +from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Resized, NormalizeIntensityd, ToTensord, EnsureTyped +from monai.inferers import sliding_window_inference + +from model import ViTUNETRSegmentationModel + +# Optional HD-BET import (packaged locally like in MCI app) +try: + from HD_BET.run import run_hd_bet + from HD_BET.hd_bet import hd_bet +except Exception as e: + print(f"Warning: HD_BET not available: {e}") + run_hd_bet = None + hd_bet = None + + +APP_DIR = os.path.dirname(__file__) +TEMPLATE_DIR = os.path.join(APP_DIR, "golden_image", "mni_templates") +PARAMS_RIGID_PATH = os.path.join(TEMPLATE_DIR, "Parameters_Rigid.txt") +DEFAULT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "temp_head.nii.gz") +FLAIR_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_04.5-18.5_t2w.nii") +HD_BET_CONFIG_PATH = os.path.join(APP_DIR, "HD_BET", "config.py") +HD_BET_MODEL_DIR = os.path.join(APP_DIR, "hdbet_model") + + +def load_config() -> dict: + cfg_path = os.path.join(APP_DIR, "config.yml") + if os.path.exists(cfg_path): + with open(cfg_path, "r") as f: + return yaml.safe_load(f) + # Defaults + return { + "gpu": {"device": "cpu"}, + "infer": { + "checkpoints": "./checkpoints/idh_model.ckpt", + "simclr_checkpoint": None, + "threshold": 0.5, + "image_size": [96, 96, 96], + }, + } + + +def build_model(cfg: dict): + device = torch.device(cfg.get("gpu", {}).get("device", "cpu")) + infer_cfg = cfg.get("infer", {}) + model_cfg = cfg.get("model", {}) + simclr_path = None#os.path.join(APP_DIR, infer_cfg.get("simclr_checkpoint", "")) + ckpt_path = os.path.join(APP_DIR, infer_cfg.get("checkpoints", "")) + + model = ViTUNETRSegmentationModel( + simclr_ckpt_path=None, + img_size=tuple(model_cfg.get("img_size", [96, 96, 96])), + in_channels=model_cfg.get("in_channels", 1), + out_channels=model_cfg.get("out_channels", 1) + ) + + # Load finetuned checkpoint (Lightning or plain state_dict) + if os.path.exists(ckpt_path): + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("model."): + new_state_dict[key[len("model."):]] = value + else: + new_state_dict[key] = value + else: + new_state_dict = checkpoint + model.load_state_dict(new_state_dict, strict=False) + else: + print(f"Warning: Segmentation checkpoint not found at {ckpt_path}. Model will use backbone-only weights.") + + model.to(device) + model.eval() + return model, device + + +# ---------------- Preprocessing (Registration + Enhancement + Skull Stripping) ---------------- + +def bias_field_correction(img_array: np.ndarray) -> np.ndarray: + image = sitk.GetImageFromArray(img_array.astype(np.float32)) + if image.GetPixelID() != sitk.sitkFloat32: + image = sitk.Cast(image, sitk.sitkFloat32) + maskImage = sitk.OtsuThreshold(image, 0, 1, 200) + corrector = sitk.N4BiasFieldCorrectionImageFilter() + numberFittingLevels = 4 + max_iters = [min(50 * (2 ** i), 200) for i in range(numberFittingLevels)] + corrector.SetMaximumNumberOfIterations(max_iters) + corrected_image = corrector.Execute(image, maskImage) + return sitk.GetArrayFromImage(corrected_image) + + +def denoise(volume: np.ndarray, kernel_size: int = 3) -> np.ndarray: + return medfilt(volume, kernel_size) + + +def rescale_intensity(volume: np.ndarray, percentils=[0.5, 99.5], bins_num=256) -> np.ndarray: + volume_float = volume.astype(np.float32) + try: + t = skimage.filters.threshold_otsu(volume_float, nbins=256) + volume_masked = np.copy(volume_float) + volume_masked[volume_masked < t] = 0 + obj_volume = volume_masked[np.where(volume_masked > 0)] + except ValueError: + obj_volume = volume_float.flatten() + if obj_volume.size == 0: + obj_volume = volume_float.flatten() + min_value = np.min(obj_volume) + max_value = np.max(obj_volume) + else: + min_value = np.percentile(obj_volume, percentils[0]) + max_value = np.percentile(obj_volume, percentils[1]) + denom = max_value - min_value + if denom < 1e-6: + denom = 1e-6 + if bins_num == 0: + output_volume = (volume_float - min_value) / denom + output_volume = np.clip(output_volume, 0.0, 1.0) + else: + output_volume = np.round((volume_float - min_value) / denom * (bins_num - 1)) + output_volume = np.clip(output_volume, 0, bins_num - 1) + return output_volume.astype(np.float32) + + +def equalize_hist(volume: np.ndarray, bins_num=256) -> np.ndarray: + mask = volume > 1e-6 + obj_volume = volume[mask] + if obj_volume.size == 0: + return volume + hist, bins = np.histogram(obj_volume, bins_num, range=(obj_volume.min(), obj_volume.max())) + cdf = hist.cumsum() + cdf_normalized = (bins_num - 1) * cdf / float(cdf[-1]) + equalized_obj_volume = np.interp(obj_volume, bins[:-1], cdf_normalized) + equalized_volume = np.copy(volume) + equalized_volume[mask] = equalized_obj_volume + return equalized_volume.astype(np.float32) + + +def run_enhance_on_file(input_nifti_path: str, output_nifti_path: str): + """ + Simplified enhancement - just copy the file since N4 is now done in registration. + This maintains compatibility with the existing preprocessing pipeline. + """ + print(f"Enhancement step (N4 already applied during registration): {input_nifti_path}") + # Since N4 bias correction is now handled in registration, just copy the file + import shutil + shutil.copy2(input_nifti_path, output_nifti_path) + print(f"Enhancement complete (passthrough): {output_nifti_path}") + + +def register_image_sitk(input_nifti_path: str, output_nifti_path: str, template_path: str, interp_type='linear'): + """ + MRI registration with SimpleITK matching the provided script approach. + + Args: + input_nifti_path: Path to input NIfTI file + output_nifti_path: Path to save registered output + template_path: Path to template image + interp_type: Interpolation type ('linear', 'bspline', 'nearest_neighbor') + """ + print(f"Registering {input_nifti_path} to template {template_path}") + + # Read template and moving images + fixed_img = sitk.ReadImage(template_path, sitk.sitkFloat32) + moving_img = sitk.ReadImage(input_nifti_path, sitk.sitkFloat32) + + # Apply N4 bias correction to moving image + moving_img = sitk.N4BiasFieldCorrection(moving_img) + + # Resample fixed image to 1mm isotropic + old_size = fixed_img.GetSize() + old_spacing = fixed_img.GetSpacing() + new_spacing = (1, 1, 1) + new_size = [ + int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))), + int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))), + int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2]))) + ] + + # Set interpolation type + if interp_type == 'linear': + interp_type = sitk.sitkLinear + elif interp_type == 'bspline': + interp_type = sitk.sitkBSpline + elif interp_type == 'nearest_neighbor': + interp_type = sitk.sitkNearestNeighbor + else: + interp_type = sitk.sitkLinear + + # Resample fixed image + resample = sitk.ResampleImageFilter() + resample.SetOutputSpacing(new_spacing) + resample.SetSize(new_size) + resample.SetOutputOrigin(fixed_img.GetOrigin()) + resample.SetOutputDirection(fixed_img.GetDirection()) + resample.SetInterpolator(interp_type) + resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue()) + resample.SetOutputPixelType(sitk.sitkFloat32) + fixed_img = resample.Execute(fixed_img) + + # Initialize transform + transform = sitk.CenteredTransformInitializer( + fixed_img, + moving_img, + sitk.Euler3DTransform(), + sitk.CenteredTransformInitializerFilter.GEOMETRY) + + # Set up registration method + registration_method = sitk.ImageRegistrationMethod() + registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) + registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) + registration_method.SetMetricSamplingPercentage(0.01) + registration_method.SetInterpolator(sitk.sitkLinear) + registration_method.SetOptimizerAsGradientDescent( + learningRate=1.0, + numberOfIterations=100, + convergenceMinimumValue=1e-6, + convergenceWindowSize=10) + registration_method.SetOptimizerScalesFromPhysicalShift() + registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1]) + registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0]) + registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() + registration_method.SetInitialTransform(transform) + + # Execute registration + final_transform = registration_method.Execute(fixed_img, moving_img) + + # Apply transform and save registered image + moving_img_resampled = sitk.Resample( + moving_img, + fixed_img, + final_transform, + sitk.sitkLinear, + 0.0, + moving_img.GetPixelID()) + + sitk.WriteImage(moving_img_resampled, output_nifti_path) + print(f"Registration complete. Saved to: {output_nifti_path}") + + +def register_image(input_nifti_path: str, output_nifti_path: str): + """Wrapper to maintain compatibility - now uses SimpleITK registration.""" + if not os.path.exists(DEFAULT_TEMPLATE_PATH): + raise FileNotFoundError(f"Template file missing: {DEFAULT_TEMPLATE_PATH}") + register_image_sitk(input_nifti_path, output_nifti_path, DEFAULT_TEMPLATE_PATH) + + +def run_skull_stripping(input_nifti_path: str, output_dir: str): + """ + Brain extraction using HD-BET direct integration matching the script approach. + + Args: + input_nifti_path: Path to input NIfTI file + output_dir: Directory to save skull-stripped output + + Returns: + tuple: (output_file_path, output_mask_path) + """ + print(f"Running HD-BET skull stripping on {input_nifti_path}") + + if hd_bet is None: + raise RuntimeError("HD-BET not available. Please include HD_BET and hdbet_model in src/IDH.") + + if not os.path.exists(HD_BET_MODEL_DIR): + raise FileNotFoundError(f"HD-BET models not found at {HD_BET_MODEL_DIR}") + + os.makedirs(output_dir, exist_ok=True) + + # Get base filename and prepare HD-BET compatible naming + base_name = os.path.basename(input_nifti_path).replace('.nii.gz', '').replace('.nii', '') + + # HD-BET expects files with _0000 suffix - create temporary file if needed + temp_input_dir = os.path.join(output_dir, "temp_input") + os.makedirs(temp_input_dir, exist_ok=True) + + # Copy input file with _0000 suffix for HD-BET + temp_input_path = os.path.join(temp_input_dir, f"{base_name}_0000.nii.gz") + shutil.copy2(input_nifti_path, temp_input_path) + + # Set device + device = "0" if torch.cuda.is_available() else "cpu" + + try: + # Also try setting the specific model file path + model_file = os.path.join(HD_BET_MODEL_DIR, '0.model') + + if os.path.exists(model_file): + print(f"Local model file exists at: {model_file}") + else: + print(f"Warning: Model file not found at: {model_file}") + # List directory contents for debugging + if os.path.exists(HD_BET_MODEL_DIR): + print(f"Contents of {HD_BET_MODEL_DIR}: {os.listdir(HD_BET_MODEL_DIR)}") + else: + print(f"Directory {HD_BET_MODEL_DIR} does not exist") + + # Run HD-BET directly on the temporary directory + print(f"Running hd_bet with input_dir: {temp_input_dir}, output_dir: {output_dir}") + hd_bet(temp_input_dir, output_dir, device=device, mode='fast', tta=0) + + # HD-BET outputs files with original naming convention + output_file_path = os.path.join(output_dir, f"{base_name}_0000.nii.gz") + output_mask_path = os.path.join(output_dir, f"{base_name}_0000_mask.nii.gz") + + # Rename to expected format for compatibility + final_output_path = os.path.join(output_dir, f"{base_name}_bet.nii.gz") + final_mask_path = os.path.join(output_dir, f"{base_name}_bet_mask.nii.gz") + + if os.path.exists(output_file_path): + shutil.move(output_file_path, final_output_path) + if os.path.exists(output_mask_path): + shutil.move(output_mask_path, final_mask_path) + + # Clean up temporary directory + shutil.rmtree(temp_input_dir, ignore_errors=True) + + if not os.path.exists(final_output_path): + raise RuntimeError(f"HD-BET did not produce output file: {final_output_path}") + + print(f"Skull stripping complete. Output saved to: {final_output_path}") + return final_output_path, final_mask_path + + except Exception as e: + # Clean up on error + shutil.rmtree(temp_input_dir, ignore_errors=True) + raise RuntimeError(f"HD-BET skull stripping failed: {str(e)}") + + +# ---------------- Visualization Functions ---------------- + +def create_segmentation_plots(input_data_3d, seg_mask_3d, slice_index): + """Create segmentation visualization plots: Input, Mask, and Overlay.""" + print(f"Generating segmentation plots for slice index: {slice_index}") + + if any(data is None for data in [input_data_3d, seg_mask_3d]): + return None, None, None + + # Check bounds - using axis 2 for axial slices + if not (0 <= slice_index < input_data_3d.shape[2]): + print(f"Error: Slice index {slice_index} out of bounds (0-{input_data_3d.shape[2]-1}).") + return None, None, None + + def save_plot_to_numpy(fig): + with io.BytesIO() as buf: + fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75) + plt.close(fig) + buf.seek(0) + img_arr = plt.imread(buf, format='png') + return (img_arr * 255).astype(np.uint8) + + try: + # Extract axial slices - using axis 2 (last dimension) + input_slice = input_data_3d[:, :, slice_index] + mask_slice = seg_mask_3d[:, :, slice_index] + + # Normalize input slice + def normalize_slice(slice_data, volume_data): + p1, p99 = np.percentile(volume_data, (1, 99)) + denom = max(p99 - p1, 1e-6) + return np.clip((slice_data - p1) / denom, 0, 1) + + input_slice_norm = normalize_slice(input_slice, input_data_3d) + + # Create plots + plots = [] + + # Input Image + fig1, ax1 = plt.subplots(figsize=(6, 6)) + ax1.imshow(input_slice_norm, cmap='gray', interpolation='none', origin='lower') + ax1.axis('off') + ax1.set_title('Input Image', fontsize=14, color='white', pad=10) + plots.append(save_plot_to_numpy(fig1)) + + # Segmentation Mask + fig2, ax2 = plt.subplots(figsize=(6, 6)) + ax2.imshow(mask_slice, cmap='hot', interpolation='none', origin='lower', vmin=0, vmax=1) + ax2.axis('off') + ax2.set_title('Segmentation Mask', fontsize=14, color='white', pad=10) + plots.append(save_plot_to_numpy(fig2)) + + # Overlay + fig3, ax3 = plt.subplots(figsize=(6, 6)) + ax3.imshow(input_slice_norm, cmap='gray', interpolation='none', origin='lower') + # Create mask overlay with transparency - using red colormap + mask_overlay = np.ma.masked_where(mask_slice < 0.5, mask_slice) + ax3.imshow(mask_overlay, cmap='Reds', interpolation='none', origin='lower', alpha=0.7, vmin=0, vmax=1) + ax3.axis('off') + ax3.set_title('Overlay', fontsize=14, color='white', pad=10) + plots.append(save_plot_to_numpy(fig3)) + + print(f"Generated 3 segmentation plots successfully for axial slice {slice_index}.") + return tuple(plots) + + except Exception as e: + print(f"Error generating segmentation plots for slice {slice_index}: {e}") + traceback.print_exc() + return tuple([None] * 3) + + +# ---------------- Saliency Generation (Legacy - keeping for reference) ---------------- + +def extract_attention_map(vit_model, image, layer_idx=-1, img_size=(96, 96, 96), patch_size=16): + """ + Extracts the attention map from a Vision Transformer (ViT) model. + + This function wraps the attention blocks of the ViT to capture the attention + weights during a forward pass. It then processes these weights to generate + a 3D saliency map corresponding to the model's focus on the input image. + """ + attention_maps = {} + original_attns = {} + + # A wrapper class to intercept and store attention weights from a ViT block. + class AttentionWithWeights(torch.nn.Module): + def __init__(self, original_attn_module): + super().__init__() + self.original_attn_module = original_attn_module + self.attn_weights = None + + def forward(self, x): + # The original implementation of the attention module may not return + # the attention weights. This wrapper recalculates them to ensure they + # are captured. This is based on the standard ViT attention mechanism. + output = self.original_attn_module(x) + if hasattr(self.original_attn_module, 'qkv'): + qkv = self.original_attn_module.qkv(x) + batch_size, seq_len, _ = x.shape + # Assuming qkv has been fused and has shape (batch_size, seq_len, 3 * num_heads * head_dim) + qkv = qkv.reshape(batch_size, seq_len, 3, self.original_attn_module.num_heads, -1) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.original_attn_module.scale + self.attn_weights = attn.softmax(dim=-1) + return output + + # Store original attention modules and replace with wrappers + for i, block in enumerate(vit_model.blocks): + if hasattr(block, 'attn'): + original_attns[i] = block.attn + block.attn = AttentionWithWeights(block.attn) + + try: + # Perform a forward pass to execute the wrapped modules and capture weights + with torch.no_grad(): + _ = vit_model(image) + + # Collect the captured attention weights from each block + for i, block in enumerate(vit_model.blocks): + if hasattr(block.attn, 'attn_weights') and block.attn.attn_weights is not None: + attention_maps[f"layer_{i}"] = block.attn.attn_weights.detach() + + finally: + # Restore original attention modules + for i, original_attn in original_attns.items(): + vit_model.blocks[i].attn = original_attn + + if not attention_maps: + raise RuntimeError("Could not extract any attention maps. Please check the ViT model structure.") + + # Select the attention map from the specified layer + if layer_idx < 0: + layer_idx = len(attention_maps) + layer_idx + layer_name = f"layer_{layer_idx}" + if layer_name not in attention_maps: + raise ValueError(f"Layer {layer_idx} not found. Available layers: {list(attention_maps.keys())}") + + layer_attn = attention_maps[layer_name] + # Average attention across all heads + head_attn = layer_attn[0].mean(dim=0) + # Get attention from the [CLS] token to all other image patches + cls_attn = head_attn[0, 1:] + + # Reshape the 1D attention vector into a 3D volume + patches_per_dim = img_size[0] // patch_size + total_patches = patches_per_dim ** 3 + + # Pad or truncate if the number of patches doesn't align + if cls_attn.shape[0] != total_patches: + if cls_attn.shape[0] > total_patches: + cls_attn = cls_attn[:total_patches] + else: + padded = torch.zeros(total_patches, device=cls_attn.device) + padded[:cls_attn.shape[0]] = cls_attn + cls_attn = padded + + cls_attn_3d = cls_attn.reshape(patches_per_dim, patches_per_dim, patches_per_dim) + cls_attn_3d = cls_attn_3d.unsqueeze(0).unsqueeze(0) # Add batch and channel dims + + # Upsample the attention map to the full image resolution + upsampled_attn = torch.nn.functional.interpolate( + cls_attn_3d, + size=img_size, + mode='trilinear', + align_corners=False + ).squeeze() + + # Normalize the map to [0, 1] for visualization + upsampled_attn = upsampled_attn.cpu().numpy() + upsampled_attn = (upsampled_attn - upsampled_attn.min()) / (upsampled_attn.max() - upsampled_attn.min()) + return upsampled_attn + + +def generate_saliency_dual(model, input_tensor, layer_idx=-1): + """ + Generate saliency maps for dual-input IDH model. + + Args: + model: The complete IDH model + input_tensor: Dual input tensor (batch_size, 2, C, D, H, W) + layer_idx: ViT layer to visualize + + Returns: + tuple: (flair_input_3d, t1c_input_3d, flair_saliency_3d) + """ + print("Generating saliency maps for dual input...") + + try: + # Extract individual images from dual input + # input_tensor shape: [batch_size, 2, C, D, H, W] + flair_tensor = input_tensor[:, 0] # [batch, C, D, H, W] + t1c_tensor = input_tensor[:, 1] # [batch, C, D, H, W] + + # Get the ViT backbone + vit_model = model.backbone.backbone + + # Generate attention map only for FLAIR + flair_attn = extract_attention_map(vit_model, flair_tensor, layer_idx) + + # Convert input tensors to numpy for visualization + flair_input_3d = flair_tensor.squeeze().cpu().detach().numpy() + t1c_input_3d = t1c_tensor.squeeze().cpu().detach().numpy() + + print("Saliency maps generated successfully.") + return flair_input_3d, t1c_input_3d, flair_attn + + except Exception as e: + print(f"Error during saliency generation: {e}") + traceback.print_exc() + return None, None, None + + +# ---------------- Visualization Functions ---------------- + +def create_slice_plots_dual(flair_data_3d, t1c_data_3d, flair_saliency_3d, slice_index): + """Create slice plots for simplified dual input visualization: T1c, FLAIR, FLAIR attention.""" + print(f"Generating plots for slice index: {slice_index}") + + if any(data is None for data in [flair_data_3d, t1c_data_3d, flair_saliency_3d]): + return None, None, None + + # Check bounds - using axis 2 for axial slices + if not (0 <= slice_index < flair_data_3d.shape[2]): + print(f"Error: Slice index {slice_index} out of bounds (0-{flair_data_3d.shape[2]-1}).") + return None, None, None + + def save_plot_to_numpy(fig): + with io.BytesIO() as buf: + fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75) + plt.close(fig) + buf.seek(0) + img_arr = plt.imread(buf, format='png') + return (img_arr * 255).astype(np.uint8) + + try: + # Extract axial slices - using axis 2 (last dimension) + flair_slice = flair_data_3d[:, :, slice_index] + t1c_slice = t1c_data_3d[:, :, slice_index] + flair_saliency_slice = flair_saliency_3d[:, :, slice_index] + + # Normalize input slices + def normalize_slice(slice_data, volume_data): + p1, p99 = np.percentile(volume_data, (1, 99)) + denom = max(p99 - p1, 1e-6) + return np.clip((slice_data - p1) / denom, 0, 1) + + flair_slice_norm = normalize_slice(flair_slice, flair_data_3d) + t1c_slice_norm = normalize_slice(t1c_slice, t1c_data_3d) + + # Process saliency slice + def process_saliency_slice(saliency_slice, saliency_volume): + saliency_slice = np.copy(saliency_slice) + saliency_slice[saliency_slice < 0] = 0 + saliency_slice_blurred = cv2.GaussianBlur(saliency_slice, (15, 15), 0) + s_max = max(np.max(saliency_volume[saliency_volume >= 0]), 1e-6) + saliency_slice_norm = saliency_slice_blurred / s_max + return np.where(saliency_slice_norm > 0.0, saliency_slice_norm, 0) + + flair_sal_processed = process_saliency_slice(flair_saliency_slice, flair_saliency_3d) + + # Create plots + plots = [] + + # T1c Input + fig1, ax1 = plt.subplots(figsize=(6, 6)) + ax1.imshow(t1c_slice_norm, cmap='gray', interpolation='none', origin='lower') + ax1.axis('off') + ax1.set_title('T1c Input', fontsize=14, color='white', pad=10) + plots.append(save_plot_to_numpy(fig1)) + + # FLAIR Input + fig2, ax2 = plt.subplots(figsize=(6, 6)) + ax2.imshow(flair_slice_norm, cmap='gray', interpolation='none', origin='lower') + ax2.axis('off') + ax2.set_title('FLAIR Input', fontsize=14, color='white', pad=10) + plots.append(save_plot_to_numpy(fig2)) + + # FLAIR Attention + fig3, ax3 = plt.subplots(figsize=(6, 6)) + ax3.imshow(flair_sal_processed, cmap='magma', interpolation='none', origin='lower', vmin=0) + ax3.axis('off') + ax3.set_title('FLAIR Attention', fontsize=14, color='white', pad=10) + plots.append(save_plot_to_numpy(fig3)) + + print(f"Generated 3 plots successfully for axial slice {slice_index}.") + return tuple(plots) + + except Exception as e: + print(f"Error generating plots for slice {slice_index}: {e}") + traceback.print_exc() + return tuple([None] * 3) + + +# ---------------- Inference ---------------- + +def get_validation_transform(image_size: Tuple[int, int, int]): + return Compose([ + LoadImaged(keys=["image"]), + EnsureChannelFirstd(keys=["image"]), + Resized(keys=["image"], spatial_size=tuple(image_size), mode="trilinear"), + NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), + EnsureTyped(keys=["image"]), + ToTensord(keys=["image"]), + ]) + + +def preprocess_nifti(image_path: str, image_size: Tuple[int, int, int], device: torch.device) -> torch.Tensor: + transform = get_validation_transform(image_size) + sample = {"image": image_path} + sample = transform(sample) + image = sample["image"].unsqueeze(0).to(device) # Add batch dimension + return image + + +def save_nifti_for_download(data_array: np.ndarray, reference_path: str, output_path: str, affine=None): + """ + Save a numpy array as NIfTI file for download, preserving spatial information from reference. + + Args: + data_array: 3D numpy array to save + reference_path: Path to reference NIfTI file for header info + output_path: Path where to save the output file + affine: Optional affine matrix, if None will use reference + """ + try: + # Load reference image to get header and affine + ref_img = nib.load(reference_path) + + if affine is None: + affine = ref_img.affine + + # Create new NIfTI image + new_img = nib.Nifti1Image(data_array, affine, ref_img.header) + + # Save the file + nib.save(new_img, output_path) + print(f"Saved NIfTI file: {output_path}") + return output_path + + except Exception as e: + print(f"Error saving NIfTI file: {e}") + return None + + +def predict_segmentation(input_file, threshold: float, do_preprocess: bool, cfg: dict, model, device): + try: + if input_file is None: + return {"error": "Please upload a NIfTI file (.nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "mask_paths": None, "num_slices": 0}, None, None + + input_path = input_file.name if hasattr(input_file, 'name') else input_file + + if not (input_path.endswith(".nii") or input_path.endswith(".nii.gz")): + return {"error": "Input must be a NIfTI file (.nii or .nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "mask_paths": None, "num_slices": 0}, None, None + + work_dir = tempfile.mkdtemp() + final_input_path = input_path + + try: + # Optional preprocessing pipeline for FLAIR + if do_preprocess: + # Registration to FLAIR template + reg_path = os.path.join(work_dir, "flair_registered.nii.gz") + register_image_sitk(input_path, reg_path, FLAIR_TEMPLATE_PATH) + # Enhancement + enh_path = os.path.join(work_dir, "flair_enhanced.nii.gz") + run_enhance_on_file(reg_path, enh_path) + # Skull stripping + skullstrip_dir = os.path.join(work_dir, "skullstripped") + bet_path, _ = run_skull_stripping(enh_path, skullstrip_dir) + final_input_path = bet_path + + # Inference + image_size = cfg.get("infer", {}).get("image_size", [96, 96, 96]) + training_cfg = cfg.get("training", {}) + input_tensor = preprocess_nifti(final_input_path, image_size, device) + + with torch.no_grad(): + # Use sliding window inference for better results + seg_logits = sliding_window_inference( + inputs=input_tensor, + roi_size=tuple(image_size), + sw_batch_size=training_cfg.get("sw_batch_size", 2), + predictor=model, + overlap=0.5 + ) + # Apply sigmoid and threshold to get binary mask + seg_prob = torch.sigmoid(seg_logits) + seg_mask = (seg_prob > threshold).float() + + # Convert to numpy for visualization + input_3d = input_tensor.squeeze().cpu().detach().numpy() + seg_prob_3d = seg_prob.squeeze().cpu().detach().numpy() + seg_mask_3d = seg_mask.squeeze().cpu().detach().numpy() + + # Calculate statistics + total_voxels = np.prod(seg_mask_3d.shape) + segmented_voxels = int(np.sum(seg_mask_3d)) + segmentation_percentage = (segmented_voxels / total_voxels) * 100 + + prediction_result = { + "segmented_voxels": segmented_voxels, + "total_voxels": total_voxels, + "segmentation_percentage": float(segmentation_percentage), + "threshold": float(threshold), + "preprocessing": bool(do_preprocess), + "max_probability": float(np.max(seg_prob_3d)), + "mean_probability": float(np.mean(seg_prob_3d)) + } + + # Initialize visualization outputs + input_img = seg_mask_img = overlay_img = None + slider_update = gr.Slider(visible=False) + viz_state = {"input_paths": None, "mask_paths": None, "num_slices": 0} + + # Initialize download files + download_preprocessed = None + download_mask = None + + # Generate visualizations + print("--- Generating Visualizations ---") + try: + num_slices = input_3d.shape[2] # Use axis 2 for axial slices + center_slice_index = num_slices // 2 + + # Save numpy arrays for slider callback + unique_id = str(uuid.uuid4()) + temp_paths = [] + for name, data in [("input", input_3d), ("seg_prob", seg_prob_3d), ("seg_mask", seg_mask_3d)]: + path = os.path.join(work_dir, f"{unique_id}_{name}.npy") + np.save(path, data) + temp_paths.append(path) + + # Generate initial plots for center slice + plots = create_segmentation_plots(input_3d, seg_mask_3d, center_slice_index) + if plots and all(p is not None for p in plots): + input_img, seg_mask_img, overlay_img = plots + + # Update state and slider + viz_state = { + "input_paths": [temp_paths[0]], # [input] + "mask_paths": temp_paths[1:], # [seg_prob, seg_mask] + "num_slices": num_slices + } + slider_update = gr.Slider(value=center_slice_index, minimum=0, maximum=num_slices-1, step=1, label="Select Slice", visible=True) + print("--- Visualization Generation Complete ---") + + except Exception as e: + print(f"Error during visualization generation: {e}") + traceback.print_exc() + + # Generate downloadable files + print("--- Generating Download Files ---") + try: + # Create download filenames + base_name = os.path.splitext(os.path.basename(input_path))[0] + if base_name.endswith('.nii'): + base_name = os.path.splitext(base_name)[0] + + # Save preprocessed image (the actual array that was fed to the model) + preprocessed_download_path = os.path.join(work_dir, f"{base_name}_preprocessed.nii.gz") + # Save the preprocessed numpy array that was actually used for inference + saved_preprocessed_path = save_nifti_for_download( + input_3d, # This is the preprocessed array that was visualized + input_path, # Use original input as reference for header/affine + preprocessed_download_path + ) + if saved_preprocessed_path: + download_preprocessed = gr.File(value=saved_preprocessed_path, visible=True, label="Download Preprocessed Image") + + # Save segmentation mask + mask_download_path = os.path.join(work_dir, f"{base_name}_segmentation_mask.nii.gz") + saved_mask_path = save_nifti_for_download( + seg_mask_3d, + final_input_path, + mask_download_path + ) + if saved_mask_path: + download_mask = gr.File(value=saved_mask_path, visible=True, label="Download Segmentation Mask") + + print("--- Download Files Generated ---") + + except Exception as e: + print(f"Error generating download files: {e}") + traceback.print_exc() + + return (prediction_result, input_img, seg_mask_img, overlay_img, slider_update, viz_state, download_preprocessed, download_mask) + + except Exception as e: + shutil.rmtree(work_dir, ignore_errors=True) + return {"error": f"Processing failed: {str(e)}"}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "mask_paths": None, "num_slices": 0}, None, None + + except Exception as e: + return {"error": str(e)}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "mask_paths": None, "num_slices": 0}, None, None + + +def update_slice_viewer_segmentation(slice_index, current_state): + """Update slice viewer for segmentation visualization.""" + input_paths = current_state.get("input_paths", []) + mask_paths = current_state.get("mask_paths", []) + + if not input_paths or not mask_paths or len(input_paths) != 1 or len(mask_paths) != 2: + print(f"Warning: Invalid state for slice viewer update: {current_state}") + return None, None, None + + try: + # Load numpy arrays + input_3d = np.load(input_paths[0]) + seg_mask_3d = np.load(mask_paths[1]) # Use the binary mask, not probabilities + + # Validate slice index + slice_index = int(slice_index) + if not (0 <= slice_index < input_3d.shape[2]): # Use axis 2 for axial slices + print(f"Warning: Invalid slice index {slice_index}") + return None, None, None + + # Generate new plots + plots = create_segmentation_plots(input_3d, seg_mask_3d, slice_index) + return plots if plots else tuple([None] * 3) + + except Exception as e: + print(f"Error updating slice viewer for index {slice_index}: {e}") + traceback.print_exc() + return tuple([None] * 3) + + +def build_interface(): + cfg = load_config() + model, device = build_model(cfg) + default_threshold = float(cfg.get("infer", {}).get("threshold", 0.5)) + + with gr.Blocks(title="BrainIAC: Glioma Segmentation", css=""" +#header-row { + min-height: 150px; + align-items: center; +} +.logo-img img { + height: 150px; + object-fit: contain; +} +""") as demo: + # --- Header with Logos --- + with gr.Row(elem_id="header-row"): + with gr.Column(scale=1): + gr.Image(os.path.join(APP_DIR, "static/images/kannlab.png"), + show_label=False, interactive=False, + show_download_button=False, + container=False, + elem_classes=["logo-img"]) + with gr.Column(scale=3): + gr.Markdown( + "
Input Image
") + input_img = gr.Image(label="Input Image", type="numpy", show_label=False) + with gr.Column(): + gr.Markdown("Segmentation Mask
") + seg_mask_img = gr.Image(label="Segmentation Mask", type="numpy", show_label=False) + with gr.Column(): + gr.Markdown("Overlay
") + overlay_img = gr.Image(label="Overlay", type="numpy", show_label=False) + + # Wire components + predict_btn.click( + fn=lambda f, prep, thr: predict_segmentation(f, thr, prep, cfg, model, device), + inputs=[input_file, preprocess_checkbox, threshold_input], + outputs=[output_json, input_img, seg_mask_img, overlay_img, slice_slider, viz_state, download_preprocessed_btn, download_mask_btn], + ) + + slice_slider.change( + fn=update_slice_viewer_segmentation, + inputs=[slice_slider, viz_state], + outputs=[input_img, seg_mask_img, overlay_img] + ) + + return demo + + +if __name__ == "__main__": + iface = build_interface() + iface.launch(server_name="0.0.0.0", server_port=7860) \ No newline at end of file diff --git a/src/IDH/checkpoints/segmentation.ckpt b/src/IDH/checkpoints/segmentation.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..fd5688fdc82b74110cb35943f0c6d9c1a1bbdd7a --- /dev/null +++ b/src/IDH/checkpoints/segmentation.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:54b02a9741be9b01a36c4e0bc260910a20b45dba19e2d250865fced83044506f +size 724673723 diff --git a/src/IDH/config.yml b/src/IDH/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..9824ae82a31b5b4ca39bb4939762f32a03034405 --- /dev/null +++ b/src/IDH/config.yml @@ -0,0 +1,13 @@ +gpu: + device: cpu +infer: + checkpoints: ./checkpoints/segmentation.ckpt + simclr_checkpoint: ./checkpoints/simclr_vitb.ckpt + threshold: 0.5 + image_size: [96, 96, 96] +model: + img_size: [96, 96, 96] + in_channels: 1 + out_channels: 1 +training: + sw_batch_size: 2 \ No newline at end of file diff --git a/src/IDH/golden_image/mni_templates/Parameters_Rigid.txt b/src/IDH/golden_image/mni_templates/Parameters_Rigid.txt new file mode 100644 index 0000000000000000000000000000000000000000..19d729f7e970a3683fe06b2b59a24f27916a516a --- /dev/null +++ b/src/IDH/golden_image/mni_templates/Parameters_Rigid.txt @@ -0,0 +1,141 @@ +// Example parameter file for rotation registration +// C-style comments: // + +// The internal pixel type, used for internal computations +// Leave to float in general. +// NB: this is not the type of the input images! The pixel +// type of the input images is automatically read from the +// images themselves. +// This setting can be changed to "short" to save some memory +// in case of very large 3D images. +(FixedInternalImagePixelType "float") +(MovingInternalImagePixelType "float") + +// **************** Main Components ************************** + +// The following components should usually be left as they are: +(Registration "MultiResolutionRegistration") +(Interpolator "BSplineInterpolator") +(ResampleInterpolator "FinalBSplineInterpolator") +(Resampler "DefaultResampler") + +// These may be changed to Fixed/MovingSmoothingImagePyramid. +// See the manual. +(FixedImagePyramid "FixedRecursiveImagePyramid") +(MovingImagePyramid "MovingRecursiveImagePyramid") + +// The following components are most important: +// The optimizer AdaptiveStochasticGradientDescent (ASGD) works +// quite ok in general. The Transform and Metric are important +// and need to be chosen careful for each application. See manual. +(Optimizer "AdaptiveStochasticGradientDescent") +(Transform "EulerTransform") +(Metric "AdvancedMattesMutualInformation") + +// ***************** Transformation ************************** + +// Scales the rotations compared to the translations, to make +// sure they are in the same range. In general, it's best to +// use automatic scales estimation: +(AutomaticScalesEstimation "true") + +// Automatically guess an initial translation by aligning the +// geometric centers of the fixed and moving. +(AutomaticTransformInitialization "true") + +// Whether transforms are combined by composition or by addition. +// In generally, Compose is the best option in most cases. +// It does not influence the results very much. +(HowToCombineTransforms "Compose") + +// ******************* Similarity measure ********************* + +// Number of grey level bins in each resolution level, +// for the mutual information. 16 or 32 usually works fine. +// You could also employ a hierarchical strategy: +//(NumberOfHistogramBins 16 32 64) +(NumberOfHistogramBins 32) + +// If you use a mask, this option is important. +// If the mask serves as region of interest, set it to false. +// If the mask indicates which pixels are valid, then set it to true. +// If you do not use a mask, the option doesn't matter. +(ErodeMask "false") + +// ******************** Multiresolution ********************** + +// The number of resolutions. 1 Is only enough if the expected +// deformations are small. 3 or 4 mostly works fine. For large +// images and large deformations, 5 or 6 may even be useful. +(NumberOfResolutions 4) + +// The downsampling/blurring factors for the image pyramids. +// By default, the images are downsampled by a factor of 2 +// compared to the next resolution. +// So, in 2D, with 4 resolutions, the following schedule is used: +//(ImagePyramidSchedule 8 8 4 4 2 2 1 1 ) +// And in 3D: +//(ImagePyramidSchedule 8 8 8 4 4 4 2 2 2 1 1 1 ) +// You can specify any schedule, for example: +//(ImagePyramidSchedule 4 4 4 3 2 1 1 1 ) +// Make sure that the number of elements equals the number +// of resolutions times the image dimension. + +// ******************* Optimizer **************************** + +// Maximum number of iterations in each resolution level: +// 200-500 works usually fine for rigid registration. +// For more robustness, you may increase this to 1000-2000. +(MaximumNumberOfIterations 250) + +// The step size of the optimizer, in mm. By default the voxel size is used. +// which usually works well. In case of unusual high-resolution images +// (eg histology) it is necessary to increase this value a bit, to the size +// of the "smallest visible structure" in the image: +//(MaximumStepLength 1.0) + +// **************** Image sampling ********************** + +// Number of spatial samples used to compute the mutual +// information (and its derivative) in each iteration. +// With an AdaptiveStochasticGradientDescent optimizer, +// in combination with the two options below, around 2000 +// samples may already suffice. +(NumberOfSpatialSamples 2048) + +// Refresh these spatial samples in every iteration, and select +// them randomly. See the manual for information on other sampling +// strategies. +(NewSamplesEveryIteration "true") +(ImageSampler "Random") + +// ************* Interpolation and Resampling **************** + +// Order of B-Spline interpolation used during registration/optimisation. +// It may improve accuracy if you set this to 3. Never use 0. +// An order of 1 gives linear interpolation. This is in most +// applications a good choice. +(BSplineInterpolationOrder 1) + +// Order of B-Spline interpolation used for applying the final +// deformation. +// 3 gives good accuracy; recommended in most cases. +// 1 gives worse accuracy (linear interpolation) +// 0 gives worst accuracy, but is appropriate for binary images +// (masks, segmentations); equivalent to nearest neighbor interpolation. +(FinalBSplineInterpolationOrder 3) + +//Default pixel value for pixels that come from outside the picture: +(DefaultPixelValue 0) + +// Choose whether to generate the deformed moving image. +// You can save some time by setting this to false, if you are +// only interested in the final (nonrigidly) deformed moving image +// for example. +(WriteResultImage "true") + +// The pixel type and format of the resulting deformed moving image +(ResultImagePixelType "short") +(ResultImageFormat "mhd") + + diff --git a/src/IDH/hdbet_model/0.model b/src/IDH/hdbet_model/0.model new file mode 100644 index 0000000000000000000000000000000000000000..23d2336bed49651cb402e47ec75d81588aa5ce8d --- /dev/null +++ b/src/IDH/hdbet_model/0.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f75233753c4750672815e2b7a86db754995ae44b8f1cd77bccfc37becd2d83c +size 65443735 diff --git a/src/IDH/hdbet_model/hdbet_model/0.model b/src/IDH/hdbet_model/hdbet_model/0.model new file mode 100644 index 0000000000000000000000000000000000000000..23d2336bed49651cb402e47ec75d81588aa5ce8d --- /dev/null +++ b/src/IDH/hdbet_model/hdbet_model/0.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f75233753c4750672815e2b7a86db754995ae44b8f1cd77bccfc37becd2d83c +size 65443735 diff --git a/src/IDH/model.py b/src/IDH/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4a21896be9a68e5ec33a3985600b8f94e596b9 --- /dev/null +++ b/src/IDH/model.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from monai.networks.nets import ViT, UNETR +import os + + +class ViTUNETRSegmentationModel(nn.Module): + def __init__(self, simclr_ckpt_path: str, img_size=(96, 96, 96), in_channels=1, out_channels=1): + super().__init__() + # Load ViT backbone + self.vit = ViT( + in_channels=in_channels, + img_size=img_size, + patch_size=(16, 16, 16), + hidden_size=768, + mlp_dim=3072, + num_layers=12, + num_heads=12, + save_attn=False, + ) + + # Load SimCLR weights if provided + if False:#simclr_ckpt_path and os.path.exists(simclr_ckpt_path): + ckpt = torch.load(simclr_ckpt_path, map_location='cpu', weights_only=False) + state_dict = ckpt.get('state_dict', ckpt) + backbone_state_dict = {k[9:]: v for k, v in state_dict.items() if k.startswith('backbone.')} + missing, unexpected = self.vit.load_state_dict(backbone_state_dict, strict=False) + print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}") + else: + print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.") + + # UNETR decoder + self.unetr = UNETR( + in_channels=in_channels, + out_channels=out_channels, + img_size=img_size, + feature_size=16, + hidden_size=768, + mlp_dim=3072, + num_heads=12, + norm_name='instance', + res_block=True, + dropout_rate=0.0 + ) + + # Transfer ViT weights to UNETR encoder + self.unetr.vit.load_state_dict(self.vit.state_dict(), strict=True) + print("="*10) + print("ViT loaded for segmentation") + print("="*10) + + def forward(self, x): + return self.unetr(x) \ No newline at end of file