import logging
import math
import os

import mup
import numpy as np
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed

from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
from transformers import (
    default_data_collator,
    get_scheduler,
)
import wandb

from cont_data import RawFeatureDataset, get_maskgit_collator_feature
from genie.config import DiffusionGenieConfig
from genie.st_mar import STMAR

from datetime import datetime
from accelerate import DistributedDataParallelKwargs
from common import data_sampler
import yaml
from train_diffusion import parse_args, train

# Get current date and time
now = datetime.now()

# Format the datetime object as a string
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S")

torch.set_float32_matmul_precision("medium")
logger = get_logger(__name__)
torch.autograd.set_detect_anomaly(True)

def parse_args_multi():
    # parser = argparse.ArgumentParser(description="Train a MaskGIT or Llama-style LLM on video generation.")
    parser = parse_args()

    # Data
    parser.add_argument(
        "--train_split", type=str, default="experiments/datasplit/dataset2.yaml",
        help="Config files for using multiple datasets."
    )

    parser.add_argument(
        "--num_episodes_per_dataset",
        type=int,
        default=1000000,
        help="Maximum number of trajectories per dataset",
    )
    parser.add_argument(
        "--image_maskgit_path",
        type=str,
        default=None,
        help="Optional path to the official MaskGIT checkpoint. "
             "If specified, will copy relevant weights from the checkpoint. "
             "These weights will have a different (hard-coded) warmup schedule.",
    )
    parser.add_argument(
        "--action_network",
        type=str,
        default=None,
        choices=["concat", "cross_attention"],  # TODO: add other methods (resampler_concat, modulate, etc)
        help="If specified, will override the action in the config. Helps reduce the number of config jsons."
    )
    args = parser.parse_args()
    return args


