Spaces:
Sleeping
Sleeping
Annonymous
commited on
Commit
·
cddd431
1
Parent(s):
feafd17
Upload 3 files
Browse files- data_transforms.py +96 -0
- methods.py +578 -0
- utils.py +101 -0
data_transforms.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import torch.nn as nn
|
5 |
+
from PIL import Image, ImageOps, ImageFilter
|
6 |
+
import random
|
7 |
+
|
8 |
+
def add_normalization_to_transform(unnormalized_transforms):
|
9 |
+
"""Adds ImageNet normalization to all transforms"""
|
10 |
+
normalized_transform = {}
|
11 |
+
for key, value in unnormalized_transforms.items():
|
12 |
+
normalized_transform[key] = transforms.Compose([value,
|
13 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
14 |
+
std=[0.229, 0.224, 0.225])])
|
15 |
+
return normalized_transform
|
16 |
+
|
17 |
+
def modify_transforms(normal_transforms, no_shift_transforms, ig_transforms):
|
18 |
+
normal_transforms = add_normalization_to_transform(normal_transforms)
|
19 |
+
no_shift_transforms = add_normalization_to_transform(no_shift_transforms)
|
20 |
+
ig_transforms = add_normalization_to_transform(ig_transforms)
|
21 |
+
return normal_transforms, no_shift_transforms, ig_transforms
|
22 |
+
|
23 |
+
class Solarization(object):
|
24 |
+
def __init__(self, p):
|
25 |
+
self.p = p
|
26 |
+
|
27 |
+
def __call__(self, img):
|
28 |
+
if random.random() < self.p:
|
29 |
+
return ImageOps.solarize(img)
|
30 |
+
else:
|
31 |
+
return img
|
32 |
+
|
33 |
+
# no imagent normalization for simclrv2
|
34 |
+
pure_transform = transforms.Compose([transforms.Resize(256),
|
35 |
+
transforms.CenterCrop(224),
|
36 |
+
transforms.ToTensor()])
|
37 |
+
|
38 |
+
aug_transform = transforms.Compose([transforms.RandomResizedCrop(224),
|
39 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
40 |
+
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
|
41 |
+
transforms.RandomGrayscale(p=0.2),
|
42 |
+
transforms.RandomApply([transforms.GaussianBlur(kernel_size=(21,21), sigma=(0.1,2.0))], p=0.5),
|
43 |
+
transforms.ToTensor()])
|
44 |
+
|
45 |
+
ig_pure_transform = transforms.Compose([transforms.Resize(256),
|
46 |
+
transforms.CenterCrop(224),
|
47 |
+
transforms.ToTensor()])
|
48 |
+
|
49 |
+
ig_transform_colorjitter = transforms.Compose([transforms.Resize(256),
|
50 |
+
transforms.CenterCrop(224),
|
51 |
+
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.4)], p=1),
|
52 |
+
transforms.ToTensor()])
|
53 |
+
|
54 |
+
ig_transform_blur = transforms.Compose([transforms.Resize(256),
|
55 |
+
transforms.CenterCrop(224),
|
56 |
+
transforms.RandomApply([transforms.GaussianBlur(kernel_size=(11,11), sigma=(5,5))], p=1),
|
57 |
+
transforms.ToTensor()])
|
58 |
+
|
59 |
+
ig_transform_solarize = transforms.Compose([transforms.Resize(256),
|
60 |
+
transforms.CenterCrop(224),
|
61 |
+
Solarization(p=1.0),
|
62 |
+
transforms.ToTensor()])
|
63 |
+
|
64 |
+
ig_transform_grayscale = transforms.Compose([transforms.Resize(256),
|
65 |
+
transforms.CenterCrop(224),
|
66 |
+
transforms.RandomGrayscale(p=1),
|
67 |
+
transforms.ToTensor()])
|
68 |
+
|
69 |
+
|
70 |
+
ig_transform_combine = transforms.Compose([transforms.Resize(256),
|
71 |
+
transforms.CenterCrop(224),
|
72 |
+
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
|
73 |
+
transforms.RandomGrayscale(p=0.2),
|
74 |
+
transforms.RandomApply([transforms.GaussianBlur(kernel_size=(21,21), sigma=(0.1, 2.0))], p=0.5),
|
75 |
+
transforms.ToTensor()])
|
76 |
+
|
77 |
+
pure_transform_no_shift = transforms.Compose([transforms.Resize((224, 224)),
|
78 |
+
transforms.ToTensor()])
|
79 |
+
|
80 |
+
aug_transform_no_shift = transforms.Compose([transforms.Resize((224, 224)),
|
81 |
+
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
|
82 |
+
transforms.RandomGrayscale(p=0.2),
|
83 |
+
transforms.ToTensor()])
|
84 |
+
|
85 |
+
normal_transforms = {'pure': pure_transform,
|
86 |
+
'aug': aug_transform}
|
87 |
+
|
88 |
+
no_shift_transforms = {'pure': pure_transform_no_shift,
|
89 |
+
'aug': aug_transform_no_shift}
|
90 |
+
|
91 |
+
ig_transforms = {'pure': ig_pure_transform,
|
92 |
+
'color_jitter': ig_transform_colorjitter,
|
93 |
+
'blur': ig_transform_blur,
|
94 |
+
'grayscale': ig_transform_grayscale,
|
95 |
+
'solarize': ig_transform_solarize,
|
96 |
+
'combine': ig_transform_combine}
|
methods.py
ADDED
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
import torchvision
|
7 |
+
from PIL import Image
|
8 |
+
from sklearn.decomposition import NMF
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
|
11 |
+
def relu_hook_function(module, grad_in, grad_out):
|
12 |
+
if isinstance(module, nn.ReLU):
|
13 |
+
return (F.relu(grad_in[0]),)
|
14 |
+
|
15 |
+
def blur_sailency(input_image):
|
16 |
+
return torchvision.transforms.functional.gaussian_blur(input_image, kernel_size=[11, 11], sigma=[5,5])
|
17 |
+
|
18 |
+
def occlusion(img1, img2, model, w_size = 64, stride = 8, batch_size = 32):
|
19 |
+
|
20 |
+
measure = nn.CosineSimilarity(dim=-1)
|
21 |
+
output_size = int(((img2.size(-1) - w_size) / stride) + 1)
|
22 |
+
out1_condition, out2_condition = model(img1), model(img2)
|
23 |
+
images1 = []
|
24 |
+
images2 = []
|
25 |
+
|
26 |
+
for i in range(output_size):
|
27 |
+
for j in range(output_size):
|
28 |
+
start_i, start_j = i * stride, j * stride
|
29 |
+
image1 = img1.clone().detach()
|
30 |
+
image2 = img2.clone().detach()
|
31 |
+
image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
|
32 |
+
image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
|
33 |
+
images1.append(image1)
|
34 |
+
images2.append(image2)
|
35 |
+
|
36 |
+
images1 = torch.cat(images1, dim=0).to(device)
|
37 |
+
images2 = torch.cat(images2, dim=0).to(device)
|
38 |
+
|
39 |
+
score_map1 = []
|
40 |
+
score_map2 = []
|
41 |
+
|
42 |
+
assert images1.shape[0] == images2.shape[0]
|
43 |
+
|
44 |
+
for b in range(0, images2.shape[0], batch_size):
|
45 |
+
|
46 |
+
with torch.no_grad():
|
47 |
+
out1 = model(images1[b : b + batch_size, :])
|
48 |
+
out2 = model(images2[b : b + batch_size, :])
|
49 |
+
|
50 |
+
score_map1.append(measure(out1, out2_condition)) # try torch.mm(out2_condition, out1.t())[0]
|
51 |
+
score_map2.append(measure(out1_condition, out2)) # try torch.mm(out1_condition, out2.t())[0]
|
52 |
+
|
53 |
+
score_map1 = torch.cat(score_map1, dim = 0)
|
54 |
+
score_map2 = torch.cat(score_map2, dim = 0)
|
55 |
+
assert images2.shape[0] == score_map2.shape[0] == score_map1.shape[0]
|
56 |
+
|
57 |
+
heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy()
|
58 |
+
heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy()
|
59 |
+
base_score = measure(out1_condition, out2_condition)
|
60 |
+
|
61 |
+
heatmap1 = (heatmap1 - base_score.item()) * -1 # or base_score.item() - heatmap1. The higher the drop, the better
|
62 |
+
heatmap2 = (heatmap2 - base_score.item()) * -1 # or base_score.item() - heatmap2. The higher the drop, the better
|
63 |
+
|
64 |
+
return heatmap1, heatmap2
|
65 |
+
|
66 |
+
def occlusion_context_agnositc(img1, img2, model, w_size = 64, stride = 8, batch_size = 32):
|
67 |
+
|
68 |
+
measure = nn.CosineSimilarity(dim=-1)
|
69 |
+
output_size = int(((img2.size(-1) - w_size) / stride) + 1)
|
70 |
+
out1_condition, out2_condition = model(img1), model(img2)
|
71 |
+
|
72 |
+
images1_occlude_mask = []
|
73 |
+
images2_occlude_mask = []
|
74 |
+
|
75 |
+
for i in range(output_size):
|
76 |
+
for j in range(output_size):
|
77 |
+
start_i, start_j = i * stride, j * stride
|
78 |
+
image1 = img1.clone().detach()
|
79 |
+
image2 = img2.clone().detach()
|
80 |
+
image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
|
81 |
+
image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = 0
|
82 |
+
images1_occlude_mask.append(image1)
|
83 |
+
images2_occlude_mask.append(image2)
|
84 |
+
|
85 |
+
images1_occlude_mask = torch.cat(images1_occlude_mask, dim=0).to(device)
|
86 |
+
images2_occlude_mask = torch.cat(images2_occlude_mask, dim=0).to(device)
|
87 |
+
|
88 |
+
images1_occlude_backround = []
|
89 |
+
images2_occlude_backround = []
|
90 |
+
|
91 |
+
copy_img1 = img1.clone().detach()
|
92 |
+
copy_img2 = img2.clone().detach()
|
93 |
+
|
94 |
+
for i in range(output_size):
|
95 |
+
for j in range(output_size):
|
96 |
+
start_i, start_j = i * stride, j * stride
|
97 |
+
|
98 |
+
image1 = torch.zeros_like(img1)
|
99 |
+
image2 = torch.zeros_like(img2)
|
100 |
+
|
101 |
+
image1[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img1[:, :, start_i : start_i + w_size, start_j : start_j + w_size]
|
102 |
+
image2[:, :, start_i : start_i + w_size, start_j : start_j + w_size] = copy_img2[:, :, start_i : start_i + w_size, start_j : start_j + w_size]
|
103 |
+
|
104 |
+
images1_occlude_backround.append(image1)
|
105 |
+
images2_occlude_backround.append(image2)
|
106 |
+
|
107 |
+
images1_occlude_backround = torch.cat(images1_occlude_backround, dim=0).to(device)
|
108 |
+
images2_occlude_backround = torch.cat(images2_occlude_backround, dim=0).to(device)
|
109 |
+
|
110 |
+
score_map1 = []
|
111 |
+
score_map2 = []
|
112 |
+
|
113 |
+
assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0]
|
114 |
+
|
115 |
+
for b in range(0, images1_occlude_mask.shape[0], batch_size):
|
116 |
+
|
117 |
+
with torch.no_grad():
|
118 |
+
out1_mask = model(images1_occlude_mask[b : b + batch_size, :])
|
119 |
+
out2_mask = model(images2_occlude_mask[b : b + batch_size, :])
|
120 |
+
out1_backround = model(images1_occlude_backround[b : b + batch_size, :])
|
121 |
+
out2_backround = model(images2_occlude_backround[b : b + batch_size, :])
|
122 |
+
|
123 |
+
out1 = out1_backround - out1_mask
|
124 |
+
out2 = out2_backround - out2_mask
|
125 |
+
score_map1.append(measure(out1, out2_condition)) # or torch.mm(out2_condition, out1.t())[0]
|
126 |
+
score_map2.append(measure(out1_condition, out2)) # or torch.mm(out1_condition, out2.t())[0]
|
127 |
+
|
128 |
+
score_map1 = torch.cat(score_map1, dim = 0)
|
129 |
+
score_map2 = torch.cat(score_map2, dim = 0)
|
130 |
+
assert images1_occlude_mask.shape[0] == images2_occlude_mask.shape[0] == score_map2.shape[0] == score_map1.shape[0]
|
131 |
+
|
132 |
+
heatmap1 = score_map1.view(output_size, output_size).cpu().detach().numpy()
|
133 |
+
heatmap2 = score_map2.view(output_size, output_size).cpu().detach().numpy()
|
134 |
+
|
135 |
+
heatmap1 = (heatmap1 - heatmap1.min()) / (heatmap1.max() - heatmap1.min())
|
136 |
+
heatmap2 = (heatmap2 - heatmap2.min()) / (heatmap2.max() - heatmap2.min())
|
137 |
+
|
138 |
+
return heatmap1, heatmap2
|
139 |
+
|
140 |
+
def pairwise_occlusion(img1, img2, model, batch_size, erase_scale, erase_ratio, num_erases):
|
141 |
+
|
142 |
+
measure = nn.CosineSimilarity(dim=-1)
|
143 |
+
out1_condition, out2_condition = model(img1), model(img2)
|
144 |
+
baseline = measure(out1_condition, out2_condition).detach()
|
145 |
+
# a bit sensitive to scale and ratio. erase_scale is from (scale[0] * 100) % to (scale[1] * 100) %
|
146 |
+
random_erase = transforms.RandomErasing(p=1.0, scale=erase_scale, ratio=erase_ratio)
|
147 |
+
|
148 |
+
image1 = img1.clone().detach()
|
149 |
+
image2 = img2.clone().detach()
|
150 |
+
images1 = []
|
151 |
+
images2 = []
|
152 |
+
|
153 |
+
for _ in range(num_erases):
|
154 |
+
images1.append(random_erase(image1))
|
155 |
+
images2.append(random_erase(image2))
|
156 |
+
|
157 |
+
images1 = torch.cat(images1, dim=0).to(device)
|
158 |
+
images2 = torch.cat(images2, dim=0).to(device)
|
159 |
+
|
160 |
+
sims = []
|
161 |
+
weights1 = []
|
162 |
+
weights2 = []
|
163 |
+
|
164 |
+
for b in range(0, images2.shape[0], batch_size):
|
165 |
+
|
166 |
+
with torch.no_grad():
|
167 |
+
out1 = model(images1[b : b + batch_size, :])
|
168 |
+
out2 = model(images2[b : b + batch_size, :])
|
169 |
+
sims.append(measure(out1, out2))
|
170 |
+
weights1.append(out1.norm(dim=-1))
|
171 |
+
weights2.append(out2.norm(dim=-1))
|
172 |
+
|
173 |
+
sims = torch.cat(sims, dim = 0)
|
174 |
+
weights1, weights2 = torch.cat(weights1, dim = 0).cpu().numpy(), torch.cat(weights2, dim = 0).cpu().numpy()
|
175 |
+
weights = list(zip(weights1, weights2))
|
176 |
+
sims = baseline - sims # the higher the drop, the better
|
177 |
+
sims = F.softmax(sims, dim = -1)
|
178 |
+
sims = sims.cpu().numpy()
|
179 |
+
|
180 |
+
assert sims.shape[0] == images1.shape[0] == images2.shape[0]
|
181 |
+
A1 = np.zeros((224, 224))
|
182 |
+
A2 = np.zeros((224, 224))
|
183 |
+
|
184 |
+
for n in range(images1.shape[0]):
|
185 |
+
|
186 |
+
im1_2d = images1[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1)
|
187 |
+
im2_2d = images2[n].cpu().numpy().transpose((1, 2, 0)).sum(axis=-1)
|
188 |
+
|
189 |
+
joint_similarity = sims[n]
|
190 |
+
weight = weights[n]
|
191 |
+
|
192 |
+
if weight[0] < weight[1]:
|
193 |
+
A1[im1_2d == 0] += joint_similarity
|
194 |
+
else:
|
195 |
+
A2[im2_2d == 0] += joint_similarity
|
196 |
+
|
197 |
+
A1 = A1 / (np.max(A1) + 1e-9)
|
198 |
+
A2 = A2 / (np.max(A2) + 1e-9)
|
199 |
+
|
200 |
+
return A1, A2
|
201 |
+
|
202 |
+
def tv_reg(img, l1 = True):
|
203 |
+
|
204 |
+
diff_i = (img[:, :, :, 1:] - img[:, :, :, :-1])
|
205 |
+
diff_j = (img[:, :, 1:, :] - img[:, :, :-1, :])
|
206 |
+
|
207 |
+
if l1:
|
208 |
+
return diff_i.abs().sum() + diff_j.abs().sum()
|
209 |
+
else:
|
210 |
+
return diff_i.pow(2).sum() + diff_j.pow(2).sum()
|
211 |
+
|
212 |
+
|
213 |
+
def synthesize(ssl_model, model_type, img1, img_cls_layer, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network):
|
214 |
+
|
215 |
+
if model_type == 'imagenet':
|
216 |
+
reduce_lr = False
|
217 |
+
model = torchvision.models.resnet50(pretrained=True)
|
218 |
+
model = list(model.children())[:img_cls_layer]
|
219 |
+
model = nn.Sequential(*model).to(device)
|
220 |
+
model.eval()
|
221 |
+
else:
|
222 |
+
reduce_lr = True
|
223 |
+
shift_layer = 3 if network == 'simclrv2' else 0
|
224 |
+
equivalent_layer = img_cls_layer - shift_layer
|
225 |
+
model = list(ssl_model.encoder.net.children())[:equivalent_layer]
|
226 |
+
model = nn.Sequential(*model).to(device)
|
227 |
+
model.eval()
|
228 |
+
|
229 |
+
opt_img = (init_scale * torch.randn(1, 3, 224, 224)).to(device).requires_grad_()
|
230 |
+
target_feats = model(img1).detach()
|
231 |
+
optimizer = torch.optim.SGD([opt_img], lr=lr, momentum=0.9)
|
232 |
+
|
233 |
+
for i in range(201):
|
234 |
+
opt_img.data = opt_img.data.clip(0,1)
|
235 |
+
optimizer.zero_grad()
|
236 |
+
output = model(opt_img)
|
237 |
+
l2_loss = l2_weight * ((output - target_feats) ** 2).sum() / (target_feats ** 2).sum()
|
238 |
+
reg_alpha = alpha_weight * (opt_img ** alpha_power).sum()
|
239 |
+
reg_total_variation = tv_weight * tv_reg(opt_img, l1 = False)
|
240 |
+
loss = l2_loss + reg_alpha + reg_total_variation
|
241 |
+
loss.backward()
|
242 |
+
optimizer.step()
|
243 |
+
|
244 |
+
if reduce_lr and i % 40 == 0:
|
245 |
+
for param_group in optimizer.param_groups:
|
246 |
+
param_group['lr'] *= 1/10
|
247 |
+
|
248 |
+
return opt_img
|
249 |
+
|
250 |
+
def get_difference(ssl_model, baseline, image, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network):
|
251 |
+
|
252 |
+
imagenet_images = []
|
253 |
+
ssl_images = []
|
254 |
+
|
255 |
+
for lay in range(4,7):
|
256 |
+
image_net_image = synthesize(ssl_model, baseline, image, lay, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network).detach().clone()
|
257 |
+
ssl_image = synthesize(ssl_model, 'ssl', image, lay, lr, l2_weight, alpha_weight, alpha_power, tv_weight, init_scale, network).detach().clone()
|
258 |
+
imagenet_images.append(image_net_image)
|
259 |
+
ssl_images.append(ssl_image)
|
260 |
+
|
261 |
+
return imagenet_images, ssl_images
|
262 |
+
|
263 |
+
def create_mixed_images(transform_type, ig_transforms, step, img_path, add_noise):
|
264 |
+
|
265 |
+
img = Image.open(img_path).convert('RGB')
|
266 |
+
img1 = ig_transforms['pure'](img).unsqueeze(0).to(device)
|
267 |
+
img2 = ig_transforms[transform_type](img).unsqueeze(0).to(device)
|
268 |
+
|
269 |
+
lambdas = np.arange(1,0,-step)
|
270 |
+
mixed_images = []
|
271 |
+
for l,lam in enumerate(lambdas):
|
272 |
+
mixed_img = lam * img1 + (1 - lam) * img2
|
273 |
+
mixed_images.append(mixed_img)
|
274 |
+
|
275 |
+
if add_noise:
|
276 |
+
sigma = 0.15 / (torch.max(img1) - torch.min(img1)).item()
|
277 |
+
mixed_images = [im + torch.zeros_like(im).normal_(0, sigma) if (n>0) and (n<len(mixed_images)-1) else im for n,im in enumerate(mixed_images)]
|
278 |
+
|
279 |
+
return mixed_images
|
280 |
+
|
281 |
+
def averaged_transforms(guided, ssl_model, mixed_images, blur_output):
|
282 |
+
|
283 |
+
measure = nn.CosineSimilarity(dim=-1)
|
284 |
+
|
285 |
+
if guided:
|
286 |
+
handles = []
|
287 |
+
for i, module in enumerate(ssl_model.modules()):
|
288 |
+
if isinstance(module, nn.ReLU):
|
289 |
+
handles.append(module.register_backward_hook(relu_hook_function))
|
290 |
+
|
291 |
+
grads1 = []
|
292 |
+
grads2 = []
|
293 |
+
|
294 |
+
for xbar_image in mixed_images[1:]:
|
295 |
+
input_image1 = mixed_images[0].clone().requires_grad_()
|
296 |
+
input_image2 = xbar_image.clone().requires_grad_()
|
297 |
+
|
298 |
+
if input_image1.grad is not None:
|
299 |
+
input_image1.grad.data.zero_()
|
300 |
+
input_image2.grad.data.zero_()
|
301 |
+
|
302 |
+
score = measure(ssl_model(input_image1), ssl_model(input_image2))
|
303 |
+
score.backward()
|
304 |
+
grads1.append(input_image1.grad.data)
|
305 |
+
grads2.append(input_image2.grad.data)
|
306 |
+
|
307 |
+
grads1 = torch.cat(grads1).mean(0).unsqueeze(0)
|
308 |
+
grads2 = torch.cat(grads2).mean(0).unsqueeze(0)
|
309 |
+
|
310 |
+
sailency1, _ = torch.max((mixed_images[0] * grads1).abs(), dim=1)
|
311 |
+
sailency2, _ = torch.max((mixed_images[-1] * grads2).abs(), dim=1)
|
312 |
+
|
313 |
+
if guided: # remove handles after finishing
|
314 |
+
for handle in handles:
|
315 |
+
handle.remove()
|
316 |
+
|
317 |
+
if blur_output:
|
318 |
+
sailency1 = blur_sailency(sailency1)
|
319 |
+
sailency2 = blur_sailency(sailency2)
|
320 |
+
|
321 |
+
return sailency1, sailency2
|
322 |
+
|
323 |
+
def sailency(guided, ssl_model, img1, img2, blur_output):
|
324 |
+
|
325 |
+
measure = nn.CosineSimilarity(dim=-1)
|
326 |
+
|
327 |
+
if guided:
|
328 |
+
handles = []
|
329 |
+
for i, module in enumerate(ssl_model.modules()):
|
330 |
+
if isinstance(module, nn.ReLU):
|
331 |
+
handles.append(module.register_backward_hook(relu_hook_function))
|
332 |
+
|
333 |
+
input_image1 = img1.clone().requires_grad_()
|
334 |
+
input_image2 = img2.clone().requires_grad_()
|
335 |
+
score = measure(ssl_model(input_image1), ssl_model(input_image2))
|
336 |
+
score.backward()
|
337 |
+
grads1 = input_image1.grad.data
|
338 |
+
grads2 = input_image2.grad.data
|
339 |
+
sailency1, _ = torch.max((img1 * grads1).abs(), dim=1)
|
340 |
+
sailency2, _ = torch.max((img2 * grads2).abs(), dim=1)
|
341 |
+
|
342 |
+
if guided: # remove handles after finishing
|
343 |
+
for handle in handles:
|
344 |
+
handle.remove()
|
345 |
+
|
346 |
+
if blur_output:
|
347 |
+
sailency1 = blur_sailency(sailency1)
|
348 |
+
sailency2 = blur_sailency(sailency2)
|
349 |
+
|
350 |
+
return sailency1, sailency2
|
351 |
+
|
352 |
+
def smooth_grad(guided, ssl_model, img1, img2, blur_output, steps = 50):
|
353 |
+
|
354 |
+
measure = nn.CosineSimilarity(dim=-1)
|
355 |
+
sigma = 0.15 / (torch.max(img1) - torch.min(img1)).item()
|
356 |
+
|
357 |
+
if guided:
|
358 |
+
handles = []
|
359 |
+
for i, module in enumerate(ssl_model.modules()):
|
360 |
+
if isinstance(module, nn.ReLU):
|
361 |
+
handles.append(module.register_backward_hook(relu_hook_function))
|
362 |
+
|
363 |
+
noise_images1 = []
|
364 |
+
noise_images2 = []
|
365 |
+
|
366 |
+
for _ in range(steps):
|
367 |
+
noise = torch.zeros_like(img1).normal_(0, sigma)
|
368 |
+
noise_images1.append(img1 + noise)
|
369 |
+
noise_images2.append(img2 + noise)
|
370 |
+
|
371 |
+
grads1 = []
|
372 |
+
grads2 = []
|
373 |
+
|
374 |
+
for n1, n2 in zip(noise_images1, noise_images2):
|
375 |
+
input_image1 = n1.clone().requires_grad_()
|
376 |
+
input_image2 = n2.clone().requires_grad_()
|
377 |
+
|
378 |
+
if input_image1.grad is not None:
|
379 |
+
input_image1.grad.data.zero_()
|
380 |
+
input_image2.grad.data.zero_()
|
381 |
+
|
382 |
+
score = measure(ssl_model(input_image1), ssl_model(input_image2))
|
383 |
+
score.backward()
|
384 |
+
grads1.append(input_image1.grad.data)
|
385 |
+
grads2.append(input_image2.grad.data)
|
386 |
+
|
387 |
+
grads1 = torch.cat(grads1).mean(0).unsqueeze(0)
|
388 |
+
grads2 = torch.cat(grads2).mean(0).unsqueeze(0)
|
389 |
+
sailency1, _ = torch.max((img1 * grads1 ).abs(), dim=1)
|
390 |
+
sailency2, _ = torch.max((img2 * grads2).abs(), dim=1)
|
391 |
+
|
392 |
+
if guided: # remove handles after finishing
|
393 |
+
for handle in handles:
|
394 |
+
handle.remove()
|
395 |
+
|
396 |
+
if blur_output:
|
397 |
+
sailency1 = blur_sailency(sailency1)
|
398 |
+
sailency2 = blur_sailency(sailency2)
|
399 |
+
|
400 |
+
return sailency1, sailency2
|
401 |
+
|
402 |
+
def get_sample_dataset(img_path, num_augments, batch_size, no_shift_transforms, ssl_model, n_components):
|
403 |
+
|
404 |
+
measure = nn.CosineSimilarity(dim=-1)
|
405 |
+
img = Image.open(img_path).convert('RGB')
|
406 |
+
no_shift_aug = transforms.Compose([no_shift_transforms['aug'],
|
407 |
+
transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3))])
|
408 |
+
|
409 |
+
augments2 = [no_shift_aug(img).unsqueeze(0) for _ in range(num_augments)]
|
410 |
+
data_samples1 = no_shift_transforms['pure'](img).unsqueeze(0).expand(num_augments, -1, -1, -1).to(device)
|
411 |
+
data_samples2 = torch.cat(augments2).to(device)
|
412 |
+
|
413 |
+
labels = []
|
414 |
+
feats_invariance = []
|
415 |
+
|
416 |
+
for b in range(0, data_samples1.shape[0], batch_size):
|
417 |
+
|
418 |
+
with torch.no_grad():
|
419 |
+
out1 = ssl_model(data_samples1[b : b + batch_size, :])
|
420 |
+
out2 = ssl_model(data_samples2[b : b + batch_size, :])
|
421 |
+
labels.append(measure(out1, out2))
|
422 |
+
feats_invariance.append(F.relu(out2))
|
423 |
+
|
424 |
+
data_labels = torch.cat(labels).unsqueeze(-1).to(device)
|
425 |
+
feats_invariance = torch.cat(feats_invariance).to(device)
|
426 |
+
nmf_model = NMF(n_components=n_components, init='random')
|
427 |
+
# (T, 2048) = W.H = (2048,N) . (N,T), where H is the matrix representing the features of each transform
|
428 |
+
H = nmf_model.fit_transform(feats_invariance.cpu().numpy())
|
429 |
+
labels_invariance = torch.from_numpy(H.mean(1)).unsqueeze(-1).to(device)
|
430 |
+
|
431 |
+
return data_samples1, data_samples2, data_labels, labels_invariance
|
432 |
+
|
433 |
+
def pixel_invariance(data_samples1, data_samples2, data_labels, labels_invariance, resize_transform, size, epochs, learning_rate, l1_weight, zero_small_values, blur_output, nmf_weight):
|
434 |
+
|
435 |
+
"""
|
436 |
+
size: resize the image to that when training the surrogate. Later we upsize
|
437 |
+
epochs: number of epochs to train the surrogate model
|
438 |
+
learning_rate: learning rate to train the surrogate model
|
439 |
+
l1_weight: if not None, enables l1 regularization (sparsity)
|
440 |
+
"""
|
441 |
+
x1 = resize_transform((size, size))(data_samples1) # (num_samples, 3, size, size)
|
442 |
+
x2 = resize_transform((size, size))(data_samples2) # (num_samples, 3, size, size)
|
443 |
+
|
444 |
+
x1 = x1.reshape(x1.size(0), -1).to(device)
|
445 |
+
x2 = x2.reshape(x2.size(0), -1).to(device)
|
446 |
+
|
447 |
+
surrogate = nn.Linear(size * size * 3, 1).to(device)
|
448 |
+
|
449 |
+
criterion = nn.BCEWithLogitsLoss(reduction = 'sum')
|
450 |
+
invariance_criterion = nn.MSELoss()
|
451 |
+
optimizer = torch.optim.SGD(surrogate.parameters(), lr=learning_rate)
|
452 |
+
|
453 |
+
for epoch in range(epochs):
|
454 |
+
pred1, pred2 = surrogate(x1), surrogate(x2)
|
455 |
+
preds = (pred1 + pred2) / 2
|
456 |
+
loss = criterion(preds, data_labels)
|
457 |
+
loss += nmf_weight * invariance_criterion(torch.sigmoid(preds), labels_invariance)
|
458 |
+
|
459 |
+
if l1_weight is not None:
|
460 |
+
loss += l1_weight * sum(p.abs().sum() for p in surrogate.parameters())
|
461 |
+
|
462 |
+
optimizer.zero_grad()
|
463 |
+
loss.backward()
|
464 |
+
optimizer.step()
|
465 |
+
|
466 |
+
heatmap = surrogate.weight.reshape(3, size, size)
|
467 |
+
heatmap, _ = torch.max(heatmap, 0)
|
468 |
+
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
|
469 |
+
|
470 |
+
if zero_small_values:
|
471 |
+
heatmap[heatmap < 0.5] = 0
|
472 |
+
|
473 |
+
if blur_output:
|
474 |
+
heatmap = blur_sailency(heatmap.unsqueeze(0)).squeeze(0)
|
475 |
+
|
476 |
+
return heatmap
|
477 |
+
|
478 |
+
class GradCAM(nn.Module):
|
479 |
+
|
480 |
+
def __init__(self, ssl_model):
|
481 |
+
super(GradCAM, self).__init__()
|
482 |
+
|
483 |
+
self.gradients = {}
|
484 |
+
self.features = {}
|
485 |
+
|
486 |
+
self.feature_extractor = ssl_model.encoder.net
|
487 |
+
self.contrastive_head = ssl_model.contrastive_head
|
488 |
+
self.measure = nn.CosineSimilarity(dim=-1)
|
489 |
+
|
490 |
+
def save_grads(self, img_index):
|
491 |
+
|
492 |
+
def hook(grad):
|
493 |
+
self.gradients[img_index] = grad.detach()
|
494 |
+
|
495 |
+
return hook
|
496 |
+
|
497 |
+
def save_features(self, img_index, feats):
|
498 |
+
self.features[img_index] = feats.detach()
|
499 |
+
|
500 |
+
def forward(self, img1, img2):
|
501 |
+
|
502 |
+
features1 = self.feature_extractor(img1)
|
503 |
+
features2 = self.feature_extractor(img2)
|
504 |
+
|
505 |
+
self.save_features('1', features1)
|
506 |
+
self.save_features('2', features2)
|
507 |
+
|
508 |
+
h1 = features1.register_hook(self.save_grads('1'))
|
509 |
+
h2 = features2.register_hook(self.save_grads('2'))
|
510 |
+
|
511 |
+
out1, out2 = features1.mean(dim=[2, 3]), features2.mean(dim=[2, 3])
|
512 |
+
out1, out2 = self.contrastive_head(out1), self.contrastive_head(out2)
|
513 |
+
score = self.measure(out1, out2)
|
514 |
+
|
515 |
+
return score
|
516 |
+
|
517 |
+
def weight_activation(feats, grads):
|
518 |
+
cam = feats * F.relu(grads)
|
519 |
+
cam = torch.sum(cam, dim=1).squeeze().cpu().detach().numpy()
|
520 |
+
return cam
|
521 |
+
|
522 |
+
def get_gradcam(ssl_model, img1, img2):
|
523 |
+
|
524 |
+
grad_cam = GradCAM(ssl_model).to(device)
|
525 |
+
score = grad_cam(img1, img2)
|
526 |
+
grad_cam.zero_grad()
|
527 |
+
score.backward()
|
528 |
+
|
529 |
+
cam1 = weight_activation(grad_cam.features['1'], grad_cam.gradients['1'])
|
530 |
+
cam2 = weight_activation(grad_cam.features['2'], grad_cam.gradients['2'])
|
531 |
+
return cam1, cam2
|
532 |
+
|
533 |
+
def get_interactioncam(ssl_model, img1, img2, reduction, grad_interact = False):
|
534 |
+
|
535 |
+
grad_cam = GradCAM(ssl_model).to(device)
|
536 |
+
score = grad_cam(img1, img2)
|
537 |
+
grad_cam.zero_grad()
|
538 |
+
score.backward()
|
539 |
+
|
540 |
+
G1 = grad_cam.gradients['1']
|
541 |
+
G2 = grad_cam.gradients['2']
|
542 |
+
|
543 |
+
if grad_interact:
|
544 |
+
B, D, H, W = G1.size()
|
545 |
+
G1_ = G1.permute(0,2,3,1).view(B, H * W, D)
|
546 |
+
G2_ = G2.permute(0,2,3,1).view(B, H * W, D)
|
547 |
+
G_ = torch.bmm(G1_.permute(0,2,1), G2_) # (B, D, D)
|
548 |
+
G1, _ = torch.max(G_, dim = -1) # (B, D)
|
549 |
+
G2, _ = torch.max(G_, dim = 1) # (B, D)
|
550 |
+
G1 = G1.unsqueeze(-1).unsqueeze(-1)
|
551 |
+
G2 = G2.unsqueeze(-1).unsqueeze(-1)
|
552 |
+
|
553 |
+
if reduction == 'mean':
|
554 |
+
joint_weight = grad_cam.features['1'].mean([2,3]) * grad_cam.features['2'].mean([2,3])
|
555 |
+
elif reduction == 'max':
|
556 |
+
max_pooled1 = F.max_pool2d(grad_cam.features['1'], kernel_size=grad_cam.features['1'].size()[2:]).squeeze(-1).squeeze(-1)
|
557 |
+
max_pooled2 = F.max_pool2d(grad_cam.features['2'], kernel_size=grad_cam.features['2'].size()[2:]).squeeze(-1).squeeze(-1)
|
558 |
+
joint_weight = max_pooled1 * max_pooled2
|
559 |
+
else:
|
560 |
+
B, D, H, W = grad_cam.features['1'].size()
|
561 |
+
reshaped1 = grad_cam.features['1'].permute(0,2,3,1).reshape(B, H * W, D)
|
562 |
+
reshaped2 = grad_cam.features['2'].permute(0,2,3,1).reshape(B, H * W, D)
|
563 |
+
features1_query, features2_query = reshaped1.mean(1).unsqueeze(1), reshaped2.mean(1).unsqueeze(1)
|
564 |
+
attn1 = (features1_query @ reshaped1.transpose(-2, -1)).softmax(dim=-1)
|
565 |
+
attn2 = (features2_query @ reshaped2.transpose(-2, -1)).softmax(dim=-1)
|
566 |
+
att_reduced1 = (attn1 @ reshaped1).squeeze(1)
|
567 |
+
att_reduced2 = (attn2 @ reshaped2).squeeze(1)
|
568 |
+
joint_weight = att_reduced1 * att_reduced2
|
569 |
+
|
570 |
+
joint_weight = joint_weight.unsqueeze(-1).unsqueeze(-1).expand_as(grad_cam.features['1'])
|
571 |
+
|
572 |
+
feats1 = grad_cam.features['1'] * joint_weight
|
573 |
+
feats2 = grad_cam.features['2'] * joint_weight
|
574 |
+
|
575 |
+
cam1 = weight_activation(feats1, G1)
|
576 |
+
cam2 = weight_activation(feats2, G2)
|
577 |
+
|
578 |
+
return cam1, cam2
|
utils.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import random
|
7 |
+
import cv2
|
8 |
+
import io
|
9 |
+
from ssl_models.simclr2 import get_simclr2_model
|
10 |
+
from ssl_models.barlow_twins import get_barlow_twins_model
|
11 |
+
from ssl_models.simsiam import get_simsiam
|
12 |
+
from ssl_models.dino import get_dino_model_without_loss, get_dino_model_with_loss
|
13 |
+
|
14 |
+
def get_ssl_model(network, variant):
|
15 |
+
|
16 |
+
if network == 'simclrv2':
|
17 |
+
if variant == '1x':
|
18 |
+
ssl_model = get_simclr2_model('r50_1x_sk0_ema.pth').eval()
|
19 |
+
else:
|
20 |
+
ssl_model = get_simclr2_model('r50_2x_sk0_ema.pth').eval()
|
21 |
+
elif network == 'barlow_twins':
|
22 |
+
ssl_model = get_barlow_twins_model().eval()
|
23 |
+
elif network == 'simsiam':
|
24 |
+
ssl_model = get_simsiam().eval()
|
25 |
+
elif network == 'dino':
|
26 |
+
ssl_model = get_dino_model_without_loss().eval()
|
27 |
+
elif network == 'dino+loss':
|
28 |
+
ssl_model, dino_score = get_dino_model_with_loss()
|
29 |
+
ssl_model = ssl_model.eval()
|
30 |
+
|
31 |
+
return ssl_model
|
32 |
+
|
33 |
+
def overlay_heatmap(img, heatmap, denormalize = False):
|
34 |
+
loaded_img = img.squeeze(0).cpu().numpy().transpose((1, 2, 0))
|
35 |
+
|
36 |
+
if denormalize:
|
37 |
+
mean = np.array([0.485, 0.456, 0.406])
|
38 |
+
std = np.array([0.229, 0.224, 0.225])
|
39 |
+
loaded_img = std * loaded_img + mean
|
40 |
+
|
41 |
+
loaded_img = (loaded_img.clip(0, 1) * 255).astype(np.uint8)
|
42 |
+
cam = heatmap / heatmap.max()
|
43 |
+
cam = cv2.resize(cam, (224, 224))
|
44 |
+
cam = np.uint8(255 * cam)
|
45 |
+
cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET) # jet: blue --> red
|
46 |
+
cam = cv2.cvtColor(cam, cv2.COLOR_BGR2RGB)
|
47 |
+
added_image = cv2.addWeighted(cam, 0.5, loaded_img, 0.5, 0)
|
48 |
+
return added_image
|
49 |
+
|
50 |
+
def viz_map(img_path, heatmap):
|
51 |
+
"For pixel invariance"
|
52 |
+
img = np.array(Image.open(img_path).resize((224,224)))
|
53 |
+
width, height, _ = img.shape
|
54 |
+
cam = heatmap.detach().cpu().numpy()
|
55 |
+
cam = cam / cam.max()
|
56 |
+
cam = cv2.resize(cam, (height, width))
|
57 |
+
heatmap = np.uint8(255 * cam)
|
58 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
59 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
60 |
+
added_image = cv2.addWeighted(heatmap, 0.5, img, 0.7, 0)
|
61 |
+
return added_image
|
62 |
+
|
63 |
+
def show_image(x, squeeze = True, denormalize = False):
|
64 |
+
|
65 |
+
if squeeze:
|
66 |
+
x = x.squeeze(0)
|
67 |
+
|
68 |
+
x = x.cpu().numpy().transpose((1, 2, 0))
|
69 |
+
|
70 |
+
if denormalize:
|
71 |
+
mean = np.array([0.485, 0.456, 0.406])
|
72 |
+
std = np.array([0.229, 0.224, 0.225])
|
73 |
+
x = std * x + mean
|
74 |
+
|
75 |
+
return x.clip(0, 1)
|
76 |
+
|
77 |
+
def deprocess(inp, to_numpy = True, to_PIL = False, denormalize = False):
|
78 |
+
|
79 |
+
if to_numpy:
|
80 |
+
inp = inp.detach().cpu().numpy()
|
81 |
+
|
82 |
+
inp = inp.squeeze(0).transpose((1, 2, 0))
|
83 |
+
|
84 |
+
if denormalize:
|
85 |
+
mean = np.array([0.485, 0.456, 0.406])
|
86 |
+
std = np.array([0.229, 0.224, 0.225])
|
87 |
+
inp = std * inp + mean
|
88 |
+
|
89 |
+
inp = (inp.clip(0, 1) * 255).astype(np.uint8)
|
90 |
+
|
91 |
+
if to_PIL:
|
92 |
+
return Image.fromarray(inp)
|
93 |
+
return inp
|
94 |
+
|
95 |
+
def fig2img(fig):
|
96 |
+
"""Convert a Matplotlib figure to a PIL Image and return it"""
|
97 |
+
buf = io.BytesIO()
|
98 |
+
fig.savefig(buf, bbox_inches='tight', pad_inches=0)
|
99 |
+
buf.seek(0)
|
100 |
+
img = Image.open(buf)
|
101 |
+
return img
|