|
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
from collections import OrderedDict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint as checkpoint
|
|
from mmcv.cnn import constant_init, trunc_normal_init
|
|
from mmcv.runner import _load_checkpoint
|
|
from timm.models.layers import trunc_normal_, DropPath
|
|
|
|
try:
|
|
from navsim.agents.backbones.ops_dcnv3 import modules as opsm
|
|
except:
|
|
opsm = None
|
|
print('DCN v3 unsupported, ignored')
|
|
|
|
import torch.nn.functional as F
|
|
from mmdet.models.builder import BACKBONES
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class to_channels_first(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x.permute(0, 3, 1, 2)
|
|
|
|
|
|
class to_channels_last(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x.permute(0, 2, 3, 1)
|
|
|
|
|
|
def build_norm_layer(dim,
|
|
norm_layer,
|
|
in_format='channels_last',
|
|
out_format='channels_last',
|
|
eps=1e-6):
|
|
layers = []
|
|
if norm_layer == 'BN':
|
|
if in_format == 'channels_last':
|
|
layers.append(to_channels_first())
|
|
layers.append(nn.BatchNorm2d(dim))
|
|
if out_format == 'channels_last':
|
|
layers.append(to_channels_last())
|
|
elif norm_layer == 'LN':
|
|
if in_format == 'channels_first':
|
|
layers.append(to_channels_last())
|
|
layers.append(nn.LayerNorm(dim, eps=eps))
|
|
if out_format == 'channels_first':
|
|
layers.append(to_channels_first())
|
|
else:
|
|
raise NotImplementedError(
|
|
f'build_norm_layer does not support {norm_layer}')
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
def build_act_layer(act_layer):
|
|
if act_layer == 'ReLU':
|
|
return nn.ReLU(inplace=True)
|
|
elif act_layer == 'SiLU':
|
|
return nn.SiLU(inplace=True)
|
|
elif act_layer == 'GELU':
|
|
return nn.GELU()
|
|
|
|
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
r""" Cross Attention Module
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
num_heads (int): Number of attention heads. Default: 8
|
|
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
|
Default: False.
|
|
qk_scale (float | None, optional): Override default qk scale of
|
|
head_dim ** -0.5 if set. Default: None.
|
|
attn_drop (float, optional): Dropout ratio of attention weight.
|
|
Default: 0.0
|
|
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
|
attn_head_dim (int, optional): Dimension of attention head.
|
|
out_dim (int, optional): Dimension of output.
|
|
"""
|
|
|
|
def __init__(self,
|
|
dim,
|
|
num_heads=8,
|
|
qkv_bias=False,
|
|
qk_scale=None,
|
|
attn_drop=0.,
|
|
proj_drop=0.,
|
|
attn_head_dim=None,
|
|
out_dim=None):
|
|
super().__init__()
|
|
if out_dim is None:
|
|
out_dim = dim
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
if attn_head_dim is not None:
|
|
head_dim = attn_head_dim
|
|
all_head_dim = head_dim * self.num_heads
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
assert all_head_dim == dim
|
|
|
|
self.q = nn.Linear(dim, all_head_dim, bias=False)
|
|
self.k = nn.Linear(dim, all_head_dim, bias=False)
|
|
self.v = nn.Linear(dim, all_head_dim, bias=False)
|
|
|
|
if qkv_bias:
|
|
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
|
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
|
|
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
|
else:
|
|
self.q_bias = None
|
|
self.k_bias = None
|
|
self.v_bias = None
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = nn.Linear(all_head_dim, out_dim)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
def forward(self, x, k=None, v=None):
|
|
B, N, C = x.shape
|
|
N_k = k.shape[1]
|
|
N_v = v.shape[1]
|
|
|
|
q_bias, k_bias, v_bias = None, None, None
|
|
if self.q_bias is not None:
|
|
q_bias = self.q_bias
|
|
k_bias = self.k_bias
|
|
v_bias = self.v_bias
|
|
|
|
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
|
|
q = q.reshape(B, N, 1, self.num_heads,
|
|
-1).permute(2, 0, 3, 1,
|
|
4).squeeze(0)
|
|
|
|
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
|
|
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1,
|
|
4).squeeze(0)
|
|
|
|
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
|
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1,
|
|
4).squeeze(0)
|
|
|
|
q = q * self.scale
|
|
attn = (q @ k.transpose(-2, -1))
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
attn = self.attn_drop(attn)
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
|
|
return x
|
|
|
|
|
|
class AttentiveBlock(nn.Module):
|
|
r"""Attentive Block
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
num_heads (int): Number of attention heads. Default: 8
|
|
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
|
Default: False.
|
|
qk_scale (float | None, optional): Override default qk scale of
|
|
head_dim ** -0.5 if set. Default: None.
|
|
drop (float, optional): Dropout rate. Default: 0.0.
|
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0.
|
|
drop_path (float | tuple[float], optional): Stochastic depth rate.
|
|
Default: 0.0.
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm.
|
|
attn_head_dim (int, optional): Dimension of attention head. Default: None.
|
|
out_dim (int, optional): Dimension of output. Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
dim,
|
|
num_heads,
|
|
qkv_bias=False,
|
|
qk_scale=None,
|
|
drop=0.,
|
|
attn_drop=0.,
|
|
drop_path=0.,
|
|
norm_layer="LN",
|
|
attn_head_dim=None,
|
|
out_dim=None):
|
|
super().__init__()
|
|
|
|
self.norm1_q = build_norm_layer(dim, norm_layer, eps=1e-6)
|
|
self.norm1_k = build_norm_layer(dim, norm_layer, eps=1e-6)
|
|
self.norm1_v = build_norm_layer(dim, norm_layer, eps=1e-6)
|
|
self.cross_dcn = CrossAttention(dim,
|
|
num_heads=num_heads,
|
|
qkv_bias=qkv_bias,
|
|
qk_scale=qk_scale,
|
|
attn_drop=attn_drop,
|
|
proj_drop=drop,
|
|
attn_head_dim=attn_head_dim,
|
|
out_dim=out_dim)
|
|
|
|
self.drop_path = DropPath(
|
|
drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
def forward(self,
|
|
x_q,
|
|
x_kv,
|
|
pos_q,
|
|
pos_k,
|
|
bool_masked_pos,
|
|
rel_pos_bias=None):
|
|
x_q = self.norm1_q(x_q + pos_q)
|
|
x_k = self.norm1_k(x_kv + pos_k)
|
|
x_v = self.norm1_v(x_kv)
|
|
|
|
x = self.cross_dcn(x_q, k=x_k, v=x_v)
|
|
|
|
return x
|
|
|
|
|
|
class AttentionPoolingBlock(AttentiveBlock):
|
|
|
|
def forward(self, x):
|
|
x_q = x.mean(1, keepdim=True)
|
|
x_kv = x
|
|
pos_q, pos_k = 0, 0
|
|
x = super().forward(x_q, x_kv, pos_q, pos_k,
|
|
bool_masked_pos=None,
|
|
rel_pos_bias=None)
|
|
x = x.squeeze(1)
|
|
return x
|
|
|
|
|
|
class StemLayer(nn.Module):
|
|
r""" Stem layer of InternImage
|
|
Args:
|
|
in_chans (int): number of input channels
|
|
out_chans (int): number of output channels
|
|
act_layer (str): activation layer
|
|
norm_layer (str): normalization layer
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_chans=3,
|
|
out_chans=96,
|
|
act_layer='GELU',
|
|
norm_layer='BN'):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(in_chans,
|
|
out_chans // 2,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1)
|
|
self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
|
|
'channels_first', 'channels_first')
|
|
self.act = build_act_layer(act_layer)
|
|
self.conv2 = nn.Conv2d(out_chans // 2,
|
|
out_chans,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1)
|
|
self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
|
|
'channels_last')
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.norm1(x)
|
|
x = self.act(x)
|
|
x = self.conv2(x)
|
|
x = self.norm2(x)
|
|
return x
|
|
|
|
|
|
class DownsampleLayer(nn.Module):
|
|
r""" Downsample layer of InternImage
|
|
Args:
|
|
channels (int): number of input channels
|
|
norm_layer (str): normalization layer
|
|
"""
|
|
|
|
def __init__(self, channels, norm_layer='LN'):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(channels,
|
|
2 * channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
bias=False)
|
|
self.norm = build_norm_layer(2 * channels, norm_layer,
|
|
'channels_first', 'channels_last')
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x.permute(0, 3, 1, 2))
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
|
|
class MLPLayer(nn.Module):
|
|
r""" MLP layer of InternImage
|
|
Args:
|
|
in_features (int): number of input features
|
|
hidden_features (int): number of hidden features
|
|
out_features (int): number of output features
|
|
act_layer (str): activation layer
|
|
drop (float): dropout rate
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_features,
|
|
hidden_features=None,
|
|
out_features=None,
|
|
act_layer='GELU',
|
|
drop=0.):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
self.act = build_act_layer(act_layer)
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class InternImageLayer(nn.Module):
|
|
r""" Basic layer of InternImage
|
|
Args:
|
|
core_op (nn.Module): core operation of InternImage
|
|
channels (int): number of input channels
|
|
groups (list): Groups of each block.
|
|
mlp_ratio (float): ratio of mlp hidden features to input channels
|
|
drop (float): dropout rate
|
|
drop_path (float): drop path rate
|
|
act_layer (str): activation layer
|
|
norm_layer (str): normalization layer
|
|
post_norm (bool): whether to use post normalization
|
|
layer_scale (float): layer scale
|
|
offset_scale (float): offset scale
|
|
with_cp (bool): whether to use checkpoint
|
|
"""
|
|
|
|
def __init__(self,
|
|
core_op,
|
|
channels,
|
|
groups,
|
|
mlp_ratio=4.,
|
|
drop=0.,
|
|
drop_path=0.,
|
|
act_layer='GELU',
|
|
norm_layer='LN',
|
|
post_norm=False,
|
|
layer_scale=None,
|
|
offset_scale=1.0,
|
|
with_cp=False,
|
|
dw_kernel_size=None,
|
|
res_post_norm=False,
|
|
center_feature_scale=False):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.groups = groups
|
|
self.mlp_ratio = mlp_ratio
|
|
self.with_cp = with_cp
|
|
|
|
self.norm1 = build_norm_layer(channels, 'LN')
|
|
self.post_norm = post_norm
|
|
self.dcn = core_op(
|
|
channels=channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
pad=1,
|
|
dilation=1,
|
|
group=groups,
|
|
offset_scale=offset_scale,
|
|
act_layer=act_layer,
|
|
norm_layer=norm_layer,
|
|
dw_kernel_size=dw_kernel_size,
|
|
center_feature_scale=center_feature_scale)
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
else nn.Identity()
|
|
self.norm2 = build_norm_layer(channels, 'LN')
|
|
self.mlp = MLPLayer(in_features=channels,
|
|
hidden_features=int(channels * mlp_ratio),
|
|
act_layer=act_layer,
|
|
drop=drop)
|
|
self.layer_scale = layer_scale is not None
|
|
if self.layer_scale:
|
|
self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels),
|
|
requires_grad=True)
|
|
self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels),
|
|
requires_grad=True)
|
|
self.res_post_norm = res_post_norm
|
|
if res_post_norm:
|
|
self.res_post_norm1 = build_norm_layer(channels, 'LN')
|
|
self.res_post_norm2 = build_norm_layer(channels, 'LN')
|
|
|
|
def forward(self, x):
|
|
|
|
def _inner_forward(x):
|
|
if not self.layer_scale:
|
|
if self.post_norm:
|
|
x = x + self.drop_path(self.norm1(self.dcn(x)))
|
|
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
|
elif self.res_post_norm:
|
|
x = x + self.drop_path(self.res_post_norm1(self.dcn(self.norm1(x))))
|
|
x = x + self.drop_path(self.res_post_norm2(self.mlp(self.norm2(x))))
|
|
else:
|
|
x = x + self.drop_path(self.dcn(self.norm1(x)))
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
return x
|
|
if self.post_norm:
|
|
x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x)))
|
|
x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x)))
|
|
else:
|
|
x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x)))
|
|
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
|
|
return x
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
x = checkpoint.checkpoint(_inner_forward, x)
|
|
else:
|
|
x = _inner_forward(x)
|
|
return x
|
|
|
|
|
|
class InternImageBlock(nn.Module):
|
|
r""" Block of InternImage
|
|
Args:
|
|
core_op (nn.Module): core operation of InternImage
|
|
channels (int): number of input channels
|
|
depths (list): Depth of each block.
|
|
groups (list): Groups of each block.
|
|
mlp_ratio (float): ratio of mlp hidden features to input channels
|
|
drop (float): dropout rate
|
|
drop_path (float): drop path rate
|
|
act_layer (str): activation layer
|
|
norm_layer (str): normalization layer
|
|
post_norm (bool): whether to use post normalization
|
|
layer_scale (float): layer scale
|
|
offset_scale (float): offset scale
|
|
with_cp (bool): whether to use checkpoint
|
|
"""
|
|
|
|
def __init__(self,
|
|
core_op,
|
|
channels,
|
|
depth,
|
|
groups,
|
|
downsample=True,
|
|
mlp_ratio=4.,
|
|
drop=0.,
|
|
drop_path=0.,
|
|
act_layer='GELU',
|
|
norm_layer='LN',
|
|
post_norm=False,
|
|
offset_scale=1.0,
|
|
layer_scale=None,
|
|
with_cp=False,
|
|
dw_kernel_size=None,
|
|
post_norm_block_ids=None,
|
|
res_post_norm=False,
|
|
center_feature_scale=False):
|
|
super().__init__()
|
|
self.channels = channels
|
|
self.depth = depth
|
|
self.post_norm = post_norm
|
|
self.center_feature_scale = center_feature_scale
|
|
|
|
self.blocks = nn.ModuleList([
|
|
InternImageLayer(
|
|
core_op=core_op,
|
|
channels=channels,
|
|
groups=groups,
|
|
mlp_ratio=mlp_ratio,
|
|
drop=drop,
|
|
drop_path=drop_path[i] if isinstance(
|
|
drop_path, list) else drop_path,
|
|
act_layer=act_layer,
|
|
norm_layer=norm_layer,
|
|
post_norm=post_norm,
|
|
layer_scale=layer_scale,
|
|
offset_scale=offset_scale,
|
|
with_cp=with_cp,
|
|
dw_kernel_size=dw_kernel_size,
|
|
res_post_norm=res_post_norm,
|
|
center_feature_scale=center_feature_scale
|
|
) for i in range(depth)
|
|
])
|
|
if not self.post_norm or center_feature_scale:
|
|
self.norm = build_norm_layer(channels, 'LN')
|
|
self.post_norm_block_ids = post_norm_block_ids
|
|
if post_norm_block_ids is not None:
|
|
self.post_norms = nn.ModuleList(
|
|
[build_norm_layer(channels, 'LN', eps=1e-6) for _ in post_norm_block_ids]
|
|
)
|
|
self.downsample = DownsampleLayer(
|
|
channels=channels, norm_layer=norm_layer) if downsample else None
|
|
|
|
def forward(self, x, return_wo_downsample=False):
|
|
for i, blk in enumerate(self.blocks):
|
|
x = blk(x)
|
|
if (self.post_norm_block_ids is not None) and (i in self.post_norm_block_ids):
|
|
index = self.post_norm_block_ids.index(i)
|
|
x = self.post_norms[index](x)
|
|
if not self.post_norm or self.center_feature_scale:
|
|
x = self.norm(x)
|
|
if return_wo_downsample:
|
|
x_ = x
|
|
if self.downsample is not None:
|
|
x = self.downsample(x)
|
|
|
|
if return_wo_downsample:
|
|
return x, x_
|
|
return x
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class InternImage(nn.Module):
|
|
r""" InternImage
|
|
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
|
|
https://arxiv.org/pdf/2103.14030
|
|
Args:
|
|
core_op (str): Core operator. Default: 'DCNv3'
|
|
channels (int): Number of the first stage. Default: 64
|
|
depths (list): Depth of each block. Default: [3, 4, 18, 5]
|
|
groups (list): Groups of each block. Default: [3, 6, 12, 24]
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
|
drop_rate (float): Probability of an element to be zeroed. Default: 0.
|
|
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
|
act_layer (str): Activation layer. Default: 'GELU'
|
|
norm_layer (str): Normalization layer. Default: 'LN'
|
|
layer_scale (bool): Whether to use layer scale. Default: False
|
|
cls_scale (bool): Whether to use class scale. Default: False
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
dw_kernel_size (int): Size of the dwconv. Default: None
|
|
level2_post_norm (bool): Whether to use level2 post norm. Default: False
|
|
level2_post_norm_block_ids (list): Indexes of post norm blocks. Default: None
|
|
res_post_norm (bool): Whether to use res post norm. Default: False
|
|
center_feature_scale (bool): Whether to use center feature scale. Default: False
|
|
"""
|
|
|
|
def __init__(self,
|
|
core_op='DCNv3',
|
|
channels=320,
|
|
depths=[6, 6, 32, 6],
|
|
groups=[10, 20, 40, 80],
|
|
mlp_ratio=4.,
|
|
drop_rate=0.,
|
|
drop_path_rate=0.,
|
|
drop_path_type='linear',
|
|
act_layer='GELU',
|
|
norm_layer='LN',
|
|
layer_scale=None,
|
|
offset_scale=1.0,
|
|
post_norm=False,
|
|
with_cp=True,
|
|
dw_kernel_size=5,
|
|
level2_post_norm=True,
|
|
level2_post_norm_block_ids=[5, 11, 17, 23, 29],
|
|
res_post_norm=True,
|
|
center_feature_scale=True,
|
|
out_indices=(2, 3),
|
|
frozen_stages=2,
|
|
init_cfg=None,
|
|
**kwargs):
|
|
super().__init__()
|
|
self.core_op = core_op
|
|
self.num_levels = len(depths)
|
|
self.depths = depths
|
|
self.channels = channels
|
|
self.num_features = int(channels * 2 ** (self.num_levels - 1))
|
|
self.post_norm = post_norm
|
|
self.mlp_ratio = mlp_ratio
|
|
self.init_cfg = init_cfg
|
|
self.out_indices = out_indices
|
|
self.level2_post_norm_block_ids = level2_post_norm_block_ids
|
|
|
|
in_chans = 3
|
|
self.patch_embed = StemLayer(in_chans=in_chans,
|
|
out_chans=channels,
|
|
act_layer=act_layer,
|
|
norm_layer=norm_layer)
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
dpr = [
|
|
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
|
]
|
|
if drop_path_type == 'uniform':
|
|
for i in range(len(dpr)):
|
|
dpr[i] = drop_path_rate
|
|
|
|
self.levels = nn.ModuleList()
|
|
for i in range(self.num_levels):
|
|
post_norm_block_ids = level2_post_norm_block_ids if level2_post_norm and (
|
|
i == 2) else None
|
|
level = InternImageBlock(
|
|
core_op=getattr(opsm, core_op),
|
|
channels=int(channels * 2 ** i),
|
|
depth=depths[i],
|
|
groups=groups[i],
|
|
mlp_ratio=self.mlp_ratio,
|
|
drop=drop_rate,
|
|
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
|
act_layer=act_layer,
|
|
norm_layer=norm_layer,
|
|
post_norm=post_norm,
|
|
downsample=(i < self.num_levels - 1),
|
|
layer_scale=layer_scale,
|
|
offset_scale=offset_scale,
|
|
with_cp=with_cp,
|
|
dw_kernel_size=dw_kernel_size,
|
|
post_norm_block_ids=post_norm_block_ids,
|
|
res_post_norm=res_post_norm,
|
|
center_feature_scale=center_feature_scale
|
|
)
|
|
self.levels.append(level)
|
|
self.frozen_stages = frozen_stages
|
|
self.num_layers = len(depths)
|
|
self.apply(self._init_weights)
|
|
self.apply(self._init_deform_weights)
|
|
self._freeze_stages()
|
|
|
|
def init_weights(self):
|
|
if self.init_cfg is None:
|
|
logger.warning(f'No pre-trained weights for '
|
|
f'{self.__class__.__name__}, '
|
|
f'training start from scratch')
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Linear):
|
|
trunc_normal_init(m, std=.02, bias=0.)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
constant_init(m, 1.0)
|
|
else:
|
|
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
|
f'specify `Pretrained` in ' \
|
|
f'`init_cfg` in ' \
|
|
f'{self.__class__.__name__} '
|
|
ckpt = _load_checkpoint(self.init_cfg['checkpoint'],
|
|
logger=logger,
|
|
map_location='cpu')
|
|
if 'state_dict' in ckpt:
|
|
_state_dict = ckpt['state_dict']
|
|
elif 'model' in ckpt:
|
|
_state_dict = ckpt['model']
|
|
else:
|
|
_state_dict = ckpt
|
|
|
|
state_dict = OrderedDict()
|
|
for k, v in _state_dict.items():
|
|
if k.startswith('backbone.'):
|
|
state_dict[k[9:]] = v
|
|
else:
|
|
state_dict[k] = v
|
|
|
|
|
|
if list(state_dict.keys())[0].startswith('module.'):
|
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
|
|
|
|
|
meg = self.load_state_dict(state_dict, False)
|
|
logger.info(meg)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Linear):
|
|
trunc_normal_(m.weight, std=.02)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
def _init_deform_weights(self, m):
|
|
if isinstance(m, getattr(opsm, self.core_op)):
|
|
m._reset_parameters()
|
|
|
|
def _freeze_stages(self):
|
|
if self.frozen_stages >= 0:
|
|
for level in self.levels[:self.frozen_stages]:
|
|
level.eval()
|
|
for param in level.parameters():
|
|
param.requires_grad = False
|
|
|
|
def forward(self, x):
|
|
x = self.patch_embed(x)
|
|
x = self.pos_drop(x)
|
|
|
|
seq_out = []
|
|
for level_idx, level in enumerate(self.levels):
|
|
x, x_ = level(x, return_wo_downsample=True)
|
|
if level_idx in self.out_indices:
|
|
seq_out.append(x_.permute(0, 3, 1, 2).contiguous())
|
|
return seq_out
|
|
|