File size: 8,159 Bytes
5793f6d
 
2c80634
8092894
da0e039
 
e0df362
5793f6d
da0e039
8092894
 
2c80634
5793f6d
 
2c80634
da0e039
8092894
 
2c80634
5793f6d
2c80634
 
 
2826168
92e3db3
2c80634
 
da0e039
 
 
 
 
2c80634
 
8092894
 
 
 
 
 
 
 
2c80634
 
 
 
 
 
 
 
8092894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5793f6d
2c80634
 
 
 
8092894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c80634
8092894
2c80634
 
 
 
5793f6d
8092894
 
 
 
2c80634
 
8092894
 
 
2c80634
 
da0e039
 
 
 
 
 
92e3db3
 
 
 
da0e039
e0df362
da0e039
 
 
 
e0df362
 
 
 
 
c2780e1
e0df362
 
92e3db3
 
 
 
 
 
2c80634
 
 
 
92e3db3
 
 
 
8092894
 
 
 
2c80634
 
 
 
 
 
 
8092894
 
 
 
2c80634
92e3db3
 
 
 
da0e039
92e3db3
 
8092894
2826168
2c80634
5793f6d
 
2c80634
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import yaml
import random
import os
import json
from pathlib import Path
from huggingface_hub import CommitScheduler, HfApi

from src.utils import load_words, load_image_and_saliency, load_example_images
from src.style import css
from src.user import UserID

def main():
    config = yaml.safe_load(open("config/config.yaml"))
    words = ['grad-cam', 'lime', 'sidu', 'rise']
    options = ['-', '1', '2', '3', '4']
    class_names = config['dataset'][config['dataset']['name']]['class_names']
    data_dir = os.path.join(config['dataset']['path'], config['dataset']['name'])

    with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo:
        # Main App Components
        title = gr.Markdown("# Saliency evaluation - experiment 1")
        user_state = gr.State(0)
        #user_id = gr.State(load_global_variable())
        answers = gr.State([])

        with gr.Row():
            target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**")
            gr.Markdown("Grad-cam")
            gr.Markdown("Lime")
            gr.Markdown("Sidu")
            gr.Markdown("Rise")

        with gr.Row():
            count = user_state if isinstance(user_state, int) else user_state.value
            images = load_image_and_saliency(count, data_dir)
            target_img = gr.Image(images[0], elem_classes="main-image")
            saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
            saliency_lime = gr.Image(images[2], elem_classes="main-image")
            saliency_sidu = gr.Image(images[3], elem_classes="main-image")
            saliency_rise = gr.Image(images[4], elem_classes="main-image")
            
        with gr.Row():
            dropdown1 = gr.Dropdown(choices=options, label="grad-cam")
            dropdown2 = gr.Dropdown(choices=options, label="lime")
            dropdown3 = gr.Dropdown(choices=options, label="sidu")
            dropdown4 = gr.Dropdown(choices=options, label="rise")

        gr.Markdown("### Image examples of the same class")
        with gr.Row():
            count = user_state if isinstance(user_state, int) else user_state.value
            images = load_example_images(count, data_dir)
            img1 = gr.Image(images[0])
            img2 = gr.Image(images[1])
            img3 = gr.Image(images[2])
            img4 = gr.Image(images[3])
            img5 = gr.Image(images[4])
            img6 = gr.Image(images[5])
            img7 = gr.Image(images[6])
            img8 = gr.Image(images[7])
            img9 = gr.Image(images[8])
            img10 = gr.Image(images[9])
            img11 = gr.Image(images[10])
            img12 = gr.Image(images[11])
            img13 = gr.Image(images[12])
            img14 = gr.Image(images[13])
            img15 = gr.Image(images[14])
            img16 = gr.Image(images[15])
            
        submit_button = gr.Button("Submit")
        finish_button = gr.Button("Finish", visible=False)

        def update_images(dropdown1, dropdown2, dropdown3, dropdown4, user_state):

            count = user_state if isinstance(user_state, int) else user_state.value
            if count < config['dataset'][config['dataset']['name']]['n_classes']:
                images = load_image_and_saliency(count, data_dir)
                target_img = gr.Image(images[0], elem_classes="main-image")
                saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
                saliency_lime = gr.Image(images[2], elem_classes="main-image")
                saliency_sidu = gr.Image(images[3], elem_classes="main-image")
                saliency_rise = gr.Image(images[4], elem_classes="main-image")

                # image examples
                images = load_example_images(count, data_dir)
                img1 = gr.Image(images[0])
                img2 = gr.Image(images[1])
                img3 = gr.Image(images[2])
                img4 = gr.Image(images[3])
                img5 = gr.Image(images[4])
                img6 = gr.Image(images[5])
                img7 = gr.Image(images[6])
                img8 = gr.Image(images[7])
                img9 = gr.Image(images[8])
                img10 = gr.Image(images[9])
                img11 = gr.Image(images[10])
                img12 = gr.Image(images[11])
                img13 = gr.Image(images[12])
                img14 = gr.Image(images[13])
                img15 = gr.Image(images[14])
                img16 = gr.Image(images[15])
                return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
            else:
                return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16

        def update_state(state):
            count = state if isinstance(state, int) else state.value
            return gr.State(count + 1)

        def update_img_label(state):
            count = state if isinstance(state, int) else state.value
            return f"### Target image: {class_names[count]}"

        def update_buttons(state):
            count = state if isinstance(state, int) else state.value
            max_images = config['dataset'][config['dataset']['name']]['n_classes']
            finish_button = gr.Button("Finish", visible=(count == max_images-1))
            submit_button = gr.Button("Submit", visible=(count != max_images-1))
            return submit_button, finish_button

        def update_dropdowns():
            dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam")
            dp2 = gr.Dropdown(choices=options, value=options[0], label="lime")
            dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu")
            dp4 = gr.Dropdown(choices=options, value=options[0], label="rise")
            return dp1, dp2, dp3, dp4

        def redirect():
            pass

        def save_results(answers):
            api = HfApi()
            json_file_results = config['results']['exp1_dir']
            JSON_DATASET_DIR = Path("json_dataset")
            JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
            JSON_DATASET_PATH = JSON_DATASET_DIR / json_file_results
            
            info_to_push = {
                "user_id": time.time(), 
                "answer": {i: answer[i] for i in range(len(answer))}}

            # use api to push the results to the hub 
            api.push_to_hub(info_to_push, json_file_results, use_temp_dir=True)

 
        def add_answer(dropdown1,dropdown2,dropdown3,dropdown4, answers):
            rank = [dropdown1,dropdown2,dropdown3,dropdown4]
            answers.append(rank)
            return answers

        submit_button.click(
            update_state,
            inputs=user_state,
            outputs=user_state
        ).then(
            add_answer,
            inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],
            outputs=answers
        ).then(
            update_img_label,
            inputs=user_state,
            outputs=target_img_label
        ).then(
            update_buttons,
            inputs=user_state,
            outputs={submit_button, finish_button}
        ).then(
            update_images, 
            inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state], 
            outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16},
        ).then(
            update_dropdowns,
            outputs={dropdown1, dropdown2, dropdown3, dropdown4}
        )
       
        finish_button.click(
            add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers
        ).then(
            save_results, inputs=answers
        ).then(
            redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'")

        demo.load()
    demo.launch()

if __name__ == "__main__":
    main()