import jax
from jax2d.sim_state import RigidBody
import jax.numpy as jnp
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams


def _get_base_shape_features(
    density: jnp.ndarray, roles: jnp.ndarray, shapes: RigidBody, env_params: EnvParams
) -> jnp.ndarray:
    cos = jnp.cos(shapes.rotation)
    sin = jnp.sin(shapes.rotation)
    return jnp.concatenate(
        [
            shapes.position,
            shapes.velocity,
            jnp.expand_dims(shapes.inverse_mass, axis=1),
            jnp.expand_dims(shapes.inverse_inertia, axis=1),
            jnp.expand_dims(density, axis=1),
            jnp.expand_dims(jnp.tanh(shapes.angular_velocity / 10), axis=1),
            jax.nn.one_hot(roles, env_params.num_shape_roles),
            jnp.expand_dims(sin, axis=1),
            jnp.expand_dims(cos, axis=1),
            jnp.expand_dims(shapes.friction, axis=1),
            jnp.expand_dims(shapes.restitution, axis=1),
        ],
        axis=1,
    )


def add_circle_features(
    base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams
):
    return jnp.concatenate(
        [
            base_features,
            shapes.radius[:, None],
            jnp.ones_like(base_features[:, :1]),  # one for circle
        ],
        axis=1,
    )


def make_circle_features(
    state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray]:
    base_features = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params)
    node_features = add_circle_features(base_features, state.circle, env_params, static_env_params)
    return node_features, state.circle.active


def add_polygon_features(
    base_features: jnp.ndarray, shapes: RigidBody, env_params: EnvParams, static_env_params: StaticEnvParams
):
    vertices = jnp.where(
        jnp.arange(static_env_params.max_polygon_vertices)[None, :, None] < shapes.n_vertices[:, None, None],
        shapes.vertices,
        jnp.zeros_like(shapes.vertices) - 1,
    )

    return jnp.concatenate(
        [
            base_features,
            jnp.zeros_like(base_features[:, :1]),  # zero for polygon
            vertices.reshape((vertices.shape[0], -1)),
            jnp.expand_dims((shapes.n_vertices <= 3), axis=1),
        ],
        axis=1,
    )


def make_polygon_features(
    state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray]:
    base_features = _get_base_shape_features(
        state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params
    )
    node_features = add_polygon_features(base_features, state.polygon, env_params, static_env_params)
    return node_features, state.polygon.active


def make_unified_shape_features(
    state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray]:
    base_p = _get_base_shape_features(state.polygon_densities, state.polygon_shape_roles, state.polygon, env_params)
    base_c = _get_base_shape_features(state.circle_densities, state.circle_shape_roles, state.circle, env_params)
    base_p = add_polygon_features(base_p, state.polygon, env_params, static_env_params)
    base_p = add_circle_features(base_p, state.polygon, env_params, static_env_params)

    base_c = add_polygon_features(base_c, state.circle, env_params, static_env_params)
    base_c = add_circle_features(base_c, state.circle, env_params, static_env_params)

    return jnp.concatenate([base_p, base_c], axis=0), jnp.concatenate(
        [state.polygon.active, state.circle.active], axis=0
    )


def make_joint_features(
    state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    # Returns joint_features, indexes, mask, of shape:
    # (2 * J, K), (2 * J, 2), (2 * J,)
    def _create_joint_features(joints):
        # 2, J, A
        J = joints.active.shape[0]

        def _create_1way_joint_features(direction):
            from_pos = jax.lax.select(direction, joints.a_relative_pos, joints.b_relative_pos)
            to_pos = jax.lax.select(direction, joints.b_relative_pos, joints.a_relative_pos)

            rotation_sin, rotation_cos = jnp.sin(joints.rotation), jnp.cos(joints.rotation)
            rotation_max_sin = jnp.sin(joints.max_rotation) * joints.motor_has_joint_limits
            rotation_max_cos = jnp.cos(joints.max_rotation) * joints.motor_has_joint_limits
            rotation_min_sin = jnp.sin(joints.min_rotation) * joints.motor_has_joint_limits
            rotation_min_cos = jnp.cos(joints.min_rotation) * joints.motor_has_joint_limits

            rotation_diff_max = (joints.max_rotation - joints.rotation) * joints.motor_has_joint_limits
            rotation_diff_min = (joints.min_rotation - joints.rotation) * joints.motor_has_joint_limits

            base_features = jnp.concatenate(
                [
                    (joints.active * 1.0)[:, None],
                    (joints.is_fixed_joint * 1.0)[:, None],  # J, 1
                    from_pos,
                    to_pos,
                    rotation_sin[:, None],
                    rotation_cos[:, None],
                ],
                axis=1,
            )
            rjoint_features = (
                jnp.concatenate(
                    [
                        joints.motor_speed[:, None],
                        joints.motor_power[:, None],
                        (joints.motor_on * 1.0)[:, None],
                        (joints.motor_has_joint_limits * 1.0)[:, None],
                        jax.nn.one_hot(state.motor_bindings, num_classes=static_env_params.num_motor_bindings),
                        rotation_min_sin[:, None],
                        rotation_min_cos[:, None],
                        rotation_max_sin[:, None],
                        rotation_max_cos[:, None],
                        rotation_diff_min[:, None],
                        rotation_diff_max[:, None],
                    ],
                    axis=1,
                )
                * (1.0 - (joints.is_fixed_joint * 1.0))[:, None]
            )

            return jnp.concatenate([base_features, rjoint_features], axis=1)

        # 2, J, A
        joint_features = jax.vmap(_create_1way_joint_features)(jnp.array([False, True]))

        # J, 2
        indexes_from = jnp.concatenate([joints.b_index[:, None], joints.a_index[:, None]], axis=1)
        indexes_to = jnp.concatenate([joints.a_index[:, None], joints.b_index[:, None]], axis=1)

        indexes_from = jnp.where(joints.active[:, None], indexes_from, jnp.zeros_like(indexes_from))
        indexes_to = jnp.where(joints.active[:, None], indexes_to, jnp.zeros_like(indexes_to))

        indexes = jnp.concatenate([indexes_from, indexes_to], axis=0)
        mask = jnp.concatenate([joints.active, joints.active], axis=0)

        return joint_features.reshape((2 * J, -1)), indexes, mask

    return _create_joint_features(state.joint)


def make_thruster_features(
    state: EnvState, env_params: EnvParams, static_env_params: StaticEnvParams
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    # Returns thruster_features, indexes, mask, of shape:
    # (T, K), (T,), (T,)
    def _create_thruster_features(thrusters):
        cos = jnp.cos(thrusters.rotation)
        sin = jnp.sin(thrusters.rotation)
        return jnp.concatenate(
            [
                (thrusters.active * 1.0)[:, None],
                (thrusters.relative_position),
                jax.nn.one_hot(state.thruster_bindings, num_classes=static_env_params.num_thruster_bindings),
                sin[:, None],
                cos[:, None],
                thrusters.power[:, None],
            ],
            axis=1,
        )

    return _create_thruster_features(state.thruster), state.thruster.object_index, state.thruster.active