File size: 2,297 Bytes
75f0b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from collections import OrderedDict

from spiga.data.loaders.dl_config import DatabaseStruct

MODELS_URL = {
    "wflw": "https://drive.google.com/uc?export=download&confirm=yes&id=1h0qA5ysKorpeDNRXe9oYkVcVe8UYyzP7",
    "300wpublic": "https://drive.google.com/uc?export=download&confirm=yes&id=1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC",
    "300wprivate": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM",
    "merlrav": "https://drive.google.com/uc?export=download&confirm=yes&id=1GKS1x0tpsTVivPZUk_yrSiMhwEAcAkg6",
    "cofw68": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM",
}


class ModelConfig(object):

    def __init__(self, dataset_name=None, load_model_url=True):
        # Model configuration
        self.model_weights = None
        self.model_weights_path = "./"
        self.load_model_url = load_model_url
        self.model_weights_url = None
        # Pretreatment
        self.focal_ratio = 1.5  # Camera matrix focal length ratio.
        self.target_dist = 1.6  # Target distance zoom in/out around face.
        self.image_size = (256, 256)
        # Outputs
        self.ftmap_size = (64, 64)
        # Dataset
        self.dataset = None

        if dataset_name is not None:
            self.update_with_dataset(dataset_name)

    def update_with_dataset(self, dataset_name):

        config_dict = {
            "dataset": DatabaseStruct(dataset_name),
            "model_weights": "spiga_%s.pt" % dataset_name,
        }

        if dataset_name == "cofw68":  # Test only
            config_dict["model_weights"] = "spiga_300wprivate.pt"

        if self.load_model_url:
            config_dict["model_weights_url"] = MODELS_URL[dataset_name]

        self.update(config_dict)

    def update(self, params_dict):
        state_dict = self.state_dict()
        for k, v in params_dict.items():
            if k in state_dict or hasattr(self, k):
                setattr(self, k, v)
            else:
                raise Warning("Unknown option: {}: {}".format(k, v))

    def state_dict(self):
        state_dict = OrderedDict()
        for k in self.__dict__.keys():
            if not k.startswith("_"):
                state_dict[k] = getattr(self, k)
        return state_dict