import torch.nn as nn from additional_modules.segformer.backbone import mit_b0 from additional_modules.segformer.segformer_head import SegFormerHead from additional_modules.segformer.base_model import BaseSegmentor class EncoderDecoder(BaseSegmentor): """Encoder Decoder segmentors. EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. Note that auxiliary_head is only used for deep supervision during training, which could be dumped during inference. """ def __init__(self, neck=None, num_classes=256, auxiliary_head=None, train_cfg=None, test_cfg=None, pretrained=None): super(EncoderDecoder, self).__init__() self.num_classes = num_classes self.backbone = mit_b0() self._init_decode_head() self.train_cfg = train_cfg self.test_cfg = test_cfg self.init_weights(pretrained=pretrained) assert self.with_decode_head def _init_decode_head(self): """Initialize ``decode_head``""" self.decode_head = SegFormerHead( in_channels=[32, 64, 160, 256], in_index=[0, 1, 2, 3], feature_strides=[4, 8, 16, 32], channels=128, dropout_ratio=0.1, num_classes=self.num_classes, norm_cfg=dict(type='SyncBN', requires_grad=True), align_corners=False, decoder_params=dict(embed_dim=256), ) self.align_corners = self.decode_head.align_corners self.num_classes = self.decode_head.num_classes def init_weights(self, pretrained=None): """Initialize the weights in backbone and heads. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ super(EncoderDecoder, self).init_weights(pretrained) self.backbone.init_weights(pretrained=pretrained) self.decode_head.init_weights() if self.with_auxiliary_head: if isinstance(self.auxiliary_head, nn.ModuleList): for aux_head in self.auxiliary_head: aux_head.init_weights() else: self.auxiliary_head.init_weights() def extract_feat(self, img): """Extract features from images.""" x = self.backbone(img) if self.with_neck: x = self.neck(x) return x def forward(self, img): """Forward function for training. Args: img (Tensor): Input images. Returns: dict[str, Tensor]: a dictionary of loss components """ x = self.extract_feat(img) x = self.decode_head(x) return x