File size: 2,139 Bytes
3368fe8
 
 
 
 
6bf4d42
 
3368fe8
 
 
89eead3
3368fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

import os
import numpy as np
import skimage.transform as trans
from skimage.color import rgb2gray
from unet.unet import unet
from unet.unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup, UNet_3Plus_DeepSup_CGM


def predict_model(input, unet_type):
    model_path = "weights"
    h, w = 256, 256
    input_shape = [h, w, 1]
    output_channels = 1
    batch_size = 1

    # convert image into numpy array and reshape it into model's input size
    img = trans.resize(input, (w, h))
    result_img = img.copy()
    img = rgb2gray(img).reshape(1, h, w, 1)

    # Load U-net model based on version: UNet type vo:unet, v1:unet3+, v2:unet3+ with deep supervision, v3:unet3+ with cgm
    if unet_type == 'v0':
        # load the pretrained model
        model_name = "unetv0_sgd500_neptune"
        model_file = os.path.join(model_path, model_name + ".hdf5")
        model = unet(model_file)
    elif unet_type == 'v1':
        # load the pretrained model
        model_name = "unetv1_sgd500_neptune"
        model_file = os.path.join(model_path, model_name + ".hdf5")
        model = UNet_3Plus(input_shape, output_channels, model_file)
    elif unet_type == 'v2':
        # load the pretrained model
        model_name = "unetv2_sgd500_neptune"
        model_file = os.path.join(model_path, model_name + ".hdf5")
        model = UNet_3Plus_DeepSup(input_shape, output_channels, model_file)
    else:
        # load the pretrained model
        model_name = "unetv3_sgd500_neptune"
        model_file = os.path.join(model_path, model_name + ".hdf5")
        model = UNet_3Plus_DeepSup_CGM(input_shape, output_channels, model_file)

    # Predict and save the results as numpy array
    results = model.predict(img)

    # Preprocess the prediction from the model depending on the model
    if unet_type == 'v2' or unet_type == 'v3':
        pred = np.copy(results[0])
    else:
        pred = np.copy(results)
    pred[pred >= 0.5] = 1
    pred[pred < 0.5] = 0
    output =  np.array(pred[0][:,:,0])

    # visualize the output mask
    seg_color = [0, 0, 255]
    masked = output != 0
    result_img[masked] = seg_color

    return result_img