|
import torch |
|
import numpy as np |
|
import SimpleITK as sitk |
|
from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti |
|
from HD_BET.predict_case import predict_case_3D_net |
|
import imp |
|
from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters |
|
import os |
|
import HD_BET |
|
|
|
|
|
def apply_bet(img, bet, out_fname): |
|
img_itk = sitk.ReadImage(img) |
|
img_npy = sitk.GetArrayFromImage(img_itk) |
|
img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet)) |
|
img_npy[img_bet == 0] = 0 |
|
out = sitk.GetImageFromArray(img_npy) |
|
out.CopyInformation(img_itk) |
|
sitk.WriteImage(out, out_fname) |
|
|
|
|
|
def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0, |
|
postprocess=False, do_tta=True, keep_mask=True, overwrite=True): |
|
""" |
|
|
|
:param mri_fnames: str or list/tuple of str |
|
:param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames |
|
:param mode: fast or accurate |
|
:param config_file: config.py |
|
:param device: either int (for device id) or 'cpu' |
|
:param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all |
|
but the largest predicted connected component. Default False |
|
:param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use |
|
CPU you may want to turn that off to speed things up |
|
:return: |
|
""" |
|
|
|
list_of_param_files = [] |
|
|
|
if mode == 'fast': |
|
params_file = get_params_fname(0) |
|
maybe_download_parameters(0) |
|
|
|
list_of_param_files.append(params_file) |
|
elif mode == 'accurate': |
|
for i in range(5): |
|
params_file = get_params_fname(i) |
|
maybe_download_parameters(i) |
|
|
|
list_of_param_files.append(params_file) |
|
else: |
|
raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode) |
|
|
|
assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files" |
|
|
|
cf = imp.load_source('cf', config_file) |
|
cf = cf.config() |
|
|
|
net, _ = cf.get_network(cf.val_use_train_mode, None) |
|
if device == "cpu": |
|
net = net.cpu() |
|
else: |
|
net.cuda(device) |
|
|
|
if not isinstance(mri_fnames, (list, tuple)): |
|
mri_fnames = [mri_fnames] |
|
|
|
if not isinstance(output_fnames, (list, tuple)): |
|
output_fnames = [output_fnames] |
|
|
|
assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length" |
|
|
|
params = [] |
|
for p in list_of_param_files: |
|
params.append(torch.load(p, map_location=lambda storage, loc: storage)) |
|
|
|
for in_fname, out_fname in zip(mri_fnames, output_fnames): |
|
mask_fname = out_fname[:-7] + "_mask.nii.gz" |
|
if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)): |
|
print("File:", in_fname) |
|
print("preprocessing...") |
|
try: |
|
data, data_dict = load_and_preprocess(in_fname) |
|
except RuntimeError: |
|
print("\nERROR\nCould not read file", in_fname, "\n") |
|
continue |
|
except AssertionError as e: |
|
print(e) |
|
continue |
|
|
|
softmax_preds = [] |
|
|
|
print("prediction (CNN id)...") |
|
for i, p in enumerate(params): |
|
print(i) |
|
net.load_state_dict(p) |
|
net.eval() |
|
net.apply(SetNetworkToVal(False, False)) |
|
_, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats, |
|
cf.val_batch_size, cf.net_input_must_be_divisible_by, |
|
cf.val_min_size, device, cf.da_mirror_axes) |
|
softmax_preds.append(softmax_pred[None]) |
|
|
|
seg = np.argmax(np.vstack(softmax_preds).mean(0), 0) |
|
|
|
if postprocess: |
|
seg = postprocess_prediction(seg) |
|
|
|
print("exporting segmentation...") |
|
save_segmentation_nifti(seg, data_dict, mask_fname) |
|
|
|
apply_bet(in_fname, mask_fname, out_fname) |
|
|
|
if not keep_mask: |
|
os.remove(mask_fname) |
|
|
|
|
|
|