Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
from mmcv.cnn import build_norm_layer | |
from mmseg.registry import MODELS | |
class Feature2Pyramid(nn.Module): | |
"""Feature2Pyramid. | |
A neck structure connect ViT backbone and decoder_heads. | |
Args: | |
embed_dims (int): Embedding dimension. | |
rescales (list[float]): Different sampling multiples were | |
used to obtain pyramid features. Default: [4, 2, 1, 0.5]. | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='SyncBN', requires_grad=True). | |
""" | |
def __init__(self, | |
embed_dim, | |
rescales=[4, 2, 1, 0.5], | |
norm_cfg=dict(type='SyncBN', requires_grad=True)): | |
super().__init__() | |
self.rescales = rescales | |
self.upsample_4x = None | |
for k in self.rescales: | |
if k == 4: | |
self.upsample_4x = nn.Sequential( | |
nn.ConvTranspose2d( | |
embed_dim, embed_dim, kernel_size=2, stride=2), | |
build_norm_layer(norm_cfg, embed_dim)[1], | |
nn.GELU(), | |
nn.ConvTranspose2d( | |
embed_dim, embed_dim, kernel_size=2, stride=2), | |
) | |
elif k == 2: | |
self.upsample_2x = nn.Sequential( | |
nn.ConvTranspose2d( | |
embed_dim, embed_dim, kernel_size=2, stride=2)) | |
elif k == 1: | |
self.identity = nn.Identity() | |
elif k == 0.5: | |
self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2) | |
elif k == 0.25: | |
self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4) | |
else: | |
raise KeyError(f'invalid {k} for feature2pyramid') | |
def forward(self, inputs): | |
assert len(inputs) == len(self.rescales) | |
outputs = [] | |
if self.upsample_4x is not None: | |
ops = [ | |
self.upsample_4x, self.upsample_2x, self.identity, | |
self.downsample_2x | |
] | |
else: | |
ops = [ | |
self.upsample_2x, self.identity, self.downsample_2x, | |
self.downsample_4x | |
] | |
for i in range(len(inputs)): | |
outputs.append(ops[i](inputs[i])) | |
return tuple(outputs) | |