|
""" EfficientNet / MobileNetV3 Blocks and Builder |
|
|
|
Copyright 2020 Ross Wightman |
|
""" |
|
import re |
|
from copy import deepcopy |
|
|
|
from .conv2d_layers import * |
|
from geffnet.activations import * |
|
|
|
__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible', |
|
'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', |
|
'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def', |
|
'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT' |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 |
|
BN_EPS_TF_DEFAULT = 1e-3 |
|
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) |
|
|
|
|
|
def get_bn_args_tf(): |
|
return _BN_ARGS_TF.copy() |
|
|
|
|
|
def resolve_bn_args(kwargs): |
|
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} |
|
bn_momentum = kwargs.pop('bn_momentum', None) |
|
if bn_momentum is not None: |
|
bn_args['momentum'] = bn_momentum |
|
bn_eps = kwargs.pop('bn_eps', None) |
|
if bn_eps is not None: |
|
bn_args['eps'] = bn_eps |
|
return bn_args |
|
|
|
|
|
_SE_ARGS_DEFAULT = dict( |
|
gate_fn=sigmoid, |
|
act_layer=None, |
|
reduce_mid=False, |
|
divisor=1) |
|
|
|
|
|
def resolve_se_args(kwargs, in_chs, act_layer=None): |
|
se_kwargs = kwargs.copy() if kwargs is not None else {} |
|
|
|
for k, v in _SE_ARGS_DEFAULT.items(): |
|
se_kwargs.setdefault(k, v) |
|
|
|
if not se_kwargs.pop('reduce_mid'): |
|
se_kwargs['reduced_base_chs'] = in_chs |
|
|
|
if se_kwargs['act_layer'] is None: |
|
assert act_layer is not None |
|
se_kwargs['act_layer'] = act_layer |
|
return se_kwargs |
|
|
|
|
|
def resolve_act_layer(kwargs, default='relu'): |
|
act_layer = kwargs.pop('act_layer', default) |
|
if isinstance(act_layer, str): |
|
act_layer = get_act_layer(act_layer) |
|
return act_layer |
|
|
|
|
|
def make_divisible(v: int, divisor: int = 8, min_value: int = None): |
|
min_value = min_value or divisor |
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
|
if new_v < 0.9 * v: |
|
new_v += divisor |
|
return new_v |
|
|
|
|
|
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): |
|
"""Round number of filters based on depth multiplier.""" |
|
if not multiplier: |
|
return channels |
|
channels *= multiplier |
|
return make_divisible(channels, divisor, channel_min) |
|
|
|
|
|
def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): |
|
"""Apply drop connect.""" |
|
if not training: |
|
return inputs |
|
|
|
keep_prob = 1 - drop_connect_rate |
|
random_tensor = keep_prob + torch.rand( |
|
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) |
|
random_tensor.floor_() |
|
output = inputs.div(keep_prob) * random_tensor |
|
return output |
|
|
|
|
|
class SqueezeExcite(nn.Module): |
|
|
|
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1): |
|
super(SqueezeExcite, self).__init__() |
|
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) |
|
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) |
|
self.act1 = act_layer(inplace=True) |
|
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) |
|
self.gate_fn = gate_fn |
|
|
|
def forward(self, x): |
|
x_se = x.mean((2, 3), keepdim=True) |
|
x_se = self.conv_reduce(x_se) |
|
x_se = self.act1(x_se) |
|
x_se = self.conv_expand(x_se) |
|
x = x * self.gate_fn(x_se) |
|
return x |
|
|
|
|
|
class ConvBnAct(nn.Module): |
|
def __init__(self, in_chs, out_chs, kernel_size, |
|
stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): |
|
super(ConvBnAct, self).__init__() |
|
assert stride in [1, 2] |
|
norm_kwargs = norm_kwargs or {} |
|
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) |
|
self.bn1 = norm_layer(out_chs, **norm_kwargs) |
|
self.act1 = act_layer(inplace=True) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = self.bn1(x) |
|
x = self.act1(x) |
|
return x |
|
|
|
|
|
class DepthwiseSeparableConv(nn.Module): |
|
""" DepthwiseSeparable block |
|
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion |
|
factor of 1.0. This is an alternative to having a IR with optional first pw conv. |
|
""" |
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3, |
|
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, |
|
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, |
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): |
|
super(DepthwiseSeparableConv, self).__init__() |
|
assert stride in [1, 2] |
|
norm_kwargs = norm_kwargs or {} |
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip |
|
self.drop_connect_rate = drop_connect_rate |
|
|
|
self.conv_dw = select_conv2d( |
|
in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) |
|
self.bn1 = norm_layer(in_chs, **norm_kwargs) |
|
self.act1 = act_layer(inplace=True) |
|
|
|
|
|
if se_ratio is not None and se_ratio > 0.: |
|
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) |
|
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) |
|
else: |
|
self.se = nn.Identity() |
|
|
|
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) |
|
self.bn2 = norm_layer(out_chs, **norm_kwargs) |
|
self.act2 = act_layer(inplace=True) if pw_act else nn.Identity() |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
x = self.conv_dw(x) |
|
x = self.bn1(x) |
|
x = self.act1(x) |
|
|
|
x = self.se(x) |
|
|
|
x = self.conv_pw(x) |
|
x = self.bn2(x) |
|
x = self.act2(x) |
|
|
|
if self.has_residual: |
|
if self.drop_connect_rate > 0.: |
|
x = drop_connect(x, self.training, self.drop_connect_rate) |
|
x += residual |
|
return x |
|
|
|
|
|
class InvertedResidual(nn.Module): |
|
""" Inverted residual block w/ optional SE""" |
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3, |
|
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, |
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, |
|
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, |
|
conv_kwargs=None, drop_connect_rate=0.): |
|
super(InvertedResidual, self).__init__() |
|
norm_kwargs = norm_kwargs or {} |
|
conv_kwargs = conv_kwargs or {} |
|
mid_chs: int = make_divisible(in_chs * exp_ratio) |
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip |
|
self.drop_connect_rate = drop_connect_rate |
|
|
|
|
|
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) |
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs) |
|
self.act1 = act_layer(inplace=True) |
|
|
|
|
|
self.conv_dw = select_conv2d( |
|
mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs) |
|
self.bn2 = norm_layer(mid_chs, **norm_kwargs) |
|
self.act2 = act_layer(inplace=True) |
|
|
|
|
|
if se_ratio is not None and se_ratio > 0.: |
|
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) |
|
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) |
|
else: |
|
self.se = nn.Identity() |
|
|
|
|
|
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) |
|
self.bn3 = norm_layer(out_chs, **norm_kwargs) |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
|
|
x = self.conv_pw(x) |
|
x = self.bn1(x) |
|
x = self.act1(x) |
|
|
|
|
|
x = self.conv_dw(x) |
|
x = self.bn2(x) |
|
x = self.act2(x) |
|
|
|
|
|
x = self.se(x) |
|
|
|
|
|
x = self.conv_pwl(x) |
|
x = self.bn3(x) |
|
|
|
if self.has_residual: |
|
if self.drop_connect_rate > 0.: |
|
x = drop_connect(x, self.training, self.drop_connect_rate) |
|
x += residual |
|
return x |
|
|
|
|
|
class CondConvResidual(InvertedResidual): |
|
""" Inverted residual block w/ CondConv routing""" |
|
|
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3, |
|
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, |
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, |
|
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, |
|
num_experts=0, drop_connect_rate=0.): |
|
|
|
self.num_experts = num_experts |
|
conv_kwargs = dict(num_experts=self.num_experts) |
|
|
|
super(CondConvResidual, self).__init__( |
|
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type, |
|
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, |
|
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, |
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, |
|
drop_connect_rate=drop_connect_rate) |
|
|
|
self.routing_fn = nn.Linear(in_chs, self.num_experts) |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
|
|
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) |
|
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) |
|
|
|
|
|
x = self.conv_pw(x, routing_weights) |
|
x = self.bn1(x) |
|
x = self.act1(x) |
|
|
|
|
|
x = self.conv_dw(x, routing_weights) |
|
x = self.bn2(x) |
|
x = self.act2(x) |
|
|
|
|
|
x = self.se(x) |
|
|
|
|
|
x = self.conv_pwl(x, routing_weights) |
|
x = self.bn3(x) |
|
|
|
if self.has_residual: |
|
if self.drop_connect_rate > 0.: |
|
x = drop_connect(x, self.training, self.drop_connect_rate) |
|
x += residual |
|
return x |
|
|
|
|
|
class EdgeResidual(nn.Module): |
|
""" EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride""" |
|
|
|
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, |
|
stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, |
|
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): |
|
super(EdgeResidual, self).__init__() |
|
norm_kwargs = norm_kwargs or {} |
|
mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio) |
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip |
|
self.drop_connect_rate = drop_connect_rate |
|
|
|
|
|
self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) |
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs) |
|
self.act1 = act_layer(inplace=True) |
|
|
|
|
|
if se_ratio is not None and se_ratio > 0.: |
|
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) |
|
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) |
|
else: |
|
self.se = nn.Identity() |
|
|
|
|
|
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) |
|
self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs) |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
|
|
x = self.conv_exp(x) |
|
x = self.bn1(x) |
|
x = self.act1(x) |
|
|
|
|
|
x = self.se(x) |
|
|
|
|
|
x = self.conv_pwl(x) |
|
x = self.bn2(x) |
|
|
|
if self.has_residual: |
|
if self.drop_connect_rate > 0.: |
|
x = drop_connect(x, self.training, self.drop_connect_rate) |
|
x += residual |
|
|
|
return x |
|
|
|
|
|
class EfficientNetBuilder: |
|
""" Build Trunk Blocks for Efficient/Mobile Networks |
|
|
|
This ended up being somewhat of a cross between |
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py |
|
and |
|
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py |
|
|
|
""" |
|
|
|
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, |
|
pad_type='', act_layer=None, se_kwargs=None, |
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): |
|
self.channel_multiplier = channel_multiplier |
|
self.channel_divisor = channel_divisor |
|
self.channel_min = channel_min |
|
self.pad_type = pad_type |
|
self.act_layer = act_layer |
|
self.se_kwargs = se_kwargs |
|
self.norm_layer = norm_layer |
|
self.norm_kwargs = norm_kwargs |
|
self.drop_connect_rate = drop_connect_rate |
|
|
|
|
|
self.in_chs = None |
|
self.block_idx = 0 |
|
self.block_count = 0 |
|
|
|
def _round_channels(self, chs): |
|
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) |
|
|
|
def _make_block(self, ba): |
|
bt = ba.pop('block_type') |
|
ba['in_chs'] = self.in_chs |
|
ba['out_chs'] = self._round_channels(ba['out_chs']) |
|
if 'fake_in_chs' in ba and ba['fake_in_chs']: |
|
|
|
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) |
|
ba['norm_layer'] = self.norm_layer |
|
ba['norm_kwargs'] = self.norm_kwargs |
|
ba['pad_type'] = self.pad_type |
|
|
|
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer |
|
assert ba['act_layer'] is not None |
|
if bt == 'ir': |
|
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count |
|
ba['se_kwargs'] = self.se_kwargs |
|
if ba.get('num_experts', 0) > 0: |
|
block = CondConvResidual(**ba) |
|
else: |
|
block = InvertedResidual(**ba) |
|
elif bt == 'ds' or bt == 'dsa': |
|
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count |
|
ba['se_kwargs'] = self.se_kwargs |
|
block = DepthwiseSeparableConv(**ba) |
|
elif bt == 'er': |
|
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count |
|
ba['se_kwargs'] = self.se_kwargs |
|
block = EdgeResidual(**ba) |
|
elif bt == 'cn': |
|
block = ConvBnAct(**ba) |
|
else: |
|
assert False, 'Uknkown block type (%s) while building model.' % bt |
|
self.in_chs = ba['out_chs'] |
|
return block |
|
|
|
def _make_stack(self, stack_args): |
|
blocks = [] |
|
|
|
for i, ba in enumerate(stack_args): |
|
if i >= 1: |
|
|
|
ba['stride'] = 1 |
|
block = self._make_block(ba) |
|
blocks.append(block) |
|
self.block_idx += 1 |
|
return nn.Sequential(*blocks) |
|
|
|
def __call__(self, in_chs, block_args): |
|
""" Build the blocks |
|
Args: |
|
in_chs: Number of input-channels passed to first block |
|
block_args: A list of lists, outer list defines stages, inner |
|
list contains strings defining block configuration(s) |
|
Return: |
|
List of block stacks (each stack wrapped in nn.Sequential) |
|
""" |
|
self.in_chs = in_chs |
|
self.block_count = sum([len(x) for x in block_args]) |
|
self.block_idx = 0 |
|
blocks = [] |
|
|
|
for stack_idx, stack in enumerate(block_args): |
|
assert isinstance(stack, list) |
|
stack = self._make_stack(stack) |
|
blocks.append(stack) |
|
return blocks |
|
|
|
|
|
def _parse_ksize(ss): |
|
if ss.isdigit(): |
|
return int(ss) |
|
else: |
|
return [int(k) for k in ss.split('.')] |
|
|
|
|
|
def _decode_block_str(block_str): |
|
""" Decode block definition string |
|
|
|
Gets a list of block arg (dicts) through a string notation of arguments. |
|
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip |
|
|
|
All args can exist in any order with the exception of the leading string which |
|
is assumed to indicate the block type. |
|
|
|
leading string - block type ( |
|
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) |
|
r - number of repeat blocks, |
|
k - kernel size, |
|
s - strides (1-9), |
|
e - expansion ratio, |
|
c - output channels, |
|
se - squeeze/excitation ratio |
|
n - activation fn ('re', 'r6', 'hs', or 'sw') |
|
Args: |
|
block_str: a string representation of block arguments. |
|
Returns: |
|
A list of block args (dicts) |
|
Raises: |
|
ValueError: if the string def not properly specified (TODO) |
|
""" |
|
assert isinstance(block_str, str) |
|
ops = block_str.split('_') |
|
block_type = ops[0] |
|
ops = ops[1:] |
|
options = {} |
|
noskip = False |
|
for op in ops: |
|
|
|
if op == 'noskip': |
|
noskip = True |
|
elif op.startswith('n'): |
|
|
|
key = op[0] |
|
v = op[1:] |
|
if v == 're': |
|
value = get_act_layer('relu') |
|
elif v == 'r6': |
|
value = get_act_layer('relu6') |
|
elif v == 'hs': |
|
value = get_act_layer('hard_swish') |
|
elif v == 'sw': |
|
value = get_act_layer('swish') |
|
else: |
|
continue |
|
options[key] = value |
|
else: |
|
|
|
splits = re.split(r'(\d.*)', op) |
|
if len(splits) >= 2: |
|
key, value = splits[:2] |
|
options[key] = value |
|
|
|
|
|
act_layer = options['n'] if 'n' in options else None |
|
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 |
|
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 |
|
fake_in_chs = int(options['fc']) if 'fc' in options else 0 |
|
|
|
num_repeat = int(options['r']) |
|
|
|
if block_type == 'ir': |
|
block_args = dict( |
|
block_type=block_type, |
|
dw_kernel_size=_parse_ksize(options['k']), |
|
exp_kernel_size=exp_kernel_size, |
|
pw_kernel_size=pw_kernel_size, |
|
out_chs=int(options['c']), |
|
exp_ratio=float(options['e']), |
|
se_ratio=float(options['se']) if 'se' in options else None, |
|
stride=int(options['s']), |
|
act_layer=act_layer, |
|
noskip=noskip, |
|
) |
|
if 'cc' in options: |
|
block_args['num_experts'] = int(options['cc']) |
|
elif block_type == 'ds' or block_type == 'dsa': |
|
block_args = dict( |
|
block_type=block_type, |
|
dw_kernel_size=_parse_ksize(options['k']), |
|
pw_kernel_size=pw_kernel_size, |
|
out_chs=int(options['c']), |
|
se_ratio=float(options['se']) if 'se' in options else None, |
|
stride=int(options['s']), |
|
act_layer=act_layer, |
|
pw_act=block_type == 'dsa', |
|
noskip=block_type == 'dsa' or noskip, |
|
) |
|
elif block_type == 'er': |
|
block_args = dict( |
|
block_type=block_type, |
|
exp_kernel_size=_parse_ksize(options['k']), |
|
pw_kernel_size=pw_kernel_size, |
|
out_chs=int(options['c']), |
|
exp_ratio=float(options['e']), |
|
fake_in_chs=fake_in_chs, |
|
se_ratio=float(options['se']) if 'se' in options else None, |
|
stride=int(options['s']), |
|
act_layer=act_layer, |
|
noskip=noskip, |
|
) |
|
elif block_type == 'cn': |
|
block_args = dict( |
|
block_type=block_type, |
|
kernel_size=int(options['k']), |
|
out_chs=int(options['c']), |
|
stride=int(options['s']), |
|
act_layer=act_layer, |
|
) |
|
else: |
|
assert False, 'Unknown block type (%s)' % block_type |
|
|
|
return block_args, num_repeat |
|
|
|
|
|
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): |
|
""" Per-stage depth scaling |
|
Scales the block repeats in each stage. This depth scaling impl maintains |
|
compatibility with the EfficientNet scaling method, while allowing sensible |
|
scaling for other models that may have multiple block arg definitions in each stage. |
|
""" |
|
|
|
|
|
|
|
num_repeat = sum(repeats) |
|
if depth_trunc == 'round': |
|
|
|
|
|
|
|
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) |
|
else: |
|
|
|
|
|
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) |
|
|
|
|
|
|
|
|
|
repeats_scaled = [] |
|
for r in repeats[::-1]: |
|
rs = max(1, round((r / num_repeat * num_repeat_scaled))) |
|
repeats_scaled.append(rs) |
|
num_repeat -= r |
|
num_repeat_scaled -= rs |
|
repeats_scaled = repeats_scaled[::-1] |
|
|
|
|
|
sa_scaled = [] |
|
for ba, rep in zip(stack_args, repeats_scaled): |
|
sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) |
|
return sa_scaled |
|
|
|
|
|
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): |
|
arch_args = [] |
|
for stack_idx, block_strings in enumerate(arch_def): |
|
assert isinstance(block_strings, list) |
|
stack_args = [] |
|
repeats = [] |
|
for block_str in block_strings: |
|
assert isinstance(block_str, str) |
|
ba, rep = _decode_block_str(block_str) |
|
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: |
|
ba['num_experts'] *= experts_multiplier |
|
stack_args.append(ba) |
|
repeats.append(rep) |
|
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): |
|
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) |
|
else: |
|
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) |
|
return arch_args |
|
|
|
|
|
def initialize_weight_goog(m, n='', fix_group_fanout=True): |
|
|
|
|
|
if isinstance(m, CondConv2d): |
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
if fix_group_fanout: |
|
fan_out //= m.groups |
|
init_weight_fn = get_condconv_initializer( |
|
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) |
|
init_weight_fn(m.weight) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.Conv2d): |
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
if fix_group_fanout: |
|
fan_out //= m.groups |
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.BatchNorm2d): |
|
m.weight.data.fill_(1.0) |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.Linear): |
|
fan_out = m.weight.size(0) |
|
fan_in = 0 |
|
if 'routing_fn' in n: |
|
fan_in = m.weight.size(1) |
|
init_range = 1.0 / math.sqrt(fan_in + fan_out) |
|
m.weight.data.uniform_(-init_range, init_range) |
|
m.bias.data.zero_() |
|
|
|
|
|
def initialize_weight_default(m, n=''): |
|
if isinstance(m, CondConv2d): |
|
init_fn = get_condconv_initializer(partial( |
|
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) |
|
init_fn(m.weight) |
|
elif isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
elif isinstance(m, nn.BatchNorm2d): |
|
m.weight.data.fill_(1.0) |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.Linear): |
|
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') |
|
|