File size: 11,532 Bytes
9cc3eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
from typing import Optional, List

import torch
import torch.distributed as dist
import torch.nn as nn
from mmdet.registry import MODELS

from mmengine.model import BaseModule
from mmengine.dist import get_dist_info
from mmengine.logging import MMLogger

import ext.open_clip as open_clip
from utils.load_checkpoint import load_checkpoint_with_prefix


@MODELS.register_module()
class OpenCLIPBackbone(BaseModule):
    """OpenCLIPBackbone,
    Please refer to:
    https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface
    for the supported models and checkpoints.
    """
    STAGES = 4

    def __init__(
            self,
            img_size: int = 1024,
            model_name: str = '',
            fix: bool = True,
            fix_layers: Optional[List] = None,
            init_cfg=None,
    ):
        assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \
            f"{init_cfg['type']} is not supported."
        pretrained = init_cfg['checkpoint']
        super().__init__(init_cfg=None)
        self.init_cfg = init_cfg
        self.logger = MMLogger.get_current_instance()
        rank, world_size = get_dist_info()

        if world_size > 1:
            if rank == 0:
                if init_cfg['type'] == 'clip_pretrain':
                    _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained,
                                                               return_transform=False, logger=self.logger)
                elif init_cfg['type'] == 'image_pretrain':
                    _ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger)

            else:
                pass
            dist.barrier()

        # Get the clip model
        if init_cfg['type'] == 'clip_pretrain':
            clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained,
                                                                return_transform=False, logger=self.logger)
        elif init_cfg['type'] == 'image_pretrain':
            clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger)
        elif init_cfg['type'] == 'Pretrained':
            clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger)
        else:
            raise NotImplementedError

        self.out_indices = (0, 1, 2, 3)
        model_name_lower = model_name.lower()
        if 'convnext_' in model_name_lower:
            model_type = 'convnext'
            if '_base' in model_name_lower:
                output_channels = [128, 256, 512, 1024]
                feat_size = 0
            elif '_large' in model_name_lower:
                output_channels = [192, 384, 768, 1536]
                feat_size = 0
            elif '_xxlarge' in model_name_lower:
                output_channels = [384, 768, 1536, 3072]
                feat_size = 0
            else:
                raise NotImplementedError(f"{model_name} not supported yet.")
        elif 'rn' in model_name_lower:
            model_type = 'resnet'
            if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']:
                output_channels = [256, 512, 1024, 2048]
                feat_size = 7
            elif model_name_lower == 'rn50x4':
                output_channels = [320, 640, 1280, 2560]
                feat_size = 9
            elif model_name_lower == 'rn50x16':
                output_channels = [384, 768, 1536, 3072]
                feat_size = 12
            elif model_name_lower == 'rn50x64':
                output_channels = [512, 1024, 2048, 4096]
                feat_size = 14
            else:
                raise NotImplementedError(f"{model_name} not supported yet.")
        else:
            raise NotImplementedError(f"{model_name} not supported yet.")

        self.model_name = model_name
        self.fix = fix
        self.model_type = model_type
        self.output_channels = output_channels
        self.feat_size = feat_size

        # Get the visual model
        if self.model_type == 'resnet':
            self.stem = nn.Sequential(*[
                clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1,
                clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2,
                clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3,
            ])
        elif self.model_type == 'convnext':
            self.stem = clip_model.visual.trunk.stem
        else:
            raise ValueError

        if self.model_type == 'resnet':
            self.avgpool = clip_model.visual.avgpool
        elif self.model_type == 'convnext':
            self.avgpool = nn.Identity()
        else:
            raise ValueError

        self.res_layers = []
        for i in range(self.STAGES):
            if self.model_type == 'resnet':
                layer_name = f'layer{i + 1}'
                layer = getattr(clip_model.visual, layer_name)
            elif self.model_type == 'convnext':
                layer_name = f'layer{i + 1}'
                layer = clip_model.visual.trunk.stages[i]
            else:
                raise ValueError
            self.add_module(layer_name, layer)
            self.res_layers.append(layer_name)

        if self.model_type == 'resnet':
            self.norm_pre = nn.Identity()
        elif self.model_type == 'convnext':
            self.norm_pre = clip_model.visual.trunk.norm_pre

        if self.model_type == 'resnet':
            self.head = clip_model.visual.attnpool
        elif self.model_type == 'convnext':
            self.head = nn.Sequential(*[
                clip_model.visual.trunk.head,
                clip_model.visual.head,
            ])

        if self.init_cfg['type'] == 'Pretrained':
            checkpoint_path = pretrained
            state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix'])
            self.load_state_dict(state_dict, strict=True)

        self.fix_layers = fix_layers

        if not self.fix:
            self.train()
            for name, param in self.norm_pre.named_parameters():
                param.requires_grad = False
            for name, param in self.head.named_parameters():
                param.requires_grad = False
            if self.fix_layers is not None:
                for i, layer_name in enumerate(self.res_layers):
                    if i in self.fix_layers:
                        res_layer = getattr(self, layer_name)
                        for name, param in res_layer.named_parameters():
                            param.requires_grad = False

        if self.fix:
            self.train(mode=False)
            for name, param in self.named_parameters():
                param.requires_grad = False

    def init_weights(self):
        self.logger.info(f"Init Config for {self.model_name}")
        self.logger.info(self.init_cfg)

    def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        if self.fix:
            super().train(mode=False)
        else:
            super().train(mode=mode)
            if self.fix_layers is not None:
                for i, layer_name in enumerate(self.res_layers):
                    if i in self.fix_layers:
                        res_layer = getattr(self, layer_name)
                        res_layer.train(mode=False)
        return self

    def forward_func(self, x):
        x = self.stem(x)
        x = self.avgpool(x)
        outs = []
        for i, layer_name in enumerate(self.res_layers):
            res_layer = getattr(self, layer_name)
            x = res_layer(x).contiguous()
            if i in self.out_indices:
                outs.append(x)
        return tuple(outs)

    def get_clip_feature(self, backbone_feat):
        if self.model_type == 'resnet':
            return backbone_feat
        elif self.model_type == 'convnext':
            return self.norm_pre(backbone_feat)
        raise NotImplementedError

    def forward_feat(self, features):
        if self.model_type == 'convnext':
            batch, num_query, channel = features.shape
            features = features.reshape(batch * num_query, channel, 1, 1)
            features = self.head(features)
            return features.view(batch, num_query, features.shape[-1])
        elif self.model_type == 'resnet':
            num_query, channel, seven, seven = features.shape
            features = self.head(features)
            return features

    def forward(self, x):
        if self.fix:
            with torch.no_grad():
                outs = self.forward_func(x)
        else:
            outs = self.forward_func(x)
        return outs

    def get_text_model(self):
        return OpenCLIPBackboneText(
            self.model_name,
            init_cfg=self.init_cfg
        )


@MODELS.register_module()
class OpenCLIPBackboneText(BaseModule):
    def __init__(
            self,
            model_name: str = '',
            init_cfg=None,
    ):
        assert init_cfg is not None and init_cfg['type'] == 'clip_pretrain', f"{init_cfg['type']} is not supported."
        pretrained = init_cfg['checkpoint']
        super().__init__(init_cfg=None)
        self.init_cfg = init_cfg
        self.logger = MMLogger.get_current_instance()
        rank, world_size = get_dist_info()

        if world_size > 1:
            if rank == 0:
                _ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False,
                                                           logger=self.logger)
            else:
                pass
            dist.barrier()

        # Get the clip model
        clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False,
                                                            logger=self.logger)

        # Get the textual model
        self.text_tokenizer = open_clip.get_tokenizer(model_name)
        self.text_transformer = clip_model.transformer
        self.text_token_embedding = clip_model.token_embedding
        self.text_pe = clip_model.positional_embedding
        self.text_ln_final = clip_model.ln_final
        self.text_proj = clip_model.text_projection

        self.register_buffer('text_attn_mask', clip_model.attn_mask)

        self.param_dtype = torch.float32
        self.model_name = model_name

    def init_weights(self):
        self.logger.info(f"Init Config for {self.model_name}")
        self.logger.info(self.init_cfg)

    # Copied from
    # https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L343
    @torch.no_grad()
    def forward(self, text):
        text_tokens = self.text_tokenizer(text).to(device=self.text_proj.device)
        x = self.text_token_embedding(text_tokens).to(self.param_dtype)
        x = x + self.text_pe.to(self.param_dtype)
        x = x.permute(1, 0, 2)
        x = self.text_transformer(x, attn_mask=self.text_attn_mask)
        x = x.permute(1, 0, 2)
        x = self.text_ln_final(x)  # [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text_tokens.argmax(dim=-1)] @ self.text_proj
        return x