thecho7's picture
LFS dat
c426e13
raw
history blame
1.05 kB
import json
DEFAULTS = {
"network": "dpn",
"encoder": "dpn92",
"model_params": {},
"optimizer": {
"batch_size": 32,
"type": "SGD", # supported: SGD, Adam
"momentum": 0.9,
"weight_decay": 0,
"clip": 1.,
"learning_rate": 0.1,
"classifier_lr": -1,
"nesterov": True,
"schedule": {
"type": "constant", # supported: constant, step, multistep, exponential, linear, poly
"mode": "epoch", # supported: epoch, step
"epochs": 10,
"params": {}
}
},
"normalize": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225]
}
}
def _merge(src, dst):
for k, v in src.items():
if k in dst:
if isinstance(v, dict):
_merge(src[k], dst[k])
else:
dst[k] = v
def load_config(config_file, defaults=DEFAULTS):
with open(config_file, "r") as fd:
config = json.load(fd)
_merge(defaults, config)
return config