Spaces:
Runtime error
Runtime error
from functools import partial | |
import math | |
import chex | |
import jax | |
import jax.numpy as jnp | |
from flax.serialization import to_state_dict | |
from jax2d.engine import ( | |
PhysicsEngine, | |
calculate_collision_matrix, | |
calc_inverse_mass_polygon, | |
calc_inverse_mass_circle, | |
calc_inverse_inertia_circle, | |
calc_inverse_inertia_polygon, | |
recalculate_mass_and_inertia, | |
select_shape, | |
) | |
from jax2d.sim_state import SimState, RigidBody, Joint, Thruster | |
from jax2d.maths import rmat | |
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams | |
from kinetix.environment.ued.ued_state import UEDParams | |
from kinetix.environment.ued.util import ( | |
count_roles, | |
is_space_for_joint, | |
make_velocities_zero, | |
sample_dimensions, | |
random_position_on_polygon, | |
random_position_on_circle, | |
get_role, | |
is_space_for_shape, | |
are_there_shapes_present, | |
) | |
from kinetix.util.saving import load_world_state_pickle | |
from flax import struct | |
from kinetix.environment.env import create_empty_env | |
from kinetix.environment.ued.util import make_do_dummy_step | |
def mutate_add_shape( | |
rng, | |
state: EnvState, | |
params: EnvParams, | |
static_env_params: StaticEnvParams, | |
ued_params: UEDParams, | |
force_no_fixate: bool = False, | |
): | |
def do_dummy(rng, state): | |
return state | |
def do_add(rng, state): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, 9) | |
space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons | |
space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles | |
is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0]) | |
is_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=is_rect_p) | |
rect_index = jnp.argmin(state.polygon.active) | |
circle_index = jnp.argmin(state.circle.active) | |
shape_role = get_role(_rngs[1], state, static_env_params) | |
max_shape_size = ( | |
jnp.array([1.0, ued_params.goal_body_size_factor, ued_params.goal_body_size_factor, 1.0])[shape_role] | |
* ued_params.max_shape_size | |
) | |
vertices, half_dimensions, radius = sample_dimensions( | |
_rngs[2], | |
static_env_params, | |
is_rect, | |
ued_params, | |
max_shape_size=max_shape_size, | |
) | |
n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4) | |
largest = jnp.max(jnp.array([half_dimensions[0] * jnp.sqrt(2), half_dimensions[1] * jnp.sqrt(2), radius])) | |
screen_dim_world = ( | |
static_env_params.screen_dim[0] / params.pixels_per_unit, | |
static_env_params.screen_dim[1] / params.pixels_per_unit, | |
) | |
min_x = largest | |
max_x = screen_dim_world[0] - largest | |
min_y = largest + 0.4 | |
max_y = screen_dim_world[1] - largest | |
def _og_minmax(): | |
return min_x, max_x, min_y, max_y | |
def _opposite_minmax(): | |
return jax.lax.switch( | |
shape_role, | |
[ | |
(lambda: (min_x, max_x, min_y, max_y)), | |
(lambda: (min_x, max_x - screen_dim_world[0] / 2, min_y, max_y)), | |
(lambda: (min_x + screen_dim_world[0] / 2, max_x, min_y, max_y)), | |
(lambda: (min_x, max_x, min_y, max_y)), | |
], | |
) | |
min_x, max_x, min_y, max_y = jax.lax.cond( | |
jax.random.uniform(_rngs[4], shape=()) < ued_params.goal_body_opposide_side_chance, | |
_opposite_minmax, | |
_og_minmax, | |
) | |
position = jax.random.uniform(_rngs[5], shape=(2,)) * jnp.array( | |
[ | |
max_x - min_x, | |
max_y - min_y, | |
] | |
) + jnp.array([min_x, min_y]) | |
rotation = jax.random.uniform(_rngs[6], shape=()) * 2 * math.pi | |
velocity = jnp.array([0.0, 0.0]) | |
angular_velocity = 0.0 | |
density = 1.0 | |
inverse_mass = jax.lax.select( | |
is_rect, | |
calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0], | |
calc_inverse_mass_circle(radius, density), | |
) | |
inverse_inertia = jax.lax.select( | |
is_rect, | |
calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density), | |
calc_inverse_inertia_circle(radius, density), | |
) | |
fixate_chance = ued_params.fixate_chance_min + (1.0 / inverse_mass) * ued_params.fixate_chance_scale | |
fixate_chance = jnp.minimum(fixate_chance, ued_params.fixate_chance_max) | |
is_fixated = jax.random.uniform(_rngs[7], shape=()) < fixate_chance | |
is_fixated &= ~force_no_fixate | |
inverse_mass *= 1 - is_fixated | |
inverse_inertia *= 1 - is_fixated | |
# We want to bias fixated shapes to starting nearer the bottom half of the screen | |
fixate_shape_bottom_bias = ( | |
ued_params.fixate_shape_bottom_bias + ued_params.fixate_shape_bottom_bias_special_role * (shape_role != 0) | |
) | |
is_forcing_bottom = jax.random.uniform(_rngs[8]) < fixate_shape_bottom_bias | |
half_screen_height = (static_env_params.screen_dim[1] / params.pixels_per_unit) / 2.0 | |
position = jax.lax.select( | |
is_fixated & is_forcing_bottom & (position[1] >= half_screen_height), | |
position.at[1].add(-half_screen_height), | |
position, | |
) | |
# This could be either a rect or a circle | |
new_rigid_body = RigidBody( | |
position=position, | |
velocity=velocity, | |
inverse_mass=inverse_mass, | |
inverse_inertia=inverse_inertia, | |
rotation=rotation, | |
angular_velocity=angular_velocity, | |
radius=radius, | |
active=True, | |
friction=1.0, | |
vertices=vertices, | |
n_vertices=n_vertices, | |
collision_mode=1, | |
restitution=0.0, | |
) | |
state = state.replace( | |
polygon=jax.tree.map( | |
lambda x, y: jax.lax.select(is_rect, y.at[rect_index].set(x), y), new_rigid_body, state.polygon | |
), | |
circle=jax.tree.map( | |
lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[circle_index].set(x), y), | |
new_rigid_body, | |
state.circle, | |
), | |
polygon_shape_roles=jax.lax.select( | |
is_rect, | |
state.polygon_shape_roles.at[rect_index].set(shape_role), | |
state.polygon_shape_roles, | |
), | |
circle_shape_roles=jax.lax.select( | |
jnp.logical_not(is_rect), | |
state.circle_shape_roles.at[circle_index].set(shape_role), | |
state.circle_shape_roles, | |
), | |
) | |
return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities) | |
return jax.lax.cond(is_space_for_shape(state), do_add, do_dummy, rng, state) | |
def mutate_add_connected_shape( | |
rng, | |
state: EnvState, | |
params: EnvParams, | |
static_env_params: StaticEnvParams, | |
ued_params: UEDParams, | |
force_rjoint: bool = False, | |
): | |
def do_dummy(rng, state): | |
return state, False | |
def do_add(rng, state): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, 21) | |
# Select a random index amongst the currently active shapes. | |
p_rect = state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False) | |
p_circle = state.circle.active | |
p_rect = p_rect.astype(jnp.float32) | |
p_circle = p_circle.astype(jnp.float32) | |
p_rect *= (state.polygon.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + ( | |
state.polygon.inverse_mass != 0 | |
) * 1.0 | |
p_circle *= (state.circle.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + ( | |
state.circle.inverse_mass != 0 | |
) * 1.0 | |
# Bias based on number of existing connections | |
rect_connections = jnp.zeros(static_env_params.num_polygons) | |
circle_connections = jnp.zeros(static_env_params.num_circles) | |
rect_connections = rect_connections.at[state.joint.a_index].add( | |
jnp.ones(static_env_params.num_joints) | |
* state.joint.active | |
* (state.joint.a_index < static_env_params.num_polygons) | |
) | |
rect_connections = rect_connections.at[state.joint.b_index].add( | |
jnp.ones(static_env_params.num_joints) | |
* state.joint.active | |
* (state.joint.b_index < static_env_params.num_polygons) | |
) | |
circle_connections = circle_connections.at[state.joint.a_index - static_env_params.num_polygons].add( | |
jnp.ones(static_env_params.num_joints) | |
* state.joint.active | |
* (state.joint.a_index >= static_env_params.num_polygons) | |
) | |
circle_connections = circle_connections.at[state.joint.b_index - static_env_params.num_polygons].add( | |
jnp.ones(static_env_params.num_joints) | |
* state.joint.active | |
* (state.joint.b_index >= static_env_params.num_polygons) | |
) | |
# Rectangles can have up to 2 connections | |
p_rect *= (-rect_connections + 2.0) / 2.0 | |
p_rect = jnp.maximum(p_rect, 0.0) | |
# Circles can have 1 connection | |
p_circle *= circle_connections == 0 | |
# To sample a target rect/circle, we have to have at least one. | |
target_rect_p = jnp.array( | |
[ | |
(state.polygon.active.astype(int).sum() > static_env_params.num_static_fixated_polys) * 1.0, | |
(state.circle.active.astype(int).sum() > 0) * 1.0, | |
] | |
) | |
# Don't connect to a circle if no connection-free ones exist | |
target_rect_p = target_rect_p.at[1].mul(p_circle.sum() > 0) | |
space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons | |
space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles | |
is_target_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=target_rect_p) | ( | |
~space_for_new_rect | |
) | |
is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0]) | |
is_rect = jax.random.choice(_rngs[1], jnp.array([True, False], dtype=bool), p=is_rect_p) | ( | |
~is_target_rect & space_for_new_rect | |
) | |
shape_index = jax.lax.select( | |
is_rect, | |
jnp.argmin(state.polygon.active), | |
jnp.argmin(state.circle.active), | |
) | |
unified_shape_index = shape_index + (~is_rect) * static_env_params.num_polygons | |
vertices, half_dimensions, radius = sample_dimensions( | |
_rngs[2], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size | |
) | |
n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4) | |
rotation = jax.random.uniform(_rngs[4], shape=()) * 2 * math.pi | |
velocity = jnp.array([0.0, 0.0]) | |
angular_velocity = 0.0 | |
density = 1.0 | |
inverse_mass = jax.lax.select( | |
is_rect, | |
calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0], | |
calc_inverse_mass_circle(radius, density), | |
) | |
inverse_inertia = jax.lax.select( | |
is_rect, | |
calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density), | |
calc_inverse_inertia_circle(radius, density), | |
) | |
# Joint | |
current_num_rjoints = (jnp.logical_not(state.joint.is_fixed_joint) * state.joint.active).sum() | |
is_rjoint = jnp.logical_or( | |
jnp.logical_or(jax.random.uniform(_rngs[5]) < 0.5, force_rjoint), | |
current_num_rjoints < ued_params.min_rjoints_bias, | |
) | |
joint_index = jnp.argmin(state.joint.active) | |
local_joint_position_rect = random_position_on_polygon(_rngs[6], vertices, n_vertices, static_env_params) | |
local_joint_position_circle = random_position_on_circle(_rngs[7], radius, on_centre_chance=1.0) | |
local_joint_position = jax.lax.select(is_rect, local_joint_position_rect, local_joint_position_circle) | |
p_rect = jax.lax.select(p_rect.sum() == 0, state.polygon.active.astype(jnp.float32), p_rect) | |
p_circle = jax.lax.select(p_circle.sum() == 0, state.circle.active.astype(jnp.float32), p_circle) | |
target_index = jax.lax.select( | |
is_target_rect, | |
jax.random.choice( | |
_rngs[8], | |
jnp.arange(static_env_params.num_polygons), | |
p=p_rect, | |
), | |
jax.random.choice( | |
_rngs[9], | |
jnp.arange(static_env_params.num_circles), | |
p=p_circle, | |
), | |
) | |
unified_target_index = target_index + jnp.logical_not(is_target_rect) * static_env_params.num_polygons | |
target_shape = select_shape(state, unified_target_index, static_env_params) | |
target_joint_position_rect = random_position_on_polygon( | |
_rngs[10], state.polygon.vertices[target_index], state.polygon.n_vertices[target_index], static_env_params | |
) | |
target_joint_position_circle = random_position_on_circle( | |
_rngs[11], state.circle.radius[target_index], on_centre_chance=1.0 | |
) | |
target_joint_position = jax.lax.select(is_target_rect, target_joint_position_rect, target_joint_position_circle) | |
# Calculate the world position of the new shape | |
# We know the rotation of the new shape. We also know the position of the current shape, which we want to remain fixed. | |
# Set `position` such that local_joint_position is the same as `target_joint_position` | |
global_joint_pos = target_shape.position + jnp.matmul(rmat(target_shape.rotation), target_joint_position) | |
position = global_joint_pos - jnp.matmul(rmat(rotation), local_joint_position) | |
_, pos_diff = calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density) | |
position = jax.lax.select(is_rect, position + pos_diff, position) | |
local_joint_position = jax.lax.select(is_rect, local_joint_position - pos_diff, local_joint_position) | |
vertices = jax.lax.select(is_rect, vertices - pos_diff[None], vertices) | |
target_role = jax.lax.select( | |
is_target_rect, state.polygon_shape_roles[target_index], state.circle_shape_roles[target_index] | |
) | |
# We cannot have role 1 and role 2 being connected. | |
p = jnp.array([1.0, 1.0, 1.0, 1.0]) | |
# If role is 0, keep all probs at 1, otherwise set the target role's complement to 0 prob | |
# 3 - role turns 1 to 2 and 2 to 1 | |
# If the target role is three, we set everything to zero except for the default | |
p = jax.lax.select( | |
target_role == 0, | |
p, | |
jax.lax.select( | |
target_role <= 2, | |
p.at[3 - target_role].set(False).at[3].set(False), | |
(p.at[2].set(False).at[1].set(False)), | |
), | |
) | |
shape_role = get_role(_rngs[12], state, static_env_params, initial_p=p) | |
# This could be either a rect or a circle | |
new_rigid_body = RigidBody( | |
position=position, | |
velocity=velocity, | |
inverse_mass=inverse_mass, | |
inverse_inertia=inverse_inertia, | |
rotation=rotation, | |
angular_velocity=angular_velocity, | |
radius=radius, | |
active=True, | |
friction=1.0, | |
vertices=vertices, | |
n_vertices=n_vertices, | |
collision_mode=1, | |
restitution=0.0, | |
) | |
# Change the shape indices such that a_index is less than b_index | |
a_index = shape_index + (1 - is_rect) * static_env_params.num_polygons | |
b_index = target_index + (1 - is_target_rect) * static_env_params.num_polygons | |
should_swap = a_index > b_index | |
a_index, b_index, local_joint_position, target_joint_position, shape_a, shape_b = jax.lax.cond( | |
should_swap, | |
lambda x: (x[1], x[0], x[3], x[2], x[5], x[4]), # pairwise swap | |
lambda x: x, | |
(a_index, b_index, local_joint_position, target_joint_position, new_rigid_body, target_shape), | |
) | |
motor_on = jax.random.uniform(_rngs[13], shape=()) < ued_params.motor_on_chance | |
joint_colour = jax.random.randint(_rngs[14], shape=(), minval=0, maxval=static_env_params.num_motor_bindings) | |
joint_rotation = shape_b.rotation - shape_a.rotation | |
motor_speed = jax.random.uniform( | |
_rngs[15], shape=(), minval=ued_params.motor_min_speed, maxval=ued_params.motor_max_speed | |
) | |
motor_power = jax.random.uniform( | |
_rngs[16], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.motor_max_power | |
) | |
wheel_power = jax.random.uniform( | |
_rngs[20], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.wheel_max_power | |
) | |
# High-powered wheels break the physics engine - this is a temporary fix | |
motor_power = jax.lax.select(is_rect & is_target_rect, motor_power, wheel_power) | |
motor_has_joint_limits = jax.random.uniform(_rngs[17], shape=()) < ued_params.joint_limit_chance | |
motor_has_joint_limits &= is_rect & is_target_rect | |
joint_limit_min = ( | |
jax.random.uniform(_rngs[18], shape=(), minval=-ued_params.joint_limit_max, maxval=0.0) | |
* motor_has_joint_limits | |
) | |
joint_limit_max = ( | |
jax.random.uniform(_rngs[19], shape=(), minval=0.0, maxval=ued_params.joint_limit_max) | |
* motor_has_joint_limits | |
) | |
rjoint = Joint( | |
a_index=a_index, | |
b_index=b_index, | |
a_relative_pos=local_joint_position, | |
b_relative_pos=target_joint_position, | |
global_position=global_joint_pos, | |
active=True, | |
motor_speed=motor_speed, | |
motor_power=motor_power, | |
motor_on=motor_on, | |
# colour=joint_colour, | |
motor_has_joint_limits=motor_has_joint_limits, | |
min_rotation=joint_limit_min, | |
max_rotation=joint_limit_max, | |
is_fixed_joint=False, | |
rotation=0.0, | |
acc_impulse=jnp.zeros((2,), dtype=jnp.float32), | |
acc_r_impulse=jnp.zeros((), dtype=jnp.float32), | |
) | |
fjoint = Joint( | |
a_index=a_index, | |
b_index=b_index, | |
a_relative_pos=local_joint_position, | |
b_relative_pos=target_joint_position, | |
global_position=global_joint_pos, | |
active=True, | |
rotation=joint_rotation, | |
acc_impulse=jnp.zeros((2,), dtype=jnp.float32), | |
acc_r_impulse=jnp.zeros((), dtype=jnp.float32), | |
is_fixed_joint=True, | |
motor_has_joint_limits=False, | |
min_rotation=0.0, | |
max_rotation=0.0, | |
motor_on=False, | |
motor_power=0.0, | |
motor_speed=0.0, | |
) | |
state = state.replace( | |
polygon=jax.tree.map( | |
lambda x, y: jax.lax.select(is_rect, y.at[shape_index].set(x), y), new_rigid_body, state.polygon | |
), | |
circle=jax.tree.map( | |
lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[shape_index].set(x), y), | |
new_rigid_body, | |
state.circle, | |
), | |
joint=jax.tree.map( | |
lambda rj, fj, y: jax.lax.select(is_rjoint, y.at[joint_index].set(rj), y.at[joint_index].set(fj)), | |
rjoint, | |
fjoint, | |
state.joint, | |
), | |
polygon_shape_roles=jax.lax.select( | |
is_rect, | |
state.polygon_shape_roles.at[shape_index].set(shape_role), | |
state.polygon_shape_roles, | |
), | |
circle_shape_roles=jax.lax.select( | |
jnp.logical_not(is_rect), | |
state.circle_shape_roles.at[shape_index].set(shape_role), | |
state.circle_shape_roles, | |
), | |
motor_bindings=state.motor_bindings.at[joint_index].set(joint_colour), | |
) | |
# We need the new collision matrix. | |
state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint)) | |
state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities) | |
# Was this a valid addition? | |
# We calculate whether (assuming the possiblity of 360 degree rotation around the joint) | |
# both shapes can be visible | |
# This is to remove the common degenerate pattern of connected shapes being fully inside each other | |
def _get_min_rect_dist(r_id, local_pos): | |
rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon) | |
half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0 | |
half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0 | |
dist_x = half_width - jnp.abs(local_pos[0]) | |
dist_y = half_height - jnp.abs(local_pos[1]) | |
return jnp.minimum(dist_x, dist_y) | |
def _get_max_rect_dist(r_id, local_pos): | |
rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon) | |
half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0 | |
half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0 | |
dist_x = jnp.maximum( | |
jnp.abs(half_width - local_pos[0]), | |
jnp.abs(-half_width - local_pos[0]), | |
) | |
dist_y = jnp.maximum( | |
jnp.abs(half_height - local_pos[1]), | |
jnp.abs(-half_height - local_pos[1]), | |
) | |
return jnp.sqrt(dist_x * dist_x + dist_y * dist_y) | |
def are_both_shapes_showing(idx1, idx2, local_pos1, local_pos2): | |
def _is_small_shape_showing(small_idx, big_idx, small_local_pos, big_local_pos): | |
small_is_poly = small_idx < static_env_params.num_polygons | |
big_is_poly = big_idx < static_env_params.num_polygons | |
# CC | |
cc_result = False | |
# CR | |
cr_r_dist = _get_min_rect_dist(big_idx, big_local_pos) | |
cr_result = ( | |
cr_r_dist + ued_params.connect_visibility_min | |
< state.circle.radius[small_idx - static_env_params.num_polygons] | |
) | |
# RC | |
rc_r_dist = _get_max_rect_dist(small_idx, small_local_pos) | |
rc_result = ( | |
rc_r_dist | |
> state.circle.radius[big_idx - static_env_params.num_polygons] + ued_params.connect_visibility_min | |
) | |
# RR | |
rr_small_dist = _get_max_rect_dist(small_idx, small_local_pos) | |
rr_big_dist = _get_min_rect_dist(big_idx, big_local_pos) | |
rr_result = rr_small_dist > rr_big_dist + ued_params.connect_visibility_min | |
# Select | |
return jax.lax.select( | |
small_is_poly, | |
jax.lax.select(big_is_poly, rr_result, rc_result), | |
jax.lax.select(big_is_poly, cr_result, cc_result), | |
) | |
# Are both shapes showing? | |
return _is_small_shape_showing(idx1, idx2, local_pos1, local_pos2) & _is_small_shape_showing( | |
idx2, idx1, local_pos2, local_pos1 | |
) | |
valid = are_both_shapes_showing( | |
unified_shape_index, unified_target_index, local_joint_position, target_joint_position | |
) | |
return state, valid | |
# To add a connected shape, we must have both at least one existing shape and space | |
return jax.lax.cond( | |
is_space_for_shape(state) & are_there_shapes_present(state, static_env_params) & is_space_for_joint(state), | |
do_add, | |
do_dummy, | |
rng, | |
state, | |
) | |
def mutate_add_connected_shape_proper( | |
rng, | |
state: EnvState, | |
params: EnvParams, | |
static_env_params: StaticEnvParams, | |
ued_params: UEDParams, | |
force_rjoint: bool = False, | |
): | |
return mutate_add_connected_shape(rng, state, params, static_env_params, ued_params, force_rjoint=force_rjoint)[0] | |
def mutate_remove_shape( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
can_remove_mask = ( | |
jnp.concatenate([state.polygon.active, state.circle.active]) | |
.at[: static_env_params.num_static_fixated_polys] | |
.set(False) | |
) | |
def dummy(rng, state): | |
return state | |
def do_remove(rng, state: EnvState): | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 2) | |
p = can_remove_mask.astype(jnp.float32) | |
index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p) | |
is_rect = index_to_remove < static_env_params.num_polygons | |
state = state.replace( | |
polygon=state.polygon.replace( | |
active=jax.lax.select( | |
is_rect, state.polygon.active.at[index_to_remove].set(False), state.polygon.active | |
) | |
), | |
circle=state.circle.replace( | |
active=jax.lax.select( | |
jnp.logical_not(is_rect), | |
state.circle.active.at[index_to_remove - static_env_params.num_polygons].set(False), | |
state.circle.active, | |
) | |
), | |
) | |
# We need to now remove any joints connected to this shape | |
joints_to_remove = (state.joint.a_index == index_to_remove) | (state.joint.b_index == index_to_remove) | |
thrusters_to_remove = state.thruster.object_index == index_to_remove | |
state = state.replace( | |
joint=state.joint.replace(active=jnp.where(joints_to_remove, False, state.joint.active)), | |
thruster=state.thruster.replace(active=jnp.where(thrusters_to_remove, False, state.thruster.active)), | |
) | |
# Now recalculate collision matrix | |
state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint)) | |
return state | |
return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state) | |
def mutate_remove_joint( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
can_remove_mask = state.joint.active | |
def dummy(rng, state): | |
return state | |
def do_remove(rng, state): | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 2) | |
p = can_remove_mask.astype(jnp.float32) | |
index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p) | |
state = state.replace(joint=state.joint.replace(active=state.joint.active.at[index_to_remove].set(False))) | |
# Recalculate collision matrix. | |
state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint)) | |
return state | |
return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state) | |
def mutate_swap_role( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
def _cr(*args): | |
return count_roles(*args, include_static_polys=False) | |
role_counts = jax.vmap(_cr, (None, None, 0))(state, static_env_params, jnp.arange(4)) | |
are_there_multiple_roles = (role_counts > 0).sum() > 1 | |
def dummy(rng, state): | |
return state | |
def do_swap(rng, state): | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 2) | |
all_roles = jnp.concatenate([state.polygon_shape_roles, state.circle_shape_roles]) | |
p = ( | |
(jnp.concatenate([state.polygon.active, state.circle.active])) | |
.astype(jnp.float32) | |
.at[: static_env_params.num_static_fixated_polys] | |
.set(0.0) | |
) | |
shape_idx_a = jax.random.choice( | |
rngs[0], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p | |
) | |
role_a = all_roles[shape_idx_a] | |
p = jnp.where(all_roles == role_a, 0.0, p) | |
shape_idx_b = jax.random.choice( | |
rngs[1], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p | |
) | |
role_b = all_roles[shape_idx_b] | |
role_a, role_b = role_b, role_a | |
for idx, role in [(shape_idx_a, role_a), (shape_idx_b, role_b)]: | |
is_rect = idx < static_env_params.num_polygons | |
state = state.replace( | |
polygon_shape_roles=jax.lax.select( | |
is_rect, state.polygon_shape_roles.at[idx].set(role), state.polygon_shape_roles | |
), | |
circle_shape_roles=jax.lax.select( | |
jnp.logical_not(is_rect), | |
state.circle_shape_roles.at[idx - static_env_params.num_polygons].set(role), | |
state.circle_shape_roles, | |
), | |
) | |
return state | |
return jax.lax.cond(are_there_multiple_roles, do_swap, dummy, rng, state) | |
def mutate_toggle_fixture( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
can_toggle_mask = ( | |
jnp.concatenate([state.polygon.active, state.circle.active]) | |
.at[: static_env_params.num_static_fixated_polys] | |
.set(False) | |
) | |
def dummy(rng, state): | |
return state | |
def do_toggle(rng, state: EnvState): | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 2) | |
p = can_toggle_mask.astype(jnp.float32) | |
index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_toggle_mask.shape[0]), p=p) | |
is_rect = index_to_remove < static_env_params.num_polygons | |
is_current_fixed = ( | |
jax.lax.select( | |
is_rect, | |
state.polygon.inverse_inertia[index_to_remove], | |
state.circle.inverse_inertia[index_to_remove - static_env_params.num_polygons], | |
) | |
== 0.0 | |
) | |
is_current_fixed = is_current_fixed * 1.0 # if it is fixed, we set it to 1.0 and recalc. | |
# If it is not fixed, this is 0.0, and it makes it fixed. | |
state = state.replace( | |
polygon=state.polygon.replace( | |
inverse_inertia=jax.lax.select( | |
is_rect, | |
state.polygon.inverse_inertia.at[index_to_remove].set(is_current_fixed), | |
state.polygon.inverse_inertia, | |
), | |
inverse_mass=jax.lax.select( | |
is_rect, | |
state.polygon.inverse_mass.at[index_to_remove].set(is_current_fixed), | |
state.polygon.inverse_mass, | |
), | |
), | |
circle=state.circle.replace( | |
inverse_inertia=jax.lax.select( | |
jnp.logical_not(is_rect), | |
state.circle.inverse_inertia.at[index_to_remove - static_env_params.num_polygons].set( | |
is_current_fixed | |
), | |
state.circle.inverse_inertia, | |
), | |
inverse_mass=jax.lax.select( | |
jnp.logical_not(is_rect), | |
state.circle.inverse_mass.at[index_to_remove - static_env_params.num_polygons].set( | |
is_current_fixed | |
), | |
state.circle.inverse_mass, | |
), | |
), | |
) | |
state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities) | |
return state | |
return jax.lax.cond(can_toggle_mask.sum() > 0, do_toggle, dummy, rng, state) | |
def mutate_add_thruster( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
is_fixated = jnp.concatenate([state.polygon.inverse_mass == 0, state.circle.inverse_mass == 0]) | |
# is_fixated = jnp.zeros_like(is_fixated, dtype=bool) | |
is_active = jnp.concatenate([state.polygon.active, state.circle.active]) | |
can_add_mask = is_active & (~is_fixated) | |
can_add_mask = jnp.logical_and(is_active, jnp.logical_not(is_fixated)) | |
def dummy(rng, state): | |
return state | |
def do_add(rng, state: EnvState): | |
rng, _rng = jax.random.split(rng) | |
_rngs = jax.random.split(_rng, 10) | |
p = can_add_mask.astype(jnp.float32) | |
shape_index = jax.random.choice(_rngs[0], jnp.arange(can_add_mask.shape[0]), p=p) | |
thruster_idx = jnp.argmin(state.thruster.active) | |
shape = select_shape(state, shape_index, static_env_params) | |
position_to_add_thruster = jax.lax.select( | |
shape_index < static_env_params.num_polygons, | |
random_position_on_polygon(_rngs[1], shape.vertices, shape.n_vertices, static_env_params), | |
random_position_on_circle(_rngs[2], shape.radius, on_centre_chance=0.0), | |
) | |
direction_to_com = ((jax.random.uniform(_rngs[3]) > 0.5) * 2 - 1) * position_to_add_thruster | |
direction_to_com = jax.lax.select( | |
jnp.linalg.norm(direction_to_com) == 0.0, jnp.array([1.0, 0.0]), direction_to_com | |
) | |
thruster_angle = jax.lax.select( | |
jax.random.uniform(_rngs[4]) < ued_params.thruster_align_com_prob, | |
jnp.atan2(direction_to_com[1], direction_to_com[0]), # test this | |
jax.random.uniform( | |
_rngs[5], | |
(), | |
) | |
* 2 | |
* jnp.pi, | |
) | |
thruster_power = jax.random.uniform(_rngs[6]) * 1.5 + 0.5 | |
thruster = Thruster( | |
object_index=shape_index, | |
active=True, | |
relative_position=position_to_add_thruster, # jnp.array([0.0, 0.0]), # a bit of a hack but reasonable. | |
rotation=thruster_angle, # jax.random.choice(rngs[1], jnp.arange(4) * jnp.pi / 2), | |
power=1.0 | |
/ jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass) | |
* ued_params.thruster_power_multiplier | |
* thruster_power, | |
global_position=shape.position + jnp.matmul(rmat(shape.rotation), position_to_add_thruster), | |
) | |
thruster_colour = jax.random.randint( | |
_rngs[7], shape=(), minval=0, maxval=static_env_params.num_thruster_bindings | |
) | |
state = state.replace( | |
thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster), | |
thruster_bindings=state.thruster_bindings.at[thruster_idx].set(thruster_colour), | |
) | |
return state | |
return jax.lax.cond( | |
jnp.logical_and((can_add_mask.sum() > 0), (jnp.logical_not(state.thruster.active).sum() > 0)), | |
do_add, | |
dummy, | |
rng, | |
state, | |
) | |
def mutate_change_gravity( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 2) | |
new_gravity = jax.lax.select( | |
jax.random.uniform(rngs[0]) < 0.5, | |
jnp.array([0.0, -9.8]), | |
jnp.array([0.0, jax.random.uniform(rngs[1], minval=-9.8, maxval=0)]), | |
) | |
return state.replace(gravity=new_gravity) | |
def mutate_remove_thruster( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
are_there_thrusters = state.thruster.active | |
def dummy(rng, state): | |
return state | |
def do_remove(rng, state): | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 2) | |
p = are_there_thrusters.astype(jnp.float32) | |
thruster_idx = jax.random.choice(rngs[0], jnp.arange(are_there_thrusters.shape[0]), p=p) | |
return state.replace(thruster=state.thruster.replace(active=state.thruster.active.at[thruster_idx].set(False))) | |
return jax.lax.cond(are_there_thrusters.sum() > 0, do_remove, dummy, rng, state) | |
def make_mutate_change_shape_size(params, static_env_params): | |
do_dummy_step = make_do_dummy_step(params, static_env_params) | |
def mutate_change_shape_size( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
shape_active = jnp.concatenate( | |
[state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active] | |
) | |
def dummy(rng, state): | |
return state | |
def do_change(rng, state): | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 10) | |
p = shape_active.astype(jnp.float32) | |
shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p) | |
is_rect = shape_idx < static_env_params.num_polygons | |
vertices, _, radius = sample_dimensions( | |
rngs[1], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size | |
) | |
idx_new_top_left = jnp.argmin(vertices[:, 0] * 100 + vertices[:, 1]) | |
idx_old_top_left = jnp.argmin( | |
state.polygon.vertices[shape_idx, :, 0] * 100 + state.polygon.vertices[shape_idx, :, 1] | |
) | |
scale_rect = (vertices[idx_new_top_left]) / (state.polygon.vertices[shape_idx, idx_old_top_left]) | |
scale_circle = radius / state.circle.radius[shape_idx - static_env_params.num_polygons] | |
vertices = state.polygon.vertices[shape_idx] * scale_rect | |
scale = jax.lax.select( | |
is_rect, | |
scale_rect, | |
jnp.array([scale_circle, scale_circle]), | |
) | |
is_a = ((state.joint.a_index == shape_idx) & state.joint.active)[:, None] | |
is_b = ((state.joint.b_index == shape_idx) & state.joint.active)[:, None] | |
state = state.replace( | |
joint=state.joint.replace( | |
a_relative_pos=(state.joint.a_relative_pos * scale[None]) * is_a | |
+ (1 - is_a) * state.joint.a_relative_pos, | |
b_relative_pos=(state.joint.b_relative_pos * scale[None]) * is_b | |
+ (1 - is_b) * state.joint.b_relative_pos, | |
), | |
polygon=state.polygon.replace( | |
vertices=jax.lax.select( | |
is_rect, state.polygon.vertices.at[shape_idx].set(vertices), state.polygon.vertices | |
), | |
), | |
circle=state.circle.replace( | |
radius=jax.lax.select( | |
jnp.logical_not(is_rect), | |
state.circle.radius.at[shape_idx - static_env_params.num_polygons].set(radius), | |
state.circle.radius, | |
) | |
), | |
) | |
def _ss(state, _): | |
return do_dummy_step(state), None | |
state = jax.lax.scan(_ss, state, jnp.arange(5))[0] | |
return recalculate_mass_and_inertia( | |
state, static_env_params, state.polygon_densities, state.circle_densities | |
) | |
return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state) | |
return mutate_change_shape_size | |
def mutate_change_shape_location( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
shape_active = jnp.concatenate( | |
[state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active] | |
) | |
def dummy(rng, state): | |
return state | |
def do_change(rng, state): | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 10) | |
p = shape_active.astype(jnp.float32) | |
shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p) | |
delta_pos = jax.random.uniform(rngs[1], shape=(2,)) - 0.5 # [-0.5, 0.5] | |
positions = jnp.concatenate([state.polygon.position, state.circle.position]) | |
mask_of_shape_locations_to_change = ( | |
(state.collision_matrix[shape_idx] == 0).at[: static_env_params.num_static_fixated_polys].set(False) | |
) | |
# check the new positions, but then maybe revert if any shape becomes out of bounds now. | |
new_positions_tentative = positions * ( | |
1 - mask_of_shape_locations_to_change[:, None] | |
) + mask_of_shape_locations_to_change[:, None] * (positions + delta_pos[None]) | |
polys = state.polygon | |
p_pos = new_positions_tentative[: static_env_params.num_polygons] | |
c_pos = new_positions_tentative[static_env_params.num_polygons :] # state.circle.position | |
rad = state.circle.radius | |
rect_vertex_mask = jnp.arange(static_env_params.max_polygon_vertices)[None] < polys.n_vertices[:, None] | |
rect_mask = polys.active.at[: static_env_params.num_static_fixated_polys].set(False) | |
circ_mask = state.circle.active | |
# check if new pos maybe goes out of bounds: | |
min_x, max_x, min_y, max_y = ( | |
jnp.minimum( | |
jnp.min( | |
p_pos[:, 0] + jnp.min(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1), | |
where=rect_mask, | |
initial=jnp.inf, | |
), | |
jnp.min(c_pos[:, 0] - rad, where=circ_mask, initial=jnp.inf), | |
), | |
jnp.maximum( | |
jnp.max( | |
p_pos[:, 0] + jnp.max(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1), | |
where=rect_mask, | |
initial=-jnp.inf, | |
), | |
jnp.max(c_pos[:, 0] + rad, where=circ_mask, initial=-jnp.inf), | |
), | |
jnp.minimum( | |
jnp.min( | |
p_pos[:, 1] + jnp.min(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1), | |
where=rect_mask, | |
initial=jnp.inf, | |
), | |
jnp.min(c_pos[:, 1] - rad, where=circ_mask, initial=jnp.inf), | |
), | |
jnp.maximum( | |
jnp.max( | |
p_pos[:, 1] + jnp.max(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1), | |
where=rect_mask, | |
initial=-jnp.inf, | |
), | |
jnp.max(c_pos[:, 1] + rad, where=circ_mask, initial=-jnp.inf), | |
), | |
) | |
how_much_oob_x_left = jnp.maximum(0, 0 - min_x) | |
how_much_oob_x_right = jnp.maximum(0, max_x - static_env_params.screen_dim[0] / params.pixels_per_unit) | |
how_much_oob_y_down = jnp.maximum(0, 0.4 - min_y) # this is for the floor | |
how_much_oob_y_up = jnp.maximum(0, max_y - static_env_params.screen_dim[1] / params.pixels_per_unit) | |
# correct by out of bounds factor | |
positions = ( | |
new_positions_tentative | |
+ jnp.array( | |
[ | |
how_much_oob_x_left - how_much_oob_x_right, | |
how_much_oob_y_down - how_much_oob_y_up, | |
] | |
)[None] | |
* mask_of_shape_locations_to_change[:, None] | |
) | |
state = state.replace( | |
polygon=state.polygon.replace( | |
position=positions[: static_env_params.num_polygons], | |
), | |
circle=state.circle.replace( | |
position=positions[static_env_params.num_polygons :], | |
), | |
) | |
return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities) | |
return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state) | |
def make_mutate_change_shape_rotation(params, static_env_params): | |
do_dummy_step = make_do_dummy_step(params, static_env_params) | |
def mutate_change_shape_rotation( | |
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
): | |
shape_active = jnp.concatenate( | |
[state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active] | |
) | |
def dummy(rng, state): | |
return state | |
def do_change(rng, state): | |
rng, _rng = jax.random.split(rng) | |
rngs = jax.random.split(_rng, 10) | |
p = shape_active.astype(jnp.float32) | |
shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p) | |
is_rect = shape_idx < static_env_params.num_polygons | |
rotation_delta = jax.random.uniform(rngs[1], shape=()) * math.pi / 2 | |
has_fixed_joint_a = (state.joint.a_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active | |
has_fixed_joint_b = (state.joint.b_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active | |
state = state.replace( | |
joint=state.joint.replace( | |
rotation=jax.lax.select( | |
has_fixed_joint_a, | |
state.joint.rotation - rotation_delta, | |
jax.lax.select( | |
has_fixed_joint_b, | |
state.joint.rotation + rotation_delta, | |
state.joint.rotation, | |
), | |
) | |
), | |
polygon=state.polygon.replace( | |
rotation=jax.lax.select( | |
is_rect, state.polygon.rotation.at[shape_idx].add(rotation_delta), state.polygon.rotation | |
), | |
), | |
circle=state.circle.replace( | |
rotation=jax.lax.select( | |
jnp.logical_not(is_rect), | |
state.circle.rotation.at[shape_idx - static_env_params.num_polygons].add(rotation_delta), | |
state.circle.rotation, | |
) | |
), | |
) | |
def _ss(state, _): | |
return do_dummy_step(state), None | |
state = jax.lax.scan(_ss, state, jnp.arange(5))[0] | |
return recalculate_mass_and_inertia( | |
state, static_env_params, state.polygon_densities, state.circle_densities | |
) | |
return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state) | |
return mutate_change_shape_rotation | |