File size: 5,950 Bytes
be1ec96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from open_clip.transformer import VisionTransformer

from .gem_utils import SelfSelfAttention, GEMResidualBlock, modified_vit_forward


class GEMWrapper(nn.Module):
    def __init__(self, model, tokenizer, depth=7, ss_attn_iter=1, ss_attn_temp=None):
        super(GEMWrapper, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.depth = depth
        self.ss_attn_iter = ss_attn_iter
        self.ss_attn_temp = ss_attn_temp
        self.patch_size = self.model.visual.patch_size[0]
        self.apply_gem()

    def apply_gem(self):
        for i in range(1, self.depth):
            # Extract info from the original ViT
            num_heads = self.model.visual.transformer.resblocks[-i].attn.num_heads
            dim = int(self.model.visual.transformer.resblocks[-i].attn.head_dim * num_heads)
            qkv_bias = True
            # Init the self-self attention layer
            ss_attn = SelfSelfAttention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias,
                                        ss_attn_iter=self.ss_attn_iter, ss_attn_temp=self.ss_attn_temp)
            # Copy necessary weights
            ss_attn.qkv.weight.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_weight.clone()
            ss_attn.qkv.bias.data = self.model.visual.transformer.resblocks[-i].attn.in_proj_bias.clone()
            ss_attn.proj.weight.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.weight.clone()
            ss_attn.proj.bias.data = self.model.visual.transformer.resblocks[-i].attn.out_proj.bias.clone()
            # Swap the original Attention with our SelfSelfAttention
            self.model.visual.transformer.resblocks[-i].attn = ss_attn
            # Wrap Residual block to handle SelfSelfAttention outputs
            self.model.visual.transformer.resblocks[-i] = GEMResidualBlock(self.model.visual.transformer.resblocks[-i])
        # Modify ViT's forward function
        self.model.visual.forward = modified_vit_forward.__get__(self.model.visual, VisionTransformer)
        return

    def encode_text(self, text: list):
        prompts = [f'a photo of a {cls}.' for cls in text]
        tokenized_prompts = self.tokenizer(prompts).to(self.model.visual.proj.device)
        text_embedding = self.model.encode_text(tokenized_prompts)
        text_embedding = F.normalize(text_embedding, dim=-1)
        return text_embedding.unsqueeze(0)

    def min_max(self, logits):
        B, num_prompt = logits.shape[:2]
        logits_min = logits.reshape(B, num_prompt, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
        logits_max = logits.reshape(B, num_prompt, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
        logits = (logits - logits_min) / (logits_max - logits_min)
        return logits

    def forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
        """
        :param image: torch.Tensor [1, 3, H, W]
        :param text: list[]
        :param normalize: bool - if True performs min-max normalization
        :param return_ori: bool - if True uses the features from the original visual encoder
        """
        # Image
        W, H = image.shape[-2:]
        feat_gem, feat_ori = self.model.visual(image)
        image_feat = feat_ori if return_ori else feat_gem
        image_feat = F.normalize(image_feat, dim=-1)  # [1, N, dim]

        # Text
        text_embeddings = self.encode_text(text)  # [1, num_prompt, dim]

        # Image-Text matching
        img_txt_matching = image_feat[:, 1:] @ text_embeddings.transpose(-1, -2)  # [1, N, num_prompt]
        img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
                                     w=W//self.patch_size, h=H//self.patch_size)  # [1, num_prompt, w, h]

        # Interpolate
        img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear')  # [1, num_prompt, W, H]

        # Heat Maps
        if normalize:
            img_txt_matching = self.min_max(img_txt_matching)
        return img_txt_matching

    def batched_forward(self, image: torch.Tensor, text: list, normalize: bool = True, return_ori: bool =False):
        """
        :param image: torch.Tensor [B, 3, H, W]
        :param text: list[list[]]
        :param normalize: bool - if True performs min-max normalization
        :param return_ori: bool - if True uses the features from the original visual encoder
        """
        L = len(text)
        cumm_idx = np.cumsum([len(t) for t in text]).tolist()
        B, _, W, H = image.shape
        assert B == L, f'Number of prompts L: {L} should be the same as number of images B: {B}.'

        # Image
        feat_gem, feat_ori = self.model.visual(image)
        image_feat = feat_ori if return_ori else feat_gem
        image_feat = F.normalize(image_feat, dim=-1)  # [B, N, dim]

        # Text
        flatten_text = [t for sub_text in text for t in sub_text]
        text_embeddings = self.encode_text(flatten_text)  # [B, num_prompt, dim]

        # Image-Text matching
        img_txt_matching = 100 * image_feat[:, 1:] @ text_embeddings.transpose(-1, -2)  # [B, N, num_prompt]
        img_txt_matching = rearrange(img_txt_matching, 'b (w h) c -> b c w h',
                                     w=W // self.patch_size, h=H // self.patch_size)  # [B, num_prompt, w, h]

        # Interpolate
        img_txt_matching = F.interpolate(img_txt_matching, size=(W, H), mode='bilinear')  # [B,num_prompt, W, H]

        # Heat Maps
        if normalize:
            img_txt_matching = self.min_max(img_txt_matching)  # [B,num_prompt, W, H]

        # unflatten
        img_txt_matching = torch.tensor_split(img_txt_matching, cumm_idx[:-1], dim=1)
        img_txt_matching = [itm[i] for i, itm in enumerate(img_txt_matching)]
        return img_txt_matching