Spaces:
Runtime error
Runtime error
from functools import partial | |
from typing import Literal | |
import torch | |
import torch.nn as nn | |
from mmdet.registry import MODELS | |
from mmengine.model import BaseModule | |
from mmengine.logging import MMLogger | |
from ext.sam import ImageEncoderViT | |
from ext.meta.sam_meta import meta_dict, checkpoint_dict | |
from utils.load_checkpoint import load_checkpoint_with_prefix | |
class SAMBackbone(BaseModule): | |
def __init__( | |
self, | |
model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h', | |
fix: bool = True, | |
init_cfg=None, | |
): | |
assert init_cfg is not None and init_cfg['type'] in \ | |
['sam_pretrain', 'Pretrained'], f"{init_cfg['type']} is not supported." | |
pretrained = init_cfg['checkpoint'] | |
super().__init__(init_cfg=None) | |
self.init_cfg = init_cfg | |
self.logger = MMLogger.get_current_instance() | |
backbone_meta = meta_dict[model_name] | |
backbone = ImageEncoderViT( | |
depth=backbone_meta['encoder_depth'], | |
embed_dim=backbone_meta['encoder_embed_dim'], | |
num_heads=backbone_meta['encoder_num_heads'], | |
patch_size=backbone_meta['vit_patch_size'], | |
img_size=backbone_meta['image_size'], | |
global_attn_indexes=backbone_meta['encoder_global_attn_indexes'], | |
out_chans=backbone_meta['prompt_embed_dim'], | |
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), | |
qkv_bias=True, | |
use_rel_pos=True, | |
mlp_ratio=4, | |
window_size=14, | |
) | |
if self.init_cfg['type'] == 'sam_pretrain': | |
checkpoint_path = checkpoint_dict[pretrained] | |
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix='image_encoder') | |
backbone.load_state_dict(state_dict, strict=True) | |
self.stem = backbone.patch_embed | |
self.pos_embed = backbone.pos_embed | |
self.res_layers = [] | |
last_pos = 0 | |
for idx, cur_pos in enumerate(backbone_meta['encoder_global_attn_indexes']): | |
blocks = backbone.blocks[last_pos:cur_pos + 1] | |
layer_name = f'layer{idx + 1}' | |
self.add_module(layer_name, nn.Sequential(*blocks)) | |
self.res_layers.append(layer_name) | |
last_pos = cur_pos + 1 | |
self.out_proj = backbone.neck | |
if self.init_cfg['type'] == 'Pretrained': | |
checkpoint_path = pretrained | |
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) | |
self.load_state_dict(state_dict, strict=True) | |
self.model_name = model_name | |
self.fix = fix | |
self.model_type = 'vit' | |
self.output_channels = None | |
self.out_indices = (0, 1, 2, 3) | |
if self.fix: | |
self.train(mode=False) | |
for name, param in self.named_parameters(): | |
param.requires_grad = False | |
def init_weights(self): | |
self.logger.info(f"Init Config for {self.model_name}") | |
self.logger.info(self.init_cfg) | |
def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module: | |
if not isinstance(mode, bool): | |
raise ValueError("training mode is expected to be boolean") | |
if self.fix: | |
super().train(mode=False) | |
else: | |
super().train(mode=mode) | |
return self | |
def forward_func(self, x): | |
x = self.stem(x) | |
x = x + self.pos_embed | |
outs = [] | |
for i, layer_name in enumerate(self.res_layers): | |
res_layer = getattr(self, layer_name) | |
x = res_layer(x) | |
if i in self.out_indices: | |
outs.append(x.permute(0, 3, 1, 2).contiguous()) | |
outs[-1] = self.out_proj(outs[-1]) | |
return tuple(outs) | |
def forward(self, x): | |
if self.fix: | |
with torch.no_grad(): | |
outs = self.forward_func(x) | |
else: | |
outs = self.forward_func(x) | |
return outs | |