MarcoParola commited on
Commit
560fd79
·
1 Parent(s): d9deb23

redesign the app in two-step questions: 1. show images, 2. ask the question

Browse files
Files changed (3) hide show
  1. app.py +123 -62
  2. data/intel_image/concepts_by_class.csv +7 -0
  3. src/utils.py +4 -1
app.py CHANGED
@@ -7,7 +7,7 @@ import time
7
  from pathlib import Path
8
  from huggingface_hub import CommitScheduler, HfApi
9
 
10
- from src.utils import load_words, load_image_and_saliency, load_example_images
11
  from src.style import css
12
  from src.user import UserID
13
 
@@ -31,29 +31,9 @@ def main():
31
  #user_id = gr.State(load_global_variable())
32
  answers = gr.State([])
33
 
34
- with gr.Row():
35
- target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**")
36
- gr.Markdown("Grad-cam")
37
- gr.Markdown("Lime")
38
- gr.Markdown("Sidu")
39
- gr.Markdown("Rise")
40
 
41
- with gr.Row():
42
- count = user_state if isinstance(user_state, int) else user_state.value
43
- images = load_image_and_saliency(count, data_dir)
44
- target_img = gr.Image(images[0], elem_classes="main-image")
45
- saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
46
- saliency_lime = gr.Image(images[2], elem_classes="main-image")
47
- saliency_sidu = gr.Image(images[3], elem_classes="main-image")
48
- saliency_rise = gr.Image(images[4], elem_classes="main-image")
49
-
50
- with gr.Row():
51
- dropdown1 = gr.Dropdown(choices=options, label="grad-cam")
52
- dropdown2 = gr.Dropdown(choices=options, label="lime")
53
- dropdown3 = gr.Dropdown(choices=options, label="sidu")
54
- dropdown4 = gr.Dropdown(choices=options, label="rise")
55
-
56
- gr.Markdown("### Image examples of the same class")
57
  with gr.Row():
58
  count = user_state if isinstance(user_state, int) else user_state.value
59
  images = load_example_images(count, data_dir)
@@ -73,42 +53,75 @@ def main():
73
  img14 = gr.Image(images[13])
74
  img15 = gr.Image(images[14])
75
  img16 = gr.Image(images[15])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- submit_button = gr.Button("Submit")
 
 
 
 
 
 
 
 
 
 
78
  finish_button = gr.Button("Finish", visible=False)
79
 
80
- def update_images(dropdown1, dropdown2, dropdown3, dropdown4, user_state):
81
-
82
  count = user_state if isinstance(user_state, int) else user_state.value
83
  if count < config['dataset'][config['dataset']['name']]['n_classes']:
84
  images = load_image_and_saliency(count, data_dir)
85
- target_img = gr.Image(images[0], elem_classes="main-image")
86
- saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
87
- saliency_lime = gr.Image(images[2], elem_classes="main-image")
88
- saliency_sidu = gr.Image(images[3], elem_classes="main-image")
89
- saliency_rise = gr.Image(images[4], elem_classes="main-image")
90
-
91
  # image examples
92
  images = load_example_images(count, data_dir)
93
- img1 = gr.Image(images[0])
94
- img2 = gr.Image(images[1])
95
- img3 = gr.Image(images[2])
96
- img4 = gr.Image(images[3])
97
- img5 = gr.Image(images[4])
98
- img6 = gr.Image(images[5])
99
- img7 = gr.Image(images[6])
100
- img8 = gr.Image(images[7])
101
- img9 = gr.Image(images[8])
102
- img10 = gr.Image(images[9])
103
- img11 = gr.Image(images[10])
104
- img12 = gr.Image(images[11])
105
- img13 = gr.Image(images[12])
106
- img14 = gr.Image(images[13])
107
- img15 = gr.Image(images[14])
108
- img16 = gr.Image(images[15])
109
- return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
110
  else:
111
- return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def update_state(state):
114
  count = state if isinstance(state, int) else state.value
@@ -116,22 +129,49 @@ def main():
116
 
117
  def update_img_label(state):
118
  count = state if isinstance(state, int) else state.value
