ovsam / app /models /sam_backbone.py
Haobo Yuan
Add model
9cc3eb2
from functools import partial
from typing import Literal
import torch
import torch.nn as nn
from mmdet.registry import MODELS
from mmengine.model import BaseModule
from mmengine.logging import MMLogger
from ext.sam import ImageEncoderViT
from ext.meta.sam_meta import meta_dict, checkpoint_dict
from utils.load_checkpoint import load_checkpoint_with_prefix
@MODELS.register_module()
class SAMBackbone(BaseModule):
def __init__(
self,
model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h',
fix: bool = True,
init_cfg=None,
):
assert init_cfg is not None and init_cfg['type'] in \
['sam_pretrain', 'Pretrained'], f"{init_cfg['type']} is not supported."
pretrained = init_cfg['checkpoint']
super().__init__(init_cfg=None)
self.init_cfg = init_cfg
self.logger = MMLogger.get_current_instance()
backbone_meta = meta_dict[model_name]
backbone = ImageEncoderViT(
depth=backbone_meta['encoder_depth'],
embed_dim=backbone_meta['encoder_embed_dim'],
num_heads=backbone_meta['encoder_num_heads'],
patch_size=backbone_meta['vit_patch_size'],
img_size=backbone_meta['image_size'],
global_attn_indexes=backbone_meta['encoder_global_attn_indexes'],
out_chans=backbone_meta['prompt_embed_dim'],
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
qkv_bias=True,
use_rel_pos=True,
mlp_ratio=4,
window_size=14,
)
if self.init_cfg['type'] == 'sam_pretrain':
checkpoint_path = checkpoint_dict[pretrained]
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix='image_encoder')
backbone.load_state_dict(state_dict, strict=True)
self.stem = backbone.patch_embed
self.pos_embed = backbone.pos_embed
self.res_layers = []
last_pos = 0
for idx, cur_pos in enumerate(backbone_meta['encoder_global_attn_indexes']):
blocks = backbone.blocks[last_pos:cur_pos + 1]
layer_name = f'layer{idx + 1}'
self.add_module(layer_name, nn.Sequential(*blocks))
self.res_layers.append(layer_name)
last_pos = cur_pos + 1
self.out_proj = backbone.neck
if self.init_cfg['type'] == 'Pretrained':
checkpoint_path = pretrained
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix'])
self.load_state_dict(state_dict, strict=True)
self.model_name = model_name
self.fix = fix
self.model_type = 'vit'
self.output_channels = None
self.out_indices = (0, 1, 2, 3)
if self.fix:
self.train(mode=False)
for name, param in self.named_parameters():
param.requires_grad = False
def init_weights(self):
self.logger.info(f"Init Config for {self.model_name}")
self.logger.info(self.init_cfg)
def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
if self.fix:
super().train(mode=False)
else:
super().train(mode=mode)
return self
def forward_func(self, x):
x = self.stem(x)
x = x + self.pos_embed
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x.permute(0, 3, 1, 2).contiguous())
outs[-1] = self.out_proj(outs[-1])
return tuple(outs)
def forward(self, x):
if self.fix:
with torch.no_grad():
outs = self.forward_func(x)
else:
outs = self.forward_func(x)
return outs