tree3po's picture
Upload 46 files
581eeac verified
from functools import partial
import math
import os
import chex
import jax
import jax.numpy as jnp
from flax.serialization import to_state_dict
from jax2d.engine import (
calculate_collision_matrix,
calc_inverse_mass_polygon,
calc_inverse_mass_circle,
calc_inverse_inertia_circle,
calc_inverse_inertia_polygon,
recalculate_mass_and_inertia,
select_shape,
PhysicsEngine,
)
from jax2d.sim_state import SimState, RigidBody, Joint, Thruster
from jax2d.maths import rmat
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
from kinetix.environment.ued.distributions import (
create_vmapped_filtered_distribution,
sample_kinetix_level,
)
from kinetix.environment.ued.mutators import (
make_mutate_change_shape_rotation,
make_mutate_change_shape_size,
mutate_add_connected_shape_proper,
mutate_add_shape,
mutate_add_connected_shape,
mutate_change_shape_location,
mutate_remove_joint,
mutate_remove_shape,
mutate_swap_role,
mutate_toggle_fixture,
mutate_add_thruster,
mutate_remove_thruster,
mutate_change_gravity,
)
from kinetix.environment.ued.ued_state import UEDParams
from kinetix.environment.utils import permute_pcg_state
from kinetix.pcg.pcg import env_state_to_pcg_state, sample_pcg_state
from kinetix.util.config import generate_ued_params_from_config, generate_params_from_config
from kinetix.util.saving import get_pcg_state_from_json, load_pcg_state_pickle, load_world_state_pickle, stack_list_of_pytrees, expand_env_state
from flax import struct
from kinetix.environment.env import create_empty_env
from kinetix.util.learning import BASE_DIR, general_eval, get_eval_levels
def make_mutate_env(static_env_params: StaticEnvParams, params: EnvParams, ued_params: UEDParams):
mutate_size = make_mutate_change_shape_size(params, static_env_params)
mutate_rot = make_mutate_change_shape_rotation(params, static_env_params)
def mutate_level(rng, level: EnvState, n=1):
def inner(carry: tuple[chex.PRNGKey, EnvState], _):
rng, level = carry
rng, _rng, _rng2 = jax.random.split(rng, 3)
any_rects_left = jnp.logical_not(level.polygon.active).sum() > 0
any_circles_left = jnp.logical_not(level.circle.active).sum() > 0
any_joints_left = jnp.logical_not(level.joint.active).sum() > 0
any_thrust_left = jnp.logical_not(level.thruster.active).sum() > 0
has_any_thursters = level.thruster.active.sum() > 0
can_do_add_shape = any_rects_left | any_circles_left
can_do_add_joint = can_do_add_shape & any_joints_left
all_mutations = [
mutate_add_shape,
mutate_add_connected_shape_proper,
mutate_remove_joint,
mutate_remove_shape,
mutate_swap_role,
mutate_add_thruster,
mutate_remove_thruster,
mutate_toggle_fixture,
mutate_size,
mutate_change_shape_location,
mutate_rot,
]
def mypartial(f):
def inner(rng, level):
return f(rng, level, params, static_env_params, ued_params)
return inner
probs = jnp.array(
[
can_do_add_shape * 1.0,
can_do_add_joint * 1.0,
0.0,
0.0,
1.0,
any_thrust_left * 1.0,
has_any_thursters * 1.0,
0.1,
1.0,
1.0,
1.0,
]
)
all_mutations = [mypartial(i) for i in all_mutations]
index = jax.random.choice(_rng, jnp.arange(len(all_mutations)), (), p=probs)
level = jax.lax.switch(index, all_mutations, _rng2, level)
return (rng, level), None
(_, level), _ = jax.lax.scan(inner, (rng, level), None, length=n)
return level
return mutate_level
def make_create_eval_env():
eval_level1 = load_world_state_pickle("worlds/eval/eval_0610_car1")
eval_level2 = load_world_state_pickle("worlds/eval/eval_0610_car2")
eval_level3 = load_world_state_pickle("worlds/eval/eval_0628_ball_left")
eval_level4 = load_world_state_pickle("worlds/eval/eval_0628_ball_right")
eval_level5 = load_world_state_pickle("worlds/eval/eval_0628_hard_car_obstacle")
eval_level6 = load_world_state_pickle("worlds/eval/eval_0628_swingup")
def _create_eval_env(rng, env_params, static_env_params, index):
return jax.lax.switch(
index,
[
lambda: eval_level1,
lambda: eval_level2,
lambda: eval_level3,
lambda: eval_level4,
lambda: eval_level5,
lambda: eval_level6,
],
)
return jax.tree.map(lambda x, y: jax.lax.select(index == 0, x, y), eval_level1, eval_level2)
return _create_eval_env
def make_reset_train_function_with_mutations(
engine: PhysicsEngine, env_params: EnvParams, static_env_params: StaticEnvParams, config, make_pcg_state=True
):
ued_params = generate_ued_params_from_config(config)
def reset(rng):
inner = sample_kinetix_level(
rng, engine, env_params, static_env_params, ued_params, env_size_name=config["env_size_name"]
)
if make_pcg_state:
return env_state_to_pcg_state(inner)
else:
return inner
return reset
def make_vmapped_filtered_level_sampler(
level_sampler, env_params: EnvParams, static_env_params: StaticEnvParams, config, make_pcg_state, env
):
ued_params = generate_ued_params_from_config(config)
def reset(rng, n_samples):
inner = create_vmapped_filtered_distribution(
rng,
level_sampler,
env_params,
static_env_params,
ued_params,
n_samples,
env,
config["filter_levels"],
config["level_filter_sample_ratio"],
config["env_size_name"],
config["level_filter_n_steps"],
)
if make_pcg_state:
return env_state_to_pcg_state(inner)
else:
return inner
return reset
def make_reset_train_function_with_list_of_levels(config, levels, static_env_params, make_pcg_state=True,
is_loading_train_levels=False):
assert len(levels) > 0, "Need to provide at least one level to train on"
if config["load_train_levels_legacy"]:
ls = [get_pcg_state_from_json(os.path.join(BASE_DIR, l + ("" if l.endswith(".json") else ".json"))) for l in levels]
v = stack_list_of_pytrees(ls)
elif is_loading_train_levels:
v = get_eval_levels(levels, static_env_params)
else:
_, static_env_params = generate_params_from_config(
config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
)
v = get_eval_levels(levels, static_env_params)
def reset(rng):
rng, _rng, _rng2 = jax.random.split(rng, 3)
idx = jax.random.randint(_rng, (), 0, len(levels))
state_to_return = jax.tree.map(lambda x: x[idx], v)
if config["permute_state_during_training"]:
state_to_return = permute_pcg_state(rng, state_to_return, static_env_params)
if not make_pcg_state:
state_to_return = sample_pcg_state(_rng2, state_to_return, params=None, static_params=static_env_params)
return state_to_return
return reset
ALL_MUTATION_FNS = [
mutate_add_shape,
mutate_add_connected_shape,
mutate_remove_joint,
mutate_swap_role,
mutate_toggle_fixture,
mutate_add_thruster,
mutate_remove_thruster,
mutate_remove_shape,
mutate_change_gravity,
]
def test_ued():
from kinetix.environment.env import create_empty_env
env_params = EnvParams()
static_env_params = StaticEnvParams()
ued_params = UEDParams()
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
state = create_empty_env(env_params, static_env_params)
state = mutate_add_shape(_rng, state, env_params, static_env_params, ued_params)
state = mutate_add_connected_shape(_rng, state, env_params, static_env_params, ued_params)
state = mutate_remove_shape(_rng, state, env_params, static_env_params, ued_params)
state = mutate_remove_joint(_rng, state, env_params, static_env_params, ued_params)
state = mutate_swap_role(_rng, state, env_params, static_env_params, ued_params)
state = mutate_toggle_fixture(_rng, state, env_params, static_env_params, ued_params)
print("Successfully did this")
if __name__ == "__main__":
test_ued()