from functools import partial
import time
from enum import IntEnum
from typing import Tuple

import chex
import hydra
import jax
import jax.numpy as jnp
import numpy as np
from omegaconf import OmegaConf
import optax
from flax import core, struct
from flax.training.train_state import TrainState as BaseTrainState

import wandb
from kinetix.environment.ued.distributions import (
    create_random_starting_distribution,
)
from kinetix.environment.ued.ued import (
    make_mutate_env,
    make_reset_train_function_with_mutations,
    make_vmapped_filtered_level_sampler,
)
from kinetix.environment.ued.ued import (
    make_mutate_env,
    make_reset_train_function_with_list_of_levels,
    make_reset_train_function_with_mutations,
)
from kinetix.util.config import (
    generate_ued_params_from_config,
    get_video_frequency,
    init_wandb,
    normalise_config,
    save_data_to_local_file,
    generate_params_from_config,
    get_eval_level_groups,
)
from jaxued.environments.underspecified_env import EnvState
from jaxued.level_sampler import LevelSampler
from jaxued.utils import compute_max_returns, max_mc, positive_value_loss
from flax.serialization import to_state_dict

import sys

sys.path.append("experiments")
from kinetix.environment.env import make_kinetix_env_from_name
from kinetix.environment.env_state import StaticEnvParams
from kinetix.environment.wrappers import (
    UnderspecifiedToGymnaxWrapper,
    LogWrapper,
    DenseRewardWrapper,
    AutoReplayWrapper,
)
from kinetix.models import make_network_from_config
from kinetix.render.renderer_pixels import make_render_pixels
from kinetix.models.actor_critic import ScannedRNN
from kinetix.util.learning import (
    general_eval,
    get_eval_levels,
    no_op_and_random_rollout,
    sample_trajectories_and_learn,
)
from kinetix.util.saving import (
    load_train_state_from_wandb_artifact_path,
    save_model_to_wandb,
)


class UpdateState(IntEnum):
    DR = 0
    REPLAY = 1
    MUTATE = 2


def get_level_complexity_metrics(all_levels: EnvState, static_env_params: StaticEnvParams):
    def get_for_single_level(level):
        return {
            "complexity/num_shapes": level.polygon.active[static_env_params.num_static_fixated_polys :].sum()
            + level.circle.active.sum(),
            "complexity/num_joints": level.joint.active.sum(),
            "complexity/num_thrusters": level.thruster.active.sum(),
            "complexity/num_rjoints": (level.joint.active * jnp.logical_not(level.joint.is_fixed_joint)).sum(),
            "complexity/num_fjoints": (level.joint.active * (level.joint.is_fixed_joint)).sum(),
            "complexity/has_ball": ((level.polygon_shape_roles == 1) * level.polygon.active).sum()
            + ((level.circle_shape_roles == 1) * level.circle.active).sum(),
            "complexity/has_goal": ((level.polygon_shape_roles == 2) * level.polygon.active).sum()
            + ((level.circle_shape_roles == 2) * level.circle.active).sum(),
        }

    return jax.tree.map(lambda x: x.mean(), jax.vmap(get_for_single_level)(all_levels))


def get_ued_score_metrics(all_ued_scores):
    (mc, pvl, learn) = all_ued_scores
    scores = {}
    for score, name in zip([mc, pvl, learn], ["MaxMC", "PVL", "Learnability"]):
        scores[f"ued_scores/{name}/Mean"] = score.mean()
        scores[f"ued_scores_additional/{name}/Max"] = score.max()
        scores[f"ued_scores_additional/{name}/Min"] = score.min()

    return scores


class TrainState(BaseTrainState):
    sampler: core.FrozenDict[str, chex.ArrayTree] = struct.field(pytree_node=True)
    update_state: UpdateState = struct.field(pytree_node=True)
    # === Below is used for logging ===
    num_dr_updates: int
    num_replay_updates: int
    num_mutation_updates: int

    dr_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
    replay_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)
    mutation_last_level_batch_scores: chex.ArrayTree = struct.field(pytree_node=True)

    dr_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
    replay_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)
    mutation_last_level_batch: chex.ArrayTree = struct.field(pytree_node=True)

    dr_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
    replay_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)
    mutation_last_rollout_batch: chex.ArrayTree = struct.field(pytree_node=True)


# region PPO helper functions

# endregion


def train_state_to_log_dict(train_state: TrainState, level_sampler: LevelSampler) -> dict:
    """To prevent the entire (large) train_state to be copied to the CPU when doing logging, this function returns all of the important information in a dictionary format.

        Anything in the `log` key will be logged to wandb.

    Args:
        train_state (TrainState):
        level_sampler (LevelSampler):

    Returns:
        dict:
    """
    sampler = train_state.sampler
    idx = jnp.arange(level_sampler.capacity) < sampler["size"]
    s = jnp.maximum(idx.sum(), 1)
    return {
        "log": {
            "level_sampler/size": sampler["size"],
            "level_sampler/episode_count": sampler["episode_count"],
            "level_sampler/max_score": sampler["scores"].max(),
            "level_sampler/weighted_score": (sampler["scores"] * level_sampler.level_weights(sampler)).sum(),
            "level_sampler/mean_score": (sampler["scores"] * idx).sum() / s,
        },
        "info": {
            "num_dr_updates": train_state.num_dr_updates,
            "num_replay_updates": train_state.num_replay_updates,
            "num_mutation_updates": train_state.num_mutation_updates,
        },
    }


