kinet-test / Kinetix /kinetix /render /renderer_pixels.py
tree3po's picture
Upload 190 files
e0f25ed verified
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax2d import joint
from jax2d.engine import select_shape
from jax2d.maths import rmat
from jax2d.sim_state import RigidBody
from jaxgl.maths import dist_from_line
from jaxgl.renderer import clear_screen, make_renderer
from jaxgl.shaders import (
fragment_shader_quad,
fragment_shader_edged_quad,
make_fragment_shader_texture,
nearest_neighbour,
make_fragment_shader_quad_textured,
)
from kinetix.render.textures import (
THRUSTER_TEXTURE_16_RGBA,
RJOINT_TEXTURE_6_RGBA,
FJOINT_TEXTURE_6_RGBA,
)
from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState
from flax import struct
def make_render_pixels(
params,
static_params: StaticEnvParams,
):
screen_dim = static_params.screen_dim
downscale = static_params.downscale
joint_tex_size = 6
thruster_tex_size = 16
FIXATED_COLOUR = jnp.array([80, 80, 80])
JOINT_COLOURS = jnp.array(
[
# [0, 0, 255],
[255, 255, 255], # yellow
[255, 255, 0], # yellow
[255, 0, 255], # purple/magenta
[0, 255, 255], # cyan
[255, 153, 51], # white
]
)
def colour_thruster_texture(colour):
return THRUSTER_TEXTURE_16_RGBA.at[:9, :, :3].mul(colour[None, None, :] / 255.0)
coloured_thruster_textures = jax.vmap(colour_thruster_texture)(JOINT_COLOURS)
ROLE_COLOURS = jnp.array(
[
[160.0, 160.0, 160.0], # None
[0.0, 204.0, 0.0], # Green: The ball
[0.0, 102.0, 204.0], # Blue: The goal
[255.0, 102.0, 102.0], # Red: Death Objects
]
)
BACKGROUND_COLOUR = jnp.array([255.0, 255.0, 255.0])
def _get_colour(shape_role, inverse_inertia):
base_colour = ROLE_COLOURS[shape_role]
f = (inverse_inertia == 0) * 1
is_not_normal = (shape_role != 0) * 1
return jnp.array(
[
base_colour,
base_colour,
FIXATED_COLOUR,
base_colour * 0.5,
]
)[2 * f + is_not_normal]
# Pixels per unit distance
ppud = params.pixels_per_unit // downscale
downscaled_screen_dim = (screen_dim[0] // downscale, screen_dim[1] // downscale)
full_screen_size = (
downscaled_screen_dim[0] + (static_params.max_shape_size * 2 * ppud),
downscaled_screen_dim[1] + (static_params.max_shape_size * 2 * ppud),
)
cleared_screen = clear_screen(full_screen_size, BACKGROUND_COLOUR)
def _world_space_to_pixel_space(x):
return (x + static_params.max_shape_size) * ppud
def fragment_shader_kinetix_circle(position, current_frag, unit_position, uniform):
centre, radius, rotation, colour, mask = uniform
dist = jnp.sqrt(jnp.square(position - centre).sum())
inside = dist <= radius
on_edge = dist > radius - 2
# TODO - precompute?
normal = jnp.array([jnp.sin(rotation), -jnp.cos(rotation)])
dist = dist_from_line(position, centre, centre + normal)
on_edge |= (dist < 1) & (jnp.dot(normal, position - centre) <= 0)
fragment = jax.lax.select(on_edge, jnp.zeros(3), colour)
return jax.lax.select(inside & mask, fragment, current_frag)
def fragment_shader_kinetix_joint(position, current_frag, unit_position, uniform):
texture, colour, mask = uniform
tex_coord = (
jnp.array(
[
joint_tex_size * unit_position[0],
joint_tex_size * unit_position[1],
]
)
- 0.5
)
tex_frag = nearest_neighbour(texture, tex_coord)
tex_frag = tex_frag.at[3].mul(mask)
tex_frag = tex_frag.at[:3].mul(colour / 255.0)
tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag)
return tex_frag
thruster_pixel_size = thruster_tex_size // downscale
thruster_pixel_size_diagonal = (thruster_pixel_size * np.sqrt(2)).astype(jnp.int32) + 1
def fragment_shader_kinetix_thruster(fragment_position, current_frag, unit_position, uniform):
thruster_position, rotation, texture, mask = uniform
tex_position = jnp.matmul(rmat(-rotation), (fragment_position - thruster_position)) / thruster_pixel_size + 0.5
mask &= (tex_position[0] >= 0) & (tex_position[0] <= 1) & (tex_position[1] >= 0) & (tex_position[1] <= 1)
eps = 0.001
tex_coord = (
jnp.array(
[
thruster_tex_size * tex_position[0],
thruster_tex_size * tex_position[1],
]
)
- 0.5
+ eps
)
tex_frag = nearest_neighbour(texture, tex_coord)
tex_frag = tex_frag.at[3].mul(mask)
tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag)
return tex_frag
patch_size_1d = static_params.max_shape_size * ppud
patch_size = (patch_size_1d, patch_size_1d)
circle_renderer = make_renderer(full_screen_size, fragment_shader_kinetix_circle, patch_size, batched=True)
quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, patch_size, batched=True)
big_quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, downscaled_screen_dim)
joint_pixel_size = joint_tex_size // downscale
joint_renderer = make_renderer(
full_screen_size, fragment_shader_kinetix_joint, (joint_pixel_size, joint_pixel_size), batched=True
)
thruster_renderer = make_renderer(
full_screen_size,
fragment_shader_kinetix_thruster,
(thruster_pixel_size_diagonal, thruster_pixel_size_diagonal),
batched=True,
)
@jax.jit
def render_pixels(state: EnvState):
pixels = cleared_screen
# Floor
floor_uniform = (
_world_space_to_pixel_space(state.polygon.position[0, None, :] + state.polygon.vertices[0]),
_get_colour(state.polygon_shape_roles[0], 0),
jnp.zeros(3),
True,
)
pixels = big_quad_renderer(pixels, _world_space_to_pixel_space(jnp.zeros(2, dtype=jnp.int32)), floor_uniform)
# Rectangles
rectangle_patch_positions = _world_space_to_pixel_space(
state.polygon.position - (static_params.max_shape_size / 2.0)
).astype(jnp.int32)
rectangle_rmats = jax.vmap(rmat)(state.polygon.rotation)
rectangle_rmats = jnp.repeat(rectangle_rmats[:, None, :, :], repeats=static_params.max_polygon_vertices, axis=1)
rectangle_vertices_pixel_space = _world_space_to_pixel_space(
state.polygon.position[:, None, :] + jax.vmap(jax.vmap(jnp.matmul))(rectangle_rmats, state.polygon.vertices)
)
rectangle_colours = jax.vmap(_get_colour)(state.polygon_shape_roles, state.polygon.inverse_mass)
rectangle_edge_colours = jnp.zeros((static_params.num_polygons, 3))
rectangle_uniforms = (
rectangle_vertices_pixel_space,
rectangle_colours,
rectangle_edge_colours,
state.polygon.active,
)
pixels = quad_renderer(pixels, rectangle_patch_positions, rectangle_uniforms)
# Circles
circle_positions_pixel_space = _world_space_to_pixel_space(state.circle.position)
circle_radii_pixel_space = state.circle.radius * ppud
circle_patch_positions = _world_space_to_pixel_space(
state.circle.position - (static_params.max_shape_size / 2.0)
).astype(jnp.int32)
circle_colours = jax.vmap(_get_colour)(state.circle_shape_roles, state.circle.inverse_mass)
circle_uniforms = (
circle_positions_pixel_space,
circle_radii_pixel_space,
state.circle.rotation,
circle_colours,
state.circle.active,
)
pixels = circle_renderer(pixels, circle_patch_positions, circle_uniforms)
# Joints
joint_patch_positions = jnp.round(
_world_space_to_pixel_space(state.joint.global_position) - (joint_pixel_size // 2)
).astype(jnp.int32)
joint_textures = jax.vmap(jax.lax.select, in_axes=(0, None, None))(
state.joint.is_fixed_joint, FJOINT_TEXTURE_6_RGBA, RJOINT_TEXTURE_6_RGBA
)
joint_colours = JOINT_COLOURS[
(state.motor_bindings + 1) * (state.joint.motor_on & (~state.joint.is_fixed_joint))
]
joint_uniforms = (joint_textures, joint_colours, state.joint.active)
pixels = joint_renderer(pixels, joint_patch_positions, joint_uniforms)
# Thrusters
thruster_positions = jnp.round(_world_space_to_pixel_space(state.thruster.global_position)).astype(jnp.int32)
thruster_patch_positions = thruster_positions - (thruster_pixel_size_diagonal // 2)
thruster_textures = coloured_thruster_textures[state.thruster_bindings + 1]
thruster_rotations = (
state.thruster.rotation
+ jax.vmap(select_shape, in_axes=(None, 0, None))(
state, state.thruster.object_index, static_params
).rotation
)
thruster_uniforms = (thruster_positions, thruster_rotations, thruster_textures, state.thruster.active)
pixels = thruster_renderer(pixels, thruster_patch_positions, thruster_uniforms)
# Crop out the sides
crop_amount = static_params.max_shape_size * ppud
return pixels[crop_amount:-crop_amount, crop_amount:-crop_amount]
return render_pixels
@struct.dataclass
class PixelsObservation:
image: jnp.ndarray
global_info: jnp.ndarray
def make_render_pixels_rl(params, static_params: StaticEnvParams):
render_fn = make_render_pixels(params, static_params)
def inner(state):
pixels = render_fn(state) / 255.0
return PixelsObservation(
image=pixels,
global_info=jnp.array([state.gravity[1] / 10.0]),
)
return inner