Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
# Modified from official impl https://github.com/apple/ml-mobileone/blob/main/mobileone.py # noqa: E501 | |
from typing import Optional, Sequence | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer | |
from mmengine.model import BaseModule, ModuleList, Sequential | |
from torch.nn.modules.batchnorm import _BatchNorm | |
from mmpretrain.registry import MODELS | |
from ..utils.se_layer import SELayer | |
from .base_backbone import BaseBackbone | |
class MobileOneBlock(BaseModule): | |
"""MobileOne block for MobileOne backbone. | |
Args: | |
in_channels (int): The input channels of the block. | |
out_channels (int): The output channels of the block. | |
kernel_size (int): The kernel size of the convs in the block. If the | |
kernel size is large than 1, there will be a ``branch_scale`` in | |
the block. | |
num_convs (int): Number of the convolution branches in the block. | |
stride (int): Stride of convolution layers. Defaults to 1. | |
padding (int): Padding of the convolution layers. Defaults to 1. | |
dilation (int): Dilation of the convolution layers. Defaults to 1. | |
groups (int): Groups of the convolution layers. Defaults to 1. | |
se_cfg (None or dict): The configuration of the se module. | |
Defaults to None. | |
norm_cfg (dict): Configuration to construct and config norm layer. | |
Defaults to ``dict(type='BN')``. | |
act_cfg (dict): Config dict for activation layer. | |
Defaults to ``dict(type='ReLU')``. | |
deploy (bool): Whether the model structure is in the deployment mode. | |
Defaults to False. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
num_convs: int, | |
stride: int = 1, | |
padding: int = 1, | |
dilation: int = 1, | |
groups: int = 1, | |
se_cfg: Optional[dict] = None, | |
conv_cfg: Optional[dict] = None, | |
norm_cfg: Optional[dict] = dict(type='BN'), | |
act_cfg: Optional[dict] = dict(type='ReLU'), | |
deploy: bool = False, | |
init_cfg: Optional[dict] = None): | |
super(MobileOneBlock, self).__init__(init_cfg) | |
assert se_cfg is None or isinstance(se_cfg, dict) | |
if se_cfg is not None: | |
self.se = SELayer(channels=out_channels, **se_cfg) | |
else: | |
self.se = nn.Identity() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.num_conv_branches = num_convs | |
self.stride = stride | |
self.padding = padding | |
self.se_cfg = se_cfg | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.deploy = deploy | |
self.groups = groups | |
self.dilation = dilation | |
if deploy: | |
self.branch_reparam = build_conv_layer( | |
conv_cfg, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
groups=self.groups, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=True) | |
else: | |
# judge if input shape and output shape are the same. | |
# If true, add a normalized identity shortcut. | |
if out_channels == in_channels and stride == 1: | |
self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1] | |
else: | |
self.branch_norm = None | |
self.branch_scale = None | |
if kernel_size > 1: | |
self.branch_scale = self.create_conv_bn(kernel_size=1) | |
self.branch_conv_list = ModuleList() | |
for _ in range(num_convs): | |
self.branch_conv_list.append( | |
self.create_conv_bn( | |
kernel_size=kernel_size, | |
padding=padding, | |
dilation=dilation)) | |
self.act = build_activation_layer(act_cfg) | |
def create_conv_bn(self, kernel_size, dilation=1, padding=0): | |
"""cearte a (conv + bn) Sequential layer.""" | |
conv_bn = Sequential() | |
conv_bn.add_module( | |
'conv', | |
build_conv_layer( | |
self.conv_cfg, | |
in_channels=self.in_channels, | |
out_channels=self.out_channels, | |
kernel_size=kernel_size, | |
groups=self.groups, | |
stride=self.stride, | |
dilation=dilation, | |
padding=padding, | |
bias=False)) | |
conv_bn.add_module( | |
'norm', | |
build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1]) | |
return conv_bn | |
def forward(self, x): | |
def _inner_forward(inputs): | |
if self.deploy: | |
return self.branch_reparam(inputs) | |
inner_out = 0 | |
if self.branch_norm is not None: | |
inner_out = self.branch_norm(inputs) | |
if self.branch_scale is not None: | |
inner_out += self.branch_scale(inputs) | |
for branch_conv in self.branch_conv_list: | |
inner_out += branch_conv(inputs) | |
return inner_out | |
return self.act(self.se(_inner_forward(x))) | |
def switch_to_deploy(self): | |
"""Switch the model structure from training mode to deployment mode.""" | |
if self.deploy: | |
return | |
assert self.norm_cfg['type'] == 'BN', \ | |
"Switch is not allowed when norm_cfg['type'] != 'BN'." | |
reparam_weight, reparam_bias = self.reparameterize() | |
self.branch_reparam = build_conv_layer( | |
self.conv_cfg, | |
self.in_channels, | |
self.out_channels, | |
kernel_size=self.kernel_size, | |
stride=self.stride, | |
padding=self.padding, | |
dilation=self.dilation, | |
groups=self.groups, | |
bias=True) | |
self.branch_reparam.weight.data = reparam_weight | |
self.branch_reparam.bias.data = reparam_bias | |
for param in self.parameters(): | |
param.detach_() | |
delattr(self, 'branch_conv_list') | |
if hasattr(self, 'branch_scale'): | |
delattr(self, 'branch_scale') | |
delattr(self, 'branch_norm') | |
self.deploy = True | |
def reparameterize(self): | |
"""Fuse all the parameters of all branches. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all | |
branches. the first element is the weights and the second is | |
the bias. | |
""" | |
weight_conv, bias_conv = 0, 0 | |
for branch_conv in self.branch_conv_list: | |
weight, bias = self._fuse_conv_bn(branch_conv) | |
weight_conv += weight | |
bias_conv += bias | |
weight_scale, bias_scale = 0, 0 | |
if self.branch_scale is not None: | |
weight_scale, bias_scale = self._fuse_conv_bn(self.branch_scale) | |
# Pad scale branch kernel to match conv branch kernel size. | |
pad = self.kernel_size // 2 | |
weight_scale = F.pad(weight_scale, [pad, pad, pad, pad]) | |
weight_norm, bias_norm = 0, 0 | |
if self.branch_norm: | |
tmp_conv_bn = self._norm_to_conv(self.branch_norm) | |
weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn) | |
return (weight_conv + weight_scale + weight_norm, | |
bias_conv + bias_scale + bias_norm) | |
def _fuse_conv_bn(self, branch): | |
"""Fuse the parameters in a branch with a conv and bn. | |
Args: | |
branch (mmcv.runner.Sequential): A branch with conv and bn. | |
Returns: | |
tuple[torch.Tensor, torch.Tensor]: The parameters obtained after | |
fusing the parameters of conv and bn in one branch. | |
The first element is the weight and the second is the bias. | |
""" | |
if branch is None: | |
return 0, 0 | |
kernel = branch.conv.weight | |
running_mean = branch.norm.running_mean | |
running_var = branch.norm.running_var | |
gamma = branch.norm.weight | |
beta = branch.norm.bias | |
eps = branch.norm.eps | |
std = (running_var + eps).sqrt() | |
fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * kernel | |
fused_bias = beta - running_mean * gamma / std | |
return fused_weight, fused_bias | |
def _norm_to_conv(self, branch_nrom): | |
"""Convert a norm layer to a conv-bn sequence towards | |
``self.kernel_size``. | |
Args: | |
branch (nn.BatchNorm2d): A branch only with bn in the block. | |
Returns: | |
(mmcv.runner.Sequential): a sequential with conv and bn. | |
""" | |
input_dim = self.in_channels // self.groups | |
conv_weight = torch.zeros( | |
(self.in_channels, input_dim, self.kernel_size, self.kernel_size), | |
dtype=branch_nrom.weight.dtype) | |
for i in range(self.in_channels): | |
conv_weight[i, i % input_dim, self.kernel_size // 2, | |
self.kernel_size // 2] = 1 | |
conv_weight = conv_weight.to(branch_nrom.weight.device) | |
tmp_conv = self.create_conv_bn(kernel_size=self.kernel_size) | |
tmp_conv.conv.weight.data = conv_weight | |
tmp_conv.norm = branch_nrom | |
return tmp_conv | |
class MobileOne(BaseBackbone): | |
"""MobileOne backbone. | |
A PyTorch impl of : `An Improved One millisecond Mobile Backbone | |
<https://arxiv.org/pdf/2206.04040.pdf>`_ | |
Args: | |
arch (str | dict): MobileOne architecture. If use string, choose | |
from 's0', 's1', 's2', 's3' and 's4'. If use dict, it should | |
have below keys: | |
- num_blocks (Sequence[int]): Number of blocks in each stage. | |
- width_factor (Sequence[float]): Width factor in each stage. | |
- num_conv_branches (Sequence[int]): Number of conv branches | |
in each stage. | |
- num_se_blocks (Sequence[int]): Number of SE layers in each | |
stage, all the SE layers are placed in the subsequent order | |
in each stage. | |
Defaults to 's0'. | |
in_channels (int): Number of input image channels. Default: 3. | |
out_indices (Sequence[int] | int): Output from which stages. | |
Defaults to ``(3, )``. | |
frozen_stages (int): Stages to be frozen (all param fixed). -1 means | |
not freezing any parameters. Defaults to -1. | |
conv_cfg (dict | None): The config dict for conv layers. | |
Defaults to None. | |
norm_cfg (dict): The config dict for norm layers. | |
Defaults to ``dict(type='BN')``. | |
act_cfg (dict): Config dict for activation layer. | |
Defaults to ``dict(type='ReLU')``. | |
deploy (bool): Whether to switch the model structure to deployment | |
mode. Defaults to False. | |
norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
freeze running stats (mean and var). Note: Effect on Batch Norm | |
and its variants only. Defaults to False. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Example: | |
>>> from mmpretrain.models import MobileOne | |
>>> import torch | |
>>> x = torch.rand(1, 3, 224, 224) | |
>>> model = MobileOne("s0", out_indices=(0, 1, 2, 3)) | |
>>> model.eval() | |
>>> outputs = model(x) | |
>>> for out in outputs: | |
... print(tuple(out.shape)) | |
(1, 48, 56, 56) | |
(1, 128, 28, 28) | |
(1, 256, 14, 14) | |
(1, 1024, 7, 7) | |
""" | |
arch_zoo = { | |
's0': | |
dict( | |
num_blocks=[2, 8, 10, 1], | |
width_factor=[0.75, 1.0, 1.0, 2.0], | |
num_conv_branches=[4, 4, 4, 4], | |
num_se_blocks=[0, 0, 0, 0]), | |
's1': | |
dict( | |
num_blocks=[2, 8, 10, 1], | |
width_factor=[1.5, 1.5, 2.0, 2.5], | |
num_conv_branches=[1, 1, 1, 1], | |
num_se_blocks=[0, 0, 0, 0]), | |
's2': | |
dict( | |
num_blocks=[2, 8, 10, 1], | |
width_factor=[1.5, 2.0, 2.5, 4.0], | |
num_conv_branches=[1, 1, 1, 1], | |
num_se_blocks=[0, 0, 0, 0]), | |
's3': | |
dict( | |
num_blocks=[2, 8, 10, 1], | |
width_factor=[2.0, 2.5, 3.0, 4.0], | |
num_conv_branches=[1, 1, 1, 1], | |
num_se_blocks=[0, 0, 0, 0]), | |
's4': | |
dict( | |
num_blocks=[2, 8, 10, 1], | |
width_factor=[3.0, 3.5, 3.5, 4.0], | |
num_conv_branches=[1, 1, 1, 1], | |
num_se_blocks=[0, 0, 5, 1]) | |
} | |
def __init__(self, | |
arch, | |
in_channels=3, | |
out_indices=(3, ), | |
frozen_stages=-1, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
se_cfg=dict(ratio=16), | |
deploy=False, | |
norm_eval=False, | |
init_cfg=[ | |
dict(type='Kaiming', layer=['Conv2d']), | |
dict(type='Constant', val=1, layer=['_BatchNorm']) | |
]): | |
super(MobileOne, self).__init__(init_cfg) | |
if isinstance(arch, str): | |
assert arch in self.arch_zoo, f'"arch": "{arch}"' \ | |
f' is not one of the {list(self.arch_zoo.keys())}' | |
arch = self.arch_zoo[arch] | |
elif not isinstance(arch, dict): | |
raise TypeError('Expect "arch" to be either a string ' | |
f'or a dict, got {type(arch)}') | |
self.arch = arch | |
for k, value in self.arch.items(): | |
assert isinstance(value, list) and len(value) == 4, \ | |
f'the value of {k} in arch must be list with 4 items.' | |
self.in_channels = in_channels | |
self.deploy = deploy | |
self.frozen_stages = frozen_stages | |
self.norm_eval = norm_eval | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.se_cfg = se_cfg | |
self.act_cfg = act_cfg | |
base_channels = [64, 128, 256, 512] | |
channels = min(64, | |
int(base_channels[0] * self.arch['width_factor'][0])) | |
self.stage0 = MobileOneBlock( | |
self.in_channels, | |
channels, | |
stride=2, | |
kernel_size=3, | |
num_convs=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
deploy=deploy) | |
self.in_planes = channels | |
self.stages = [] | |
for i, num_blocks in enumerate(self.arch['num_blocks']): | |
planes = int(base_channels[i] * self.arch['width_factor'][i]) | |
stage = self._make_stage(planes, num_blocks, | |
arch['num_se_blocks'][i], | |
arch['num_conv_branches'][i]) | |
stage_name = f'stage{i + 1}' | |
self.add_module(stage_name, stage) | |
self.stages.append(stage_name) | |
if isinstance(out_indices, int): | |
out_indices = [out_indices] | |
assert isinstance(out_indices, Sequence), \ | |
f'"out_indices" must by a sequence or int, ' \ | |
f'get {type(out_indices)} instead.' | |
out_indices = list(out_indices) | |
for i, index in enumerate(out_indices): | |
if index < 0: | |
out_indices[i] = len(self.stages) + index | |
assert 0 <= out_indices[i] <= len(self.stages), \ | |
f'Invalid out_indices {index}.' | |
self.out_indices = out_indices | |
def _make_stage(self, planes, num_blocks, num_se, num_conv_branches): | |
strides = [2] + [1] * (num_blocks - 1) | |
if num_se > num_blocks: | |
raise ValueError('Number of SE blocks cannot ' | |
'exceed number of layers.') | |
blocks = [] | |
for i in range(num_blocks): | |
use_se = False | |
if i >= (num_blocks - num_se): | |
use_se = True | |
blocks.append( | |
# Depthwise conv | |
MobileOneBlock( | |
in_channels=self.in_planes, | |
out_channels=self.in_planes, | |
kernel_size=3, | |
num_convs=num_conv_branches, | |
stride=strides[i], | |
padding=1, | |
groups=self.in_planes, | |
se_cfg=self.se_cfg if use_se else None, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg, | |
deploy=self.deploy)) | |
blocks.append( | |
# Pointwise conv | |
MobileOneBlock( | |
in_channels=self.in_planes, | |
out_channels=planes, | |
kernel_size=1, | |
num_convs=num_conv_branches, | |
stride=1, | |
padding=0, | |
se_cfg=self.se_cfg if use_se else None, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg, | |
deploy=self.deploy)) | |
self.in_planes = planes | |
return Sequential(*blocks) | |
def forward(self, x): | |
x = self.stage0(x) | |
outs = [] | |
for i, stage_name in enumerate(self.stages): | |
stage = getattr(self, stage_name) | |
x = stage(x) | |
if i in self.out_indices: | |
outs.append(x) | |
return tuple(outs) | |
def _freeze_stages(self): | |
if self.frozen_stages >= 0: | |
self.stage0.eval() | |
for param in self.stage0.parameters(): | |
param.requires_grad = False | |
for i in range(self.frozen_stages): | |
stage = getattr(self, f'stage{i+1}') | |
stage.eval() | |
for param in stage.parameters(): | |
param.requires_grad = False | |
def train(self, mode=True): | |
"""switch the mobile to train mode or not.""" | |
super(MobileOne, self).train(mode) | |
self._freeze_stages() | |
if mode and self.norm_eval: | |
for m in self.modules(): | |
if isinstance(m, _BatchNorm): | |
m.eval() | |
def switch_to_deploy(self): | |
"""switch the model to deploy mode, which has smaller amount of | |
parameters and calculations.""" | |
for m in self.modules(): | |
if isinstance(m, MobileOneBlock): | |
m.switch_to_deploy() | |
self.deploy = True | |