File size: 9,811 Bytes
9b3f85e
 
 
 
 
e07d9b3
9b3f85e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64c0095
 
9b3f85e
 
 
66cb0dd
 
 
 
49c438a
66cb0dd
 
 
dbef072
66cb0dd
7696dbe
 
 
 
406ea5e
6ac0651
 
 
 
 
 
7696dbe
4514277
 
 
5ba6bf7
 
91bb663
38c9666
 
7696dbe
4514277
 
 
7696dbe
4514277
7696dbe
4514277
7696dbe
 
 
 
 
 
 
 
dfd9053
4514277
66cb0dd
 
 
 
49c438a
66cb0dd
 
 
 
 
 
 
 
 
2279c2d
 
49c438a
2279c2d
 
 
 
 
 
 
 
 
7696dbe
9d30896
 
9b3f85e
 
7257855
 
9b3f85e
 
 
 
 
 
 
 
 
 
c81fbdf
 
 
 
 
 
9b3f85e
 
 
c81fbdf
 
9b3f85e
 
 
 
a5eabc6
9b3f85e
 
5909bfa
9b3f85e
 
 
 
 
 
 
7c66b75
23100d7
49fb715
23100d7
 
 
 
b35f5de
49fb715
23100d7
9b3f85e
 
96ed4cd
 
 
 
 
 
 
 
 
 
5fb6975
9d30896
27e8365
0cf944c
27e8365
 
 
 
49c438a
27e8365
96ed4cd
27e8365
 
 
96ed4cd
 
c803c34
96ed4cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34ce820
96ed4cd
 
 
 
 
 
 
 
 
 
 
 
9b3f85e
 
 
 
c803c34
bc50c42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cf944c
 
7d30024
0cf944c
 
 
 
bc50c42
 
fc3f34a
bc50c42
c81fbdf
 
 
fc3f34a
c81fbdf
 
fc3f34a
c81fbdf
 
 
9b3f85e
 
 
66cb0dd
 
 
 
49c438a
66cb0dd
 
 
 
 
 
 
 
 
9b3f85e
 
c81fbdf
 
9b3f85e
 
 
f7b6ec5
9b3f85e
 
 
 
 
613784c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
import pathlib

import gradio as gr
import torch

from transformers import AutoFeatureExtractor, DetrForObjectDetection
from visualization import visualize_attention_map, visualize_prediction
from style import css, description, title

from PIL import Image



def make_prediction(img, feature_extractor, model):
    inputs = feature_extractor(img, return_tensors="pt")
    outputs = model(**inputs)
    img_size = torch.tensor([tuple(reversed(img.size))])
    processed_outputs = feature_extractor.post_process(outputs, img_size)
    print(outputs.keys())
    return (
        processed_outputs[0],
        outputs["decoder_attentions"],
        outputs["encoder_attentions"],
    )


def construct_model_name(
    experiment_type,
    convbase,
    attention_heads_num,
    enc_dec_layers,
    ffn_dim,
    act_func,
    d_model,
    dilation=None
):
    base = "polejowska/"
    
    if convbase == "RESNET-50":
        base += "detr-r50"
    elif convbase == "RESNET-101":
        if enc_dec_layers == 6:
            return "polejowska/detr-r101-official"
        elif enc_dec_layers == 4:
            return "polejowska/detr-r101-cd45rb-8ah-4l"
        elif enc_dec_layers == 12:
            return "polejowska/detr-r101-cd45rb-8ah-12l"
    
    base += "-cd45rb"
    
    base += f"-{attention_heads_num}ah"
    
    base += f"-{enc_dec_layers}l"

    if attention_heads_num == 1:
        base += "-corrected"
    
    if d_model != 256:
        base += f"-{d_model}d"
    
    if ffn_dim == 1024:
        base += "-1024ffn"
    elif ffn_dim == 4096:
        base += "-4096ffn-correcetd"
    
    if act_func == "GeLU":
        base += "-gelu-corrected"
    
    if dilation == "True":
        base += "-dilation-corrected"
    
    return base


def detect_objects(
    experiment_type,
    convbase,
    attention_heads_num,
    enc_dec_layers,
    ffn_dim,
    act_func,
    d_model,
    dilation,
    image_input, 
    threshold=0.7,
    display_mask=False,
    img_input_mask=None
):
    if experiment_type in ["Parameters verification", "Reproducability check (1)", "Reproducability check (2)", "Reproducability check (3)", "Reproducability check (4)"]:
        if experiment_type == "Parameters verification":
            model_repo = construct_model_name(experiment_type, convbase, attention_heads_num, enc_dec_layers, ffn_dim, act_func, d_model, dilation)
        elif experiment_type == "Reproducability check (1)": 
            model_repo = "polejowska/detr-r50-cd45rb-all-2ah"
        elif experiment_type == "Reproducability check (2)":
            model_repo = "polejowska/detr-r50-cd45rb-all-4ah"
        elif experiment_type == "Reproducability check (3)":
            model_repo = "polejowska/detr-r50-cd45rb-all-8ah"
        elif experiment_type == "Reproducability check (4)":
            model_repo = "polejowska/detr-r50-cd45rb-all-16ah"
            
    model = DetrForObjectDetection.from_pretrained(model_repo)
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_repo)
    
    (
        processed_outputs,
        decoder_attention_map,
        encoder_attention_map,
    ) = make_prediction(image_input, feature_extractor, model)

    viz_img = visualize_prediction(
        pil_img=image_input,
        output_dict=processed_outputs,
        threshold=threshold,
        id2label=model.config.id2label,
        display_mask=display_mask,
        mask=img_input_mask
    )
    decoder_attention_map_img = visualize_attention_map(
        image_input, decoder_attention_map
    )
    encoder_attention_map_img = visualize_attention_map(
        image_input, encoder_attention_map
    )

    return (
        viz_img,
        decoder_attention_map_img,
        encoder_attention_map_img,
    )


