ovsam / app /models /transformer_neck.py
Haobo Yuan
Add model
9cc3eb2
raw
history blame
5.61 kB
from functools import partial
from typing import Tuple, List, Optional
import torch
from torch import Tensor, nn
from mmengine.model import BaseModule, normal_init
from mmdet.registry import MODELS
from mmdet.models.layers import PatchEmbed
from ext.meta.sam_meta import checkpoint_dict
from ext.sam.common import LayerNorm2d
from ext.sam.image_encoder import Block
from utils.load_checkpoint import load_checkpoint_with_prefix
@MODELS.register_module()
class MultiLayerTransformerNeck(BaseModule):
STRIDE = 16
def __init__(
self,
input_size: Tuple[int, int],
in_channels: List[int],
embed_channels: int,
out_channels: int,
layer_ids: Tuple[int] = (0, 1, 2, 3),
strides: Tuple[int] = (4, 8, 16, 32),
embedding_path: Optional[str] = None,
fix=False,
init_cfg=None
) -> None:
super().__init__(init_cfg=None)
self.transformer_size = (input_size[0] // self.STRIDE, input_size[1] // self.STRIDE)
self.layer_ids = layer_ids
self.patch_embeds = nn.ModuleList()
for idx, in_ch in enumerate(in_channels):
if idx in layer_ids:
if strides[idx] > self.STRIDE:
patch_embed = PatchEmbed(
conv_type=nn.ConvTranspose2d,
in_channels=in_ch,
embed_dims=embed_channels,
kernel_size=strides[idx] // self.STRIDE,
stride=strides[idx] // self.STRIDE,
input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx])
)
else:
patch_embed = PatchEmbed(
in_channels=in_ch,
embed_dims=embed_channels,
kernel_size=self.STRIDE // strides[idx],
stride=self.STRIDE // strides[idx],
input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx])
)
self.patch_embeds.append(patch_embed)
else:
self.patch_embeds.append(nn.Identity())
if embedding_path is not None:
assert embedding_path.startswith('sam_')
embedding_ckpt = embedding_path.split('_', maxsplit=1)[1]
path = checkpoint_dict[embedding_ckpt]
state_dict = load_checkpoint_with_prefix(path, prefix='image_encoder')
pos_embed = state_dict['pos_embed']
else:
# For loading from checkpoint
pos_embed = torch.zeros(1, input_size[0] // self.STRIDE, input_size[1] // self.STRIDE, embed_channels)
self.register_buffer('pos_embed', pos_embed)
self.level_encoding = nn.Embedding(len(layer_ids), embed_channels)
depth = 5
global_attn_indexes = [4]
window_size = 14
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_channels,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
use_rel_pos=True,
rel_pos_zero_init=True,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=self.transformer_size,
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_channels,
out_channels,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_channels),
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_channels),
)
self.fix = fix
if self.fix:
self.train(mode=False)
for name, param in self.named_parameters():
param.requires_grad = False
if init_cfg is not None:
assert init_cfg['type'] == 'Pretrained'
checkpoint_path = init_cfg['checkpoint']
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=init_cfg['prefix'])
self.load_state_dict(state_dict, strict=True)
self._is_init = True
def init_weights(self):
normal_init(self.level_encoding, mean=0, std=1)
def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
if self.fix:
super().train(mode=False)
else:
super().train(mode=mode)
return self
def forward(self, inputs: Tuple[Tensor]) -> Tensor:
input_embeddings = []
level_cnt = 0
for idx, feat in enumerate(inputs):
if idx not in self.layer_ids:
continue
feat, size = self.patch_embeds[idx](feat)
feat = feat.unflatten(1, size)
feat = feat + self.level_encoding.weight[level_cnt]
input_embeddings.append(feat)
level_cnt += 1
feat = sum(input_embeddings)
feat = feat + self.pos_embed
for block in self.blocks:
feat = block(feat)
feat = feat.permute(0, 3, 1, 2).contiguous()
feat = self.neck(feat)
return feat