File size: 2,079 Bytes
02508fb
 
66c5431
02508fb
66c5431
 
 
02508fb
 
 
 
 
 
 
 
 
66c5431
02508fb
66c5431
02508fb
 
 
 
 
 
 
 
 
 
 
 
 
 
66c5431
d31b749
1a260cd
02508fb
66c5431
 
 
02508fb
 
 
 
 
e4634c2
 
66c5431
 
e4634c2
9659a3a
 
 
 
66c5431
 
e4634c2
66c5431
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
from transformers import PreTrainedModel

from segformer_plusplus.model.backbone.mit import MixVisionTransformer  # Backbone
from mix_vision_transformer_config import MySegformerConfig             # Config
from segformer_plusplus.model.head.segformer_head import SegformerHead # <-- dein Head


class MySegformerForSemanticSegmentation(PreTrainedModel):
    config_class = MySegformerConfig
    base_model_prefix = "my_segformer"

    def __init__(self, config):
        super().__init__(config)

        # Backbone (MixVisionTransformer)
        self.backbone = MixVisionTransformer(
            embed_dims=config.embed_dims,       # z.B. [64, 128, 320, 512]
            num_stages=config.num_stages,
            num_layers=config.num_layers,
            num_heads=config.num_heads,
            patch_sizes=config.patch_sizes,
            strides=config.strides,
            sr_ratios=config.sr_ratios,
            mlp_ratio=config.mlp_ratio,
            qkv_bias=config.qkv_bias,
            drop_rate=config.drop_rate,
            attn_drop_rate=config.attn_drop_rate,
            drop_path_rate=config.drop_path_rate,
            out_indices=config.out_indices
        )

        # Head direkt importieren
        in_channels = [64, 128, 320, 512]

        self.segmentation_head = SegformerHead(
            in_channels=in_channels,                # Liste der Embeddings aus Backbone
            in_index=list(config.out_indices),      # welche Feature Maps genutzt werden
            out_channels=getattr(config, "num_classes", 19),  # Anzahl Klassen
            dropout_ratio=0.1,
            align_corners=False
        )

        self.post_init()

    def forward(self, x):
        # Backbone → Features (Liste von Tensors)
        features = self.backbone(x)

                # Debug: Ausgabe der Shapes der Backbone-Features
        for i, f in enumerate(features):
            print(f"Feature {i}: shape = {f.shape}")

        # Head → logits
        logits = self.segmentation_head(features)

        return {"logits": logits}