119
- return f"### Target image: {class_names[count]}"
 
 
 
 
 
120
 
121
- def update_buttons(state):
122
  count = state if isinstance(state, int) else state.value
123
  max_images = config['dataset'][config['dataset']['name']]['n_classes']
124
  finish_button = gr.Button("Finish", visible=(count == max_images-1))
125
  submit_button = gr.Button("Submit", visible=(count != max_images-1))
126
- return submit_button, finish_button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  def update_dropdowns():
129
- dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam")
130
- dp2 = gr.Dropdown(choices=options, value=options[0], label="lime")
131
- dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu")
132
- dp4 = gr.Dropdown(choices=options, value=options[0], label="rise")
133
  return dp1, dp2, dp3, dp4
134
 
 
 
 
 
 
 
135
  def redirect():
136
  pass
137
 
@@ -145,7 +185,7 @@ def main():
145
  JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
146
  JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json"
147
  scheduler = CommitScheduler(
148
- repo_id="example-space-to-dataset-json",
149
  repo_type="dataset",
150
  folder_path=JSON_DATASET_DIR,
151
  path_in_repo="data",
@@ -196,19 +236,40 @@ def main():
196
  update_img_label,
197
  inputs=user_state,
198
  outputs=target_img_label
 
 
 
 
199
  ).then(
200
  update_buttons,
 
 
 
 
 
 
 
 
 
 
 
 
201
  inputs=user_state,
202
- outputs={submit_button, finish_button}
203
  ).then(
204
- update_images,
205
  inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state],
206
- outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16},
 
 
 
 
207
  ).then(
208
  update_dropdowns,
209
  outputs={dropdown1, dropdown2, dropdown3, dropdown4}
210
  )
211
-
 
212
  finish_button.click(
213
  add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers
214
  ).then(
 
7
  from pathlib import Path
8
  from huggingface_hub import CommitScheduler, HfApi
9
 
10
+ from src.utils import load_words, load_image_and_saliency, load_example_images, load_csv_concepts
11
  from src.style import css
12
  from src.user import UserID
13
 
 
31
  #user_id = gr.State(load_global_variable())
32
  answers = gr.State([])
33
 
34
+ concepts = load_csv_concepts(data_dir)
 
 
 
 
 
35
 
36
+ gr.Markdown("### Image examples")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  with gr.Row():
38
  count = user_state if isinstance(user_state, int) else user_state.value
39
  images = load_example_images(count, data_dir)
 
53
  img14 = gr.Image(images[13])
54
  img15 = gr.Image(images[14])
55
  img16 = gr.Image(images[15])
56
+
57
+ with gr.Row():
58
+ target_img_label = gr.Markdown(f"Target image: **{class_names[user_state.value]}**")
59
+ gr.Markdown("Grad-cam")
60
+ gr.Markdown("Lime")
61
+ gr.Markdown("Sidu")
62
+ gr.Markdown("Rise")
63
+
64
+ with gr.Row():
65
+ count = user_state if isinstance(user_state, int) else user_state.value
66
+ images = load_image_and_saliency(count, data_dir)
67
+ target_img = gr.Image(images[0], elem_classes="main-image delay", visible=False)
68
+ saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False)
69
+ saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False)
70
+ saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False)
71
+ saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False)
72
 
73
+ count = user_state if isinstance(user_state, int) else user_state.value
74
+ row = concepts.iloc[count]
75
+ question = gr.Markdown(f"### Sort the following saliency maps according to which of them better highlights the following concepts: {row[1]}, {row[2]} , {row[3]}?", visible=False)
76
+ with gr.Row():
77
+ dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False)
78
+ dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False)
79
+ dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False)
80
+ dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False)
81
+
82
+ continue_button = gr.Button("Continue")
83
+ submit_button = gr.Button("Submit", visible=False)
84
  finish_button = gr.Button("Finish", visible=False)
85
 
86
+ def update_images(user_state):
 
87
  count = user_state if isinstance(user_state, int) else user_state.value
88
  if count < config['dataset'][config['dataset']['name']]['n_classes']:
