Spaces:
Runtime error
Runtime error
| # ========================================================== | |
| # Modified from mmcv | |
| # ========================================================== | |
| import ast | |
| import os | |
| import os.path as osp | |
| import shutil | |
| import sys | |
| import tempfile | |
| from argparse import Action | |
| from importlib import import_module | |
| from addict import Dict | |
| from yapf.yapflib.yapf_api import FormatCode | |
| BASE_KEY = "_base_" | |
| DELETE_KEY = "_delete_" | |
| RESERVED_KEYS = ["filename", "text", "pretty_text", "get", "dump", "merge_from_dict"] | |
| def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): | |
| if not osp.isfile(filename): | |
| raise FileNotFoundError(msg_tmpl.format(filename)) | |
| class ConfigDict(Dict): | |
| def __missing__(self, name): | |
| raise KeyError(name) | |
| def __getattr__(self, name): | |
| try: | |
| value = super(ConfigDict, self).__getattr__(name) | |
| except KeyError: | |
| ex = AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'") | |
| except Exception as e: | |
| ex = e | |
| else: | |
| return value | |
| raise ex | |
| class SLConfig(object): | |
| """ | |
| config files. | |
| only support .py file as config now. | |
| ref: mmcv.utils.config | |
| Example: | |
| >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) | |
| >>> cfg.a | |
| 1 | |
| >>> cfg.b | |
| {'b1': [0, 1]} | |
| >>> cfg.b.b1 | |
| [0, 1] | |
| >>> cfg = Config.fromfile('tests/data/config/a.py') | |
| >>> cfg.filename | |
| "/home/kchen/projects/mmcv/tests/data/config/a.py" | |
| >>> cfg.item4 | |
| 'test' | |
| >>> cfg | |
| "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " | |
| "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" | |
| """ | |
| def _validate_py_syntax(filename): | |
| with open(filename) as f: | |
| content = f.read() | |
| try: | |
| ast.parse(content) | |
| except SyntaxError: | |
| raise SyntaxError("There are syntax errors in config " f"file {filename}") | |
| def _file2dict(filename): | |
| filename = osp.abspath(osp.expanduser(filename)) | |
| check_file_exist(filename) | |
| if filename.lower().endswith(".py"): | |
| with tempfile.TemporaryDirectory() as temp_config_dir: | |
| temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=".py") | |
| temp_config_name = osp.basename(temp_config_file.name) | |
| if os.name == 'nt': | |
| temp_config_file.close() | |
| shutil.copyfile(filename, osp.join(temp_config_dir, temp_config_name)) | |
| temp_module_name = osp.splitext(temp_config_name)[0] | |
| sys.path.insert(0, temp_config_dir) | |
| SLConfig._validate_py_syntax(filename) | |
| mod = import_module(temp_module_name) | |
| sys.path.pop(0) | |
| cfg_dict = { | |
| name: value for name, value in mod.__dict__.items() if not name.startswith("__") | |
| } | |
| # delete imported module | |
| del sys.modules[temp_module_name] | |
| # close temp file | |
| temp_config_file.close() | |
| elif filename.lower().endswith((".yml", ".yaml", ".json")): | |
| from .slio import slload | |
| cfg_dict = slload(filename) | |
| else: | |
| raise IOError("Only py/yml/yaml/json type are supported now!") | |
| cfg_text = filename + "\n" | |
| with open(filename, "r") as f: | |
| cfg_text += f.read() | |
| # parse the base file | |
| if BASE_KEY in cfg_dict: | |
| cfg_dir = osp.dirname(filename) | |
| base_filename = cfg_dict.pop(BASE_KEY) | |
| base_filename = base_filename if isinstance(base_filename, list) else [base_filename] | |
| cfg_dict_list = list() | |
| cfg_text_list = list() | |
| for f in base_filename: | |
| _cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f)) | |
| cfg_dict_list.append(_cfg_dict) | |
| cfg_text_list.append(_cfg_text) | |
| base_cfg_dict = dict() | |
| for c in cfg_dict_list: | |
| if len(base_cfg_dict.keys() & c.keys()) > 0: | |
| raise KeyError("Duplicate key is not allowed among bases") | |
| # TODO Allow the duplicate key while warnning user | |
| base_cfg_dict.update(c) | |
| base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict) | |
| cfg_dict = base_cfg_dict | |
| # merge cfg_text | |
| cfg_text_list.append(cfg_text) | |
| cfg_text = "\n".join(cfg_text_list) | |
| return cfg_dict, cfg_text | |
| def _merge_a_into_b(a, b): | |
| """merge dict `a` into dict `b` (non-inplace). | |
| values in `a` will overwrite `b`. | |
| copy first to avoid inplace modification | |
| Args: | |
| a ([type]): [description] | |
| b ([type]): [description] | |
| Returns: | |
| [dict]: [description] | |
| """ | |
| # import ipdb; ipdb.set_trace() | |
| if not isinstance(a, dict): | |
| return a | |
| b = b.copy() | |
| for k, v in a.items(): | |
| if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): | |
| if not isinstance(b[k], dict) and not isinstance(b[k], list): | |
| # if : | |
| # import ipdb; ipdb.set_trace() | |
| raise TypeError( | |
| f"{k}={v} in child config cannot inherit from base " | |
| f"because {k} is a dict in the child config but is of " | |
| f"type {type(b[k])} in base config. You may set " | |
| f"`{DELETE_KEY}=True` to ignore the base config" | |
| ) | |
| b[k] = SLConfig._merge_a_into_b(v, b[k]) | |
| elif isinstance(b, list): | |
| try: | |
| _ = int(k) | |
| except: | |
| raise TypeError( | |
| f"b is a list, " f"index {k} should be an int when input but {type(k)}" | |
| ) | |
| b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)]) | |
| else: | |
| b[k] = v | |
| return b | |
| def fromfile(filename): | |
| cfg_dict, cfg_text = SLConfig._file2dict(filename) | |
| return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename) | |
| def __init__(self, cfg_dict=None, cfg_text=None, filename=None): | |
| if cfg_dict is None: | |
| cfg_dict = dict() | |
| elif not isinstance(cfg_dict, dict): | |
| raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}") | |
| for key in cfg_dict: | |
| if key in RESERVED_KEYS: | |
| raise KeyError(f"{key} is reserved for config file") | |
| super(SLConfig, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict)) | |
| super(SLConfig, self).__setattr__("_filename", filename) | |
| if cfg_text: | |
| text = cfg_text | |
| elif filename: | |
| with open(filename, "r") as f: | |
| text = f.read() | |
| else: | |
| text = "" | |
| super(SLConfig, self).__setattr__("_text", text) | |
| def filename(self): | |
| return self._filename | |
| def text(self): | |
| return self._text | |
| def pretty_text(self): | |
| indent = 4 | |
| def _indent(s_, num_spaces): | |
| s = s_.split("\n") | |
| if len(s) == 1: | |
| return s_ | |
| first = s.pop(0) | |
| s = [(num_spaces * " ") + line for line in s] | |
| s = "\n".join(s) | |
| s = first + "\n" + s | |
| return s | |
| def _format_basic_types(k, v, use_mapping=False): | |
| if isinstance(v, str): | |
| v_str = f"'{v}'" | |
| else: | |
| v_str = str(v) | |
| if use_mapping: | |
| k_str = f"'{k}'" if isinstance(k, str) else str(k) | |
| attr_str = f"{k_str}: {v_str}" | |
| else: | |
| attr_str = f"{str(k)}={v_str}" | |
| attr_str = _indent(attr_str, indent) | |
| return attr_str | |
| def _format_list(k, v, use_mapping=False): | |
| # check if all items in the list are dict | |
| if all(isinstance(_, dict) for _ in v): | |
| v_str = "[\n" | |
| v_str += "\n".join( | |
| f"dict({_indent(_format_dict(v_), indent)})," for v_ in v | |
| ).rstrip(",") | |
| if use_mapping: | |
| k_str = f"'{k}'" if isinstance(k, str) else str(k) | |
| attr_str = f"{k_str}: {v_str}" | |
| else: | |
| attr_str = f"{str(k)}={v_str}" | |
| attr_str = _indent(attr_str, indent) + "]" | |
| else: | |
| attr_str = _format_basic_types(k, v, use_mapping) | |
| return attr_str | |
| def _contain_invalid_identifier(dict_str): | |
| contain_invalid_identifier = False | |
| for key_name in dict_str: | |
| contain_invalid_identifier |= not str(key_name).isidentifier() | |
| return contain_invalid_identifier | |
| def _format_dict(input_dict, outest_level=False): | |
| r = "" | |
| s = [] | |
| use_mapping = _contain_invalid_identifier(input_dict) | |
| if use_mapping: | |
| r += "{" | |
| for idx, (k, v) in enumerate(input_dict.items()): | |
| is_last = idx >= len(input_dict) - 1 | |
| end = "" if outest_level or is_last else "," | |
| if isinstance(v, dict): | |
| v_str = "\n" + _format_dict(v) | |
| if use_mapping: | |
| k_str = f"'{k}'" if isinstance(k, str) else str(k) | |
| attr_str = f"{k_str}: dict({v_str}" | |
| else: | |
| attr_str = f"{str(k)}=dict({v_str}" | |
| attr_str = _indent(attr_str, indent) + ")" + end | |
| elif isinstance(v, list): | |
| attr_str = _format_list(k, v, use_mapping) + end | |
| else: | |
| attr_str = _format_basic_types(k, v, use_mapping) + end | |
| s.append(attr_str) | |
| r += "\n".join(s) | |
| if use_mapping: | |
| r += "}" | |
| return r | |
| cfg_dict = self._cfg_dict.to_dict() | |
| text = _format_dict(cfg_dict, outest_level=True) | |
| # copied from setup.cfg | |
| yapf_style = dict( | |
| based_on_style="pep8", | |
| blank_line_before_nested_class_or_def=True, | |
| split_before_expression_after_opening_paren=True, | |
| ) | |
| text, _ = FormatCode(text, style_config=yapf_style, verify=True) | |
| return text | |
| def __repr__(self): | |
| return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}" | |
| def __len__(self): | |
| return len(self._cfg_dict) | |
| def __getattr__(self, name): | |
| # # debug | |
| # print('+'*15) | |
| # print('name=%s' % name) | |
| # print("addr:", id(self)) | |
| # # print('type(self):', type(self)) | |
| # print(self.__dict__) | |
| # print('+'*15) | |
| # if self.__dict__ == {}: | |
| # raise ValueError | |
| return getattr(self._cfg_dict, name) | |
| def __getitem__(self, name): | |
| return self._cfg_dict.__getitem__(name) | |
| def __setattr__(self, name, value): | |
| if isinstance(value, dict): | |
| value = ConfigDict(value) | |
| self._cfg_dict.__setattr__(name, value) | |
| def __setitem__(self, name, value): | |
| if isinstance(value, dict): | |
| value = ConfigDict(value) | |
| self._cfg_dict.__setitem__(name, value) | |
| def __iter__(self): | |
| return iter(self._cfg_dict) | |
| def dump(self, file=None): | |
| # import ipdb; ipdb.set_trace() | |
| if file is None: | |
| return self.pretty_text | |
| else: | |
| with open(file, "w") as f: | |
| f.write(self.pretty_text) | |
| def merge_from_dict(self, options): | |
| """Merge list into cfg_dict | |
| Merge the dict parsed by MultipleKVAction into this cfg. | |
| Examples: | |
| >>> options = {'model.backbone.depth': 50, | |
| ... 'model.backbone.with_cp':True} | |
| >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) | |
| >>> cfg.merge_from_dict(options) | |
| >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | |
| >>> assert cfg_dict == dict( | |
| ... model=dict(backbone=dict(depth=50, with_cp=True))) | |
| Args: | |
| options (dict): dict of configs to merge from. | |
| """ | |
| option_cfg_dict = {} | |
| for full_key, v in options.items(): | |
| d = option_cfg_dict | |
| key_list = full_key.split(".") | |
| for subkey in key_list[:-1]: | |
| d.setdefault(subkey, ConfigDict()) | |
| d = d[subkey] | |
| subkey = key_list[-1] | |
| d[subkey] = v | |
| cfg_dict = super(SLConfig, self).__getattribute__("_cfg_dict") | |
| super(SLConfig, self).__setattr__( | |
| "_cfg_dict", SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict) | |
| ) | |
| # for multiprocess | |
| def __setstate__(self, state): | |
| self.__init__(state) | |
| def copy(self): | |
| return SLConfig(self._cfg_dict.copy()) | |
| def deepcopy(self): | |
| return SLConfig(self._cfg_dict.deepcopy()) | |
| class DictAction(Action): | |
| """ | |
| argparse action to split an argument into KEY=VALUE form | |
| on the first = and append to a dictionary. List options should | |
| be passed as comma separated values, i.e KEY=V1,V2,V3 | |
| """ | |
| def _parse_int_float_bool(val): | |
| try: | |
| return int(val) | |
| except ValueError: | |
| pass | |
| try: | |
| return float(val) | |
| except ValueError: | |
| pass | |
| if val.lower() in ["true", "false"]: | |
| return True if val.lower() == "true" else False | |
| if val.lower() in ["none", "null"]: | |
| return None | |
| return val | |
| def __call__(self, parser, namespace, values, option_string=None): | |
| options = {} | |
| for kv in values: | |
| key, val = kv.split("=", maxsplit=1) | |
| val = [self._parse_int_float_bool(v) for v in val.split(",")] | |
| if len(val) == 1: | |
| val = val[0] | |
| options[key] = val | |
| setattr(namespace, self.dest, options) | |