import math
from functools import partial

import jax
import jax.numpy as jnp

from jax2d.engine import PhysicsEngine, calculate_collision_matrix, recalculate_mass_and_inertia, select_shape
from jax2d.sim_state import RigidBody, Thruster
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams


def sample_dimensions(rng, static_env_params: StaticEnvParams, is_rect: bool, ued_params, max_shape_size=None):
    if max_shape_size is None:
        max_shape_size = static_env_params.max_shape_size
    # Returns (half_dimensions, radius)

    rng, _rng = jax.random.split(rng)
    # Don't want overly small shapes
    min_rect_size = 0.05
    min_circle_size = 0.1
    cap_rect = max_shape_size / 2.0 / jnp.sqrt(2.0)
    cap_circ = max_shape_size / 2.0 * ued_params.circle_max_size_coeff
    half_dimensions = (
        jax.lax.select(is_rect, jax.random.uniform(_rng, shape=(2,)), jnp.zeros(2, dtype=jnp.float32))
        * (cap_rect - min_rect_size)
        + min_rect_size
    )

    rng, _rng, __rng = jax.random.split(rng, 3)
    dim_scale = (
        jnp.ones(2)
        .at[jax.random.randint(_rng, shape=(), minval=0, maxval=2)]
        .set(
            jax.lax.select(
                jax.random.uniform(__rng) < ued_params.large_rect_dim_chance, ued_params.large_rect_dim_scale, 1.0
            )
        )
    )
    half_dimensions *= dim_scale

    vertices = jnp.array(
        [
            half_dimensions * jnp.array([1, 1]),
            half_dimensions * jnp.array([1, -1]),
            half_dimensions * jnp.array([-1, -1]),
            half_dimensions * jnp.array([-1, 1]),
        ]
    )

    rng, _rng = jax.random.split(rng)
    radius = (
        jax.lax.select(is_rect, jnp.zeros((), dtype=jnp.float32), jax.random.uniform(_rng, shape=()))
        * (cap_circ - min_circle_size)
        + min_circle_size
    )
    return vertices, half_dimensions, radius


def count_roles(state: EnvState, static_env_params: StaticEnvParams, role: int, include_static_polys=True) -> int:
    active_to_use = state.polygon.active
    if not include_static_polys:
        active_to_use = active_to_use.at[: static_env_params.num_static_fixated_polys].set(False)
    return ((state.polygon_shape_roles == role) * active_to_use).sum() + (
        (state.circle_shape_roles == role) * state.circle.active
    ).sum()


def random_position_on_triangle(rng, vertices):
    verts = vertices[:3]
    rng, _rng, _rng2 = jax.random.split(rng, 3)
    f1 = jax.random.uniform(_rng)
    f2 = jax.random.uniform(_rng2)
    # https://www.reddit.com/r/godot/comments/mqp29g/how_do_i_get_a_random_position_inside_a_collision/
    return verts[0] + jnp.sqrt(f1) * (-verts[0] + verts[1] + f2 * (verts[2] - verts[1]))


def random_position_on_rectangle(rng, vertices):
    verts = vertices[:4]
    rng, _rng, _rng2 = jax.random.split(rng, 3)
    f1 = jax.random.uniform(_rng)
    f2 = jax.random.uniform(_rng2)

    min_x, max_x = jnp.min(verts[:, 0]), jnp.max(verts[:, 0])
    min_y, max_y = jnp.min(verts[:, 1]), jnp.max(verts[:, 1])
    random_x_pos = min_x + f1 * (max_x - min_x)
    random_y_pos = min_y + f2 * (max_y - min_y)

    return jnp.array([random_x_pos, random_y_pos])


def random_position_on_polygon(rng, vertices, n_vertices, static_env_params: StaticEnvParams):
    assert static_env_params.max_polygon_vertices <= 4, "Only supports up to 4 vertices"
    return jax.lax.select(
        n_vertices <= 3, random_position_on_triangle(rng, vertices), random_position_on_rectangle(rng, vertices)
    )


