File size: 3,801 Bytes
e0f25ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np

from jax2d import joint
from jax2d.engine import select_shape
from jax2d.maths import rmat
from jax2d.sim_state import RigidBody
from jaxgl.maths import dist_from_line
from jaxgl.renderer import clear_screen, make_renderer
from jaxgl.shaders import (
    fragment_shader_quad,
    fragment_shader_edged_quad,
    make_fragment_shader_texture,
    nearest_neighbour,
    make_fragment_shader_quad_textured,
)
from kinetix.render.renderer_symbolic_common import (
    make_circle_features,
    make_joint_features,
    make_polygon_features,
    make_thruster_features,
)
from kinetix.environment.env_state import StaticEnvParams, EnvParams, EnvState
from flax import struct


def make_render_symbolic(params, static_params: StaticEnvParams):
    def render_symbolic(state):

        n_polys = static_params.num_polygons
        nshapes = n_polys + static_params.num_circles

        polygon_features, polygon_mask = make_polygon_features(state, params, static_params)
        mask_to_ignore_walls_ceiling = np.ones(static_params.num_polygons, dtype=bool)
        mask_to_ignore_walls_ceiling[np.array([1, 2, 3])] = False

        polygon_features = polygon_features[mask_to_ignore_walls_ceiling]
        polygon_mask = polygon_mask[mask_to_ignore_walls_ceiling]

        circle_features, circle_mask = make_circle_features(state, params, static_params)
        joint_features, joint_idxs, joint_mask = make_joint_features(state, params, static_params)
        thruster_features, thruster_idxs, thruster_mask = make_thruster_features(state, params, static_params)

        two_J = joint_features.shape[0]
        J = two_J // 2  # for symbolic only have the one
        joint_features = jnp.concatenate(
            [
                joint_features[:J],  # shape (2 * J, K)
                jax.nn.one_hot(joint_idxs[:J, 0], nshapes),  # shape (2 * J, N)
                jax.nn.one_hot(joint_idxs[:J, 1], nshapes),  # shape (2 * J, N)
            ],
            axis=1,
        )
        thruster_features = jnp.concatenate(
            [
                thruster_features,
                jax.nn.one_hot(thruster_idxs, nshapes),
            ],
            axis=1,
        )

        polygon_features = jnp.where(polygon_mask[:, None], polygon_features, 0.0).flatten()
        circle_features = jnp.where(circle_mask[:, None], circle_features, 0.0).flatten()
        joint_features = jnp.where(joint_mask[:J, None], joint_features, 0.0).flatten()
        thruster_features = jnp.where(thruster_mask[:, None], thruster_features, 0.0).flatten()

        def _get_manifold_features(manifold):
            collision_mask_features = jnp.concatenate(
                [
                    manifold.normal,
                    jnp.expand_dims(manifold.penetration, axis=-1),
                    manifold.collision_point,
                    jnp.expand_dims(manifold.acc_impulse_normal, axis=-1),
                    jnp.expand_dims(manifold.acc_impulse_tangent, axis=-1),
                ],
                axis=-1,
            )

            return (collision_mask_features * manifold.active[..., None]).flatten()

        obs = jnp.concatenate(
            [
                polygon_features,
                circle_features,
                joint_features,
                thruster_features,
                jnp.array([state.gravity[1]]) / 10,
                # _get_manifold_features(state.acc_cc_manifolds),
                # _get_manifold_features(state.acc_cr_manifolds),
                # _get_manifold_features(state.acc_rr_manifolds),
            ],
            axis=0,
        )

        obs = jnp.clip(obs, a_min=-10.0, a_max=10.0)
        obs = jnp.nan_to_num(obs)
        return obs

    return render_symbolic