def set_example_image(example: list):
    return gr.Image(value=example[0]), gr.Image(value=example[1])


with gr.Blocks(css=css) as app:
    gr.Markdown(title)

    with gr.Tabs():
        with gr.TabItem("Image upload and detections visualization"):
            with gr.Row():
                with gr.Column():
                    with gr.Row():
                        experiment_type = gr.Dropdown(
                            value="Parameters verification",
                            choices=[
                                "Parameters verification",
                                "Reproducability check (1)",
                                "Reproducability check (2)",
                                "Reproducability check (3)",
                                "Reproducability check (4)",
                            ],
                            label="Select an experiment type",
                            show_label=True,
                        )
                    with gr.Row():
                        convbase= gr.Dropdown(
                            value="RESNET-50",
                            choices=[
                                "RESNET-50",
                                "RESNET-101",
                            ],
                            label="Select a base model for convolution part",
                            show_label=True,
                        )
                    with gr.Row():
                        attention_heads_num = gr.Dropdown(
                            value=8,
                            choices=[1, 2, 4, 8, 16],
                            label="The number of attention heads in encoder and decoder",
                            show_label=True,
                        )
                    with gr.Row():
                        enc_dec_layers = gr.Dropdown(
                            value=6,
                            choices=[4, 6, 12],
                            label="The number of layers in encoder and decoder",
                            show_label=True,
                        )
                    with gr.Row():
                        ffn_dim = gr.Dropdown(
                            value=2048,
                            choices=[1024, 2048, 4096],
                            label="Select FFN dimension",
                            show_label=True,
                        )
                    with gr.Row():
                        act_func= gr.Dropdown(
                            value="ReLU",
                            choices=[
                                "ReLU",
                                "GeLU",
                            ],
                            label="Select an activation function",
                            show_label=True,
                        )
                    with gr.Row():
                        d_model= gr.Dropdown(
                            value=256,
                            choices=[128, 256, 512],
                            label="Select a hidden size",
                            show_label=True,
                        )
                    with gr.Row():
                        dilation= gr.Dropdown(
                            value="False",
                            choices=[
                                "True",
                                "False",
                            ],
                            label="Use dilation",
                            show_label=True,
                        )
                    with gr.Row():
                        slider_input = gr.Slider(
                            minimum=0.2, maximum=1, value=0.7, label="Prediction threshold"
                        )

                with gr.Column():
                    with gr.Row():
                        img_input = gr.Image(type="pil")
                    img_input_mask = gr.Image(type="pil", visible=False)
                    with gr.Row():
                        example_images = gr.Dataset(
                            components=[img_input, img_input_mask],
                            samples=[
                                [path.as_posix(), path.as_posix().replace("_HE", "_mask")]
                                for path in sorted(
                                    pathlib.Path("cd45rb_test_imgs").rglob("*_HE.png")
                                )
                            ],
                            samples_per_page=2,
                        )
                    with gr.Row():
                        display_mask = gr.Checkbox(
                            label="Display masks",
                        )
                    with gr.Row():
                        detect_button = gr.Button("Detect leukocytes")

                    with gr.Row():
                        with gr.Column():
                            img_output_from_upload = gr.Image(width=900, height=900)

        with gr.TabItem("Attentions visualization"):
            gr.Markdown("""Encoder attentions""")
            with gr.Row():
                encoder_att_map_output = gr.Image(width=850, height=850)
            gr.Markdown("""Decoder attentions""")
            with gr.Row():
                decoder_att_map_output = gr.Image(width=850, height=850)
        with gr.TabItem("Dataset details"):
            with gr.Row():
                gr.Markdown(description)

    detect_button.click(
        detect_objects,
        inputs=[
            experiment_type,
            convbase,
            attention_heads_num,
            enc_dec_layers,
            ffn_dim,
            act_func,
            d_model,
            dilation,
            img_input,
            slider_input,
            display_mask,
            img_input_mask
        ],
        outputs=[
            img_output_from_upload,
            decoder_att_map_output,
            encoder_att_map_output,
        ],
        queue=True,
    )

    example_images.click(
        fn=set_example_image, inputs=[example_images], outputs=[img_input, img_input_mask],
        show_progress=True
    )

    app.launch()