kinet-test / Kinetix /kinetix /render /renderer_symbolic_flat.py
tree3po's picture
Upload 190 files
e0f25ed verified
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax2d import joint
from jax2d.engine import select_shape
from jax2d.maths import rmat
from jax2d.sim_state import RigidBody
from jaxgl.maths import dist_from_line
from jaxgl.renderer import clear_screen, make_renderer
from jaxgl.shaders import (
fragment_shader_quad,
fragment_shader_edged_quad,
make_fragment_shader_texture,
nearest_neighbour,
make_fragment_shader_quad_textured,
)
from kinetix.render.renderer_symbolic_common import (
make_circle_features,
make_joint_features,
make_polygon_features,
make_thruster_features,
)
from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState
from flax import struct
def make_render_symbolic(params, static_params: StaticEnvParams):
def render_symbolic(state):
n_polys = static_params.num_polygons
nshapes = n_polys + static_params.num_circles
polygon_features, polygon_mask = make_polygon_features(state, params, static_params)
mask_to_ignore_walls_ceiling = np.ones(static_params.num_polygons, dtype=bool)
mask_to_ignore_walls_ceiling[np.array([1, 2, 3])] = False
polygon_features = polygon_features[mask_to_ignore_walls_ceiling]
polygon_mask = polygon_mask[mask_to_ignore_walls_ceiling]
circle_features, circle_mask = make_circle_features(state, params, static_params)
joint_features, joint_idxs, joint_mask = make_joint_features(state, params, static_params)
thruster_features, thruster_idxs, thruster_mask = make_thruster_features(state, params, static_params)
two_J = joint_features.shape[0]
J = two_J // 2 # for symbolic only have the one
joint_features = jnp.concatenate(
[
joint_features[:J], # shape (2 * J, K)
jax.nn.one_hot(joint_idxs[:J, 0], nshapes), # shape (2 * J, N)
jax.nn.one_hot(joint_idxs[:J, 1], nshapes), # shape (2 * J, N)
],
axis=1,
)
thruster_features = jnp.concatenate(
[
thruster_features,
jax.nn.one_hot(thruster_idxs, nshapes),
],
axis=1,
)
polygon_features = jnp.where(polygon_mask[:, None], polygon_features, 0.0).flatten()
circle_features = jnp.where(circle_mask[:, None], circle_features, 0.0).flatten()
joint_features = jnp.where(joint_mask[:J, None], joint_features, 0.0).flatten()
thruster_features = jnp.where(thruster_mask[:, None], thruster_features, 0.0).flatten()
def _get_manifold_features(manifold):
collision_mask_features = jnp.concatenate(
[
manifold.normal,
jnp.expand_dims(manifold.penetration, axis=-1),
manifold.collision_point,
jnp.expand_dims(manifold.acc_impulse_normal, axis=-1),
jnp.expand_dims(manifold.acc_impulse_tangent, axis=-1),
],
axis=-1,
)
return (collision_mask_features * manifold.active[..., None]).flatten()
obs = jnp.concatenate(
[
polygon_features,
circle_features,
joint_features,
thruster_features,
jnp.array([state.gravity[1]]) / 10,
# _get_manifold_features(state.acc_cc_manifolds),
# _get_manifold_features(state.acc_cr_manifolds),
# _get_manifold_features(state.acc_rr_manifolds),
],
axis=0,
)
obs = jnp.clip(obs, a_min=-10.0, a_max=10.0)
obs = jnp.nan_to_num(obs)
return obs
return render_symbolic