import copy
import datetime
import gzip
import json
import os
from hashlib import md5

import jax
import jax.numpy as jnp
import numpy as np
from numpy import isin
from kinetix.environment.ued.ued_state import UEDParams
from omegaconf import OmegaConf
from pandas import isna
from typing import List, Tuple
import wandb
from kinetix.environment.env_state import EnvParams, StaticEnvParams
from collections import defaultdict

from kinetix.util.saving import load_from_json_file


def get_hash_without_seed(config):
    old_seed = config["seed"]
    config["seed"] = 0
    ans = md5(OmegaConf.to_yaml(config, sort_keys=True).encode()).hexdigest()
    config["seed"] = old_seed
    return ans


def get_date() -> str:
    return datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")


def generate_params_from_config(config):
    if config.get("env_size_type", "predefined") == "custom":
        # must load env params from a file
        _, static_env_params, env_params = load_from_json_file(os.path.join("worlds", config["custom_path"]))
        return env_params, static_env_params.replace(
            frame_skip=config["frame_skip"],
        )
    env_params = EnvParams()

    static_env_params = StaticEnvParams().replace(
        num_polygons=config["num_polygons"],
        num_circles=config["num_circles"],
        num_joints=config["num_joints"],
        num_thrusters=config["num_thrusters"],
        frame_skip=config["frame_skip"],
        num_motor_bindings=config["num_motor_bindings"],
        num_thruster_bindings=config["num_thruster_bindings"],
    )

    return env_params, static_env_params


def generate_ued_params_from_config(config) -> UEDParams:
    ans = UEDParams()

    if config["env_size_name"] == "s":
        ans = ans.replace(add_shape_n_proposals=1)  # otherwise we get a very weird XLA bug.
    if "fixate_chance_max" in config:
        print("Changing fixate chance max to", config["fixate_chance_max"])
        ans = ans.replace(fixate_chance_max=config["fixate_chance_max"])
    return ans


def get_eval_level_groups(eval_levels: List[str]) -> List[Tuple[str, str]]:
    def get_groups(s):
        # This is the size group
        group_one = s.split("/")[0]
        group_two = s.split("/")[1].split("_")[0]
        group_two = "".join([i for i in group_two if not i.isdigit()])
        if group_two == "h":
            group_two = "handmade"
        if group_two == "r":
            group_two = "random"
        return f"{group_one}_all", f"{group_one}_{group_two}"

    indices = defaultdict(list)

    for idx, s in enumerate(eval_levels):
        groups = get_groups(s)
        for group in groups:
            indices[group].append(idx)

    indices2 = {}
    for g in indices:
        indices2[g] = np.array(indices[g])

    return indices2


def normalise_config(config, name, editor_config=False):
    old_config = copy.deepcopy(config)
    keys = ["env", "learning", "model", "misc", "eval", "ued", "env_size", "train_levels"]
    for k in keys:
        if k not in config:
            config[k] = {}
        small_d = config[k]
        del config[k]
        for kk, vv in small_d.items():
            assert kk not in config, kk
            config[kk] = vv

    if not editor_config:
        config["eval_env_size_true"] = config["eval_env_size"]
        if config["num_train_envs"] == 2048 and "Pixels" in config["env_name"]:
            config["num_train_envs"] = 512
        if "SFL" in name and config["env_size_name"] in ["m", "l"]:
            config["eval_num_attempts"] = 6  # to avoid a very weird XLA bug.
        config["hash"] = get_hash_without_seed(config)

        config["random_hash"] = np.random.randint(2**31)

        config["log_save_path"] = f"logs/{config['hash']}/{config['seed']}-{get_date()}"
        os.makedirs(config["log_save_path"], exist_ok=True)
        with open(f"{config['log_save_path']}/config.yaml", "w") as f:
            f.write(OmegaConf.to_yaml(old_config))
        if config["group"] == "auto":
            config["group"] = f"{name}-" + config["group_auto_prefix"] + config["env_name"].replace("Kinetix-", "")
            config["group"] += "-" + str(config["env_size_name"])

        if config["eval_levels"] == ["auto"] or config["eval_levels"] == "auto":
            config["eval_levels"] = config["train_levels_list"]
            print("Using Auto eval levels:", config["eval_levels"])
        config["num_eval_levels"] = len(config["eval_levels"])

        steps = (
            config["num_steps"]
            * config.get("outer_rollout_steps", 1)
            * config["num_train_envs"]
            * (2 if name == "PAIRED" else 1)
        )
        config["num_updates"] = int(config["total_timesteps"]) // steps

        nsteps = int(config["total_timesteps"] // 1e6)
        letter = "M"
        if nsteps >= 1000:
            nsteps = nsteps // 1000
            letter = "B"
        config["run_name"] = (
            config["env_name"] + f"-{name}-" + str(nsteps) + letter + "-" + str(config["num_train_envs"])
        )

        if config["checkpoint_save_freq"] >= config["num_updates"]:
            config["checkpoint_save_freq"] = config["num_updates"]
    return config


def get_tags(config, name):
    return [name]
    tags = [name]
    if name in ["PLR", "ACCEL", "DR"]:
        if config["use_accel"]:
            tags.append("ACCEL")
        else:
            tags.append("PLR")
    return tags


def init_wandb(config, name) -> wandb.run:
    run = wandb.init(
        config=config,
        project=config["wandb_project"],
        group=config["group"],
        name=config["run_name"],
        entity=config["wandb_entity"],
        mode=config["wandb_mode"],
        tags=get_tags(config, name),
    )
    wandb.define_metric("timing/num_updates")
    wandb.define_metric("timing/num_env_steps")
    wandb.define_metric("*", step_metric="timing/num_env_steps")
    wandb.define_metric("timing/sps", step_metric="timing/num_env_steps")
    return run


def save_data_to_local_file(data_to_save, config):
    if not config.get("save_local_data", False):
        return

    def reverse_in(li, value):
        for i, v in enumerate(li):
            if v in value:
                return True
        return False

    clean_data = {k: v for k, v in data_to_save.items() if not reverse_in(["media/", "images/"], k)}

    def _clean(x):
        if isinstance(x, jnp.ndarray):
            return x.tolist()
        elif isinstance(x, jnp.float32):
            if jnp.isnan(x):
                return -float("inf")
            return round(float(x) * 1000) / 1000
        elif isinstance(x, jnp.int32):
            return int(x)
        return x

    clean_data = jax.tree_map(lambda x: _clean(x), clean_data)
    print("Saving this data:", clean_data)
    with open(f"{config['log_save_path']}/data.jsonl", "a+") as f:
        f.write(json.dumps(clean_data) + "\n")


def compress_log_files_after_run(config):
    fpath = f"{config['log_save_path']}/data.jsonl"
    with open(fpath, "rb") as f_in, gzip.open(fpath + ".gz", "wb") as f_out:
        f_out.writelines(f_in)


def get_video_frequency(config, update_step):
    frac_through_training = update_step / config["num_updates"]
    vid_frequency = (
        config["eval_freq"]
        * config["video_frequency"]
        * jax.lax.select(
            (0.1 <= frac_through_training) & (frac_through_training < 0.3),
            1,
            jax.lax.select(
                (0.3 <= frac_through_training) & (frac_through_training < 0.6),
                2,
                4,
            ),
        )
    )
    return vid_frequency