File size: 5,330 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

from mmseg.registry import MODELS
from ..utils import resize


class CascadeFeatureFusion(BaseModule):
    """Cascade Feature Fusion Unit in ICNet.

    Args:
        low_channels (int): The number of input channels for
            low resolution feature map.
        high_channels (int): The number of input channels for
            high resolution feature map.
        out_channels (int): The number of output channels.
        conv_cfg (dict): Dictionary to construct and config conv layer.
            Default: None.
        norm_cfg (dict): Dictionary to construct and config norm layer.
            Default: dict(type='BN').
        act_cfg (dict): Dictionary to construct and config act layer.
            Default: dict(type='ReLU').
        align_corners (bool): align_corners argument of F.interpolate.
            Default: False.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None.

    Returns:
        x (Tensor): The output tensor of shape (N, out_channels, H, W).
        x_low (Tensor): The output tensor of shape (N, out_channels, H, W)
            for Cascade Label Guidance in auxiliary heads.
    """

    def __init__(self,
                 low_channels,
                 high_channels,
                 out_channels,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
                 align_corners=False,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        self.align_corners = align_corners
        self.conv_low = ConvModule(
            low_channels,
            out_channels,
            3,
            padding=2,
            dilation=2,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)
        self.conv_high = ConvModule(
            high_channels,
            out_channels,
            1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg)

    def forward(self, x_low, x_high):
        x_low = resize(
            x_low,
            size=x_high.size()[2:],
            mode='bilinear',
            align_corners=self.align_corners)
        # Note: Different from original paper, `x_low` is underwent
        # `self.conv_low` rather than another 1x1 conv classifier
        #  before being used for auxiliary head.
        x_low = self.conv_low(x_low)
        x_high = self.conv_high(x_high)
        x = x_low + x_high
        x = F.relu(x, inplace=True)
        return x, x_low


@MODELS.register_module()
class ICNeck(BaseModule):
    """ICNet for Real-Time Semantic Segmentation on High-Resolution Images.

    This head is the implementation of `ICHead
    <https://arxiv.org/abs/1704.08545>`_.

    Args:
        in_channels (int): The number of input image channels. Default: 3.
        out_channels (int): The numbers of output feature channels.
            Default: 128.
        conv_cfg (dict): Dictionary to construct and config conv layer.
            Default: None.
        norm_cfg (dict): Dictionary to construct and config norm layer.
            Default: dict(type='BN').
        act_cfg (dict): Dictionary to construct and config act layer.
            Default: dict(type='ReLU').
        align_corners (bool): align_corners argument of F.interpolate.
            Default: False.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Default: None.
    """

    def __init__(self,
                 in_channels=(64, 256, 256),
                 out_channels=128,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 act_cfg=dict(type='ReLU'),
                 align_corners=False,
                 init_cfg=None):
        super().__init__(init_cfg=init_cfg)
        assert len(in_channels) == 3, 'Length of input channels \
                                        must be 3!'

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        self.align_corners = align_corners
        self.cff_24 = CascadeFeatureFusion(
            self.in_channels[2],
            self.in_channels[1],
            self.out_channels,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg,
            align_corners=self.align_corners)

        self.cff_12 = CascadeFeatureFusion(
            self.out_channels,
            self.in_channels[0],
            self.out_channels,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg,
            align_corners=self.align_corners)

    def forward(self, inputs):
        assert len(inputs) == 3, 'Length of input feature \
                                        maps must be 3!'

        x_sub1, x_sub2, x_sub4 = inputs
        x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2)
        x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1)
        # Note: `x_cff_12` is used for decode_head,
        # `x_24` and `x_12` are used for auxiliary head.
        return x_24, x_12, x_cff_12