Spaces:
Runtime error
Runtime error
import jax | |
from jax2d.sim_state import RigidBody | |
import jax.numpy as jnp | |
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams | |
def _get_base_shape_features( | |
density: jnp.ndarray, roles: jnp.ndarray, shapes: RigidBody, env_params: EnvParams | |
) -> jnp.ndarray: | |
cos = jnp.cos(shapes.rotation) | |
sin = jnp.sin(shapes.rotation) | |
return jnp.concatenate( | |
[ | |
shapes.position, | |
shapes.velocity, | |
jnp.expand_dims(shapes.inverse_mass, axis=1), | |
jnp.expand_dims(shapes.inverse_inertia, axis=1), | |
jnp.expand_dims(density, axis=1), | |
jnp.expand_dims(jnp.tanh(shapes.angular_velocity / 10), axis=1), | |
jax.nn.one_hot(roles, env_params.num_shape_roles), | |
jnp.expand_dims(sin, axis=1), | |
jnp.expand_dims(cos, axis=1), | |
jnp.expand_dims(shapes.friction, axis=1), | |
jnp.expand_dims(shapes.restitution, axis=1), | |
], | |
axis=1, | |
) | |
def add_circle_features( | |
base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams | |
): | |
return jnp.concatenate( | |
[ | |
base_features, | |
shapes.radius[:, None], | |
jnp.ones_like(base_features[:, :1]), # one for circle | |
], | |
axis=1, | |
) | |
def make_circle_features( | |
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams | |
) -> tuple[jnp.ndarray, jnp.ndarray]: | |
base_features = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params) | |
node_features = add_circle_features(base_features, state.circle, env_params, static_env_params) | |
return node_features, state.circle.active | |
def add_polygon_features( | |
base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams | |
): | |
vertices = jnp.where( | |
jnp.arange(static_env_params.max_polygon_vertices)[None, :, None] < shapes.n_vertices[:, None, None], | |
shapes.vertices, | |
jnp.zeros_like(shapes.vertices) - 1, | |
) | |
return jnp.concatenate( | |
[ | |
base_features, | |
jnp.zeros_like(base_features[:, :1]), # zero for polygon | |
vertices.reshape((vertices.shape[0], -1)), | |
jnp.expand_dims((shapes.n_vertices <= 3), axis=1), | |
], | |
axis=1, | |
) | |
def make_polygon_features( | |
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams | |
) -> tuple[jnp.ndarray, jnp.ndarray]: | |
base_features = _get_base_shape_features( | |
state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params | |
) | |
node_features = add_polygon_features(base_features, state.polygon, env_params, static_env_params) | |
return node_features, state.polygon.active | |
def make_unified_shape_features( | |
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams | |
) -> tuple[jnp.ndarray, jnp.ndarray]: | |
base_p = _get_base_shape_features(state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params) | |
base_c = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params) | |
base_p = add_polygon_features(base_p, state.polygon, env_params, static_env_params) | |
base_p = add_circle_features(base_p, state.polygon, env_params, static_env_params) | |
base_c = add_polygon_features(base_c, state.circle, env_params, static_env_params) | |
base_c = add_circle_features(base_c, state.circle, env_params, static_env_params) | |
return jnp.concatenate([base_p, base_c], axis=0), jnp.concatenate( | |
[state.polygon.active, state.circle.active], axis=0 | |
) | |
def make_joint_features( | |
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams | |
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: | |
# Returns joint_features, indexes, mask, of shape: | |
# (2 * J, K), (2 * J, 2), (2 * J,) | |
def _create_joint_features(joints): | |
# 2, J, A | |
J = joints.active.shape[0] | |
def _create_1way_joint_features(direction): | |
from_pos = jax.lax.select(direction, joints.a_relative_pos, joints.b_relative_pos) | |
to_pos = jax.lax.select(direction, joints.b_relative_pos, joints.a_relative_pos) | |
rotation_sin, rotation_cos = jnp.sin(joints.rotation), jnp.cos(joints.rotation) | |
rotation_max_sin = jnp.sin(joints.max_rotation) * joints.motor_has_joint_limits | |
rotation_max_cos = jnp.cos(joints.max_rotation) * joints.motor_has_joint_limits | |
rotation_min_sin = jnp.sin(joints.min_rotation) * joints.motor_has_joint_limits | |
rotation_min_cos = jnp.cos(joints.min_rotation) * joints.motor_has_joint_limits | |
rotation_diff_max = (joints.max_rotation - joints.rotation) * joints.motor_has_joint_limits | |
rotation_diff_min = (joints.min_rotation - joints.rotation) * joints.motor_has_joint_limits | |
base_features = jnp.concatenate( | |
[ | |
(joints.active * 1.0)[:, None], | |
(joints.is_fixed_joint * 1.0)[:, None], # J, 1 | |
from_pos, | |
to_pos, | |
rotation_sin[:, None], | |
rotation_cos[:, None], | |
], | |
axis=1, | |
) | |
rjoint_features = ( | |
jnp.concatenate( | |
[ | |
joints.motor_speed[:, None], | |
joints.motor_power[:, None], | |
(joints.motor_on * 1.0)[:, None], | |
(joints.motor_has_joint_limits * 1.0)[:, None], | |
jax.nn.one_hot(state.motor_bindings, num_classes=static_env_params.num_motor_bindings), | |
rotation_min_sin[:, None], | |
rotation_min_cos[:, None], | |
rotation_max_sin[:, None], | |
rotation_max_cos[:, None], | |
rotation_diff_min[:, None], | |
rotation_diff_max[:, None], | |
], | |
axis=1, | |
) | |
* (1.0 - (joints.is_fixed_joint * 1.0))[:, None] | |
) | |
return jnp.concatenate([base_features, rjoint_features], axis=1) | |
# 2, J, A | |
joint_features = jax.vmap(_create_1way_joint_features)(jnp.array([False, True])) | |
# J, 2 | |
indexes_from = jnp.concatenate([joints.b_index[:, None], joints.a_index[:, None]], axis=1) | |
indexes_to = jnp.concatenate([joints.a_index[:, None], joints.b_index[:, None]], axis=1) | |
indexes_from = jnp.where(joints.active[:, None], indexes_from, jnp.zeros_like(indexes_from)) | |
indexes_to = jnp.where(joints.active[:, None], indexes_to, jnp.zeros_like(indexes_to)) | |
indexes = jnp.concatenate([indexes_from, indexes_to], axis=0) | |
mask = jnp.concatenate([joints.active, joints.active], axis=0) | |
return joint_features.reshape((2 * J, -1)), indexes, mask | |
return _create_joint_features(state.joint) | |
def make_thruster_features( | |
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams | |
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: | |
# Returns thruster_features, indexes, mask, of shape: | |
# (T, K), (T,), (T,) | |
def _create_thruster_features(thrusters): | |
cos = jnp.cos(thrusters.rotation) | |
sin = jnp.sin(thrusters.rotation) | |
return jnp.concatenate( | |
[ | |
(thrusters.active * 1.0)[:, None], | |
(thrusters.relative_position), | |
jax.nn.one_hot(state.thruster_bindings, num_classes=static_env_params.num_thruster_bindings), | |
sin[:, None], | |
cos[:, None], | |
thrusters.power[:, None], | |
], | |
axis=1, | |
) | |
return _create_thruster_features(state.thruster), state.thruster.object_index, state.thruster.active | |