# Copyright (2024) Tsinghua University, Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging from omegaconf import OmegaConf class Config: def __init__(self, args): self.config = {} self.args = args user_config = self._build_opt_list(self.args.options) config = OmegaConf.load(self.args.cfg_path) config = OmegaConf.merge(config, user_config) self.config = config def _convert_to_dot_list(self, opts): if opts is None: opts = [] if len(opts) == 0: return opts has_equal = opts[0].find("=") != -1 if has_equal: return opts return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] def _build_opt_list(self, opts): opts_dot_list = self._convert_to_dot_list(opts) return OmegaConf.from_dotlist(opts_dot_list) def pretty_print(self): logging.info("\n===== Running Parameters =====") logging.info(self._convert_node_to_json(self.config.run)) logging.info("\n====== Dataset Attributes ======") logging.info(self._convert_node_to_json(self.config.datasets)) logging.info(f"\n====== Model Attributes ======") logging.info(self._convert_node_to_json(self.config.model)) def _convert_node_to_json(self, node): container = OmegaConf.to_container(node, resolve=True) return json.dumps(container, indent=4, sort_keys=True) def to_dict(self): return OmegaConf.to_container(self.config)