# Copyright (c) OpenMMLab. All rights reserved. import math from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn from mmengine.model.weight_init import trunc_normal_ from mmpretrain.models.backbones.hivit import BlockWithRPE, HiViT, PatchMerge from mmpretrain.registry import MODELS from mmpretrain.structures import DataSample from ..utils import build_2d_sincos_position_embedding from .base import BaseSelfSupervisor @MODELS.register_module() class iTPNHiViT(HiViT): """HiViT for iTPN pre-training. Args: img_size (int | tuple): Input image size. Defaults to 224. patch_size (int | tuple): The patch size. Defaults to 16. inner_patches (int): Inner patch. Defaults to 4. stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in the first two stages. Defaults to 3. mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in the last stage. Defaults to 4. qkv_bias (bool): Enable bias for qkv projections if True. qk_scale (float): The number of divider after q@k. Default to None. drop_rate (float): Probability of an element to be zeroed. Defaults to 0. attn_drop_rate (float): The drop out rate for attention output weights. Defaults to 0. drop_path_rate (float): stochastic depth rate. Defaults to 0. norm_cfg (dict): Config dict for normalization layer. Defaults to ``dict(type='LN')``. ape (bool): If True, add absolute position embedding to the patch embedding. rpe (bool): If True, add relative position embedding to the patch embedding. layer_scale_init_value (float): Layer-scale init values. Defaults to 0. mask_ratio (bool): The ratio of total number of patches to be masked. Defaults to 0.75. reconstruction_type (str): The reconstruction of self-supervised learning. Defaults to 'pixel'. """ def __init__( self, arch='base', img_size: int = 224, patch_size: int = 16, inner_patches: int = 4, stem_mlp_ratio: int = 3., mlp_ratio: int = 4., qkv_bias: bool = True, qk_scale: Optional[bool] = None, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_cfg: dict = dict(type='LN', eps=1e-6), ape: bool = True, rpe: bool = False, layer_scale_init_value: float = 0.0, mask_ratio: float = 0.75, reconstruction_type: str = 'pixel', **kwargs, ): super().__init__( arch=arch, img_size=img_size, patch_size=patch_size, inner_patches=inner_patches, stem_mlp_ratio=stem_mlp_ratio, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, norm_cfg=norm_cfg, ape=ape, rpe=rpe, layer_scale_init_value=layer_scale_init_value, **kwargs, ) self.pos_embed.requires_grad = False self.mask_ratio = mask_ratio assert reconstruction_type in ['pixel', 'clip'], \ 'iTPN method only support `pixel` and `clip`, ' \ f'but got `{reconstruction_type}`.' self.reconstruction_type = reconstruction_type self.num_patches = self.patch_embed.num_patches if reconstruction_type == 'clip': self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) def init_weights(self) -> None: """Initialize position embedding, patch embedding and cls token.""" super().apply(self._init_weights) if self.reconstruction_type == 'clip': trunc_normal_(self.mask_token, std=0.02) self.rescale_init_weight() else: pos_embed = build_2d_sincos_position_embedding( int(self.num_patches**.5), self.pos_embed.shape[-1], cls_token=False) self.pos_embed.data.copy_(pos_embed.float()) w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) def rescale_init_weight(self) -> None: """Rescale the initialized weights.""" def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): if isinstance(layer, BlockWithRPE): if layer.attn is not None: rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def masking_id(self, batch_size, mask_ratio): N, L = batch_size, self.pos_embed.size(1) len_keep = int(L * (1 - mask_ratio)) noise = torch.rand( N, L, device=self.pos_embed.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=self.pos_embed.device) mask[:, :ids_keep.size(1)] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return ids_keep, ids_restore, mask def forward_pixel( self, x: torch.Tensor, mask: Optional[bool] = True ) -> Tuple[Tuple, torch.Tensor, torch.Tensor]: """Generate features for masked images. The function supports two kind of forward behaviors. If the ``mask`` is ``True``, the function will generate mask to masking some patches randomly and get the hidden features for visible patches, which means the function will be executed as masked imagemodeling pre-training; if the ``mask`` is ``None`` or ``False``, the forward function will call ``super().forward()``, which extract features from images without mask. Args: x (torch.Tensor): Input images, which is of shape B x C x H x W. mask (bool, optional): To indicate whether the forward function generating ``mask`` or not. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, mask and the ids to restore original image. - ``x`` (torch.Tensor): hidden features, which is of shape B x (L * mask_ratio) x C. - ``mask`` (torch.Tensor): mask used to mask image. - ``ids_restore`` (torch.Tensor): ids to restore original image. """ if mask is None or False: return super().forward(x) else: B, C, H, W = x.shape ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio) x = self.patch_embed(x) x = torch.gather( x, dim=1, index=ids_keep[:, :, None, None, None].expand(-1, -1, *x.shape[2:])) outs = [] for blk in self.blocks[:-self.num_main_blocks]: if isinstance(blk, PatchMerge): outs.append(x) x = blk(x) x = x[..., 0, 0, :] if self.ape: pos_embed = self.interpolate_pos_encoding(x, H, W) pos_embed = torch.gather( pos_embed.expand(B, -1, -1), dim=1, index=ids_keep[:, :, None].expand(-1, -1, pos_embed.shape[2]), ) x = x + pos_embed x = self.pos_drop(x) for blk in self.blocks[-self.num_main_blocks:]: x = blk(x) outs.append(x) return (tuple(outs), mask, ids_restore) def forward_clip(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple: """Generate features for masked images. The function supports two kind of forward behaviors. If the ``mask`` is ``True``, the function will generate mask to masking some patches randomly and get the hidden features for visible patches, which means the function will be executed as masked imagemodeling pre-training; if the ``mask`` is ``None`` or ``False``, the forward function will call ``super().forward()``, which extract features from images without mask. Args: x (torch.Tensor): Input images, which is of shape B x C x H x W. mask (bool, optional): To indicate whether the forward function generating ``mask`` or not. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, mask and the ids to restore original image. - ``x`` (torch.Tensor): hidden features, which is of shape B x (L * mask_ratio) x C. - ``mask`` (torch.Tensor): mask used to mask image. - ``ids_restore`` (torch.Tensor): ids to restore original image. """ if mask is None or False: return super().forward(x) else: B, C, H, W = x.shape x = self.patch_embed(x) outs = [] for blk in self.blocks[:-self.num_main_blocks]: if isinstance(blk, PatchMerge): outs.append(x) x = blk(x) x = x[..., 0, 0, :] B, L, _ = x.shape mask_token = self.mask_token.expand(B, L, -1) w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) x = x * (1. - w) + mask_token * w if self.ape: pos_embed = self.interpolate_pos_encoding(x, H, W) x = x + pos_embed x = self.pos_drop(x) rpe_index = True if self.rpe else None for blk in self.blocks[-self.num_main_blocks:]: x = blk(x, rpe_index) outs.append(x) return tuple(outs) def forward(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple: """Generate features for masked images. The function supports two kind of forward behaviors. If the ``mask`` is ``True``, the function will generate mask to masking some patches randomly and get the hidden features for visible patches, which means the function will be executed as masked imagemodeling pre-training; if the ``mask`` is ``None`` or ``False``, the forward function will call ``super().forward()``, which extract features from images without mask. Args: x (torch.Tensor): Input images, which is of shape B x C x H x W. mask (bool, optional): To indicate whether the forward function generating ``mask`` or not. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features, mask and the ids to restore original image. - ``x`` (torch.Tensor): hidden features, which is of shape B x (L * mask_ratio) x C. - ``mask`` (torch.Tensor): mask used to mask image. - ``ids_restore`` (torch.Tensor): ids to restore original image. """ if self.reconstruction_type == 'pixel': return self.forward_pixel(x, mask) return self.forward_clip(x, mask) @MODELS.register_module() class iTPN(BaseSelfSupervisor): """iTPN. Implementation of `iTPN: Integrally Pre-Trained Transformer Pyramid Networks `_. """ def extract_feat(self, inputs: torch.Tensor): return self.backbone(inputs, mask=None) def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (torch.Tensor): The input images. data_samples (List[DataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ if self.backbone.reconstruction_type == 'pixel': latent, mask, ids_restore = self.backbone(inputs) pred = self.neck(latent, ids_restore) loss = self.head.loss(pred, inputs, mask) else: mask = torch.stack( [data_sample.mask for data_sample in data_samples]) img_latent = self.backbone(inputs[0], mask) # inputs[1] is the target image with torch.no_grad(): target = self.target_generator(inputs[1])[0] target = target.detach() # iTPN contains a neck module feats = self.neck(img_latent) loss = self.head.loss(feats, target[:, 1:, :], mask) losses = dict(loss=loss) return losses