Spaces:
Runtime error
Runtime error
File size: 5,330 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmseg.registry import MODELS
from ..utils import resize
class CascadeFeatureFusion(BaseModule):
"""Cascade Feature Fusion Unit in ICNet.
Args:
low_channels (int): The number of input channels for
low resolution feature map.
high_channels (int): The number of input channels for
high resolution feature map.
out_channels (int): The number of output channels.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Dictionary to construct and config act layer.
Default: dict(type='ReLU').
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (Tensor): The output tensor of shape (N, out_channels, H, W).
x_low (Tensor): The output tensor of shape (N, out_channels, H, W)
for Cascade Label Guidance in auxiliary heads.
"""
def __init__(self,
low_channels,
high_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.align_corners = align_corners
self.conv_low = ConvModule(
low_channels,
out_channels,
3,
padding=2,
dilation=2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv_high = ConvModule(
high_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x_low, x_high):
x_low = resize(
x_low,
size=x_high.size()[2:],
mode='bilinear',
align_corners=self.align_corners)
# Note: Different from original paper, `x_low` is underwent
# `self.conv_low` rather than another 1x1 conv classifier
# before being used for auxiliary head.
x_low = self.conv_low(x_low)
x_high = self.conv_high(x_high)
x = x_low + x_high
x = F.relu(x, inplace=True)
return x, x_low
@MODELS.register_module()
class ICNeck(BaseModule):
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
This head is the implementation of `ICHead
<https://arxiv.org/abs/1704.08545>`_.
Args:
in_channels (int): The number of input image channels. Default: 3.
out_channels (int): The numbers of output feature channels.
Default: 128.
conv_cfg (dict): Dictionary to construct and config conv layer.
Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Dictionary to construct and config act layer.
Default: dict(type='ReLU').
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=(64, 256, 256),
out_channels=128,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert len(in_channels) == 3, 'Length of input channels \
must be 3!'
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
self.cff_24 = CascadeFeatureFusion(
self.in_channels[2],
self.in_channels[1],
self.out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
self.cff_12 = CascadeFeatureFusion(
self.out_channels,
self.in_channels[0],
self.out_channels,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
align_corners=self.align_corners)
def forward(self, inputs):
assert len(inputs) == 3, 'Length of input feature \
maps must be 3!'
x_sub1, x_sub2, x_sub4 = inputs
x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2)
x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1)
# Note: `x_cff_12` is used for decode_head,
# `x_24` and `x_12` are used for auxiliary head.
return x_24, x_12, x_cff_12
|