Spaces:
Sleeping
Sleeping
Commit
·
8092894
1
Parent(s):
15dceff
implement user space, counter system to assign ad id to each user
Browse files- app.py +95 -117
- config/config.yaml +16 -3
- src/style.py +6 -0
- src/user.py +18 -0
- src/utils.py +20 -27
app.py
CHANGED
@@ -1,67 +1,42 @@
|
|
1 |
import gradio as gr
|
2 |
import yaml
|
3 |
-
from src.utils import load_words, save_results, load_global_variable, load_saliencies
|
4 |
-
from src.style import css
|
5 |
import random
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
"https://picsum.photos/202",
|
11 |
-
"https://picsum.photos/203",
|
12 |
-
"https://picsum.photos/204",
|
13 |
-
"https://picsum.photos/205",
|
14 |
-
"https://picsum.photos/206",
|
15 |
-
"https://picsum.photos/207",
|
16 |
-
"https://picsum.photos/208",
|
17 |
-
"https://picsum.photos/209",
|
18 |
-
"https://picsum.photos/210",
|
19 |
-
"https://picsum.photos/211",
|
20 |
-
"https://picsum.photos/212",
|
21 |
-
"https://picsum.photos/213",
|
22 |
-
"https://picsum.photos/214",
|
23 |
-
]
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
def update_img_count(state):
|
28 |
-
count = state
|
29 |
-
print('oooooooo', count)
|
30 |
-
return gr.State(count + 1)
|
31 |
|
32 |
def main():
|
33 |
config = yaml.safe_load(open("config/config.yaml"))
|
34 |
-
|
35 |
-
global_var = load_global_variable()
|
36 |
-
|
37 |
-
#images = load_images(global_var)
|
38 |
-
#saliency = load_saliencies(global_var)
|
39 |
words = ['grad-cam', 'lime', 'sidu', 'rise']
|
40 |
options = ['1', '2', '3', '4']
|
|
|
|
|
|
|
41 |
|
42 |
with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo:
|
43 |
# Main App Components
|
44 |
title = gr.Markdown("# Saliency evaluation - experiment 1")
|
45 |
user_state = gr.State(0)
|
46 |
-
|
47 |
-
#user_counter = gr.Textbox(str(global_var), visible=False)
|
48 |
-
#img_counter = gr.Textbox(str(0), visible=False)
|
49 |
|
50 |
with gr.Row():
|
51 |
-
gr.Markdown("### Target image")
|
52 |
gr.Markdown("### Grad-cam")
|
53 |
gr.Markdown("### Lime")
|
54 |
gr.Markdown("### Sidu")
|
55 |
gr.Markdown("### Rise")
|
56 |
|
57 |
with gr.Row():
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
saliency_sidu = gr.Image(
|
64 |
-
|
|
|
65 |
with gr.Row():
|
66 |
dropdown1 = gr.Dropdown(choices=options, label="grad-cam")
|
67 |
dropdown2 = gr.Dropdown(choices=options, label="lime")
|
@@ -71,103 +46,98 @@ def main():
|
|
71 |
gr.Markdown("### Image examples of the same class")
|
72 |
with gr.Row():
|
73 |
# generate random integer value
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
submit_button = gr.Button("Submit")
|
95 |
finish_button = gr.Button("Finish", visible=False)
|
96 |
|
97 |
def update_images(dropdown1, dropdown2, dropdown3, dropdown4, user_state):
|
98 |
|
99 |
-
|
100 |
-
#str_dropdowns = str(dropdowns)
|
101 |
-
# remove the curly braces
|
102 |
-
#dropdowns = str_dropdowns[1:-1]
|
103 |
-
#dropdowns = [r.split(":")[1].strip().replace("'", "") for r in dropdowns.split(",")]
|
104 |
-
|
105 |
print('dropdowns', dropdown1, dropdown2, dropdown3, dropdown4)
|
106 |
-
|
107 |
rank = [dropdown1,dropdown2,dropdown3,dropdown4]
|
108 |
print('rank', rank)
|
109 |
-
# image target and saliency images
|
110 |
-
target_img = gr.Image(random_images[random.randint(0, 5)])
|
111 |
-
saliency_gradcam = gr.Image(random_images[random.randint(0, 5)])
|
112 |
-
saliency_lime = gr.Image(random_images[random.randint(0, 5)])
|
113 |
-
saliency_rise = gr.Image(random_images[random.randint(0, 5)])
|
114 |
-
saliency_sidu = gr.Image(random_images[random.randint(0, 5)])
|
115 |
-
|
116 |
-
# image examples
|
117 |
-
img1 = gr.Image(random_images[random.randint(0, 5)])
|
118 |
-
img2 = gr.Image(random_images[random.randint(0, 5)])
|
119 |
-
img3 = gr.Image(random_images[random.randint(0, 5)])
|
120 |
-
img4 = gr.Image(random_images[random.randint(0, 5)])
|
121 |
-
img5 = gr.Image(random_images[random.randint(0, 5)])
|
122 |
-
img6 = gr.Image(random_images[random.randint(0, 5)])
|
123 |
-
img7 = gr.Image(random_images[random.randint(0, 5)])
|
124 |
-
img8 = gr.Image(random_images[random.randint(0, 5)])
|
125 |
-
img9 = gr.Image(random_images[random.randint(0, 5)])
|
126 |
-
img10 = gr.Image(random_images[random.randint(0, 5)])
|
127 |
-
img11 = gr.Image(random_images[random.randint(0, 5)])
|
128 |
-
img12 = gr.Image(random_images[random.randint(0, 5)])
|
129 |
-
img13 = gr.Image(random_images[random.randint(0, 5)])
|
130 |
-
img14 = gr.Image(random_images[random.randint(0, 5)])
|
131 |
-
img15 = gr.Image(random_images[random.randint(0, 5)])
|
132 |
-
img16 = gr.Image(random_images[random.randint(0, 5)])
|
133 |
-
img17 = gr.Image(random_images[random.randint(0, 5)])
|
134 |
-
img18 = gr.Image(random_images[random.randint(0, 5)])
|
135 |
|
136 |
-
|
137 |
-
if
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
else:
|
142 |
-
|
143 |
-
submit_button.visible = True
|
144 |
|
145 |
-
|
146 |
-
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16, img17, img18
|
147 |
-
|
148 |
def update_state(state):
|
149 |
-
|
150 |
count = state if isinstance(state, int) else state.value
|
151 |
-
print('\n\ncount', count)
|
152 |
return gr.State(count + 1)
|
153 |
|
|
|
|
|
|
|
|
|
154 |
def update_buttons(state):
|
155 |
count = state if isinstance(state, int) else state.value
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
submit_button = gr.Button("Submit", visible=False)
|
160 |
-
else:
|
161 |
-
finish_button = gr.Button("Finish", visible=False)
|
162 |
-
submit_button = gr.Button("Submit", visible=True)
|
163 |
-
|
164 |
return submit_button, finish_button
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
submit_button.click(
|
168 |
update_state,
|
169 |
inputs=user_state,
|
170 |
outputs=user_state
|
|
|
|
|
|
|
|
|
171 |
).then(
|
172 |
update_buttons,
|
173 |
inputs=user_state,
|
@@ -175,7 +145,11 @@ def main():
|
|
175 |
).then(
|
176 |
update_images,
|
177 |
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state],
|
178 |
-
outputs={target_img, saliency_gradcam, saliency_lime,
|
|
|
|
|
|
|
|
|
179 |
)
|
180 |
|
181 |
def redirect():
|
@@ -183,7 +157,11 @@ def main():
|
|
183 |
|
184 |
finish_button.click(redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'")
|
185 |
|
186 |
-
|
|
|
|
|
|
|
|
|
187 |
|
188 |
demo.launch()
|
189 |
|
|
|
1 |
import gradio as gr
|
2 |
import yaml
|
|
|
|
|
3 |
import random
|
4 |
+
import os
|
5 |
|
6 |
+
from src.utils import load_words, save_results, load_image_and_saliency, load_example_images
|
7 |
+
from src.style import css
|
8 |
+
from src.user import UserID
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def main():
|
11 |
config = yaml.safe_load(open("config/config.yaml"))
|
|
|
|
|
|
|
|
|
|
|
12 |
words = ['grad-cam', 'lime', 'sidu', 'rise']
|
13 |
options = ['1', '2', '3', '4']
|
14 |
+
class_names = config['dataset'][config['dataset']['name']]['class_names']
|
15 |
+
data_dir = os.path.join(config['dataset']['path'], config['dataset']['name'])
|
16 |
+
id_generator = UserID()
|
17 |
|
18 |
with gr.Blocks(theme=gr.themes.Glass(), css=css) as demo:
|
19 |
# Main App Components
|
20 |
title = gr.Markdown("# Saliency evaluation - experiment 1")
|
21 |
user_state = gr.State(0)
|
22 |
+
user_id = gr.State(0)
|
|
|
|
|
23 |
|
24 |
with gr.Row():
|
25 |
+
target_img_label = gr.Markdown(f"### Target image: {class_names[user_state.value]}")
|
26 |
gr.Markdown("### Grad-cam")
|
27 |
gr.Markdown("### Lime")
|
28 |
gr.Markdown("### Sidu")
|
29 |
gr.Markdown("### Rise")
|
30 |
|
31 |
with gr.Row():
|
32 |
+
count = user_state if isinstance(user_state, int) else user_state.value
|
33 |
+
images = load_image_and_saliency(count, data_dir)
|
34 |
+
target_img = gr.Image(images[0], elem_classes="main-image")
|
35 |
+
saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
|
36 |
+
saliency_lime = gr.Image(images[2], elem_classes="main-image")
|
37 |
+
saliency_sidu = gr.Image(images[3], elem_classes="main-image")
|
38 |
+
saliency_rise = gr.Image(images[4], elem_classes="main-image")
|
39 |
+
|
40 |
with gr.Row():
|
41 |
dropdown1 = gr.Dropdown(choices=options, label="grad-cam")
|
42 |
dropdown2 = gr.Dropdown(choices=options, label="lime")
|
|
|
46 |
gr.Markdown("### Image examples of the same class")
|
47 |
with gr.Row():
|
48 |
# generate random integer value
|
49 |
+
count = user_state if isinstance(user_state, int) else user_state.value
|
50 |
+
images = load_example_images(count, data_dir)
|
51 |
+
img1 = gr.Image(images[0])
|
52 |
+
img2 = gr.Image(images[1])
|
53 |
+
img3 = gr.Image(images[2])
|
54 |
+
img4 = gr.Image(images[3])
|
55 |
+
img5 = gr.Image(images[4])
|
56 |
+
img6 = gr.Image(images[5])
|
57 |
+
img7 = gr.Image(images[6])
|
58 |
+
img8 = gr.Image(images[7])
|
59 |
+
img9 = gr.Image(images[8])
|
60 |
+
img10 = gr.Image(images[9])
|
61 |
+
img11 = gr.Image(images[10])
|
62 |
+
img12 = gr.Image(images[11])
|
63 |
+
img13 = gr.Image(images[12])
|
64 |
+
img14 = gr.Image(images[13])
|
65 |
+
img15 = gr.Image(images[14])
|
66 |
+
img16 = gr.Image(images[15])
|
67 |
+
|
|
|
68 |
submit_button = gr.Button("Submit")
|
69 |
finish_button = gr.Button("Finish", visible=False)
|
70 |
|
71 |
def update_images(dropdown1, dropdown2, dropdown3, dropdown4, user_state):
|
72 |
|
73 |
+
|
|
|
|
|
|
|
|
|
|
|
74 |
print('dropdowns', dropdown1, dropdown2, dropdown3, dropdown4)
|
|
|
75 |
rank = [dropdown1,dropdown2,dropdown3,dropdown4]
|
76 |
print('rank', rank)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
# image target and saliency images
|
79 |
+
count = user_state if isinstance(user_state, int) else user_state.value
|
80 |
+
print(count, config['dataset'][config['dataset']['name']]['n_classes'])
|
81 |
+
if count < config['dataset'][config['dataset']['name']]['n_classes']:
|
82 |
+
images = load_image_and_saliency(count, data_dir)
|
83 |
+
target_img = gr.Image(images[0], elem_classes="main-image")
|
84 |
+
saliency_gradcam = gr.Image(images[1], elem_classes="main-image")
|
85 |
+
saliency_lime = gr.Image(images[2], elem_classes="main-image")
|
86 |
+
saliency_sidu = gr.Image(images[3], elem_classes="main-image")
|
87 |
+
saliency_rise = gr.Image(images[4], elem_classes="main-image")
|
88 |
+
|
89 |
+
# image examples
|
90 |
+
images = load_example_images(count, data_dir)
|
91 |
+
img1 = gr.Image(images[0])
|
92 |
+
img2 = gr.Image(images[1])
|
93 |
+
img3 = gr.Image(images[2])
|
94 |
+
img4 = gr.Image(images[3])
|
95 |
+
img5 = gr.Image(images[4])
|
96 |
+
img6 = gr.Image(images[5])
|
97 |
+
img7 = gr.Image(images[6])
|
98 |
+
img8 = gr.Image(images[7])
|
99 |
+
img9 = gr.Image(images[8])
|
100 |
+
img10 = gr.Image(images[9])
|
101 |
+
img11 = gr.Image(images[10])
|
102 |
+
img12 = gr.Image(images[11])
|
103 |
+
img13 = gr.Image(images[12])
|
104 |
+
img14 = gr.Image(images[13])
|
105 |
+
img15 = gr.Image(images[14])
|
106 |
+
img16 = gr.Image(images[15])
|
107 |
+
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
|
108 |
else:
|
109 |
+
return target_img, saliency_gradcam, saliency_lime, saliency_rise, saliency_sidu, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16
|
|
|
110 |
|
|
|
|
|
|
|
111 |
def update_state(state):
|
|
|
112 |
count = state if isinstance(state, int) else state.value
|
|
|
113 |
return gr.State(count + 1)
|
114 |
|
115 |
+
def update_img_label(state):
|
116 |
+
count = state if isinstance(state, int) else state.value
|
117 |
+
return f"### Target image: {class_names[count]}"
|
118 |
+
|
119 |
def update_buttons(state):
|
120 |
count = state if isinstance(state, int) else state.value
|
121 |
+
max_images = config['dataset'][config['dataset']['name']]['n_classes']
|
122 |
+
finish_button = gr.Button("Finish", visible=(count == max_images-1))
|
123 |
+
submit_button = gr.Button("Submit", visible=(count != max_images-1))
|
|
|
|
|
|
|
|
|
|
|
124 |
return submit_button, finish_button
|
125 |
|
126 |
+
def update_dropdowns(dropdowns):
|
127 |
+
dropdown1 = gr.Dropdown(choices=options, label="grad-cam")
|
128 |
+
dropdown2 = gr.Dropdown(choices=options, label="lime")
|
129 |
+
dropdown3 = gr.Dropdown(choices=options, label="sidu")
|
130 |
+
dropdown4 = gr.Dropdown(choices=options, label="rise")
|
131 |
+
return dropdown1, dropdown2, dropdown3, dropdown4
|
132 |
|
133 |
submit_button.click(
|
134 |
update_state,
|
135 |
inputs=user_state,
|
136 |
outputs=user_state
|
137 |
+
).then(
|
138 |
+
update_img_label,
|
139 |
+
inputs=user_state,
|
140 |
+
outputs=target_img_label
|
141 |
).then(
|
142 |
update_buttons,
|
143 |
inputs=user_state,
|
|
|
145 |
).then(
|
146 |
update_images,
|
147 |
inputs=[dropdown1, dropdown2, dropdown3, dropdown4, user_state],
|
148 |
+
outputs={target_img, saliency_gradcam, saliency_lime, saliency_sidu, saliency_rise, img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, img11, img12, img13, img14, img15, img16},
|
149 |
+
).then(
|
150 |
+
update_dropdowns,
|
151 |
+
inputs={dropdown1, dropdown2, dropdown3, dropdown4},
|
152 |
+
outputs={dropdown1, dropdown2, dropdown3, dropdown4}
|
153 |
)
|
154 |
|
155 |
def redirect():
|
|
|
157 |
|
158 |
finish_button.click(redirect, js="window.location = 'https://marcoparola.github.io/saliency-evaluation-app/end'")
|
159 |
|
160 |
+
def init(request: gr.Request):
|
161 |
+
user_id.value = id_generator.increment()
|
162 |
+
return user_id
|
163 |
+
|
164 |
+
demo.load(init, inputs=None, outputs=user_id)
|
165 |
|
166 |
demo.launch()
|
167 |
|
config/config.yaml
CHANGED
@@ -2,12 +2,25 @@ data_dir: data
|
|
2 |
image_dir: images
|
3 |
saliency_dir: saliency
|
4 |
|
|
|
|
|
|
|
5 |
repo_id: "MarcoParola/saliency-evaluation"
|
6 |
|
7 |
dataset:
|
8 |
-
|
|
|
|
|
9 |
n_classes: 6
|
10 |
-
class_names: ['
|
11 |
imagenette:
|
12 |
n_classes: 10
|
13 |
-
class_names: ['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
image_dir: images
|
3 |
saliency_dir: saliency
|
4 |
|
5 |
+
gui:
|
6 |
+
max_img_examples: 16
|
7 |
+
|
8 |
repo_id: "MarcoParola/saliency-evaluation"
|
9 |
|
10 |
dataset:
|
11 |
+
name: intel_image
|
12 |
+
path: data
|
13 |
+
intel_image:
|
14 |
n_classes: 6
|
15 |
+
class_names: ['BUILDING', 'FOREST', 'GLACIER', 'MOUNTAIN', 'SEA', 'STREET']
|
16 |
imagenette:
|
17 |
n_classes: 10
|
18 |
+
class_names: ['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
saliency_methods:
|
23 |
+
- gradcam
|
24 |
+
- lime
|
25 |
+
- sidu
|
26 |
+
- rise
|
src/style.py
CHANGED
@@ -3,6 +3,12 @@ css = """
|
|
3 |
height: 300px;
|
4 |
}
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
.gallery-textlabel > * {
|
7 |
h2 {
|
8 |
font-weight: medium;
|
|
|
3 |
height: 300px;
|
4 |
}
|
5 |
|
6 |
+
.main-image {
|
7 |
+
width: 200px;
|
8 |
+
height: 200px;
|
9 |
+
object-fit: cover;
|
10 |
+
}
|
11 |
+
|
12 |
.gallery-textlabel > * {
|
13 |
h2 {
|
14 |
font-weight: medium;
|
src/user.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
from threading import Lock
|
4 |
+
|
5 |
+
class UserID:
|
6 |
+
def __init__(self):
|
7 |
+
self.lock = Lock()
|
8 |
+
self.counter = 0
|
9 |
+
if os.path.exists('global_variable.csv'):
|
10 |
+
df = pd.read_csv('global_variable.csv')
|
11 |
+
self.counter = df['value'][0]
|
12 |
+
|
13 |
+
def increment(self):
|
14 |
+
with self.lock:
|
15 |
+
self.counter += 1
|
16 |
+
df = pd.DataFrame({'value': [self.counter]})
|
17 |
+
df.to_csv('global_variable.csv', index=False)
|
18 |
+
return self.counter
|
src/utils.py
CHANGED
@@ -2,42 +2,35 @@ import os
|
|
2 |
import pandas as pd
|
3 |
from huggingface_hub import HfApi, HfFolder
|
4 |
import yaml
|
5 |
-
|
6 |
|
7 |
config = yaml.safe_load(open("./config/config.yaml"))
|
8 |
|
9 |
-
# Function to load global variable from CSV
|
10 |
-
def load_global_variable():
|
11 |
-
global_counter = 0
|
12 |
-
if os.path.exists('global_variable.csv'):
|
13 |
-
df = pd.read_csv('global_variable.csv')
|
14 |
-
global_counter = df['value'][0]
|
15 |
-
|
16 |
-
print('global_counter', global_counter)
|
17 |
-
|
18 |
-
global_counter += 1
|
19 |
-
df = pd.DataFrame({'value': [global_counter]})
|
20 |
-
df.to_csv('global_variable.csv', index=False)
|
21 |
-
return global_counter
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
def
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
28 |
return images
|
29 |
|
30 |
-
def load_saliencies(global_var):
|
31 |
-
image_dir = os.path.join(config["data_dir"], config["image_dir"])
|
32 |
-
#images = [f"image_{global_var}_{i}.jpg" for i in range(10)]
|
33 |
-
saliencies = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if os.path.isfile(os.path.join(image_dir, f))]
|
34 |
-
# select first 5 saliencies
|
35 |
-
saliencies = saliencies[:5]
|
36 |
-
return saliencies
|
37 |
|
38 |
# Function to load words based on global variable
|
39 |
-
def load_words(
|
40 |
-
words = [f"word_{
|
41 |
return words
|
42 |
|
43 |
# Function to save results and increment global variable
|
|
|
2 |
import pandas as pd
|
3 |
from huggingface_hub import HfApi, HfFolder
|
4 |
import yaml
|
5 |
+
import numpy as np
|
6 |
|
7 |
config = yaml.safe_load(open("./config/config.yaml"))
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
def load_image_and_saliency(class_idx, data_dir):
|
11 |
+
path = os.path.join(data_dir, 'images', str(class_idx))
|
12 |
+
images = os.listdir(path)
|
13 |
+
# pick a random image
|
14 |
+
id = np.random.randint(0, len(images))
|
15 |
+
image = os.path.join(path, images[id])
|
16 |
+
gradcam_image = os.path.join(data_dir, 'saliency', 'gradcam', images[id])
|
17 |
+
lime_image = os.path.join(data_dir, 'saliency', 'lime', images[id])
|
18 |
+
sidu_image = os.path.join(data_dir, 'saliency', 'sidu', images[id])
|
19 |
+
rise_image = os.path.join(data_dir, 'saliency', 'rise', images[id])
|
20 |
+
return image, gradcam_image, lime_image, sidu_image, rise_image
|
21 |
|
22 |
+
def load_example_images(class_idx, data_dir, max_images=16):
|
23 |
+
path = os.path.join(data_dir, 'images', str(class_idx))
|
24 |
+
images = os.listdir(path)
|
25 |
+
# pick max_images random images
|
26 |
+
ids = np.random.choice(len(images), max_images, replace=False)
|
27 |
+
images = [os.path.join(path, images[id]) for id in ids]
|
28 |
return images
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
# Function to load words based on global variable
|
32 |
+
def load_words(idx):
|
33 |
+
words = [f"word_{idx}_{i}" for i in range(20)]
|
34 |
return words
|
35 |
|
36 |
# Function to save results and increment global variable
|