File size: 3,999 Bytes
9cc3eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from functools import partial
from typing import Literal

import torch
import torch.nn as nn
from mmdet.registry import MODELS

from mmengine.model import BaseModule
from mmengine.logging import MMLogger

from ext.sam import ImageEncoderViT
from ext.meta.sam_meta import meta_dict, checkpoint_dict
from utils.load_checkpoint import load_checkpoint_with_prefix


@MODELS.register_module()
class SAMBackbone(BaseModule):

    def __init__(
            self,
            model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h',
            fix: bool = True,
            init_cfg=None,
    ):
        assert init_cfg is not None and init_cfg['type'] in \
               ['sam_pretrain', 'Pretrained'], f"{init_cfg['type']} is not supported."
        pretrained = init_cfg['checkpoint']
        super().__init__(init_cfg=None)
        self.init_cfg = init_cfg
        self.logger = MMLogger.get_current_instance()

        backbone_meta = meta_dict[model_name]

        backbone = ImageEncoderViT(
            depth=backbone_meta['encoder_depth'],
            embed_dim=backbone_meta['encoder_embed_dim'],
            num_heads=backbone_meta['encoder_num_heads'],
            patch_size=backbone_meta['vit_patch_size'],
            img_size=backbone_meta['image_size'],
            global_attn_indexes=backbone_meta['encoder_global_attn_indexes'],
            out_chans=backbone_meta['prompt_embed_dim'],
            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
            qkv_bias=True,
            use_rel_pos=True,
            mlp_ratio=4,
            window_size=14,
        )
        if self.init_cfg['type'] == 'sam_pretrain':
            checkpoint_path = checkpoint_dict[pretrained]
            state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix='image_encoder')
            backbone.load_state_dict(state_dict, strict=True)

        self.stem = backbone.patch_embed
        self.pos_embed = backbone.pos_embed

        self.res_layers = []
        last_pos = 0
        for idx, cur_pos in enumerate(backbone_meta['encoder_global_attn_indexes']):
            blocks = backbone.blocks[last_pos:cur_pos + 1]
            layer_name = f'layer{idx + 1}'
            self.add_module(layer_name, nn.Sequential(*blocks))
            self.res_layers.append(layer_name)
            last_pos = cur_pos + 1

        self.out_proj = backbone.neck

        if self.init_cfg['type'] == 'Pretrained':
            checkpoint_path = pretrained
            state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix'])
            self.load_state_dict(state_dict, strict=True)

        self.model_name = model_name
        self.fix = fix
        self.model_type = 'vit'
        self.output_channels = None
        self.out_indices = (0, 1, 2, 3)
        if self.fix:
            self.train(mode=False)
            for name, param in self.named_parameters():
                param.requires_grad = False

    def init_weights(self):
        self.logger.info(f"Init Config for {self.model_name}")
        self.logger.info(self.init_cfg)

    def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        if self.fix:
            super().train(mode=False)
        else:
            super().train(mode=mode)
        return self

    def forward_func(self, x):
        x = self.stem(x)
        x = x + self.pos_embed
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x)
            if i in self.out_indices:
                outs.append(x.permute(0, 3, 1, 2).contiguous())
        outs[-1] = self.out_proj(outs[-1])
        return tuple(outs)

    def forward(self, x):
        if self.fix:
            with torch.no_grad():
                outs = self.forward_func(x)
        else:
            outs = self.forward_func(x)
        return outs