|
|
|
|
|
from keras.models import load_model |
|
import numpy as np |
|
import tensorflow_addons as tfa |
|
from scipy.ndimage import zoom |
|
from tqdm import tqdm |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
def predict_mask(image, dim_x, dim_y, dim_z, _resize=True, norm_=True, mode_='test', patch_size=(64,128,128,1), _step=64, _step_z=32, _patch_size_z=64): |
|
|
|
cust={'InstanceNormalization': tfa.layers.InstanceNormalization} |
|
|
|
|
|
model_dir = hf_hub_download(repo_id="Hemaxi/3DCycleGAN", filename="CycleGANVesselSegmentation.h5") |
|
model_BtoA = load_model(model_dir, cust) |
|
|
|
print('Mode: {}'.format(mode_)) |
|
|
|
_patch_size = patch_size[1] |
|
_nbslices = patch_size[0] |
|
|
|
perceqmin = 1 |
|
perceqmax = 99 |
|
|
|
image = ((image/(np.max(image)))*255).astype('uint8') |
|
|
|
print('Image Shape: {}'.format(image.shape)) |
|
print('----------------------------------------') |
|
|
|
initial_image_x = np.shape(image)[0] |
|
initial_image_y = np.shape(image)[1] |
|
initial_image_z = np.shape(image)[2] |
|
|
|
|
|
if norm_: |
|
minval = np.percentile(image, perceqmin) |
|
maxval = np.percentile(image, perceqmax) |
|
image = np.clip(image, minval, maxval) |
|
image = (((image - minval) / (maxval - minval)) * 255).astype('uint8') |
|
|
|
if _resize: |
|
image = zoom(image, (dim_x/0.333, dim_y/0.333, dim_z/0.5), order=3, mode='nearest') |
|
image = ((image/np.max(image))*255.0).astype('uint8') |
|
|
|
|
|
|
|
size_y = np.shape(image)[0] |
|
size_x = np.shape(image)[1] |
|
size_depth = np.shape(image)[2] |
|
aux_sizes_or = [size_y, size_x, size_depth] |
|
|
|
|
|
|
|
new_size_y = int((size_y/_patch_size) + 1) * _patch_size |
|
new_size_x = int((size_x/_patch_size) + 1) * _patch_size |
|
new_size_z = int((size_depth/_patch_size_z) + 1) * _patch_size_z |
|
aux_sizes = [new_size_y, new_size_x, new_size_z] |
|
|
|
|
|
aux_img = np.random.randint(1,50,(aux_sizes[0], aux_sizes[1], aux_sizes[2])) |
|
aux_img[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]] = image |
|
image = aux_img |
|
del aux_img |
|
|
|
final_mask_foreground = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2])) |
|
final_mask_background = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2])) |
|
final_mask_background = final_mask_background.astype('uint8') |
|
final_mask_foreground = final_mask_foreground.astype('uint8') |
|
|
|
|
|
total_iterations = int(image.shape[0]/_patch_size) |
|
|
|
with tqdm(total=total_iterations) as pbar: |
|
i=0 |
|
while i+_patch_size<=image.shape[0]: |
|
j=0 |
|
while j+_patch_size<=image.shape[1]: |
|
k=0 |
|
while k+_patch_size_z<=image.shape[2]: |
|
|
|
B_real = np.zeros((1,_nbslices,_patch_size,_patch_size,1),dtype='float32') |
|
_slice = image[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] |
|
|
|
_slice = _slice.transpose(2,0,1) |
|
_slice = np.expand_dims(_slice, axis=-1) |
|
|
|
B_real[0,:]=(_slice-127.5) /127.5 |
|
|
|
A_generated = model_BtoA.predict(B_real) |
|
|
|
A_generated = (A_generated + 1)/2 |
|
|
|
A_generated = A_generated[0,:,:,:,0] |
|
A_generated = A_generated.transpose(1,2,0) |
|
|
|
|
|
A_generated = (A_generated>0.5)*1 |
|
|
|
A_generated = A_generated.astype('uint8') |
|
|
|
final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + A_generated |
|
final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + (1-A_generated) |
|
|
|
k=k+_step_z |
|
j=j+_step |
|
i=i+_step |
|
pbar.update(1) |
|
|
|
|
|
del _slice |
|
del A_generated |
|
del B_real |
|
|
|
final_mask = (final_mask_foreground>=final_mask_background)*1 |
|
|
|
image = image[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:size_depth] |
|
print('Image Shape: {}'.format(image.shape)) |
|
print('----------------------------------------') |
|
|
|
final_mask = final_mask[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]] |
|
|
|
|
|
if _resize: |
|
final_mask = zoom(final_mask, (0.333/dim_x, 0.333/dim_y, 0.5/dim_z), order=3, mode='nearest') |
|
final_mask = (final_mask*255.0).astype('uint8') |
|
|
|
final_size_x = np.shape(final_mask)[0] |
|
final_size_y = np.shape(final_mask)[1] |
|
final_size_z = np.shape(final_mask)[2] |
|
|
|
aux_mask = np.zeros((initial_image_x, initial_image_y, initial_image_z)).astype('uint8') |
|
aux_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)] = final_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)] |
|
|
|
final_mask = aux_mask.copy() |
|
|
|
|
|
print('Mask Shape: {}'.format(final_mask.shape)) |
|
print('----------------------------------------') |
|
final_mask = final_mask/np.max(final_mask) |
|
final_mask = final_mask*255.0 |
|
final_mask = final_mask.astype('uint8') |
|
|
|
return final_mask |