# Copyright (c) OpenMMLab. All rights reserved. """Modified from https://github.com/MichaelFan01/STDC-Seg.""" import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine.model import BaseModule, ModuleList, Sequential from mmseg.registry import MODELS from ..utils import resize from .bisenetv1 import AttentionRefinementModule class STDCModule(BaseModule): """STDCModule. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels before scaling. stride (int): The number of stride for the first conv layer. norm_cfg (dict): Config dict for normalization layer. Default: None. act_cfg (dict): The activation config for conv layers. num_convs (int): Numbers of conv layers. fusion_type (str): Type of fusion operation. Default: 'add'. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, in_channels, out_channels, stride, norm_cfg=None, act_cfg=None, num_convs=4, fusion_type='add', init_cfg=None): super().__init__(init_cfg=init_cfg) assert num_convs > 1 assert fusion_type in ['add', 'cat'] self.stride = stride self.with_downsample = True if self.stride == 2 else False self.fusion_type = fusion_type self.layers = ModuleList() conv_0 = ConvModule( in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg) if self.with_downsample: self.downsample = ConvModule( out_channels // 2, out_channels // 2, kernel_size=3, stride=2, padding=1, groups=out_channels // 2, norm_cfg=norm_cfg, act_cfg=None) if self.fusion_type == 'add': self.layers.append(nn.Sequential(conv_0, self.downsample)) self.skip = Sequential( ConvModule( in_channels, in_channels, kernel_size=3, stride=2, padding=1, groups=in_channels, norm_cfg=norm_cfg, act_cfg=None), ConvModule( in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)) else: self.layers.append(conv_0) self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) else: self.layers.append(conv_0) for i in range(1, num_convs): out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i self.layers.append( ConvModule( out_channels // 2**i, out_channels // out_factor, kernel_size=3, stride=1, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg)) def forward(self, inputs): if self.fusion_type == 'add': out = self.forward_add(inputs) else: out = self.forward_cat(inputs) return out def forward_add(self, inputs): layer_outputs = [] x = inputs.clone() for layer in self.layers: x = layer(x) layer_outputs.append(x) if self.with_downsample: inputs = self.skip(inputs) return torch.cat(layer_outputs, dim=1) + inputs def forward_cat(self, inputs): x0 = self.layers[0](inputs) layer_outputs = [x0] for i, layer in enumerate(self.layers[1:]): if i == 0: if self.with_downsample: x = layer(self.downsample(x0)) else: x = layer(x0) else: x = layer(x) layer_outputs.append(x) if self.with_downsample: layer_outputs[0] = self.skip(x0) return torch.cat(layer_outputs, dim=1) class FeatureFusionModule(BaseModule): """Feature Fusion Module. This module is different from FeatureFusionModule in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter channel number is calculated by given `scale_factor`, while FeatureFusionModule in BiSeNetV1 only uses one ConvModule in `self.conv_atten`. Args: in_channels (int): The number of input channels. out_channels (int): The number of output channels. scale_factor (int): The number of channel scale factor. Default: 4. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). act_cfg (dict): The activation config for conv layers. Default: dict(type='ReLU'). init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, in_channels, out_channels, scale_factor=4, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), init_cfg=None): super().__init__(init_cfg=init_cfg) channels = out_channels // scale_factor self.conv0 = ConvModule( in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg) self.attention = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), ConvModule( out_channels, channels, 1, norm_cfg=None, bias=False, act_cfg=act_cfg), ConvModule( channels, out_channels, 1, norm_cfg=None, bias=False, act_cfg=None), nn.Sigmoid()) def forward(self, spatial_inputs, context_inputs): inputs = torch.cat([spatial_inputs, context_inputs], dim=1) x = self.conv0(inputs) attn = self.attention(x) x_attn = x * attn return x_attn + x @MODELS.register_module() class STDCNet(BaseModule): """This backbone is the implementation of `Rethinking BiSeNet For Real-time Semantic Segmentation `_. Args: stdc_type (int): The type of backbone structure, `STDCNet1` and`STDCNet2` denotes two main backbones in paper, whose FLOPs is 813M and 1446M, respectively. in_channels (int): The num of input_channels. channels (tuple[int]): The output channels for each stage. bottleneck_type (str): The type of STDC Module type, the value must be 'add' or 'cat'. norm_cfg (dict): Config dict for normalization layer. act_cfg (dict): The activation config for conv layers. num_convs (int): Numbers of conv layer at each STDC Module. Default: 4. with_final_conv (bool): Whether add a conv layer at the Module output. Default: True. pretrained (str, optional): Model pretrained path. Default: None. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. Example: >>> import torch >>> stdc_type = 'STDCNet1' >>> in_channels = 3 >>> channels = (32, 64, 256, 512, 1024) >>> bottleneck_type = 'cat' >>> inputs = torch.rand(1, 3, 1024, 2048) >>> self = STDCNet(stdc_type, in_channels, ... channels, bottleneck_type).eval() >>> outputs = self.forward(inputs) >>> for i in range(len(outputs)): ... print(f'outputs[{i}].shape = {outputs[i].shape}') outputs[0].shape = torch.Size([1, 256, 128, 256]) outputs[1].shape = torch.Size([1, 512, 64, 128]) outputs[2].shape = torch.Size([1, 1024, 32, 64]) """ arch_settings = { 'STDCNet1': [(2, 1), (2, 1), (2, 1)], 'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)] } def __init__(self, stdc_type, in_channels, channels, bottleneck_type, norm_cfg, act_cfg, num_convs=4, with_final_conv=False, pretrained=None, init_cfg=None): super().__init__(init_cfg=init_cfg) assert stdc_type in self.arch_settings, \ f'invalid structure {stdc_type} for STDCNet.' assert bottleneck_type in ['add', 'cat'],\ f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}' assert len(channels) == 5,\ f'invalid channels length {len(channels)} for STDCNet.' self.in_channels = in_channels self.channels = channels self.stage_strides = self.arch_settings[stdc_type] self.prtrained = pretrained self.num_convs = num_convs self.with_final_conv = with_final_conv self.stages = ModuleList([ ConvModule( self.in_channels, self.channels[0], kernel_size=3, stride=2, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg), ConvModule( self.channels[0], self.channels[1], kernel_size=3, stride=2, padding=1, norm_cfg=norm_cfg, act_cfg=act_cfg) ]) # `self.num_shallow_features` is the number of shallow modules in # `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper. # They are both not used for following modules like Attention # Refinement Module and Feature Fusion Module. # Thus they would be cut from `outs`. Please refer to Figure 4 # of original paper for more details. self.num_shallow_features = len(self.stages) for strides in self.stage_strides: idx = len(self.stages) - 1 self.stages.append( self._make_stage(self.channels[idx], self.channels[idx + 1], strides, norm_cfg, act_cfg, bottleneck_type)) # After appending, `self.stages` is a ModuleList including several # shallow modules and STDCModules. # (len(self.stages) == # self.num_shallow_features + len(self.stage_strides)) if self.with_final_conv: self.final_conv = ConvModule( self.channels[-1], max(1024, self.channels[-1]), 1, norm_cfg=norm_cfg, act_cfg=act_cfg) def _make_stage(self, in_channels, out_channels, strides, norm_cfg, act_cfg, bottleneck_type): layers = [] for i, stride in enumerate(strides): layers.append( STDCModule( in_channels if i == 0 else out_channels, out_channels, stride, norm_cfg, act_cfg, num_convs=self.num_convs, fusion_type=bottleneck_type)) return Sequential(*layers) def forward(self, x): outs = [] for stage in self.stages: x = stage(x) outs.append(x) if self.with_final_conv: outs[-1] = self.final_conv(outs[-1]) outs = outs[self.num_shallow_features:] return tuple(outs) @MODELS.register_module() class STDCContextPathNet(BaseModule): """STDCNet with Context Path. The `outs` below is a list of three feature maps from deep to shallow, whose height and width is from small to big, respectively. The biggest feature map of `outs` is outputted for `STDCHead`, where Detail Loss would be calculated by Detail Ground-truth. The other two feature maps are used for Attention Refinement Module, respectively. Besides, the biggest feature map of `outs` and the last output of Attention Refinement Module are concatenated for Feature Fusion Module. Then, this fusion feature map `feat_fuse` would be outputted for `decode_head`. More details please refer to Figure 4 of original paper. Args: backbone_cfg (dict): Config dict for stdc backbone. last_in_channels (tuple(int)), The number of channels of last two feature maps from stdc backbone. Default: (1024, 512). out_channels (int): The channels of output feature maps. Default: 128. ffm_cfg (dict): Config dict for Feature Fusion Module. Default: `dict(in_channels=512, out_channels=256, scale_factor=4)`. upsample_mode (str): Algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | ``'trilinear'``. Default: ``'nearest'``. align_corners (str): align_corners argument of F.interpolate. It must be `None` if upsample_mode is ``'nearest'``. Default: None. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. Return: outputs (tuple): The tuple of list of output feature map for auxiliary heads and decoder head. """ def __init__(self, backbone_cfg, last_in_channels=(1024, 512), out_channels=128, ffm_cfg=dict( in_channels=512, out_channels=256, scale_factor=4), upsample_mode='nearest', align_corners=None, norm_cfg=dict(type='BN'), init_cfg=None): super().__init__(init_cfg=init_cfg) self.backbone = MODELS.build(backbone_cfg) self.arms = ModuleList() self.convs = ModuleList() for channels in last_in_channels: self.arms.append(AttentionRefinementModule(channels, out_channels)) self.convs.append( ConvModule( out_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg)) self.conv_avg = ConvModule( last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg) self.ffm = FeatureFusionModule(**ffm_cfg) self.upsample_mode = upsample_mode self.align_corners = align_corners def forward(self, x): outs = list(self.backbone(x)) avg = F.adaptive_avg_pool2d(outs[-1], 1) avg_feat = self.conv_avg(avg) feature_up = resize( avg_feat, size=outs[-1].shape[2:], mode=self.upsample_mode, align_corners=self.align_corners) arms_out = [] for i in range(len(self.arms)): x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up feature_up = resize( x_arm, size=outs[len(outs) - 1 - i - 1].shape[2:], mode=self.upsample_mode, align_corners=self.align_corners) feature_up = self.convs[i](feature_up) arms_out.append(feature_up) feat_fuse = self.ffm(outs[0], arms_out[1]) # The `outputs` has four feature maps. # `outs[0]` is outputted for `STDCHead` auxiliary head. # Two feature maps of `arms_out` are outputted for auxiliary head. # `feat_fuse` is outputted for decoder head. outputs = [outs[0]] + list(arms_out) + [feat_fuse] return tuple(outputs)