File size: 10,167 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import functional as F

from mmpretrain.models.backbones import MixMIMTransformer
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import build_2d_sincos_position_embedding
from .base import BaseSelfSupervisor


@MODELS.register_module()
class MixMIMPretrainTransformer(MixMIMTransformer):
    """MixMIM backbone for MixMIM pre-training.

    A PyTorch implement of : ` MixMIM: Mixed and Masked Image
    Modeling for Efficient Visual Representation Learning
    <https://arxiv.org/abs/2205.13137>`_

    Args:
        arch (str | dict): MixMIM architecture. If use string,
            choose from 'base','large' and 'huge'.
            If use dict, it should have below keys:

            - **embed_dims** (int): The dimensions of embedding.
            - **depths** (int): The number of transformer encoder layers.
            - **num_heads** (int): The number of heads in attention modules.

            Defaults to 'base'.
        mlp_ratio (int): The mlp ratio in FFN.  Defaults to 4.
        img_size (int | tuple): The expected input image shape. Because we
            support dynamic input shape, just set the argument to mlp_ratio
            the most common input image shape. Defaults to 224.
        patch_size (int | tuple): The patch size in patch embedding.
            Defaults to 16.
        in_channels (int): The num of input channels. Defaults to 3.
        window_size (list): The height and width of the window.
        qkv_bias (bool): Whether to add bias for qkv in attention modules.
            Defaults to True.
        patch_cfg (dict): Extra config dict for patch embedding.
            Defaults to an empty dict.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        drop_rate (float): Probability of an element to be zeroed.
            Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
        attn_drop_rate (float): Attention drop rate. Defaults to 0.
        use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory
            cost. Defaults to False.
        mask_ratio (bool): The base ratio of total number of patches to be
            masked. Defaults to 0.5.
        range_mask_ratio (float): The range of mask ratio.
            Defaults to 0.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 arch: Union[str, dict] = 'base',
                 mlp_ratio: float = 4,
                 img_size: int = 224,
                 patch_size: int = 4,
                 in_channels: int = 3,
                 window_size: List = [14, 14, 14, 7],
                 qkv_bias: bool = True,
                 patch_cfg: dict = dict(),
                 norm_cfg: dict = dict(type='LN'),
                 drop_rate: float = 0.0,
                 drop_path_rate: float = 0.0,
                 attn_drop_rate: float = 0.0,
                 use_checkpoint: bool = False,
                 mask_ratio: float = 0.5,
                 range_mask_ratio: float = 0.0,
                 init_cfg: Optional[dict] = None) -> None:

        super().__init__(
            arch=arch,
            mlp_ratio=mlp_ratio,
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_channels,
            window_size=window_size,
            qkv_bias=qkv_bias,
            patch_cfg=patch_cfg,
            norm_cfg=norm_cfg,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            attn_drop_rate=attn_drop_rate,
            use_checkpoint=use_checkpoint,
            init_cfg=init_cfg)

        self.mask_ratio = mask_ratio
        self.range_mask_ratio = range_mask_ratio

    def init_weights(self):
        """Initialize position embedding, patch embedding."""
        super(MixMIMTransformer, self).init_weights()

        pos_embed = build_2d_sincos_position_embedding(
            int(self.num_patches**.5),
            self.absolute_pos_embed.shape[-1],
            cls_token=False)
        self.absolute_pos_embed.data.copy_(pos_embed.float())

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self,
                       x: torch.Tensor,
                       mask_ratio: float = 0.5) -> Tuple[torch.Tensor]:
        """Generate the mask for MixMIM Pretraining.

        Args:
            x (torch.Tensor): Image with data augmentation applied, which is
                of shape B x L x C.
            mask_ratio (float): The mask ratio of total patches.
                Defaults to 0.5.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
                - mask_s1 (torch.Tensor): mask with stride of
                  self.encoder_stride // 8.
                - mask_s2 (torch.Tensor): mask with stride of
                  self.encoder_stride // 4.
                - mask_s3 (torch.Tensor): mask with stride of
                  self.encoder_stride // 2.
                - mask (torch.Tensor): mask with stride of
                  self.encoder_stride.
        """

        B, C, H, W = x.shape
        out_H = H // self.encoder_stride
        out_W = W // self.encoder_stride
        s3_H, s3_W = out_H * 2, out_W * 2
        s2_H, s2_W = out_H * 4, out_W * 4
        s1_H, s1_W = out_H * 8, out_W * 8

        seq_l = out_H * out_W
        # use a shared mask for a batch images
        mask = torch.zeros([1, 1, seq_l], device=x.device)

        mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio)
        noise = torch.rand(1, 1, seq_l, device=x.device)  # noise in [0, 1]
        # ascend: small is keep, large is removed
        mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)]
        mask.scatter_(2, mask_idx, 1)
        mask = mask.reshape(1, 1, out_H, out_W)
        mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest')
        mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest')
        mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest')

        mask = mask.reshape(1, out_H * out_W, 1).contiguous()
        mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous()
        mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous()
        mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous()

        return mask_s1, mask_s2, mask_s3, mask

    def forward(self,
                x: torch.Tensor,
                mask: Optional[bool] = True) -> Tuple[torch.Tensor]:
        """Generate features for masked images.

        This function generates mask and masks some patches randomly and get
        the hidden features for visible patches.

        Args:
            x (torch.Tensor): Input images, which is of shape B x C x H x W.
            mask (bool, optional): To indicate whether the forward containing
                ``mask`` or not.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
              - x (torch.Tensor): hidden features, which is of shape
                B x L x C.
              - mask_s4 (torch.Tensor): the mask tensor for the last layer.
        """
        if mask is None or False:
            return super().forward(x)

        else:
            mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(
                x, self.mask_ratio)

            x, _ = self.patch_embed(x)

            x = x * (1. - mask_s1) + x.flip(0) * mask_s1
            x = x + self.absolute_pos_embed
            x = self.drop_after_pos(x)

            for idx, layer in enumerate(self.layers):
                if idx == 0:
                    x = layer(x, attn_mask=mask_s1)
                elif idx == 1:
                    x = layer(x, attn_mask=mask_s2)
                elif idx == 2:
                    x = layer(x, attn_mask=mask_s3)
                elif idx == 3:
                    x = layer(x, attn_mask=mask_s4)

            x = self.norm(x)

            return x, mask_s4


@MODELS.register_module()
class MixMIM(BaseSelfSupervisor):
    """MixMIM.

    Implementation of `MixMIM: Mixed and Masked Image Modeling for Efficient
    Visual Representation Learning. <https://arxiv.org/abs/2205.13137>`_.
    """

    def __init__(self,
                 backbone: dict,
                 neck: Optional[dict] = None,
                 head: Optional[dict] = None,
                 pretrained: Optional[str] = None,
                 data_preprocessor: Optional[Union[dict, nn.Module]] = None,
                 init_cfg: Optional[dict] = None):

        head.update(dict(patch_size=neck['encoder_stride']))
        super().__init__(
            backbone=backbone,
            neck=neck,
            head=head,
            pretrained=pretrained,
            data_preprocessor=data_preprocessor,
            init_cfg=init_cfg)

    def extract_feat(self, inputs: torch.Tensor):
        return self.backbone(inputs, mask=None)

    def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
             **kwargs) -> Dict[str, torch.Tensor]:
        """The forward function in training.

        Args:
            inputs (torch.Tensor): The input images.
            data_samples (List[DataSample]): All elements required
                during the forward function.

        Returns:
            Dict[str, torch.Tensor]: A dictionary of loss components.
        """
        latent, mask = self.backbone(inputs)
        x_rec = self.neck(latent, mask)
        loss = self.head.loss(x_rec, inputs, mask)
        losses = dict(loss=loss)
        return losses