# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.cnn import build_norm_layer from mmseg.registry import MODELS @MODELS.register_module() class Feature2Pyramid(nn.Module): """Feature2Pyramid. A neck structure connect ViT backbone and decoder_heads. Args: embed_dims (int): Embedding dimension. rescales (list[float]): Different sampling multiples were used to obtain pyramid features. Default: [4, 2, 1, 0.5]. norm_cfg (dict): Config dict for normalization layer. Default: dict(type='SyncBN', requires_grad=True). """ def __init__(self, embed_dim, rescales=[4, 2, 1, 0.5], norm_cfg=dict(type='SyncBN', requires_grad=True)): super().__init__() self.rescales = rescales self.upsample_4x = None for k in self.rescales: if k == 4: self.upsample_4x = nn.Sequential( nn.ConvTranspose2d( embed_dim, embed_dim, kernel_size=2, stride=2), build_norm_layer(norm_cfg, embed_dim)[1], nn.GELU(), nn.ConvTranspose2d( embed_dim, embed_dim, kernel_size=2, stride=2), ) elif k == 2: self.upsample_2x = nn.Sequential( nn.ConvTranspose2d( embed_dim, embed_dim, kernel_size=2, stride=2)) elif k == 1: self.identity = nn.Identity() elif k == 0.5: self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) elif k == 0.25: self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) else: raise KeyError(f'invalid {k} for feature2pyramid') def forward(self, inputs): assert len(inputs) == len(self.rescales) outputs = [] if self.upsample_4x is not None: ops = [ self.upsample_4x, self.upsample_2x, self.identity, self.downsample_2x ] else: ops = [ self.upsample_2x, self.identity, self.downsample_2x, self.downsample_4x ] for i in range(len(inputs)): outputs.append(ops[i](inputs[i])) return tuple(outputs)