Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import os
import yaml
import json
import copy
import argparse
import utils.logging as logging
logger = logging.get_logger(__name__)
class Config(object):
def __init__(self, load=True, cfg_dict=None, cfg_level=None):
self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "")
if load:
self.args = self._parse_args()
logger.info("Loading config from {}.".format(self.args.cfg_file))
self.need_initialization = True
cfg_base = self._load_yaml(self.args) # self._initialize_cfg()
cfg_dict = self._load_yaml(self.args)
cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict)
cfg_dict = self._update_from_args(cfg_dict)
self.cfg_dict = cfg_dict
self._update_dict(cfg_dict)
def _parse_args(self):
parser = argparse.ArgumentParser(
description="Argparser for configuring [code base name to think of] codebase"
)
parser.add_argument(
"--cfg",
dest="cfg_file",
help="Path to the configuration file",
default='configs/UniAnimate_infer.yaml'
)
parser.add_argument(
"--init_method",
help="Initialization method, includes TCP or shared file-system",
default="tcp://localhost:9999",
type=str,
)
parser.add_argument(
'--debug',
action='store_true',
default=False,
help='Into debug information'
)
parser.add_argument(
"opts",
help="other configurations",
default=None,
nargs=argparse.REMAINDER)
return parser.parse_args()
def _path_join(self, path_list):
path = ""
for p in path_list:
path+= p + '/'
return path[:-1]
def _update_from_args(self, cfg_dict):
args = self.args
for var in vars(args):
cfg_dict[var] = getattr(args, var)
return cfg_dict
def _initialize_cfg(self):
if self.need_initialization:
self.need_initialization = False
if os.path.exists('./configs/base.yaml'):
with open("./configs/base.yaml", 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
else:
with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
return cfg
def _load_yaml(self, args, file_name=""):
assert args.cfg_file is not None
if not file_name == "": # reading from base file
with open(file_name, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
else:
if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]:
args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./")
with open(args.cfg_file, 'r') as f:
cfg = yaml.load(f.read(), Loader=yaml.SafeLoader)
file_name = args.cfg_file
if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys():
# return cfg if the base file is being accessed
cfg = self._merge_cfg_from_command_update(args, cfg)
return cfg
if "_BASE" in cfg.keys():
if cfg["_BASE"][1] == '.':
prev_count = cfg["_BASE"].count('..')
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:])
else:
cfg_base_file = cfg["_BASE"].replace(
"./",
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
)
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg)
else:
if "_BASE_RUN" in cfg.keys():
if cfg["_BASE_RUN"][1] == '.':
prev_count = cfg["_BASE_RUN"].count('..')
cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:])
else:
cfg_base_file = cfg["_BASE_RUN"].replace(
"./",
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
)
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True)
if "_BASE_MODEL" in cfg.keys():
if cfg["_BASE_MODEL"][1] == '.':
prev_count = cfg["_BASE_MODEL"].count('..')
cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:])
else:
cfg_base_file = cfg["_BASE_MODEL"].replace(
"./",
args.cfg_file.replace(args.cfg_file.split('/')[-1], "")
)
cfg_base = self._load_yaml(args, cfg_base_file)
cfg = self._merge_cfg_from_base(cfg_base, cfg)
cfg = self._merge_cfg_from_command(args, cfg)
return cfg
def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False):
for k,v in cfg_new.items():
if k in cfg_base.keys():
if isinstance(v, dict):
self._merge_cfg_from_base(cfg_base[k], v)
else:
cfg_base[k] = v
else:
if "BASE" not in k or preserve_base:
cfg_base[k] = v
return cfg_base
def _merge_cfg_from_command_update(self, args, cfg):
if len(args.opts) == 0:
return cfg
assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
args.opts, len(args.opts)
)
keys = args.opts[0::2]
vals = args.opts[1::2]
for key, val in zip(keys, vals):
cfg[key] = val
return cfg
def _merge_cfg_from_command(self, args, cfg):
assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format(
args.opts, len(args.opts)
)
keys = args.opts[0::2]
vals = args.opts[1::2]
# maximum supported depth 3
for idx, key in enumerate(keys):
key_split = key.split('.')
assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format(
len(key_split)
)
assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format(
key_split[0]
)
if len(key_split) == 2:
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
key
)
elif len(key_split) == 3:
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
key
)
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
key
)
elif len(key_split) == 4:
assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format(
key
)
assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format(
key
)
assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format(
key
)
if len(key_split) == 1:
cfg[key_split[0]] = vals[idx]
elif len(key_split) == 2:
cfg[key_split[0]][key_split[1]] = vals[idx]
elif len(key_split) == 3:
cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx]
elif len(key_split) == 4:
cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx]
return cfg
def _update_dict(self, cfg_dict):
def recur(key, elem):
if type(elem) is dict:
return key, Config(load=False, cfg_dict=elem, cfg_level=key)
else:
if type(elem) is str and elem[1:3]=="e-":
elem = float(elem)
return key, elem
dic = dict(recur(k, v) for k, v in cfg_dict.items())
self.__dict__.update(dic)
def get_args(self):
return self.args
def __repr__(self):
return "{}\n".format(self.dump())
def dump(self):
return json.dumps(self.cfg_dict, indent=2)
def deep_copy(self):
return copy.deepcopy(self)
if __name__ == '__main__':
# debug
cfg = Config(load=True)
print(cfg.DATA)