# --------------------------------------------------------
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import argparse
import json
import os
import time
import traceback
from typing import Optional

import numpy as np
from tqdm import tqdm

from datasets.encode_openx_dataset import MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES, get_shard_inds, VAL_RATIO, \
    process_dataset_step, DATA_FREQ_TABLE
from datasets.extern.ego4d import ego4d_dataset_size, ego4d_dataset_generator
from datasets.extern.egoexo4d import egoexo4d_dataset_size, egoexo4d_dataset_generator
from datasets.extern.robomimic import robomimic_dataset_generator, robomimic_dataset_size
from . import utils


SCRIPT_DESCRIPTION="""
Similar to encode_openx_dataset.py except for non-OpenX datasets.
Again, each split can be partitioned into multiple shards,
which is useful for parallelized encoding across GPUs.

Example usage:
    CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name egoexo4d --data_split train --num_shards 1000 --curr_shard_rank 400

Untested usage (SVD tokenizer):
CUDA_VISIBLE_DEVICES=0 python -m datasets.encode_extern_dataset --dataset_name robomimic --data_split val --no_quantization --encoder_type temporalvae --encoder_name_or_path 'stabilityai/stable-video-diffusion-img2vid'
""".strip()

DATASET_TO_GEN_AND_SIZE = {
    "ego4d": (ego4d_dataset_generator, ego4d_dataset_size),
    "egoexo4d": (egoexo4d_dataset_generator, egoexo4d_dataset_size),
    "robomimic": (robomimic_dataset_generator, robomimic_dataset_size),
}


