Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmengine.model import BaseModule | |
from mmseg.registry import MODELS | |
from ..utils import resize | |
class CascadeFeatureFusion(BaseModule): | |
"""Cascade Feature Fusion Unit in ICNet. | |
Args: | |
low_channels (int): The number of input channels for | |
low resolution feature map. | |
high_channels (int): The number of input channels for | |
high resolution feature map. | |
out_channels (int): The number of output channels. | |
conv_cfg (dict): Dictionary to construct and config conv layer. | |
Default: None. | |
norm_cfg (dict): Dictionary to construct and config norm layer. | |
Default: dict(type='BN'). | |
act_cfg (dict): Dictionary to construct and config act layer. | |
Default: dict(type='ReLU'). | |
align_corners (bool): align_corners argument of F.interpolate. | |
Default: False. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
Returns: | |
x (Tensor): The output tensor of shape (N, out_channels, H, W). | |
x_low (Tensor): The output tensor of shape (N, out_channels, H, W) | |
for Cascade Label Guidance in auxiliary heads. | |
""" | |
def __init__(self, | |
low_channels, | |
high_channels, | |
out_channels, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
align_corners=False, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.align_corners = align_corners | |
self.conv_low = ConvModule( | |
low_channels, | |
out_channels, | |
3, | |
padding=2, | |
dilation=2, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.conv_high = ConvModule( | |
high_channels, | |
out_channels, | |
1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
def forward(self, x_low, x_high): | |
x_low = resize( | |
x_low, | |
size=x_high.size()[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
# Note: Different from original paper, `x_low` is underwent | |
# `self.conv_low` rather than another 1x1 conv classifier | |
# before being used for auxiliary head. | |
x_low = self.conv_low(x_low) | |
x_high = self.conv_high(x_high) | |
x = x_low + x_high | |
x = F.relu(x, inplace=True) | |
return x, x_low | |
class ICNeck(BaseModule): | |
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images. | |
This head is the implementation of `ICHead | |
<https://arxiv.org/abs/1704.08545>`_. | |
Args: | |
in_channels (int): The number of input image channels. Default: 3. | |
out_channels (int): The numbers of output feature channels. | |
Default: 128. | |
conv_cfg (dict): Dictionary to construct and config conv layer. | |
Default: None. | |
norm_cfg (dict): Dictionary to construct and config norm layer. | |
Default: dict(type='BN'). | |
act_cfg (dict): Dictionary to construct and config act layer. | |
Default: dict(type='ReLU'). | |
align_corners (bool): align_corners argument of F.interpolate. | |
Default: False. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels=(64, 256, 256), | |
out_channels=128, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
align_corners=False, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
assert len(in_channels) == 3, 'Length of input channels \ | |
must be 3!' | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.align_corners = align_corners | |
self.cff_24 = CascadeFeatureFusion( | |
self.in_channels[2], | |
self.in_channels[1], | |
self.out_channels, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg, | |
align_corners=self.align_corners) | |
self.cff_12 = CascadeFeatureFusion( | |
self.out_channels, | |
self.in_channels[0], | |
self.out_channels, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg, | |
align_corners=self.align_corners) | |
def forward(self, inputs): | |
assert len(inputs) == 3, 'Length of input feature \ | |
maps must be 3!' | |
x_sub1, x_sub2, x_sub4 = inputs | |
x_cff_24, x_24 = self.cff_24(x_sub4, x_sub2) | |
x_cff_12, x_12 = self.cff_12(x_cff_24, x_sub1) | |
# Note: `x_cff_12` is used for decode_head, | |
# `x_24` and `x_12` are used for auxiliary head. | |
return x_24, x_12, x_cff_12 | |