fatchecker / predict_unet.py
bumble-bee's picture
changed model path in predict_unet
89eead3
import os
import numpy as np
import skimage.transform as trans
from skimage.color import rgb2gray
from unet.unet import unet
from unet.unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup, UNet_3Plus_DeepSup_CGM
def predict_model(input, unet_type):
model_path = "weights"
h, w = 256, 256
input_shape = [h, w, 1]
output_channels = 1
batch_size = 1
# convert image into numpy array and reshape it into model's input size
img = trans.resize(input, (w, h))
result_img = img.copy()
img = rgb2gray(img).reshape(1, h, w, 1)
# Load U-net model based on version: UNet type vo:unet, v1:unet3+, v2:unet3+ with deep supervision, v3:unet3+ with cgm
if unet_type == 'v0':
# load the pretrained model
model_name = "unetv0_sgd500_neptune"
model_file = os.path.join(model_path, model_name + ".hdf5")
model = unet(model_file)
elif unet_type == 'v1':
# load the pretrained model
model_name = "unetv1_sgd500_neptune"
model_file = os.path.join(model_path, model_name + ".hdf5")
model = UNet_3Plus(input_shape, output_channels, model_file)
elif unet_type == 'v2':
# load the pretrained model
model_name = "unetv2_sgd500_neptune"
model_file = os.path.join(model_path, model_name + ".hdf5")
model = UNet_3Plus_DeepSup(input_shape, output_channels, model_file)
else:
# load the pretrained model
model_name = "unetv3_sgd500_neptune"
model_file = os.path.join(model_path, model_name + ".hdf5")
model = UNet_3Plus_DeepSup_CGM(input_shape, output_channels, model_file)
# Predict and save the results as numpy array
results = model.predict(img)
# Preprocess the prediction from the model depending on the model
if unet_type == 'v2' or unet_type == 'v3':
pred = np.copy(results[0])
else:
pred = np.copy(results)
pred[pred >= 0.5] = 1
pred[pred < 0.5] = 0
output = np.array(pred[0][:,:,0])
# visualize the output mask
seg_color = [0, 0, 255]
masked = output != 0
result_img[masked] = seg_color
return result_img