Gent (PG/R - Comp Sci & Elec Eng)
load weights
fe6bd89
import torch, os
import torch.distributed as dist
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
"""
Re-start from checkpoint
"""
if not os.path.isfile(ckp_path):
return
print("Found checkpoint at {}".format(ckp_path))
if ckp_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
ckp_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(ckp_path, map_location='cpu')
for key, value in kwargs.items():
if key in checkpoint and value is not None:
if key == "model_ema":
value.ema.load_state_dict(checkpoint[key])
else:
value.load_state_dict(checkpoint[key])
else:
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
# re load variable important for the run
if run_variables is not None:
for var_name in run_variables:
if var_name in checkpoint:
run_variables[var_name] = checkpoint[var_name]
def load_pretrained_weights(model, pretrained_weights, checkpoint_key=None, prefixes=None,drop_head="head"):
"""load vit weights"""
if pretrained_weights == '':
return
elif pretrained_weights.startswith('https'):
state_dict = torch.hub.load_state_dict_from_url(
pretrained_weights, map_location='cpu', check_hash=True)
else:
state_dict = torch.load(pretrained_weights, map_location='cpu')
epoch = state_dict['epoch'] if 'epoch' in state_dict else -1
if not checkpoint_key:
for key in ['model', 'teacher', 'encoder']:
if key in state_dict: checkpoint_key = key
print("Load pre-trained checkpoint from: %s[%s] at %d epoch" % (pretrained_weights, checkpoint_key, epoch))
if checkpoint_key:
state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
if prefixes is None: prefixes= ["module.","backbone."]
for prefix in prefixes:
state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if not drop_head in k }
# remove `backbone.` prefix induced by multicrop wrapper
checkpoint_model = state_dict
# interpolate position embedding
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] ) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# print('debug:', pos_embed_checkpoint.shape,orig_size,new_size,num_extra_tokens)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
msg = model.load_state_dict(checkpoint_model, strict=False)
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))