Spaces:
Runtime error
Runtime error
File size: 10,132 Bytes
e0f25ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax2d import joint
from jax2d.engine import select_shape
from jax2d.maths import rmat
from jax2d.sim_state import RigidBody
from jaxgl.maths import dist_from_line
from jaxgl.renderer import clear_screen, make_renderer
from jaxgl.shaders import (
fragment_shader_quad,
fragment_shader_edged_quad,
make_fragment_shader_texture,
nearest_neighbour,
make_fragment_shader_quad_textured,
)
from kinetix.render.textures import (
THRUSTER_TEXTURE_16_RGBA,
RJOINT_TEXTURE_6_RGBA,
FJOINT_TEXTURE_6_RGBA,
)
from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState
from flax import struct
def make_render_pixels(
params,
static_params: StaticEnvParams,
):
screen_dim = static_params.screen_dim
downscale = static_params.downscale
joint_tex_size = 6
thruster_tex_size = 16
FIXATED_COLOUR = jnp.array([80, 80, 80])
JOINT_COLOURS = jnp.array(
[
# [0, 0, 255],
[255, 255, 255], # yellow
[255, 255, 0], # yellow
[255, 0, 255], # purple/magenta
[0, 255, 255], # cyan
[255, 153, 51], # white
]
)
def colour_thruster_texture(colour):
return THRUSTER_TEXTURE_16_RGBA.at[:9, :, :3].mul(colour[None, None, :] / 255.0)
coloured_thruster_textures = jax.vmap(colour_thruster_texture)(JOINT_COLOURS)
ROLE_COLOURS = jnp.array(
[
[160.0, 160.0, 160.0], # None
[0.0, 204.0, 0.0], # Green: The ball
[0.0, 102.0, 204.0], # Blue: The goal
[255.0, 102.0, 102.0], # Red: Death Objects
]
)
BACKGROUND_COLOUR = jnp.array([255.0, 255.0, 255.0])
def _get_colour(shape_role, inverse_inertia):
base_colour = ROLE_COLOURS[shape_role]
f = (inverse_inertia == 0) * 1
is_not_normal = (shape_role != 0) * 1
return jnp.array(
[
base_colour,
base_colour,
FIXATED_COLOUR,
base_colour * 0.5,
]
)[2 * f + is_not_normal]
# Pixels per unit distance
ppud = params.pixels_per_unit // downscale
downscaled_screen_dim = (screen_dim[0] // downscale, screen_dim[1] // downscale)
full_screen_size = (
downscaled_screen_dim[0] + (static_params.max_shape_size * 2 * ppud),
downscaled_screen_dim[1] + (static_params.max_shape_size * 2 * ppud),
)
cleared_screen = clear_screen(full_screen_size, BACKGROUND_COLOUR)
def _world_space_to_pixel_space(x):
return (x + static_params.max_shape_size) * ppud
def fragment_shader_kinetix_circle(position, current_frag, unit_position, uniform):
centre, radius, rotation, colour, mask = uniform
dist = jnp.sqrt(jnp.square(position - centre).sum())
inside = dist <= radius
on_edge = dist > radius - 2
# TODO - precompute?
normal = jnp.array([jnp.sin(rotation), -jnp.cos(rotation)])
dist = dist_from_line(position, centre, centre + normal)
on_edge |= (dist < 1) & (jnp.dot(normal, position - centre) <= 0)
fragment = jax.lax.select(on_edge, jnp.zeros(3), colour)
return jax.lax.select(inside & mask, fragment, current_frag)
def fragment_shader_kinetix_joint(position, current_frag, unit_position, uniform):
texture, colour, mask = uniform
tex_coord = (
jnp.array(
[
joint_tex_size * unit_position[0],
joint_tex_size * unit_position[1],
]
)
- 0.5
)
tex_frag = nearest_neighbour(texture, tex_coord)
tex_frag = tex_frag.at[3].mul(mask)
tex_frag = tex_frag.at[:3].mul(colour / 255.0)
tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag)
return tex_frag
thruster_pixel_size = thruster_tex_size // downscale
thruster_pixel_size_diagonal = (thruster_pixel_size * np.sqrt(2)).astype(jnp.int32) + 1
def fragment_shader_kinetix_thruster(fragment_position, current_frag, unit_position, uniform):
thruster_position, rotation, texture, mask = uniform
tex_position = jnp.matmul(rmat(-rotation), (fragment_position - thruster_position)) / thruster_pixel_size + 0.5
mask &= (tex_position[0] >= 0) & (tex_position[0] <= 1) & (tex_position[1] >= 0) & (tex_position[1] <= 1)
eps = 0.001
tex_coord = (
jnp.array(
[
thruster_tex_size * tex_position[0],
thruster_tex_size * tex_position[1],
]
)
- 0.5
+ eps
)
tex_frag = nearest_neighbour(texture, tex_coord)
tex_frag = tex_frag.at[3].mul(mask)
tex_frag = (tex_frag[3] * tex_frag[:3]) + ((1.0 - tex_frag[3]) * current_frag)
return tex_frag
patch_size_1d = static_params.max_shape_size * ppud
patch_size = (patch_size_1d, patch_size_1d)
circle_renderer = make_renderer(full_screen_size, fragment_shader_kinetix_circle, patch_size, batched=True)
quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, patch_size, batched=True)
big_quad_renderer = make_renderer(full_screen_size, fragment_shader_edged_quad, downscaled_screen_dim)
joint_pixel_size = joint_tex_size // downscale
joint_renderer = make_renderer(
full_screen_size, fragment_shader_kinetix_joint, (joint_pixel_size, joint_pixel_size), batched=True
)
thruster_renderer = make_renderer(
full_screen_size,
fragment_shader_kinetix_thruster,
(thruster_pixel_size_diagonal, thruster_pixel_size_diagonal),
batched=True,
)
@jax.jit
def render_pixels(state: EnvState):
pixels = cleared_screen
# Floor
floor_uniform = (
_world_space_to_pixel_space(state.polygon.position[0, None, :] + state.polygon.vertices[0]),
_get_colour(state.polygon_shape_roles[0], 0),
jnp.zeros(3),
True,
)
pixels = big_quad_renderer(pixels, _world_space_to_pixel_space(jnp.zeros(2, dtype=jnp.int32)), floor_uniform)
# Rectangles
rectangle_patch_positions = _world_space_to_pixel_space(
state.polygon.position - (static_params.max_shape_size / 2.0)
).astype(jnp.int32)
rectangle_rmats = jax.vmap(rmat)(state.polygon.rotation)
rectangle_rmats = jnp.repeat(rectangle_rmats[:, None, :, :], repeats=static_params.max_polygon_vertices, axis=1)
rectangle_vertices_pixel_space = _world_space_to_pixel_space(
state.polygon.position[:, None, :] + jax.vmap(jax.vmap(jnp.matmul))(rectangle_rmats, state.polygon.vertices)
)
rectangle_colours = jax.vmap(_get_colour)(state.polygon_shape_roles, state.polygon.inverse_mass)
rectangle_edge_colours = jnp.zeros((static_params.num_polygons, 3))
rectangle_uniforms = (
rectangle_vertices_pixel_space,
rectangle_colours,
rectangle_edge_colours,
state.polygon.active,
)
pixels = quad_renderer(pixels, rectangle_patch_positions, rectangle_uniforms)
# Circles
circle_positions_pixel_space = _world_space_to_pixel_space(state.circle.position)
circle_radii_pixel_space = state.circle.radius * ppud
circle_patch_positions = _world_space_to_pixel_space(
state.circle.position - (static_params.max_shape_size / 2.0)
).astype(jnp.int32)
circle_colours = jax.vmap(_get_colour)(state.circle_shape_roles, state.circle.inverse_mass)
circle_uniforms = (
circle_positions_pixel_space,
circle_radii_pixel_space,
state.circle.rotation,
circle_colours,
state.circle.active,
)
pixels = circle_renderer(pixels, circle_patch_positions, circle_uniforms)
# Joints
joint_patch_positions = jnp.round(
_world_space_to_pixel_space(state.joint.global_position) - (joint_pixel_size // 2)
).astype(jnp.int32)
joint_textures = jax.vmap(jax.lax.select, in_axes=(0, None, None))(
state.joint.is_fixed_joint, FJOINT_TEXTURE_6_RGBA, RJOINT_TEXTURE_6_RGBA
)
joint_colours = JOINT_COLOURS[
(state.motor_bindings + 1) * (state.joint.motor_on & (~state.joint.is_fixed_joint))
]
joint_uniforms = (joint_textures, joint_colours, state.joint.active)
pixels = joint_renderer(pixels, joint_patch_positions, joint_uniforms)
# Thrusters
thruster_positions = jnp.round(_world_space_to_pixel_space(state.thruster.global_position)).astype(jnp.int32)
thruster_patch_positions = thruster_positions - (thruster_pixel_size_diagonal // 2)
thruster_textures = coloured_thruster_textures[state.thruster_bindings + 1]
thruster_rotations = (
state.thruster.rotation
+ jax.vmap(select_shape, in_axes=(None, 0, None))(
state, state.thruster.object_index, static_params
).rotation
)
thruster_uniforms = (thruster_positions, thruster_rotations, thruster_textures, state.thruster.active)
pixels = thruster_renderer(pixels, thruster_patch_positions, thruster_uniforms)
# Crop out the sides
crop_amount = static_params.max_shape_size * ppud
return pixels[crop_amount:-crop_amount, crop_amount:-crop_amount]
return render_pixels
@struct.dataclass
class PixelsObservation:
image: jnp.ndarray
global_info: jnp.ndarray
def make_render_pixels_rl(params, static_params: StaticEnvParams):
render_fn = make_render_pixels(params, static_params)
def inner(state):
pixels = render_fn(state) / 255.0
return PixelsObservation(
image=pixels,
global_info=jnp.array([state.gravity[1] / 10.0]),
)
return inner
|