import copy import datetime import gzip import json import os from hashlib import md5 import jax import jax.numpy as jnp import numpy as np from numpy import isin from kinetix.environment.ued.ued_state import UEDParams from omegaconf import OmegaConf from pandas import isna from typing import List, Tuple import wandb from kinetix.environment.env_state import EnvParams, StaticEnvParams from collections import defaultdict from kinetix.util.saving import load_from_json_file def get_hash_without_seed(config): old_seed = config["seed"] config["seed"] = 0 ans = md5(OmegaConf.to_yaml(config, sort_keys=True).encode()).hexdigest() config["seed"] = old_seed return ans def get_date() -> str: return datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") def generate_params_from_config(config): if config.get("env_size_type", "predefined") == "custom": # must load env params from a file _, static_env_params, env_params = load_from_json_file(os.path.join("worlds", config["custom_path"])) return env_params, static_env_params.replace( frame_skip=config["frame_skip"], ) env_params = EnvParams() static_env_params = StaticEnvParams().replace( num_polygons=config["num_polygons"], num_circles=config["num_circles"], num_joints=config["num_joints"], num_thrusters=config["num_thrusters"], frame_skip=config["frame_skip"], num_motor_bindings=config["num_motor_bindings"], num_thruster_bindings=config["num_thruster_bindings"], ) return env_params, static_env_params def generate_ued_params_from_config(config) -> UEDParams: ans = UEDParams() if config["env_size_name"] == "s": ans = ans.replace(add_shape_n_proposals=1) # otherwise we get a very weird XLA bug. if "fixate_chance_max" in config: print("Changing fixate chance max to", config["fixate_chance_max"]) ans = ans.replace(fixate_chance_max=config["fixate_chance_max"]) return ans def get_eval_level_groups(eval_levels: List[str]) -> List[Tuple[str, str]]: def get_groups(s): # This is the size group group_one = s.split("/")[0] group_two = s.split("/")[1].split("_")[0] group_two = "".join([i for i in group_two if not i.isdigit()]) if group_two == "h": group_two = "handmade" if group_two == "r": group_two = "random" return f"{group_one}_all", f"{group_one}_{group_two}" indices = defaultdict(list) for idx, s in enumerate(eval_levels): groups = get_groups(s) for group in groups: indices[group].append(idx) indices2 = {} for g in indices: indices2[g] = np.array(indices[g]) return indices2 def normalise_config(config, name, editor_config=False): old_config = copy.deepcopy(config) keys = ["env", "learning", "model", "misc", "eval", "ued", "env_size", "train_levels"] for k in keys: if k not in config: config[k] = {} small_d = config[k] del config[k] for kk, vv in small_d.items(): assert kk not in config, kk config[kk] = vv if not editor_config: config["eval_env_size_true"] = config["eval_env_size"] if config["num_train_envs"] == 2048 and "Pixels" in config["env_name"]: config["num_train_envs"] = 512 if "SFL" in name and config["env_size_name"] in ["m", "l"]: config["eval_num_attempts"] = 6 # to avoid a very weird XLA bug. config["hash"] = get_hash_without_seed(config) config["random_hash"] = np.random.randint(2**31) config["log_save_path"] = f"logs/{config['hash']}/{config['seed']}-{get_date()}" os.makedirs(config["log_save_path"], exist_ok=True) with open(f"{config['log_save_path']}/config.yaml", "w") as f: f.write(OmegaConf.to_yaml(old_config)) if config["group"] == "auto": config["group"] = f"{name}-" + config["group_auto_prefix"] + config["env_name"].replace("Kinetix-", "") config["group"] += "-" + str(config["env_size_name"]) if config["eval_levels"] == ["auto"] or config["eval_levels"] == "auto": config["eval_levels"] = config["train_levels_list"] print("Using Auto eval levels:", config["eval_levels"]) config["num_eval_levels"] = len(config["eval_levels"]) steps = ( config["num_steps"] * config.get("outer_rollout_steps", 1) * config["num_train_envs"] * (2 if name == "PAIRED" else 1) ) config["num_updates"] = int(config["total_timesteps"]) // steps nsteps = int(config["total_timesteps"] // 1e6) letter = "M" if nsteps >= 1000: nsteps = nsteps // 1000 letter = "B" config["run_name"] = ( config["env_name"] + f"-{name}-" + str(nsteps) + letter + "-" + str(config["num_train_envs"]) ) if config["checkpoint_save_freq"] >= config["num_updates"]: config["checkpoint_save_freq"] = config["num_updates"] return config def get_tags(config, name): return [name] tags = [name] if name in ["PLR", "ACCEL", "DR"]: if config["use_accel"]: tags.append("ACCEL") else: tags.append("PLR") return tags def init_wandb(config, name) -> wandb.run: run = wandb.init( config=config, project=config["wandb_project"], group=config["group"], name=config["run_name"], entity=config["wandb_entity"], mode=config["wandb_mode"], tags=get_tags(config, name), ) wandb.define_metric("timing/num_updates") wandb.define_metric("timing/num_env_steps") wandb.define_metric("*", step_metric="timing/num_env_steps") wandb.define_metric("timing/sps", step_metric="timing/num_env_steps") return run def save_data_to_local_file(data_to_save, config): if not config.get("save_local_data", False): return def reverse_in(li, value): for i, v in enumerate(li): if v in value: return True return False clean_data = {k: v for k, v in data_to_save.items() if not reverse_in(["media/", "images/"], k)} def _clean(x): if isinstance(x, jnp.ndarray): return x.tolist() elif isinstance(x, jnp.float32): if jnp.isnan(x): return -float("inf") return round(float(x) * 1000) / 1000 elif isinstance(x, jnp.int32): return int(x) return x clean_data = jax.tree_map(lambda x: _clean(x), clean_data) print("Saving this data:", clean_data) with open(f"{config['log_save_path']}/data.jsonl", "a+") as f: f.write(json.dumps(clean_data) + "\n") def compress_log_files_after_run(config): fpath = f"{config['log_save_path']}/data.jsonl" with open(fpath, "rb") as f_in, gzip.open(fpath + ".gz", "wb") as f_out: f_out.writelines(f_in) def get_video_frequency(config, update_step): frac_through_training = update_step / config["num_updates"] vid_frequency = ( config["eval_freq"] * config["video_frequency"] * jax.lax.select( (0.1 <= frac_through_training) & (frac_through_training < 0.3), 1, jax.lax.select( (0.3 <= frac_through_training) & (frac_through_training < 0.6), 2, 4, ), ) ) return vid_frequency