"""
Based on PureJaxRL Implementation of PPO
"""

import os
import sys
import time
import typing
from functools import partial
from typing import NamedTuple

import chex
import hydra
import jax
import jax.experimental
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from flax.training.train_state import TrainState
from kinetix.environment.ued.ued import make_reset_train_function_with_mutations, make_vmapped_filtered_level_sampler
from kinetix.environment.ued.ued import (
    make_reset_train_function_with_list_of_levels,
    make_reset_train_function_with_mutations,
)
from kinetix.util.config import (
    generate_ued_params_from_config,
    init_wandb,
    normalise_config,
    generate_params_from_config,
    get_eval_level_groups,
)
from jaxued.environments.underspecified_env import EnvParams, EnvState, Observation, UnderspecifiedEnv
from omegaconf import OmegaConf
from PIL import Image
from flax.serialization import to_state_dict

import wandb
from kinetix.environment.env import make_kinetix_env_from_name
from kinetix.environment.wrappers import (
    AutoReplayWrapper,
    DenseRewardWrapper,
    LogWrapper,
    UnderspecifiedToGymnaxWrapper,
)
from kinetix.models import make_network_from_config
from kinetix.models.actor_critic import ScannedRNN
from kinetix.render.renderer_pixels import make_render_pixels
from kinetix.util.learning import general_eval, get_eval_levels
from kinetix.util.saving import (
    load_train_state_from_wandb_artifact_path,
    save_model_to_wandb,
)

sys.path.append("ued")
from flax.traverse_util import flatten_dict, unflatten_dict
from safetensors.flax import load_file, save_file


def save_params(params: typing.Dict, filename: typing.Union[str, os.PathLike]) -> None:
    flattened_dict = flatten_dict(params, sep=",")
    save_file(flattened_dict, filename)


def load_params(filename: typing.Union[str, os.PathLike]) -> typing.Dict:
    flattened_dict = load_file(filename)
    return unflatten_dict(flattened_dict, sep=",")


class Transition(NamedTuple):
    global_done: jnp.ndarray
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray


class RolloutBatch(NamedTuple):
    obs: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    dones: jnp.ndarray
    log_probs: jnp.ndarray
    values: jnp.ndarray
    targets: jnp.ndarray
    advantages: jnp.ndarray
    # carry: jnp.ndarray
    mask: jnp.ndarray


def evaluate_rnn(
    rng: chex.PRNGKey,
    env: UnderspecifiedEnv,
    env_params: EnvParams,
    train_state: TrainState,
    init_hstate: chex.ArrayTree,
    init_obs: Observation,
    init_env_state: EnvState,
    max_episode_length: int,
    keep_states=True,
) -> tuple[chex.Array, chex.Array, chex.Array]:
    """This runs the RNN on the environment, given an initial state and observation, and returns (states, rewards, episode_lengths)

    Args:
        rng (chex.PRNGKey):
        env (UnderspecifiedEnv):
        env_params (EnvParams):
        train_state (TrainState):
        init_hstate (chex.ArrayTree): Shape (num_levels, )
        init_obs (Observation): Shape (num_levels, )
        init_env_state (EnvState): Shape (num_levels, )
        max_episode_length (int):

    Returns:
        Tuple[chex.Array, chex.Array, chex.Array]: (States, rewards, episode lengths) ((NUM_STEPS, NUM_LEVELS), (NUM_STEPS, NUM_LEVELS), (NUM_LEVELS,)
    """
    num_levels = jax.tree_util.tree_flatten(init_obs)[0][0].shape[0]

    def step(carry, _):
        rng, hstate, obs, state, done, mask, episode_length = carry
        rng, rng_action, rng_step = jax.random.split(rng, 3)

        x = jax.tree.map(lambda x: x[None, ...], (obs, done))
        hstate, pi, _ = train_state.apply_fn(train_state.params, hstate, x)
        action = pi.sample(seed=rng_action).squeeze(0)

        obs, next_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
            jax.random.split(rng_step, num_levels), state, action, env_params
        )

        next_mask = mask & ~done
        episode_length += mask

        if keep_states:
            return (rng, hstate, obs, next_state, done, next_mask, episode_length), (state, reward, info)
        else:
            return (rng, hstate, obs, next_state, done, next_mask, episode_length), (None, reward, info)

    (_, _, _, _, _, _, episode_lengths), (states, rewards, infos) = jax.lax.scan(
        step,
        (
            rng,
            init_hstate,
            init_obs,
            init_env_state,
            jnp.zeros(num_levels, dtype=bool),
            jnp.ones(num_levels, dtype=bool),
            jnp.zeros(num_levels, dtype=jnp.int32),
        ),
        None,
        length=max_episode_length,
    )

    return states, rewards, episode_lengths, infos


