import os from collections import OrderedDict from mmcv.runner import _load_checkpoint import torch import torch.nn as nn import torch.nn.functional as F import math import numpy as np import logging from functools import partial from scipy import interpolate from math import pi from einops import rearrange, repeat import warnings import torch.utils.checkpoint as cp from mmdet.models.builder import BACKBONES logger = logging.getLogger(__name__) BatchNorm2d = torch.nn.BatchNorm2d XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: from xformers.ops import memory_efficient_attention, unbind XFORMERS_AVAILABLE = True warnings.warn("xFormers is available (Attention)") else: warnings.warn("xFormers is disabled (Attention)") raise ImportError except ImportError: XFORMERS_AVAILABLE = False warnings.warn("xFormers is not available (Attention)") class Conv2d(torch.nn.Conv2d): """ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. """ def __init__(self, *args, **kwargs): """ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: Args: norm (nn.Module, optional): a normalization layer activation (callable(Tensor) -> Tensor): a callable activation function It assumes that norm layer is used before activation. """ norm = kwargs.pop("norm", None) activation = kwargs.pop("activation", None) super().__init__(*args, **kwargs) self.norm = norm self.activation = activation def forward(self, x): # torchscript does not support SyncBatchNorm yet # https://github.com/pytorch/pytorch/issues/40507 # and we skip these codes in torchscript since: # 1. currently we only support torchscript in evaluation mode # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. if not torch.jit.is_scripting(): with warnings.catch_warnings(record=True): if x.numel() == 0 and self.training: # https://github.com/pytorch/pytorch/issues/12013 assert not isinstance( self.norm, torch.nn.SyncBatchNorm ), "SyncBatchNorm does not support empty inputs!" x = F.conv2d( x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x def window_partition(x, window_size): """ Partition into non-overlapping windows with padding if needed. Args: x (tensor): input tokens with [B, H, W, C]. window_size (int): window size. Returns: windows: windows after partition with [B * num_windows, window_size, window_size, C]. (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) Hp, Wp = H + pad_h, W + pad_w x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows, (Hp, Wp) def window_unpartition(windows, window_size, pad_hw, hw): """ Window unpartition into original sequences and removing padding. Args: x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. window_size (int): window size. pad_hw (Tuple): padded height and width (Hp, Wp). hw (Tuple): original height and width (H, W) before padding. Returns: x: unpartitioned sequences with [B, H, W, C]. """ Hp, Wp = pad_hw H, W = hw B = windows.shape[0] // (Hp * Wp // window_size // window_size) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) if Hp > H or Wp > W: x = x[:, :H, :W, :].contiguous() return x def get_rel_pos(q_size, k_size, rel_pos): """ Get relative positional embeddings according to the relative positions of query and key sizes. Args: q_size (int): size of query q. k_size (int): size of key k. rel_pos (Tensor): relative position embeddings (L, C). Returns: Extracted positional embeddings according to relative positions. """ max_rel_dist = int(2 * max(q_size, k_size) - 1) use_log_interpolation = True # Interpolate rel pos if needed. if rel_pos.shape[0] != max_rel_dist: if not use_log_interpolation: # Interpolate rel pos. rel_pos_resized = F.interpolate( rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear", ) rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) else: src_size = rel_pos.shape[0] dst_size = max_rel_dist # q = 1.13492 q = 1.0903078 dis = [] cur = 1 for i in range(src_size // 2): dis.append(cur) cur += q ** (i + 1) r_ids = [-_ for _ in reversed(dis)] x = r_ids + [0] + dis t = dst_size // 2.0 dx = np.arange(-t, t + 0.1, 1.0) # print("x = %s" % str(x)) # print("dx = %s" % str(dx)) all_rel_pos_bias = [] for i in range(rel_pos.shape[1]): z = rel_pos[:, i].view(src_size).cpu().float().numpy() f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate") all_rel_pos_bias.append( torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device)) rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1) else: rel_pos_resized = rel_pos # Scale the coords with short length if shapes for q and k are different. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) return rel_pos_resized[relative_coords.long()] def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): """ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ q_h, q_w = q_size k_h, k_w = k_size Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape r_q = q.reshape(B, q_h, q_w, dim) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) attn = ( attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] ).view(B, q_h * q_w, k_h * k_w) return attn def get_abs_pos(abs_pos, has_cls_token, hw): """ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the original embeddings. Args: abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. hw (Tuple): size of input image tokens. Returns: Absolute positional embeddings after processing with shape (1, H, W, C) """ h, w = hw if has_cls_token: abs_pos = abs_pos[:, 1:] xy_num = abs_pos.shape[1] size = int(math.sqrt(xy_num)) assert size * size == xy_num if size != h or size != w: original_datatype = abs_pos.dtype new_abs_pos = F.interpolate( abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2).float(), # bf16 is not implemented size=(h, w), mode="bicubic", align_corners=False, ).to(original_datatype) return new_abs_pos.permute(0, 2, 3, 1) else: return abs_pos.reshape(1, h, w, -1) class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 ): """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) def forward(self, x): x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x def broadcat(tensors, dim=-1): num_tensors = len(tensors) shape_lens = set(list(map(lambda t: len(t.shape), tensors))) assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' shape_len = list(shape_lens)[0] dim = (dim + shape_len) if dim < 0 else dim dims = list(zip(*map(lambda t: list(t.shape), tensors))) expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] assert all( [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) expanded_dims.insert(dim, (dim, dims[dim])) expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) return torch.cat(tensors, dim=dim) def rotate_half(x): x = rearrange(x, '... (d r) -> ... d r', r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, '... d r -> ... (d r)') class VisionRotaryEmbedding(nn.Module): def __init__( self, dim, pt_seq_len, ft_seq_len=None, custom_freqs=None, freqs_for='lang', theta=10000, max_freq=10, num_freqs=1, ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == 'lang': freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi elif freqs_for == 'constant': freqs = torch.ones(num_freqs).float() else: raise ValueError(f'unknown modality {freqs_for}') if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs_h = torch.einsum('..., f -> ... f', t, freqs) freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2) freqs_w = torch.einsum('..., f -> ... f', t, freqs) freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2) freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) self.register_buffer("freqs_cos", freqs.cos()) self.register_buffer("freqs_sin", freqs.sin()) print('======== shape of rope freq', self.freqs_cos.shape, '========') def forward(self, t, start_index=0): rot_dim = self.freqs_cos.shape[-1] end_index = start_index + rot_dim assert rot_dim <= t.shape[ -1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) return torch.cat((t_left, t, t_right), dim=-1) class VisionRotaryEmbeddingFast(nn.Module): def __init__( self, dim, pt_seq_len=16, ft_seq_len=None, custom_freqs=None, freqs_for='lang', theta=10000, max_freq=10, num_freqs=1, ): super().__init__() if custom_freqs: freqs = custom_freqs elif freqs_for == 'lang': freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi elif freqs_for == 'constant': freqs = torch.ones(num_freqs).float() else: raise ValueError(f'unknown modality {freqs_for}') if ft_seq_len is None: ft_seq_len = pt_seq_len t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len freqs = torch.einsum('..., f -> ... f', t, freqs) freqs = repeat(freqs, '... n -> ... (n r)', r=2) freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) self.register_buffer("freqs_cos", freqs_cos) self.register_buffer("freqs_sin", freqs_sin) print('======== shape of rope freq', self.freqs_cos.shape, '========') def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin class FrozenBatchNorm2d(nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed. It contains non-trainable buffers called "weight" and "bias", "running_mean", "running_var", initialized to perform identity transformation. The pre-trained backbone models from Caffe2 only contain "weight" and "bias", which are computed from the original four parameters of BN. The affine transform `x * weight + bias` will perform the equivalent computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. When loading a backbone model from Caffe2, "running_mean" and "running_var" will be left unchanged as identity transformation. Other pre-trained backbone models may contain all 4 parameters. The forward is implemented by `F.batch_norm(..., training=False)`. """ _version = 3 def __init__(self, num_features, eps=1e-5): super().__init__() self.num_features = num_features self.eps = eps self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("bias", torch.zeros(num_features)) self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features) - eps) def forward(self, x): if x.requires_grad: # When gradients are needed, F.batch_norm will use extra memory # because its backward op computes gradients for weight/bias as well. scale = self.weight * (self.running_var + self.eps).rsqrt() bias = self.bias - self.running_mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) out_dtype = x.dtype # may be half return x * scale.to(out_dtype) + bias.to(out_dtype) else: # When gradients are not needed, F.batch_norm is a single fused op # and provide more optimization opportunities. return F.batch_norm( x, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps, ) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): version = local_metadata.get("version", None) if version is None or version < 2: # No running_mean/var in early versions # This will silent the warnings if prefix + "running_mean" not in state_dict: state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) if prefix + "running_var" not in state_dict: state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def __repr__(self): return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) @classmethod def convert_frozen_batchnorm(cls, module): """ Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. Args: module (torch.nn.Module): Returns: If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it. Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py """ bn_module = nn.modules.batchnorm bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) res = module if isinstance(module, bn_module): res = cls(module.num_features) if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data res.eps = module.eps else: for name, child in module.named_children(): new_child = cls.convert_frozen_batchnorm(child) if new_child is not child: res.add_module(name, new_child) return res class LayerNorm(nn.Module): """ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and variance normalization over the channel dimension for inputs that have shape (batch_size, channels, height, width). https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 """ def __init__(self, normalized_shape, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.normalized_shape = (normalized_shape,) def forward(self, x): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class CNNBlockBase(nn.Module): """ A CNN block is assumed to have input channels, output channels and a stride. The input and output of `forward()` method must be NCHW tensors. The method can perform arbitrary computation but must match the given channels and stride specification. Attribute: in_channels (int): out_channels (int): stride (int): """ def __init__(self, in_channels, out_channels, stride): """ The `__init__` method of any subclass should also contain these arguments. Args: in_channels (int): out_channels (int): stride (int): """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.stride = stride def freeze(self): """ Make this block not trainable. This method sets all parameters to `requires_grad=False`, and convert all BatchNorm layers to FrozenBatchNorm Returns: the block itself """ for p in self.parameters(): p.requires_grad = False FrozenBatchNorm2d.convert_frozen_batchnorm(self) return self def get_norm(norm, out_channels): """ Args: norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module. Returns: nn.Module or None: the normalization layer """ if norm is None: return None if isinstance(norm, str): if len(norm) == 0: return None norm = { "BN": BatchNorm2d, # Fixed in https://github.com/pytorch/pytorch/pull/36382 "SyncBN": nn.SyncBatchNorm, "FrozenBN": FrozenBatchNorm2d, "GN": lambda channels: nn.GroupNorm(32, channels), # for debugging: "nnSyncBN": nn.SyncBatchNorm, "LN": lambda channels: LayerNorm(channels) }[norm] return norm(out_channels) class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): if self.drop_prob == 0. or not self.training: return x keep_prob = 1 - self.drop_prob # work with diff dim tensors, not just 2D ConvNets shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + \ torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class SwiGLU(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0., norm_layer=nn.LayerNorm, subln=False ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w1 = nn.Linear(in_features, hidden_features) self.w2 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() self.w3 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x1 = self.w1(x) x2 = self.w2(x) hidden = self.act(x1) * x2 x = self.ffn_ln(hidden) x = self.w3(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_head_dim=None, norm_layer=nn.LayerNorm, rope=None, flash_attn=True, xformers_attn=True, subln=False ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim ** -0.5 self.subln = subln self.q_proj = nn.Linear(dim, all_head_dim, bias=False) self.k_proj = nn.Linear(dim, all_head_dim, bias=False) self.v_proj = nn.Linear(dim, all_head_dim, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None self.rope = rope # self.flash_attn = flash_attn self.proj = nn.Linear(all_head_dim, dim) self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() self.xformers_attn = xformers_attn # if self.flash_attn: # factory_kwargs = {'device': 'cuda', 'dtype': torch.float16} # self.inner_attn = FlashAttention(attention_dropout=0.0, **factory_kwargs) def forward(self, x): B, H, W, C = x.shape x = x.view(B, -1, C) N = H * W q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) k = F.linear(input=x, weight=self.k_proj.weight, bias=None) v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) ## rope q = self.rope(q).type_as(v) k = self.rope(k).type_as(v) if self.xformers_attn: q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) # kv = torch.stack([k, v], dim=2) # x, attn_weights = self.inner_attn(q, kv, key_padding_mask=None, causal=False) x = memory_efficient_attention(q, k, v) x = x.reshape(B, N, -1) x = self.inner_attn_ln(x) else: q = q * self.scale attn = (q @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1).type_as(x) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.inner_attn_ln(x) x = self.proj(x) x = x.view(B, H, W, C) return x class ResBottleneckBlock(CNNBlockBase): """ The standard bottleneck residual block without the last activation layer. It contains 3 conv layers with kernels 1x1, 3x3, 1x1. """ def __init__( self, in_channels, out_channels, bottleneck_channels, norm="LN", act_layer=nn.GELU, ): """ Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. bottleneck_channels (int): number of output channels for the 3x3 "bottleneck" conv layers. norm (str or callable): normalization for all conv layers. See :func:`layers.get_norm` for supported format. act_layer (callable): activation for all conv layers. """ super().__init__(in_channels, out_channels, 1) self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) self.norm1 = get_norm(norm, bottleneck_channels) self.act1 = act_layer() self.conv2 = Conv2d( bottleneck_channels, bottleneck_channels, 3, padding=1, bias=False, ) self.norm2 = get_norm(norm, bottleneck_channels) self.act2 = act_layer() self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) self.norm3 = get_norm(norm, out_channels) for layer in [self.norm1, self.norm2]: layer.weight.data.fill_(1.0) layer.bias.data.zero_() # zero init last norm layer. self.norm3.weight.data.zero_() self.norm3.bias.data.zero_() def forward(self, x): out = x for layer in self.children(): out = layer(out) out = x + out return out class Block(nn.Module): """Transformer blocks with support of window attention and residual propagation blocks""" def __init__( self, dim, num_heads, mlp_ratio=4 * 2 / 3, qkv_bias=True, drop_path=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), window_size=0, use_residual_block=False, rope=None, flash_attn=False, xformers_attn=True, subln=False, with_cp=True, ): """ Args: dim (int): Number of input channels. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. drop_path (float): Stochastic depth rate. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. If it equals 0, then not use window attention. use_residual_block (bool): If True, use a residual block after the MLP block. input_size (int or None): Input resolution for calculating the relative positional parameter size. """ super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, rope=rope, flash_attn=flash_attn, xformers_attn=xformers_attn, subln=subln ) self.with_cp = with_cp self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = SwiGLU( in_features=dim, hidden_features=int(dim * mlp_ratio), subln=True, norm_layer=norm_layer, ) self.window_size = window_size self.use_residual_block = use_residual_block if use_residual_block: # Use a residual block with bottleneck channel as dim // 2 self.residual = ResBottleneckBlock( in_channels=dim, out_channels=dim, bottleneck_channels=dim // 2, norm="LN", ) def _forward(self, x): shortcut = x x = self.norm1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) if self.use_residual_block: x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) return x def forward(self, x): if self.with_cp and self.training: x = cp.checkpoint(self._forward, x) else: x = self._forward(x) return x @BACKBONES.register_module() class EVAViT(nn.Module): """ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. "Exploring Plain Vision Transformer Backbones for Object Detection", https://arxiv.org/abs/2203.16527 """ def __init__( self, img_size=1024, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4 * 2 / 3, qkv_bias=True, drop_path_rate=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), use_abs_pos=True, pt_hw_seq_len=16, intp_freq=True, window_size=0, global_window_size=0, window_block_indexes=(), residual_block_indexes=(), pretrain_img_size=224, pretrain_use_cls_token=True, out_feature="last_feat", subln=False, flash_attn=False, xformers_attn=True, with_cp=True, frozen=False, ): """ Args: img_size (int): Input image size. patch_size (int): Patch size. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. depth (int): Depth of ViT. num_heads (int): Number of attention heads in each ViT block. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool): If True, add a learnable bias to query, key, value. drop_path_rate (float): Stochastic depth rate. norm_layer (nn.Module): Normalization layer. act_layer (nn.Module): Activation layer. use_abs_pos (bool): If True, use absolute positional embeddings. use_rel_pos (bool): If True, add relative positional embeddings to the attention map. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. window_size (int): Window size for window attention blocks. window_block_indexes (list): Indexes for blocks using window attention. residual_block_indexes (list): Indexes for blocks using conv propagation. use_act_checkpoint (bool): If True, use activation checkpointing. pretrain_img_size (int): input image size for pretraining models. pretrain_use_cls_token (bool): If True, pretrainig models use class token. out_feature (str): name of the feature from the last block. """ super().__init__() self.pretrain_use_cls_token = pretrain_use_cls_token self.patch_embed = PatchEmbed( kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), in_chans=in_chans, embed_dim=embed_dim, ) self.frozen = frozen if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size) num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) else: self.pos_embed = None half_head_dim = embed_dim // num_heads // 2 hw_seq_len = img_size // patch_size self.rope_win = VisionRotaryEmbeddingFast( dim=half_head_dim, pt_seq_len=pt_hw_seq_len, ft_seq_len=window_size if intp_freq else None, ) self.rope_glb = VisionRotaryEmbeddingFast( dim=half_head_dim, pt_seq_len=pt_hw_seq_len, ft_seq_len=hw_seq_len if intp_freq else None, ) # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] self.blocks = nn.ModuleList() for i in range(depth): block = Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=dpr[i], norm_layer=norm_layer, window_size=window_size if i in window_block_indexes else global_window_size, use_residual_block=i in residual_block_indexes, rope=self.rope_win if i in window_block_indexes else self.rope_glb, flash_attn=flash_attn, xformers_attn=xformers_attn, subln=subln, with_cp=with_cp, ) self.blocks.append(block) self._out_feature_channels = {out_feature: embed_dim} self._out_feature_strides = {out_feature: patch_size} self._out_features = [out_feature] self.adapter = None if self.pos_embed is not None: nn.init.normal_(self.pos_embed, std=0.02) # MIN SHI: I disable the weight initialization since they will be automatically loaded # **However, they will cause problems (deepspeed + bf16)** # self.apply(self._init_weights) self._freeze_stages() def _freeze_stages(self): if self.frozen: self.eval() for m in self.parameters(): m.requires_grad = False def init_weights(self, ckpt_path): ckpt = _load_checkpoint(ckpt_path, logger=logger, map_location='cpu') if 'state_dict' in ckpt: _state_dict = ckpt['state_dict'] elif 'model' in ckpt: _state_dict = ckpt['model'] else: _state_dict = ckpt state_dict = OrderedDict() rope_keys = [] for k, v in _state_dict.items(): if 'rope' in k: rope_keys.append(k[13:]) if k.startswith('backbone.'): state_dict[k[9:]] = v elif k.startswith('img_backbone.'): state_dict[k[13:]] = v else: state_dict[k] = v print(f'Before deleting rope keys, {len(state_dict)}') for rope_k in rope_keys: del state_dict[rope_k] print(f'After deleting rope keys, {len(state_dict)}') # strip prefix of state_dict if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} # load state_dict meg = self.load_state_dict(state_dict, False) print(meg) def forward(self, x,): x = self.patch_embed(x) if self.pos_embed is not None: x = x + get_abs_pos( self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2]) ) for blk in self.blocks: x = blk(x) # b, h, w, c x = x.permute(0, 3, 1, 2) # b, c, h, w # print(x.shape) return [x, ]