Spaces:
Runtime error
Runtime error
File size: 8,577 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from torch import nn
from mmpretrain.registry import MODELS
from ..backbones.vision_transformer import TransformerEncoderLayer
from ..utils import PromptMultiheadAttention
from .mae_neck import MAEPretrainDecoder
class PromptTransformerEncoderLayer(TransformerEncoderLayer):
"""Prompt Transformer Encoder Layer for MILAN.
This module is specific for the prompt encoder in MILAN. It will not update
the visible tokens from the encoder.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Defaults to 0.0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): Enable bias for qkv if True. Defaults to True.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
feedforward_channels=int,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
num_fcs: int = 2,
qkv_bias: bool = True,
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN'),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
self.attn = PromptMultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias)
def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor,
ids_restore: torch.Tensor) -> torch.Tensor:
"""Forward function for `PromptMultiheadAttention`.
Args:
x (torch.Tensor): Mask token features with shape N x L_m x C.
visible_tokens (torch.Tensor): The visible tokens features from
encoder with shape N x L_v x C.
ids_restore (torch.Tensor): The ids of all tokens in the original
image with shape N x L.
Returns:
torch Tensor: Output features with shape N x L x C.
"""
x = x + self.attn(self.norm1(x), visible_tokens, ids_restore)
x = self.ffn(self.norm2(x), identity=x)
return x
@MODELS.register_module()
class MILANPretrainDecoder(MAEPretrainDecoder):
"""Prompt decoder for MILAN.
This decoder is used in MILAN pretraining, which will not update these
visible tokens from the encoder.
Args:
num_patches (int): The number of total patches. Defaults to 196.
patch_size (int): Image patch size. Defaults to 16.
in_chans (int): The channel of input image. Defaults to 3.
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
decoder_embed_dim (int): Decoder's embedding dimension.
Defaults to 512.
decoder_depth (int): The depth of decoder. Defaults to 8.
decoder_num_heads (int): Number of attention heads of decoder.
Defaults to 16.
predict_feature_dim (int): The dimension of the feature to be
predicted. Defaults to 512.
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
Defaults to 4.
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
num_patches: int = 196,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 1024,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
predict_feature_dim: int = 512,
mlp_ratio: int = 4,
norm_cfg: dict = dict(type='LN', eps=1e-6),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
num_patches=num_patches,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
# map the dim of features from decoder to the dim compatible with
# that of CLIP
self.decoder_pred = nn.Linear(
decoder_embed_dim, predict_feature_dim, bias=True)
# use prompt transformer encoder layer, instead of the conventional
# transformer encoder layer
self.decoder_blocks = nn.ModuleList([
PromptTransformerEncoderLayer(
decoder_embed_dim,
decoder_num_heads,
int(mlp_ratio * decoder_embed_dim),
qkv_bias=True,
norm_cfg=norm_cfg) for _ in range(decoder_depth)
])
def forward(self, x: torch.Tensor, ids_restore: torch.Tensor,
ids_keep: torch.Tensor,
ids_dump: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
x (torch.Tensor): The input features, which is of shape (N, L, C).
ids_restore (torch.Tensor): The indices to restore these tokens
to the original image.
ids_keep (torch.Tensor): The indices of tokens to be kept.
ids_dump (torch.Tensor): The indices of tokens to be masked.
Returns:
torch.Tensor: The reconstructed features, which is of shape
(N, L, C).
"""
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
# add pos embed
x = x + self.decoder_pos_embed
# split mask tokens and visible tokens
visible_tokens = torch.cat([
x[:, :1, :],
torch.gather(
x[:, 1:, :],
dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
],
dim=1)
x = torch.gather(
x[:, 1:, :],
dim=1,
index=ids_dump.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
for blk in self.decoder_blocks:
x = blk(x, visible_tokens, ids_restore)
# full sequence recovery
x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1)
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1,
x.shape[-1])) # unshuffle
x = torch.cat([visible_tokens[:, :1, :], x_], dim=1)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
return x
|