COC-VIZ / cluster_visualize.py
CurHarsh's picture
Update cluster_visualize.py
9822ff6
# --------------------------------------------------------
# This file is a modified version of https://github.com/ma-xu/Context-Cluster/blob/main/cluster_visualize.py
# It is modified in order to make it compatible with Gradio.
# --------------------------------------------------------
import models
import timm
import os
import torch
import argparse
import cv2
import numpy as np
import torch.nn.functional as F
import torchvision.transforms.functional as TransF
from torchvision import transforms
from einops import rearrange
import random
from timm.models import load_checkpoint
from torchvision.utils import draw_segmentation_masks
object_categories = []
with open("imagenet1k_id_to_label.txt", "r") as f:
for line in f:
_, val = line.strip().split(":")
object_categories.append(val)
class PredictionArgs:
def __init__(self,
model,
checkpoint,
image,
shape=224,
stage=0,
block=0,
head=1,
resize_img=False,
alpha=0.5):
"""
This class contains all the arguments required for model prediction.
Args:
model: `str` denoting the name of model. ex. 'coc_tiny', 'coc_small', 'coc_medium'.
checkpoint: `str` denoting the path of model checkpoint.
image: `np.array` denoting the path of image.
shape: `int` denoting the dimension of square image.
stage: `int` denoting index of visualized stage, 0-3.
block: `int` denoting index of visualized stage, -1 is the last block ,2,3,4,1.
head: `int` denoting index of visualized head, 0-3 or 0-7.
resize_img: Boolean denoting whether to resize img to feature-map size.
alpha: `float` denoting transparency, 0-1.
"""
self.model = model
self.checkpoint = checkpoint
self.image = image
self.shape = shape
self.stage = stage
self.block = block
self.head = head
self.resize_img = resize_img
self.alpha = alpha
assert self.model in timm.list_models(), "Please use a timm pre-trined model, see timm.list_models()"
# Preprocessing
def _preprocess(raw_image):
raw_image = cv2.resize(raw_image, (224,) * 2)
image = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)(raw_image[..., ::-1].copy())
return image, raw_image
def pairwise_cos_sim(x1: torch.Tensor, x2: torch.Tensor):
"""
return pair-wise similarity matrix between two tensors
:param x1: [B,M,D]
:param x2: [B,N,D]
:return: similarity matrix [B,M,N]
"""
x1 = F.normalize(x1, dim=-1)
x2 = F.normalize(x2, dim=-1)
sim = torch.matmul(x1, x2.permute(0, 2, 1))
return sim
# forward hook function
def get_attention_score(self, input, output):
x = input[0] # input tensor in a tuple
value = self.v(x)
x = self.f(x)
x = rearrange(x, "b (e c) w h -> (b e) c w h", e=self.heads)
value = rearrange(value, "b (e c) w h -> (b e) c w h", e=self.heads)
if self.fold_w > 1 and self.fold_h > 1:
b0, c0, w0, h0 = x.shape
assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0, \
f"Ensure the feature map size ({w0}*{h0}) can be divided by fold {self.fold_w}*{self.fold_h}"
x = rearrange(x, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w,
f2=self.fold_h) # [bs*blocks,c,ks[0],ks[1]]
value = rearrange(value, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, f2=self.fold_h)
b, c, w, h = x.shape
centers = self.centers_proposal(x) # [b,c,C_W,C_H], we set M = C_W*C_H and N = w*h
value_centers = rearrange(self.centers_proposal(value), 'b c w h -> b (w h) c') # [b,C_W,C_H,c]
b, c, ww, hh = centers.shape
sim = torch.sigmoid(self.sim_beta +
self.sim_alpha * pairwise_cos_sim(
centers.reshape(b, c, -1).permute(0, 2, 1),
x.reshape(b, c, -1).permute(0, 2,1)
)
) # [B,M,N]
# sololy assign each point to one center
sim_max, sim_max_idx = sim.max(dim=1, keepdim=True)
mask = torch.zeros_like(sim) # binary #[B,M,N]
mask.scatter_(1, sim_max_idx, 1.) # binary #[B,M,N]
# changed, for plotting mask.
mask = mask.reshape(mask.shape[0], mask.shape[1], w, h) # [(head*fold*fold),m, w,h]
mask = rearrange(mask, "(h0 f1 f2) m w h -> h0 (f1 f2) m w h",
h0=self.heads, f1=self.fold_w, f2=self.fold_h) # [head, (fold*fold),m, w,h]
mask_list = []
for i in range(self.fold_w):
for j in range(self.fold_h):
for k in range(mask.shape[2]):
temp = torch.zeros(self.heads, w * self.fold_w, h * self.fold_h)
temp[:, i * w:(i + 1) * w, j * h:(j + 1) * h] = mask[:, i * self.fold_w + j, k, :, :]
mask_list.append(temp.unsqueeze(dim=0)) # [1, heads, w, h]
mask2 = torch.concat(mask_list, dim=0) # [ n, heads, w, h]
global attention
attention = mask2.detach()
def generate_visualization(args):
global attention
image, raw_image = _preprocess(args.image)
image = image.unsqueeze(dim=0)
model = timm.create_model(model_name=args.model, pretrained=True)
if args.checkpoint:
load_checkpoint(model, args.checkpoint, True)
print(f"\n\n==> Loaded checkpoint")
else:
raise Exception("Checkpoint doesn't exist at specified path: {}".format(args.checkpoint))
print(f"\n\n==> NO checkpoint is loaded")
model.network[args.stage * 2][args.block].token_mixer.register_forward_hook(get_attention_score)
out = model(image)
if type(out) is tuple:
out = out[0]
possibility = torch.softmax(out, dim=1).max() * 100
possibility = "{:.3f}".format(possibility)
value, index = torch.max(out, dim=1)
from torchvision.io import read_image
img = torch.tensor(raw_image).permute(2, 0, 1)
# process the attention map
attention = attention[:, args.head, :, :]
mask = attention.unsqueeze(dim=0)
mask = F.interpolate(mask, (img.shape[-2], img.shape[-1]))
mask = mask.squeeze(dim=0)
mask = mask > 0.5
# randomly selected some good colors.
colors = ["brown", "green", "deepskyblue", "blue", "darkgreen", "darkcyan", "coral", "aliceblue",
"white", "black", "beige", "red", "tomato", "yellowgreen", "violet", "mediumseagreen"]
if mask.shape[0] == 4:
colors = colors[0:4]
if mask.shape[0] > 4:
colors = colors * (mask.shape[0] // 16)
random.seed(123)
random.shuffle(colors)
img_with_masks = draw_segmentation_masks(img, masks=mask, alpha=args.alpha, colors=colors)
img_with_masks = img_with_masks.detach()
img_with_masks = TransF.to_pil_image(img_with_masks)
img_with_masks = np.asarray(img_with_masks)
return img_with_masks, possibility