Spaces:
Runtime error
Runtime error
File size: 2,403 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmseg.registry import MODELS
@MODELS.register_module()
class Feature2Pyramid(nn.Module):
"""Feature2Pyramid.
A neck structure connect ViT backbone and decoder_heads.
Args:
embed_dims (int): Embedding dimension.
rescales (list[float]): Different sampling multiples were
used to obtain pyramid features. Default: [4, 2, 1, 0.5].
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='SyncBN', requires_grad=True).
"""
def __init__(self,
embed_dim,
rescales=[4, 2, 1, 0.5],
norm_cfg=dict(type='SyncBN', requires_grad=True)):
super().__init__()
self.rescales = rescales
self.upsample_4x = None
for k in self.rescales:
if k == 4:
self.upsample_4x = nn.Sequential(
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),
build_norm_layer(norm_cfg, embed_dim)[1],
nn.GELU(),
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2),
)
elif k == 2:
self.upsample_2x = nn.Sequential(
nn.ConvTranspose2d(
embed_dim, embed_dim, kernel_size=2, stride=2))
elif k == 1:
self.identity = nn.Identity()
elif k == 0.5:
self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2)
elif k == 0.25:
self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4)
else:
raise KeyError(f'invalid {k} for feature2pyramid')
def forward(self, inputs):
assert len(inputs) == len(self.rescales)
outputs = []
if self.upsample_4x is not None:
ops = [
self.upsample_4x, self.upsample_2x, self.identity,
self.downsample_2x
]
else:
ops = [
self.upsample_2x, self.identity, self.downsample_2x,
self.downsample_4x
]
for i in range(len(inputs)):
outputs.append(ops[i](inputs[i]))
return tuple(outputs)
|