|
|
|
|
|
|
|
import fvcore.nn.weight_init as weight_init
|
|
from torch import nn
|
|
|
|
from .batch_norm import FrozenBatchNorm2d, get_norm
|
|
from .wrappers import Conv2d
|
|
|
|
|
|
"""
|
|
CNN building blocks.
|
|
"""
|
|
|
|
|
|
class CNNBlockBase(nn.Module):
|
|
"""
|
|
A CNN block is assumed to have input channels, output channels and a stride.
|
|
The input and output of `forward()` method must be NCHW tensors.
|
|
The method can perform arbitrary computation but must match the given
|
|
channels and stride specification.
|
|
|
|
Attribute:
|
|
in_channels (int):
|
|
out_channels (int):
|
|
stride (int):
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, stride):
|
|
"""
|
|
The `__init__` method of any subclass should also contain these arguments.
|
|
|
|
Args:
|
|
in_channels (int):
|
|
out_channels (int):
|
|
stride (int):
|
|
"""
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.stride = stride
|
|
|
|
def freeze(self):
|
|
"""
|
|
Make this block not trainable.
|
|
This method sets all parameters to `requires_grad=False`,
|
|
and convert all BatchNorm layers to FrozenBatchNorm
|
|
|
|
Returns:
|
|
the block itself
|
|
"""
|
|
for p in self.parameters():
|
|
p.requires_grad = False
|
|
FrozenBatchNorm2d.convert_frozen_batchnorm(self)
|
|
return self
|
|
|
|
|
|
class DepthwiseSeparableConv2d(nn.Module):
|
|
"""
|
|
A kxk depthwise convolution + a 1x1 convolution.
|
|
|
|
In :paper:`xception`, norm & activation are applied on the second conv.
|
|
:paper:`mobilenet` uses norm & activation on both convs.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
dilation=1,
|
|
*,
|
|
norm1=None,
|
|
activation1=None,
|
|
norm2=None,
|
|
activation2=None,
|
|
):
|
|
"""
|
|
Args:
|
|
norm1, norm2 (str or callable): normalization for the two conv layers.
|
|
activation1, activation2 (callable(Tensor) -> Tensor): activation
|
|
function for the two conv layers.
|
|
"""
|
|
super().__init__()
|
|
self.depthwise = Conv2d(
|
|
in_channels,
|
|
in_channels,
|
|
kernel_size=kernel_size,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=in_channels,
|
|
bias=not norm1,
|
|
norm=get_norm(norm1, in_channels),
|
|
activation=activation1,
|
|
)
|
|
self.pointwise = Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
bias=not norm2,
|
|
norm=get_norm(norm2, out_channels),
|
|
activation=activation2,
|
|
)
|
|
|
|
|
|
weight_init.c2_msra_fill(self.depthwise)
|
|
weight_init.c2_msra_fill(self.pointwise)
|
|
|
|
def forward(self, x):
|
|
return self.pointwise(self.depthwise(x))
|
|
|