File size: 3,620 Bytes
bc8c24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe6bd89
 
bc8c24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch, os
import torch.distributed as dist

def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
    """
    Re-start from checkpoint
    """
    if not os.path.isfile(ckp_path):
        return
    print("Found checkpoint at {}".format(ckp_path))
    if ckp_path.startswith('https'):
        checkpoint = torch.hub.load_state_dict_from_url(
            ckp_path, map_location='cpu', check_hash=True)
    else:
        checkpoint = torch.load(ckp_path, map_location='cpu')
        
    for key, value in kwargs.items():
        if key in checkpoint and value is not None:
            if key == "model_ema":
                value.ema.load_state_dict(checkpoint[key])
            else:
                value.load_state_dict(checkpoint[key])
        else:
            print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))

    # re load variable important for the run
    if run_variables is not None:
        for var_name in run_variables:
            if var_name in checkpoint:
                run_variables[var_name] = checkpoint[var_name]



def load_pretrained_weights(model, pretrained_weights, checkpoint_key=None, prefixes=None,drop_head="head"):
    """load vit weights"""
    if pretrained_weights == '':
        return
    elif pretrained_weights.startswith('https'):
        state_dict = torch.hub.load_state_dict_from_url(
            pretrained_weights, map_location='cpu', check_hash=True)
    else:
        state_dict = torch.load(pretrained_weights, map_location='cpu')
    
    epoch = state_dict['epoch'] if 'epoch' in state_dict else -1
    if not checkpoint_key: 
        for key in ['model', 'teacher', 'encoder']:
            if key in state_dict: checkpoint_key = key
            
    print("Load pre-trained checkpoint from: %s[%s] at %d epoch" % (pretrained_weights, checkpoint_key, epoch))
    
    if checkpoint_key:
        state_dict = state_dict[checkpoint_key]
    # remove `module.` prefix
    if prefixes is None: prefixes= ["module.","backbone."]
    for prefix in prefixes:
        state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if not drop_head in k }
    # remove `backbone.` prefix induced by multicrop wrapper
    checkpoint_model = state_dict
    
    
    # interpolate position embedding
    pos_embed_checkpoint = checkpoint_model['pos_embed']
    embedding_size = pos_embed_checkpoint.shape[-1]
    num_patches = model.patch_embed.num_patches
    num_extra_tokens = model.pos_embed.shape[-2] - num_patches
    # height (== width) for the checkpoint position embedding
    orig_size = int((pos_embed_checkpoint.shape[-2] ) ** 0.5)
    # height (== width) for the new position embedding
    new_size = int(num_patches ** 0.5)
    # class_token and dist_token are kept unchanged
    extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
    # only the position tokens are interpolated
    pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
    # print('debug:', pos_embed_checkpoint.shape,orig_size,new_size,num_extra_tokens)
    
    pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
    pos_tokens = torch.nn.functional.interpolate(
        pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
    pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
    new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
    checkpoint_model['pos_embed'] = new_pos_embed
    
    msg = model.load_state_dict(checkpoint_model, strict=False)
    print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))