Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmengine.model import BaseModule | |
from mmseg.registry import MODELS | |
from ..utils import resize | |
class SpatialPath(BaseModule): | |
"""Spatial Path to preserve the spatial size of the original input image | |
and encode affluent spatial information. | |
Args: | |
in_channels(int): The number of channels of input | |
image. Default: 3. | |
num_channels (Tuple[int]): The number of channels of | |
each layers in Spatial Path. | |
Default: (64, 64, 64, 128). | |
Returns: | |
x (torch.Tensor): Feature map for Feature Fusion Module. | |
""" | |
def __init__(self, | |
in_channels=3, | |
num_channels=(64, 64, 64, 128), | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
assert len(num_channels) == 4, 'Length of input channels \ | |
of Spatial Path must be 4!' | |
self.layers = [] | |
for i in range(len(num_channels)): | |
layer_name = f'layer{i + 1}' | |
self.layers.append(layer_name) | |
if i == 0: | |
self.add_module( | |
layer_name, | |
ConvModule( | |
in_channels=in_channels, | |
out_channels=num_channels[i], | |
kernel_size=7, | |
stride=2, | |
padding=3, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
elif i == len(num_channels) - 1: | |
self.add_module( | |
layer_name, | |
ConvModule( | |
in_channels=num_channels[i - 1], | |
out_channels=num_channels[i], | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
else: | |
self.add_module( | |
layer_name, | |
ConvModule( | |
in_channels=num_channels[i - 1], | |
out_channels=num_channels[i], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
def forward(self, x): | |
for i, layer_name in enumerate(self.layers): | |
layer_stage = getattr(self, layer_name) | |
x = layer_stage(x) | |
return x | |
class AttentionRefinementModule(BaseModule): | |
"""Attention Refinement Module (ARM) to refine the features of each stage. | |
Args: | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels. | |
Returns: | |
x_out (torch.Tensor): Feature map of Attention Refinement Module. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channel, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.conv_layer = ConvModule( | |
in_channels=in_channels, | |
out_channels=out_channel, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.atten_conv_layer = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
ConvModule( | |
in_channels=out_channel, | |
out_channels=out_channel, | |
kernel_size=1, | |
bias=False, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None), nn.Sigmoid()) | |
def forward(self, x): | |
x = self.conv_layer(x) | |
x_atten = self.atten_conv_layer(x) | |
x_out = x * x_atten | |
return x_out | |
class ContextPath(BaseModule): | |
"""Context Path to provide sufficient receptive field. | |
Args: | |
backbone_cfg:(dict): Config of backbone of | |
Context Path. | |
context_channels (Tuple[int]): The number of channel numbers | |
of various modules in Context Path. | |
Default: (128, 256, 512). | |
align_corners (bool, optional): The align_corners argument of | |
resize operation. Default: False. | |
Returns: | |
x_16_up, x_32_up (torch.Tensor, torch.Tensor): Two feature maps | |
undergoing upsampling from 1/16 and 1/32 downsampling | |
feature maps. These two feature maps are used for Feature | |
Fusion Module and Auxiliary Head. | |
""" | |
def __init__(self, | |
backbone_cfg, | |
context_channels=(128, 256, 512), | |
align_corners=False, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
assert len(context_channels) == 3, 'Length of input channels \ | |
of Context Path must be 3!' | |
self.backbone = MODELS.build(backbone_cfg) | |
self.align_corners = align_corners | |
self.arm16 = AttentionRefinementModule(context_channels[1], | |
context_channels[0]) | |
self.arm32 = AttentionRefinementModule(context_channels[2], | |
context_channels[0]) | |
self.conv_head32 = ConvModule( | |
in_channels=context_channels[0], | |
out_channels=context_channels[0], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.conv_head16 = ConvModule( | |
in_channels=context_channels[0], | |
out_channels=context_channels[0], | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.gap_conv = nn.Sequential( | |
nn.AdaptiveAvgPool2d((1, 1)), | |
ConvModule( | |
in_channels=context_channels[2], | |
out_channels=context_channels[0], | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
def forward(self, x): | |
x_4, x_8, x_16, x_32 = self.backbone(x) | |
x_gap = self.gap_conv(x_32) | |
x_32_arm = self.arm32(x_32) | |
x_32_sum = x_32_arm + x_gap | |
x_32_up = resize(input=x_32_sum, size=x_16.shape[2:], mode='nearest') | |
x_32_up = self.conv_head32(x_32_up) | |
x_16_arm = self.arm16(x_16) | |
x_16_sum = x_16_arm + x_32_up | |
x_16_up = resize(input=x_16_sum, size=x_8.shape[2:], mode='nearest') | |
x_16_up = self.conv_head16(x_16_up) | |
return x_16_up, x_32_up | |
class FeatureFusionModule(BaseModule): | |
"""Feature Fusion Module to fuse low level output feature of Spatial Path | |
and high level output feature of Context Path. | |
Args: | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels. | |
Returns: | |
x_out (torch.Tensor): Feature map of Feature Fusion Module. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.conv1 = ConvModule( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
self.gap = nn.AdaptiveAvgPool2d((1, 1)) | |
self.conv_atten = nn.Sequential( | |
ConvModule( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=False, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), nn.Sigmoid()) | |
def forward(self, x_sp, x_cp): | |
x_concat = torch.cat([x_sp, x_cp], dim=1) | |
x_fuse = self.conv1(x_concat) | |
x_atten = self.gap(x_fuse) | |
# Note: No BN and more 1x1 conv in paper. | |
x_atten = self.conv_atten(x_atten) | |
x_atten = x_fuse * x_atten | |
x_out = x_atten + x_fuse | |
return x_out | |
class BiSeNetV1(BaseModule): | |
"""BiSeNetV1 backbone. | |
This backbone is the implementation of `BiSeNet: Bilateral | |
Segmentation Network for Real-time Semantic | |
Segmentation <https://arxiv.org/abs/1808.00897>`_. | |
Args: | |
backbone_cfg:(dict): Config of backbone of | |
Context Path. | |
in_channels (int): The number of channels of input | |
image. Default: 3. | |
spatial_channels (Tuple[int]): Size of channel numbers of | |
various layers in Spatial Path. | |
Default: (64, 64, 64, 128). | |
context_channels (Tuple[int]): Size of channel numbers of | |
various modules in Context Path. | |
Default: (128, 256, 512). | |
out_indices (Tuple[int] | int, optional): Output from which stages. | |
Default: (0, 1, 2). | |
align_corners (bool, optional): The align_corners argument of | |
resize operation in Bilateral Guided Aggregation Layer. | |
Default: False. | |
out_channels(int): The number of channels of output. | |
It must be the same with `in_channels` of decode_head. | |
Default: 256. | |
""" | |
def __init__(self, | |
backbone_cfg, | |
in_channels=3, | |
spatial_channels=(64, 64, 64, 128), | |
context_channels=(128, 256, 512), | |
out_indices=(0, 1, 2), | |
align_corners=False, | |
out_channels=256, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN', requires_grad=True), | |
act_cfg=dict(type='ReLU'), | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
assert len(spatial_channels) == 4, 'Length of input channels \ | |
of Spatial Path must be 4!' | |
assert len(context_channels) == 3, 'Length of input channels \ | |
of Context Path must be 3!' | |
self.out_indices = out_indices | |
self.align_corners = align_corners | |
self.context_path = ContextPath(backbone_cfg, context_channels, | |
self.align_corners) | |
self.spatial_path = SpatialPath(in_channels, spatial_channels) | |
self.ffm = FeatureFusionModule(context_channels[1], out_channels) | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
def forward(self, x): | |
# stole refactoring code from Coin Cheung, thanks | |
x_context8, x_context16 = self.context_path(x) | |
x_spatial = self.spatial_path(x) | |
x_fuse = self.ffm(x_spatial, x_context8) | |
outs = [x_fuse, x_context8, x_context16] | |
outs = [outs[i] for i in self.out_indices] | |
return tuple(outs) | |