def main():
    args = parse_args_multi()
    assert (args.llama_config is not None) ^ (args.genie_config is not None), \
        "Exactly one of `llama_config` and `genie_config` should be set."

    # Manual gradient accumulation
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(gradient_accumulation_steps=1, log_with=args.report_to,
                                even_batches=False, project_dir=args.output_dir, kwargs_handlers=[ddp_kwargs])
    accelerator.init_trackers("video")

    if accelerator.is_main_process:
        accelerator.trackers[0].run.name = formatted_date + "_" + args.run_name

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()

    if args.seed is not None:
        set_seed(args.seed)

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    accelerator.wait_for_everyone()

    # create multiple datasets
    with open(args.train_split, 'r') as file:
        datasplit = yaml.safe_load(file)

    config = DiffusionGenieConfig.from_pretrained(args.genie_config)

    # Extract the 'domains' value and split it into a list
    domains_list = [domain.strip() for domain in datasplit['domains'].split(',')]
    train_datasets = []
    val_datasets = []
    dataset_num_samples = []
    val_dataset_num_samples = []

    action_dimensions = []
    action_stats = []

    shared_keys = ("s", "h", "w", "vocab_size", "latent_channels",
                   "encoder_type", "encoder_name_or_path", "quantized")  # TODO: check train/val hz per dataset?
    for domain in domains_list:
        try:
 
            # train_data_dir = f"data/{domain}_vae_traj500_train" # {args.num_episodes_per_dataset}
            # val_data_dir = f"data/{domain}_vae_traj500_val"
            train_data_dir = f"data/{domain}_noquant_temporalvae_shard0_of_1_train"  # {args.num_episodes_per_dataset}
            val_data_dir = f"data/{domain}_noquant_temporalvae_shard0_of_1_val"

            # train_data_dir = f"data/{domain}_vae_traj{args.num_episodes_per_dataset}_train" # {args.num_episodes_per_dataset}
            # val_data_dir = f"data/{domain}_vae_traj{args.num_episodes_per_dataset}_val"

            if config.drop_action_ratio > 0:
                raise NotImplementedError

            train_dataset = RawFeatureDataset(train_data_dir, window_size=args.window_size,
                                            stride=args.stride, filter_overlaps=args.filter_overlaps,
                                            max_traj_num=args.num_episodes_per_dataset,
                                            use_actions=config.use_actions, domain=domain)
            dataset_num_samples.append(len(train_dataset))
            action_dimensions.append(train_dataset.n_action)
            if config.use_actions:
                action_stats.append(train_dataset.action_stat)

            if not args.overfit_first_batch:
                eval_dataset = RawFeatureDataset(val_data_dir, window_size=args.window_size,
                                            stride=args.stride, filter_overlaps=True,
                                            use_actions=config.use_actions, domain=domain)
            else:
                train_dataset.valid_start_inds = train_dataset.valid_start_inds[:args.per_device_train_batch_size
                                                                                * args.gradient_accumulation_steps
                                                                                * accelerator.num_processes]
                eval_dataset = train_dataset

            # Shuffle eval dataset and then set shuffle=False on the dataloader.
            # Shuffling in the dataloader results in reshuffling with each iteration.
            eval_dataset.valid_start_inds = torch.tensor(eval_dataset.valid_start_inds)[
                torch.randperm(len(eval_dataset), generator=torch.Generator().manual_seed(0))
            ].tolist()
            val_dataset_num_samples.append(len(eval_dataset))
        except Exception as e:
            import traceback
            print(traceback.format_exc())

        train_datasets.append(train_dataset)
        val_datasets.append(eval_dataset)
        assert all(train_dataset.metadata.get(shared_key) == eval_dataset.metadata.get(shared_key)
                   for shared_key in shared_keys)  # TODO: check this across all datasets

    print("dataset_num_samples:", dataset_num_samples)

    # Will not store key in metadata if it's missing, so that defaults can be filled by functions later?  # TODO: handle missing keys
    shared_metadata = {shared_key: train_dataset.metadata[shared_key]
                       for shared_key in shared_keys if shared_key in train_dataset.metadata}

    config.use_mup = args.mu_transfer  # Note: changing this may affect pre-trained model due to attn scaling
    config.image_vocab_size = None
    config.T = args.window_size
    config.S = shared_metadata["h"] * shared_metadata["w"]  # TODO: make STMaskGIT use h and w instead of S
    config.vae_embed_dim = shared_metadata["latent_channels"]
    if args.action_network is not None:
        print("Using action network", args.action_network)
        config.action_network = args.action_network

    model = STMAR(config)

    if config.use_actions:
        # TODO: use new list instead of domains_list, in case domain fails
        model.init_action_projectors(domains_list, action_dimensions, action_stats, config.action_network)

    if args.image_maskgit_path is not None:
        model.init_weights()
        model.load_pretrained_image_weights(args.image_maskgit_path)
        if args.mu_transfer:
            model.set_mup_shapes(rescale_params=False)
    elif args.mu_transfer:
        model.set_mup_shapes(rescale_params=True)
        # model.init_weights()  # might be unnecessary if `rescale_params` is True

    # Optimizer. Split weights in two groups, one with weight decay and the other not.
    opt_class = mup.MuAdamW if args.mu_transfer else torch.optim.AdamW
    # scale base learning rate
    effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \
                           * accelerator.num_processes
    args.learning_rate = args.learning_rate * min(max(1, effective_batch_size / 64), 8)

    no_decay = ["bias", "layer_norm.weight"]

    pretrained_params = {  # more accurately the params we want lower lr for, some weights like pos_embed_TSC are pre-trained but not treated as lower lr
        param_name
        for param_name, _ in model.named_parameters()
        if any(term in param_name for term in ("spatial_attn.qkv", "spatial_attn.proj", "mlp"))
    } if args.image_maskgit_path is not None else set()
    # Give pre-trained weights 10x lower learning rate
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters()
                       if not any(nd in n for nd in no_decay) and n not in pretrained_params],
            "weight_decay": args.weight_decay,
            "lr": args.learning_rate,
        },
        {
            "params": [p for n, p in model.named_parameters()
                       if any(nd in n for nd in no_decay) and n not in pretrained_params],
            "weight_decay": 0.0,
            "lr": args.learning_rate,
        },
        {
            "params": [p for n, p in model.named_parameters()
                       if not any(nd in n for nd in no_decay) and n in pretrained_params],
            "weight_decay": args.weight_decay,
            "lr": args.learning_rate * 0.1,
        },
        {
            "params": [p for n, p in model.named_parameters()
                       if any(nd in n for nd in no_decay) and n in pretrained_params],
            "weight_decay": 0.0,
            "lr": args.learning_rate * 0.1,
        },
    ]

    optimizer = opt_class(optimizer_grouped_parameters, lr=args.learning_rate,
                          betas=(args.adam_beta_1, args.adam_beta_2), eps=args.adam_eps)

    # DataLoaders creation:
    collate_fn = default_data_collator if args.llama_config is not None else get_maskgit_collator_feature(config)
    combined_dataset = torch.utils.data.ConcatDataset(train_datasets)

    batch_sampler = data_sampler.MultiTaskBatchSampler(
        dataset_num_samples,
        batch_size=args.per_device_train_batch_size,
        temperature=3. # the higher the more flat the distribution
    )
    dataset_traj_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_num_samples)
    accelerator.log(({"dataset_mixture": wandb.Image(dataset_traj_image)}), log_kwargs={"wandb": {"commit": False}})
    dataset_weights = batch_sampler.generate_tasks_distribution().cpu().numpy()
    dataset_weight_image = data_sampler.make_dataset_pie_plot(domains_list, dataset_weights)
    accelerator.log(({"dataset_mixture_weight": wandb.Image(dataset_weight_image)}), log_kwargs={"wandb": {"commit": False}})

    train_dataloader = DataLoader(combined_dataset, batch_sampler=batch_sampler, collate_fn=collate_fn,
                                   num_workers=24, pin_memory=False)

    batch_val_sampler = data_sampler.MultiTaskBatchSampler(
        val_dataset_num_samples,
        batch_size=args.per_device_train_batch_size,
        temperature=4. # the higher the more flat the distribution
    )

    combined_val_dataset = torch.utils.data.ConcatDataset(val_datasets)
    eval_dataloader = DataLoader(combined_val_dataset, batch_sampler=batch_val_sampler, collate_fn=collate_fn,
                                    num_workers=24, pin_memory=False)

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True
        if args.max_train_steps < 2000 and args.resume_from_checkpoint is None: # minimal number of trainng steps
            args.max_train_steps = 2000
            args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    if args.lr_scheduler_type == "custom_cosine":  # decay to `end_ratio` of the peak learning rate
        def get_lr_wrapper(warmup_steps, max_steps, end_ratio=0.1):
            def get_lr(step):
                if step < warmup_steps:
                    return (step + 1) / warmup_steps

                remaining_steps = max_steps - warmup_steps
                return ((1 + math.cos(math.pi * (step - warmup_steps) / remaining_steps)) / 2) \
                    * (1 - end_ratio) + end_ratio
            return get_lr

        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, get_lr_wrapper(args.num_warmup_steps * accelerator.num_processes,
                                      args.max_train_steps if overrode_max_train_steps
                                      else args.max_train_steps * accelerator.num_processes)
        )
    else:
        lr_scheduler = get_scheduler(
            name=args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
            num_training_steps=args.max_train_steps
            if overrode_max_train_steps
            else args.max_train_steps * accelerator.num_processes,
        )

    # Enable gradient checkpointing to save memory
    if args.gradient_checkpointing:
        logger.info("Enabling gradient checkpointing")
        model.gradient_checkpointing_enable()
        model.config.use_cache = False # incompatible with grad checkpointing

    # Prepare everything with our `accelerator`.
    accelerator.wait_for_everyone()
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    if not args.no_compile:
        torch._dynamo.config.cache_size_limit = 256
        torch._dynamo.config.optimize_ddp = False  # https://github.com/pytorch/pytorch/issues/104674
        # TODO: https://github.com/pytorch/pytorch/issues/109774#issuecomment-2046633776
        model = torch.compile(model)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Figure out how many steps we should save the Accelerator states
    checkpointing_steps = args.checkpointing_steps
    if checkpointing_steps is not None and checkpointing_steps.isdigit():
        checkpointing_steps = int(checkpointing_steps)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initialize automatically on the main process.
    experiment_config = vars(args) | vars(config)

    seq_len = shared_metadata["h"] * shared_metadata["w"] * args.window_size
    effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \
                           * accelerator.num_processes
    args.num_datasets = len(train_datasets)
    model_module = model.module if hasattr(model, "module") else model

    experiment_config.update(shared_metadata | {
        "model_parameters": sum(p.numel() for p in model.parameters()),
        "model_parameters_M": round(sum(p.numel() for p in model.parameters()) / 1e6),
        "trunk_parameters": sum(p.numel() for p in model_module.decoder.parameters()),
        "trunk_parameters_M": round(sum(p.numel() for p in model_module.decoder.parameters()) / 1e6),
        "seq_len": seq_len,
        "train_data_tokens": len(train_dataset) * seq_len,
        "effective_batch_size": effective_batch_size,
        "effective_batch_size_tokens": effective_batch_size * seq_len,
        "mixed_precision": accelerator.mixed_precision,
        "num_datasets": args.num_datasets
    })

    experiment_config["FLOPs_per_update_step"] = 6 * experiment_config["model_parameters"] \
                                                 * experiment_config["effective_batch_size_tokens"]

    accelerator.init_trackers(project_name="video", config=experiment_config)

    # Train!
    train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args)


if __name__ == "__main__":
    main()