TTP / mmpretrain /models /necks /itpn_neck.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.models.backbones.hivit import BlockWithRPE
from mmpretrain.registry import MODELS
from ..backbones.vision_transformer import TransformerEncoderLayer
from ..utils import build_2d_sincos_position_embedding
class PatchSplit(nn.Module):
"""The up-sample module used in neck (transformer pyramid network)
Args:
dim (int): the input dimension (channel number).
fpn_dim (int): the fpn dimension (channel number).
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
"""
def __init__(self, dim, fpn_dim, norm_cfg):
super().__init__()
_, self.norm = build_norm_layer(norm_cfg, dim)
self.reduction = nn.Linear(dim, fpn_dim * 4, bias=False)
self.fpn_dim = fpn_dim
def forward(self, x):
B, N, H, W, C = x.shape
x = self.norm(x)
x = self.reduction(x)
x = x.reshape(B, N, H, W, 2, 2,
self.fpn_dim).permute(0, 1, 2, 4, 3, 5,
6).reshape(B, N, 2 * H, 2 * W,
self.fpn_dim)
return x
@MODELS.register_module()
class iTPNPretrainDecoder(BaseModule):
"""The neck module of iTPN (transformer pyramid network).
Args:
num_patches (int): The number of total patches. Defaults to 196.
patch_size (int): Image patch size. Defaults to 16.
in_chans (int): The channel of input image. Defaults to 3.
embed_dim (int): Encoder's embedding dimension. Defaults to 512.
fpn_dim (int): The fpn dimension (channel number).
fpn_depth (int): The layer number of feature pyramid.
decoder_embed_dim (int): Decoder's embedding dimension.
Defaults to 512.
decoder_depth (int): The depth of decoder. Defaults to 8.
decoder_num_heads (int): Number of attention heads of decoder.
Defaults to 16.
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
Defaults to 4.
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
reconstruction_type (str): The itpn supports 2 kinds of supervisions.
Defaults to 'pixel'.
num_outs (int): The output number of neck (transformer pyramid
network). Defaults to 3.
predict_feature_dim (int): The output dimension to supervision.
Defaults to None.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
num_patches: int = 196,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 512,
fpn_dim: int = 256,
fpn_depth: int = 2,
decoder_embed_dim: int = 512,
decoder_depth: int = 6,
decoder_num_heads: int = 16,
mlp_ratio: int = 4,
norm_cfg: dict = dict(type='LN', eps=1e-6),
reconstruction_type: str = 'pixel',
num_outs: int = 3,
qkv_bias: bool = True,
qk_scale: Optional[bool] = None,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
predict_feature_dim: Optional[float] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.num_patches = num_patches
assert reconstruction_type in ['pixel', 'clip'], \
'iTPN method only support `pixel` and `clip`, ' \
f'but got `{reconstruction_type}`.'
self.reconstruction_type = reconstruction_type
self.num_outs = num_outs
self.build_transformer_pyramid(
num_outs=num_outs,
embed_dim=embed_dim,
fpn_dim=fpn_dim,
fpn_depth=fpn_depth,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
rpe=False,
norm_cfg=norm_cfg,
)
# merge the output
self.decoder_embed = nn.ModuleList()
self.decoder_embed.append(
nn.Sequential(
nn.LayerNorm(fpn_dim),
nn.Linear(fpn_dim, decoder_embed_dim, bias=True),
))
if self.num_outs >= 2:
self.decoder_embed.append(
nn.Sequential(
nn.LayerNorm(fpn_dim),
nn.Linear(fpn_dim, decoder_embed_dim // 4, bias=True),
))
if self.num_outs >= 3:
self.decoder_embed.append(
nn.Sequential(
nn.LayerNorm(fpn_dim),
nn.Linear(fpn_dim, decoder_embed_dim // 16, bias=True),
))
if reconstruction_type == 'pixel':
self.mask_token = nn.Parameter(
torch.zeros(1, 1, decoder_embed_dim))
# create new position embedding, different from that in encoder
# and is not learnable
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, decoder_embed_dim),
requires_grad=False)
self.decoder_blocks = nn.ModuleList([
TransformerEncoderLayer(
decoder_embed_dim,
decoder_num_heads,
int(mlp_ratio * decoder_embed_dim),
qkv_bias=True,
norm_cfg=norm_cfg) for _ in range(decoder_depth)
])
self.decoder_norm_name, decoder_norm = build_norm_layer(
norm_cfg, decoder_embed_dim, postfix=1)
self.add_module(self.decoder_norm_name, decoder_norm)
# Used to map features to pixels
if predict_feature_dim is None:
predict_feature_dim = patch_size**2 * in_chans
self.decoder_pred = nn.Linear(
decoder_embed_dim, predict_feature_dim, bias=True)
else:
_, norm = build_norm_layer(norm_cfg, embed_dim)
self.add_module('norm', norm)
def build_transformer_pyramid(self,
num_outs=3,
embed_dim=512,
fpn_dim=256,
fpn_depth=2,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
rpe=False,
norm_cfg=None):
Hp = None
mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim}
if num_outs > 1:
if embed_dim != fpn_dim:
self.align_dim_16tofpn = nn.Linear(embed_dim, fpn_dim)
else:
self.align_dim_16tofpn = None
self.fpn_modules = nn.ModuleList()
self.fpn_modules.append(
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=rpe,
norm_cfg=norm_cfg))
self.fpn_modules.append(
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=False,
norm_cfg=norm_cfg,
))
self.align_dim_16to8 = nn.Linear(
mlvl_dims['8'], fpn_dim, bias=False)
self.split_16to8 = PatchSplit(mlvl_dims['16'], fpn_dim, norm_cfg)
self.block_16to8 = nn.Sequential(*[
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=rpe,
norm_cfg=norm_cfg,
) for _ in range(fpn_depth)
])
if num_outs > 2:
self.align_dim_8to4 = nn.Linear(
mlvl_dims['4'], fpn_dim, bias=False)
self.split_8to4 = PatchSplit(fpn_dim, fpn_dim, norm_cfg)
self.block_8to4 = nn.Sequential(*[
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=rpe,
norm_cfg=norm_cfg,
) for _ in range(fpn_depth)
])
self.fpn_modules.append(
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=rpe,
norm_cfg=norm_cfg))
def init_weights(self) -> None:
"""Initialize position embedding and mask token of MAE decoder."""
super().init_weights()
if self.reconstruction_type == 'pixel':
decoder_pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.decoder_pos_embed.shape[-1],
cls_token=False)
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
torch.nn.init.normal_(self.mask_token, std=.02)
else:
self.rescale_init_weight()
def rescale_init_weight(self) -> None:
"""Rescale the initialized weights."""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.fpn_modules):
if isinstance(layer, BlockWithRPE):
if layer.attn is not None:
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
@property
def decoder_norm(self):
"""The normalization layer of decoder."""
return getattr(self, self.decoder_norm_name)
def forward(self,
x: torch.Tensor,
ids_restore: torch.Tensor = None) -> torch.Tensor:
"""The forward function.
The process computes the visible patches' features vectors and the mask
tokens to output feature vectors, which will be used for
reconstruction.
Args:
x (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
ids_restore (torch.Tensor): ids to restore original image.
Returns:
torch.Tensor: The reconstructed feature vectors, which is of
shape B x (num_patches) x C.
"""
features = x[:2]
x = x[-1]
B, L, _ = x.shape
x = x[..., None, None, :]
Hp = Wp = math.sqrt(L)
outs = [x] if self.align_dim_16tofpn is None else [
self.align_dim_16tofpn(x)
]
if self.num_outs >= 2:
x = self.block_16to8(
self.split_16to8(x) + self.align_dim_16to8(features[1]))
outs.append(x)
if self.num_outs >= 3:
x = self.block_8to4(
self.split_8to4(x) + self.align_dim_8to4(features[0]))
outs.append(x)
if self.num_outs > 3:
outs = [
out.reshape(B, Hp, Wp, *out.shape[-3:]).permute(
0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * out.shape[-3],
Wp * out.shape[-2]).contiguous()
for out in outs
]
if self.num_outs >= 4:
outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2))
if self.num_outs >= 5:
outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2))
for i, out in enumerate(outs):
out = self.fpn_modules[i](out)
outs[i] = out
if self.reconstruction_type == 'pixel':
feats = []
for feat, layer in zip(outs, self.decoder_embed):
x = layer(feat).reshape(B, L, -1)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x = torch.cat([x, mask_tokens], dim=1)
x = torch.gather(
x,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
feats.append(x)
x = feats.pop(0)
# add pos embed
x = x + self.decoder_pos_embed
for i, feat in enumerate(feats):
x = x + feats[i]
# apply Transformer blocks
for i, blk in enumerate(self.decoder_blocks):
x = blk(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
return x
else:
feats = []
for feat, layer in zip(outs, self.decoder_embed):
x = layer(feat).reshape(B, L, -1)
feats.append(x)
x = feats.pop(0)
for i, feat in enumerate(feats):
x = x + feats[i]
x = self.norm(x)
return x