File size: 4,938 Bytes
e0f25ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from cmath import rect
from functools import partial

import jax
import jax.numpy as jnp
from flax import struct
from jax2d.engine import get_pairwise_interaction_indices
from kinetix.environment.env_state import EnvState
from kinetix.render.renderer_symbolic_common import (
    make_circle_features,
    make_joint_features,
    make_polygon_features,
    make_thruster_features,
    make_unified_shape_features,
)


@struct.dataclass
class EntityObservation:
    circles: jnp.ndarray
    polygons: jnp.ndarray
    joints: jnp.ndarray
    thrusters: jnp.ndarray

    circle_mask: jnp.ndarray
    polygon_mask: jnp.ndarray
    joint_mask: jnp.ndarray
    thruster_mask: jnp.ndarray
    attention_mask: jnp.ndarray
    # collision_mask: jnp.ndarray

    joint_indexes: jnp.ndarray
    thruster_indexes: jnp.ndarray


def make_render_entities(params, static_params):
    _, _, _, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices(static_params)
    circle_rect_pairs = circle_rect_pairs.at[:, 0].add(static_params.num_polygons)
    circle_circle_pairs = circle_circle_pairs + static_params.num_polygons

    def render_entities(state: EnvState):
        state = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x), state)

        joint_features, joint_indexes, joint_mask = make_joint_features(state, params, static_params)
        thruster_features, thruster_indexes, thruster_mask = make_thruster_features(state, params, static_params)

        poly_nodes, poly_mask = make_polygon_features(state, params, static_params)
        circle_nodes, circle_mask = make_circle_features(state, params, static_params)

        def _add_grav(nodes):
            return jnp.concatenate(
                [nodes, jnp.zeros((nodes.shape[0], 1)) + state.gravity[1] / 10], axis=-1
            )  # add gravity to each shape's embedding

        poly_nodes = _add_grav(poly_nodes)
        circle_nodes = _add_grav(circle_nodes)

        # Shape of something like (NPoly + NCircle + 2 * NJoint + NThruster )
        mask_flat_shapes = jnp.concatenate([poly_mask, circle_mask], axis=0)
        num_shapes = static_params.num_polygons + static_params.num_circles

        def make_n_squared_mask(val):
            # val has shape N of bools.
            N = val.shape[0]
            A = jnp.eye(N, N, dtype=bool)  # also have things attend to themselves
            # Make the shapes fully connected
            full_mask = A.at[:num_shapes, :num_shapes].set(jnp.ones((num_shapes, num_shapes), dtype=bool))

            one_hop_connected = jnp.zeros((N, N), dtype=bool)
            one_hop_connected = one_hop_connected.at[joint_indexes[:, 0], joint_indexes[:, 1]].set(True)
            one_hop_connected = one_hop_connected.at[0, 0].set(False)  # invalid joints have indices of (0, 0)

            multi_hop_connected = jnp.logical_not(state.collision_matrix)

            collision_mask = state.collision_matrix

            # where val is false, we want to mask out the row and column.
            full_mask = full_mask & (val[:, None]) & (val[None, :])
            collision_mask = collision_mask & (val[:, None]) & (val[None, :])
            multi_hop_connected = multi_hop_connected & (val[:, None]) & (val[None, :])
            one_hop_connected = one_hop_connected & (val[:, None]) & (val[None, :])
            collision_manifold_mask = jnp.zeros_like(collision_mask)

            def _set(collision_manifold_mask, pairs, active):
                return collision_manifold_mask.at[
                    pairs[:, 0],
                    pairs[:, 1],
                ].set(active)

            collision_manifold_mask = _set(
                collision_manifold_mask,
                rect_rect_pairs,
                jnp.logical_or(state.acc_rr_manifolds.active[..., 0], state.acc_rr_manifolds.active[..., 1]),
            )

            collision_manifold_mask = _set(collision_manifold_mask, circle_rect_pairs, state.acc_cr_manifolds.active)
            collision_manifold_mask = _set(collision_manifold_mask, circle_circle_pairs, state.acc_cc_manifolds.active)
            collision_manifold_mask = collision_manifold_mask & (val[:, None]) & (val[None, :])

            return jnp.concatenate(
                [full_mask[None], multi_hop_connected[None], one_hop_connected[None], collision_manifold_mask[None]],
                axis=0,
            )

        mask_n_squared = make_n_squared_mask(mask_flat_shapes)

        return EntityObservation(
            circles=circle_nodes,
            polygons=poly_nodes,
            joints=joint_features,
            thrusters=thruster_features,
            circle_mask=circle_mask,
            polygon_mask=poly_mask,
            joint_mask=joint_mask,
            thruster_mask=thruster_mask,
            attention_mask=mask_n_squared,
            joint_indexes=joint_indexes,
            thruster_indexes=thruster_indexes,
        )

    return render_entities