import torch import numpy as np def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None): if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)): shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3 shp = patient.shape new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0], shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1], shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]] for i in range(len(shp)): if shp[i] % shape_must_be_divisible_by[i] == 0: new_shp[i] -= shape_must_be_divisible_by[i] if min_size is not None: new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0) return reshape_by_padding_upper_coords(patient, new_shp, 0), shp def reshape_by_padding_upper_coords(image, new_shape, pad_value=None): shape = tuple(list(image.shape)) new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0)) if pad_value is None: if len(shape) == 2: pad_value = image[0,0] elif len(shape) == 3: pad_value = image[0, 0, 0] else: raise ValueError("Image must be either 2 or 3 dimensional") res = np.ones(list(new_shape), dtype=image.dtype) * pad_value if len(shape) == 2: res[0:0+int(shape[0]), 0:0+int(shape[1])] = image elif len(shape) == 3: res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image return res def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None, new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)): with torch.no_grad(): pad_res = [] for i in range(patient_data.shape[0]): t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size) pad_res.append(t[None]) patient_data = np.vstack(pad_res) new_shp = patient_data.shape data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32) data[0] = patient_data if BATCH_SIZE is not None: data = np.vstack([data] * BATCH_SIZE) a = torch.rand(data.shape).float() if main_device == 'cpu': pass else: a = a.cuda(main_device) if do_mirroring: x = 8 else: x = 1 all_preds = [] for i in range(num_repeats): for m in range(x): data_for_net = np.array(data) do_stuff = False if m == 0: do_stuff = True pass if m == 1 and (4 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, :, :, ::-1] if m == 2 and (3 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, :, ::-1, :] if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, :, ::-1, ::-1] if m == 4 and (2 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, ::-1, :, :] if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, ::-1, :, ::-1] if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, ::-1, ::-1, :] if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1] if do_stuff: _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net))) p = net(a) # np.copy is necessary because ::-1 creates just a view i think p = p.data.cpu().numpy() if m == 0: pass if m == 1 and (4 in mirror_axes): p = p[:, :, :, :, ::-1] if m == 2 and (3 in mirror_axes): p = p[:, :, :, ::-1, :] if m == 3 and (4 in mirror_axes) and (3 in mirror_axes): p = p[:, :, :, ::-1, ::-1] if m == 4 and (2 in mirror_axes): p = p[:, :, ::-1, :, :] if m == 5 and (2 in mirror_axes) and (4 in mirror_axes): p = p[:, :, ::-1, :, ::-1] if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): p = p[:, :, ::-1, ::-1, :] if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): p = p[:, :, ::-1, ::-1, ::-1] all_preds.append(p) stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]] predicted_segmentation = stacked.mean(0).argmax(0) uncertainty = stacked.var(0) bayesian_predictions = stacked softmax_pred = stacked.mean(0) return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty