from transformers import PreTrainedModel import torch import torch.nn as nn from segformer_plusplus.utils import resize from segformer_plusplus.model.backbone.mit import MixVisionTransformer # deine Backbone-Importierung from mix_vision_transformer_config import MySegformerConfig # deine Config-Importierung # Head-Implementierung (etwas vereinfacht und angepasst) class SegformerHead(nn.Module): def __init__(self, in_channels=[64, 128, 256, 512], # anpassen je nach Backbone-Ausgabe! in_index=[0, 1, 2, 3], channels=256, dropout_ratio=0.1, out_channels=19, # Anzahl Klassen, anpassen! norm_cfg=None, align_corners=False, interpolate_mode='bilinear'): super().__init__() self.in_channels = in_channels self.in_index = in_index self.channels = channels self.dropout_ratio = dropout_ratio self.out_channels = out_channels self.norm_cfg = norm_cfg self.align_corners = align_corners self.interpolate_mode = interpolate_mode self.act_cfg = dict(type='ReLU') self.conv_seg = nn.Conv2d(channels, out_channels, kernel_size=1) self.dropout = nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else None num_inputs = len(in_channels) assert num_inputs == len(in_index) from segformer_plusplus.utils.activation import ConvModule self.convs = nn.ModuleList() for i in range(num_inputs): self.convs.append( ConvModule( in_channels=in_channels[i], out_channels=channels, kernel_size=1, stride=1, norm_cfg=norm_cfg, act_cfg=self.act_cfg)) self.fusion_conv = ConvModule( in_channels=channels * num_inputs, out_channels=channels, kernel_size=1, norm_cfg=norm_cfg) def cls_seg(self, feat): if self.dropout is not None: feat = self.dropout(feat) return self.conv_seg(feat) def forward(self, inputs): outs = [] for idx in range(len(inputs)): x = inputs[idx] conv = self.convs[idx] outs.append( resize( input=conv(x), size=inputs[0].shape[2:], mode=self.interpolate_mode, align_corners=self.align_corners)) out = self.fusion_conv(torch.cat(outs, dim=1)) out = self.cls_seg(out) return out class MySegformerForSemanticSegmentation(PreTrainedModel): config_class = MySegformerConfig base_model_prefix = "my_segformer" def __init__(self, config): super().__init__(config) # Backbone initialisieren mit Parametern aus Config self.backbone = MixVisionTransformer( embed_dims=config.embed_dims, num_stages=config.num_stages, num_layers=config.num_layers, num_heads=config.num_heads, patch_sizes=config.patch_sizes, strides=config.strides, sr_ratios=config.sr_ratios, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, drop_rate=config.drop_rate, attn_drop_rate=config.attn_drop_rate, drop_path_rate=config.drop_path_rate, out_indices=config.out_indices ) # Head initialisieren, out_channels aus config oder fix setzen self.segmentation_head = SegformerHead( in_channels=[64, 128, 256, 512], # <- Anpassen, je nachdem wie Backbone ausgibt! out_channels=config.num_classes if hasattr(config, 'num_classes') else 19, dropout_ratio=0.1, align_corners=False ) self.post_init() def forward(self, x): features = self.backbone(x) segmentation_output = self.segmentation_head(features) return segmentation_output