Spaces:
Runtime error
Runtime error
File size: 6,517 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# Copyright (c) Open-CD. All rights reserved.
import torch
import torch.nn as nn
from mmseg.models.backbones import ResNet
from opencd.registry import MODELS
@MODELS.register_module()
class IA_ResNet(ResNet):
"""Interaction ResNet backbone.
Args:
interaction_cfg (Sequence[dict]): Interaction strategies for the stages.
The length should be the same as `num_stages`. The details can be
found in `opencd/models/utils/interaction_layer.py`.
Default: (None, None, None, None).
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Number of stem channels. Default: 64.
base_channels (int): Number of base channels of res layer. Default: 64.
num_stages (int): Resnet stages, normally 4. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: (1, 2, 2, 2).
dilations (Sequence[int]): Dilation of each stage.
Default: (1, 1, 1, 1).
out_indices (Sequence[int]): Output from which stages.
Default: (0, 1, 2, 3).
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer. Default: 'pytorch'.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): Dictionary to construct and config conv layer.
When conv_cfg is None, cfg will be set to dict(type='Conv2d').
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
dcn (dict | None): Dictionary to construct and config DCN conv layer.
When dcn is not None, conv_cfg must be None. Default: None.
stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each
stage. The length of stage_with_dcn is equal to num_stages.
Default: (False, False, False, False).
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- position (str, required): Position inside block to insert plugin,
options: 'after_conv1', 'after_conv2', 'after_conv3'.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
Default: None.
multi_grid (Sequence[int]|None): Multi grid dilation rates of last
stage. Default: None.
contract_dilation (bool): Whether contract first dilation of each layer
Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> from opencd.models import IA_ResNet
>>> import torch
>>> self = IA_ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs, inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 128, 8, 8)
(1, 256, 4, 4)
(1, 512, 2, 2)
(1, 1024, 1, 1)
"""
def __init__(self,
interaction_cfg=(None, None, None, None),
**kwargs):
super().__init__(**kwargs)
assert self.num_stages == len(interaction_cfg), \
'The length of the `interaction_cfg` should be same as the `num_stages`.'
# cross-correlation
self.ccs = []
for ia_cfg in interaction_cfg:
if ia_cfg is None:
ia_cfg = dict(type='TwoIdentity')
self.ccs.append(MODELS.build(ia_cfg))
self.ccs = nn.ModuleList(self.ccs)
def forward(self, x1, x2):
"""Forward function."""
def _stem_forward(x):
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
return x
x1 = _stem_forward(x1)
x2 = _stem_forward(x2)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x1 = res_layer(x1)
x2 = res_layer(x2)
x1, x2 = self.ccs[i](x1, x2)
if i in self.out_indices:
outs.append(torch.cat([x1, x2], dim=1))
return tuple(outs)
@MODELS.register_module()
class IA_ResNetV1c(IA_ResNet):
"""ResNetV1c variant described in [1]_.
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in
the input stem with three 3x3 convs. For more details please refer to `Bag
of Tricks for Image Classification with Convolutional Neural Networks
<https://arxiv.org/abs/1812.01187>`_.
"""
def __init__(self, **kwargs):
super(IA_ResNetV1c, self).__init__(
deep_stem=True, avg_down=False, **kwargs)
@MODELS.register_module()
class IA_ResNetV1d(IA_ResNet):
"""ResNetV1d variant described in [1]_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
"""
def __init__(self, **kwargs):
super(IA_ResNetV1d, self).__init__(
deep_stem=True, avg_down=True, **kwargs) |