VRIS_vip / tools /load_pretrained_weights.py
dianecy's picture
Add files using upload-large-folder tool
91e3dad verified
raw
history blame
441 Bytes
import torch
def pre_trained_model_to_finetune(checkpoint, args):
checkpoint = checkpoint['model']
# only delete the class_embed since the finetuned dataset has different num_classes
num_layers = args.dec_layers + 1 if args.two_stage else args.dec_layers
for l in range(num_layers):
del checkpoint["class_embed.{}.weight".format(l)]
del checkpoint["class_embed.{}.bias".format(l)]
return checkpoint