Divyanshu Tak commited on
Commit
0ee52bb
·
1 Parent(s): 96acd99

Add BrainIAC Glioma Segmentation app with proper Docker setup

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +37 -0
  2. README.md +22 -5
  3. requirements.txt +18 -0
  4. src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-310.pyc +0 -0
  5. src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-38.pyc +0 -0
  6. src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-39.pyc +0 -0
  7. src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-310.pyc +0 -0
  8. src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-38.pyc +0 -0
  9. src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-39.pyc +0 -0
  10. src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-310.pyc +0 -0
  11. src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-38.pyc +0 -0
  12. src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-310.pyc +0 -0
  13. src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-38.pyc +0 -0
  14. src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-39.pyc +0 -0
  15. src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-310.pyc +0 -0
  16. src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-38.pyc +0 -0
  17. src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-39.pyc +0 -0
  18. src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-310.pyc +0 -0
  19. src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-38.pyc +0 -0
  20. src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-39.pyc +0 -0
  21. src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-310.pyc +0 -0
  22. src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-38.pyc +0 -0
  23. src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-39.pyc +0 -0
  24. src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-310.pyc +0 -0
  25. src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-38.pyc +0 -0
  26. src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-39.pyc +0 -0
  27. src/IDH/HD_BET/HD_BET/config.py +121 -0
  28. src/IDH/HD_BET/HD_BET/data_loading.py +121 -0
  29. src/IDH/HD_BET/HD_BET/hd_bet.py +119 -0
  30. src/IDH/HD_BET/HD_BET/network_architecture.py +213 -0
  31. src/IDH/HD_BET/HD_BET/paths.py +6 -0
  32. src/IDH/HD_BET/HD_BET/predict_case.py +126 -0
  33. src/IDH/HD_BET/HD_BET/run.py +117 -0
  34. src/IDH/HD_BET/HD_BET/utils.py +115 -0
  35. src/IDH/HD_BET/__pycache__/config.cpython-310.pyc +0 -0
  36. src/IDH/HD_BET/__pycache__/config.cpython-38.pyc +0 -0
  37. src/IDH/HD_BET/__pycache__/config.cpython-39.pyc +0 -0
  38. src/IDH/HD_BET/__pycache__/data_loading.cpython-310.pyc +0 -0
  39. src/IDH/HD_BET/__pycache__/data_loading.cpython-38.pyc +0 -0
  40. src/IDH/HD_BET/__pycache__/data_loading.cpython-39.pyc +0 -0
  41. src/IDH/HD_BET/__pycache__/hd_bet.cpython-310.pyc +0 -0
  42. src/IDH/HD_BET/__pycache__/hd_bet.cpython-38.pyc +0 -0
  43. src/IDH/HD_BET/__pycache__/network_architecture.cpython-310.pyc +0 -0
  44. src/IDH/HD_BET/__pycache__/network_architecture.cpython-38.pyc +0 -0
  45. src/IDH/HD_BET/__pycache__/network_architecture.cpython-39.pyc +0 -0
  46. src/IDH/HD_BET/__pycache__/paths.cpython-310.pyc +0 -0
  47. src/IDH/HD_BET/__pycache__/paths.cpython-38.pyc +0 -0
  48. src/IDH/HD_BET/__pycache__/paths.cpython-39.pyc +0 -0
  49. src/IDH/HD_BET/__pycache__/predict_case.cpython-310.pyc +0 -0
  50. src/IDH/HD_BET/__pycache__/predict_case.cpython-38.pyc +0 -0
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install necessary system dependencies (kept minimal)
8
+ RUN apt-get update && \
9
+ apt-get install -y --no-install-recommends \
10
+ git \
11
+ libgl1 \
12
+ libglib2.0-0 \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy the requirements file first to leverage Docker cache
16
+ COPY requirements.txt ./
17
+
18
+ # Install Python packages
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy the entire project
22
+ COPY . /app/
23
+
24
+ # Create a non-root user (HF Spaces requirement)
25
+ RUN useradd -m -u 1000 user
26
+ USER user
27
+
28
+ # Make sure the user owns the app directory
29
+ COPY --chown=user:user . /app/
30
+
31
+ # Expose Gradio default port
32
+ EXPOSE 7860
33
+
34
+ ENV PYTHONUNBUFFERED=1
35
+
36
+ # Run the app from the src/IDH directory
37
+ CMD ["python", "src/IDH/app_gradio.py"]
README.md CHANGED
@@ -1,11 +1,28 @@
1
  ---
2
  title: BrainIAC Glioma Segmentation
3
- emoji: 🏃
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
- license: cc-by-4.0
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: BrainIAC Glioma Segmentation
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: docker
7
  pinned: false
8
+ license: mit
9
  ---
10
 
