|
|
|
|
|
|
|
|
|
|
|
import models |
|
import timm |
|
import os |
|
import torch |
|
import argparse |
|
import cv2 |
|
import numpy as np |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as TransF |
|
from torchvision import transforms |
|
from einops import rearrange |
|
import random |
|
from timm.models import load_checkpoint |
|
from torchvision.utils import draw_segmentation_masks |
|
|
|
object_categories = [] |
|
with open("imagenet1k_id_to_label.txt", "r") as f: |
|
for line in f: |
|
_, val = line.strip().split(":") |
|
object_categories.append(val) |
|
|
|
|
|
class PredictionArgs: |
|
def __init__(self, |
|
model, |
|
checkpoint, |
|
image, |
|
shape=224, |
|
stage=0, |
|
block=0, |
|
head=1, |
|
resize_img=False, |
|
alpha=0.5): |
|
""" |
|
This class contains all the arguments required for model prediction. |
|
|
|
Args: |
|
model: `str` denoting the name of model. ex. 'coc_tiny', 'coc_small', 'coc_medium'. |
|
checkpoint: `str` denoting the path of model checkpoint. |
|
image: `np.array` denoting the path of image. |
|
shape: `int` denoting the dimension of square image. |
|
stage: `int` denoting index of visualized stage, 0-3. |
|
block: `int` denoting index of visualized stage, -1 is the last block ,2,3,4,1. |
|
head: `int` denoting index of visualized head, 0-3 or 0-7. |
|
resize_img: Boolean denoting whether to resize img to feature-map size. |
|
alpha: `float` denoting transparency, 0-1. |
|
""" |
|
self.model = model |
|
self.checkpoint = checkpoint |
|
self.image = image |
|
self.shape = shape |
|
self.stage = stage |
|
self.block = block |
|
self.head = head |
|
self.resize_img = resize_img |
|
self.alpha = alpha |
|
assert self.model in timm.list_models(), "Please use a timm pre-trined model, see timm.list_models()" |
|
|
|
|
|
def _preprocess(raw_image): |
|
raw_image = cv2.resize(raw_image, (224,) * 2) |
|
image = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
)(raw_image[..., ::-1].copy()) |
|
return image, raw_image |
|
|
|
|
|
def pairwise_cos_sim(x1: torch.Tensor, x2: torch.Tensor): |
|
""" |
|
return pair-wise similarity matrix between two tensors |
|
:param x1: [B,M,D] |
|
:param x2: [B,N,D] |
|
:return: similarity matrix [B,M,N] |
|
""" |
|
x1 = F.normalize(x1, dim=-1) |
|
x2 = F.normalize(x2, dim=-1) |
|
sim = torch.matmul(x1, x2.permute(0, 2, 1)) |
|
return sim |
|
|
|
|
|
|
|
def get_attention_score(self, input, output): |
|
x = input[0] |
|
value = self.v(x) |
|
x = self.f(x) |
|
x = rearrange(x, "b (e c) w h -> (b e) c w h", e=self.heads) |
|
value = rearrange(value, "b (e c) w h -> (b e) c w h", e=self.heads) |
|
if self.fold_w > 1 and self.fold_h > 1: |
|
b0, c0, w0, h0 = x.shape |
|
assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0, \ |
|
f"Ensure the feature map size ({w0}*{h0}) can be divided by fold {self.fold_w}*{self.fold_h}" |
|
x = rearrange(x, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, |
|
f2=self.fold_h) |
|
value = rearrange(value, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, f2=self.fold_h) |
|
b, c, w, h = x.shape |
|
centers = self.centers_proposal(x) |
|
value_centers = rearrange(self.centers_proposal(value), 'b c w h -> b (w h) c') |
|
b, c, ww, hh = centers.shape |
|
sim = torch.sigmoid(self.sim_beta + |
|
self.sim_alpha * pairwise_cos_sim( |
|
centers.reshape(b, c, -1).permute(0, 2, 1), |
|
x.reshape(b, c, -1).permute(0, 2,1) |
|
) |
|
) |
|
|
|
sim_max, sim_max_idx = sim.max(dim=1, keepdim=True) |
|
mask = torch.zeros_like(sim) |
|
mask.scatter_(1, sim_max_idx, 1.) |
|
|
|
mask = mask.reshape(mask.shape[0], mask.shape[1], w, h) |
|
mask = rearrange(mask, "(h0 f1 f2) m w h -> h0 (f1 f2) m w h", |
|
h0=self.heads, f1=self.fold_w, f2=self.fold_h) |
|
mask_list = [] |
|
for i in range(self.fold_w): |
|
for j in range(self.fold_h): |
|
for k in range(mask.shape[2]): |
|
temp = torch.zeros(self.heads, w * self.fold_w, h * self.fold_h) |
|
temp[:, i * w:(i + 1) * w, j * h:(j + 1) * h] = mask[:, i * self.fold_w + j, k, :, :] |
|
mask_list.append(temp.unsqueeze(dim=0)) |
|
|
|
mask2 = torch.concat(mask_list, dim=0) |
|
global attention |
|
attention = mask2.detach() |
|
|
|
|
|
def generate_visualization(args): |
|
global attention |
|
image, raw_image = _preprocess(args.image) |
|
image = image.unsqueeze(dim=0) |
|
model = timm.create_model(model_name=args.model, pretrained=True) |
|
if args.checkpoint: |
|
load_checkpoint(model, args.checkpoint, True) |
|
print(f"\n\n==> Loaded checkpoint") |
|
else: |
|
raise Exception("Checkpoint doesn't exist at specified path: {}".format(args.checkpoint)) |
|
print(f"\n\n==> NO checkpoint is loaded") |
|
model.network[args.stage * 2][args.block].token_mixer.register_forward_hook(get_attention_score) |
|
out = model(image) |
|
if type(out) is tuple: |
|
out = out[0] |
|
possibility = torch.softmax(out, dim=1).max() * 100 |
|
possibility = "{:.3f}".format(possibility) |
|
value, index = torch.max(out, dim=1) |
|
|
|
from torchvision.io import read_image |
|
img = torch.tensor(raw_image).permute(2, 0, 1) |
|
|
|
|
|
attention = attention[:, args.head, :, :] |
|
mask = attention.unsqueeze(dim=0) |
|
mask = F.interpolate(mask, (img.shape[-2], img.shape[-1])) |
|
mask = mask.squeeze(dim=0) |
|
mask = mask > 0.5 |
|
|
|
colors = ["brown", "green", "deepskyblue", "blue", "darkgreen", "darkcyan", "coral", "aliceblue", |
|
"white", "black", "beige", "red", "tomato", "yellowgreen", "violet", "mediumseagreen"] |
|
if mask.shape[0] == 4: |
|
colors = colors[0:4] |
|
if mask.shape[0] > 4: |
|
colors = colors * (mask.shape[0] // 16) |
|
random.seed(123) |
|
random.shuffle(colors) |
|
|
|
img_with_masks = draw_segmentation_masks(img, masks=mask, alpha=args.alpha, colors=colors) |
|
img_with_masks = img_with_masks.detach() |
|
img_with_masks = TransF.to_pil_image(img_with_masks) |
|
img_with_masks = np.asarray(img_with_masks) |
|
return img_with_masks, possibility |
|
|