89
  images = load_image_and_saliency(count, data_dir)
90
+
 
 
 
 
 
91
  # image examples
92
  images = load_example_images(count, data_dir)
93
+ img1 = gr.Image(images[0], visible=True)
94
+ img2 = gr.Image(images[1], visible=True)
95
+ img3 = gr.Image(images[2], visible=True)
96
+ img4 = gr.Image(images[3], visible=True)
97
+ img5 = gr.Image(images[4], visible=True)
98
+ img6 = gr.Image(images[5], visible=True)
99
+ img7 = gr.Image(images[6], visible=True)
100
+ img8 = gr.Image(images[7], visible=True)
101
+ img9 = gr.Image(images[8], visible=True)
102
+ img10 = gr.Image(images[9], visible=True)
103
+ img11 = gr.Image(images[10], visible=True)
104
+ img12 = gr.Image(images[11], visible=True)
105
+ img13 = gr.Image(images[12], visible=True)
106
+ img14 = gr.Image(images[13], visible=True)
107
+ img15 = gr.Image(images[14], visible=True)
108
+ img16 = gr.Image(images[15], visible=True)
109
+ return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
110
  else:
111
+ return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
112
+
113
+ def update_saliencies(dropdown1, dropdown2, dropdown3, dropdown4, user_state):
114
+ count = user_state if isinstance(user_state, int) else user_state.value
115
+ if count < config['dataset'][config['dataset']['name']]['n_classes']:
116
+ images = load_image_and_saliency(count, data_dir)
117
+ target_img = gr.Image(images[0], elem_classes="main-image", visible=True)
118
+ saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=True)
119
+ saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=True)
120
+ saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=True)
121
+ saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=True)
122
+ return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
123
+ else:
124
+ return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu
125
 
126
  def update_state(state):
127
  count = state if isinstance(state, int) else state.value
 
129
 
130
  def update_img_label(state):
131
  count = state if isinstance(state, int) else state.value
132
+ return f" Target image: **{class_names[count]}**"
133
+
134
+ def update_buttons():
135
+ submit_button = gr.Button("Submit", visible=False)
136
+ continue_button = gr.Button("Continue", visible=True)
137
+ return continue_button, submit_button
138
 
139
+ def show_view(state):
140
  count = state if isinstance(state, int) else state.value
141
  max_images = config['dataset'][config['dataset']['name']]['n_classes']
142
  finish_button = gr.Button("Finish", visible=(count == max_images-1))
143
  submit_button = gr.Button("Submit", visible=(count != max_images-1))
144
+ continue_button = gr.Button("Continue", visible=False)
145
+ return continue_button, submit_button, finish_button
146
+
147
+
148
+ def hide_view():
149
+ target_img = gr.Image(images[0], elem_classes="main-image", visible=False)
150
+ saliency_gradcam = gr.Image(images[1], elem_classes="main-image", visible=False)
151
+ saliency_lime = gr.Image(images[2], elem_classes="main-image", visible=False)
152
+ saliency_sidu = gr.Image(images[4], elem_classes="main-image", visible=False)
153
+ saliency_rise = gr.Image(images[3], elem_classes="main-image", visible=False)
154
+ question = gr.Markdown(f"### Sort the following saliency maps according to which of them better highlights the following concepts: {row[1]}, {row[2]} , {row[3]}?", visible=False)
155
+ dropdown1 = gr.Dropdown(choices=options, label="grad-cam", visible=False)
156
+ dropdown2 = gr.Dropdown(choices=options, label="lime", visible=False)
157
+ dropdown3 = gr.Dropdown(choices=options, label="sidu", visible=False)
158
+ dropdown4 = gr.Dropdown(choices=options, label="rise", visible=False)
159
+ return target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, question, dropdown1, dropdown2, dropdown3, dropdown4
160
+
161
 
162
  def update_dropdowns():
163
+ dp1 = gr.Dropdown(choices=options, value=options[0], label="grad-cam", visible=True)
164
+ dp2 = gr.Dropdown(choices=options, value=options[0], label="lime", visible=True)
165
+ dp3 = gr.Dropdown(choices=options, value=options[0], label="sidu", visible=True)
166
+ dp4 = gr.Dropdown(choices=options, value=options[0], label="rise", visible=True)
167
  return dp1, dp2, dp3, dp4
