File size: 1,829 Bytes
fa0f216 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import collections
OLD_KEYS = ['netG._logvarD.bias', 'netG._logvarD.weight', 'netG._logvarE.bias', 'netG._logvarE.weight', 'netG._muD.bias', 'netG._muD.weight', 'netG._muE.bias', 'netG._muE.weight', 'netD.embed.weight', 'netD.embed.u0', 'netD.embed.sv0', 'netD.embed.bias']
def load_generator(model, checkpoint):
if not isinstance(checkpoint, collections.OrderedDict):
checkpoint = checkpoint['model']
checkpoint = {k.replace("netG.",""): v for k, v in checkpoint.items() if k.startswith("netG") and k not in OLD_KEYS}
model.netG.load_state_dict(checkpoint)
return model
def load_checkpoint(model, checkpoint):
if not isinstance(checkpoint, collections.OrderedDict):
checkpoint = checkpoint['model']
old_model = model.state_dict()
if len(checkpoint.keys()) == 241: # default
counter = 0
for k, v in checkpoint.items():
if k in old_model:
old_model[k] = v
counter += 1
elif 'netG.' + k in old_model:
old_model['netG.' + k] = v
counter += 1
ckeys = [k for k in checkpoint.keys() if 'Feat_Encoder' in k]
okeys = [k for k in old_model.keys() if 'Feat_Encoder' in k]
for ck, ok in zip(ckeys, okeys):
old_model[ok] = checkpoint[ck]
counter += 1
# assert counter == 241
checkpoint_dict = old_model
else:
checkpoint = {k: v for k, v in checkpoint.items() if k not in OLD_KEYS}
assert len(old_model) == len(checkpoint)
checkpoint_dict = {k2: v1 for (k1, v1), (k2, v2) in zip(checkpoint.items(), old_model.items()) if
v1.shape == v2.shape}
assert len(old_model) == len(checkpoint_dict)
model.load_state_dict(checkpoint_dict, strict=False)
return model |