11
+ # BrainIAC: Glioma Segmentation
12
+
13
+ A Vision Transformer UNETR model for glioma segmentation from FLAIR MRI scans.
14
+
15
+ ## Features
16
+ - Upload FLAIR MRI NIfTI files
17
+ - Optional preprocessing (debiasing + registration + skull stripping)
18
+ - Interactive slice-by-slice visualization
19
+ - Download preprocessed images and segmentation masks
20
+ - Real-time segmentation statistics
21
+
22
+ ## Usage
23
+ 1. Upload a FLAIR MRI scan (.nii or .nii.gz)
24
+ 2. Optionally enable preprocessing
25
+ 3. Adjust segmentation threshold
26
+ 4. View results and download files
27
+
28
+ *Research use only*
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ monai==1.3.2
2
+ nibabel==5.2.1
3
+ numpy==1.23.5
4
+ pydicom
5
+ PyYAML
6
+ pytorch-lightning==2.3.3
7
+ scipy==1.10.1
8
+ SimpleITK==2.4.0
9
+ torch==2.6.0
10
+ tqdm
11
+ gradio
12
+ pandas
13
+ scikit-image==0.21.0
14
+ opencv-python
15
+ itk-elastix
16
+ dicom2nifti
17
+ einops
18
+ matplotlib
src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-310.pyc ADDED
Binary file (4.15 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-38.pyc ADDED
Binary file (4.13 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/config.cpython-39.pyc ADDED
Binary file (4.19 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-38.pyc ADDED
Binary file (4.48 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/data_loading.cpython-39.pyc ADDED
Binary file (4.46 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/hd_bet.cpython-38.pyc ADDED
Binary file (4.27 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-310.pyc ADDED
Binary file (6.78 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-38.pyc ADDED
Binary file (6.89 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/network_architecture.cpython-39.pyc ADDED
Binary file (6.84 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-310.pyc ADDED
Binary file (324 Bytes). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-38.pyc ADDED
Binary file (335 Bytes). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/paths.cpython-39.pyc ADDED
Binary file (322 Bytes). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-38.pyc ADDED
Binary file (3.67 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/predict_case.cpython-39.pyc ADDED
Binary file (3.68 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-310.pyc ADDED
Binary file (3.83 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-38.pyc ADDED
Binary file (3.88 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/run.cpython-39.pyc ADDED
Binary file (3.85 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.85 kB). View file
 
src/IDH/HD_BET/HD_BET/__pycache__/utils.cpython-39.pyc ADDED
Binary file (4.81 kB). View file
 
src/IDH/HD_BET/HD_BET/config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from HD_BET.utils import SetNetworkToVal, softmax_helper
4
+ from abc import abstractmethod
5
+ from HD_BET.network_architecture import Network
6
+
7
+
8
+ class BaseConfig(object):
9
+ def __init__(self):
10
+ pass
11
+
12
+ @abstractmethod
13
+ def get_split(self, fold, random_state=12345):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def get_network(self, mode="train"):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_basic_generators(self, fold):
22
+ pass
23
+
24
+ @abstractmethod
25
+ def get_data_generators(self, fold):
26
+ pass
27
+
28
+ def preprocess(self, data):
29
+ return data
30
+
31
+ def __repr__(self):
32
+ res = ""
33
+ for v in vars(self):
34
+ if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
35
+ res += (v + ": " + str(self.__getattribute__(v)) + "\n")
36
+ return res
37
+
38
+
39
+ class HD_BET_Config(BaseConfig):
40
+ def __init__(self):
41
+ super(HD_BET_Config, self).__init__()
42
+
43
+ self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
44
+
45
+ # network parameters
46
+ self.net_base_num_layers = 21
47
+ self.BATCH_SIZE = 2
48
+ self.net_do_DS = True
49
+ self.net_dropout_p = 0.0
50
+ self.net_use_inst_norm = True
51
+ self.net_conv_use_bias = True
52
+ self.net_norm_use_affine = True
53
+ self.net_leaky_relu_slope = 1e-1
54
+
55
+ # hyperparameters
56
+ self.INPUT_PATCH_SIZE = (128, 128, 128)
57
+ self.num_classes = 2
58
+ self.selected_data_channels = range(1)
59
+
60
+ # data augmentation
61
+ self.da_mirror_axes = (2, 3, 4)
62
+
63
+ # validation
64
+ self.val_use_DO = False
65
+ self.val_use_train_mode = False # for dropout sampling
66
+ self.val_num_repeats = 1 # only useful if dropout sampling
67
+ self.val_batch_size = 1 # only useful if dropout sampling
68
+ self.val_save_npz = True
69
+ self.val_do_mirroring = True # test time data augmentation via mirroring
70
+ self.val_write_images = True
71
+ self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
72
+ self.val_min_size = self.INPUT_PATCH_SIZE
73
+ self.val_fn = None
74
+
75
+ # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
76
+ # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
77
+ # to false in 0.4)
78
+ self.val_use_moving_averages = False
79
+
80
+ def get_network(self, train=True, pretrained_weights=None):
81
+ net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
82
+ self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
83
+ self.net_norm_use_affine, True, self.net_do_DS)
84
+
85
+ if pretrained_weights is not None:
86
+ net.load_state_dict(
87
+ torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
88
+
89
+ if train:
90
+ net.train(True)
91
+ else:
92
+ net.train(False)
93
+ net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
94
+ net.do_ds = False
95
+
96
+ optimizer = None
97
+ self.lr_scheduler = None
98
+ return net, optimizer
99
+
100
+ def get_data_generators(self, fold):
101
+ pass
102
+
103
+ def get_split(self, fold, random_state=12345):
104
+ pass
105
+
106
+ def get_basic_generators(self, fold):
107
+ pass
108
+
109
+ def on_epoch_end(self, epoch):
110
+ pass
111
+
112
+ def preprocess(self, data):
113
+ data = np.copy(data)
114
+ for c in range(data.shape[0]):
115
+ data[c] -= data[c].mean()
116
+ data[c] /= data[c].std()
117
+ return data
118
+
119
+
120
+ config = HD_BET_Config
121
+
src/IDH/HD_BET/HD_BET/data_loading.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import SimpleITK as sitk
2
+ import numpy as np
3
+ from skimage.transform import resize
4
+
5
+
6
+ def resize_image(image, old_spacing, new_spacing, order=3):
7
+ new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))),
8
+ int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))),
9
+ int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2]))))
10
+ return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False)
11
+
12
+
13
+ def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)):
14
+ spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]]
15
+ image = sitk.GetArrayFromImage(itk_image).astype(float)
16
+
17
+ assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed"
18
+
19
+ if not is_seg:
20
+ if np.any([[i != j] for i, j in zip(spacing, spacing_target)]):
21
+ image = resize_image(image, spacing, spacing_target).astype(np.float32)
22
+
23
+ image -= image.mean()
24
+ image /= image.std()
25
+ else:
26
+ new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))),
27
+ int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))),
28
+ int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))))
29
+ image = resize_segmentation(image, new_shape, 1)
30
+ return image
31
+
32
+
33
+ def load_and_preprocess(mri_file):
34
+ images = {}
35
+ # t1
36
+ images["T1"] = sitk.ReadImage(mri_file)
37
+
38
+ properties_dict = {
39
+ "spacing": images["T1"].GetSpacing(),
40
+ "direction": images["T1"].GetDirection(),
41
+ "size": images["T1"].GetSize(),
42
+ "origin": images["T1"].GetOrigin()
43
+ }
44
+
45
+ for k in images.keys():
46
+ images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5))
47
+
48
+ properties_dict['size_before_cropping'] = images["T1"].shape
49
+
50
+ imgs = []
51
+ for seq in ['T1']:
52
+ imgs.append(images[seq][None])
53
+ all_data = np.vstack(imgs)
54
+ print("image shape after preprocessing: ", str(all_data[0].shape))
55
+ return all_data, properties_dict
56
+
57
+
58
+ def save_segmentation_nifti(segmentation, dct, out_fname, order=1):
59
+ '''
60
+ segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out
61
+ of the original image
62
+
63
+ dct:
64
+ size_before_cropping
65
+ brain_bbox
66
+ size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping
67
+ spacing
68
+ origin
69
+ direction
70
+
71
+ :param segmentation:
72
+ :param dct:
73
+ :param out_fname:
74
+ :return:
75
+ '''
76
+ old_size = dct.get('size_before_cropping')
77
+ bbox = dct.get('brain_bbox')
78
+ if bbox is not None:
79
+ seg_old_size = np.zeros(old_size)
80
+ for c in range(3):
81
+ bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c]))
82
+ seg_old_size[bbox[0][0]:bbox[0][1],
83
+ bbox[1][0]:bbox[1][1],
84
+ bbox[2][0]:bbox[2][1]] = segmentation
85
+ else:
86
+ seg_old_size = segmentation
87
+ if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]):
88
+ seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order)
89
+ else:
90
+ seg_old_spacing = seg_old_size
91
+ seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32))
92
+ seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]])
93
+ seg_resized_itk.SetOrigin(dct['origin'])
94
+ seg_resized_itk.SetDirection(dct['direction'])
95
+ sitk.WriteImage(seg_resized_itk, out_fname)
96
+
97
+
98
+ def resize_segmentation(segmentation, new_shape, order=3, cval=0):
99
+ '''
100
+ Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency
101
+
102
+ Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
103
+ hot encoding which is resized and transformed back to a segmentation map.
104
+ This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
105
+ :param segmentation:
106
+ :param new_shape:
107
+ :param order:
108
+ :return:
109
+ '''
110
+ tpe = segmentation.dtype
111
+ unique_labels = np.unique(segmentation)
112
+ assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
113
+ if order == 0:
114
+ return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe)
115
+ else:
116
+ reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
117
+
118
+ for i, c in enumerate(unique_labels):
119
+ reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
120
+ reshaped[reshaped_multihot >= 0.5] = c
121
+ return reshaped
src/IDH/HD_BET/HD_BET/hd_bet.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+ sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/")
6
+ from HD_BET.run import run_hd_bet
7
+ from HD_BET.utils import maybe_mkdir_p, subfiles
8
+ import HD_BET
9
+
10
+ def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1):
11
+
12
+ if output_file_or_dir is None:
13
+ output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir),
14
+ os.path.basename(input_file_or_dir).split(".")[0] + "_bet")
15
+
16
+
17
+ params_file = os.path.join(HD_BET.__path__[0], "model_final.py")
18
+ config_file = os.path.join(HD_BET.__path__[0], "config.py")
19
+
20
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
21
+
22
+ if device == 'cpu':
23
+ pass
24
+ else:
25
+ device = int(device)
26
+
27
+ if os.path.isdir(input_file_or_dir):
28
+ maybe_mkdir_p(output_file_or_dir)
29
+ input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False)
30
+
31
+ if len(input_files) == 0:
32
+ raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here")
33
+
34
+ output_files = [os.path.join(output_file_or_dir, i) for i in input_files]
35
+ input_files = [os.path.join(input_file_or_dir, i) for i in input_files]
36
+ else:
37
+ if not output_file_or_dir.endswith('.nii.gz'):
38
+ output_file_or_dir += '.nii.gz'
39
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
40
+
41
+ output_files = [output_file_or_dir]
42
+ input_files = [input_file_or_dir]
43
+
44
+ if tta == 0:
45
+ tta = False
46
+ elif tta == 1:
47
+ tta = True
48
+ else:
49
+ raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta))
50
+
51
+ if overwrite_existing == 0:
52
+ overwrite_existing = False
53
+ elif overwrite_existing == 1:
54
+ overwrite_existing = True
55
+ else:
56
+ raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing))
57
+
58
+ if pp == 0:
59
+ pp = False
60
+ elif pp == 1:
61
+ pp = True
62
+ else:
63
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
64
+
65
+ if save_mask == 0:
66
+ save_mask = False
67
+ elif save_mask == 1:
68
+ save_mask = True
69
+ else:
70
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
71
+
72
+ run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ print("\n########################")
77
+ print("If you are using hd-bet, please cite the following paper:")
78
+ print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W,"
79
+ "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial"
80
+ "neural networks. arXiv preprint arXiv:1901.11341, 2019.")
81
+ print("########################\n")
82
+
83
+ import argparse
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be '
86
+ 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to '
87
+ 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz '
88
+ 'within that folder will be brain extracted.', required=True, type=str)
89
+ parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder'
90
+ ' will be created', required=False, type=str)
91
+ parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will '
92
+ 'use only one set of parameters whereas accurate will '
93
+ 'use the five sets of parameters that resulted from '
94
+ 'our cross-validation as an ensemble. Default: '
95
+ 'accurate',
96
+ required=False)
97
+ parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. '
98
+ 'Must be either int or str. Use int for GPU id or '
99
+ '\'cpu\' to run on CPU. When using CPU you should '
100
+ 'consider disabling tta. Default for -device is: 0',
101
+ required=False)
102
+ parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation '
103
+ '(mirroring). 1= True, 0=False. Disable this '
104
+ 'if you are using CPU to speed things up! '
105
+ 'Default: 1')
106
+ parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all'
107
+ ' but the largest connected component in '
108
+ 'the prediction. Default: 1')
109
+ parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation '
110
+ 'mask will not be '
111
+ 'saved')
112
+ parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't "
113
+ "want to overwrite existing "
114
+ "predictions")
115
+
116
+ args = parser.parse_args()
117
+
118
+ hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing)
119
+
src/IDH/HD_BET/HD_BET/network_architecture.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from HD_BET.utils import softmax_helper
5
+
6
+
7
+ class EncodingModule(nn.Module):
8
+ def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True,
9
+ inst_norm_affine=True, lrelu_inplace=True):
10
+ nn.Module.__init__(self)
11
+ self.dropout_p = dropout_p
12
+ self.lrelu_inplace = lrelu_inplace
13
+ self.inst_norm_affine = inst_norm_affine
14
+ self.conv_bias = conv_bias
15
+ self.leakiness = leakiness
16
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
17
+ self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
18
+ self.dropout = nn.Dropout3d(dropout_p)
19
+ self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
20
+ self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
21
+
22
+ def forward(self, x):
23
+ skip = x
24
+ x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
25
+ x = self.conv1(x)
26
+ if self.dropout_p is not None and self.dropout_p > 0:
27
+ x = self.dropout(x)
28
+ x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
29
+ x = self.conv2(x)
30
+ x = x + skip
31
+ return x
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True):
36
+ super(Upsample, self).__init__()
37
+ self.align_corners = align_corners
38
+ self.mode = mode
39
+ self.scale_factor = scale_factor
40
+ self.size = size
41
+
42
+ def forward(self, x):
43
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
44
+ align_corners=self.align_corners)
45
+
46
+
47
+ class LocalizationModule(nn.Module):
48
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
49
+ lrelu_inplace=True):
50
+ nn.Module.__init__(self)
51
+ self.lrelu_inplace = lrelu_inplace
52
+ self.inst_norm_affine = inst_norm_affine
53
+ self.conv_bias = conv_bias
54
+ self.leakiness = leakiness
55
+ self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
56
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
57
+ self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
58
+ self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
59
+
60
+ def forward(self, x):
61
+ x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
62
+ x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
63
+ return x
64
+
65
+
66
+ class UpsamplingModule(nn.Module):
67
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
68
+ lrelu_inplace=True):
69
+ nn.Module.__init__(self)
70
+ self.lrelu_inplace = lrelu_inplace
71
+ self.inst_norm_affine = inst_norm_affine
72
+ self.conv_bias = conv_bias
73
+ self.leakiness = leakiness
74
+ self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True)
75
+ self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias)
76
+ self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
77
+
78
+ def forward(self, x):
79
+ x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness,
80
+ inplace=self.lrelu_inplace)
81
+ return x
82
+
83
+
84
+ class DownsamplingModule(nn.Module):
85
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
86
+ lrelu_inplace=True):
87
+ nn.Module.__init__(self)
88
+ self.lrelu_inplace = lrelu_inplace
89
+ self.inst_norm_affine = inst_norm_affine
90
+ self.conv_bias = conv_bias
91
+ self.leakiness = leakiness
92
+ self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
93
+ self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)
94
+
95
+ def forward(self, x):
96
+ x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
97
+ b = self.downsample(x)
98
+ return x, b
99
+
100
+
101
+ class Network(nn.Module):
102
+ def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3,
103
+ final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
104
+ lrelu_inplace=True, do_ds=True):
105
+ super(Network, self).__init__()
106
+
107
+ self.do_ds = do_ds
108
+ self.lrelu_inplace = lrelu_inplace
109
+ self.inst_norm_affine = inst_norm_affine
110
+ self.conv_bias = conv_bias
111
+ self.leakiness = leakiness
112
+ self.final_nonlin = final_nonlin
113
+ self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias)
114
+
115
+ self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
116
+ inst_norm_affine=True, lrelu_inplace=True)
117
+ self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True,
118
+ inst_norm_affine=True, lrelu_inplace=True)
119
+
120
+ self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
121
+ inst_norm_affine=True, lrelu_inplace=True)
122
+ self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True,
123
+ inst_norm_affine=True, lrelu_inplace=True)
124
+
125
+ self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
126
+ inst_norm_affine=True, lrelu_inplace=True)
127
+ self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True,
128
+ inst_norm_affine=True, lrelu_inplace=True)
129
+
130
+ self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
131
+ inst_norm_affine=True, lrelu_inplace=True)
132
+ self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True,
133
+ inst_norm_affine=True, lrelu_inplace=True)
134
+
135
+ self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2,
136
+ conv_bias=True, inst_norm_affine=True, lrelu_inplace=True)
137
+
138
+ self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
139
+ self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
140
+ inst_norm_affine=True, lrelu_inplace=True)
141
+
142
+ self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
143
+ inst_norm_affine=True, lrelu_inplace=True)
144
+ self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
145
+ inst_norm_affine=True, lrelu_inplace=True)
146
+
147
+ self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
148
+ inst_norm_affine=True, lrelu_inplace=True)
149
+ self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False)
150
+ self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
151
+ inst_norm_affine=True, lrelu_inplace=True)
152
+
153
+ self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
154
+ inst_norm_affine=True, lrelu_inplace=True)
155
+ self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
156
+ self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True,
157
+ inst_norm_affine=True, lrelu_inplace=True)
158
+
159
+ self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
160
+ self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
161
+ self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
162
+ self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
163
+ self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
164
+
165
+ def forward(self, x):
166
+ seg_outputs = []
167
+
168
+ x = self.init_conv(x)
169
+ x = self.context1(x)
170
+
171
+ skip1, x = self.down1(x)
172
+ x = self.context2(x)
173
+
174
+ skip2, x = self.down2(x)
175
+ x = self.context3(x)
176
+
177
+ skip3, x = self.down3(x)
178
+ x = self.context4(x)
179
+
180
+ skip4, x = self.down4(x)
181
+ x = self.context5(x)
182
+
183
+ x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
184
+ x = self.up1(x)
185
+
186
+ x = torch.cat((skip4, x), dim=1)
187
+ x = self.loc1(x)
188
+ x = self.up2(x)
189
+
190
+ x = torch.cat((skip3, x), dim=1)
191
+ x = self.loc2(x)
192
+ loc2_seg = self.final_nonlin(self.loc2_seg(x))
193
+ seg_outputs.append(loc2_seg)
194
+ x = self.up3(x)
195
+
196
+ x = torch.cat((skip2, x), dim=1)
197
+ x = self.loc3(x)
198
+ loc3_seg = self.final_nonlin(self.loc3_seg(x))
199
+ seg_outputs.append(loc3_seg)
200
+ x = self.up4(x)
201
+
202
+ x = torch.cat((skip1, x), dim=1)
203
+ x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness,
204
+ inplace=self.lrelu_inplace)
205
+ x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness,
206
+ inplace=self.lrelu_inplace)
207
+ x = self.final_nonlin(self.seg_layer(x))
208
+ seg_outputs.append(x)
209
+
210
+ if self.do_ds:
211
+ return seg_outputs[::-1]
212
+ else:
213
+ return seg_outputs[-1]
src/IDH/HD_BET/HD_BET/paths.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # please refer to the readme on where to get the parameters. Save them in this folder:
4
+ # Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params"
5
+ # Updated path for Docker container:
6
+ folder_with_parameter_files = "/app/IDH/hdbet_model"
src/IDH/HD_BET/HD_BET/predict_case.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
6
+ if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
7
+ shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
8
+ shp = patient.shape
9
+ new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
10
+ shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
11
+ shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
12
+ for i in range(len(shp)):
13
+ if shp[i] % shape_must_be_divisible_by[i] == 0:
14
+ new_shp[i] -= shape_must_be_divisible_by[i]
15
+ if min_size is not None:
16
+ new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
17
+ return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
18
+
19
+
20
+ def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
21
+ shape = tuple(list(image.shape))
22
+ new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
23
+ if pad_value is None:
24
+ if len(shape) == 2:
25
+ pad_value = image[0,0]
26
+ elif len(shape) == 3:
27
+ pad_value = image[0, 0, 0]
28
+ else:
29
+ raise ValueError("Image must be either 2 or 3 dimensional")
30
+ res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
31
+ if len(shape) == 2:
32
+ res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
33
+ elif len(shape) == 3:
34
+ res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
35
+ return res
36
+
37
+
38
+ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
39
+ new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
40
+ with torch.no_grad():
41
+ pad_res = []
42
+ for i in range(patient_data.shape[0]):
43
+ t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
44
+ pad_res.append(t[None])
45
+
46
+ patient_data = np.vstack(pad_res)
47
+
48
+ new_shp = patient_data.shape
49
+
50
+ data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
51
+
52
+ data[0] = patient_data
53
+
54
+ if BATCH_SIZE is not None:
55
+ data = np.vstack([data] * BATCH_SIZE)
56
+
57
+ a = torch.rand(data.shape).float()
58
+
59
+ if main_device == 'cpu':
60
+ pass
61
+ else:
62
+ a = a.cuda(main_device)
63
+
64
+ if do_mirroring:
65
+ x = 8
66
+ else:
67
+ x = 1
68
+ all_preds = []
69
+ for i in range(num_repeats):
70
+ for m in range(x):
71
+ data_for_net = np.array(data)
72
+ do_stuff = False
73
+ if m == 0:
74
+ do_stuff = True
75
+ pass
76
+ if m == 1 and (4 in mirror_axes):
77
+ do_stuff = True
78
+ data_for_net = data_for_net[:, :, :, :, ::-1]
79
+ if m == 2 and (3 in mirror_axes):
80
+ do_stuff = True
81
+ data_for_net = data_for_net[:, :, :, ::-1, :]
82
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
83
+ do_stuff = True
84
+ data_for_net = data_for_net[:, :, :, ::-1, ::-1]
85
+ if m == 4 and (2 in mirror_axes):
86
+ do_stuff = True
87
+ data_for_net = data_for_net[:, :, ::-1, :, :]
88
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
89
+ do_stuff = True
90
+ data_for_net = data_for_net[:, :, ::-1, :, ::-1]
91
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
92
+ do_stuff = True
93
+ data_for_net = data_for_net[:, :, ::-1, ::-1, :]
94
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
95
+ do_stuff = True
96
+ data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
97
+
98
+ if do_stuff:
99
+ _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
100
+ p = net(a) # np.copy is necessary because ::-1 creates just a view i think
101
+ p = p.data.cpu().numpy()
102
+
103
+ if m == 0:
104
+ pass
105
+ if m == 1 and (4 in mirror_axes):
106
+ p = p[:, :, :, :, ::-1]
107
+ if m == 2 and (3 in mirror_axes):
108
+ p = p[:, :, :, ::-1, :]
109
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
110
+ p = p[:, :, :, ::-1, ::-1]
111
+ if m == 4 and (2 in mirror_axes):
112
+ p = p[:, :, ::-1, :, :]
113
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
114
+ p = p[:, :, ::-1, :, ::-1]
115
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
116
+ p = p[:, :, ::-1, ::-1, :]
117
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
118
+ p = p[:, :, ::-1, ::-1, ::-1]
119
+ all_preds.append(p)
120
+
121
+ stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
122
+ predicted_segmentation = stacked.mean(0).argmax(0)
123
+ uncertainty = stacked.var(0)
124
+ bayesian_predictions = stacked
125
+ softmax_pred = stacked.mean(0)
126
+ return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty
src/IDH/HD_BET/HD_BET/run.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import SimpleITK as sitk
4
+ from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
5
+ from HD_BET.predict_case import predict_case_3D_net
6
+ import imp
7
+ from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
8
+ import os
9
+ import HD_BET
10
+
11
+
12
+ def apply_bet(img, bet, out_fname):
13
+ img_itk = sitk.ReadImage(img)
14
+ img_npy = sitk.GetArrayFromImage(img_itk)
15
+ img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
16
+ img_npy[img_bet == 0] = 0
17
+ out = sitk.GetImageFromArray(img_npy)
18
+ out.CopyInformation(img_itk)
19
+ sitk.WriteImage(out, out_fname)
20
+
21
+
22
+ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
23
+ postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
24
+ """
25
+
26
+ :param mri_fnames: str or list/tuple of str
27
+ :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
28
+ :param mode: fast or accurate
29
+ :param config_file: config.py
30
+ :param device: either int (for device id) or 'cpu'
31
+ :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
32
+ but the largest predicted connected component. Default False
33
+ :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
34
+ CPU you may want to turn that off to speed things up
35
+ :return:
36
+ """
37
+
38
+ list_of_param_files = []
39
+
40
+ if mode == 'fast':
41
+ params_file = get_params_fname(0)
42
+ maybe_download_parameters(0)
43
+
44
+ list_of_param_files.append(params_file)
45
+ elif mode == 'accurate':
46
+ for i in range(5):
47
+ params_file = get_params_fname(i)
48
+ maybe_download_parameters(i)
49
+
50
+ list_of_param_files.append(params_file)
51
+ else:
52
+ raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
53
+
54
+ assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
55
+
56
+ cf = imp.load_source('cf', config_file)
57
+ cf = cf.config()
58
+
59
+ net, _ = cf.get_network(cf.val_use_train_mode, None)
60
+ if device == "cpu":
61
+ net = net.cpu()
62
+ else:
63
+ net.cuda(device)
64
+
65
+ if not isinstance(mri_fnames, (list, tuple)):
66
+ mri_fnames = [mri_fnames]
67
+
68
+ if not isinstance(output_fnames, (list, tuple)):
69
+ output_fnames = [output_fnames]
70
+
71
+ assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
72
+
73
+ params = []
74
+ for p in list_of_param_files:
75
+ params.append(torch.load(p, map_location=lambda storage, loc: storage))
76
+
77
+ for in_fname, out_fname in zip(mri_fnames, output_fnames):
78
+ mask_fname = out_fname[:-7] + "_mask.nii.gz"
79
+ if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
80
+ print("File:", in_fname)
81
+ print("preprocessing...")
82
+ try:
83
+ data, data_dict = load_and_preprocess(in_fname)
84
+ except RuntimeError:
85
+ print("\nERROR\nCould not read file", in_fname, "\n")
86
+ continue
87
+ except AssertionError as e:
88
+ print(e)
89
+ continue
90
+
91
+ softmax_preds = []
92
+
93
+ print("prediction (CNN id)...")
94
+ for i, p in enumerate(params):
95
+ print(i)
96
+ net.load_state_dict(p)
97
+ net.eval()
98
+ net.apply(SetNetworkToVal(False, False))
99
+ _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
100
+ cf.val_batch_size, cf.net_input_must_be_divisible_by,
101
+ cf.val_min_size, device, cf.da_mirror_axes)
102
+ softmax_preds.append(softmax_pred[None])
103
+
104
+ seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
105
+
106
+ if postprocess:
107
+ seg = postprocess_prediction(seg)
108
+
109
+ print("exporting segmentation...")
110
+ save_segmentation_nifti(seg, data_dict, mask_fname)
111
+
112
+ apply_bet(in_fname, mask_fname, out_fname)
113
+
114
+ if not keep_mask:
115
+ os.remove(mask_fname)
116
+
117
+
src/IDH/HD_BET/HD_BET/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from urllib.request import urlopen
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ from skimage.morphology import label
6
+ import os
7
+ from HD_BET.paths import folder_with_parameter_files
8
+
9
+
10
+ def get_params_fname(fold):
11
+ return os.path.join(folder_with_parameter_files, "%d.model" % fold)
12
+
13
+
14
+ def maybe_download_parameters(fold=0, force_overwrite=False):
15
+ """
16
+ Downloads the parameters for some fold if it is not present yet.
17
+ :param fold:
18
+ :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download
19
+ :return:
20
+ """
21
+
22
+ assert 0 <= fold <= 4, "fold must be between 0 and 4"
23
+
24
+ if not os.path.isdir(folder_with_parameter_files):
25
+ maybe_mkdir_p(folder_with_parameter_files)
26
+
27
+ out_filename = get_params_fname(fold)
28
+
29
+ if force_overwrite and os.path.isfile(out_filename):
30
+ os.remove(out_filename)
31
+
32
+ if not os.path.isfile(out_filename):
33
+ url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold
34
+ print("Downloading", url, "...")
35
+ data = urlopen(url).read()
36
+ #out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model"
37
+ with open(out_filename, 'wb') as f:
38
+ f.write(data)
39
+
40
+
41
+ def init_weights(module):
42
+ if isinstance(module, nn.Conv3d):
43
+ module.weight = nn.init.kaiming_normal(module.weight, a=1e-2)
44
+ if module.bias is not None:
45
+ module.bias = nn.init.constant(module.bias, 0)
46
+
47
+
48
+ def softmax_helper(x):
49
+ rpt = [1 for _ in range(len(x.size()))]
50
+ rpt[1] = x.size(1)
51
+ x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
52
+ e_x = torch.exp(x - x_max)
53
+ return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
54
+
55
+
56
+ class SetNetworkToVal(object):
57
+ def __init__(self, use_dropout_sampling=False, norm_use_average=True):
58
+ self.norm_use_average = norm_use_average
59
+ self.use_dropout_sampling = use_dropout_sampling
60
+
61
+ def __call__(self, module):
62
+ if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout):
63
+ module.train(self.use_dropout_sampling)
64
+ elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \
65
+ isinstance(module, nn.InstanceNorm1d) \
66
+ or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \
67
+ isinstance(module, nn.BatchNorm1d):
68
+ module.train(not self.norm_use_average)
69
+
70
+
71
+ def postprocess_prediction(seg):
72
+ # basically look for connected components and choose the largest one, delete everything else
73
+ print("running postprocessing... ")
74
+ mask = seg != 0
75
+ lbls = label(mask, connectivity=mask.ndim)
76
+ lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
77
+ largest_region = np.argmax(lbls_sizes[1:]) + 1
78
+ seg[lbls != largest_region] = 0
79
+ return seg
80
+
81
+
82
+ def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
83
+ if join:
84
+ l = os.path.join
85
+ else:
86
+ l = lambda x, y: y
87
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
88
+ and (prefix is None or i.startswith(prefix))
89
+ and (suffix is None or i.endswith(suffix))]
90
+ if sort:
91
+ res.sort()
92
+ return res
93
+
94
+
95
+ def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
96
+ if join:
97
+ l = os.path.join
98
+ else:
99
+ l = lambda x, y: y
100
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
101
+ and (prefix is None or i.startswith(prefix))
102
+ and (suffix is None or i.endswith(suffix))]
103
+ if sort:
104
+ res.sort()
105
+ return res
106
+
107
+
108
+ subfolders = subdirs # I am tired of confusing those
109
+
110
+
111
+ def maybe_mkdir_p(directory):
112
+ splits = directory.split("/")[1:]
113
+ for i in range(0, len(splits)):
114
+ if not os.path.isdir(os.path.join("", *splits[:i+1])):
115
+ os.mkdir(os.path.join("", *splits[:i+1]))
src/IDH/HD_BET/__pycache__/config.cpython-310.pyc ADDED
Binary file (4.15 kB). View file
 
