Gent (PG/R - Comp Sci & Elec Eng) commited on
Commit
fe6bd89
·
1 Parent(s): bc8c24d

load weights

Browse files
Files changed (1) hide show
  1. utils.py +2 -1
utils.py CHANGED
@@ -48,7 +48,8 @@ def load_pretrained_weights(model, pretrained_weights, checkpoint_key=None, pref
48
 
49
  print("Load pre-trained checkpoint from: %s[%s] at %d epoch" % (pretrained_weights, checkpoint_key, epoch))
50
 
51
- state_dict = state_dict[checkpoint_key]
 
52
  # remove `module.` prefix
53
  if prefixes is None: prefixes= ["module.","backbone."]
54
  for prefix in prefixes:
 
48
 
49
  print("Load pre-trained checkpoint from: %s[%s] at %d epoch" % (pretrained_weights, checkpoint_key, epoch))
50
 
51
+ if checkpoint_key:
52
+ state_dict = state_dict[checkpoint_key]
53
  # remove `module.` prefix
54
  if prefixes is None: prefixes= ["module.","backbone."]
55
  for prefix in prefixes: