from cmath import rect
from functools import partial

import jax
import jax.numpy as jnp
from flax import struct
from jax2d.engine import get_pairwise_interaction_indices
from kinetix.environment.env_state import EnvState
from kinetix.render.renderer_symbolic_common import (
    make_circle_features,
    make_joint_features,
    make_polygon_features,
    make_thruster_features,
    make_unified_shape_features,
)


@struct.dataclass
class EntityObservation:
    circles: jnp.ndarray
    polygons: jnp.ndarray
    joints: jnp.ndarray
    thrusters: jnp.ndarray

    circle_mask: jnp.ndarray
    polygon_mask: jnp.ndarray
    joint_mask: jnp.ndarray
    thruster_mask: jnp.ndarray
    attention_mask: jnp.ndarray
    # collision_mask: jnp.ndarray

    joint_indexes: jnp.ndarray
    thruster_indexes: jnp.ndarray


def make_render_entities(params, static_params):
    _, _, _, circle_circle_pairs, circle_rect_pairs, rect_rect_pairs = get_pairwise_interaction_indices(static_params)
    circle_rect_pairs = circle_rect_pairs.at[:, 0].add(static_params.num_polygons)
    circle_circle_pairs = circle_circle_pairs + static_params.num_polygons

    def render_entities(state: EnvState):
        state = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x), state)

        joint_features, joint_indexes, joint_mask = make_joint_features(state, params, static_params)
        thruster_features, thruster_indexes, thruster_mask = make_thruster_features(state, params, static_params)

        poly_nodes, poly_mask = make_polygon_features(state, params, static_params)
        circle_nodes, circle_mask = make_circle_features(state, params, static_params)

        def _add_grav(nodes):
            return jnp.concatenate(
                [nodes, jnp.zeros((nodes.shape[0], 1)) + state.gravity[1] / 10], axis=-1
            )  # add gravity to each shape's embedding

        poly_nodes = _add_grav(poly_nodes)
        circle_nodes = _add_grav(circle_nodes)

        # Shape of something like (NPoly + NCircle + 2 * NJoint + NThruster )
        mask_flat_shapes = jnp.concatenate([poly_mask, circle_mask], axis=0)
        num_shapes = static_params.num_polygons + static_params.num_circles

        def make_n_squared_mask(val):
            # val has shape N of bools.
            N = val.shape[0]
            A = jnp.eye(N, N, dtype=bool)  # also have things attend to themselves
            # Make the shapes fully connected
            full_mask = A.at[:num_shapes, :num_shapes].set(jnp.ones((num_shapes, num_shapes), dtype=bool))

            one_hop_connected = jnp.zeros((N, N), dtype=bool)
            one_hop_connected = one_hop_connected.at[joint_indexes[:, 0], joint_indexes[:, 1]].set(True)
            one_hop_connected = one_hop_connected.at[0, 0].set(False)  # invalid joints have indices of (0, 0)

            multi_hop_connected = jnp.logical_not(state.collision_matrix)

            collision_mask = state.collision_matrix

            # where val is false, we want to mask out the row and column.
            full_mask = full_mask & (val[:, None]) & (val[None, :])
            collision_mask = collision_mask & (val[:, None]) & (val[None, :])
            multi_hop_connected = multi_hop_connected & (val[:, None]) & (val[None, :])
            one_hop_connected = one_hop_connected & (val[:, None]) & (val[None, :])
            collision_manifold_mask = jnp.zeros_like(collision_mask)

            def _set(collision_manifold_mask, pairs, active):
                return collision_manifold_mask.at[
                    pairs[:, 0],
                    pairs[:, 1],
                ].set(active)

            collision_manifold_mask = _set(
                collision_manifold_mask,
                rect_rect_pairs,
                jnp.logical_or(state.acc_rr_manifolds.active[..., 0], state.acc_rr_manifolds.active[..., 1]),
            )

            collision_manifold_mask = _set(collision_manifold_mask, circle_rect_pairs, state.acc_cr_manifolds.active)
            collision_manifold_mask = _set(collision_manifold_mask, circle_circle_pairs, state.acc_cc_manifolds.active)
            collision_manifold_mask = collision_manifold_mask & (val[:, None]) & (val[None, :])

            return jnp.concatenate(
                [full_mask[None], multi_hop_connected[None], one_hop_connected[None], collision_manifold_mask[None]],
                axis=0,
            )

        mask_n_squared = make_n_squared_mask(mask_flat_shapes)

        return EntityObservation(
            circles=circle_nodes,
            polygons=poly_nodes,
            joints=joint_features,
            thrusters=thruster_features,
            circle_mask=circle_mask,
            polygon_mask=poly_mask,
            joint_mask=joint_mask,
            thruster_mask=thruster_mask,
            attention_mask=mask_n_squared,
            joint_indexes=joint_indexes,
            thruster_indexes=thruster_indexes,
        )

    return render_entities