@hydra.main(version_base=None, config_path="../configs", config_name="sfl")
def main(config):
    time_start = time.time()
    config = OmegaConf.to_container(config)
    config = normalise_config(config, "SFL" if config["ued"]["sampled_envs_ratio"] > 0 else "SFL-DR")
    env_params, static_env_params = generate_params_from_config(config)
    config["env_params"] = to_state_dict(env_params)
    config["static_env_params"] = to_state_dict(static_env_params)
    run = init_wandb(config, "SFL")

    rng = jax.random.PRNGKey(config["seed"])

    config["num_envs_from_sampled"] = int(config["num_train_envs"] * config["sampled_envs_ratio"])
    config["num_envs_to_generate"] = int(config["num_train_envs"] * (1 - config["sampled_envs_ratio"]))
    assert (config["num_envs_from_sampled"] + config["num_envs_to_generate"]) == config["num_train_envs"]

    def make_env(static_env_params):
        env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
        env = AutoReplayWrapper(env)
        env = UnderspecifiedToGymnaxWrapper(env)
        env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"])
        env = LogWrapper(env)
        return env

    env = make_env(static_env_params)

    if config["train_level_mode"] == "list":
        sample_random_level = make_reset_train_function_with_list_of_levels(
            config, config["train_levels"], static_env_params, make_pcg_state=False, is_loading_train_levels=True
        )
    elif config["train_level_mode"] == "random":
        sample_random_level = make_reset_train_function_with_mutations(
            env.physics_engine, env_params, static_env_params, config, make_pcg_state=False
        )
    else:
        raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")

    sample_random_levels = make_vmapped_filtered_level_sampler(
        sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env
    )
    _, eval_static_env_params = generate_params_from_config(
        config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
    )
    eval_env = make_env(eval_static_env_params)
    ued_params = generate_ued_params_from_config(config)

    def make_render_fn(static_env_params):
        render_fn_inner = make_render_pixels(env_params, static_env_params)
        render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1]
        return render_fn

    render_fn = make_render_fn(static_env_params)
    render_fn_eval = make_render_fn(eval_static_env_params)

    NUM_EVAL_DR_LEVELS = 200
    key_to_sample_dr_eval_set = jax.random.PRNGKey(100)
    DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS)

    print("Hello here num steps is ", config["num_steps"])
    print("CONFIG is ", config)

    config["total_timesteps"] = config["num_updates"] * config["num_steps"] * config["num_train_envs"]
    config["minibatch_size"] = config["num_train_envs"] * config["num_steps"] // config["num_minibatches"]
    config["clip_eps"] = config["clip_eps"]

    config["env_name"] = config["env_name"]
    network = make_network_from_config(env, env_params, config)

    def linear_schedule(count):
        count = count // (config["num_minibatches"] * config["update_epochs"])
        frac = 1.0 - count / config["num_updates"]
        return config["lr"] * frac

    # INIT NETWORK
    rng, _rng = jax.random.split(rng)
    train_envs = 32  # To not run out of memory, the initial sample size does not matter.
    obs, _ = env.reset_to_level(rng, sample_random_level(rng), env_params)
    obs = jax.tree.map(
        lambda x: jnp.repeat(jnp.repeat(x[None, ...], train_envs, axis=0)[None, ...], 256, axis=0),
        obs,
    )
    init_x = (obs, jnp.zeros((256, train_envs)))
    init_hstate = ScannedRNN.initialize_carry(train_envs)
    network_params = network.init(_rng, init_hstate, init_x)
    if config["anneal_lr"]:
        tx = optax.chain(
            optax.clip_by_global_norm(config["max_grad_norm"]),
            optax.adam(learning_rate=linear_schedule, eps=1e-5),
        )
    else:
        tx = optax.chain(
            optax.clip_by_global_norm(config["max_grad_norm"]),
            optax.adam(config["lr"], eps=1e-5),
        )
    train_state = TrainState.create(
        apply_fn=network.apply,
        params=network_params,
        tx=tx,
    )
    if config["load_from_checkpoint"] != None:
        print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
        train_state = load_train_state_from_wandb_artifact_path(
            train_state,
            config["load_from_checkpoint"],
            load_only_params=config["load_only_params"],
            legacy=config["load_legacy_checkpoint"],
        )

    rng, _rng = jax.random.split(rng)

    # INIT ENV
    rng, _rng, _rng2 = jax.random.split(rng, 3)
    rng_reset = jax.random.split(_rng, config["num_train_envs"])

    new_levels = sample_random_levels(_rng2, config["num_train_envs"])
    obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)

    start_state = env_state
    init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])

    @jax.jit
    def log_buffer_learnability(rng, train_state, instances):
        BATCH_SIZE = config["num_to_save"]
        BATCH_ACTORS = BATCH_SIZE

        def _batch_step(unused, rng):
            def _env_step(runner_state, unused):
                env_state, start_state, last_obs, last_done, hstate, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                obs_batch = last_obs
                ac_in = (
                    jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
                    last_done[np.newaxis, :],
                )
                hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
                action = pi.sample(seed=_rng).squeeze()
                log_prob = pi.log_prob(action)
                env_act = action

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["num_to_save"])
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
                    rng_step, env_state, env_act, env_params
                )
                done_batch = done

                transition = Transition(
                    done,
                    last_done,
                    action.squeeze(),
                    value.squeeze(),
                    reward,
                    log_prob.squeeze(),
                    obs_batch,
                    info,
                )
                runner_state = (env_state, start_state, obsv, done_batch, hstate, rng)
                return runner_state, transition

            @partial(jax.vmap, in_axes=(None, 1, 1, 1))
            @partial(jax.jit, static_argnums=(0,))
            def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
                idxs = jnp.arange(max_steps)

                @partial(jax.vmap, in_axes=(0, 0))
                def __ep_outcomes(start_idx, end_idx):
                    mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
                    r = jnp.sum(returns * mask)
                    goal_r = info["GoalR"]  # (returns > 0) * 1.0
                    success = jnp.sum(goal_r * mask)
                    l = end_idx - start_idx
                    return r, success, l

                done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
                mask_done = jnp.where(done_idxs == max_steps, 0, 1)
                ep_return, success, length = __ep_outcomes(
                    jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
                )

                return {
                    "ep_return": ep_return.mean(where=mask_done),
                    "num_episodes": mask_done.sum(),
                    "success_rate": success.mean(where=mask_done),
                    "ep_len": length.mean(where=mask_done),
                }

            # sample envs
            rng, _rng, _rng2 = jax.random.split(rng, 3)
            rng_reset = jax.random.split(_rng, config["num_to_save"])
            rng_levels = jax.random.split(_rng2, config["num_to_save"])
            # obsv, env_state = jax.vmap(sample_random_level, in_axes=(0,))(reset_rng)
            # new_levels = jax.vmap(sample_random_level)(rng_levels)
            obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, instances, env_params)
            # env_instances = new_levels
            init_hstate = ScannedRNN.initialize_carry(
                BATCH_ACTORS,
            )

            runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng)
            runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"])
            done_by_env = traj_batch.done.reshape((-1, config["num_to_save"]))
            reward_by_env = traj_batch.reward.reshape((-1, config["num_to_save"]))
            # info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info)
            o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info)
            success_by_env = o["success_rate"].reshape((1, config["num_to_save"]))
            learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
            return None, (learnability_by_env, success_by_env.sum(axis=0))

        rngs = jax.random.split(rng, 1)
        _, (learnability, success_by_env) = jax.lax.scan(_batch_step, None, rngs, 1)
        return learnability[0], success_by_env[0]

    num_eval_levels = len(config["eval_levels"])
    all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)

    eval_group_indices = get_eval_level_groups(config["eval_levels"])
    print("group indices", eval_group_indices)

    @jax.jit
    def get_learnability_set(rng, network_params):

        BATCH_ACTORS = config["batch_size"]

        def _batch_step(unused, rng):
            def _env_step(runner_state, unused):
                env_state, start_state, last_obs, last_done, hstate, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                obs_batch = last_obs
                ac_in = (
                    jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
                    last_done[np.newaxis, :],
                )
                hstate, pi, value = network.apply(network_params, hstate, ac_in)
                action = pi.sample(seed=_rng).squeeze()
                log_prob = pi.log_prob(action)
                env_act = action

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["batch_size"])
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
                    rng_step, env_state, env_act, env_params
                )
                done_batch = done

                transition = Transition(
                    done,
                    last_done,
                    action.squeeze(),
                    value.squeeze(),
                    reward,
                    log_prob.squeeze(),
                    obs_batch,
                    info,
                )
                runner_state = (env_state, start_state, obsv, done_batch, hstate, rng)
                return runner_state, transition

            @partial(jax.vmap, in_axes=(None, 1, 1, 1))
            @partial(jax.jit, static_argnums=(0,))
            def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
                idxs = jnp.arange(max_steps)

                @partial(jax.vmap, in_axes=(0, 0))
                def __ep_outcomes(start_idx, end_idx):
                    mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
                    r = jnp.sum(returns * mask)
                    goal_r = info["GoalR"]  # (returns > 0) * 1.0
                    success = jnp.sum(goal_r * mask)
                    l = end_idx - start_idx
                    return r, success, l

                done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
                mask_done = jnp.where(done_idxs == max_steps, 0, 1)
                ep_return, success, length = __ep_outcomes(
                    jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
                )

                return {
                    "ep_return": ep_return.mean(where=mask_done),
                    "num_episodes": mask_done.sum(),
                    "success_rate": success.mean(where=mask_done),
                    "ep_len": length.mean(where=mask_done),
                }

            # sample envs
            rng, _rng, _rng2 = jax.random.split(rng, 3)
            rng_reset = jax.random.split(_rng, config["batch_size"])
            new_levels = sample_random_levels(_rng2, config["batch_size"])
            obsv, env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)
            env_instances = new_levels
            init_hstate = ScannedRNN.initialize_carry(
                BATCH_ACTORS,
            )

            runner_state = (env_state, env_state, obsv, jnp.zeros((BATCH_ACTORS), dtype=bool), init_hstate, rng)
            runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["rollout_steps"])
            done_by_env = traj_batch.done.reshape((-1, config["batch_size"]))
            reward_by_env = traj_batch.reward.reshape((-1, config["batch_size"]))
            # info_by_actor = jax.tree.map(lambda x: x.swapaxes(2, 1).reshape((-1, BATCH_ACTORS)), traj_batch.info)
            o = _calc_outcomes_by_agent(config["rollout_steps"], traj_batch.done, traj_batch.reward, traj_batch.info)
            success_by_env = o["success_rate"].reshape((1, config["batch_size"]))
            learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)
            return None, (learnability_by_env, success_by_env.sum(axis=0), env_instances)

        if config["sampled_envs_ratio"] == 0.0:
            print("Not doing any rollouts because sampled_envs_ratio is 0.0")
            # Here we have zero envs, so we can literally just sample random ones because there is no point.
            top_instances = sample_random_levels(_rng, config["num_to_save"])
            top_success = top_learn = learnability = success_rates = jnp.zeros(config["num_to_save"])
        else:
            rngs = jax.random.split(rng, config["num_batches"])
            _, (learnability, success_rates, env_instances) = jax.lax.scan(
                _batch_step, None, rngs, config["num_batches"]
            )

            flat_env_instances = jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:]), env_instances)
            learnability = learnability.flatten() + success_rates.flatten() * 0.001
            top_1000 = jnp.argsort(learnability)[-config["num_to_save"] :]

            top_1000_instances = jax.tree.map(lambda x: x.at[top_1000].get(), flat_env_instances)
            top_learn, top_instances = learnability.at[top_1000].get(), top_1000_instances
            top_success = success_rates.at[top_1000].get()

        if config["put_eval_levels_in_buffer"]:
            top_instances = jax.tree.map(
                lambda all, new: jnp.concatenate([all[:-num_eval_levels], new], axis=0),
                top_instances,
                all_eval_levels.env_state,
            )

        log = {
            "learnability/learnability_sampled_mean": learnability.mean(),
            "learnability/learnability_sampled_median": jnp.median(learnability),
            "learnability/learnability_sampled_min": learnability.min(),
            "learnability/learnability_sampled_max": learnability.max(),
            "learnability/learnability_selected_mean": top_learn.mean(),
            "learnability/learnability_selected_median": jnp.median(top_learn),
            "learnability/learnability_selected_min": top_learn.min(),
            "learnability/learnability_selected_max": top_learn.max(),
            "learnability/solve_rate_sampled_mean": top_success.mean(),
            "learnability/solve_rate_sampled_median": jnp.median(top_success),
            "learnability/solve_rate_sampled_min": top_success.min(),
            "learnability/solve_rate_sampled_max": top_success.max(),
            "learnability/solve_rate_selected_mean": success_rates.mean(),
            "learnability/solve_rate_selected_median": jnp.median(success_rates),
            "learnability/solve_rate_selected_min": success_rates.min(),
            "learnability/solve_rate_selected_max": success_rates.max(),
        }

        return top_learn, top_instances, log

    def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True):
        """
        This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"].
        It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
        """
        num_levels = len(config["eval_levels"])
        # eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
        return general_eval(
            rng,
            eval_env,
            env_params,
            train_state,
            all_eval_levels,
            env_params.max_timesteps,
            num_levels,
            keep_states=keep_states,
            return_trajectories=True,
        )

    def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False):
        return general_eval(
            rng,
            env,
            env_params,
            train_state,
            DR_EVAL_LEVELS,
            env_params.max_timesteps,
            NUM_EVAL_DR_LEVELS,
            keep_states=keep_states,
        )

    def eval_on_top_learnable_levels(rng: chex.PRNGKey, train_state: TrainState, levels, keep_states=True):
        N = 5
        return general_eval(
            rng,
            env,
            env_params,
            train_state,
            jax.tree.map(lambda x: x[:N], levels),
            env_params.max_timesteps,
            N,
            keep_states=keep_states,
        )

    # TRAIN LOOP
    def train_step(runner_state_instances, unused):
        # COLLECT TRAJECTORIES
        runner_state, instances = runner_state_instances
        num_env_instances = instances.polygon.position.shape[0]

        def _env_step(runner_state, unused):
            train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state

            # SELECT ACTION
            rng, _rng = jax.random.split(rng)
            obs_batch = last_obs
            ac_in = (
                jax.tree.map(lambda x: x[np.newaxis, :], obs_batch),
                last_done[np.newaxis, :],
            )
            hstate, pi, value = network.apply(train_state.params, hstate, ac_in)
            action = pi.sample(seed=_rng).squeeze()
            log_prob = pi.log_prob(action)
            env_act = action

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            rng_step = jax.random.split(_rng, config["num_train_envs"])
            obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(
                rng_step, env_state, env_act, env_params
            )
            done_batch = done
            transition = Transition(
                done,
                last_done,
                action.squeeze(),
                value.squeeze(),
                reward,
                log_prob.squeeze(),
                obs_batch,
                info,
            )
            runner_state = (train_state, env_state, start_state, obsv, done_batch, hstate, update_steps, rng)
            return runner_state, (transition)

        initial_hstate = runner_state[-3]
        runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["num_steps"])

        # CALCULATE ADVANTAGE
        train_state, env_state, start_state, last_obs, last_done, hstate, update_steps, rng = runner_state
        last_obs_batch = last_obs  # batchify(last_obs, env.agents, config["num_train_envs"])
        ac_in = (
            jax.tree.map(lambda x: x[np.newaxis, :], last_obs_batch),
            last_done[np.newaxis, :],
        )
        _, _, last_val = network.apply(train_state.params, hstate, ac_in)
        last_val = last_val.squeeze()

        def _calculate_gae(traj_batch, last_val):
            def _get_advantages(gae_and_next_value, transition: Transition):
                gae, next_value = gae_and_next_value
                done, value, reward = (
                    transition.global_done,
                    transition.value,
                    transition.reward,
                )
                delta = reward + config["gamma"] * next_value * (1 - done) - value
                gae = delta + config["gamma"] * config["gae_lambda"] * (1 - done) * gae
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)

        # UPDATE NETWORK
        def _update_epoch(update_state, unused):
            def _update_minbatch(train_state, batch_info):
                init_hstate, traj_batch, advantages, targets = batch_info

                def _loss_fn_masked(params, init_hstate, traj_batch, gae, targets):

                    # RERUN NETWORK
                    _, pi, value = network.apply(
                        params,
                        jax.tree.map(lambda x: x.transpose(), init_hstate),
                        (traj_batch.obs, traj_batch.done),
                    )
                    log_prob = pi.log_prob(traj_batch.action)

                    # CALCULATE VALUE LOSS
                    value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
                        -config["clip_eps"], config["clip_eps"]
                    )
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(value_pred_clipped - targets)
                    value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped)
                    critic_loss = config["vf_coef"] * value_loss.mean()

                    # CALCULATE ACTOR LOSS
                    logratio = log_prob - traj_batch.log_prob
                    ratio = jnp.exp(logratio)
                    # if env.do_sep_reward: gae = gae.sum(axis=-1)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = (
                        jnp.clip(
                            ratio,
                            1.0 - config["clip_eps"],
                            1.0 + config["clip_eps"],
                        )
                        * gae
                    )
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                    loss_actor = loss_actor.mean()
                    entropy = pi.entropy().mean()

                    approx_kl = jax.lax.stop_gradient(((ratio - 1) - logratio).mean())
                    clipfrac = jax.lax.stop_gradient((jnp.abs(ratio - 1) > config["clip_eps"]).mean())

                    total_loss = loss_actor + critic_loss - config["ent_coef"] * entropy
                    return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clipfrac)

                grad_fn = jax.value_and_grad(_loss_fn_masked, has_aux=True)
                total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets)
                train_state = train_state.apply_gradients(grads=grads)
                return train_state, total_loss

            (
                train_state,
                init_hstate,
                traj_batch,
                advantages,
                targets,
                rng,
            ) = update_state
            rng, _rng = jax.random.split(rng)

            init_hstate = jax.tree.map(lambda x: jnp.reshape(x, (256, config["num_train_envs"])), init_hstate)
            batch = (
                init_hstate,
                traj_batch,
                advantages.squeeze(),
                targets.squeeze(),
            )
            permutation = jax.random.permutation(_rng, config["num_train_envs"])

            shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=1), batch)

            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.swapaxes(
                    jnp.reshape(
                        x,
                        [x.shape[0], config["num_minibatches"], -1] + list(x.shape[2:]),
                    ),
                    1,
                    0,
                ),
                shuffled_batch,
            )

            train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches)
            # total_loss = jax.tree.map(lambda x: x.mean(), total_loss)
            update_state = (
                train_state,
                init_hstate,
                traj_batch,
                advantages,
                targets,
                rng,
            )
            return update_state, total_loss

        # init_hstate = initial_hstate[None, :].squeeze().transpose()
        init_hstate = jax.tree.map(lambda x: x[None, :].squeeze().transpose(), initial_hstate)
        update_state = (
            train_state,
            init_hstate,
            traj_batch,
            advantages,
            targets,
            rng,
        )
        update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["update_epochs"])
        train_state = update_state[0]
        metric = traj_batch.info
        metric = jax.tree.map(
            lambda x: x.reshape((config["num_steps"], config["num_train_envs"])),  # , env.num_agents
            traj_batch.info,
        )
        rng = update_state[-1]

        def callback(metric):
            dones = metric["dones"]
            wandb.log(
                {
                    "episode_return": (metric["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()),
                    "episode_solved": (metric["returned_episode_solved"] * dones).sum() / jnp.maximum(1, dones.sum()),
                    "episode_length": (metric["returned_episode_lengths"] * dones).sum() / jnp.maximum(1, dones.sum()),
                    "timing/num_env_steps": int(
                        int(metric["update_steps"]) * int(config["num_train_envs"]) * int(config["num_steps"])
                    ),
                    "timing/num_updates": metric["update_steps"],
                    **metric["loss_info"],
                }
            )

        loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
        metric["loss_info"] = {
            "loss/total_loss": loss_info[0],
            "loss/value_loss": loss_info[1][0],
            "loss/policy_loss": loss_info[1][1],
            "loss/entropy_loss": loss_info[1][2],
        }
        metric["dones"] = traj_batch.done
        metric["update_steps"] = update_steps
        jax.experimental.io_callback(callback, None, metric)

        # SAMPLE NEW ENVS
        rng, _rng, _rng2 = jax.random.split(rng, 3)
        rng_reset = jax.random.split(_rng, config["num_envs_to_generate"])

        new_levels = sample_random_levels(_rng2, config["num_envs_to_generate"])
        obsv_gen, env_state_gen = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(rng_reset, new_levels, env_params)

        rng, _rng, _rng2 = jax.random.split(rng, 3)
        sampled_env_instances_idxs = jax.random.randint(_rng, (config["num_envs_from_sampled"],), 0, num_env_instances)
        sampled_env_instances = jax.tree.map(lambda x: x.at[sampled_env_instances_idxs].get(), instances)
        myrng = jax.random.split(_rng2, config["num_envs_from_sampled"])
        obsv_sampled, env_state_sampled = jax.vmap(env.reset_to_level, in_axes=(0, 0))(myrng, sampled_env_instances)

        obsv = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), obsv_gen, obsv_sampled)
        env_state = jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), env_state_gen, env_state_sampled)

        start_state = env_state
        hstate = ScannedRNN.initialize_carry(config["num_train_envs"])

        update_steps = update_steps + 1
        runner_state = (
            train_state,
            env_state,
            start_state,
            obsv,
            jnp.zeros((config["num_train_envs"]), dtype=bool),
            hstate,
            update_steps,
            rng,
        )
        return (runner_state, instances), metric

    def log_buffer(learnability, levels, epoch):
        num_samples = levels.polygon.position.shape[0]
        states = levels
        rows = 2
        fig, axes = plt.subplots(rows, int(num_samples / rows), figsize=(20, 10))
        axes = axes.flatten()
        all_imgs = jax.vmap(render_fn)(states)
        for i, ax in enumerate(axes):
            # ax.imshow(train_state.plr_buffer.get_sample(i))
            score = learnability[i]
            ax.imshow(all_imgs[i] / 255.0)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title(f"learnability: {score:.3f}")
            ax.set_aspect("equal", "box")

        plt.tight_layout()
        fig.canvas.draw()
        im = Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
        plt.close()
        return {"maps": wandb.Image(im)}

    @jax.jit
    def train_and_eval_step(runner_state, eval_rng):

        learnability_rng, eval_singleton_rng, eval_sampled_rng, _rng = jax.random.split(eval_rng, 4)
        # TRAIN
        learnabilty_scores, instances, test_metrics = get_learnability_set(learnability_rng, runner_state[0].params)

        if config["log_learnability_before_after"]:
            learn_scores_before, success_score_before = log_buffer_learnability(
                learnability_rng, runner_state[0], instances
            )

        print("instance size", sum(x.size for x in jax.tree_util.tree_leaves(instances)))

        runner_state_instances = (runner_state, instances)
        runner_state_instances, metrics = jax.lax.scan(train_step, runner_state_instances, None, config["eval_freq"])

        if config["log_learnability_before_after"]:
            learn_scores_after, success_score_after = log_buffer_learnability(
                learnability_rng, runner_state_instances[0][0], instances
            )

        # EVAL
        rng, rng_eval = jax.random.split(eval_singleton_rng)
        (states, cum_rewards, _, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(eval, (0, None))(
            jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0]
        )
        all_eval_eplens = episode_lengths

        # Collect Metrics
        eval_returns = cum_rewards.mean(axis=0)  # (num_eval_levels,)
        eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
            1, eval_dones.sum(axis=1)
        )
        eval_solves = eval_solves.mean(axis=0)
        # just grab the first run
        states, episode_lengths = jax.tree_util.tree_map(
            lambda x: x[0], (states, episode_lengths)
        )  # (num_steps, num_eval_levels, ...), (num_eval_levels,)
        # And one attempt
        states = jax.tree_util.tree_map(lambda x: x[:, :], states)
        episode_lengths = episode_lengths[:]
        images = jax.vmap(jax.vmap(render_fn_eval))(
            states.env_state.env_state.env_state
        )  # (num_steps, num_eval_levels, ...)
        frames = images.transpose(
            0, 1, 4, 2, 3
        )  # WandB expects color channel before image dimensions when dealing with animations for some reason

        test_metrics["update_count"] = runner_state[-2]
        test_metrics["eval_returns"] = eval_returns
        test_metrics["eval_ep_lengths"] = episode_lengths
        test_metrics["eval_animation"] = (frames, episode_lengths)

        # Eval on sampled
        dr_states, dr_cum_rewards, _, dr_episode_lengths, dr_infos = jax.vmap(eval_on_dr_levels, (0, None))(
            jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0]
        )

        eval_dr_returns = dr_cum_rewards.mean(axis=0).mean()
        eval_dr_eplen = dr_episode_lengths.mean(axis=0).mean()

        test_metrics["eval/mean_eval_return_sampled"] = eval_dr_returns
        my_eval_dones = dr_infos["returned_episode"]
        eval_dr_solves = (dr_infos["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum(
            1, my_eval_dones.sum(axis=1)
        )

        test_metrics["eval/mean_eval_solve_rate_sampled"] = eval_dr_solves
        test_metrics["eval/mean_eval_eplen_sampled"] = eval_dr_eplen

        # Collect Metrics
        eval_returns = cum_rewards.mean(axis=0)  # (num_eval_levels,)

        log_dict = {}

        log_dict["to_remove"] = {
            "eval_return": eval_returns,
            "eval_solve_rate": eval_solves,
            "eval_eplen": all_eval_eplens,
        }

        for i, name in enumerate(config["eval_levels"]):
            log_dict[f"eval_avg_return/{name}"] = eval_returns[i]
            log_dict[f"eval_avg_solve_rate/{name}"] = eval_solves[i]

        log_dict.update({"eval/mean_eval_return": eval_returns.mean()})
        log_dict.update({"eval/mean_eval_solve_rate": eval_solves.mean()})
        log_dict.update({"eval/mean_eval_eplen": all_eval_eplens.mean()})

        test_metrics.update(log_dict)

        runner_state, _ = runner_state_instances
        test_metrics["update_count"] = runner_state[-2]

        top_instances = jax.tree.map(lambda x: x.at[-5:].get(), instances)

        # Eval on top learnable levels
        tl_states, tl_cum_rewards, _, tl_episode_lengths, tl_infos = jax.vmap(
            eval_on_top_learnable_levels, (0, None, None)
        )(jax.random.split(rng_eval, config["eval_num_attempts"]), runner_state_instances[0][0], top_instances)

        # just grab the first run
        states, episode_lengths = jax.tree_util.tree_map(
            lambda x: x[0], (tl_states, tl_episode_lengths)
        )  # (num_steps, num_eval_levels, ...), (num_eval_levels,)
        # And one attempt
        states = jax.tree_util.tree_map(lambda x: x[:, :], states)
        episode_lengths = episode_lengths[:]
        images = jax.vmap(jax.vmap(render_fn))(
            states.env_state.env_state.env_state
        )  # (num_steps, num_eval_levels, ...)
        frames = images.transpose(
            0, 1, 4, 2, 3
        )  # WandB expects color channel before image dimensions when dealing with animations for some reason

        test_metrics["top_learnable_animation"] = (frames, episode_lengths, tl_cum_rewards)

        if config["log_learnability_before_after"]:

            def single(x, name):
                return {
                    f"{name}_mean": x.mean(),
                    f"{name}_std": x.std(),
                    f"{name}_min": x.min(),
                    f"{name}_max": x.max(),
                    f"{name}_median": jnp.median(x),
                }

            test_metrics["learnability_log_v2/"] = {
                **single(learn_scores_before, "learnability_before"),
                **single(learn_scores_after, "learnability_after"),
                **single(success_score_before, "success_score_before"),
                **single(success_score_after, "success_score_after"),
            }

        return runner_state, (learnabilty_scores.at[-20:].get(), top_instances), test_metrics

    rng, _rng = jax.random.split(rng)
    runner_state = (
        train_state,
        env_state,
        start_state,
        obsv,
        jnp.zeros((config["num_train_envs"]), dtype=bool),
        init_hstate,
        0,
        _rng,
    )

    def log_eval(stats):
        log_dict = {}

        to_remove = stats["to_remove"]
        del stats["to_remove"]

        def _aggregate_per_size(values, name):
            to_return = {}
            for group_name, indices in eval_group_indices.items():
                to_return[f"{name}_{group_name}"] = values[indices].mean()
            return to_return

        env_steps = stats["update_count"] * config["num_train_envs"] * config["num_steps"]
        env_steps_delta = config["eval_freq"] * config["num_train_envs"] * config["num_steps"]
        time_now = time.time()
        log_dict = {
            "timing/num_updates": stats["update_count"],
            "timing/num_env_steps": env_steps,
            "timing/sps": env_steps_delta / stats["time_delta"],
            "timing/sps_agg": env_steps / (time_now - time_start),
        }
        log_dict.update(_aggregate_per_size(to_remove["eval_return"], "eval_aggregate/return"))
        log_dict.update(_aggregate_per_size(to_remove["eval_solve_rate"], "eval_aggregate/solve_rate"))

        for i in range((len(config["eval_levels"]))):
            frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i]
            frames = np.array(frames[:episode_length])
            log_dict.update(
                {
                    f"media/eval_video_{config['eval_levels'][i]}": wandb.Video(
                        frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})"
                    )
                }
            )

        for j in range(5):
            frames, episode_length, cum_rewards = (
                stats["top_learnable_animation"][0][:, j],
                stats["top_learnable_animation"][1][j],
                stats["top_learnable_animation"][2][:, j],
            )  # num attempts
            rr = "|".join([f"{r:<.2f}" for r in cum_rewards])
            frames = np.array(frames[:episode_length])
            log_dict.update(
                {
                    f"media/tl_animation_{j}": wandb.Video(
                        frames.astype(np.uint8), fps=15, caption=f"(len {episode_length})\n{rr}"
                    )
                }
            )

        stats.update(log_dict)
        wandb.log(stats, step=stats["update_count"])

    checkpoint_steps = config["checkpoint_save_freq"]
    assert config["num_updates"] % config["eval_freq"] == 0, "num_updates must be divisible by eval_freq"

    for eval_step in range(int(config["num_updates"] // config["eval_freq"])):
        start_time = time.time()
        rng, eval_rng = jax.random.split(rng)
        runner_state, instances, metrics = train_and_eval_step(runner_state, eval_rng)
        curr_time = time.time()
        metrics.update(log_buffer(*instances, metrics["update_count"]))
        metrics["time_delta"] = curr_time - start_time
        metrics["steps_per_section"] = (config["eval_freq"] * config["num_steps"] * config["num_train_envs"]) / metrics[
            "time_delta"
        ]
        log_eval(metrics)
        if ((eval_step + 1) * config["eval_freq"]) % checkpoint_steps == 0:
            if config["save_path"] is not None:
                steps = int(metrics["update_count"]) * int(config["num_train_envs"]) * int(config["num_steps"])
                # save_params_to_wandb(runner_state[0].params, steps, config)
                save_model_to_wandb(runner_state[0], steps, config)

    if config["save_path"] is not None:
        # save_params_to_wandb(runner_state[0].params, config["total_timesteps"], config)
        save_model_to_wandb(runner_state[0], config["total_timesteps"], config)


if __name__ == "__main__":
    # with jax.disable_jit():
    #     main()
    main()