def encode_dataset_split(
    extern_dataset_name: str,
    split: str,
    max_episodes: Optional[int],
    original_res: bool,
    no_quantization: bool,
    curr_shard_rank: int,
    num_shards: int,
    root_dir: str,
    encoder_type: str,
    encoder_name_or_path: str,
    dataset_postfix: str = "",
    no_encoding: bool = False,
):
    """
    Encodes (e.g. tokenizes) dataset.
    The data written to disk can be used to load a `RawTokenDataset` (or the continuous version.)

    Args:
        extern_dataset_name:  TODO
        split: expected to be either "train" or "val". TODO: decide how to split
        max_episodes: the maximum number of trajectories to include in the dataset.
        dataset_postfix: will be a suffix of the output dirname.
        image_encoder: string specifying the type of image encoder/tokenizer to use.
        original_res: if True, will maintain original resolution of the video rather than resizing it to 256x256.
        no_quantization: if True, will not perform quantization step in image encoder.
    """
    extern_dataset_name = extern_dataset_name.strip()  # never modified
    suffixed_dataset_name = extern_dataset_name  # will modify later

    if original_res:
        suffixed_dataset_name = f"{suffixed_dataset_name}_originalres"
    if no_quantization:
        suffixed_dataset_name = f"{suffixed_dataset_name}_noquant"
    if no_encoding:
        suffixed_dataset_name = f"{suffixed_dataset_name}_noencoding"
    save_dirname = "_".join([suffixed_dataset_name, encoder_type, dataset_postfix, split])
    dataset_path = os.path.join(root_dir, save_dirname)
    print("=" * 25)
    print(f"{dataset_path=}")
    utils.mkdir_if_missing(dataset_path)

    # Load data
    generator, size_func = DATASET_TO_GEN_AND_SIZE[extern_dataset_name]
    num_examples = size_func()
    if max_episodes is not None:
        num_examples = min(num_examples, max_episodes)  # clip num_examples

    # We will only operate on a subset of the training examples, depending on:
    #      1) The split (train/val). Some examples are reserved for the other split.
    #      2) Sharding
    assert num_examples > MIN_VAL_EXAMPLES  # non-positive number of train examples otherwise
    num_val_examples = np.clip(int(VAL_RATIO * num_examples), MIN_VAL_EXAMPLES, MAX_VAL_EXAMPLES)

    if split == "train":  # first_ind inclusive, last_ind exclusive
        first_split_ind, last_split_ind = num_val_examples, num_examples
    elif split == "val":
        first_split_ind, last_split_ind = 0, num_val_examples
    else:
        raise NotImplementedError(f"{split=}")

    first_shard_ind, last_shard_ind = get_shard_inds(first_split_ind, last_split_ind, curr_shard_rank, num_shards)
    print(f"Total number of examples in {suffixed_dataset_name}: {num_examples}")
    print(f"Number of examples for {split=}, shard {curr_shard_rank} of {num_shards}: "
          f"{last_shard_ind - first_shard_ind}. {first_shard_ind=} {last_shard_ind=}")

    ##### Encode data #####
    traj_lens = []  # only used to print statistics
    videos = []  # NOTE: videos/actions for the entire shard are stored in RAM until the end
    actions = []
    segment_ids = []

    # split based on some fixed batch sizes to reset RAM.
    max_batch_per_loading = 10
    pbar = tqdm(range(first_shard_ind, last_shard_ind, max_batch_per_loading), position=0, leave=True)
    start_time = time.time()

    for start_idx in pbar:
        end_idx = min(start_idx + max_batch_per_loading, last_shard_ind)
        pbar.set_description(f"{suffixed_dataset_name} caching episodes: {start_idx}:{end_idx}")
        ds = generator(range(start_idx, end_idx))

        for chunk_idx, episode in enumerate(tqdm(ds, position=1, leave=False)):
            segment_id = start_idx + chunk_idx
            try:
                # batchify the data and then process
                for step_ind, step_data in enumerate(episode["steps"]):
                    dataset_step = process_dataset_step(
                        step_data,
                        encoder_type=encoder_type,
                        encoder_name_or_path=encoder_name_or_path,
                        keep_res=original_res,
                        quantize=not no_quantization,
                        no_encoding=no_encoding
                    )

                    segment_ids.append(segment_id)
                    videos.append(dataset_step["image"])
                    actions.append(dataset_step["action"])

                traj_lens.append(step_ind + 1)  # number of steps in this trajectory
            except:
                print("-" * 25)
                print(f"Add episode failed: {segment_id=}", traceback.format_exc(), suffixed_dataset_name)

            # 2 day timeout
            if time.time() - start_time > 86400 * 2:
                print(f"Writing dataset {suffixed_dataset_name} timed out")
                break

    if len(videos) == 0:
        print("Empty shard!")
        with open(f"{dataset_path}/error.json", "w") as f:
            json.dump({"status": "empty_shard"}, f)

        return

    if no_quantization:
        num_channels, height, width = videos[-1].shape[:3]  # num_channels is not actually stored in metadata
    else:
        height, width = videos[-1].shape[:2]
        num_channels = None

    ##### Write videos, actions, segment_ids, and metadata #####
    # align format to save segment_ids.bin, video.bin, actions/action.bin, metadata.json
    # save videos
    videos = np.stack(videos, axis=0)
    # fp = np.memmap(f'{dataset_path}/video.bin', dtype=video_dtype, mode='w+', shape=videos.shape)
    # fp[:] = videos[:]
    videos.tofile(f'{dataset_path}/video.bin')

    # save action
    utils.mkdir_if_missing(f'{dataset_path}/actions')
    actions = np.stack(actions, axis=0)
    # fp = np.memmap(f'{dataset_path}/actions/actions.bin', dtype=np.float32, mode='w+', shape=actions.shape)
    # fp[:] = actions[:]
    actions = actions.astype(np.float32)
    actions.tofile(f'{dataset_path}/actions/actions.bin')

    # save segment_ids
    segment_ids = np.array(segment_ids)
    # fp = np.memmap(f'{dataset_path}/segment_ids.bin', dtype=np.int32, mode='w+', shape=segment_ids.shape)
    # fp[:] = segment_ids[:]  # map to trajectory index
    segment_ids = segment_ids.astype(np.int32)
    segment_ids.tofile(f'{dataset_path}/segment_ids.bin')

    # feature_mean = np.mean(videos)
    # feature_std = np.std((videos - feature_mean) / 1e9) * 1e9

    # save metadata
    if encoder_type == "magvit":
        vocab_size = int(2 ** 18)
    elif encoder_type == "temporalvae":
        vocab_size = None
    else:
        raise NotImplementedError(f"{encoder_type=}")

    with open(f'{dataset_path}/metadata.json', 'w') as f:  # Technically only need to save most of this data for shard 0
        json.dump({
            "token_dtype": str(np.dtype(videos.dtype)),
            "action_dim": actions[0].shape[-1],
            "s": 16,
            "h": height,
            "w": width,
            "vocab_size": vocab_size,
            "hz": DATA_FREQ_TABLE.get(extern_dataset_name, 1),  # to be loaded from the data code
            "encoder_name_or_path": encoder_name_or_path,
            "encoder_type": encoder_type,
            "num_images": len(videos),
            "latent_channels": num_channels,
            "name": extern_dataset_name,
            # "feature_mean": feature_mean,
            # "feature_std": feature_std,
        }, f)

    print(f"{len(traj_lens)=} {np.mean(traj_lens)=} {np.sum(traj_lens)=}")
    print(f"Dataset creation time: {time.time() - start_time:.3f}")


