tree3po's picture
Upload 46 files
581eeac verified
from functools import partial
import math
import chex
import jax
import jax.numpy as jnp
from flax.serialization import to_state_dict
from jax2d.engine import (
PhysicsEngine,
calculate_collision_matrix,
calc_inverse_mass_polygon,
calc_inverse_mass_circle,
calc_inverse_inertia_circle,
calc_inverse_inertia_polygon,
recalculate_mass_and_inertia,
select_shape,
)
from jax2d.sim_state import SimState, RigidBody, Joint, Thruster
from jax2d.maths import rmat
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
from kinetix.environment.ued.ued_state import UEDParams
from kinetix.environment.ued.util import (
count_roles,
is_space_for_joint,
make_velocities_zero,
sample_dimensions,
random_position_on_polygon,
random_position_on_circle,
get_role,
is_space_for_shape,
are_there_shapes_present,
)
from kinetix.util.saving import load_world_state_pickle
from flax import struct
from kinetix.environment.env import create_empty_env
from kinetix.environment.ued.util import make_do_dummy_step
@partial(jax.jit, static_argnums=(3, 4))
def mutate_add_shape(
rng,
state: EnvState,
params: EnvParams,
static_env_params: StaticEnvParams,
ued_params: UEDParams,
force_no_fixate: bool = False,
):
def do_dummy(rng, state):
return state
def do_add(rng, state):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, 9)
space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons
space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles
is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0])
is_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=is_rect_p)
rect_index = jnp.argmin(state.polygon.active)
circle_index = jnp.argmin(state.circle.active)
shape_role = get_role(_rngs[1], state, static_env_params)
max_shape_size = (
jnp.array([1.0, ued_params.goal_body_size_factor, ued_params.goal_body_size_factor, 1.0])[shape_role]
* ued_params.max_shape_size
)
vertices, half_dimensions, radius = sample_dimensions(
_rngs[2],
static_env_params,
is_rect,
ued_params,
max_shape_size=max_shape_size,
)
n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4)
largest = jnp.max(jnp.array([half_dimensions[0] * jnp.sqrt(2), half_dimensions[1] * jnp.sqrt(2), radius]))
screen_dim_world = (
static_env_params.screen_dim[0] / params.pixels_per_unit,
static_env_params.screen_dim[1] / params.pixels_per_unit,
)
min_x = largest
max_x = screen_dim_world[0] - largest
min_y = largest + 0.4
max_y = screen_dim_world[1] - largest
def _og_minmax():
return min_x, max_x, min_y, max_y
def _opposite_minmax():
return jax.lax.switch(
shape_role,
[
(lambda: (min_x, max_x, min_y, max_y)),
(lambda: (min_x, max_x - screen_dim_world[0] / 2, min_y, max_y)),
(lambda: (min_x + screen_dim_world[0] / 2, max_x, min_y, max_y)),
(lambda: (min_x, max_x, min_y, max_y)),
],
)
min_x, max_x, min_y, max_y = jax.lax.cond(
jax.random.uniform(_rngs[4], shape=()) < ued_params.goal_body_opposide_side_chance,
_opposite_minmax,
_og_minmax,
)
position = jax.random.uniform(_rngs[5], shape=(2,)) * jnp.array(
[
max_x - min_x,
max_y - min_y,
]
) + jnp.array([min_x, min_y])
rotation = jax.random.uniform(_rngs[6], shape=()) * 2 * math.pi
velocity = jnp.array([0.0, 0.0])
angular_velocity = 0.0
density = 1.0
inverse_mass = jax.lax.select(
is_rect,
calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0],
calc_inverse_mass_circle(radius, density),
)
inverse_inertia = jax.lax.select(
is_rect,
calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density),
calc_inverse_inertia_circle(radius, density),
)
fixate_chance = ued_params.fixate_chance_min + (1.0 / inverse_mass) * ued_params.fixate_chance_scale
fixate_chance = jnp.minimum(fixate_chance, ued_params.fixate_chance_max)
is_fixated = jax.random.uniform(_rngs[7], shape=()) < fixate_chance
is_fixated &= ~force_no_fixate
inverse_mass *= 1 - is_fixated
inverse_inertia *= 1 - is_fixated
# We want to bias fixated shapes to starting nearer the bottom half of the screen
fixate_shape_bottom_bias = (
ued_params.fixate_shape_bottom_bias + ued_params.fixate_shape_bottom_bias_special_role * (shape_role != 0)
)
is_forcing_bottom = jax.random.uniform(_rngs[8]) < fixate_shape_bottom_bias
half_screen_height = (static_env_params.screen_dim[1] / params.pixels_per_unit) / 2.0
position = jax.lax.select(
is_fixated & is_forcing_bottom & (position[1] >= half_screen_height),
position.at[1].add(-half_screen_height),
position,
)
# This could be either a rect or a circle
new_rigid_body = RigidBody(
position=position,
velocity=velocity,
inverse_mass=inverse_mass,
inverse_inertia=inverse_inertia,
rotation=rotation,
angular_velocity=angular_velocity,
radius=radius,
active=True,
friction=1.0,
vertices=vertices,
n_vertices=n_vertices,
collision_mode=1,
restitution=0.0,
)
state = state.replace(
polygon=jax.tree.map(
lambda x, y: jax.lax.select(is_rect, y.at[rect_index].set(x), y), new_rigid_body, state.polygon
),
circle=jax.tree.map(
lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[circle_index].set(x), y),
new_rigid_body,
state.circle,
),
polygon_shape_roles=jax.lax.select(
is_rect,
state.polygon_shape_roles.at[rect_index].set(shape_role),
state.polygon_shape_roles,
),
circle_shape_roles=jax.lax.select(
jnp.logical_not(is_rect),
state.circle_shape_roles.at[circle_index].set(shape_role),
state.circle_shape_roles,
),
)
return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
return jax.lax.cond(is_space_for_shape(state), do_add, do_dummy, rng, state)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_add_connected_shape(
rng,
state: EnvState,
params: EnvParams,
static_env_params: StaticEnvParams,
ued_params: UEDParams,
force_rjoint: bool = False,
):
def do_dummy(rng, state):
return state, False
def do_add(rng, state):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, 21)
# Select a random index amongst the currently active shapes.
p_rect = state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False)
p_circle = state.circle.active
p_rect = p_rect.astype(jnp.float32)
p_circle = p_circle.astype(jnp.float32)
p_rect *= (state.polygon.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + (
state.polygon.inverse_mass != 0
) * 1.0
p_circle *= (state.circle.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + (
state.circle.inverse_mass != 0
) * 1.0
# Bias based on number of existing connections
rect_connections = jnp.zeros(static_env_params.num_polygons)
circle_connections = jnp.zeros(static_env_params.num_circles)
rect_connections = rect_connections.at[state.joint.a_index].add(
jnp.ones(static_env_params.num_joints)
* state.joint.active
* (state.joint.a_index < static_env_params.num_polygons)
)
rect_connections = rect_connections.at[state.joint.b_index].add(
jnp.ones(static_env_params.num_joints)
* state.joint.active
* (state.joint.b_index < static_env_params.num_polygons)
)
circle_connections = circle_connections.at[state.joint.a_index - static_env_params.num_polygons].add(
jnp.ones(static_env_params.num_joints)
* state.joint.active
* (state.joint.a_index >= static_env_params.num_polygons)
)
circle_connections = circle_connections.at[state.joint.b_index - static_env_params.num_polygons].add(
jnp.ones(static_env_params.num_joints)
* state.joint.active
* (state.joint.b_index >= static_env_params.num_polygons)
)
# Rectangles can have up to 2 connections
p_rect *= (-rect_connections + 2.0) / 2.0
p_rect = jnp.maximum(p_rect, 0.0)
# Circles can have 1 connection
p_circle *= circle_connections == 0
# To sample a target rect/circle, we have to have at least one.
target_rect_p = jnp.array(
[
(state.polygon.active.astype(int).sum() > static_env_params.num_static_fixated_polys) * 1.0,
(state.circle.active.astype(int).sum() > 0) * 1.0,
]
)
# Don't connect to a circle if no connection-free ones exist
target_rect_p = target_rect_p.at[1].mul(p_circle.sum() > 0)
space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons
space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles
is_target_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=target_rect_p) | (
~space_for_new_rect
)
is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0])
is_rect = jax.random.choice(_rngs[1], jnp.array([True, False], dtype=bool), p=is_rect_p) | (
~is_target_rect & space_for_new_rect
)
shape_index = jax.lax.select(
is_rect,
jnp.argmin(state.polygon.active),
jnp.argmin(state.circle.active),
)
unified_shape_index = shape_index + (~is_rect) * static_env_params.num_polygons
vertices, half_dimensions, radius = sample_dimensions(
_rngs[2], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size
)
n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4)
rotation = jax.random.uniform(_rngs[4], shape=()) * 2 * math.pi
velocity = jnp.array([0.0, 0.0])
angular_velocity = 0.0
density = 1.0
inverse_mass = jax.lax.select(
is_rect,
calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0],
calc_inverse_mass_circle(radius, density),
)
inverse_inertia = jax.lax.select(
is_rect,
calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density),
calc_inverse_inertia_circle(radius, density),
)
# Joint
current_num_rjoints = (jnp.logical_not(state.joint.is_fixed_joint) * state.joint.active).sum()
is_rjoint = jnp.logical_or(
jnp.logical_or(jax.random.uniform(_rngs[5]) < 0.5, force_rjoint),
current_num_rjoints < ued_params.min_rjoints_bias,
)
joint_index = jnp.argmin(state.joint.active)
local_joint_position_rect = random_position_on_polygon(_rngs[6], vertices, n_vertices, static_env_params)
local_joint_position_circle = random_position_on_circle(_rngs[7], radius, on_centre_chance=1.0)
local_joint_position = jax.lax.select(is_rect, local_joint_position_rect, local_joint_position_circle)
p_rect = jax.lax.select(p_rect.sum() == 0, state.polygon.active.astype(jnp.float32), p_rect)
p_circle = jax.lax.select(p_circle.sum() == 0, state.circle.active.astype(jnp.float32), p_circle)
target_index = jax.lax.select(
is_target_rect,
jax.random.choice(
_rngs[8],
jnp.arange(static_env_params.num_polygons),
p=p_rect,
),
jax.random.choice(
_rngs[9],
jnp.arange(static_env_params.num_circles),
p=p_circle,
),
)
unified_target_index = target_index + jnp.logical_not(is_target_rect) * static_env_params.num_polygons
target_shape = select_shape(state, unified_target_index, static_env_params)
target_joint_position_rect = random_position_on_polygon(
_rngs[10], state.polygon.vertices[target_index], state.polygon.n_vertices[target_index], static_env_params
)
target_joint_position_circle = random_position_on_circle(
_rngs[11], state.circle.radius[target_index], on_centre_chance=1.0
)
target_joint_position = jax.lax.select(is_target_rect, target_joint_position_rect, target_joint_position_circle)
# Calculate the world position of the new shape
# We know the rotation of the new shape. We also know the position of the current shape, which we want to remain fixed.
# Set `position` such that local_joint_position is the same as `target_joint_position`
global_joint_pos = target_shape.position + jnp.matmul(rmat(target_shape.rotation), target_joint_position)
position = global_joint_pos - jnp.matmul(rmat(rotation), local_joint_position)
_, pos_diff = calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)
position = jax.lax.select(is_rect, position + pos_diff, position)
local_joint_position = jax.lax.select(is_rect, local_joint_position - pos_diff, local_joint_position)
vertices = jax.lax.select(is_rect, vertices - pos_diff[None], vertices)
target_role = jax.lax.select(
is_target_rect, state.polygon_shape_roles[target_index], state.circle_shape_roles[target_index]
)
# We cannot have role 1 and role 2 being connected.
p = jnp.array([1.0, 1.0, 1.0, 1.0])
# If role is 0, keep all probs at 1, otherwise set the target role's complement to 0 prob
# 3 - role turns 1 to 2 and 2 to 1
# If the target role is three, we set everything to zero except for the default
p = jax.lax.select(
target_role == 0,
p,
jax.lax.select(
target_role <= 2,
p.at[3 - target_role].set(False).at[3].set(False),
(p.at[2].set(False).at[1].set(False)),
),
)
shape_role = get_role(_rngs[12], state, static_env_params, initial_p=p)
# This could be either a rect or a circle
new_rigid_body = RigidBody(
position=position,
velocity=velocity,
inverse_mass=inverse_mass,
inverse_inertia=inverse_inertia,
rotation=rotation,
angular_velocity=angular_velocity,
radius=radius,
active=True,
friction=1.0,
vertices=vertices,
n_vertices=n_vertices,
collision_mode=1,
restitution=0.0,
)
# Change the shape indices such that a_index is less than b_index
a_index = shape_index + (1 - is_rect) * static_env_params.num_polygons
b_index = target_index + (1 - is_target_rect) * static_env_params.num_polygons
should_swap = a_index > b_index
a_index, b_index, local_joint_position, target_joint_position, shape_a, shape_b = jax.lax.cond(
should_swap,
lambda x: (x[1], x[0], x[3], x[2], x[5], x[4]), # pairwise swap
lambda x: x,
(a_index, b_index, local_joint_position, target_joint_position, new_rigid_body, target_shape),
)
motor_on = jax.random.uniform(_rngs[13], shape=()) < ued_params.motor_on_chance
joint_colour = jax.random.randint(_rngs[14], shape=(), minval=0, maxval=static_env_params.num_motor_bindings)
joint_rotation = shape_b.rotation - shape_a.rotation
motor_speed = jax.random.uniform(
_rngs[15], shape=(), minval=ued_params.motor_min_speed, maxval=ued_params.motor_max_speed
)
motor_power = jax.random.uniform(
_rngs[16], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.motor_max_power
)
wheel_power = jax.random.uniform(
_rngs[20], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.wheel_max_power
)
# High-powered wheels break the physics engine - this is a temporary fix
motor_power = jax.lax.select(is_rect & is_target_rect, motor_power, wheel_power)
motor_has_joint_limits = jax.random.uniform(_rngs[17], shape=()) < ued_params.joint_limit_chance
motor_has_joint_limits &= is_rect & is_target_rect
joint_limit_min = (
jax.random.uniform(_rngs[18], shape=(), minval=-ued_params.joint_limit_max, maxval=0.0)
* motor_has_joint_limits
)
joint_limit_max = (
jax.random.uniform(_rngs[19], shape=(), minval=0.0, maxval=ued_params.joint_limit_max)
* motor_has_joint_limits
)
rjoint = Joint(
a_index=a_index,
b_index=b_index,
a_relative_pos=local_joint_position,
b_relative_pos=target_joint_position,
global_position=global_joint_pos,
active=True,
motor_speed=motor_speed,
motor_power=motor_power,
motor_on=motor_on,
# colour=joint_colour,
motor_has_joint_limits=motor_has_joint_limits,
min_rotation=joint_limit_min,
max_rotation=joint_limit_max,
is_fixed_joint=False,
rotation=0.0,
acc_impulse=jnp.zeros((2,), dtype=jnp.float32),
acc_r_impulse=jnp.zeros((), dtype=jnp.float32),
)
fjoint = Joint(
a_index=a_index,
b_index=b_index,
a_relative_pos=local_joint_position,
b_relative_pos=target_joint_position,
global_position=global_joint_pos,
active=True,
rotation=joint_rotation,
acc_impulse=jnp.zeros((2,), dtype=jnp.float32),
acc_r_impulse=jnp.zeros((), dtype=jnp.float32),
is_fixed_joint=True,
motor_has_joint_limits=False,
min_rotation=0.0,
max_rotation=0.0,
motor_on=False,
motor_power=0.0,
motor_speed=0.0,
)
state = state.replace(
polygon=jax.tree.map(
lambda x, y: jax.lax.select(is_rect, y.at[shape_index].set(x), y), new_rigid_body, state.polygon
),
circle=jax.tree.map(
lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[shape_index].set(x), y),
new_rigid_body,
state.circle,
),
joint=jax.tree.map(
lambda rj, fj, y: jax.lax.select(is_rjoint, y.at[joint_index].set(rj), y.at[joint_index].set(fj)),
rjoint,
fjoint,
state.joint,
),
polygon_shape_roles=jax.lax.select(
is_rect,
state.polygon_shape_roles.at[shape_index].set(shape_role),
state.polygon_shape_roles,
),
circle_shape_roles=jax.lax.select(
jnp.logical_not(is_rect),
state.circle_shape_roles.at[shape_index].set(shape_role),
state.circle_shape_roles,
),
motor_bindings=state.motor_bindings.at[joint_index].set(joint_colour),
)
# We need the new collision matrix.
state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint))
state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
# Was this a valid addition?
# We calculate whether (assuming the possiblity of 360 degree rotation around the joint)
# both shapes can be visible
# This is to remove the common degenerate pattern of connected shapes being fully inside each other
def _get_min_rect_dist(r_id, local_pos):
rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon)
half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0
half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0
dist_x = half_width - jnp.abs(local_pos[0])
dist_y = half_height - jnp.abs(local_pos[1])
return jnp.minimum(dist_x, dist_y)
def _get_max_rect_dist(r_id, local_pos):
rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon)
half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0
half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0
dist_x = jnp.maximum(
jnp.abs(half_width - local_pos[0]),
jnp.abs(-half_width - local_pos[0]),
)
dist_y = jnp.maximum(
jnp.abs(half_height - local_pos[1]),
jnp.abs(-half_height - local_pos[1]),
)
return jnp.sqrt(dist_x * dist_x + dist_y * dist_y)
def are_both_shapes_showing(idx1, idx2, local_pos1, local_pos2):
def _is_small_shape_showing(small_idx, big_idx, small_local_pos, big_local_pos):
small_is_poly = small_idx < static_env_params.num_polygons
big_is_poly = big_idx < static_env_params.num_polygons
# CC
cc_result = False
# CR
cr_r_dist = _get_min_rect_dist(big_idx, big_local_pos)
cr_result = (
cr_r_dist + ued_params.connect_visibility_min
< state.circle.radius[small_idx - static_env_params.num_polygons]
)
# RC
rc_r_dist = _get_max_rect_dist(small_idx, small_local_pos)
rc_result = (
rc_r_dist
> state.circle.radius[big_idx - static_env_params.num_polygons] + ued_params.connect_visibility_min
)
# RR
rr_small_dist = _get_max_rect_dist(small_idx, small_local_pos)
rr_big_dist = _get_min_rect_dist(big_idx, big_local_pos)
rr_result = rr_small_dist > rr_big_dist + ued_params.connect_visibility_min
# Select
return jax.lax.select(
small_is_poly,
jax.lax.select(big_is_poly, rr_result, rc_result),
jax.lax.select(big_is_poly, cr_result, cc_result),
)
# Are both shapes showing?
return _is_small_shape_showing(idx1, idx2, local_pos1, local_pos2) & _is_small_shape_showing(
idx2, idx1, local_pos2, local_pos1
)
valid = are_both_shapes_showing(
unified_shape_index, unified_target_index, local_joint_position, target_joint_position
)
return state, valid
# To add a connected shape, we must have both at least one existing shape and space
return jax.lax.cond(
is_space_for_shape(state) & are_there_shapes_present(state, static_env_params) & is_space_for_joint(state),
do_add,
do_dummy,
rng,
state,
)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_add_connected_shape_proper(
rng,
state: EnvState,
params: EnvParams,
static_env_params: StaticEnvParams,
ued_params: UEDParams,
force_rjoint: bool = False,
):
return mutate_add_connected_shape(rng, state, params, static_env_params, ued_params, force_rjoint=force_rjoint)[0]
@partial(jax.jit, static_argnums=(3, 4))
def mutate_remove_shape(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
can_remove_mask = (
jnp.concatenate([state.polygon.active, state.circle.active])
.at[: static_env_params.num_static_fixated_polys]
.set(False)
)
def dummy(rng, state):
return state
def do_remove(rng, state: EnvState):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 2)
p = can_remove_mask.astype(jnp.float32)
index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p)
is_rect = index_to_remove < static_env_params.num_polygons
state = state.replace(
polygon=state.polygon.replace(
active=jax.lax.select(
is_rect, state.polygon.active.at[index_to_remove].set(False), state.polygon.active
)
),
circle=state.circle.replace(
active=jax.lax.select(
jnp.logical_not(is_rect),
state.circle.active.at[index_to_remove - static_env_params.num_polygons].set(False),
state.circle.active,
)
),
)
# We need to now remove any joints connected to this shape
joints_to_remove = (state.joint.a_index == index_to_remove) | (state.joint.b_index == index_to_remove)
thrusters_to_remove = state.thruster.object_index == index_to_remove
state = state.replace(
joint=state.joint.replace(active=jnp.where(joints_to_remove, False, state.joint.active)),
thruster=state.thruster.replace(active=jnp.where(thrusters_to_remove, False, state.thruster.active)),
)
# Now recalculate collision matrix
state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint))
return state
return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_remove_joint(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
can_remove_mask = state.joint.active
def dummy(rng, state):
return state
def do_remove(rng, state):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 2)
p = can_remove_mask.astype(jnp.float32)
index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p)
state = state.replace(joint=state.joint.replace(active=state.joint.active.at[index_to_remove].set(False)))
# Recalculate collision matrix.
state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint))
return state
return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_swap_role(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
def _cr(*args):
return count_roles(*args, include_static_polys=False)
role_counts = jax.vmap(_cr, (None, None, 0))(state, static_env_params, jnp.arange(4))
are_there_multiple_roles = (role_counts > 0).sum() > 1
def dummy(rng, state):
return state
def do_swap(rng, state):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 2)
all_roles = jnp.concatenate([state.polygon_shape_roles, state.circle_shape_roles])
p = (
(jnp.concatenate([state.polygon.active, state.circle.active]))
.astype(jnp.float32)
.at[: static_env_params.num_static_fixated_polys]
.set(0.0)
)
shape_idx_a = jax.random.choice(
rngs[0], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p
)
role_a = all_roles[shape_idx_a]
p = jnp.where(all_roles == role_a, 0.0, p)
shape_idx_b = jax.random.choice(
rngs[1], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p
)
role_b = all_roles[shape_idx_b]
role_a, role_b = role_b, role_a
for idx, role in [(shape_idx_a, role_a), (shape_idx_b, role_b)]:
is_rect = idx < static_env_params.num_polygons
state = state.replace(
polygon_shape_roles=jax.lax.select(
is_rect, state.polygon_shape_roles.at[idx].set(role), state.polygon_shape_roles
),
circle_shape_roles=jax.lax.select(
jnp.logical_not(is_rect),
state.circle_shape_roles.at[idx - static_env_params.num_polygons].set(role),
state.circle_shape_roles,
),
)
return state
return jax.lax.cond(are_there_multiple_roles, do_swap, dummy, rng, state)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_toggle_fixture(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
can_toggle_mask = (
jnp.concatenate([state.polygon.active, state.circle.active])
.at[: static_env_params.num_static_fixated_polys]
.set(False)
)
def dummy(rng, state):
return state
def do_toggle(rng, state: EnvState):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 2)
p = can_toggle_mask.astype(jnp.float32)
index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_toggle_mask.shape[0]), p=p)
is_rect = index_to_remove < static_env_params.num_polygons
is_current_fixed = (
jax.lax.select(
is_rect,
state.polygon.inverse_inertia[index_to_remove],
state.circle.inverse_inertia[index_to_remove - static_env_params.num_polygons],
)
== 0.0
)
is_current_fixed = is_current_fixed * 1.0 # if it is fixed, we set it to 1.0 and recalc.
# If it is not fixed, this is 0.0, and it makes it fixed.
state = state.replace(
polygon=state.polygon.replace(
inverse_inertia=jax.lax.select(
is_rect,
state.polygon.inverse_inertia.at[index_to_remove].set(is_current_fixed),
state.polygon.inverse_inertia,
),
inverse_mass=jax.lax.select(
is_rect,
state.polygon.inverse_mass.at[index_to_remove].set(is_current_fixed),
state.polygon.inverse_mass,
),
),
circle=state.circle.replace(
inverse_inertia=jax.lax.select(
jnp.logical_not(is_rect),
state.circle.inverse_inertia.at[index_to_remove - static_env_params.num_polygons].set(
is_current_fixed
),
state.circle.inverse_inertia,
),
inverse_mass=jax.lax.select(
jnp.logical_not(is_rect),
state.circle.inverse_mass.at[index_to_remove - static_env_params.num_polygons].set(
is_current_fixed
),
state.circle.inverse_mass,
),
),
)
state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
return state
return jax.lax.cond(can_toggle_mask.sum() > 0, do_toggle, dummy, rng, state)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_add_thruster(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
is_fixated = jnp.concatenate([state.polygon.inverse_mass == 0, state.circle.inverse_mass == 0])
# is_fixated = jnp.zeros_like(is_fixated, dtype=bool)
is_active = jnp.concatenate([state.polygon.active, state.circle.active])
can_add_mask = is_active & (~is_fixated)
can_add_mask = jnp.logical_and(is_active, jnp.logical_not(is_fixated))
def dummy(rng, state):
return state
def do_add(rng, state: EnvState):
rng, _rng = jax.random.split(rng)
_rngs = jax.random.split(_rng, 10)
p = can_add_mask.astype(jnp.float32)
shape_index = jax.random.choice(_rngs[0], jnp.arange(can_add_mask.shape[0]), p=p)
thruster_idx = jnp.argmin(state.thruster.active)
shape = select_shape(state, shape_index, static_env_params)
position_to_add_thruster = jax.lax.select(
shape_index < static_env_params.num_polygons,
random_position_on_polygon(_rngs[1], shape.vertices, shape.n_vertices, static_env_params),
random_position_on_circle(_rngs[2], shape.radius, on_centre_chance=0.0),
)
direction_to_com = ((jax.random.uniform(_rngs[3]) > 0.5) * 2 - 1) * position_to_add_thruster
direction_to_com = jax.lax.select(
jnp.linalg.norm(direction_to_com) == 0.0, jnp.array([1.0, 0.0]), direction_to_com
)
thruster_angle = jax.lax.select(
jax.random.uniform(_rngs[4]) < ued_params.thruster_align_com_prob,
jnp.atan2(direction_to_com[1], direction_to_com[0]), # test this
jax.random.uniform(
_rngs[5],
(),
)
* 2
* jnp.pi,
)
thruster_power = jax.random.uniform(_rngs[6]) * 1.5 + 0.5
thruster = Thruster(
object_index=shape_index,
active=True,
relative_position=position_to_add_thruster, # jnp.array([0.0, 0.0]), # a bit of a hack but reasonable.
rotation=thruster_angle, # jax.random.choice(rngs[1], jnp.arange(4) * jnp.pi / 2),
power=1.0
/ jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass)
* ued_params.thruster_power_multiplier
* thruster_power,
global_position=shape.position + jnp.matmul(rmat(shape.rotation), position_to_add_thruster),
)
thruster_colour = jax.random.randint(
_rngs[7], shape=(), minval=0, maxval=static_env_params.num_thruster_bindings
)
state = state.replace(
thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster),
thruster_bindings=state.thruster_bindings.at[thruster_idx].set(thruster_colour),
)
return state
return jax.lax.cond(
jnp.logical_and((can_add_mask.sum() > 0), (jnp.logical_not(state.thruster.active).sum() > 0)),
do_add,
dummy,
rng,
state,
)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_change_gravity(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 2)
new_gravity = jax.lax.select(
jax.random.uniform(rngs[0]) < 0.5,
jnp.array([0.0, -9.8]),
jnp.array([0.0, jax.random.uniform(rngs[1], minval=-9.8, maxval=0)]),
)
return state.replace(gravity=new_gravity)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_remove_thruster(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
are_there_thrusters = state.thruster.active
def dummy(rng, state):
return state
def do_remove(rng, state):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 2)
p = are_there_thrusters.astype(jnp.float32)
thruster_idx = jax.random.choice(rngs[0], jnp.arange(are_there_thrusters.shape[0]), p=p)
return state.replace(thruster=state.thruster.replace(active=state.thruster.active.at[thruster_idx].set(False)))
return jax.lax.cond(are_there_thrusters.sum() > 0, do_remove, dummy, rng, state)
def make_mutate_change_shape_size(params, static_env_params):
do_dummy_step = make_do_dummy_step(params, static_env_params)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_change_shape_size(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
shape_active = jnp.concatenate(
[state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active]
)
def dummy(rng, state):
return state
def do_change(rng, state):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 10)
p = shape_active.astype(jnp.float32)
shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p)
is_rect = shape_idx < static_env_params.num_polygons
vertices, _, radius = sample_dimensions(
rngs[1], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size
)
idx_new_top_left = jnp.argmin(vertices[:, 0] * 100 + vertices[:, 1])
idx_old_top_left = jnp.argmin(
state.polygon.vertices[shape_idx, :, 0] * 100 + state.polygon.vertices[shape_idx, :, 1]
)
scale_rect = (vertices[idx_new_top_left]) / (state.polygon.vertices[shape_idx, idx_old_top_left])
scale_circle = radius / state.circle.radius[shape_idx - static_env_params.num_polygons]
vertices = state.polygon.vertices[shape_idx] * scale_rect
scale = jax.lax.select(
is_rect,
scale_rect,
jnp.array([scale_circle, scale_circle]),
)
is_a = ((state.joint.a_index == shape_idx) & state.joint.active)[:, None]
is_b = ((state.joint.b_index == shape_idx) & state.joint.active)[:, None]
state = state.replace(
joint=state.joint.replace(
a_relative_pos=(state.joint.a_relative_pos * scale[None]) * is_a
+ (1 - is_a) * state.joint.a_relative_pos,
b_relative_pos=(state.joint.b_relative_pos * scale[None]) * is_b
+ (1 - is_b) * state.joint.b_relative_pos,
),
polygon=state.polygon.replace(
vertices=jax.lax.select(
is_rect, state.polygon.vertices.at[shape_idx].set(vertices), state.polygon.vertices
),
),
circle=state.circle.replace(
radius=jax.lax.select(
jnp.logical_not(is_rect),
state.circle.radius.at[shape_idx - static_env_params.num_polygons].set(radius),
state.circle.radius,
)
),
)
def _ss(state, _):
return do_dummy_step(state), None
state = jax.lax.scan(_ss, state, jnp.arange(5))[0]
return recalculate_mass_and_inertia(
state, static_env_params, state.polygon_densities, state.circle_densities
)
return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state)
return mutate_change_shape_size
@partial(jax.jit, static_argnums=(3, 4))
def mutate_change_shape_location(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
shape_active = jnp.concatenate(
[state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active]
)
def dummy(rng, state):
return state
def do_change(rng, state):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 10)
p = shape_active.astype(jnp.float32)
shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p)
delta_pos = jax.random.uniform(rngs[1], shape=(2,)) - 0.5 # [-0.5, 0.5]
positions = jnp.concatenate([state.polygon.position, state.circle.position])
mask_of_shape_locations_to_change = (
(state.collision_matrix[shape_idx] == 0).at[: static_env_params.num_static_fixated_polys].set(False)
)
# check the new positions, but then maybe revert if any shape becomes out of bounds now.
new_positions_tentative = positions * (
1 - mask_of_shape_locations_to_change[:, None]
) + mask_of_shape_locations_to_change[:, None] * (positions + delta_pos[None])
polys = state.polygon
p_pos = new_positions_tentative[: static_env_params.num_polygons]
c_pos = new_positions_tentative[static_env_params.num_polygons :] # state.circle.position
rad = state.circle.radius
rect_vertex_mask = jnp.arange(static_env_params.max_polygon_vertices)[None] < polys.n_vertices[:, None]
rect_mask = polys.active.at[: static_env_params.num_static_fixated_polys].set(False)
circ_mask = state.circle.active
# check if new pos maybe goes out of bounds:
min_x, max_x, min_y, max_y = (
jnp.minimum(
jnp.min(
p_pos[:, 0] + jnp.min(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1),
where=rect_mask,
initial=jnp.inf,
),
jnp.min(c_pos[:, 0] - rad, where=circ_mask, initial=jnp.inf),
),
jnp.maximum(
jnp.max(
p_pos[:, 0] + jnp.max(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1),
where=rect_mask,
initial=-jnp.inf,
),
jnp.max(c_pos[:, 0] + rad, where=circ_mask, initial=-jnp.inf),
),
jnp.minimum(
jnp.min(
p_pos[:, 1] + jnp.min(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1),
where=rect_mask,
initial=jnp.inf,
),
jnp.min(c_pos[:, 1] - rad, where=circ_mask, initial=jnp.inf),
),
jnp.maximum(
jnp.max(
p_pos[:, 1] + jnp.max(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1),
where=rect_mask,
initial=-jnp.inf,
),
jnp.max(c_pos[:, 1] + rad, where=circ_mask, initial=-jnp.inf),
),
)
how_much_oob_x_left = jnp.maximum(0, 0 - min_x)
how_much_oob_x_right = jnp.maximum(0, max_x - static_env_params.screen_dim[0] / params.pixels_per_unit)
how_much_oob_y_down = jnp.maximum(0, 0.4 - min_y) # this is for the floor
how_much_oob_y_up = jnp.maximum(0, max_y - static_env_params.screen_dim[1] / params.pixels_per_unit)
# correct by out of bounds factor
positions = (
new_positions_tentative
+ jnp.array(
[
how_much_oob_x_left - how_much_oob_x_right,
how_much_oob_y_down - how_much_oob_y_up,
]
)[None]
* mask_of_shape_locations_to_change[:, None]
)
state = state.replace(
polygon=state.polygon.replace(
position=positions[: static_env_params.num_polygons],
),
circle=state.circle.replace(
position=positions[static_env_params.num_polygons :],
),
)
return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state)
def make_mutate_change_shape_rotation(params, static_env_params):
do_dummy_step = make_do_dummy_step(params, static_env_params)
@partial(jax.jit, static_argnums=(3, 4))
def mutate_change_shape_rotation(
rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams
):
shape_active = jnp.concatenate(
[state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active]
)
def dummy(rng, state):
return state
def do_change(rng, state):
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 10)
p = shape_active.astype(jnp.float32)
shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p)
is_rect = shape_idx < static_env_params.num_polygons
rotation_delta = jax.random.uniform(rngs[1], shape=()) * math.pi / 2
has_fixed_joint_a = (state.joint.a_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active
has_fixed_joint_b = (state.joint.b_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active
state = state.replace(
joint=state.joint.replace(
rotation=jax.lax.select(
has_fixed_joint_a,
state.joint.rotation - rotation_delta,
jax.lax.select(
has_fixed_joint_b,
state.joint.rotation + rotation_delta,
state.joint.rotation,
),
)
),
polygon=state.polygon.replace(
rotation=jax.lax.select(
is_rect, state.polygon.rotation.at[shape_idx].add(rotation_delta), state.polygon.rotation
),
),
circle=state.circle.replace(
rotation=jax.lax.select(
jnp.logical_not(is_rect),
state.circle.rotation.at[shape_idx - static_env_params.num_polygons].add(rotation_delta),
state.circle.rotation,
)
),
)
def _ss(state, _):
return do_dummy_step(state), None
state = jax.lax.scan(_ss, state, jnp.arange(5))[0]
return recalculate_mass_and_inertia(
state, static_env_params, state.polygon_densities, state.circle_densities
)
return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state)
return mutate_change_shape_rotation