Spaces:
Runtime error
Runtime error
File size: 5,945 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
from mmpretrain.registry import MODELS
@MODELS.register_module()
class BEiTV2Neck(BaseModule):
"""Neck for BEiTV2 Pre-training.
This module construct the decoder for the final prediction.
Args:
num_layers (int): Number of encoder layers of neck. Defaults to 2.
early_layers (int): The layer index of the early output from the
backbone. Defaults to 9.
backbone_arch (str): Vision Transformer architecture. Defaults to base.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): The initialization value for the
learnable scaling of attention and FFN. Defaults to 0.1.
use_rel_pos_bias (bool): Whether to use unique relative position bias,
if False, use shared relative position bias defined in backbone.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'depth': 12,
'num_heads': 12,
'feedforward_channels': 3072,
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'depth': 24,
'num_heads': 16,
'feedforward_channels': 4096,
}),
}
def __init__(
self,
num_layers: int = 2,
early_layers: int = 9,
backbone_arch: str = 'base',
drop_rate: float = 0.,
drop_path_rate: float = 0.,
layer_scale_init_value: float = 0.1,
use_rel_pos_bias: bool = False,
norm_cfg: dict = dict(type='LN', eps=1e-6),
init_cfg: Optional[Union[dict, List[dict]]] = dict(
type='TruncNormal', layer='Linear', std=0.02, bias=0)
) -> None:
super().__init__(init_cfg=init_cfg)
if isinstance(backbone_arch, str):
backbone_arch = backbone_arch.lower()
assert backbone_arch in set(self.arch_zoo), \
(f'Arch {backbone_arch} is not in default archs '
f'{set(self.arch_zoo)}')
self.arch_settings = self.arch_zoo[backbone_arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(backbone_arch, dict) and essential_keys <= set(
backbone_arch
), f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = backbone_arch
# stochastic depth decay rule
self.early_layers = early_layers
depth = self.arch_settings['depth']
dpr = np.linspace(0, drop_path_rate,
max(depth, early_layers + num_layers))
self.patch_aggregation = nn.ModuleList()
for i in range(early_layers, early_layers + num_layers):
_layer_cfg = dict(
embed_dims=self.arch_settings['embed_dims'],
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
norm_cfg=norm_cfg,
layer_scale_init_value=layer_scale_init_value,
window_size=None,
use_rel_pos_bias=use_rel_pos_bias)
self.patch_aggregation.append(
BEiTTransformerEncoderLayer(**_layer_cfg))
self.rescale_patch_aggregation_init_weight()
embed_dims = self.arch_settings['embed_dims']
_, norm = build_norm_layer(norm_cfg, embed_dims)
self.add_module('norm', norm)
def rescale_patch_aggregation_init_weight(self):
"""Rescale the initialized weights."""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.patch_aggregation):
rescale(layer.attn.proj.weight.data,
self.early_layers + layer_id + 1)
rescale(layer.ffn.layers[1].weight.data,
self.early_layers + layer_id + 1)
def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the latent prediction and final prediction.
Args:
x (Tuple[torch.Tensor]): Features of tokens.
rel_pos_bias (torch.Tensor): Shared relative position bias table.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- ``x``: The final layer features from backbone, which are normed
in ``BEiTV2Neck``.
- ``x_cls_pt``: The early state features from backbone, which are
consist of final layer cls_token and early state patch_tokens
from backbone and sent to PatchAggregation layers in the neck.
"""
early_states, x = inputs[0], inputs[1]
x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1)
for layer in self.patch_aggregation:
x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias)
# shared norm
x, x_cls_pt = self.norm(x), self.norm(x_cls_pt)
# remove cls_token
x = x[:, 1:]
x_cls_pt = x_cls_pt[:, 1:]
return x, x_cls_pt
|