src/IDH/HD_BET/__pycache__/config.cpython-38.pyc ADDED
Binary file (4.13 kB). View file
 
src/IDH/HD_BET/__pycache__/config.cpython-39.pyc ADDED
Binary file (4.19 kB). View file
 
src/IDH/HD_BET/__pycache__/data_loading.cpython-310.pyc ADDED
Binary file (4.47 kB). View file
 
src/IDH/HD_BET/__pycache__/data_loading.cpython-38.pyc ADDED
Binary file (4.48 kB). View file
 
src/IDH/HD_BET/__pycache__/data_loading.cpython-39.pyc ADDED
Binary file (4.46 kB). View file
 
src/IDH/HD_BET/__pycache__/hd_bet.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
src/IDH/HD_BET/__pycache__/hd_bet.cpython-38.pyc ADDED
Binary file (4.27 kB). View file
 
src/IDH/HD_BET/__pycache__/network_architecture.cpython-310.pyc ADDED
Binary file (6.78 kB). View file
 
src/IDH/HD_BET/__pycache__/network_architecture.cpython-38.pyc ADDED
Binary file (6.89 kB). View file
 
src/IDH/HD_BET/__pycache__/network_architecture.cpython-39.pyc ADDED
Binary file (6.84 kB). View file
 
src/IDH/HD_BET/__pycache__/paths.cpython-310.pyc ADDED
Binary file (324 Bytes). View file
 
src/IDH/HD_BET/__pycache__/paths.cpython-38.pyc ADDED
Binary file (335 Bytes). View file
 
src/IDH/HD_BET/__pycache__/paths.cpython-39.pyc ADDED
Binary file (322 Bytes). View file
 
src/IDH/HD_BET/__pycache__/predict_case.cpython-310.pyc ADDED
Binary file (3.68 kB). View file
 
src/IDH/HD_BET/__pycache__/predict_case.cpython-38.pyc ADDED
Binary file (3.67 kB). View file