|
from transformers import PreTrainedModel |
|
import torch |
|
import torch.nn as nn |
|
from segformer_plusplus.utils import resize |
|
from segformer_plusplus.model.backbone.mit import MixVisionTransformer |
|
from mix_vision_transformer_config import MySegformerConfig |
|
|
|
|
|
class SegformerHead(nn.Module): |
|
def __init__(self, |
|
in_channels=[64, 128, 256, 512], |
|
in_index=[0, 1, 2, 3], |
|
channels=256, |
|
dropout_ratio=0.1, |
|
out_channels=19, |
|
norm_cfg=None, |
|
align_corners=False, |
|
interpolate_mode='bilinear'): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.in_index = in_index |
|
self.channels = channels |
|
self.dropout_ratio = dropout_ratio |
|
self.out_channels = out_channels |
|
self.norm_cfg = norm_cfg |
|
self.align_corners = align_corners |
|
self.interpolate_mode = interpolate_mode |
|
|
|
self.act_cfg = dict(type='ReLU') |
|
self.conv_seg = nn.Conv2d(channels, out_channels, kernel_size=1) |
|
self.dropout = nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else None |
|
|
|
num_inputs = len(in_channels) |
|
assert num_inputs == len(in_index) |
|
|
|
from segformer_plusplus.utils.activation import ConvModule |
|
|
|
self.convs = nn.ModuleList() |
|
for i in range(num_inputs): |
|
self.convs.append( |
|
ConvModule( |
|
in_channels=in_channels[i], |
|
out_channels=channels, |
|
kernel_size=1, |
|
stride=1, |
|
norm_cfg=norm_cfg, |
|
act_cfg=self.act_cfg)) |
|
|
|
self.fusion_conv = ConvModule( |
|
in_channels=channels * num_inputs, |
|
out_channels=channels, |
|
kernel_size=1, |
|
norm_cfg=norm_cfg) |
|
|
|
def cls_seg(self, feat): |
|
if self.dropout is not None: |
|
feat = self.dropout(feat) |
|
return self.conv_seg(feat) |
|
|
|
def forward(self, inputs): |
|
outs = [] |
|
for idx in range(len(inputs)): |
|
x = inputs[idx] |
|
conv = self.convs[idx] |
|
outs.append( |
|
resize( |
|
input=conv(x), |
|
size=inputs[0].shape[2:], |
|
mode=self.interpolate_mode, |
|
align_corners=self.align_corners)) |
|
|
|
out = self.fusion_conv(torch.cat(outs, dim=1)) |
|
out = self.cls_seg(out) |
|
return out |
|
|
|
|
|
class MySegformerForSemanticSegmentation(PreTrainedModel): |
|
config_class = MySegformerConfig |
|
base_model_prefix = "my_segformer" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
self.backbone = MixVisionTransformer( |
|
embed_dims=config.embed_dims, |
|
num_stages=config.num_stages, |
|
num_layers=config.num_layers, |
|
num_heads=config.num_heads, |
|
patch_sizes=config.patch_sizes, |
|
strides=config.strides, |
|
sr_ratios=config.sr_ratios, |
|
mlp_ratio=config.mlp_ratio, |
|
qkv_bias=config.qkv_bias, |
|
drop_rate=config.drop_rate, |
|
attn_drop_rate=config.attn_drop_rate, |
|
drop_path_rate=config.drop_path_rate, |
|
out_indices=config.out_indices |
|
) |
|
|
|
|
|
self.segmentation_head = SegformerHead( |
|
in_channels=[64, 128, 256, 512], |
|
out_channels=config.num_classes if hasattr(config, 'num_classes') else 19, |
|
dropout_ratio=0.1, |
|
align_corners=False |
|
) |
|
|
|
self.post_init() |
|
|
|
def forward(self, x): |
|
features = self.backbone(x) |
|
segmentation_output = self.segmentation_head(features) |
|
return segmentation_output |
|
|