Spaces:
Runtime error
Runtime error
| # Copyright (c) Open-CD. All rights reserved. | |
| import warnings | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer | |
| from mmengine.model import BaseModule | |
| from torch.nn import functional as F | |
| from torch.nn.modules.batchnorm import _BatchNorm | |
| from torch.utils import checkpoint as cp | |
| from mmseg.models.utils import SELayer, make_divisible | |
| from opencd.registry import MODELS | |
| class AsymGlobalAttn(BaseModule): | |
| def __init__(self, dim, strip_kernel_size=21): | |
| super().__init__() | |
| self.norm = build_norm_layer(dict(type='mmpretrain.LN2d', eps=1e-6), dim)[1] | |
| self.global_ = nn.Sequential( | |
| nn.Conv2d(dim, dim, 1), | |
| nn.Conv2d(dim, dim, (1, strip_kernel_size), padding=(0, (strip_kernel_size-1)//2), groups=dim), | |
| nn.Conv2d(dim, dim, (strip_kernel_size, 1), padding=((strip_kernel_size-1)//2, 0), groups=dim) | |
| ) | |
| self.v = nn.Conv2d(dim, dim, 1) | |
| self.proj = nn.Conv2d(dim, dim, 1) | |
| self.layer_scale = nn.Parameter(1e-6 * torch.ones((dim)), requires_grad=True) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| identity = x | |
| a = self.global_(x) | |
| x = a * self.v(x) | |
| x = self.proj(x) | |
| x = self.norm(x) | |
| x = self.layer_scale.unsqueeze(-1).unsqueeze(-1) * x + identity | |
| return x | |
| class PriorAttention(BaseModule): | |
| def __init__(self, | |
| channels, | |
| num_paths=2, | |
| attn_channels=None, | |
| act_cfg=dict(type='ReLU'), | |
| norm_cfg=dict(type='BN', requires_grad=True)): | |
| super(PriorAttention, self).__init__() | |
| self.num_paths = num_paths # `2` is supported. | |
| attn_channels = attn_channels or channels // 16 | |
| attn_channels = max(attn_channels, 8) | |
| self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) | |
| self.bn = build_norm_layer(norm_cfg, attn_channels)[1] | |
| self.act = build_activation_layer(act_cfg) | |
| self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) | |
| def forward(self, x1, x2): | |
| x = torch.abs(x1 - x2) | |
| attn = x.mean((2, 3), keepdim=True) | |
| attn = self.fc_reduce(attn) | |
| attn = self.bn(attn) | |
| attn = self.act(attn) | |
| attn = self.fc_select(attn) | |
| B, C, H, W = attn.shape | |
| attn1, attn2 = attn.reshape(B, self.num_paths, C // self.num_paths, H, W).transpose(0, 1) | |
| attn1 = torch.sigmoid(attn1) | |
| attn2 = torch.sigmoid(attn2) | |
| return x1 * attn1 + x1, x2 * attn2 + x2 | |
| class StemBlock(BaseModule): | |
| """InvertedResidual block for MobileNetV2. | |
| Args: | |
| in_channels (int): The input channels of the InvertedResidual block. | |
| out_channels (int): The output channels of the InvertedResidual block. | |
| stride (int): Stride of the middle (first) 3x3 convolution. | |
| expand_ratio (int): Adjusts number of channels of the hidden layer | |
| in InvertedResidual by this amount. | |
| dilation (int): Dilation rate of depthwise conv. Default: 1 | |
| conv_cfg (dict): Config dict for convolution layer. | |
| Default: None, which means using conv2d. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config dict for activation layer. | |
| Default: dict(type='ReLU6'). | |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
| memory while slowing down the training speed. Default: False. | |
| Returns: | |
| Tensor: The output tensor. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| stride, | |
| expand_ratio, | |
| dilation=1, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU6'), | |
| with_cp=False, | |
| **kwargs): | |
| super(StemBlock, self).__init__() | |
| self.stride = stride | |
| assert stride in [1, 2], f'stride must in [1, 2]. ' \ | |
| f'But received {stride}.' | |
| self.with_cp = with_cp | |
| self.use_res_connect = self.stride == 1 and in_channels == out_channels | |
| hidden_dim = int(round(in_channels * expand_ratio)) | |
| layers = [] | |
| if expand_ratio != 1: | |
| layers.append( | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=hidden_dim, | |
| kernel_size=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| **kwargs)) | |
| layers.extend([ | |
| ConvModule( | |
| in_channels=hidden_dim, | |
| out_channels=hidden_dim, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=dilation, | |
| dilation=dilation, | |
| groups=hidden_dim, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| **kwargs), | |
| ]) | |
| self.conv = nn.Sequential(*layers) | |
| self.interact = PriorAttention(channels=hidden_dim) | |
| self.post_conv = ConvModule( | |
| in_channels=hidden_dim, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=None, | |
| **kwargs) | |
| def forward(self, x): | |
| x1, x2 = x | |
| identity_x1 = x1 | |
| identity_x2 = x2 | |
| x1 = self.conv(x1) | |
| x2 = self.conv(x2) | |
| x1, x2 = self.interact(x1, x2) | |
| x1 = self.post_conv(x1) | |
| x2 = self.post_conv(x2) | |
| if self.use_res_connect: | |
| x1 = x1 + identity_x1 | |
| x2 = x2 + identity_x2 | |
| return x1, x2 | |
| class PriorFusion(BaseModule): | |
| def __init__(self, channels, stack_nums=2): | |
| super().__init__() | |
| self.stem = nn.Sequential( | |
| *[StemBlock( | |
| in_channels=channels, | |
| out_channels=channels, | |
| stride=1, | |
| expand_ratio=4) for _ in range(stack_nums)]) | |
| self.pseudo_fusion = nn.Sequential( | |
| nn.Conv2d(channels * 2, channels * 2, 3, padding=1, groups=channels * 2), | |
| build_norm_layer(dict(type='mmpretrain.LN2d', eps=1e-6), channels * 2)[1], | |
| nn.GELU(), | |
| nn.Conv2d(channels * 2, channels, 3, padding=1, groups=channels), | |
| ) | |
| def forward(self, x1, x2): | |
| B, C, H, W = x1.shape | |
| identity_x1 = x1 | |
| identity_x2 = x2 | |
| x1, x2 = self.stem((x1, x2)) | |
| x1 = x1 + identity_x1 | |
| x2 = x2 + identity_x2 | |
| early_x = torch.cat([x1, x2], dim=1) | |
| x = self.pseudo_fusion(early_x) | |
| return early_x, x | |
| class TinyBlock(BaseModule): | |
| """InvertedResidual block for MobileNetV2. | |
| Args: | |
| in_channels (int): The input channels of the InvertedResidual block. | |
| out_channels (int): The output channels of the InvertedResidual block. | |
| stride (int): Stride of the middle (first) 3x3 convolution. | |
| expand_ratio (int): Adjusts number of channels of the hidden layer | |
| in InvertedResidual by this amount. | |
| dilation (int): Dilation rate of depthwise conv. Default: 1 | |
| conv_cfg (dict): Config dict for convolution layer. | |
| Default: None, which means using conv2d. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config dict for activation layer. | |
| Default: dict(type='ReLU6'). | |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
| memory while slowing down the training speed. Default: False. | |
| Returns: | |
| Tensor: The output tensor. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| stride, | |
| expand_ratio, | |
| dilation=1, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU6'), | |
| with_cp=False, | |
| with_se=False, | |
| **kwargs): | |
| super(TinyBlock, self).__init__() | |
| self.stride = stride | |
| assert stride in [1, 2], f'stride must in [1, 2]. ' \ | |
| f'But received {stride}.' | |
| self.with_cp = with_cp | |
| self.use_res_connect = self.stride == 1 and in_channels == out_channels | |
| hidden_dim = int(round(in_channels * expand_ratio)) | |
| layers = [] | |
| Attention_Layer = SELayer(hidden_dim) if with_se else nn.Identity() | |
| if expand_ratio != 1: | |
| layers.append( | |
| ConvModule( | |
| in_channels=in_channels, | |
| out_channels=hidden_dim, | |
| kernel_size=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| **kwargs)) | |
| layers.extend([ | |
| ConvModule( | |
| in_channels=hidden_dim, | |
| out_channels=hidden_dim, | |
| kernel_size=3, | |
| stride=stride, | |
| padding=dilation, | |
| dilation=dilation, | |
| groups=hidden_dim, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg, | |
| **kwargs), | |
| Attention_Layer, | |
| ConvModule( | |
| in_channels=hidden_dim, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=None, | |
| **kwargs) | |
| ]) | |
| self.conv = nn.Sequential(*layers) | |
| def forward(self, x): | |
| def _inner_forward(x): | |
| if self.use_res_connect: | |
| x = x + self.conv(x) | |
| return x | |
| else: | |
| return self.conv(x) | |
| if self.with_cp and x.requires_grad: | |
| out = cp.checkpoint(_inner_forward, x) | |
| else: | |
| out = _inner_forward(x) | |
| return out | |
| class TinyNet(BaseModule): | |
| """TinyNet backbone. | |
| This backbone is the implementation of | |
| Args: | |
| output_early_x (bool): output early features before fusion. | |
| Defaults to 'False'. | |
| arch='B' (str): The model's architecture. It should be | |
| one of architecture in ``TinyNet.change_extractor_settings``. | |
| Defaults to 'B'. | |
| stem_stack_nums (int): The number of stacked stem blocks. | |
| use_global: (Sequence[bool]): whether use `AsymGlobalAttn` after | |
| stages. Defaults: (True, True, True, True). | |
| strip_kernel_size: (Sequence[int]): The strip kernel size of | |
| `AsymGlobalAttn`. Defaults: (41, 31, 21, 11). | |
| widen_factor (float): Width multiplier, multiply number of | |
| channels in each layer by this amount. Default: 1.0. | |
| strides (Sequence[int], optional): Strides of the first block of each | |
| layer. If not specified, default config in ``arch_setting`` will | |
| be used. | |
| dilations (Sequence[int]): Dilation of each layer. | |
| out_indices (None or Sequence[int]): Output from which stages. | |
| Default: (7, ). | |
| frozen_stages (int): Stages to be frozen (all param fixed). | |
| Default: -1, which means not freezing any parameters. | |
| conv_cfg (dict): Config dict for convolution layer. | |
| Default: None, which means using conv2d. | |
| norm_cfg (dict): Config dict for normalization layer. | |
| Default: dict(type='BN'). | |
| act_cfg (dict): Config dict for activation layer. | |
| Default: dict(type='ReLU6'). | |
| norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
| freeze running stats (mean and var). Note: Effect on Batch Norm | |
| and its variants only. Default: False. | |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
| memory while slowing down the training speed. Default: False. | |
| pretrained (str, optional): model pretrained path. Default: None | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None | |
| """ | |
| # Parameters to build layers. 3 parameters are needed to construct a | |
| # layer, from left to right: expand_ratio, channel, num_blocks. | |
| change_extractor_settings = { | |
| 'S': [[4, 16, 2], [6, 24, 2], [6, 32, 3], [6, 48, 1]], | |
| 'B': [[4, 16, 2], [6, 24, 2], [6, 32, 3], [6, 48, 1]], | |
| 'L': [[4, 16, 2], [6, 24, 2], [6, 32, 6], [6, 48, 1]],} | |
| def __init__(self, | |
| output_early_x=False, | |
| arch='B', | |
| stem_stack_nums=2, | |
| use_global=(True, True, True, True), | |
| strip_kernel_size=(41, 31, 21, 11), | |
| widen_factor=1., | |
| strides=(1, 2, 2, 2), | |
| dilations=(1, 1, 1, 1), | |
| out_indices=(0, 1, 2, 3), | |
| frozen_stages=-1, | |
| conv_cfg=None, | |
| norm_cfg=dict(type='BN'), | |
| act_cfg=dict(type='ReLU6'), | |
| norm_eval=False, | |
| with_cp=False, | |
| pretrained=None, | |
| init_cfg=None): | |
| super().__init__(init_cfg) | |
| self.arch_settings = self.change_extractor_settings[arch] | |
| self.pretrained = pretrained | |
| assert not (init_cfg and pretrained), \ | |
| 'init_cfg and pretrained cannot be setting at the same time' | |
| if isinstance(pretrained, str): | |
| warnings.warn('DeprecationWarning: pretrained is a deprecated, ' | |
| 'please use "init_cfg" instead') | |
| self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) | |
| elif pretrained is None: | |
| if init_cfg is None: | |
| self.init_cfg = [ | |
| dict(type='Kaiming', layer='Conv2d'), | |
| dict( | |
| type='Constant', | |
| val=1, | |
| layer=['_BatchNorm', 'GroupNorm']) | |
| ] | |
| else: | |
| raise TypeError('pretrained must be a str or None') | |
| self.widen_factor = widen_factor | |
| self.strides = strides | |
| self.dilations = dilations | |
| assert len(strides) == len(dilations) == len(self.arch_settings) | |
| self.out_indices = out_indices | |
| for index in out_indices: | |
| if index not in range(0, 7): | |
| raise ValueError('the item in out_indices must in ' | |
| f'range(0, 7). But received {index}') | |
| if frozen_stages not in range(-1, 7): | |
| raise ValueError('frozen_stages must be in range(-1, 7). ' | |
| f'But received {frozen_stages}') | |
| self.out_indices = out_indices | |
| self.frozen_stages = frozen_stages | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| self.norm_eval = norm_eval | |
| self.with_cp = with_cp | |
| self.in_channels = make_divisible(16 * widen_factor, 8) | |
| self.conv1 = ConvModule( | |
| in_channels=3, | |
| out_channels=self.in_channels, | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg) | |
| self.fusion_block = PriorFusion(self.in_channels, stem_stack_nums) | |
| self.layers = [] | |
| self.use_global = use_global | |
| self.strip_kernel_size = strip_kernel_size | |
| for i, layer_cfg in enumerate(self.arch_settings): | |
| expand_ratio, channel, num_blocks = layer_cfg | |
| stride = self.strides[i] | |
| dilation = self.dilations[i] | |
| out_channels = make_divisible(channel * widen_factor, 8) | |
| inverted_res_layer = self.make_layer( | |
| out_channels=out_channels, | |
| num_blocks=num_blocks, | |
| stride=stride, | |
| dilation=dilation, | |
| expand_ratio=expand_ratio, | |
| use_global=use_global[i], | |
| strip_kernel_size=self.strip_kernel_size[i]) | |
| layer_name = f'layer{i + 1}' | |
| self.add_module(layer_name, inverted_res_layer) | |
| self.layers.append(layer_name) | |
| self.output_early_x = output_early_x | |
| def make_layer(self, out_channels, num_blocks, stride, dilation, | |
| expand_ratio, use_global, strip_kernel_size): | |
| """Stack InvertedResidual blocks to build a layer for MobileNetV2. | |
| Args: | |
| out_channels (int): out_channels of block. | |
| num_blocks (int): Number of blocks. | |
| stride (int): Stride of the first block. | |
| dilation (int): Dilation of the first block. | |
| expand_ratio (int): Expand the number of channels of the | |
| hidden layer in InvertedResidual by this ratio. | |
| """ | |
| layers = [] | |
| for i in range(num_blocks): | |
| layers.append( | |
| TinyBlock( | |
| self.in_channels, | |
| out_channels, | |
| stride if i == 0 else 1, | |
| expand_ratio=expand_ratio, | |
| dilation=dilation if i == 0 else 1, | |
| conv_cfg=self.conv_cfg, | |
| norm_cfg=self.norm_cfg, | |
| act_cfg=self.act_cfg, | |
| with_cp=self.with_cp)) | |
| self.in_channels = out_channels | |
| # after stage | |
| if use_global: | |
| layers.append( | |
| AsymGlobalAttn(out_channels, strip_kernel_size) | |
| ) | |
| return nn.Sequential(*layers) | |
| def forward(self, x1, x2): | |
| x1 = self.conv1(x1) | |
| x2 = self.conv1(x2) | |
| early_x, x = self.fusion_block(x1, x2) | |
| if self.output_early_x: | |
| outs = [early_x] | |
| else: | |
| outs = [] | |
| for i, layer_name in enumerate(self.layers): | |
| layer = getattr(self, layer_name) | |
| x = layer(x) | |
| if i in self.out_indices: | |
| outs.append(x) | |
| if len(outs) == 1: | |
| return outs[0] | |
| else: | |
| return tuple(outs) | |
| def _freeze_stages(self): | |
| if self.frozen_stages >= 0: | |
| for param in self.conv1.parameters(): | |
| param.requires_grad = False | |
| for i in range(1, self.frozen_stages + 1): | |
| layer = getattr(self, f'layer{i}') | |
| layer.eval() | |
| for param in layer.parameters(): | |
| param.requires_grad = False | |
| def train(self, mode=True): | |
| super(TinyNet, self).train(mode) | |
| self._freeze_stages() | |
| if mode and self.norm_eval: | |
| for m in self.modules(): | |
| if isinstance(m, _BatchNorm): | |
| m.eval() |