def compute_learnability(config, done, reward, info, num_envs):
    num_agents = 1
    BATCH_ACTORS = num_envs * num_agents

    rollout_length = config["num_steps"] * config["outer_rollout_steps"]

    @partial(jax.vmap, in_axes=(None, 1, 1, 1))
    @partial(jax.jit, static_argnums=(0,))
    def _calc_outcomes_by_agent(max_steps: int, dones, returns, info):
        idxs = jnp.arange(max_steps)

        @partial(jax.vmap, in_axes=(0, 0))
        def __ep_outcomes(start_idx, end_idx):
            mask = (idxs > start_idx) & (idxs <= end_idx) & (end_idx != max_steps)
            r = jnp.sum(returns * mask)
            goal_r = info["GoalR"]
            success = jnp.sum(goal_r * mask)
            collision = 0
            timeo = 0
            l = end_idx - start_idx
            return r, success, collision, timeo, l

        done_idxs = jnp.argwhere(dones, size=50, fill_value=max_steps).squeeze()
        mask_done = jnp.where(done_idxs == max_steps, 0, 1)
        ep_return, success, collision, timeo, length = __ep_outcomes(
            jnp.concatenate([jnp.array([-1]), done_idxs[:-1]]), done_idxs
        )

        return {
            "ep_return": ep_return.mean(where=mask_done),
            "num_episodes": mask_done.sum(),
            "num_success": success.sum(where=mask_done),
            "success_rate": success.mean(where=mask_done),
            "collision_rate": collision.mean(where=mask_done),
            "timeout_rate": timeo.mean(where=mask_done),
            "ep_len": length.mean(where=mask_done),
        }

    done_by_env = done.reshape((-1, num_agents, num_envs))
    reward_by_env = reward.reshape((-1, num_agents, num_envs))
    o = _calc_outcomes_by_agent(rollout_length, done, reward, info)
    success_by_env = o["success_rate"].reshape((num_agents, num_envs))
    learnability_by_env = (success_by_env * (1 - success_by_env)).sum(axis=0)

    return (
        learnability_by_env,
        o["num_episodes"].reshape(num_agents, num_envs).sum(axis=0),
        o["num_success"].reshape(num_agents, num_envs).T,
    )  # so agents is at the end.


def compute_score(
    config: dict, dones: chex.Array, values: chex.Array, max_returns: chex.Array, reward, info, advantages: chex.Array
) -> chex.Array:
    # Computes the score for each level
    if config["score_function"] == "MaxMC":
        return max_mc(dones, values, max_returns)
    elif config["score_function"] == "pvl":
        return positive_value_loss(dones, advantages)
    elif config["score_function"] == "learnability":
        learnability, num_episodes, num_success = compute_learnability(
            config, dones, reward, info, config["num_train_envs"]
        )
        return learnability
    else:
        raise ValueError(f"Unknown score function: {config['score_function']}")


def compute_all_scores(
    config: dict,
    dones: chex.Array,
    values: chex.Array,
    max_returns: chex.Array,
    reward,
    info,
    advantages: chex.Array,
    return_success_rate=False,
):
    mc = max_mc(dones, values, max_returns)
    pvl = positive_value_loss(dones, advantages)
    learnability, num_episodes, num_success = compute_learnability(
        config, dones, reward, info, config["num_train_envs"]
    )
    if config["score_function"] == "MaxMC":
        main_score = mc
    elif config["score_function"] == "pvl":
        main_score = pvl
    elif config["score_function"] == "learnability":
        main_score = learnability
    else:
        raise ValueError(f"Unknown score function: {config['score_function']}")
    if return_success_rate:
        success_rate = num_success.squeeze(1) / jnp.maximum(num_episodes, 1)
        return main_score, (mc, pvl, learnability, success_rate)
    return main_score, (mc, pvl, learnability)


