Divyanshu Tak
Add BrainIAC Glioma Segmentation app with proper Docker setup
0ee52bb
import torch
import numpy as np
import SimpleITK as sitk
from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
from HD_BET.predict_case import predict_case_3D_net
import imp
from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
import os
import HD_BET
def apply_bet(img, bet, out_fname):
img_itk = sitk.ReadImage(img)
img_npy = sitk.GetArrayFromImage(img_itk)
img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
img_npy[img_bet == 0] = 0
out = sitk.GetImageFromArray(img_npy)
out.CopyInformation(img_itk)
sitk.WriteImage(out, out_fname)
def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
"""
:param mri_fnames: str or list/tuple of str
:param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
:param mode: fast or accurate
:param config_file: config.py
:param device: either int (for device id) or 'cpu'
:param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
but the largest predicted connected component. Default False
:param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
CPU you may want to turn that off to speed things up
:return:
"""
list_of_param_files = []
if mode == 'fast':
params_file = get_params_fname(0)
maybe_download_parameters(0)
list_of_param_files.append(params_file)
elif mode == 'accurate':
for i in range(5):
params_file = get_params_fname(i)
maybe_download_parameters(i)
list_of_param_files.append(params_file)
else:
raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
cf = imp.load_source('cf', config_file)
cf = cf.config()
net, _ = cf.get_network(cf.val_use_train_mode, None)
if device == "cpu":
net = net.cpu()
else:
net.cuda(device)
if not isinstance(mri_fnames, (list, tuple)):
mri_fnames = [mri_fnames]
if not isinstance(output_fnames, (list, tuple)):
output_fnames = [output_fnames]
assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
params = []
for p in list_of_param_files:
params.append(torch.load(p, map_location=lambda storage, loc: storage))
for in_fname, out_fname in zip(mri_fnames, output_fnames):
mask_fname = out_fname[:-7] + "_mask.nii.gz"
if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
print("File:", in_fname)
print("preprocessing...")
try:
data, data_dict = load_and_preprocess(in_fname)
except RuntimeError:
print("\nERROR\nCould not read file", in_fname, "\n")
continue
except AssertionError as e:
print(e)
continue
softmax_preds = []
print("prediction (CNN id)...")
for i, p in enumerate(params):
print(i)
net.load_state_dict(p)
net.eval()
net.apply(SetNetworkToVal(False, False))
_, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
cf.val_batch_size, cf.net_input_must_be_divisible_by,
cf.val_min_size, device, cf.da_mirror_axes)
softmax_preds.append(softmax_pred[None])
seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
if postprocess:
seg = postprocess_prediction(seg)
print("exporting segmentation...")
save_segmentation_nifti(seg, data_dict, mask_fname)
apply_bet(in_fname, mask_fname, out_fname)
if not keep_mask:
os.remove(mask_fname)