|
from .mask_predictor import Decoder |
|
from .backbone import MultiModalSwinTransformer |
|
from .backbone_resnet import MultiModalResNet |
|
from ._utils import DMMI |
|
from ._utils_mracl import DMMI_MRACL |
|
|
|
__all__ = ['dmmi_swin', 'dmmi_res', 'dmmi_swin_hardpos_only'] |
|
|
|
|
|
def _segm_dmmi_swin(pretrained, args): |
|
|
|
if args.swin_type == 'tiny': |
|
embed_dim = 96 |
|
depths = [2, 2, 6, 2] |
|
num_heads = [3, 6, 12, 24] |
|
elif args.swin_type == 'small': |
|
embed_dim = 96 |
|
depths = [2, 2, 18, 2] |
|
num_heads = [3, 6, 12, 24] |
|
elif args.swin_type == 'base': |
|
embed_dim = 128 |
|
depths = [2, 2, 18, 2] |
|
num_heads = [4, 8, 16, 32] |
|
elif args.swin_type == 'large': |
|
embed_dim = 192 |
|
depths = [2, 2, 18, 2] |
|
num_heads = [6, 12, 24, 48] |
|
else: |
|
assert False |
|
|
|
if 'window12' in pretrained or args.window12: |
|
print('Window size 12!') |
|
window_size = 12 |
|
else: |
|
window_size = 7 |
|
|
|
if args.mha: |
|
mha = args.mha.split('-') |
|
mha = [int(a) for a in mha] |
|
else: |
|
mha = [1, 1, 1, 1] |
|
|
|
out_indices = (0, 1, 2, 3) |
|
backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, |
|
window_size=window_size, |
|
ape=False, drop_path_rate=0.3, patch_norm=True, |
|
out_indices=out_indices, |
|
use_checkpoint=False, num_heads_fusion=mha, |
|
fusion_drop=args.fusion_drop |
|
) |
|
if pretrained: |
|
print('Initializing Multi-modal Swin Transformer weights from ' + pretrained) |
|
backbone.init_weights(pretrained=pretrained) |
|
else: |
|
print('Randomly initialize Multi-modal Swin Transformer weights.') |
|
backbone.init_weights() |
|
|
|
if 'hardpos_only' in args.model : |
|
model = DMMI_MRACL(backbone, Decoder(8*embed_dim)) |
|
else : |
|
model = DMMI(backbone, Decoder(8*embed_dim)) |
|
|
|
return model |
|
|
|
def _load_model_dmmi_swin(pretrained, args): |
|
model = _segm_dmmi_swin(pretrained, args) |
|
return model |
|
|
|
def dmmi_swin(pretrained='', args=None): |
|
return _load_model_dmmi_swin(pretrained, args) |
|
|
|
def dmmi_swin_hardpos_only(pretrained='', args=None): |
|
return _load_model_dmmi_swin(pretrained, args) |
|
|
|
|
|
|
|
|
|
def _segm_dmmi_res(pretrained, args): |
|
backbone = MultiModalResNet(pretrained) |
|
|
|
model = DMMI(backbone, Decoder(2048)) |
|
return model |
|
|
|
|
|
def _load_model_dmmi_res(pretrained, args): |
|
model = _segm_dmmi_res(pretrained, args) |
|
return model |
|
|
|
|
|
def dmmi_res(pretrained='', args=None): |
|
return _load_model_dmmi_res(pretrained, args) |