# Copyright (c) Open-CD. All rights reserved. from mmseg.models.backbones.resnest import Bottleneck from mmseg.models.utils import ResLayer from opencd.registry import MODELS from .interaction_resnet import IA_ResNetV1d @MODELS.register_module() class IA_ResNeSt(IA_ResNetV1d): """Interaction ResNeSt backbone. This backbone is the implementation of `ResNeSt: Split-Attention Networks `_. Args: groups (int): Number of groups of Bottleneck. Default: 1 base_width (int): Base width of Bottleneck. Default: 4 radix (int): Radix of SpltAtConv2d. Default: 2 reduction_factor (int): Reduction factor of inter_channels in SplitAttentionConv2d. Default: 4. avg_down_stride (bool): Whether to use average pool for stride in Bottleneck. Default: True. kwargs (dict): Keyword arguments for ResNet. """ arch_settings = { 50: (Bottleneck, (3, 4, 6, 3)), 101: (Bottleneck, (3, 4, 23, 3)), 152: (Bottleneck, (3, 8, 36, 3)), 200: (Bottleneck, (3, 24, 36, 3)) } def __init__(self, groups=1, base_width=4, radix=2, reduction_factor=4, avg_down_stride=True, **kwargs): self.groups = groups self.base_width = base_width self.radix = radix self.reduction_factor = reduction_factor self.avg_down_stride = avg_down_stride super(IA_ResNeSt, self).__init__(**kwargs) def make_res_layer(self, **kwargs): """Pack all blocks in a stage into a ``ResLayer``.""" return ResLayer( groups=self.groups, base_width=self.base_width, base_channels=self.base_channels, radix=self.radix, reduction_factor=self.reduction_factor, avg_down_stride=self.avg_down_stride, **kwargs)