def random_position_on_circle(rng, radius, on_centre_chance):
    rngs = jax.random.split(rng, 3)

    on_centre = jax.random.uniform(rngs[0]) < on_centre_chance

    local_joint_position_circle_theta = jax.random.uniform(rngs[1], shape=()) * 2 * math.pi
    local_joint_position_circle_r = jax.random.uniform(rngs[2], shape=()) * radius
    local_joint_position_circle = jnp.array(
        [
            local_joint_position_circle_r * jnp.cos(local_joint_position_circle_theta),
            local_joint_position_circle_r * jnp.sin(local_joint_position_circle_theta),
        ]
    )

    return jax.lax.select(on_centre, jnp.array([0.0, 0.0]), local_joint_position_circle)


def get_role(rng, state: EnvState, static_env_params: StaticEnvParams, initial_p=None) -> int:

    if initial_p is None:
        initial_p = jnp.array([1.0, 1.0, 1.0, 1.0])

    needs_ball = count_roles(state, static_env_params, 1) == 0
    needs_goal = count_roles(state, static_env_params, 2) == 0
    needs_lava = count_roles(state, static_env_params, 3) == 0

    # always put goal/ball first.
    prob_of_something_else = (needs_ball == 0) & (needs_goal == 0)
    p = initial_p * jnp.array(
        [prob_of_something_else, needs_ball, needs_goal, prob_of_something_else * needs_lava / 3]
    )  # This ensures we cannot more than one ball or goal.
    return jax.random.choice(rng, jnp.array([0, 1, 2, 3]), p=p)


def is_space_for_shape(state: EnvState):
    return jnp.logical_not(jnp.concatenate([state.polygon.active, state.circle.active])).sum() > 0


def is_space_for_joint(state: EnvState):
    return jnp.logical_not(state.joint.active).sum() > 0


def are_there_shapes_present(state: EnvState, static_env_params: StaticEnvParams):
    m = (
        jnp.concatenate([state.polygon.active, state.circle.active])
        .at[: static_env_params.num_static_fixated_polys]
        .set(False)
    )
    return m.sum() > 0


@partial(jax.jit, static_argnums=(2, 9))
def add_rigidbody_to_state(
    state: EnvState,
    env_params: EnvParams,
    static_env_params: StaticEnvParams,
    position: jnp.ndarray,
    vertices: jnp.ndarray,
    n_vertices: int,
    radius: float,
    shape_role: int,
    density: float = 1,
    is_circle: bool = False,
):

    new_rigid_body = RigidBody(
        position=position,
        velocity=jnp.array([0.0, 0.0]),
        inverse_mass=1.0,
        inverse_inertia=1.0,
        rotation=0.0,
        angular_velocity=0.0,
        radius=radius,
        active=True,
        friction=1.0,
        vertices=vertices,
        n_vertices=n_vertices,
        collision_mode=1,
        restitution=0.0,
    )

    if is_circle:
        actives = state.circle.active
    else:
        actives = state.polygon.active

    idx = jnp.argmin(actives)

    def noop(state):
        return state

    def replace(state):
        add_func = lambda all, new: all.at[idx].set(new)
        if is_circle:
            state = state.replace(
                circle=jax.tree.map(add_func, state.circle, new_rigid_body),
                circle_densities=state.circle_densities.at[idx].set(density),
                circle_shape_roles=state.circle_shape_roles.at[idx].set(shape_role),
            )
        else:
            state = state.replace(
                polygon=jax.tree.map(add_func, state.polygon, new_rigid_body),
                polygon_densities=state.polygon_densities.at[idx].set(density),
                polygon_shape_roles=state.polygon_shape_roles.at[idx].set(shape_role),
            )

        state = state.replace(
            collision_matrix=calculate_collision_matrix(static_env_params, state.joint),
        )

        state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities)
        return state

    return jax.lax.cond(jnp.logical_not(actives).sum() > 0, replace, noop, state)


