VRIS_vip / RIS-DMMI /lib /segmentation.py
dianecy's picture
Upload folder using huggingface_hub
0b32e3c verified
raw
history blame
3.06 kB
from .mask_predictor import Decoder
from .backbone import MultiModalSwinTransformer
from .backbone_resnet import MultiModalResNet
from ._utils import DMMI
from ._utils_mracl import DMMI_MRACL
__all__ = ['dmmi_swin', 'dmmi_res', 'dmmi_swin_hardpos_only']
# DMMI based on swin-transformer
def _segm_dmmi_swin(pretrained, args):
# initialize the SwinTransformer backbone with the specified version
if args.swin_type == 'tiny':
embed_dim = 96
depths = [2, 2, 6, 2]
num_heads = [3, 6, 12, 24]
elif args.swin_type == 'small':
embed_dim = 96
depths = [2, 2, 18, 2]
num_heads = [3, 6, 12, 24]
elif args.swin_type == 'base':
embed_dim = 128
depths = [2, 2, 18, 2]
num_heads = [4, 8, 16, 32]
elif args.swin_type == 'large':
embed_dim = 192
depths = [2, 2, 18, 2]
num_heads = [6, 12, 24, 48]
else:
assert False
# args.window12 added for test.py because state_dict is loaded after model initialization
if 'window12' in pretrained or args.window12:
print('Window size 12!')
window_size = 12
else:
window_size = 7
if args.mha:
mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd']
mha = [int(a) for a in mha]
else:
mha = [1, 1, 1, 1]
out_indices = (0, 1, 2, 3)
backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads,
window_size=window_size,
ape=False, drop_path_rate=0.3, patch_norm=True,
out_indices=out_indices,
use_checkpoint=False, num_heads_fusion=mha,
fusion_drop=args.fusion_drop
)
if pretrained:
print('Initializing Multi-modal Swin Transformer weights from ' + pretrained)
backbone.init_weights(pretrained=pretrained)
else:
print('Randomly initialize Multi-modal Swin Transformer weights.')
backbone.init_weights()
if 'hardpos_only' in args.model :
model = DMMI_MRACL(backbone, Decoder(8*embed_dim))
else :
model = DMMI(backbone, Decoder(8*embed_dim))
return model
def _load_model_dmmi_swin(pretrained, args):
model = _segm_dmmi_swin(pretrained, args)
return model
def dmmi_swin(pretrained='', args=None):
return _load_model_dmmi_swin(pretrained, args)
def dmmi_swin_hardpos_only(pretrained='', args=None):
return _load_model_dmmi_swin(pretrained, args)
#############################################
# DMMI based on resnet
def _segm_dmmi_res(pretrained, args):
backbone = MultiModalResNet(pretrained)
model = DMMI(backbone, Decoder(2048))
return model
def _load_model_dmmi_res(pretrained, args):
model = _segm_dmmi_res(pretrained, args)
return model
def dmmi_res(pretrained='', args=None):
return _load_model_dmmi_res(pretrained, args)