|
import math |
|
from typing import Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn.bricks import Swish, build_norm_layer |
|
from torch.nn import functional as F |
|
from torch.nn.init import _calculate_fan_in_and_fan_out, trunc_normal_ |
|
|
|
from mmdet.registry import MODELS |
|
from mmdet.utils import OptConfigType |
|
|
|
|
|
def variance_scaling_trunc(tensor, gain=1.): |
|
fan_in, _ = _calculate_fan_in_and_fan_out(tensor) |
|
gain /= max(1.0, fan_in) |
|
std = math.sqrt(gain) / .87962566103423978 |
|
return trunc_normal_(tensor, 0., std) |
|
|
|
|
|
@MODELS.register_module() |
|
class Conv2dSamePadding(nn.Conv2d): |
|
|
|
def __init__(self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: Union[int, Tuple[int, int]], |
|
stride: Union[int, Tuple[int, int]] = 1, |
|
padding: Union[int, Tuple[int, int]] = 0, |
|
dilation: Union[int, Tuple[int, int]] = 1, |
|
groups: int = 1, |
|
bias: bool = True): |
|
super().__init__(in_channels, out_channels, kernel_size, stride, 0, |
|
dilation, groups, bias) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
img_h, img_w = x.size()[-2:] |
|
kernel_h, kernel_w = self.weight.size()[-2:] |
|
extra_w = (math.ceil(img_w / self.stride[1]) - |
|
1) * self.stride[1] - img_w + kernel_w |
|
extra_h = (math.ceil(img_h / self.stride[0]) - |
|
1) * self.stride[0] - img_h + kernel_h |
|
|
|
left = extra_w // 2 |
|
right = extra_w - left |
|
top = extra_h // 2 |
|
bottom = extra_h - top |
|
x = F.pad(x, [left, right, top, bottom]) |
|
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, |
|
self.dilation, self.groups) |
|
|
|
|
|
class MaxPool2dSamePadding(nn.Module): |
|
|
|
def __init__(self, |
|
kernel_size: Union[int, Tuple[int, int]] = 3, |
|
stride: Union[int, Tuple[int, int]] = 2, |
|
**kwargs): |
|
super().__init__() |
|
self.pool = nn.MaxPool2d(kernel_size, stride, **kwargs) |
|
self.stride = self.pool.stride |
|
self.kernel_size = self.pool.kernel_size |
|
|
|
if isinstance(self.stride, int): |
|
self.stride = [self.stride] * 2 |
|
if isinstance(self.kernel_size, int): |
|
self.kernel_size = [self.kernel_size] * 2 |
|
|
|
def forward(self, x): |
|
h, w = x.shape[-2:] |
|
|
|
extra_h = (math.ceil(w / self.stride[1]) - |
|
1) * self.stride[1] - w + self.kernel_size[1] |
|
extra_v = (math.ceil(h / self.stride[0]) - |
|
1) * self.stride[0] - h + self.kernel_size[0] |
|
|
|
left = extra_h // 2 |
|
right = extra_h - left |
|
top = extra_v // 2 |
|
bottom = extra_v - top |
|
|
|
x = F.pad(x, [left, right, top, bottom]) |
|
x = self.pool(x) |
|
|
|
return x |
|
|
|
|
|
class DepthWiseConvBlock(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
apply_norm: bool = True, |
|
conv_bn_act_pattern: bool = False, |
|
norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3) |
|
) -> None: |
|
super(DepthWiseConvBlock, self).__init__() |
|
self.depthwise_conv = Conv2dSamePadding( |
|
in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=1, |
|
groups=in_channels, |
|
bias=False) |
|
self.pointwise_conv = Conv2dSamePadding( |
|
in_channels, out_channels, kernel_size=1, stride=1) |
|
|
|
self.apply_norm = apply_norm |
|
if self.apply_norm: |
|
self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1] |
|
|
|
self.apply_activation = conv_bn_act_pattern |
|
if self.apply_activation: |
|
self.swish = Swish() |
|
|
|
def forward(self, x): |
|
x = self.depthwise_conv(x) |
|
x = self.pointwise_conv(x) |
|
if self.apply_norm: |
|
x = self.bn(x) |
|
if self.apply_activation: |
|
x = self.swish(x) |
|
|
|
return x |
|
|
|
|
|
class DownChannelBlock(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
apply_norm: bool = True, |
|
conv_bn_act_pattern: bool = False, |
|
norm_cfg: OptConfigType = dict(type='BN', momentum=1e-2, eps=1e-3) |
|
) -> None: |
|
super(DownChannelBlock, self).__init__() |
|
self.down_conv = Conv2dSamePadding(in_channels, out_channels, 1) |
|
self.apply_norm = apply_norm |
|
if self.apply_norm: |
|
self.bn = build_norm_layer(norm_cfg, num_features=out_channels)[1] |
|
self.apply_activation = conv_bn_act_pattern |
|
if self.apply_activation: |
|
self.swish = Swish() |
|
|
|
def forward(self, x): |
|
x = self.down_conv(x) |
|
if self.apply_norm: |
|
x = self.bn(x) |
|
if self.apply_activation: |
|
x = self.swish(x) |
|
|
|
return x |
|
|