# Copyright (c) OpenMMLab. All rights reserved. import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine.model import BaseModule from mmseg.registry import MODELS from ..utils import resize class CascadeFeatureFusion(BaseModule): """Cascade Feature Fusion Unit in ICNet. Args: low_channels (int): The number of input channels for low resolution feature map. high_channels (int): The number of input channels for high resolution feature map. out_channels (int): The number of output channels. conv_cfg (dict): Dictionary to construct and config conv layer. Default: None. norm_cfg (dict): Dictionary to construct and config norm layer. Default: dict(type='BN'). act_cfg (dict): Dictionary to construct and config act layer. Default: dict(type='ReLU'). align_corners (bool): align_corners argument of F.interpolate. Default: False. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. Returns: x (Tensor): The output tensor of shape (N, out_channels, H, W). x_low (Tensor): The output tensor of shape (N, out_channels, H, W) for Cascade Label Guidance in auxiliary heads. """ def __init__(self, low_channels, high_channels, out_channels, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False, init_cfg=None): super().__init__(init_cfg=init_cfg) self.align_corners = align_corners self.conv_low = ConvModule( low_channels, out_channels, 3, padding=2, dilation=2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) self.conv_high = ConvModule( high_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) def forward(self, x_low, x_high): x_low = resize( x_low, size=x_high.size()[2:], mode='bilinear', align_corners=self.align_corners) # Note: Different from original paper, `x_low` is underwent # `self.conv_low` rather than another 1x1 conv classifier # before being used for auxiliary head. x_low = self.conv_low(x_low) x_high = self.conv_high(x_high) x = x_low + x_high x = F.relu(x, inplace=True) return x, x_low @MODELS.register_module() class ICNeck(BaseModule): """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. This head is the implementation of `ICHead `_. Args: in_channels (int): The number of input image channels. Default: 3. out_channels (int): The numbers of output feature channels. Default: 128. conv_cfg (dict): Dictionary to construct and config conv layer. Default: None. norm_cfg (dict): Dictionary to construct and config norm layer. Default: dict(type='BN'). act_cfg (dict): Dictionary to construct and config act layer. Default: dict(type='ReLU'). align_corners (bool): align_corners argument of F.interpolate. Default: False. init_cfg (dict or list[dict], optional): Initialization config dict. Default: None. """ def __init__(self, in_channels=(64, 256, 256), out_channels=128, conv_cfg=None, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), align_corners=False, init_cfg=None): super().__init__(init_cfg=init_cfg) assert len(in_channels) == 3, 'Length of input channels \ must be 3!' self.in_channels = in_channels self.out_channels = out_channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.align_corners = align_corners self.cff_24 = CascadeFeatureFusion( self.in_channels[2], self.in_channels[1], self.out_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, align_corners=self.align_corners) self.cff_12 = CascadeFeatureFusion( self.out_channels, self.in_channels[0], self.out_channels, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, align_corners=self.align_corners) def forward(self, inputs): assert len(inputs) == 3, 'Length of input feature \ maps must be 3!' x_sub1, x_sub2, x_sub4 = inputs x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2) x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1) # Note: `x_cff_12` is used for decode_head, # `x_24` and `x_12` are used for auxiliary head. return x_24, x_12, x_cff_12