@hydra.main(version_base=None, config_path="../configs", config_name="plr")
def main(config=None):
    my_name = "PLR"
    config = OmegaConf.to_container(config)
    if config["ued"]["replay_prob"] == 0.0:
        my_name = "DR"
    elif config["ued"]["use_accel"]:
        my_name = "ACCEL"

    time_start = time.time()
    config = normalise_config(config, my_name)
    env_params, static_env_params = generate_params_from_config(config)
    config["env_params"] = to_state_dict(env_params)
    config["static_env_params"] = to_state_dict(static_env_params)

    run = init_wandb(config, my_name)
    config = wandb.config
    time_prev = time.time()

    def log_eval(stats, train_state_info):
        nonlocal time_prev
        print(f"Logging update: {stats['update_count']}")
        total_loss = jnp.mean(stats["losses"][0])
        if jnp.isnan(total_loss):
            print("NaN loss, skipping logging")
            raise ValueError("NaN loss")

        # generic stats
        env_steps = int(
            int(stats["update_count"]) * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"]
        )
        env_steps_delta = (
            config["eval_freq"] * config["num_train_envs"] * config["num_steps"] * config["outer_rollout_steps"]
        )
        time_now = time.time()
        log_dict = {
            "timing/num_updates": stats["update_count"],
            "timing/num_env_steps": env_steps,
            "timing/sps": env_steps_delta / (time_now - time_prev),
            "timing/sps_agg": env_steps / (time_now - time_start),
            "loss/total_loss": jnp.mean(stats["losses"][0]),
            "loss/value_loss": jnp.mean(stats["losses"][1][0]),
            "loss/policy_loss": jnp.mean(stats["losses"][1][1]),
            "loss/entropy_loss": jnp.mean(stats["losses"][1][2]),
        }
        time_prev = time_now

        # evaluation performance

        returns = stats["eval_returns"]
        log_dict.update({"eval/mean_eval_return": returns.mean()})
        log_dict.update({"eval/mean_eval_learnability": stats["eval_learn"].mean()})
        log_dict.update({"eval/mean_eval_solve_rate": stats["eval_solves"].mean()})
        log_dict.update({"eval/mean_eval_eplen": stats["eval_ep_lengths"].mean()})
        for i in range(config["num_eval_levels"]):
            log_dict[f"eval_avg_return/{config['eval_levels'][i]}"] = returns[i]
            log_dict[f"eval_avg_learnability/{config['eval_levels'][i]}"] = stats["eval_learn"][i]
            log_dict[f"eval_avg_solve_rate/{config['eval_levels'][i]}"] = stats["eval_solves"][i]
            log_dict[f"eval_avg_episode_length/{config['eval_levels'][i]}"] = stats["eval_ep_lengths"][i]
            log_dict[f"eval_get_max_eplen/{config['eval_levels'][i]}"] = stats["eval_get_max_eplen"][i]
            log_dict[f"episode_return_bigger_than_negative/{config['eval_levels'][i]}"] = stats[
                "episode_return_bigger_than_negative"
            ][i]

        def _aggregate_per_size(values, name):
            to_return = {}
            for group_name, indices in eval_group_indices.items():
                to_return[f"{name}_{group_name}"] = values[indices].mean()
            return to_return

        log_dict.update(_aggregate_per_size(returns, "eval_aggregate/return"))
        log_dict.update(_aggregate_per_size(stats["eval_solves"], "eval_aggregate/solve_rate"))

        if config["EVAL_ON_SAMPLED"]:
            log_dict.update({"eval/mean_eval_return_sampled": stats["eval_dr_returns"].mean()})
            log_dict.update({"eval/mean_eval_solve_rate_sampled": stats["eval_dr_solve_rates"].mean()})
            log_dict.update({"eval/mean_eval_eplen_sampled": stats["eval_dr_eplen"].mean()})

        # level sampler
        log_dict.update(train_state_info["log"])

        # images
        log_dict.update(
            {
                "images/highest_scoring_level": wandb.Image(
                    np.array(stats["highest_scoring_level"]), caption="Highest scoring level"
                )
            }
        )
        log_dict.update(
            {
                "images/highest_weighted_level": wandb.Image(
                    np.array(stats["highest_weighted_level"]), caption="Highest weighted level"
                )
            }
        )

        for s in ["dr", "replay", "mutation"]:
            if train_state_info["info"][f"num_{s}_updates"] > 0:
                log_dict.update(
                    {
                        f"images/{s}_levels": [
                            wandb.Image(np.array(image), caption=f"{score}")
                            for image, score in zip(stats[f"{s}_levels"], stats[f"{s}_scores"])
                        ]
                    }
                )
                if stats["log_videos"]:
                    # animations
                    rollout_ep = stats[f"{s}_ep_len"]
                    arr = np.array(stats[f"{s}_rollout"][:rollout_ep])
                    log_dict.update(
                        {
                            f"media/{s}_eval": wandb.Video(
                                arr.astype(np.uint8), fps=15, caption=f"{s.capitalize()} (len {rollout_ep})"
                            )
                        }
                    )
                #  * 255

        # DR, Replay and Mutate Returns
        dr_inds = (stats["update_state"] == UpdateState.DR).nonzero()[0]
        rep_inds = (stats["update_state"] == UpdateState.REPLAY).nonzero()[0]
        mut_inds = (stats["update_state"] == UpdateState.MUTATE).nonzero()[0]

        for name, inds in [
            ("DR", dr_inds),
            ("REPLAY", rep_inds),
            ("MUTATION", mut_inds),
        ]:
            if len(inds) > 0:
                log_dict.update(
                    {
                        f"{name}/episode_return": stats["episode_return"][inds].mean(),
                        f"{name}/mean_eplen": stats["returned_episode_lengths"][inds].mean(),
                        f"{name}/mean_success": stats["returned_episode_solved"][inds].mean(),
                        f"{name}/noop_return": stats["noop_returns"][inds].mean(),
                        f"{name}/noop_eplen": stats["noop_eplen"][inds].mean(),
                        f"{name}/noop_success": stats["noop_success"][inds].mean(),
                        f"{name}/random_return": stats["random_returns"][inds].mean(),
                        f"{name}/random_eplen": stats["random_eplen"][inds].mean(),
                        f"{name}/random_success": stats["random_success"][inds].mean(),
                    }
                )
                for k in stats:
                    if "complexity/" in k:
                        k2 = "complexity/" + name + "_" + k.replace("complexity/", "")
                        log_dict.update({k2: stats[k][inds].mean()})
                    if "ued_scores/" in k:
                        k2 = "ued_scores/" + name + "_" + k.replace("ued_scores/", "")
                        log_dict.update({k2: stats[k][inds].mean()})

        # Eval rollout animations
        if stats["log_videos"]:
            for i in range((config["num_eval_levels"])):
                frames, episode_length = stats["eval_animation"][0][:, i], stats["eval_animation"][1][i]
                frames = np.array(frames[:episode_length])
                log_dict.update(
                    {
                        f"media/eval_video_{config['eval_levels'][i]}": wandb.Video(
                            frames.astype(np.uint8), fps=15, caption=f"Len ({episode_length})"
                        )
                    }
                )

        wandb.log(log_dict)

    def get_all_metrics(
        rng,
        losses,
        info,
        init_env_state,
        init_obs,
        dones,
        grads,
        all_ued_scores,
        new_levels,
    ):
        noop_returns, noop_len, noop_success, random_returns, random_lens, random_success = no_op_and_random_rollout(
            env,
            env_params,
            rng,
            init_obs,
            init_env_state,
            config["num_train_envs"],
            config["num_steps"] * config["outer_rollout_steps"],
        )
        metrics = (
            {
                "losses": jax.tree_util.tree_map(lambda x: x.mean(), losses),
                "returned_episode_lengths": (info["returned_episode_lengths"] * dones).sum()
                / jnp.maximum(1, dones.sum()),
                "max_episode_length": info["returned_episode_lengths"].max(),
                "levels_played": init_env_state.env_state.env_state,
                "episode_return": (info["returned_episode_returns"] * dones).sum() / jnp.maximum(1, dones.sum()),
                "episode_return_v2": (info["returned_episode_returns"] * info["returned_episode"]).sum()
                / jnp.maximum(1, info["returned_episode"].sum()),
                "grad_norms": grads.mean(),
                "noop_returns": noop_returns,
                "noop_eplen": noop_len,
                "noop_success": noop_success,
                "random_returns": random_returns,
                "random_eplen": random_lens,
                "random_success": random_success,
                "returned_episode_solved": (info["returned_episode_solved"] * dones).sum()
                / jnp.maximum(1, dones.sum()),
            }
            | get_level_complexity_metrics(new_levels, static_env_params)
            | get_ued_score_metrics(all_ued_scores)
        )
        return metrics

    # Setup the environment.
    def make_env(static_env_params):
        env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params)
        env = AutoReplayWrapper(env)
        env = UnderspecifiedToGymnaxWrapper(env)
        env = DenseRewardWrapper(env, dense_reward_scale=config["dense_reward_scale"])
        env = LogWrapper(env)
        return env

    env = make_env(static_env_params)

    if config["train_level_mode"] == "list":
        sample_random_level = make_reset_train_function_with_list_of_levels(
            config, config["train_levels_list"], static_env_params, make_pcg_state=False, is_loading_train_levels=True
        )
    elif config["train_level_mode"] == "random":
        sample_random_level = make_reset_train_function_with_mutations(
            env.physics_engine, env_params, static_env_params, config, make_pcg_state=False
        )
    else:
        raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}")

    if config["use_accel"] and config["accel_start_from_empty"]:

        def make_sample_random_level():
            def inner(rng):
                def _inner_accel(rng):
                    return create_random_starting_distribution(
                        rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=True
                    )

                def _inner_accel_not_controllable(rng):
                    return create_random_starting_distribution(
                        rng, env_params, static_env_params, ued_params, config["env_size_name"], controllable=False
                    )

                rng, _rng = jax.random.split(rng)
                return _inner_accel(_rng)

            return inner

        sample_random_level = make_sample_random_level()

    sample_random_levels = make_vmapped_filtered_level_sampler(
        sample_random_level, env_params, static_env_params, config, make_pcg_state=False, env=env
    )

    def generate_world():
        raise NotImplementedError
        pass

    def generate_eval_world(rng, env_params, static_env_params, level_idx):
        # jax.random.split(jax.random.PRNGKey(101), num_levels), env_params, static_env_params, jnp.arange(num_levels)

        raise NotImplementedError

    _, eval_static_env_params = generate_params_from_config(
        config["eval_env_size_true"] | {"frame_skip": config["frame_skip"]}
    )
    eval_env = make_env(eval_static_env_params)
    ued_params = generate_ued_params_from_config(config)

    mutate_world = make_mutate_env(static_env_params, env_params, ued_params)

    def make_render_fn(static_env_params):
        render_fn_inner = make_render_pixels(env_params, static_env_params)
        render_fn = lambda x: render_fn_inner(x).transpose(1, 0, 2)[::-1]
        return render_fn

    render_fn = make_render_fn(static_env_params)
    render_fn_eval = make_render_fn(eval_static_env_params)
    if config["EVAL_ON_SAMPLED"]:
        NUM_EVAL_DR_LEVELS = 200
        key_to_sample_dr_eval_set = jax.random.PRNGKey(100)
        DR_EVAL_LEVELS = sample_random_levels(key_to_sample_dr_eval_set, NUM_EVAL_DR_LEVELS)

    # And the level sampler
    level_sampler = LevelSampler(
        capacity=config["level_buffer_capacity"],
        replay_prob=config["replay_prob"],
        staleness_coeff=config["staleness_coeff"],
        minimum_fill_ratio=config["minimum_fill_ratio"],
        prioritization=config["prioritization"],
        prioritization_params={"temperature": config["temperature"], "k": config["topk_k"]},
        duplicate_check=config["buffer_duplicate_check"],
    )

    @jax.jit
    def create_train_state(rng) -> TrainState:
        # Creates the train state
        def linear_schedule(count):
            frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / (
                config["num_updates"] * config["outer_rollout_steps"]
            )
            return config["lr"] * frac

        rng, _rng = jax.random.split(rng)
        init_state = jax.tree.map(lambda x: x[0], sample_random_levels(_rng, 1))

        rng, _rng = jax.random.split(rng)
        obs, _ = env.reset_to_level(_rng, init_state, env_params)
        ns = config["num_steps"] * config["outer_rollout_steps"]
        obs = jax.tree.map(
            lambda x: jnp.repeat(jnp.repeat(x[None, ...], config["num_train_envs"], axis=0)[None, ...], ns, axis=0),
            obs,
        )
        init_x = (obs, jnp.zeros((ns, config["num_train_envs"]), dtype=jnp.bool_))
        network = make_network_from_config(env, env_params, config)
        rng, _rng = jax.random.split(rng)
        network_params = network.init(_rng, ScannedRNN.initialize_carry(config["num_train_envs"]), init_x)

        if config["anneal_lr"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["max_grad_norm"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["max_grad_norm"]),
                optax.adam(config["lr"], eps=1e-5),
            )

        pholder_level = jax.tree.map(lambda x: x[0], sample_random_levels(jax.random.PRNGKey(0), 1))
        sampler = level_sampler.initialize(pholder_level, {"max_return": -jnp.inf})
        pholder_level_batch = jax.tree_util.tree_map(
            lambda x: jnp.array([x]).repeat(config["num_train_envs"], axis=0), pholder_level
        )
        pholder_rollout_batch = (
            jax.tree.map(
                lambda x: jnp.repeat(
                    jnp.expand_dims(x, 0), repeats=config["num_steps"] * config["outer_rollout_steps"], axis=0
                ),
                init_state,
            ),
            init_x[1][:, 0],
        )

        pholder_level_batch_scores = jnp.zeros((config["num_train_envs"],), dtype=jnp.float32)
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
            sampler=sampler,
            update_state=0,
            num_dr_updates=0,
            num_replay_updates=0,
            num_mutation_updates=0,
            dr_last_level_batch_scores=pholder_level_batch_scores,
            replay_last_level_batch_scores=pholder_level_batch_scores,
            mutation_last_level_batch_scores=pholder_level_batch_scores,
            dr_last_level_batch=pholder_level_batch,
            replay_last_level_batch=pholder_level_batch,
            mutation_last_level_batch=pholder_level_batch,
            dr_last_rollout_batch=pholder_rollout_batch,
            replay_last_rollout_batch=pholder_rollout_batch,
            mutation_last_rollout_batch=pholder_rollout_batch,
        )

        if config["load_from_checkpoint"] != None:
            print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"])
            train_state = load_train_state_from_wandb_artifact_path(
                train_state,
                config["load_from_checkpoint"],
                load_only_params=config["load_only_params"],
                legacy=config["load_legacy_checkpoint"],
            )
        return train_state

    all_eval_levels = get_eval_levels(config["eval_levels"], eval_env.static_env_params)
    eval_group_indices = get_eval_level_groups(config["eval_levels"])

    @jax.jit
    def train_step(carry: Tuple[chex.PRNGKey, TrainState], _):
        """
        This is the main training loop. It basically calls either `on_new_levels`, `on_replay_levels`, or `on_mutate_levels` at every step.
        """

        def on_new_levels(rng: chex.PRNGKey, train_state: TrainState):
            """
            Samples new (randomly-generated) levels and evaluates the policy on these. It also then adds the levels to the level buffer if they have high-enough scores.
            The agent is updated on these trajectories iff `config["exploratory_grad_updates"]` is True.
            """
            sampler = train_state.sampler

            # Reset
            rng, rng_levels, rng_reset = jax.random.split(rng, 3)
            new_levels = sample_random_levels(rng_levels, config["num_train_envs"])
            init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
                jax.random.split(rng_reset, config["num_train_envs"]), new_levels, env_params
            )
            init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
            # Rollout
            (
                (rng, train_state, new_hstate, last_obs, last_env_state),
                (
                    obs,
                    actions,
                    rewards,
                    dones,
                    log_probs,
                    values,
                    info,
                    advantages,
                    targets,
                    losses,
                    grads,
                    rollout_states,
                ),
            ) = sample_trajectories_and_learn(
                env,
                env_params,
                config,
                rng,
                train_state,
                init_hstate,
                init_obs,
                init_env_state,
                update_grad=config["exploratory_grad_updates"],
                return_states=True,
            )
            max_returns = compute_max_returns(dones, rewards)
            scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
            sampler, _ = level_sampler.insert_batch(sampler, new_levels, scores, {"max_return": max_returns})
            rng, _rng = jax.random.split(rng)
            metrics = {
                "update_state": UpdateState.DR,
            } | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, new_levels)

            train_state = train_state.replace(
                sampler=sampler,
                update_state=UpdateState.DR,
                num_dr_updates=train_state.num_dr_updates + 1,
                dr_last_level_batch=new_levels,
                dr_last_level_batch_scores=scores,
                dr_last_rollout_batch=jax.tree.map(
                    lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
                ),
            )
            return (rng, train_state), metrics

        def on_replay_levels(rng: chex.PRNGKey, train_state: TrainState):
            """
            This samples levels from the level buffer, and updates the policy on them.
            """
            sampler = train_state.sampler

            # Collect trajectories on replay levels
            rng, rng_levels, rng_reset = jax.random.split(rng, 3)
            sampler, (level_inds, levels) = level_sampler.sample_replay_levels(
                sampler, rng_levels, config["num_train_envs"]
            )
            init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
                jax.random.split(rng_reset, config["num_train_envs"]), levels, env_params
            )
            init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
            (
                (rng, train_state, new_hstate, last_obs, last_env_state),
                (
                    obs,
                    actions,
                    rewards,
                    dones,
                    log_probs,
                    values,
                    info,
                    advantages,
                    targets,
                    losses,
                    grads,
                    rollout_states,
                ),
            ) = sample_trajectories_and_learn(
                env,
                env_params,
                config,
                rng,
                train_state,
                init_hstate,
                init_obs,
                init_env_state,
                update_grad=True,
                return_states=True,
            )

            max_returns = jnp.maximum(
                level_sampler.get_levels_extra(sampler, level_inds)["max_return"], compute_max_returns(dones, rewards)
            )
            scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
            sampler = level_sampler.update_batch(sampler, level_inds, scores, {"max_return": max_returns})

            rng, _rng = jax.random.split(rng)
            metrics = {
                "update_state": UpdateState.REPLAY,
            } | get_all_metrics(_rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, levels)
            train_state = train_state.replace(
                sampler=sampler,
                update_state=UpdateState.REPLAY,
                num_replay_updates=train_state.num_replay_updates + 1,
                replay_last_level_batch=levels,
                replay_last_level_batch_scores=scores,
                replay_last_rollout_batch=jax.tree.map(
                    lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
                ),
            )
            return (rng, train_state), metrics

        def on_mutate_levels(rng: chex.PRNGKey, train_state: TrainState):
            """
            This mutates the previous batch of replay levels and potentially adds them to the level buffer.
            This also updates the policy iff `config["exploratory_grad_updates"]` is True.
            """

            sampler = train_state.sampler
            rng, rng_mutate, rng_reset = jax.random.split(rng, 3)

            # mutate
            parent_levels = train_state.replay_last_level_batch
            child_levels = jax.vmap(mutate_world, (0, 0, None))(
                jax.random.split(rng_mutate, config["num_train_envs"]), parent_levels, config["num_edits"]
            )
            init_obs, init_env_state = jax.vmap(env.reset_to_level, in_axes=(0, 0, None))(
                jax.random.split(rng_reset, config["num_train_envs"]), child_levels, env_params
            )

            init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
            # rollout
            (
                (rng, train_state, new_hstate, last_obs, last_env_state),
                (
                    obs,
                    actions,
                    rewards,
                    dones,
                    log_probs,
                    values,
                    info,
                    advantages,
                    targets,
                    losses,
                    grads,
                    rollout_states,
                ),
            ) = sample_trajectories_and_learn(
                env,
                env_params,
                config,
                rng,
                train_state,
                init_hstate,
                init_obs,
                init_env_state,
                update_grad=config["exploratory_grad_updates"],
                return_states=True,
            )

            max_returns = compute_max_returns(dones, rewards)
            scores, all_ued_scores = compute_all_scores(config, dones, values, max_returns, rewards, info, advantages)
            sampler, _ = level_sampler.insert_batch(sampler, child_levels, scores, {"max_return": max_returns})

            rng, _rng = jax.random.split(rng)
            metrics = {"update_state": UpdateState.MUTATE,} | get_all_metrics(
                _rng, losses, info, init_env_state, init_obs, dones, grads, all_ued_scores, child_levels
            )

            train_state = train_state.replace(
                sampler=sampler,
                update_state=UpdateState.DR,
                num_mutation_updates=train_state.num_mutation_updates + 1,
                mutation_last_level_batch=child_levels,
                mutation_last_level_batch_scores=scores,
                mutation_last_rollout_batch=jax.tree.map(
                    lambda x: x[:, 0], (rollout_states.env_state.env_state.env_state, dones)
                ),
            )
            return (rng, train_state), metrics

        rng, train_state = carry
        rng, rng_replay = jax.random.split(rng)

        # The train step makes a decision on which branch to take, either on_new, on_replay or on_mutate.
        # on_mutate is only called if the replay branch has been taken before (as it uses `train_state.update_state`).
        branches = [
            on_new_levels,
            on_replay_levels,
        ]
        if config["use_accel"]:
            s = train_state.update_state
            branch = (1 - s) * level_sampler.sample_replay_decision(train_state.sampler, rng_replay) + 2 * s
            branches.append(on_mutate_levels)
        else:
            branch = level_sampler.sample_replay_decision(train_state.sampler, rng_replay).astype(int)

        return jax.lax.switch(branch, branches, rng, train_state)

    @partial(jax.jit, static_argnums=(2,))
    def eval(rng: chex.PRNGKey, train_state: TrainState, keep_states=True):
        """
        This evaluates the current policy on the set of evaluation levels specified by config["eval_levels"].
        It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,)
        """
        num_levels = config["num_eval_levels"]
        return general_eval(
            rng,
            eval_env,
            env_params,
            train_state,
            all_eval_levels,
            env_params.max_timesteps,
            num_levels,
            keep_states=keep_states,
            return_trajectories=True,
        )

    @partial(jax.jit, static_argnums=(2,))
    def eval_on_dr_levels(rng: chex.PRNGKey, train_state: TrainState, keep_states=False):
        return general_eval(
            rng,
            env,
            env_params,
            train_state,
            DR_EVAL_LEVELS,
            env_params.max_timesteps,
            NUM_EVAL_DR_LEVELS,
            keep_states=keep_states,
        )

    @jax.jit
    def train_and_eval_step(runner_state, _):
        """
        This function runs the train_step for a certain number of iterations, and then evaluates the policy.
        It returns the updated train state, and a dictionary of metrics.
        """
        # Train
        (rng, train_state), metrics = jax.lax.scan(train_step, runner_state, None, config["eval_freq"])

        # Eval
        metrics["update_count"] = (
            train_state.num_dr_updates + train_state.num_replay_updates + train_state.num_mutation_updates
        )

        vid_frequency = get_video_frequency(config, metrics["update_count"])
        should_log_videos = metrics["update_count"] % vid_frequency == 0

        def _compute_eval_learnability(dones, rewards, infos):
            @jax.vmap
            def _single(d, r, i):
                learn, num_eps, num_succ = compute_learnability(config, d, r, i, config["num_eval_levels"])

                return num_eps, num_succ.squeeze(-1)

            num_eps, num_succ = _single(dones, rewards, infos)
            num_eps, num_succ = num_eps.sum(axis=0), num_succ.sum(axis=0)
            success_rate = num_succ / jnp.maximum(1, num_eps)

            return success_rate * (1 - success_rate)

        @jax.jit
        def _get_eval(rng):
            metrics = {}
            rng, rng_eval = jax.random.split(rng)
            (states, cum_rewards, done_idx, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap(
                eval, (0, None)
            )(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state)

            # learnability here of the holdout set:
            eval_learn = _compute_eval_learnability(eval_dones, eval_rewards, eval_infos)
            # Collect Metrics
            eval_returns = cum_rewards.mean(axis=0)  # (num_eval_levels,)
            eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum(
                1, eval_dones.sum(axis=1)
            )
            eval_solves = eval_solves.mean(axis=0)
            metrics["eval_returns"] = eval_returns
            metrics["eval_ep_lengths"] = episode_lengths.mean(axis=0)
            metrics["eval_learn"] = eval_learn
            metrics["eval_solves"] = eval_solves

            metrics["eval_get_max_eplen"] = (episode_lengths == env_params.max_timesteps).mean(axis=0)
            metrics["episode_return_bigger_than_negative"] = (cum_rewards > -0.4).mean(axis=0)

            if config["EVAL_ON_SAMPLED"]:
                states_dr, cum_rewards_dr, done_idx_dr, episode_lengths_dr, infos_dr = jax.vmap(
                    eval_on_dr_levels, (0, None)
                )(jax.random.split(rng_eval, config["eval_num_attempts"]), train_state)

                eval_dr_returns = cum_rewards_dr.mean(axis=0).mean()
                eval_dr_eplen = episode_lengths_dr.mean(axis=0).mean()

                my_eval_dones = infos_dr["returned_episode"]
                eval_dr_solves = (infos_dr["returned_episode_solved"] * my_eval_dones).sum(axis=1) / jnp.maximum(
                    1, my_eval_dones.sum(axis=1)
                )

                metrics["eval_dr_returns"] = eval_dr_returns
                metrics["eval_dr_eplen"] = eval_dr_eplen
                metrics["eval_dr_solve_rates"] = eval_dr_solves
            return metrics, states, episode_lengths, cum_rewards

        @jax.jit
        def _get_videos(rng, states, episode_lengths, cum_rewards):
            metrics = {"log_videos": True}

            # just grab the first run
            states, episode_lengths = jax.tree_util.tree_map(
                lambda x: x[0], (states, episode_lengths)
            )  # (num_steps, num_eval_levels, ...), (num_eval_levels,)
            # And one attempt
            states = jax.tree_util.tree_map(lambda x: x[:, :], states)
            episode_lengths = episode_lengths[:]
            images = jax.vmap(jax.vmap(render_fn_eval))(
                states.env_state.env_state.env_state
            )  # (num_steps, num_eval_levels, ...)
            frames = images.transpose(
                0, 1, 4, 2, 3
            )  # WandB expects color channel before image dimensions when dealing with animations for some reason

            @jax.jit
            def _get_video(rollout_batch):
                states = rollout_batch[0]
                images = jax.vmap(render_fn)(states)  # dimensions are (steps, x, y, 3)
                return (
                    # jax.tree.map(lambda x: x[:].transpose(0, 2, 1, 3)[:, ::-1], images).transpose(0, 3, 1, 2),
                    images.transpose(0, 3, 1, 2),
                    # images.transpose(0, 1, 4, 2, 3),
                    rollout_batch[1][:].argmax(),
                )

            # rollouts
            metrics["dr_rollout"], metrics["dr_ep_len"] = _get_video(train_state.dr_last_rollout_batch)
            metrics["replay_rollout"], metrics["replay_ep_len"] = _get_video(train_state.replay_last_rollout_batch)
            metrics["mutation_rollout"], metrics["mutation_ep_len"] = _get_video(
                train_state.mutation_last_rollout_batch
            )

            metrics["eval_animation"] = (frames, episode_lengths)

            metrics["eval_returns_video"] = cum_rewards[0]
            metrics["eval_len_video"] = episode_lengths

            # Eval on sampled

            return metrics

        @jax.jit
        def _get_dummy_videos(rng, states, episode_lengths, cum_rewards):
            n_eval = config["num_eval_levels"]
            nsteps = env_params.max_timesteps
            nsteps2 = config["outer_rollout_steps"] * config["num_steps"]
            img_size = (
                env.static_env_params.screen_dim[0] // env.static_env_params.downscale,
                env.static_env_params.screen_dim[1] // env.static_env_params.downscale,
            )
            return {
                "log_videos": False,
                "dr_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
                "dr_ep_len": jnp.zeros((), jnp.int32),
                "replay_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
                "replay_ep_len": jnp.zeros((), jnp.int32),
                "mutation_rollout": jnp.zeros((nsteps2, 3, *img_size), jnp.float32),
                "mutation_ep_len": jnp.zeros((), jnp.int32),
                # "eval_returns": jnp.zeros((n_eval,), jnp.float32),
                # "eval_solves": jnp.zeros((n_eval,), jnp.float32),
                # "eval_learn": jnp.zeros((n_eval,), jnp.float32),
                # "eval_ep_lengths": jnp.zeros((n_eval,), jnp.int32),
                "eval_animation": (
                    jnp.zeros((nsteps, n_eval, 3, *img_size), jnp.float32),
                    jnp.zeros((n_eval,), jnp.int32),
                ),
                "eval_returns_video": jnp.zeros((n_eval,), jnp.float32),
                "eval_len_video": jnp.zeros((n_eval,), jnp.int32),
            }

        rng, rng_eval, rng_vid = jax.random.split(rng, 3)

        metrics_eval, states, episode_lengths, cum_rewards = _get_eval(rng_eval)
        metrics = {
            **metrics,
            **metrics_eval,
            **jax.lax.cond(
                should_log_videos, _get_videos, _get_dummy_videos, rng_vid, states, episode_lengths, cum_rewards
            ),
        }
        max_num_images = 8

        top_regret_ones = max_num_images // 2
        bot_regret_ones = max_num_images - top_regret_ones

        @jax.jit
        def get_values(level_batch, scores):
            args = jnp.argsort(scores)  # low scores are at the start, high scores are at the end

            low_scores = args[:bot_regret_ones]
            high_scores = args[-top_regret_ones:]

            low_levels = jax.tree.map(lambda x: x[low_scores], level_batch)
            high_levels = jax.tree.map(lambda x: x[high_scores], level_batch)

            low_scores = scores[low_scores]
            high_scores = scores[high_scores]
            # now concatenate:
            return jax.vmap(render_fn)(
                jax.tree.map(lambda x, y: jnp.concatenate([x, y], axis=0), low_levels, high_levels)
            ), jnp.concatenate([low_scores, high_scores], axis=0)

        metrics["dr_levels"], metrics["dr_scores"] = get_values(
            train_state.dr_last_level_batch, train_state.dr_last_level_batch_scores
        )
        metrics["replay_levels"], metrics["replay_scores"] = get_values(
            train_state.replay_last_level_batch, train_state.replay_last_level_batch_scores
        )
        metrics["mutation_levels"], metrics["mutation_scores"] = get_values(
            train_state.mutation_last_level_batch, train_state.mutation_last_level_batch_scores
        )

        def _t(i):
            return jax.lax.select(i == 0, config["num_steps"], i)

        metrics["dr_ep_len"] = _t(train_state.dr_last_rollout_batch[1][:].argmax())
        metrics["replay_ep_len"] = _t(train_state.replay_last_rollout_batch[1][:].argmax())
        metrics["mutation_ep_len"] = _t(train_state.mutation_last_rollout_batch[1][:].argmax())

        highest_scoring_level = level_sampler.get_levels(train_state.sampler, train_state.sampler["scores"].argmax())
        highest_weighted_level = level_sampler.get_levels(
            train_state.sampler, level_sampler.level_weights(train_state.sampler).argmax()
        )

        metrics["highest_scoring_level"] = render_fn(highest_scoring_level)
        metrics["highest_weighted_level"] = render_fn(highest_weighted_level)

        # log_eval(metrics, train_state_to_log_dict(runner_state[1], level_sampler))
        jax.debug.callback(log_eval, metrics, train_state_to_log_dict(runner_state[1], level_sampler))
        return (rng, train_state), {"update_count": metrics["update_count"]}

    def log_checkpoint(update_count, train_state):
        if config["save_path"] is not None and config["checkpoint_save_freq"] > 1:
            steps = (
                int(update_count)
                * int(config["num_train_envs"])
                * int(config["num_steps"])
                * int(config["outer_rollout_steps"])
            )
            # save_params_to_wandb(train_state.params, steps, config)
            save_model_to_wandb(train_state, steps, config)

    def train_eval_and_checkpoint_step(runner_state, _):
        runner_state, metrics = jax.lax.scan(
            train_and_eval_step, runner_state, xs=jnp.arange(config["checkpoint_save_freq"] // config["eval_freq"])
        )
        jax.debug.callback(log_checkpoint, metrics["update_count"][-1], runner_state[1])
        return runner_state, metrics

    # Set up the train states
    rng = jax.random.PRNGKey(config["seed"])
    rng_init, rng_train = jax.random.split(rng)

    train_state = create_train_state(rng_init)
    runner_state = (rng_train, train_state)

    runner_state, metrics = jax.lax.scan(
        train_eval_and_checkpoint_step,
        runner_state,
        xs=jnp.arange((config["num_updates"]) // (config["checkpoint_save_freq"])),
    )

    if config["save_path"] is not None:
        # save_params_to_wandb(runner_state[1].params, config["total_timesteps"], config)
        save_model_to_wandb(runner_state[1], config["total_timesteps"], config, is_final=True)

    return runner_state[1]


if __name__ == "__main__":
    main()