def rectangle_vertices(half_dim):
    return jnp.array(
        [
            half_dim * jnp.array([1, 1]),
            half_dim * jnp.array([1, -1]),
            half_dim * jnp.array([-1, -1]),
            half_dim * jnp.array([-1, 1]),
        ]
    )


# More Manual Control
@partial(jax.jit, static_argnums=(2,))
def add_rectangle_to_state(
    state: EnvState,
    env_params: EnvParams,
    static_env_params: StaticEnvParams,
    position: jnp.ndarray,
    width: float,
    height: float,
    shape_role: int,
    density: float = 1,
):

    return add_rigidbody_to_state(
        state,
        env_params,
        static_env_params,
        position,
        rectangle_vertices(jnp.array([width, height]) / 2),
        4,
        0.0,
        shape_role,
        density,
        is_circle=False,
    )


@partial(jax.jit, static_argnums=(2,))
def add_circle_to_state(
    state: EnvState,
    env_params: EnvParams,
    static_env_params: StaticEnvParams,
    position: jnp.ndarray,
    radius: float,
    shape_role: int,
    density: float = 1,
):
    return add_rigidbody_to_state(
        state,
        env_params,
        static_env_params,
        position,
        jnp.array([0.0, 0.0]),
        0,
        radius,
        shape_role,
        density,
        is_circle=True,
    )


@partial(jax.jit, static_argnums=(2,))
def add_thruster_to_object(
    state: EnvState,
    env_params: EnvParams,
    static_env_params: StaticEnvParams,
    shape_index: int,
    rotation: float,
    colour: int,
    thruster_power_multiplier: float,
):
    def dummy(state):
        return state

    def do_add(state: EnvState):
        thruster_idx = jnp.argmin(state.thruster.active)

        shape = select_shape(state, shape_index, static_env_params)

        thruster = Thruster(
            object_index=shape_index,
            active=True,
            relative_position=jnp.array([0.0, 0.0]),  # a bit of a hack but reasonable.
            rotation=rotation,
            power=1.0 / jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass) * thruster_power_multiplier,
            global_position=select_shape(state, shape_index, static_env_params).position,
        )

        state = state.replace(
            thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster),
            thruster_bindings=state.thruster_bindings.at[thruster_idx].set(colour),
        )

        return state

    return jax.lax.cond(
        (select_shape(state, shape_index, static_env_params).active)
        & (jnp.logical_not(state.thruster.active).sum() > 0),
        do_add,
        dummy,
        state,
    )


def make_velocities_zero(state: EnvState):
    def inner(state):
        return state.replace(
            polygon=state.polygon.replace(
                angular_velocity=state.polygon.angular_velocity * 0,
                velocity=state.polygon.velocity * 0,
            ),
            circle=state.circle.replace(
                angular_velocity=state.circle.angular_velocity * 0,
                velocity=state.circle.velocity * 0,
            ),
        )

    return inner(state)


def make_do_dummy_step(
    params: EnvParams, static_sim_params: StaticEnvParams, zero_collisions=True, zero_velocities=True
):
    env = PhysicsEngine(static_sim_params)

    @jax.jit
    def _step_fn(state):
        state, _ = env.step(state, params, jnp.zeros((static_sim_params.num_joints + static_sim_params.num_thrusters,)))
        return state

    def do_dummy_step(state: EnvState) -> EnvState:
        rng = jax.random.PRNGKey(0)
        og_col = state.collision_matrix
        g = state.gravity
        state = state.replace(
            collision_matrix=state.collision_matrix & (not zero_collisions), gravity=state.gravity * 0
        )
        state = _step_fn(state)
        state = state.replace(gravity=g, collision_matrix=og_col)
        if zero_velocities:
            state = make_velocities_zero(state)
        return state

    return do_dummy_step