168
 
169
+ def update_questions(state):
170
+ concepts = load_csv_concepts(data_dir)
171
+ count = state if isinstance(state, int) else state.value
172
+ row = concepts.iloc[count]
173
+ return gr.Markdown(f"### Sort the following saliency maps according to which of them better highlights the following concepts: {row[1]}, {row[2]} , {row[3]}?", visible=True)
174
+
175
  def redirect():
176
  pass
177
 
 
185
  JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
186
  JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{uuid4()}.json"
187
  scheduler = CommitScheduler(
188
+ repo_id=f"results_{config['dataset']['name']}_{config['results']['exp1_dir']}", # The repo id
189
  repo_type="dataset",
190
  folder_path=JSON_DATASET_DIR,
191
  path_in_repo="data",
 
236
  update_img_label,
237
  inputs=user_state,
238
  outputs=target_img_label
239
+ ).then(
240
+ update_images,
241
+ inputs=user_state,
242
+ outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16]
243
  ).then(
244
  update_buttons,
245
+ outputs={continue_button, submit_button}
246
+ ).then(
247
+ hide_view,
248
+ outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, question, dropdown1, dropdown2, dropdown3, dropdown4}
249
+ )
250
+
251
+ continue_button.click(
252
+ show_view,
253
+ inputs=user_state,
254
+ outputs={continue_button, submit_button, finish_button}
255
+ ).then(
256
+ update_img_label,
257
  inputs=user_state,
258
+ outputs=target_img_label
259
  ).then(
260
+ update_saliencies,
261
  inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state],
262
+ outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise},
263
+ ).then(
264
+ update_questions,
265
+ inputs=user_state,
266
+ outputs=question
267
  ).then(
268
  update_dropdowns,
269
  outputs={dropdown1, dropdown2, dropdown3, dropdown4}
270
  )
271
+
272
+
273
  finish_button.click(
274
  add_answer, inputs=[dropdown1, dropdown2, dropdown3, dropdown4, answers],outputs=answers
275
  ).then(
data/intel_image/concepts_by_class.csv ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ class, concept1, concept2, concept3, concept4, concept5, concept6, concept7, concept8, concept9, concept10, concept11, concept12, concept13, concept14, concept15, concept16
2
+ buildings, Roof, Window, Facade, Wall, Boat, Tree, Sky, Car, Streetlights, Sidewalk, Beach, Vegetation, Water, Mountain Peak, Rock, Ice
3
+ forest, Vegetation, Tree, Water, Sidewalk, Facade, Sky, Beach, Wall, Rock, Window, Ice, Roof, Streetlights, Car, Mountain Peak, Boat
4
+ glacier, Ice, Rock, Mountain Peak, Water, Wall, Beach, Sky, Vegetation, Sidewalk, Facade, Roof, Tree, Window, Boat, Streetlights, Car
5
+ mountain, Mountain Peak, Rock, Vegetation, Sky, Tree, Ice, Water, Beach, Wall, Facade, Roof, Boat, Sidewalk, Window, Streetlights, Car
6
+ sea, Water, Boat, Beach, Sky, Rock, Sidewalk, Wall, Ice, Roof, Vegetation, Facade, Mountain Peak, Tree, Streetlights, Window, Car
7
+ street, Car, Streetlights, Sidewalk, Boat, Wall, Facade, Tree, Roof, Beach, Sky, Window, Vegetation, Water, Rock, Mountain Peak, Ice
src/utils.py CHANGED
@@ -33,7 +33,10 @@ def load_words(idx):
33
  return words
34
 
35
 
36
-
 
 
 
37
 
38
 
39
 
 
33
  return words
34
 
35
 
36
+ def load_csv_concepts(data_dir):
37
+ # Load data from csv
38
+ data = pd.read_csv(os.path.join(data_dir, 'concepts_by_class.csv'))
39
+ return data
40
 
41
 
42