Spaces:
Sleeping
Sleeping
import torch, os | |
import torch.distributed as dist | |
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): | |
""" | |
Re-start from checkpoint | |
""" | |
if not os.path.isfile(ckp_path): | |
return | |
print("Found checkpoint at {}".format(ckp_path)) | |
if ckp_path.startswith('https'): | |
checkpoint = torch.hub.load_state_dict_from_url( | |
ckp_path, map_location='cpu', check_hash=True) | |
else: | |
checkpoint = torch.load(ckp_path, map_location='cpu') | |
for key, value in kwargs.items(): | |
if key in checkpoint and value is not None: | |
if key == "model_ema": | |
value.ema.load_state_dict(checkpoint[key]) | |
else: | |
value.load_state_dict(checkpoint[key]) | |
else: | |
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path)) | |
# re load variable important for the run | |
if run_variables is not None: | |
for var_name in run_variables: | |
if var_name in checkpoint: | |
run_variables[var_name] = checkpoint[var_name] | |
def load_pretrained_weights(model, pretrained_weights, checkpoint_key=None, prefixes=None,drop_head="head"): | |
"""load vit weights""" | |
if pretrained_weights == '': | |
return | |
elif pretrained_weights.startswith('https'): | |
state_dict = torch.hub.load_state_dict_from_url( | |
pretrained_weights, map_location='cpu', check_hash=True) | |
else: | |
state_dict = torch.load(pretrained_weights, map_location='cpu') | |
epoch = state_dict['epoch'] if 'epoch' in state_dict else -1 | |
if not checkpoint_key: | |
for key in ['model', 'teacher', 'encoder']: | |
if key in state_dict: checkpoint_key = key | |
print("Load pre-trained checkpoint from: %s[%s] at %d epoch" % (pretrained_weights, checkpoint_key, epoch)) | |
if checkpoint_key: | |
state_dict = state_dict[checkpoint_key] | |
# remove `module.` prefix | |
if prefixes is None: prefixes= ["module.","backbone."] | |
for prefix in prefixes: | |
state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if not drop_head in k } | |
# remove `backbone.` prefix induced by multicrop wrapper | |
checkpoint_model = state_dict | |
# interpolate position embedding | |
pos_embed_checkpoint = checkpoint_model['pos_embed'] | |
embedding_size = pos_embed_checkpoint.shape[-1] | |
num_patches = model.patch_embed.num_patches | |
num_extra_tokens = model.pos_embed.shape[-2] - num_patches | |
# height (== width) for the checkpoint position embedding | |
orig_size = int((pos_embed_checkpoint.shape[-2] ) ** 0.5) | |
# height (== width) for the new position embedding | |
new_size = int(num_patches ** 0.5) | |
# class_token and dist_token are kept unchanged | |
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] | |
# only the position tokens are interpolated | |
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] | |
# print('debug:', pos_embed_checkpoint.shape,orig_size,new_size,num_extra_tokens) | |
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) | |
pos_tokens = torch.nn.functional.interpolate( | |
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) | |
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) | |
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) | |
checkpoint_model['pos_embed'] = new_pos_embed | |
msg = model.load_state_dict(checkpoint_model, strict=False) | |
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) |