|
import collections |
|
|
|
OLD_KEYS = ['netG._logvarD.bias', 'netG._logvarD.weight', 'netG._logvarE.bias', 'netG._logvarE.weight', 'netG._muD.bias', 'netG._muD.weight', 'netG._muE.bias', 'netG._muE.weight', 'netD.embed.weight', 'netD.embed.u0', 'netD.embed.sv0', 'netD.embed.bias'] |
|
|
|
|
|
def load_generator(model, checkpoint): |
|
if not isinstance(checkpoint, collections.OrderedDict): |
|
checkpoint = checkpoint['model'] |
|
|
|
checkpoint = {k.replace("netG.",""): v for k, v in checkpoint.items() if k.startswith("netG") and k not in OLD_KEYS} |
|
model.netG.load_state_dict(checkpoint) |
|
|
|
return model |
|
|
|
|
|
def load_checkpoint(model, checkpoint): |
|
if not isinstance(checkpoint, collections.OrderedDict): |
|
checkpoint = checkpoint['model'] |
|
old_model = model.state_dict() |
|
if len(checkpoint.keys()) == 241: |
|
counter = 0 |
|
for k, v in checkpoint.items(): |
|
if k in old_model: |
|
old_model[k] = v |
|
counter += 1 |
|
elif 'netG.' + k in old_model: |
|
old_model['netG.' + k] = v |
|
counter += 1 |
|
|
|
ckeys = [k for k in checkpoint.keys() if 'Feat_Encoder' in k] |
|
okeys = [k for k in old_model.keys() if 'Feat_Encoder' in k] |
|
for ck, ok in zip(ckeys, okeys): |
|
old_model[ok] = checkpoint[ck] |
|
counter += 1 |
|
|
|
checkpoint_dict = old_model |
|
else: |
|
checkpoint = {k: v for k, v in checkpoint.items() if k not in OLD_KEYS} |
|
assert len(old_model) == len(checkpoint) |
|
checkpoint_dict = {k2: v1 for (k1, v1), (k2, v2) in zip(checkpoint.items(), old_model.items()) if |
|
v1.shape == v2.shape} |
|
assert len(old_model) == len(checkpoint_dict) |
|
model.load_state_dict(checkpoint_dict, strict=False) |
|
return model |