def parse_args():
    parser = argparse.ArgumentParser(description=SCRIPT_DESCRIPTION)

    parser.add_argument(
        "--dataset_name", type=str, required=True, choices=DATASET_TO_GEN_AND_SIZE.keys(),
        help="TODO"
    )
    parser.add_argument(
        "--data_split", type=str, choices=["train", "val"], required=True,
        help="The split of the dataset to create."
    )
    parser.add_argument(
        "--episode_cnt", type=int,
        help="If specified, will limit the maximum number of trajectories to encode."
    )
    parser.add_argument(
        "--original_res", action='store_true',
        help="Maintain original resolution of the video rather than resizing it to 256x256."
    )
    parser.add_argument(
        "--no_quantization", action='store_true',
        help="Skip quantization step in visual encoder."
    )
    parser.add_argument(
        "--num_shards", type=int, default=1,
        help="The number of shards to partition the train/val dataset into."
    )
    parser.add_argument(
        "--curr_shard_rank", type=int, default=0,
        help="The (0-indexed) shard number to encode."
    )
    parser.add_argument(
        "--root_dir", type=str, default="data",
        help="The root directory to write all datasets to."
    )
    parser.add_argument(
        "--encoder_type", type=str, default="magvit", choices=["magvit", "temporalvae"],
        help="Type of the image tokenizer."
    )
    parser.add_argument(
        "--encoder_name_or_path", type=str, default="data/magvit2.ckpt",
        help="The path or name of the image encoder."
    )
    parser.add_argument(
        "--no_encoding", action='store_true',
        help="Preserve the groundtruth raw images to compute metrics in validation."
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    utils.set_seed(233)

    dataset_postfix = f"shard{args.curr_shard_rank}_of_{args.num_shards}"
    if args.episode_cnt is not None:
        dataset_postfix = f"max{args.episode_cnt}_{dataset_postfix}"

    encode_dataset_split(
        extern_dataset_name=args.dataset_name,
        split=args.data_split,
        max_episodes=args.episode_cnt,
        dataset_postfix=dataset_postfix,
        original_res=args.original_res,
        no_quantization=args.no_quantization,
        num_shards=args.num_shards,
        curr_shard_rank=args.curr_shard_rank,
        root_dir=args.root_dir,
        encoder_type=args.encoder_type,
        encoder_name_or_path=args.encoder_name_or_path,
        no_encoding=args.no_encoding,
    )