Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import skimage.transform as trans | |
from skimage.color import rgb2gray | |
from unet.unet import unet | |
from unet.unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup, UNet_3Plus_DeepSup_CGM | |
def predict_model(input, unet_type): | |
model_path = "weights" | |
h, w = 256, 256 | |
input_shape = [h, w, 1] | |
output_channels = 1 | |
batch_size = 1 | |
# convert image into numpy array and reshape it into model's input size | |
img = trans.resize(input, (w, h)) | |
result_img = img.copy() | |
img = rgb2gray(img).reshape(1, h, w, 1) | |
# Load U-net model based on version: UNet type vo:unet, v1:unet3+, v2:unet3+ with deep supervision, v3:unet3+ with cgm | |
if unet_type == 'v0': | |
# load the pretrained model | |
model_name = "unetv0_sgd500_neptune" | |
model_file = os.path.join(model_path, model_name + ".hdf5") | |
model = unet(model_file) | |
elif unet_type == 'v1': | |
# load the pretrained model | |
model_name = "unetv1_sgd500_neptune" | |
model_file = os.path.join(model_path, model_name + ".hdf5") | |
model = UNet_3Plus(input_shape, output_channels, model_file) | |
elif unet_type == 'v2': | |
# load the pretrained model | |
model_name = "unetv2_sgd500_neptune" | |
model_file = os.path.join(model_path, model_name + ".hdf5") | |
model = UNet_3Plus_DeepSup(input_shape, output_channels, model_file) | |
else: | |
# load the pretrained model | |
model_name = "unetv3_sgd500_neptune" | |
model_file = os.path.join(model_path, model_name + ".hdf5") | |
model = UNet_3Plus_DeepSup_CGM(input_shape, output_channels, model_file) | |
# Predict and save the results as numpy array | |
results = model.predict(img) | |
# Preprocess the prediction from the model depending on the model | |
if unet_type == 'v2' or unet_type == 'v3': | |
pred = np.copy(results[0]) | |
else: | |
pred = np.copy(results) | |
pred[pred >= 0.5] = 1 | |
pred[pred < 0.5] = 0 | |
output = np.array(pred[0][:,:,0]) | |
# visualize the output mask | |
seg_color = [0, 0, 255] | |
masked = output != 0 | |
result_img[masked] = seg_color | |
return result_img | |