# Copyright (c) Open-CD. All rights reserved. import torch import torch.nn as nn from mmseg.models.backbones import ResNet from opencd.registry import MODELS @MODELS.register_module() class IA_ResNet(ResNet): """Interaction ResNet backbone. Args: interaction_cfg (Sequence[dict]): Interaction strategies for the stages. The length should be the same as `num_stages`. The details can be found in `opencd/models/utils/interaction_layer.py`. Default: (None, None, None, None). depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. in_channels (int): Number of input image channels. Default: 3. stem_channels (int): Number of stem channels. Default: 64. base_channels (int): Number of base channels of res layer. Default: 64. num_stages (int): Resnet stages, normally 4. Default: 4. strides (Sequence[int]): Strides of the first block of each stage. Default: (1, 2, 2, 2). dilations (Sequence[int]): Dilation of each stage. Default: (1, 1, 1, 1). out_indices (Sequence[int]): Output from which stages. Default: (0, 1, 2, 3). style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. Default: 'pytorch'. deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. Default: False. avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Default: False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Default: -1. conv_cfg (dict | None): Dictionary to construct and config conv layer. When conv_cfg is None, cfg will be set to dict(type='Conv2d'). Default: None. norm_cfg (dict): Dictionary to construct and config norm layer. Default: dict(type='BN', requires_grad=True). norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. dcn (dict | None): Dictionary to construct and config DCN conv layer. When dcn is not None, conv_cfg must be None. Default: None. stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each stage. The length of stage_with_dcn is equal to num_stages. Default: (False, False, False, False). plugins (list[dict]): List of plugins for stages, each dict contains: - cfg (dict, required): Cfg dict to build plugin. - position (str, required): Position inside block to insert plugin, options: 'after_conv1', 'after_conv2', 'after_conv3'. - stages (tuple[bool], optional): Stages to apply plugin, length should be same as 'num_stages'. Default: None. multi_grid (Sequence[int]|None): Multi grid dilation rates of last stage. Default: None. contract_dilation (bool): Whether contract first dilation of each layer Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. zero_init_residual (bool): Whether to use zero init for last norm layer in resblocks to let them behave as identity. Default: True. pretrained (str, optional): model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. Example: >>> from opencd.models import IA_ResNet >>> import torch >>> self = IA_ResNet(depth=18) >>> self.eval() >>> inputs = torch.rand(1, 3, 32, 32) >>> level_outputs = self.forward(inputs, inputs) >>> for level_out in level_outputs: ... print(tuple(level_out.shape)) (1, 128, 8, 8) (1, 256, 4, 4) (1, 512, 2, 2) (1, 1024, 1, 1) """ def __init__(self, interaction_cfg=(None, None, None, None), **kwargs): super().__init__(**kwargs) assert self.num_stages == len(interaction_cfg), \ 'The length of the `interaction_cfg` should be same as the `num_stages`.' # cross-correlation self.ccs = [] for ia_cfg in interaction_cfg: if ia_cfg is None: ia_cfg = dict(type='TwoIdentity') self.ccs.append(MODELS.build(ia_cfg)) self.ccs = nn.ModuleList(self.ccs) def forward(self, x1, x2): """Forward function.""" def _stem_forward(x): if self.deep_stem: x = self.stem(x) else: x = self.conv1(x) x = self.norm1(x) x = self.relu(x) x = self.maxpool(x) return x x1 = _stem_forward(x1) x2 = _stem_forward(x2) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x1 = res_layer(x1) x2 = res_layer(x2) x1, x2 = self.ccs[i](x1, x2) if i in self.out_indices: outs.append(torch.cat([x1, x2], dim=1)) return tuple(outs) @MODELS.register_module() class IA_ResNetV1c(IA_ResNet): """ResNetV1c variant described in [1]_. Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in the input stem with three 3x3 convs. For more details please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks `_. """ def __init__(self, **kwargs): super(IA_ResNetV1c, self).__init__( deep_stem=True, avg_down=False, **kwargs) @MODELS.register_module() class IA_ResNetV1d(IA_ResNet): """ResNetV1d variant described in [1]_. Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in the input stem with three 3x3 convs. And in the downsampling block, a 2x2 avg_pool with stride 2 is added before conv, whose stride is changed to 1. """ def __init__(self, **kwargs): super(IA_ResNetV1d, self).__init__( deep_stem=True, avg_down=True, **kwargs)