File size: 3,700 Bytes
a06c206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

import gradio as gr
import onnxruntime as rt
import numpy as np
from transforms import ResizeLongestSide
from torch.nn import functional as F
import torch
import onnxruntime

IMAGE_SIZE = 1024

def preprocess_image(image):
    transform = ResizeLongestSide(IMAGE_SIZE)
    input_image = transform.apply_image(image)
    input_image_torch = torch.as_tensor(input_image, device="cpu")
    input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
    pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
    pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
    x = (input_image_torch - pixel_mean) / pixel_std
    h, w = x.shape[-2:]
    padh = IMAGE_SIZE - h
    padw = IMAGE_SIZE - w
    x = F.pad(x, (0, padw, 0, padh))
    x = x.numpy()
    return x

def prepare_inputs(image_embedding, input_point, image_shape):
    transform = ResizeLongestSide(IMAGE_SIZE)

    input_label = np.array([1])
    onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
    onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

    onnx_coord = transform.apply_coords(onnx_coord, image_shape).astype(np.float32)

    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    onnx_has_mask_input = np.zeros(1, dtype=np.float32)

    decoder_inputs = {
        "image_embeddings": image_embedding,
        "point_coords": onnx_coord,
        "point_labels": onnx_label,
        "mask_input": onnx_mask_input,
        "has_mask_input": onnx_has_mask_input,
        "orig_im_size": np.array(image_shape, dtype=np.float32)
    }
    return decoder_inputs

enc_session = onnxruntime.InferenceSession("encoder-quant.onnx")
dec_session = onnxruntime.InferenceSession("decoder-quant.onnx")

def predict_image(img):
    x = preprocess_image(img)

    encoder_inputs = {
        "x": x,
    }

    output = enc_session.run(None, encoder_inputs)
    image_embedding = output[0]

    middle_of_photo = np.array([[img.shape[1] / 2, img.shape[0] / 2]])

    decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, img.shape[:2])
    masks, _, low_res_logits = dec_session.run(None, decoder_inputs)

    # normalize the results between -1 and 1
    masks = masks[0][0]
    masks[masks<0] = 0
    masks = masks / np.max(masks)
    return masks, image_embedding, img.shape[:2]

def segment_image(image_embedding, shape, evt: gr.SelectData):
    image_embedding = np.array(image_embedding)
    middle_of_photo = np.array([evt.index])
    decoder_inputs = prepare_inputs(image_embedding, middle_of_photo, shape)
    masks, _, low_res_logits = dec_session.run(None, decoder_inputs)

    # normalize the results between -1 and 1
    masks = masks[0][0]
    masks[masks<0] = 0
    masks = masks / np.max(masks)
    return masks

with gr.Blocks() as demo:
    gr.Markdown("# SAM quantized (Segment Anything Model)")
    markdown = """
    This is a demo of the SAM model, which is a model for segmenting anything in an image. 
    It returns segmentation mask of the image that's overlapping with the clicked point.

    The model is quantized using ONNX Runtime
    """

    gr.Markdown(markdown)

    embedding = gr.State()
    shape = gr.State()
    with gr.Row():
        with gr.Column():
            inputs = gr.Image()
            start_segmentation = gr.Button("Segment")

        with gr.Column():
            outputs = gr.Image(label="Segmentation Mask")

    start_segmentation.click(
        predict_image,
        inputs,
        [outputs, embedding, shape],
    )

    outputs.select(
        segment_image,
        [embedding, shape],
        outputs,
    )




demo.launch()