Upload directory
Browse files
aligners/differentiable_face_aligner/dfa/utils/model_utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def remove_prefix(state_dict, prefix):
|
4 |
+
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
|
5 |
+
print('remove prefix \'{}\''.format(prefix))
|
6 |
+
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
|
7 |
+
return {f(key): value for key, value in state_dict.items()}
|
8 |
+
|
9 |
+
def check_keys(model, pretrained_state_dict):
|
10 |
+
ckpt_keys = set(pretrained_state_dict.keys())
|
11 |
+
model_keys = set(model.state_dict().keys())
|
12 |
+
used_pretrained_keys = model_keys & ckpt_keys
|
13 |
+
unused_pretrained_keys = ckpt_keys - model_keys
|
14 |
+
missing_keys = model_keys - ckpt_keys
|
15 |
+
print('Missing keys:{}'.format(len(missing_keys)))
|
16 |
+
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
|
17 |
+
print('Used keys:{}'.format(len(used_pretrained_keys)))
|
18 |
+
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
|
19 |
+
return True
|
20 |
+
|
21 |
+
def load_model(model, pretrained_path, load_to_cpu):
|
22 |
+
print('Loading pretrained model from {}'.format(pretrained_path))
|
23 |
+
if load_to_cpu:
|
24 |
+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
|
25 |
+
else:
|
26 |
+
device = torch.cuda.current_device()
|
27 |
+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
|
28 |
+
if "state_dict" in pretrained_dict.keys():
|
29 |
+
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
|
30 |
+
else:
|
31 |
+
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
|
32 |
+
check_keys(model, pretrained_dict)
|
33 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
34 |
+
return model
|
35 |
+
|
36 |
+
|