Spaces:
Runtime error
Runtime error
File size: 11,532 Bytes
9cc3eb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
from typing import Optional, List
import torch
import torch.distributed as dist
import torch.nn as nn
from mmdet.registry import MODELS
from mmengine.model import BaseModule
from mmengine.dist import get_dist_info
from mmengine.logging import MMLogger
import ext.open_clip as open_clip
from utils.load_checkpoint import load_checkpoint_with_prefix
@MODELS.register_module()
class OpenCLIPBackbone(BaseModule):
"""OpenCLIPBackbone,
Please refer to:
https://github.com/mlfoundations/open_clip/tree/5f7892b672b21e6853d0f6c11b18dda9bcf36c8d#pretrained-model-interface
for the supported models and checkpoints.
"""
STAGES = 4
def __init__(
self,
img_size: int = 1024,
model_name: str = '',
fix: bool = True,
fix_layers: Optional[List] = None,
init_cfg=None,
):
assert init_cfg is not None and init_cfg['type'] in ['clip_pretrain', 'image_pretrain', 'Pretrained'], \
f"{init_cfg['type']} is not supported."
pretrained = init_cfg['checkpoint']
super().__init__(init_cfg=None)
self.init_cfg = init_cfg
self.logger = MMLogger.get_current_instance()
rank, world_size = get_dist_info()
if world_size > 1:
if rank == 0:
if init_cfg['type'] == 'clip_pretrain':
_ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained,
return_transform=False, logger=self.logger)
elif init_cfg['type'] == 'image_pretrain':
_ = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger)
else:
pass
dist.barrier()
# Get the clip model
if init_cfg['type'] == 'clip_pretrain':
clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained,
return_transform=False, logger=self.logger)
elif init_cfg['type'] == 'image_pretrain':
clip_model = open_clip.create_model(model_name, pretrained_image=True, logger=self.logger)
elif init_cfg['type'] == 'Pretrained':
clip_model = open_clip.create_model(model_name, pretrained_image=False, logger=self.logger)
else:
raise NotImplementedError
self.out_indices = (0, 1, 2, 3)
model_name_lower = model_name.lower()
if 'convnext_' in model_name_lower:
model_type = 'convnext'
if '_base' in model_name_lower:
output_channels = [128, 256, 512, 1024]
feat_size = 0
elif '_large' in model_name_lower:
output_channels = [192, 384, 768, 1536]
feat_size = 0
elif '_xxlarge' in model_name_lower:
output_channels = [384, 768, 1536, 3072]
feat_size = 0
else:
raise NotImplementedError(f"{model_name} not supported yet.")
elif 'rn' in model_name_lower:
model_type = 'resnet'
if model_name_lower.replace('-quickgelu', '') in ['rn50', 'rn101']:
output_channels = [256, 512, 1024, 2048]
feat_size = 7
elif model_name_lower == 'rn50x4':
output_channels = [320, 640, 1280, 2560]
feat_size = 9
elif model_name_lower == 'rn50x16':
output_channels = [384, 768, 1536, 3072]
feat_size = 12
elif model_name_lower == 'rn50x64':
output_channels = [512, 1024, 2048, 4096]
feat_size = 14
else:
raise NotImplementedError(f"{model_name} not supported yet.")
else:
raise NotImplementedError(f"{model_name} not supported yet.")
self.model_name = model_name
self.fix = fix
self.model_type = model_type
self.output_channels = output_channels
self.feat_size = feat_size
# Get the visual model
if self.model_type == 'resnet':
self.stem = nn.Sequential(*[
clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.act1,
clip_model.visual.conv2, clip_model.visual.bn2, clip_model.visual.act2,
clip_model.visual.conv3, clip_model.visual.bn3, clip_model.visual.act3,
])
elif self.model_type == 'convnext':
self.stem = clip_model.visual.trunk.stem
else:
raise ValueError
if self.model_type == 'resnet':
self.avgpool = clip_model.visual.avgpool
elif self.model_type == 'convnext':
self.avgpool = nn.Identity()
else:
raise ValueError
self.res_layers = []
for i in range(self.STAGES):
if self.model_type == 'resnet':
layer_name = f'layer{i + 1}'
layer = getattr(clip_model.visual, layer_name)
elif self.model_type == 'convnext':
layer_name = f'layer{i + 1}'
layer = clip_model.visual.trunk.stages[i]
else:
raise ValueError
self.add_module(layer_name, layer)
self.res_layers.append(layer_name)
if self.model_type == 'resnet':
self.norm_pre = nn.Identity()
elif self.model_type == 'convnext':
self.norm_pre = clip_model.visual.trunk.norm_pre
if self.model_type == 'resnet':
self.head = clip_model.visual.attnpool
elif self.model_type == 'convnext':
self.head = nn.Sequential(*[
clip_model.visual.trunk.head,
clip_model.visual.head,
])
if self.init_cfg['type'] == 'Pretrained':
checkpoint_path = pretrained
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix'])
self.load_state_dict(state_dict, strict=True)
self.fix_layers = fix_layers
if not self.fix:
self.train()
for name, param in self.norm_pre.named_parameters():
param.requires_grad = False
for name, param in self.head.named_parameters():
param.requires_grad = False
if self.fix_layers is not None:
for i, layer_name in enumerate(self.res_layers):
if i in self.fix_layers:
res_layer = getattr(self, layer_name)
for name, param in res_layer.named_parameters():
param.requires_grad = False
if self.fix:
self.train(mode=False)
for name, param in self.named_parameters():
param.requires_grad = False
def init_weights(self):
self.logger.info(f"Init Config for {self.model_name}")
self.logger.info(self.init_cfg)
def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
if self.fix:
super().train(mode=False)
else:
super().train(mode=mode)
if self.fix_layers is not None:
for i, layer_name in enumerate(self.res_layers):
if i in self.fix_layers:
res_layer = getattr(self, layer_name)
res_layer.train(mode=False)
return self
def forward_func(self, x):
x = self.stem(x)
x = self.avgpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x).contiguous()
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def get_clip_feature(self, backbone_feat):
if self.model_type == 'resnet':
return backbone_feat
elif self.model_type == 'convnext':
return self.norm_pre(backbone_feat)
raise NotImplementedError
def forward_feat(self, features):
if self.model_type == 'convnext':
batch, num_query, channel = features.shape
features = features.reshape(batch * num_query, channel, 1, 1)
features = self.head(features)
return features.view(batch, num_query, features.shape[-1])
elif self.model_type == 'resnet':
num_query, channel, seven, seven = features.shape
features = self.head(features)
return features
def forward(self, x):
if self.fix:
with torch.no_grad():
outs = self.forward_func(x)
else:
outs = self.forward_func(x)
return outs
def get_text_model(self):
return OpenCLIPBackboneText(
self.model_name,
init_cfg=self.init_cfg
)
@MODELS.register_module()
class OpenCLIPBackboneText(BaseModule):
def __init__(
self,
model_name: str = '',
init_cfg=None,
):
assert init_cfg is not None and init_cfg['type'] == 'clip_pretrain', f"{init_cfg['type']} is not supported."
pretrained = init_cfg['checkpoint']
super().__init__(init_cfg=None)
self.init_cfg = init_cfg
self.logger = MMLogger.get_current_instance()
rank, world_size = get_dist_info()
if world_size > 1:
if rank == 0:
_ = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False,
logger=self.logger)
else:
pass
dist.barrier()
# Get the clip model
clip_model = open_clip.create_model_from_pretrained(model_name, pretrained=pretrained, return_transform=False,
logger=self.logger)
# Get the textual model
self.text_tokenizer = open_clip.get_tokenizer(model_name)
self.text_transformer = clip_model.transformer
self.text_token_embedding = clip_model.token_embedding
self.text_pe = clip_model.positional_embedding
self.text_ln_final = clip_model.ln_final
self.text_proj = clip_model.text_projection
self.register_buffer('text_attn_mask', clip_model.attn_mask)
self.param_dtype = torch.float32
self.model_name = model_name
def init_weights(self):
self.logger.info(f"Init Config for {self.model_name}")
self.logger.info(self.init_cfg)
# Copied from
# https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L343
@torch.no_grad()
def forward(self, text):
text_tokens = self.text_tokenizer(text).to(device=self.text_proj.device)
x = self.text_token_embedding(text_tokens).to(self.param_dtype)
x = x + self.text_pe.to(self.param_dtype)
x = x.permute(1, 0, 2)
x = self.text_transformer(x, attn_mask=self.text_attn_mask)
x = x.permute(1, 0, 2)
x = self.text_ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text_tokens.argmax(dim=-1)] @ self.text_proj
return x
|