# example of using saved cycleGAN models for image translation #based on https://machinelearningmastery.com/cyclegan-tutorial-with-keras/ from keras.models import load_model import numpy as np import tensorflow_addons as tfa from scipy.ndimage import zoom from tqdm import tqdm import warnings warnings.filterwarnings("ignore") from huggingface_hub import hf_hub_download from skimage.morphology import binary_erosion, binary_dilation from skimage import draw def predict_mask(image, dim_x, dim_y, dim_z, _resize=True, norm_=True, mode_='test', patch_size=(64,128,128,1), _step=64, _step_z=32, _patch_size_z=64): cust={'InstanceNormalization': tfa.layers.InstanceNormalization} #load the model # Download the model from Hugging Face Model Hub model_dir = hf_hub_download(repo_id="Hemaxi/3DCycleGAN", filename="CycleGANVesselSegmentation.h5") model_BtoA = load_model(model_dir, cust) print('Mode: {}'.format(mode_)) _patch_size = patch_size[1] _nbslices = patch_size[0] perceqmin = 1 perceqmax = 99 image = ((image/(np.max(image)))*255).astype('uint8') print('Image Shape: {}'.format(image.shape)) print('----------------------------------------') initial_image_x = np.shape(image)[0] initial_image_y = np.shape(image)[1] initial_image_z = np.shape(image)[2] #percentile equalization if norm_: minval = np.percentile(image, perceqmin) maxval = np.percentile(image, perceqmax) image = np.clip(image, minval, maxval) image = (((image - minval) / (maxval - minval)) * 255).astype('uint8') if _resize: image = zoom(image, (dim_x/0.333, dim_y/0.333, dim_z/0.5), order=3, mode='nearest') image = ((image/np.max(image))*255.0).astype('uint8') #image size size_y = np.shape(image)[0] size_x = np.shape(image)[1] size_depth = np.shape(image)[2] aux_sizes_or = [size_y, size_x, size_depth] #patch size new_size_y = int((size_y/_patch_size) + 1) * _patch_size new_size_x = int((size_x/_patch_size) + 1) * _patch_size new_size_z = int((size_depth/_patch_size_z) + 1) * _patch_size_z aux_sizes = [new_size_y, new_size_x, new_size_z] ## zero padding aux_img = np.random.randint(1,50,(aux_sizes[0], aux_sizes[1], aux_sizes[2])) aux_img[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]] = image image = aux_img del aux_img final_mask_foreground = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2])) final_mask_background = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2])) final_mask_background = final_mask_background.astype('uint8') final_mask_foreground = final_mask_foreground.astype('uint8') total_iterations = int(image.shape[0]/_patch_size) with tqdm(total=total_iterations) as pbar: i=0 while i+_patch_size<=image.shape[0]: j=0 while j+_patch_size<=image.shape[1]: k=0 while k+_patch_size_z<=image.shape[2]: B_real = np.zeros((1,_nbslices,_patch_size,_patch_size,1),dtype='float32') _slice = image[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] _slice = _slice.transpose(2,0,1) _slice = np.expand_dims(_slice, axis=-1) B_real[0,:]=(_slice-127.5) /127.5 A_generated = model_BtoA.predict(B_real) A_generated = (A_generated + 1)/2 #from [-1,1] to [0,1] A_generated = A_generated[0,:,:,:,0] A_generated = A_generated.transpose(1,2,0) #print(np.unique(A_generated)) A_generated = (A_generated>0.5)*1 A_generated = A_generated.astype('uint8') final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + A_generated final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + (1-A_generated) k=k+_step_z j=j+_step i=i+_step pbar.update(1) del _slice del A_generated del B_real final_mask = (final_mask_foreground>=final_mask_background)*1 image = image[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:size_depth] print('Image Shape: {}'.format(image.shape)) print('----------------------------------------') final_mask = final_mask[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]] if _resize: final_mask = zoom(final_mask, (0.333/dim_x, 0.333/dim_y, 0.5/dim_z), order=3, mode='nearest') final_mask = (final_mask*255.0).astype('uint8') final_size_x = np.shape(final_mask)[0] final_size_y = np.shape(final_mask)[1] final_size_z = np.shape(final_mask)[2] aux_mask = np.zeros((initial_image_x, initial_image_y, initial_image_z)).astype('uint8') aux_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)] = final_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)] final_mask = aux_mask.copy() print('Mask Shape: {}'.format(final_mask.shape)) print('----------------------------------------') final_mask = final_mask/np.max(final_mask) final_mask = final_mask*255.0 final_mask = final_mask.astype('uint8') #closing operation to fill small holes mask = final_mask mask[mask!=0] = 1 mask = mask.astype('uint8') ellipsoid = draw.ellipsoid(9,9,3, spacing=(1,1,1), levelset=False) ellipsoid = ellipsoid.astype('uint8') ellipsoid = ellipsoid[1:-1,1:-1,1:-1] #perform closing operation on the mask dil = binary_dilation(mask, ellipsoid) closed_mask = binary_erosion(dil, ellipsoid) closed_mask = (closed_mask*255.0).astype('uint8') return closed_mask