kinet-test / Kinetix /kinetix /render /renderer_symbolic_common.py
tree3po's picture
Upload 190 files
e0f25ed verified
import jax
from jax2d.sim_state import RigidBody
import jax.numpy as jnp
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
def _get_base_shape_features(
density: jnp.ndarray, roles: jnp.ndarray, shapes: RigidBody, env_params: EnvParams
) -> jnp.ndarray:
cos = jnp.cos(shapes.rotation)
sin = jnp.sin(shapes.rotation)
return jnp.concatenate(
[
shapes.position,
shapes.velocity,
jnp.expand_dims(shapes.inverse_mass, axis=1),
jnp.expand_dims(shapes.inverse_inertia, axis=1),
jnp.expand_dims(density, axis=1),
jnp.expand_dims(jnp.tanh(shapes.angular_velocity / 10), axis=1),
jax.nn.one_hot(roles, env_params.num_shape_roles),
jnp.expand_dims(sin, axis=1),
jnp.expand_dims(cos, axis=1),
jnp.expand_dims(shapes.friction, axis=1),
jnp.expand_dims(shapes.restitution, axis=1),
],
axis=1,
)
def add_circle_features(
base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams
):
return jnp.concatenate(
[
base_features,
shapes.radius[:, None],
jnp.ones_like(base_features[:, :1]), # one for circle
],
axis=1,
)
def make_circle_features(
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray]:
base_features = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params)
node_features = add_circle_features(base_features, state.circle, env_params, static_env_params)
return node_features, state.circle.active
def add_polygon_features(
base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams
):
vertices = jnp.where(
jnp.arange(static_env_params.max_polygon_vertices)[None, :, None] < shapes.n_vertices[:, None, None],
shapes.vertices,
jnp.zeros_like(shapes.vertices) - 1,
)
return jnp.concatenate(
[
base_features,
jnp.zeros_like(base_features[:, :1]), # zero for polygon
vertices.reshape((vertices.shape[0], -1)),
jnp.expand_dims((shapes.n_vertices <= 3), axis=1),
],
axis=1,
)
def make_polygon_features(
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray]:
base_features = _get_base_shape_features(
state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params
)
node_features = add_polygon_features(base_features, state.polygon, env_params, static_env_params)
return node_features, state.polygon.active
def make_unified_shape_features(
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray]:
base_p = _get_base_shape_features(state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params)
base_c = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params)
base_p = add_polygon_features(base_p, state.polygon, env_params, static_env_params)
base_p = add_circle_features(base_p, state.polygon, env_params, static_env_params)
base_c = add_polygon_features(base_c, state.circle, env_params, static_env_params)
base_c = add_circle_features(base_c, state.circle, env_params, static_env_params)
return jnp.concatenate([base_p, base_c], axis=0), jnp.concatenate(
[state.polygon.active, state.circle.active], axis=0
)
def make_joint_features(
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
# Returns joint_features, indexes, mask, of shape:
# (2 * J, K), (2 * J, 2), (2 * J,)
def _create_joint_features(joints):
# 2, J, A
J = joints.active.shape[0]
def _create_1way_joint_features(direction):
from_pos = jax.lax.select(direction, joints.a_relative_pos, joints.b_relative_pos)
to_pos = jax.lax.select(direction, joints.b_relative_pos, joints.a_relative_pos)
rotation_sin, rotation_cos = jnp.sin(joints.rotation), jnp.cos(joints.rotation)
rotation_max_sin = jnp.sin(joints.max_rotation) * joints.motor_has_joint_limits
rotation_max_cos = jnp.cos(joints.max_rotation) * joints.motor_has_joint_limits
rotation_min_sin = jnp.sin(joints.min_rotation) * joints.motor_has_joint_limits
rotation_min_cos = jnp.cos(joints.min_rotation) * joints.motor_has_joint_limits
rotation_diff_max = (joints.max_rotation - joints.rotation) * joints.motor_has_joint_limits
rotation_diff_min = (joints.min_rotation - joints.rotation) * joints.motor_has_joint_limits
base_features = jnp.concatenate(
[
(joints.active * 1.0)[:, None],
(joints.is_fixed_joint * 1.0)[:, None], # J, 1
from_pos,
to_pos,
rotation_sin[:, None],
rotation_cos[:, None],
],
axis=1,
)
rjoint_features = (
jnp.concatenate(
[
joints.motor_speed[:, None],
joints.motor_power[:, None],
(joints.motor_on * 1.0)[:, None],
(joints.motor_has_joint_limits * 1.0)[:, None],
jax.nn.one_hot(state.motor_bindings, num_classes=static_env_params.num_motor_bindings),
rotation_min_sin[:, None],
rotation_min_cos[:, None],
rotation_max_sin[:, None],
rotation_max_cos[:, None],
rotation_diff_min[:, None],
rotation_diff_max[:, None],
],
axis=1,
)
* (1.0 - (joints.is_fixed_joint * 1.0))[:, None]
)
return jnp.concatenate([base_features, rjoint_features], axis=1)
# 2, J, A
joint_features = jax.vmap(_create_1way_joint_features)(jnp.array([False, True]))
# J, 2
indexes_from = jnp.concatenate([joints.b_index[:, None], joints.a_index[:, None]], axis=1)
indexes_to = jnp.concatenate([joints.a_index[:, None], joints.b_index[:, None]], axis=1)
indexes_from = jnp.where(joints.active[:, None], indexes_from, jnp.zeros_like(indexes_from))
indexes_to = jnp.where(joints.active[:, None], indexes_to, jnp.zeros_like(indexes_to))
indexes = jnp.concatenate([indexes_from, indexes_to], axis=0)
mask = jnp.concatenate([joints.active, joints.active], axis=0)
return joint_features.reshape((2 * J, -1)), indexes, mask
return _create_joint_features(state.joint)
def make_thruster_features(
state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
# Returns thruster_features, indexes, mask, of shape:
# (T, K), (T,), (T,)
def _create_thruster_features(thrusters):
cos = jnp.cos(thrusters.rotation)
sin = jnp.sin(thrusters.rotation)
return jnp.concatenate(
[
(thrusters.active * 1.0)[:, None],
(thrusters.relative_position),
jax.nn.one_hot(state.thruster_bindings, num_classes=static_env_params.num_thruster_bindings),
sin[:, None],
cos[:, None],
thrusters.power[:, None],
],
axis=1,
)
return _create_thruster_features(state.thruster), state.thruster.object_index, state.thruster.active