Spaces:
Running
Running
Annonymous
commited on
Commit
·
835894d
1
Parent(s):
3d52de0
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import matplotlib
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
import cv2
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
|
11 |
+
from data_transforms import normal_transforms, no_shift_transforms, ig_transforms, modify_transforms
|
12 |
+
from utils import overlay_heatmap, viz_map, show_image, deprocess, get_ssl_model, fig2img
|
13 |
+
from methods import occlusion, occlusion_context_agnositc, pairwise_occlusion
|
14 |
+
from methods import get_difference
|
15 |
+
from methods import create_mixed_images, averaged_transforms, sailency, smooth_grad
|
16 |
+
from methods import get_sample_dataset, pixel_invariance, get_gradcam, get_interactioncam
|
17 |
+
|
18 |
+
matplotlib.use('Agg')
|
19 |
+
|
20 |
+
def load_model(model_name):
|
21 |
+
|
22 |
+
global network, ssl_model, denorm
|
23 |
+
if model_name == "simclrv2 (1X)":
|
24 |
+
variant = '1x'
|
25 |
+
network = 'simclrv2'
|
26 |
+
denorm = False
|
27 |
+
|
28 |
+
elif model_name == "simclrv2 (2X)":
|
29 |
+
variant = '2x'
|
30 |
+
network = 'simclrv2'
|
31 |
+
denorm = False
|
32 |
+
|
33 |
+
elif model_name == "Barlow Twins":
|
34 |
+
network = 'barlow_twins'
|
35 |
+
variant = None
|
36 |
+
denorm = True
|
37 |
+
|
38 |
+
ssl_model = get_ssl_model(network, variant)
|
39 |
+
|
40 |
+
if network != 'simclrv2':
|
41 |
+
global normal_transforms, no_shift_transforms, ig_transforms
|
42 |
+
normal_transforms, no_shift_transforms, ig_transforms = modify_transforms(normal_transforms, no_shift_transforms, ig_transforms)
|
43 |
+
|
44 |
+
return "Loaded Model Successfully"
|
45 |
+
|
46 |
+
def load_or_augment_images(img1_input, img2_input, use_aug):
|
47 |
+
|
48 |
+
global img_main, img1, img2
|
49 |
+
|
50 |
+
img_main = img1_input.convert('RGB')
|
51 |
+
|
52 |
+
if use_aug:
|
53 |
+
img1 = normal_transforms['pure'](img_main).unsqueeze(0).to(device)
|
54 |
+
img2 = normal_transforms['aug'](img_main).unsqueeze(0).to(device)
|
55 |
+
else:
|
56 |
+
img1 = normal_transforms['pure'](img_main).unsqueeze(0).to(device)
|
57 |
+
img2 = img2_input.convert('RGB')
|
58 |
+
img2 = normal_transforms['pure'](img2).unsqueeze(0).to(device)
|
59 |
+
|
60 |
+
similarity = "Similarity: {:.3f}".format(nn.CosineSimilarity(dim=-1)(ssl_model(img1), ssl_model(img2)).item())
|
61 |
+
|
62 |
+
fig, axs = plt.subplots(1, 2, figsize=(10,10))
|
63 |
+
np.vectorize(lambda ax:ax.axis('off'))(axs)
|
64 |
+
|
65 |
+
axs[0].imshow(show_image(img1, denormalize = denorm))
|
66 |
+
axs[1].imshow(show_image(img2, denormalize = denorm))
|
67 |
+
plt.subplots_adjust(wspace=0.1, hspace = 0)
|
68 |
+
pil_output = fig2img(fig)
|
69 |
+
return pil_output, similarity
|
70 |
+
|
71 |
+
def run_occlusion(w_size, stride):
|
72 |
+
heatmap1, heatmap2 = occlusion(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
|
73 |
+
heatmap1_ca, heatmap2_ca = occlusion_context_agnositc(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
|
74 |
+
heatmap1_po, heatmap2_po = pairwise_occlusion(img1, img2, ssl_model, batch_size = 32, erase_scale = (0.1, 0.3), erase_ratio = (1, 1.5), num_erases = 100)
|
75 |
+
|
76 |
+
added_image1 = overlay_heatmap(img1, heatmap1, denormalize = denorm)
|
77 |
+
added_image2 = overlay_heatmap(img2, heatmap2, denormalize = denorm)
|
78 |
+
added_image1_ca = overlay_heatmap(img1, heatmap1_ca, denormalize = denorm)
|
79 |
+
added_image2_ca = overlay_heatmap(img2, heatmap2_ca, denormalize = denorm)
|
80 |
+
|
81 |
+
fig, axs = plt.subplots(2, 4, figsize=(20,10))
|
82 |
+
np.vectorize(lambda ax:ax.axis('off'))(axs)
|
83 |
+
|
84 |
+
axs[0, 0].imshow(show_image(img1, denormalize = denorm))
|
85 |
+
axs[0, 1].imshow(added_image1)
|
86 |
+
axs[0, 1].set_title("Conditional Occlusion")
|
87 |
+
axs[0, 2].imshow(added_image1_ca)
|
88 |
+
axs[0, 2].set_title("CA Cond. Occlusion")
|
89 |
+
axs[0, 3].imshow((deprocess(img1, denormalize = denorm) * heatmap1_po[:,:,None]).astype('uint8'))
|
90 |
+
axs[0, 3].set_title("Pairwise Occlusion")
|
91 |
+
axs[1, 0].imshow(show_image(img2, denormalize = denorm))
|
92 |
+
axs[1, 1].imshow(added_image2)
|
93 |
+
axs[1, 2].imshow(added_image2_ca)
|
94 |
+
axs[1, 3].imshow((deprocess(img2, denormalize = denorm) * heatmap2_po[:,:,None]).astype('uint8'))
|
95 |
+
plt.subplots_adjust(wspace=0, hspace = 0.01)
|
96 |
+
pil_output = fig2img(fig)
|
97 |
+
return pil_output
|
98 |
+
|
99 |
+
def get_model_difference(later):
|
100 |
+
|
101 |
+
imagenet_images, ssl_images = get_difference(ssl_model = ssl_model, baseline = 'imagenet', image = img2, lr = 1e4,
|
102 |
+
l2_weight = 0.1, alpha_weight = 1e-7, alpha_power = 6, tv_weight = 1e-8,
|
103 |
+
init_scale = 0.1, network = network)
|
104 |
+
|
105 |
+
fig, axs = plt.subplots(3, 3, figsize=(10,10))
|
106 |
+
np.vectorize(lambda ax:ax.axis('off'))(axs)
|
107 |
+
|
108 |
+
for aa, (in_img, ssl_img) in enumerate(zip(imagenet_images, ssl_images)):
|
109 |
+
axs[aa,0].imshow(deprocess(img2, denormalize = denorm))
|
110 |
+
axs[aa,1].imshow(deprocess(in_img))
|
111 |
+
axs[aa,2].imshow(deprocess(ssl_img))
|
112 |
+
|
113 |
+
axs[0,0].set_title("Original Image")
|
114 |
+
axs[0,1].set_title("Synthesized (cls)")
|
115 |
+
axs[0,2].set_title("Synthesized (contastive)")
|
116 |
+
|
117 |
+
plt.subplots_adjust(wspace=0.01, hspace = 0.01)
|
118 |
+
pil_output = fig2img(fig)
|
119 |
+
return pil_output
|
120 |
+
|
121 |
+
def get_avg_trasforms(transform_type, add_noise, blur_output, guided):
|
122 |
+
|
123 |
+
mixed_images = create_mixed_images(transform_type = transform_type,
|
124 |
+
ig_transforms = ig_transforms,
|
125 |
+
step = 0.1,
|
126 |
+
img_path = img_main,
|
127 |
+
add_noise = add_noise)
|
128 |
+
|
129 |
+
# vanilla gradients (for comparison purposes)
|
130 |
+
sailency1_van, sailency2_van = sailency(guided = guided, ssl_model = ssl_model,
|
131 |
+
img1 = mixed_images[0], img2 = mixed_images[-1],
|
132 |
+
blur_output = blur_output)
|
133 |
+
|
134 |
+
# smooth gradients (for comparison purposes)
|
135 |
+
sailency1_s, sailency2_s = smooth_grad(guided = guided, ssl_model = ssl_model,
|
136 |
+
img1 = mixed_images[0], img2 = mixed_images[-1],
|
137 |
+
blur_output = blur_output, steps = 50)
|
138 |
+
|
139 |
+
# integrated transform
|
140 |
+
sailency1, sailency2 = averaged_transforms(guided = guided, ssl_model = ssl_model,
|
141 |
+
mixed_images = mixed_images,
|
142 |
+
blur_output = blur_output)
|
143 |
+
|
144 |
+
fig, axs = plt.subplots(2, 4, figsize=(20,10))
|
145 |
+
np.vectorize(lambda ax:ax.axis('off'))(axs)
|
146 |
+
|
147 |
+
axs[0,0].imshow(show_image(mixed_images[0], denormalize = denorm))
|
148 |
+
axs[0,1].imshow(show_image(sailency1_van.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
|
149 |
+
axs[0,1].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
|
150 |
+
axs[0,1].set_title("Vanilla Gradients")
|
151 |
+
axs[0,2].imshow(show_image(sailency1_s.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
|
152 |
+
axs[0,2].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
|
153 |
+
axs[0,2].set_title("Smooth Gradients")
|
154 |
+
axs[0,3].imshow(show_image(sailency1.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
|
155 |
+
axs[0,3].imshow(show_image(mixed_images[0], denormalize = denorm), alpha=0.5)
|
156 |
+
axs[0,3].set_title("Integrated Transform")
|
157 |
+
axs[1,0].imshow(show_image(mixed_images[-1], denormalize = denorm))
|
158 |
+
axs[1,1].imshow(show_image(sailency2_van.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
|
159 |
+
axs[1,1].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)
|
160 |
+
axs[1,2].imshow(show_image(sailency2_s.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
|
161 |
+
axs[1,2].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)
|
162 |
+
axs[1,3].imshow(show_image(sailency2.detach(), squeeze = False).squeeze(), cmap = plt.cm.jet)
|
163 |
+
axs[1,3].imshow(show_image(mixed_images[-1], denormalize = denorm), alpha=0.5)
|
164 |
+
|
165 |
+
plt.subplots_adjust(wspace=0.02, hspace = 0.02)
|
166 |
+
pil_output = fig2img(fig)
|
167 |
+
return pil_output
|
168 |
+
|
169 |
+
def get_cams():
|
170 |
+
|
171 |
+
gradcam1, gradcam2 = get_gradcam(ssl_model, img1, img2)
|
172 |
+
intcam1_mean, intcam2_mean = get_interactioncam(ssl_model, img1, img2, reduction = 'mean')
|
173 |
+
intcam1_maxmax, intcam2_maxmax = get_interactioncam(ssl_model, img1, img2, reduction = 'max', grad_interact = True)
|
174 |
+
intcam1_attnmax, intcam2_attnmax = get_interactioncam(ssl_model, img1, img2, reduction = 'attn', grad_interact = True)
|
175 |
+
|
176 |
+
fig, axs = plt.subplots(2, 5, figsize=(20,8))
|
177 |
+
np.vectorize(lambda ax:ax.axis('off'))(axs)
|
178 |
+
|
179 |
+
axs[0,0].imshow(show_image(img1[0], squeeze = False, denormalize = denorm))
|
180 |
+
axs[0,1].imshow(overlay_heatmap(img1, gradcam1, denormalize = denorm))
|
181 |
+
axs[0,1].set_title("Grad-CAM")
|
182 |
+
axs[0,2].imshow(overlay_heatmap(img1, intcam1_mean, denormalize = denorm))
|
183 |
+
axs[0,2].set_title("IntCAM Mean")
|
184 |
+
axs[0,3].imshow(overlay_heatmap(img1, intcam1_maxmax, denormalize = denorm))
|
185 |
+
axs[0,3].set_title("IntCAM Max + IntGradMax")
|
186 |
+
axs[0,4].imshow(overlay_heatmap(img1, intcam1_attnmax, denormalize = denorm))
|
187 |
+
axs[0,4].set_title("IntCAM Attn + IntGradMax")
|
188 |
+
|
189 |
+
axs[1,0].imshow(show_image(img2[0], squeeze = False, denormalize = denorm))
|
190 |
+
axs[1,1].imshow(overlay_heatmap(img2, gradcam2, denormalize = denorm))
|
191 |
+
axs[1,2].imshow(overlay_heatmap(img2, intcam2_mean, denormalize = denorm))
|
192 |
+
axs[1,3].imshow(overlay_heatmap(img2, intcam2_maxmax, denormalize = denorm))
|
193 |
+
axs[1,4].imshow(overlay_heatmap(img2, intcam2_attnmax, denormalize = denorm))
|
194 |
+
|
195 |
+
plt.subplots_adjust(wspace=0.01, hspace = 0.01)
|
196 |
+
pil_output = fig2img(fig)
|
197 |
+
return pil_output
|
198 |
+
|
199 |
+
def get_pixel_invariance():
|
200 |
+
|
201 |
+
data_samples1, data_samples2, data_labels, labels_invariance = get_sample_dataset(img_path = img_main,
|
202 |
+
num_augments = 1000,
|
203 |
+
batch_size = 32,
|
204 |
+
no_shift_transforms = no_shift_transforms,
|
205 |
+
ssl_model = ssl_model,
|
206 |
+
n_components = 10)
|
207 |
+
|
208 |
+
inv_heatmap = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
|
209 |
+
labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64,
|
210 |
+
epochs = 1000, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True,
|
211 |
+
blur_output = True, nmf_weight = 0)
|
212 |
+
|
213 |
+
inv_heatmap_nmf = pixel_invariance(data_samples1 = data_samples1, data_samples2 = data_samples2, data_labels = data_labels,
|
214 |
+
labels_invariance = labels_invariance, resize_transform = transforms.Resize, size = 64,
|
215 |
+
epochs = 100, learning_rate = 0.1, l1_weight = 0.2, zero_small_values = True,
|
216 |
+
blur_output = True, nmf_weight = 1)
|
217 |
+
|
218 |
+
fig, axs = plt.subplots(1, 2, figsize=(10,5))
|
219 |
+
np.vectorize(lambda ax:ax.axis('off'))(axs)
|
220 |
+
|
221 |
+
axs[0].imshow(viz_map(img_main, inv_heatmap))
|
222 |
+
axs[0].set_title("Heatmap w/o NMF")
|
223 |
+
axs[1].imshow(viz_map(img_main, inv_heatmap_nmf))
|
224 |
+
axs[1].set_title("Heatmap w/ NMF")
|
225 |
+
plt.subplots_adjust(wspace=0.01, hspace = 0.01)
|
226 |
+
|
227 |
+
pil_output = fig2img(fig)
|
228 |
+
return pil_output
|
229 |
+
|
230 |
+
xai = gr.Blocks()
|
231 |
+
|
232 |
+
with xai:
|
233 |
+
gr.Markdown("<h1>Methods for Explaining Contrastive Learning, CVPR 2023 Submission</h1>")
|
234 |
+
gr.Markdown("The interface is simplified as much as possible with only necessary options to select for each method. Please use our Google Colab demo for more flexibility.")
|
235 |
+
|
236 |
+
with gr.Row():
|
237 |
+
model_name = gr.Dropdown(["simclrv2 (1X)", "simclrv2 (2X)", "Barlow Twins"], label="Choose Model and press \"Load Model\"")
|
238 |
+
load_model_button = gr.Button("Load Model")
|
239 |
+
status_or_similarity = gr.inputs.Textbox(label = "Status")
|
240 |
+
with gr.Row():
|
241 |
+
gr.Markdown("You can either load two images or load a single image and augment it to get the second image (in that case please check the \"Use Augmentations\" button). After that, please press on \"Show Images\"")
|
242 |
+
img1 = gr.Image(type='pil', label = "First Image")
|
243 |
+
img2 = gr.Image(type='pil', label = "Second Image")
|
244 |
+
with gr.Row():
|
245 |
+
use_aug = gr.Checkbox(value = False, label = "Use Augmentations")
|
246 |
+
load_images_button = gr.Button("Show Images")
|
247 |
+
|
248 |
+
gr.Markdown("Choose a method from the different tabs. You may leave the default options as they are and press on \"Run\" ")
|
249 |
+
with gr.Row():
|
250 |
+
with gr.Column():
|
251 |
+
with gr.Tabs():
|
252 |
+
with gr.TabItem("Interaction-CAM"):
|
253 |
+
cams_button = gr.Button("Get Heatmaps")
|
254 |
+
with gr.TabItem("Perturbation Methods"):
|
255 |
+
w_size = gr.Number(value = 64, label = "Occlusion Window Size", precision = 0)
|
256 |
+
stride = gr.Number(value = 8, label = "Occlusion Stride", precision = 0)
|
257 |
+
occlusion_button = gr.Button("Get Heatmap")
|
258 |
+
with gr.TabItem("Averaged Transforms"):
|
259 |
+
transform_type = gr.inputs.Radio(label="Data Augment", choices=['color_jitter', 'blur', 'grayscale', 'solarize', 'combine'], default="combine")
|
260 |
+
add_noise = gr.Checkbox(value = True, label = "Add Noise")
|
261 |
+
blur_output = gr.Checkbox(value = True, label = "Blur Output")
|
262 |
+
guided = gr.Checkbox(value = True, label = "Guided Backprop")
|
263 |
+
avgtransform_button = gr.Button("Get Saliency")
|
264 |
+
with gr.TabItem("Pixel Invariance"):
|
265 |
+
gr.Markdown("Note: Invariance map will be obtained for the first image")
|
266 |
+
pixel_invariance_button = gr.Button("Get Invariance Map")
|
267 |
+
with gr.TabItem("Image Synthesization"):
|
268 |
+
baseline = gr.inputs.Radio(label="Compare With", choices=["Supervised Classification"], default="Supervised Classification")
|
269 |
+
modeldiff_button = gr.Button("Compare")
|
270 |
+
|
271 |
+
with gr.Column():
|
272 |
+
output_image = gr.Image(type='pil', show_label = False)
|
273 |
+
|
274 |
+
load_model_button.click(load_model, inputs = model_name, outputs = status_or_similarity)
|
275 |
+
load_images_button.click(load_or_augment_images, inputs = [img1, img2, use_aug], outputs = [output_image, status_or_similarity])
|
276 |
+
occlusion_button.click(run_occlusion, inputs=[w_size,stride], outputs=output_image)
|
277 |
+
modeldiff_button.click(get_model_difference, inputs = baseline, outputs = output_image)
|
278 |
+
avgtransform_button.click(get_avg_trasforms, inputs = [transform_type, add_noise, blur_output, guided], outputs = output_image)
|
279 |
+
cams_button.click(get_cams, inputs = [], outputs = output_image)
|
280 |
+
pixel_invariance_button.click(get_pixel_invariance, inputs = [], outputs = output_image)
|
281 |
+
|
282 |
+
xai.launch()
|
283 |
+
|