Spaces:
Runtime error
Runtime error
File size: 16,815 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 |
# Copyright (c) OpenMMLab. All rights reserved.
# Originally from https://github.com/visual-attention-network/segnext
# Licensed under the Apache License, Version 2.0 (the "License")
import math
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule
from mmengine.model.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmseg.registry import MODELS
class Mlp(BaseModule):
"""Multi Layer Perceptron (MLP) Module.
Args:
in_features (int): The dimension of input features.
hidden_features (int): The dimension of hidden features.
Defaults: None.
out_features (int): The dimension of output features.
Defaults: None.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
drop (float): The number of dropout rate in MLP block.
Defaults: 0.0.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type='GELU'),
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.dwconv = nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=True,
groups=hidden_features)
self.act = build_activation_layer(act_cfg)
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
"""Forward function."""
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class StemConv(BaseModule):
"""Stem Block at the beginning of Semantic Branch.
Args:
in_channels (int): The dimension of input channels.
out_channels (int): The dimension of output channels.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
in_channels,
out_channels,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.proj = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels // 2,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels // 2)[1],
build_activation_layer(act_cfg),
nn.Conv2d(
out_channels // 2,
out_channels,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)),
build_norm_layer(norm_cfg, out_channels)[1],
)
def forward(self, x):
"""Forward function."""
x = self.proj(x)
_, _, H, W = x.size()
x = x.flatten(2).transpose(1, 2)
return x, H, W
class MSCAAttention(BaseModule):
"""Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
Args:
channels (int): The dimension of channels.
kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
"""
def __init__(self,
channels,
kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
paddings=[2, [0, 3], [0, 5], [0, 10]]):
super().__init__()
self.conv0 = nn.Conv2d(
channels,
channels,
kernel_size=kernel_sizes[0],
padding=paddings[0],
groups=channels)
for i, (kernel_size,
padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
kernel_size_ = [kernel_size, kernel_size[::-1]]
padding_ = [padding, padding[::-1]]
conv_name = [f'conv{i}_1', f'conv{i}_2']
for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_,
conv_name):
self.add_module(
i_conv,
nn.Conv2d(
channels,
channels,
tuple(i_kernel),
padding=i_pad,
groups=channels))
self.conv3 = nn.Conv2d(channels, channels, 1)
def forward(self, x):
"""Forward function."""
u = x.clone()
attn = self.conv0(x)
# Multi-Scale Feature extraction
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
# Channel Mixing
attn = self.conv3(attn)
# Convolutional Attention
x = attn * u
return x
class MSCASpatialAttention(BaseModule):
"""Spatial Attention Module in Multi-Scale Convolutional Attention Module
(MSCA).
Args:
in_channels (int): The dimension of channels.
attention_kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU')):
super().__init__()
self.proj_1 = nn.Conv2d(in_channels, in_channels, 1)
self.activation = build_activation_layer(act_cfg)
self.spatial_gating_unit = MSCAAttention(in_channels,
attention_kernel_sizes,
attention_kernel_paddings)
self.proj_2 = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
"""Forward function."""
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x
class MSCABlock(BaseModule):
"""Basic Multi-Scale Convolutional Attention Block. It leverage the large-
kernel attention (LKA) mechanism to build both channel and spatial
attention. In each branch, it uses two depth-wise strip convolutions to
approximate standard depth-wise convolutions with large kernels. The kernel
size for each branch is set to 7, 11, and 21, respectively.
Args:
channels (int): The dimension of channels.
attention_kernel_sizes (list): The size of attention
kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): The number of
corresponding padding value in attention module.
Defaults: [2, [0, 3], [0, 5], [0, 10]].
mlp_ratio (float): The ratio of multiple input dimension to
calculate hidden feature in MLP layer. Defaults: 4.0.
drop (float): The number of dropout rate in MLP block.
Defaults: 0.0.
drop_path (float): The ratio of drop paths.
Defaults: 0.0.
act_cfg (dict): Config dict for activation layer in block.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
channels,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.norm1 = build_norm_layer(norm_cfg, channels)[1]
self.attn = MSCASpatialAttention(channels, attention_kernel_sizes,
attention_kernel_paddings, act_cfg)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, channels)[1]
mlp_hidden_channels = int(channels * mlp_ratio)
self.mlp = Mlp(
in_features=channels,
hidden_features=mlp_hidden_channels,
act_cfg=act_cfg,
drop=drop)
layer_scale_init_value = 1e-2
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones(channels), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones(channels), requires_grad=True)
def forward(self, x, H, W):
"""Forward function."""
B, N, C = x.shape
x = x.permute(0, 2, 1).view(B, C, H, W)
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
self.attn(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
self.mlp(self.norm2(x)))
x = x.view(B, C, N).permute(0, 2, 1)
return x
class OverlapPatchEmbed(BaseModule):
"""Image to Patch Embedding.
Args:
patch_size (int): The patch size.
Defaults: 7.
stride (int): Stride of the convolutional layer.
Default: 4.
in_channels (int): The number of input channels.
Defaults: 3.
embed_dims (int): The dimensions of embedding.
Defaults: 768.
norm_cfg (dict): Config dict for normalization layer.
Defaults: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
patch_size=7,
stride=4,
in_channels=3,
embed_dim=768,
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.proj = nn.Conv2d(
in_channels,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=patch_size // 2)
self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
def forward(self, x):
"""Forward function."""
x = self.proj(x)
_, _, H, W = x.shape
x = self.norm(x)
x = x.flatten(2).transpose(1, 2)
return x, H, W
@MODELS.register_module()
class MSCAN(BaseModule):
"""SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
This backbone is the implementation of `SegNeXt: Rethinking
Convolutional Attention Design for Semantic
Segmentation <https://arxiv.org/abs/2209.08575>`_.
Inspiration from https://github.com/visual-attention-network/segnext.
Args:
in_channels (int): The number of input channels. Defaults: 3.
embed_dims (list[int]): Embedding dimension.
Defaults: [64, 128, 256, 512].
mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim.
Defaults: [4, 4, 4, 4].
drop_rate (float): Dropout rate. Defaults: 0.
drop_path_rate (float): Stochastic depth rate. Defaults: 0.
depths (list[int]): Depths of each Swin Transformer stage.
Default: [3, 4, 6, 3].
num_stages (int): MSCAN stages. Default: 4.
attention_kernel_sizes (list): Size of attention kernel in
Attention Module (Figure 2(b) of original paper).
Defaults: [5, [1, 7], [1, 11], [1, 21]].
attention_kernel_paddings (list): Size of attention paddings
in Attention Module (Figure 2(b) of original paper).
Defaults: [2, [0, 3], [0, 5], [0, 10]].
norm_cfg (dict): Config of norm layers.
Defaults: dict(type='SyncBN', requires_grad=True).
pretrained (str, optional): model pretrained path.
Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=[64, 128, 256, 512],
mlp_ratios=[4, 4, 4, 4],
drop_rate=0.,
drop_path_rate=0.,
depths=[3, 4, 6, 3],
num_stages=4,
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='SyncBN', requires_grad=True),
pretrained=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
self.num_stages = num_stages
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
cur = 0
for i in range(num_stages):
if i == 0:
patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
else:
patch_embed = OverlapPatchEmbed(
patch_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
in_channels=in_channels if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
norm_cfg=norm_cfg)
block = nn.ModuleList([
MSCABlock(
channels=embed_dims[i],
attention_kernel_sizes=attention_kernel_sizes,
attention_kernel_paddings=attention_kernel_paddings,
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
drop_path=dpr[cur + j],
act_cfg=act_cfg,
norm_cfg=norm_cfg) for j in range(depths[i])
])
norm = nn.LayerNorm(embed_dims[i])
cur += depths[i]
setattr(self, f'patch_embed{i + 1}', patch_embed)
setattr(self, f'block{i + 1}', block)
setattr(self, f'norm{i + 1}', norm)
def init_weights(self):
"""Initialize modules of MSCAN."""
print('init cfg', self.init_cfg)
if self.init_cfg is None:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
else:
super().init_weights()
def forward(self, x):
"""Forward function."""
B = x.shape[0]
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f'patch_embed{i + 1}')
block = getattr(self, f'block{i + 1}')
norm = getattr(self, f'norm{i + 1}')
x, H, W = patch_embed(x)
for blk in block:
x = blk(x, H, W)
x = norm(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
|