Spaces:
Runtime error
Runtime error
from functools import partial | |
import math | |
import os | |
import chex | |
import jax | |
import jax.numpy as jnp | |
from flax.serialization import to_state_dict | |
from jax2d.engine import ( | |
calculate_collision_matrix, | |
calc_inverse_mass_polygon, | |
calc_inverse_mass_circle, | |
calc_inverse_inertia_circle, | |
calc_inverse_inertia_polygon, | |
recalculate_mass_and_inertia, | |
select_shape, | |
PhysicsEngine, | |
) | |
from jax2d.sim_state import SimState, RigidBody, Joint, Thruster | |
from jax2d.maths import rmat | |
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams | |
from kinetix.environment.ued.distributions import ( | |
create_vmapped_filtered_distribution, | |
sample_kinetix_level, | |
) | |
from kinetix.environment.ued.mutators import ( | |
make_mutate_change_shape_rotation, | |
make_mutate_change_shape_size, | |
mutate_add_connected_shape_proper, | |
mutate_add_shape, | |
mutate_add_connected_shape, | |
mutate_change_shape_location, | |
mutate_remove_joint, | |
mutate_remove_shape, | |
mutate_swap_role, | |
mutate_toggle_fixture, | |
mutate_add_thruster, | |
mutate_remove_thruster, | |
mutate_change_gravity, | |
) | |
from kinetix.environment.ued.ued_state import UEDParams | |
from kinetix.environment.utils import permute_pcg_state | |
from kinetix.pcg.pcg import env_state_to_pcg_state, sample_pcg_state | |
from kinetix.util.config import generate_ued_params_from_config, generate_params_from_config | |
from kinetix.util.saving import get_pcg_state_from_json, load_pcg_state_pickle, load_world_state_pickle, stack_list_of_pytrees, expand_env_state | |
from flax import struct | |
from kinetix.environment.env import create_empty_env | |
from kinetix.util.learning import BASE_DIR, general_eval, get_eval_levels | |
def make_mutate_env(static_env_params: StaticEnvParams, params: EnvParams, ued_params: UEDParams): | |
mutate_size = make_mutate_change_shape_size(params, static_env_params) | |
mutate_rot = make_mutate_change_shape_rotation(params, static_env_params) | |
def mutate_level(rng, level: EnvState, n=1): | |
def inner(carry: tuple[chex.PRNGKey, EnvState], _): | |
rng, level = carry | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
any_rects_left = jnp.logical_not(level.polygon.active).sum() > 0 | |
any_circles_left = jnp.logical_not(level.circle.active).sum() > 0 | |
any_joints_left = jnp.logical_not(level.joint.active).sum() > 0 | |
any_thrust_left = jnp.logical_not(level.thruster.active).sum() > 0 | |
has_any_thursters = level.thruster.active.sum() > 0 | |
can_do_add_shape = any_rects_left | any_circles_left | |
can_do_add_joint = can_do_add_shape & any_joints_left | |
all_mutations = [ | |
mutate_add_shape, | |
mutate_add_connected_shape_proper, | |
mutate_remove_joint, | |
mutate_remove_shape, | |
mutate_swap_role, | |
mutate_add_thruster, | |
mutate_remove_thruster, | |
mutate_toggle_fixture, | |
mutate_size, | |
mutate_change_shape_location, | |
mutate_rot, | |
] | |
def mypartial(f): | |
def inner(rng, level): | |
return f(rng, level, params, static_env_params, ued_params) | |
return inner | |
probs = jnp.array( | |
[ | |
can_do_add_shape * 1.0, | |
can_do_add_joint * 1.0, | |
0.0, | |
0.0, | |
1.0, | |
any_thrust_left * 1.0, | |
has_any_thursters * 1.0, | |
0.1, | |
1.0, | |
1.0, | |
1.0, | |
] | |
) | |
all_mutations = [mypartial(i) for i in all_mutations] | |
index = jax.random.choice(_rng, jnp.arange(len(all_mutations)), (), p=probs) | |
level = jax.lax.switch(index, all_mutations, _rng2, level) | |
return (rng, level), None | |
(_, level), _ = jax.lax.scan(inner, (rng, level), None, length=n) | |
return level | |
return mutate_level | |
def make_create_eval_env(): | |
eval_level1 = load_world_state_pickle("worlds/eval/eval_0610_car1") | |
eval_level2 = load_world_state_pickle("worlds/eval/eval_0610_car2") | |
eval_level3 = load_world_state_pickle("worlds/eval/eval_0628_ball_left") | |
eval_level4 = load_world_state_pickle("worlds/eval/eval_0628_ball_right") | |
eval_level5 = load_world_state_pickle("worlds/eval/eval_0628_hard_car_obstacle") | |
eval_level6 = load_world_state_pickle("worlds/eval/eval_0628_swingup") | |
def _create_eval_env(rng, env_params, static_env_params, index): | |
return jax.lax.switch( | |
index, | |
[ | |
lambda: eval_level1, | |
lambda: eval_level2, | |
lambda: eval_level3, | |
lambda: eval_level4, | |
lambda: eval_level5, | |
lambda: eval_level6, | |
], | |
) | |
return jax.tree.map(lambda x, y: jax.lax.select(index == 0, x, y), eval_level1, eval_level2) | |
return _create_eval_env | |
def make_reset_train_function_with_mutations( | |
engine: PhysicsEngine, env_params: EnvParams, static_env_params: StaticEnvParams, config, make_pcg_state=True | |
): | |
ued_params = generate_ued_params_from_config(config) | |
def reset(rng): | |
inner = sample_kinetix_level( | |
rng, engine, env_params, static_env_params, ued_params, env_size_name=config["env_size_name"] | |
) | |
if make_pcg_state: | |
return env_state_to_pcg_state(inner) | |
else: | |
return inner | |
return reset | |
def make_vmapped_filtered_level_sampler( | |
level_sampler, env_params: EnvParams, static_env_params: StaticEnvParams, config, make_pcg_state, env | |
): | |
ued_params = generate_ued_params_from_config(config) | |
def reset(rng, n_samples): | |
inner = create_vmapped_filtered_distribution( | |
rng, | |
level_sampler, | |
env_params, | |
static_env_params, | |
ued_params, | |
n_samples, | |
env, | |
config["filter_levels"], | |
config["level_filter_sample_ratio"], | |
config["env_size_name"], | |
config["level_filter_n_steps"], | |
) | |
if make_pcg_state: | |
return env_state_to_pcg_state(inner) | |
else: | |
return inner | |
return reset | |
def make_reset_train_function_with_list_of_levels(config, levels, static_env_params, make_pcg_state=True, | |
is_loading_train_levels=False): | |
assert len(levels) > 0, "Need to provide at least one level to train on" | |
if config["load_train_levels_legacy"]: | |
ls = [get_pcg_state_from_json(os.path.join(BASE_DIR, l + ("" if l.endswith(".json") else ".json"))) for l in levels] | |
v = stack_list_of_pytrees(ls) | |
elif is_loading_train_levels: | |
v = get_eval_levels(levels, static_env_params) | |
else: | |
_, static_env_params = generate_params_from_config( | |
config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]} | |
) | |
v = get_eval_levels(levels, static_env_params) | |
def reset(rng): | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
idx = jax.random.randint(_rng, (), 0, len(levels)) | |
state_to_return = jax.tree.map(lambda x: x[idx], v) | |
if config["permute_state_during_training"]: | |
state_to_return = permute_pcg_state(rng, state_to_return, static_env_params) | |
if not make_pcg_state: | |
state_to_return = sample_pcg_state(_rng2, state_to_return, params=None, static_params=static_env_params) | |
return state_to_return | |
return reset | |
ALL_MUTATION_FNS = [ | |
mutate_add_shape, | |
mutate_add_connected_shape, | |
mutate_remove_joint, | |
mutate_swap_role, | |
mutate_toggle_fixture, | |
mutate_add_thruster, | |
mutate_remove_thruster, | |
mutate_remove_shape, | |
mutate_change_gravity, | |
] | |
def test_ued(): | |
from kinetix.environment.env import create_empty_env | |
env_params = EnvParams() | |
static_env_params = StaticEnvParams() | |
ued_params = UEDParams() | |
rng = jax.random.PRNGKey(0) | |
rng, _rng = jax.random.split(rng) | |
state = create_empty_env(env_params, static_env_params) | |
state = mutate_add_shape(_rng, state, env_params, static_env_params, ued_params) | |
state = mutate_add_connected_shape(_rng, state, env_params, static_env_params, ued_params) | |
state = mutate_remove_shape(_rng, state, env_params, static_env_params, ued_params) | |
state = mutate_remove_joint(_rng, state, env_params, static_env_params, ued_params) | |
state = mutate_swap_role(_rng, state, env_params, static_env_params, ued_params) | |
state = mutate_toggle_fixture(_rng, state, env_params, static_env_params, ued_params) | |
print("Successfully did this") | |
if __name__ == "__main__": | |
test_ued() | |