Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import pickle | |
| from typing import Any, Dict, Union | |
| import flax.serialization | |
| import flax.serialization | |
| import flax.serialization | |
| import flax.serialization | |
| import flax.serialization | |
| import flax.serialization | |
| import flax.serialization | |
| import jax | |
| import jax.numpy as jnp | |
| import flax | |
| import wandb | |
| from jax2d.engine import ( | |
| calculate_collision_matrix, | |
| get_empty_collision_manifolds, | |
| get_pairwise_interaction_indices, | |
| recalculate_mass_and_inertia, | |
| ) | |
| from jax2d.sim_state import RigidBody, SimState | |
| from kinetix.environment.env_state import EnvState, StaticEnvParams, EnvParams | |
| from flax.traverse_util import flatten_dict, unflatten_dict | |
| from safetensors.flax import save_file, load_file | |
| from kinetix.pcg.pcg import env_state_to_pcg_state | |
| from kinetix.pcg.pcg_state import PCGState | |
| import bz2 | |
| def check_if_mass_and_inertia_are_correct(state: SimState, env_params: EnvParams, static_params): | |
| new = recalculate_mass_and_inertia(state, static_params, state.polygon_densities, state.circle_densities) | |
| def _check(a, b, shape, name): | |
| a = jnp.where(shape.active, a, jnp.zeros_like(a)) | |
| b = jnp.where(shape.active, b, jnp.zeros_like(b)) | |
| if not jnp.allclose(a, b): | |
| idxs = jnp.arange(len(shape.active))[(a != b) & shape.active] | |
| new_one = a[idxs] | |
| old_one = b[idxs] | |
| raise ValueError( | |
| f"Error: {name} is not the same after loading. Indexes {idxs} are incorrect. New = {new_one} | Before = {old_one}" | |
| ) | |
| _check(new.polygon.inverse_mass, state.polygon.inverse_mass, state.polygon, "Polygon inverse mass") | |
| _check(new.circle.inverse_mass, state.circle.inverse_mass, state.circle, "Circle inverse mass") | |
| _check(new.polygon.inverse_inertia, state.polygon.inverse_inertia, state.polygon, "Polygon inverse inertia") | |
| _check(new.circle.inverse_inertia, state.circle.inverse_inertia, state.circle, "Circle inverse inertia") | |
| return True | |
| def save_pickle(filename, state): | |
| with open(filename, "wb") as f: | |
| pickle.dump(state, f) | |
| def load_pcg_state_pickle(filename): | |
| with open(filename, "rb") as f: | |
| return pickle.load(f) | |
| def expand_env_state(env_state: EnvState, static_env_params: StaticEnvParams, ignore_collision_matrix=False): | |
| num_rects = len(env_state.polygon.position) | |
| num_circles = len(env_state.circle.position) | |
| num_joints = len(env_state.joint.a_index) | |
| num_thrusters = len(env_state.thruster.object_index) | |
| def _add_dummy(num_to_add, obj): | |
| return jax.tree_map( | |
| lambda current: jnp.concatenate( | |
| [current, jnp.zeros((num_to_add, *current.shape[1:]), dtype=current.dtype)], axis=0 | |
| ), | |
| obj, | |
| ) | |
| does_need_to_change = False | |
| added_rects = 0 | |
| if ( | |
| num_rects > static_env_params.num_polygons | |
| or num_circles > static_env_params.num_circles | |
| or num_joints > static_env_params.num_joints | |
| ): | |
| raise Exception( | |
| f"The current static_env_params is too small to accommodate the loaded env_state (needs num_rects={num_rects}, num_circles={num_circles}, num_joints={num_joints} but current is {static_env_params.num_polygons}, {static_env_params.num_circles}, {static_env_params.num_joints})." | |
| ) | |
| if num_rects < static_env_params.num_polygons: | |
| added_rects = static_env_params.num_polygons - num_rects | |
| does_need_to_change = True | |
| env_state = env_state.replace( | |
| polygon=_add_dummy(added_rects, env_state.polygon), | |
| polygon_shape_roles=_add_dummy(added_rects, env_state.polygon_shape_roles), | |
| polygon_highlighted=_add_dummy(added_rects, env_state.polygon_highlighted), | |
| polygon_densities=_add_dummy(added_rects, env_state.polygon_densities), | |
| ) | |
| if num_circles < static_env_params.num_circles: | |
| does_need_to_change = True | |
| n_to_add = static_env_params.num_circles - num_circles | |
| env_state = env_state.replace( | |
| circle=_add_dummy(n_to_add, env_state.circle), | |
| circle_shape_roles=_add_dummy(n_to_add, env_state.circle_shape_roles), | |
| circle_highlighted=_add_dummy(n_to_add, env_state.circle_highlighted), | |
| circle_densities=_add_dummy(n_to_add, env_state.circle_densities), | |
| ) | |
| if num_joints < static_env_params.num_joints: | |
| does_need_to_change = True | |
| n_to_add = static_env_params.num_joints - num_joints | |
| env_state = env_state.replace( | |
| joint=_add_dummy(n_to_add, env_state.joint), | |
| motor_bindings=_add_dummy(n_to_add, env_state.motor_bindings), | |
| motor_auto=_add_dummy(n_to_add, env_state.motor_auto), | |
| ) | |
| if num_thrusters < static_env_params.num_thrusters: | |
| does_need_to_change = True | |
| n_to_add = static_env_params.num_thrusters - num_thrusters | |
| env_state = env_state.replace( | |
| thruster=_add_dummy(n_to_add, env_state.thruster), | |
| thruster_bindings=_add_dummy(n_to_add, env_state.thruster_bindings), | |
| ) | |
| # This fixes the indices | |
| def _modify_index(old_indices): | |
| return jnp.where(old_indices >= num_rects, old_indices + added_rects, old_indices) | |
| if added_rects > 0: | |
| env_state = env_state.replace( | |
| joint=env_state.joint.replace( | |
| a_index=_modify_index(env_state.joint.a_index), | |
| b_index=_modify_index(env_state.joint.b_index), | |
| ), | |
| thruster=env_state.thruster.replace( | |
| object_index=_modify_index(env_state.thruster.object_index), | |
| ), | |
| ) | |
| # Double check the collision manifolds are fine | |
| if does_need_to_change or 1: | |
| # print("Loading but changing the shapes to match the current static params.") | |
| acc_rr_manifolds, acc_cr_manifolds, acc_cc_manifolds = get_empty_collision_manifolds(static_env_params) | |
| env_state = env_state.replace( | |
| collision_matrix=( | |
| env_state.collision_matrix | |
| if ignore_collision_matrix | |
| else calculate_collision_matrix(static_env_params, env_state.joint) | |
| ), | |
| acc_rr_manifolds=acc_rr_manifolds, | |
| acc_cr_manifolds=acc_cr_manifolds, | |
| acc_cc_manifolds=acc_cc_manifolds, | |
| ) | |
| return env_state | |
| def expand_pcg_state(pcg_state: PCGState, static_env_params): | |
| new_pcg_state = pcg_state.replace( | |
| env_state=expand_env_state(pcg_state.env_state, static_env_params), | |
| env_state_max=expand_env_state(pcg_state.env_state_max, static_env_params), | |
| env_state_pcg_mask=expand_env_state( | |
| pcg_state.env_state_pcg_mask, static_env_params, ignore_collision_matrix=True | |
| ), | |
| ) | |
| new_pcg_state = new_pcg_state.replace( | |
| env_state_pcg_mask=new_pcg_state.env_state_pcg_mask.replace( | |
| collision_matrix=jnp.zeros_like(new_pcg_state.env_state.collision_matrix, dtype=bool), | |
| ) | |
| ) | |
| num_shapes = new_pcg_state.env_state.polygon.active.shape[0] + new_pcg_state.env_state.circle.active.shape[0] | |
| return new_pcg_state.replace( | |
| tied_together=jnp.zeros((num_shapes, num_shapes), dtype=bool) | |
| .at[ | |
| : pcg_state.tied_together.shape[0], | |
| : pcg_state.tied_together.shape[1], | |
| ] | |
| .set(pcg_state.tied_together) | |
| ) | |
| def load_world_state_pickle(filename, params=None, static_env_params=None): | |
| static_params = static_env_params or StaticEnvParams() | |
| with open(filename, "rb") as f: | |
| state: SimState = pickle.load(f) | |
| state = jax.tree.map(lambda x: jnp.nan_to_num(x), state) | |
| # Check if the mass and inertia are reasonable. | |
| check_if_mass_and_inertia_are_correct(state, params or EnvParams(), static_params) | |
| # Now check if the shapes are correct | |
| return expand_env_state(state, static_params) | |
| def stack_list_of_pytrees(list_of_pytrees): | |
| v = jax.tree_map(lambda x: jnp.expand_dims(x, 0), list_of_pytrees[0]) | |
| for l in list_of_pytrees[1:]: | |
| v = jax.tree_map(lambda x, y: jnp.concatenate([x, jnp.expand_dims(y, 0)], axis=0), v, l) | |
| return v | |
| def get_pcg_state_from_json(json_filename) -> PCGState: | |
| env_state, _, _ = load_from_json_file(json_filename) | |
| return env_state_to_pcg_state(env_state) | |
| def my_load_file(filename): | |
| data = bz2.BZ2File(filename, "rb") | |
| data = pickle.load(data) | |
| return data | |
| def my_save_file(obj, filename): | |
| with bz2.BZ2File(filename, "w") as f: | |
| pickle.dump(obj, f) | |
| def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None: | |
| my_save_file(params, filename) | |
| def load_params(filename: Union[str, os.PathLike], legacy=False) -> Dict: | |
| if legacy: | |
| filename = filename.replace("full_model.pbz2", "model.safetensors") | |
| filename = filename.replace(".pbz2", ".safetensors") | |
| return unflatten_dict(load_file(filename), sep=",") | |
| return my_load_file(filename) | |
| def load_params_from_wandb_artifact_path(checkpoint_name, legacy=False): | |
| api = wandb.Api() | |
| name = api.artifact(checkpoint_name).download() | |
| network_params = load_params(name + "/model.pbz2", legacy=legacy) | |
| return network_params | |
| def save_params_to_wandb(params, timesteps, config): | |
| if config["checkpoint_human_numbers"]: | |
| timesteps = str(round(timesteps / 1e9)) + "B" | |
| run_name = config["run_name"] + "-" + str(config["random_hash"]) + "-" + str(timesteps) | |
| save_dir = os.path.join(config["save_path"], run_name) | |
| os.makedirs(save_dir, exist_ok=True) | |
| save_params(params, f"{save_dir}/model.pbz2") | |
| # upload this to wandb as an artifact | |
| artifact = wandb.Artifact(f"{run_name}-checkpoint", type="checkpoint") | |
| artifact.add_file(f"{save_dir}/model.pbz2") | |
| artifact.save() | |
| print(f"Parameters of model saved in {save_dir}/model.pbz2") | |
| def load_params_wandb_artifact_path_full_model(checkpoint_name): | |
| api = wandb.Api() | |
| name = api.artifact(checkpoint_name).download() | |
| all_dict = load_params(name + "/full_model.pbz2") | |
| return all_dict["params"] | |
| def load_train_state_from_wandb_artifact_path(train_state, checkpoint_name, load_only_params=False, legacy=False): | |
| api = wandb.Api() | |
| name = api.artifact(checkpoint_name).download() | |
| all_dict = load_params(name + "/full_model.pbz2", legacy=legacy) | |
| if legacy: | |
| return train_state.replace(params=all_dict) | |
| train_state = train_state.replace(params=all_dict["params"]) | |
| if not load_only_params: | |
| train_state = train_state.replace( | |
| # step=all_dict["step"], | |
| opt_state=all_dict["opt_state"] | |
| ) | |
| return train_state | |
| def save_params_to_wandb(params, timesteps, config): | |
| return save_dict_to_wandb(params, timesteps, config, "params") | |
| def save_dict_to_wandb(dict, timesteps, config, name): | |
| timesteps = str(round(timesteps / 1e9)) + "B" | |
| run_name = config["run_name"] + "-" + str(config["random_hash"]) + "-" + str(timesteps) | |
| save_dir = os.path.join(config["save_path"], run_name) | |
| os.makedirs(save_dir, exist_ok=True) | |
| save_params(dict, f"{save_dir}/{name}.pbz2") | |
| # upload this to wandb as an artifact | |
| artifact = wandb.Artifact(f"{run_name}-checkpoint", type="checkpoint") | |
| artifact.add_file(f"{save_dir}/{name}.pbz2") | |
| artifact.save() | |
| print(f"Parameters of model saved in {save_dir}/{name}.pbz2") | |
| def save_model_to_wandb(train_state, timesteps, config, is_final=False): | |
| dict_to_use = {"step": train_state.step, "params": train_state.params, "opt_state": train_state.opt_state} | |
| step = int(train_state.step) | |
| if config["economical_saving"]: | |
| if step in [2048, 10240, 40960, 81920] or is_final: | |
| save_dict_to_wandb(dict_to_use, timesteps, config, "full_model") | |
| else: | |
| print("Not saving model because step is", step) | |
| else: | |
| save_dict_to_wandb(dict_to_use, timesteps, config, "full_model") | |
| def import_env_state_from_json(json_file: dict[str, Any]) -> tuple[EnvState, StaticEnvParams, EnvParams]: | |
| from kinetix.environment.env import create_empty_env | |
| def normalise(k, v): | |
| if k == "screen_dim": | |
| return v | |
| if type(v) == dict and "0" in v: | |
| return jnp.array([normalise(k, v[str(i)]) for i in range(len(v))]) | |
| return v | |
| env_state = json_file["env_state"] | |
| env_params = json_file["env_params"] | |
| static_env_params = json_file["static_env_params"] | |
| env_params_target = EnvParams() | |
| static_env_params_target = StaticEnvParams() | |
| new_env_params = flax.serialization.from_state_dict( | |
| env_params_target, {k: normalise(k, v) for k, v in env_params.items()} | |
| ) | |
| norm_static = {k: normalise(k, v) for k, v in static_env_params.items()} | |
| # norm_static["screen_dim"] = tuple(static_env_params_target.screen_dim) | |
| norm_static["downscale"] = static_env_params_target.downscale | |
| # print( | |
| # static_env_params_target, | |
| # ) | |
| new_static_env_params = flax.serialization.from_state_dict(static_env_params_target, norm_static) | |
| new_static_env_params = new_static_env_params.replace(screen_dim=static_env_params_target.screen_dim) | |
| env_state_target = create_empty_env(new_static_env_params) | |
| def astype(x, all): | |
| return jnp.astype(x, all.dtype) | |
| def _load_rigidbody(env_state_target, i, is_poly): | |
| to_load_from: dict[str, Any] = env_state["circle" if not is_poly else "polygon"][i] | |
| role = to_load_from.pop("role") | |
| density = to_load_from.pop("density") | |
| if "highlighted" in to_load_from: | |
| _ = to_load_from.pop("highlighted") | |
| new_obj = flax.serialization.from_state_dict( | |
| jax.tree.map(lambda x: x[i], env_state_target.circle if not is_poly else env_state_target.polygon), | |
| {k: normalise(k, v) for k, v in to_load_from.items()}, | |
| ) | |
| if is_poly: | |
| env_state_target = env_state_target.replace( | |
| polygon_shape_roles=env_state_target.polygon_shape_roles.at[i].set(role), | |
| polygon_densities=env_state_target.polygon_densities.at[i].set(density), | |
| polygon=jax.tree.map( | |
| lambda all, new: all.at[i].set(astype(new, all)), env_state_target.polygon, new_obj | |
| ), | |
| ) | |
| else: | |
| env_state_target = env_state_target.replace( | |
| circle_shape_roles=env_state_target.circle_shape_roles.at[i].set(role), | |
| circle_densities=env_state_target.circle_densities.at[i].set(density), | |
| circle=jax.tree.map(lambda all, new: all.at[i].set(astype(new, all)), env_state_target.circle, new_obj), | |
| ) | |
| return env_state_target | |
| # Now load the env state: | |
| for i in range(new_static_env_params.num_circles): | |
| env_state_target = _load_rigidbody(env_state_target, i, False) | |
| for i in range(new_static_env_params.num_polygons): | |
| env_state_target = _load_rigidbody(env_state_target, i, True) | |
| for i in range(new_static_env_params.num_joints): | |
| to_load_from = env_state["joint"][i] | |
| motor_binding = to_load_from.pop("motor_binding") | |
| new_obj = flax.serialization.from_state_dict( | |
| jax.tree.map(lambda x: x[i], env_state_target.joint), {k: normalise(k, v) for k, v in to_load_from.items()} | |
| ) | |
| env_state_target = env_state_target.replace( | |
| joint=jax.tree.map(lambda all, new: all.at[i].set(astype(new, all)), env_state_target.joint, new_obj), | |
| motor_bindings=env_state_target.motor_bindings.at[i].set(motor_binding), | |
| ) | |
| for i in range(new_static_env_params.num_thrusters): | |
| to_load_from = env_state["thruster"][i] | |
| thruster_binding = to_load_from.pop("thruster_binding") | |
| new_obj = flax.serialization.from_state_dict( | |
| jax.tree.map(lambda x: x[i], env_state_target.thruster), | |
| {k: normalise(k, v) for k, v in to_load_from.items()}, | |
| ) | |
| env_state_target = env_state_target.replace( | |
| thruster=jax.tree.map(lambda all, new: all.at[i].set(astype(new, all)), env_state_target.thruster, new_obj), | |
| thruster_bindings=env_state_target.thruster_bindings.at[i].set(thruster_binding), | |
| ) | |
| env_state_target = env_state_target.replace( | |
| collision_matrix=flax.serialization.from_state_dict( | |
| env_state_target.collision_matrix, normalise("collision_matrix", env_state["collision_matrix"]) | |
| ) | |
| ) | |
| for i in range(env_state_target.acc_rr_manifolds.active.shape[0]): | |
| a = flax.serialization.from_state_dict( | |
| jax.tree.map(lambda x: x[i], env_state_target.acc_rr_manifolds), | |
| {k: normalise(k, v) for k, v in env_state["acc_rr_manifolds"][i].items()}, | |
| ) | |
| b = flax.serialization.from_state_dict( | |
| jax.tree.map(lambda x: x[i], env_state_target.acc_rr_manifolds), | |
| {k: normalise(k, v) for k, v in env_state["acc_rr_manifolds"][i + 1].items()}, | |
| ) | |
| env_state_target = env_state_target.replace( | |
| acc_rr_manifolds=jax.tree.map( | |
| lambda all, new: all.at[i].set(astype(new, all)), env_state_target.acc_rr_manifolds, a | |
| ), | |
| ) | |
| env_state_target.replace( | |
| acc_rr_manifolds=jax.tree.map( | |
| lambda all, new: all.at[i + 1].set(astype(new, all)), env_state_target.acc_rr_manifolds, b | |
| ) | |
| ) | |
| for i in range(env_state_target.acc_cr_manifolds.active.shape[0]): | |
| a = flax.serialization.from_state_dict( | |
| jax.tree.map(lambda x: x[i], env_state_target.acc_cr_manifolds), | |
| {k: normalise(k, v) for k, v in env_state["acc_cr_manifolds"][i].items()}, | |
| ) | |
| env_state_target = env_state_target.replace( | |
| acc_cr_manifolds=jax.tree.map( | |
| lambda all, new: all.at[i].set(astype(new, all)), env_state_target.acc_cr_manifolds, a | |
| ), | |
| ) | |
| for i in range(env_state_target.acc_cc_manifolds.active.shape[0]): | |
| a = flax.serialization.from_state_dict( | |
| jax.tree.map(lambda x: x[i], env_state_target.acc_cc_manifolds), | |
| {k: normalise(k, v) for k, v in env_state["acc_cc_manifolds"][i].items()}, | |
| ) | |
| env_state_target = env_state_target.replace( | |
| acc_cc_manifolds=jax.tree.map( | |
| lambda all, new: all.at[i].set(astype(new, all)), env_state_target.acc_cc_manifolds, a | |
| ), | |
| ) | |
| env_state_target = env_state_target.replace( | |
| collision_matrix=calculate_collision_matrix(new_static_env_params, env_state_target.joint) | |
| ) | |
| return ( | |
| env_state_target, | |
| new_static_env_params, | |
| new_env_params.replace(max_timesteps=env_params_target.max_timesteps), | |
| ) | |
| def export_env_state_to_json( | |
| filename: str, env_state: EnvState, static_env_params: StaticEnvParams, env_params: EnvParams | |
| ): | |
| json_to_save = { | |
| "polygon": [], | |
| "circle": [], | |
| "joint": [], | |
| "thruster": [], | |
| "collision_matrix": flax.serialization.to_state_dict(env_state.collision_matrix.tolist()), | |
| "acc_rr_manifolds": [], | |
| "acc_cr_manifolds": [], | |
| "acc_cc_manifolds": [], | |
| "gravity": flax.serialization.to_state_dict(env_state.gravity.tolist()), | |
| } | |
| def _rigidbody_to_json(index: int, is_poly): | |
| main_arr = env_state.polygon if is_poly else env_state.circle | |
| c = jax.tree.map(lambda x: x[index].tolist(), main_arr) | |
| roles = env_state.polygon_shape_roles if is_poly else env_state.circle_shape_roles | |
| densities = env_state.polygon_densities if is_poly else env_state.circle_densities | |
| highlighted = env_state.polygon_highlighted if is_poly else env_state.circle_highlighted | |
| d = flax.serialization.to_state_dict(c) | |
| d["role"] = roles[index].tolist() | |
| d["density"] = densities[index].tolist() | |
| d["highlighted"] = highlighted[index].tolist() | |
| return d | |
| def _joint_to_json(i): | |
| joint = jax.tree.map(lambda x: x[i].tolist(), env_state.joint) | |
| d = flax.serialization.to_state_dict(joint) | |
| d["motor_binding"] = env_state.motor_bindings[i].tolist() | |
| return d | |
| def _thruster_to_json(i): | |
| thruster = jax.tree.map(lambda x: x[i].tolist(), env_state.thruster) | |
| d = flax.serialization.to_state_dict(thruster) | |
| d["thruster_binding"] = env_state.thruster_bindings[i].tolist() | |
| return d | |
| for i in range(static_env_params.num_circles): | |
| json_to_save["circle"].append(_rigidbody_to_json(i, False)) | |
| for i in range(static_env_params.num_polygons): | |
| json_to_save["polygon"].append(_rigidbody_to_json(i, True)) | |
| for i in range(static_env_params.num_joints): | |
| json_to_save["joint"].append(_joint_to_json(i)) | |
| for i in range(static_env_params.num_thrusters): | |
| json_to_save["thruster"].append(_thruster_to_json(i)) | |
| ncc, ncr, nrr, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices( | |
| static_env_params | |
| ) | |
| for i in range(nrr): | |
| a = jax.tree.map(lambda x: x[i, 0].tolist(), env_state.acc_rr_manifolds) | |
| b = jax.tree.map(lambda x: x[i, 1].tolist(), env_state.acc_rr_manifolds) | |
| json_to_save["acc_rr_manifolds"].append(flax.serialization.to_state_dict(a)) | |
| json_to_save["acc_rr_manifolds"].append(flax.serialization.to_state_dict(b)) | |
| for i in range(ncr): | |
| a = jax.tree.map(lambda x: x[i].tolist(), env_state.acc_cr_manifolds) | |
| json_to_save["acc_cr_manifolds"].append(flax.serialization.to_state_dict(a)) | |
| for i in range(ncc): | |
| a = jax.tree.map(lambda x: x[i].tolist(), env_state.acc_cc_manifolds) | |
| json_to_save["acc_cc_manifolds"].append(flax.serialization.to_state_dict(a)) | |
| to_save = { | |
| "env_state": json_to_save, | |
| "env_params": flax.serialization.to_state_dict( | |
| jax.tree.map(lambda x: x.tolist() if type(x) == jnp.ndarray else x, env_params) | |
| ), | |
| "static_env_params": flax.serialization.to_state_dict( | |
| jax.tree.map(lambda x: x.tolist() if type(x) == jnp.ndarray else x, static_env_params) | |
| ), | |
| } | |
| with open(filename, "w+") as f: | |
| json.dump(to_save, f) | |
| return to_save | |
| def load_from_json_file(filename): | |
| with open(filename, "r") as f: | |
| return import_env_state_from_json(json.load(f)) | |
| if __name__ == "__main__": | |
| pass | |