SegformerPlusPlus / modeling_my_segformer.py
Tim77777767
Created Files for HF compatibility
02508fb
raw
history blame
4.1 kB
from transformers import PreTrainedModel
import torch
import torch.nn as nn
from segformer_plusplus.utils import resize
from segformer_plusplus.model.backbone.mit import MixVisionTransformer # deine Backbone-Importierung
from mix_vision_transformer_config import MySegformerConfig # deine Config-Importierung
# Head-Implementierung (etwas vereinfacht und angepasst)
class SegformerHead(nn.Module):
def __init__(self,
in_channels=[64, 128, 256, 512], # anpassen je nach Backbone-Ausgabe!
in_index=[0, 1, 2, 3],
channels=256,
dropout_ratio=0.1,
out_channels=19, # Anzahl Klassen, anpassen!
norm_cfg=None,
align_corners=False,
interpolate_mode='bilinear'):
super().__init__()
self.in_channels = in_channels
self.in_index = in_index
self.channels = channels
self.dropout_ratio = dropout_ratio
self.out_channels = out_channels
self.norm_cfg = norm_cfg
self.align_corners = align_corners
self.interpolate_mode = interpolate_mode
self.act_cfg = dict(type='ReLU')
self.conv_seg = nn.Conv2d(channels, out_channels, kernel_size=1)
self.dropout = nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else None
num_inputs = len(in_channels)
assert num_inputs == len(in_index)
from segformer_plusplus.utils.activation import ConvModule
self.convs = nn.ModuleList()
for i in range(num_inputs):
self.convs.append(
ConvModule(
in_channels=in_channels[i],
out_channels=channels,
kernel_size=1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=self.act_cfg))
self.fusion_conv = ConvModule(
in_channels=channels * num_inputs,
out_channels=channels,
kernel_size=1,
norm_cfg=norm_cfg)
def cls_seg(self, feat):
if self.dropout is not None:
feat = self.dropout(feat)
return self.conv_seg(feat)
def forward(self, inputs):
outs = []
for idx in range(len(inputs)):
x = inputs[idx]
conv = self.convs[idx]
outs.append(
resize(
input=conv(x),
size=inputs[0].shape[2:],
mode=self.interpolate_mode,
align_corners=self.align_corners))
out = self.fusion_conv(torch.cat(outs, dim=1))
out = self.cls_seg(out)
return out
class MySegformerForSemanticSegmentation(PreTrainedModel):
config_class = MySegformerConfig
base_model_prefix = "my_segformer"
def __init__(self, config):
super().__init__(config)
# Backbone initialisieren mit Parametern aus Config
self.backbone = MixVisionTransformer(
embed_dims=config.embed_dims,
num_stages=config.num_stages,
num_layers=config.num_layers,
num_heads=config.num_heads,
patch_sizes=config.patch_sizes,
strides=config.strides,
sr_ratios=config.sr_ratios,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
drop_rate=config.drop_rate,
attn_drop_rate=config.attn_drop_rate,
drop_path_rate=config.drop_path_rate,
out_indices=config.out_indices
)
# Head initialisieren, out_channels aus config oder fix setzen
self.segmentation_head = SegformerHead(
in_channels=[64, 128, 256, 512], # <- Anpassen, je nachdem wie Backbone ausgibt!
out_channels=config.num_classes if hasattr(config, 'num_classes') else 19,
dropout_ratio=0.1,
align_corners=False
)
self.post_init()
def forward(self, x):
features = self.backbone(x)
segmentation_output = self.segmentation_head(features)
return segmentation_output