import argparse
import contextlib
import logging
import math
import os
import time

import matplotlib
import mup
import numpy as np
import torch
import torchvision.transforms.functional as transforms_f
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from einops import rearrange
from lpips import lpips
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import transformers
import traceback

from transformers import (
    default_data_collator,
    get_scheduler,
)
from collections import defaultdict
from data import RawTokenDataset, get_maskgit_collator

from common.eval_utils import decode_tokens, compute_lpips
from genie.st_mask_git import STMaskGIT
from genie.config import GenieConfig

from visualize import decode_latents_wrapper
from skimage import metrics as image_metrics
from matplotlib import pyplot as plt
from datetime import datetime
from accelerate import DistributedDataParallelKwargs

torch.autograd.set_detect_anomaly(True)

# Get current date and time
now = datetime.now()

# Format the datetime object as a string
formatted_date = now.strftime("%Y-%m-%d %H:%M:%S")

torch.set_float32_matmul_precision("medium")
logger = get_logger(__name__)

def parse_args():
    # parser = argparse.ArgumentParser(description="Train a MaskGIT or Llama-style LLM on video generation.")
    parser = argparse.ArgumentParser(description="Train a spatial-temporal MaskGIT-style model on video generation.")

    # Data
    parser.add_argument(
        "--train_data_dir", type=str, default="data/1x_humanoid_magvit_traj1000_train",
        help="Directory containing tokenized data, should have a `video.bin`, `metadata.json` and `segment_ids.json`."
    )
    parser.add_argument(
        "--val_data_dir", type=str, default="data/1x_humanoid_magvit_traj1000_val",
        help="Directory containing tokenized data, should have a `video.bin`, `metadata.json` and `segment_ids.json`."
    )
    parser.add_argument(
        "--domain", type=str, default="1x_humanoid",
        help="The domain name for the dataset"
    )
    parser.add_argument(
        "--window_size",
        type=int,
        default=12,
        help="Number of frames to in a sequence.",
    )
    parser.add_argument(
        "--stride",
        type=int,
        default=None,
        help="Difference in frame count between consecutive frames in a sequence.",
    )
    parser.add_argument(
        "--filter_overlaps",
        action="store_true",
        help=(
            "Whether to filter repeated frames in the train dataset (`filter_overlaps` always true for the val set). "
            "Filtering essentially makes the training dataset less correlated but ~16x smaller, "
            "see the `filter_overlaps` argument in `RawTokenDataset` for details."),
        default=True
    )

    # Model
    parser.add_argument(
        "--llama_config",
        type=str,
        help="`transformers.LlamaConfig` json. "
             "E.g. https://huggingface.co/1x-technologies/Llama_1B_v0/blob/main/config.json",
    )
    parser.add_argument(
        "--diffusion",
        action="store_true",
        help="use diffusion model."
    ),
    parser.add_argument(
        "--genie_config",
        type=str,
        help="GenieConfig json."
    ),

    parser.add_argument(
        "--warmstart_path",
        type=str,
        default=None,
        help="A path to a checkpoint to warmstart a model from, possibly not trained on the same dataset, "
             "will resize embeddings if needed.",
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help="If the training should continue from a checkpoint folder.",
    )

    # Training
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=4,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=1,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay to use.")
    parser.add_argument("--num_train_epochs", type=int, default=2, help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )
    parser.add_argument(
        "--max_eval_steps",
        type=int,
        default=int(1e10),
        help="Only evaluate on `max_eval_steps` batches of validation data per process, faster.",
    )
    parser.add_argument(
        "--eval_every_n_steps",
        type=int,
        default=1000,
        help="Eval every N training steps.",
    )
    parser.add_argument(
        "--vis_every_n_steps",
        type=int,
        default=20000,
        help="Visualize every N training steps.",
    )
    parser.add_argument(
        "--lr_scheduler_type",
        type=str,
        default="constant_with_warmup",
        help="The scheduler type to use.",
        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "custom_cosine"],
    )
    parser.add_argument(
        "--num_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument(
        "--max_grad_norm",
        type=float,
        default=1.0,
        help="Threshold to clip gradients.",
    )
    parser.add_argument(
        "--attention_dropout",
        type=float,
        default=0.05,
        help="Attention dropout prob.",
    )
    parser.add_argument(
        "--adam_beta_1",
        type=float,
        default=0.9,
    )
    parser.add_argument(
        "--adam_beta_2",
        type=float,
        default=0.95,
    )
    parser.add_argument(
        "--adam_eps",
        type=float,
        default=1e-8,
    )

    # Misc
    parser.add_argument("--output_dir", type=str, required=True, help="Where to store the model checkpoints.")
    parser.add_argument(
        "--checkpointing_steps",
        type=str,
        default="10000",
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
    parser.add_argument(
        "--overfit_first_batch",
        action="store_true",
        help=(
            "Debug option that trains and validates on only the first batch of the training dataset."
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="wandb",
        help="The integration to report the results and logs to.",
    )
    parser.add_argument(
        "--mu_transfer",
        action="store_true",
        help="If specified, will train with mu transfer reparametrizations. Only supports Llama models.",
        default=True
    )
    parser.add_argument(
        "--no_compile",
        action="store_true",
        help="If specified, will not compile the model.",
        default=True
    )

    parser.add_argument(
        "--run_name",
        type=str,
        default="video_prediction",
        help="",
    )
    parser.add_argument(
        "--cleanup_checkpoints",
        action="store_true",
        help=(
            "Whether to clean up checkpoints (to keep only the last 3) along the training. "),
    )
    parser.add_argument(
        "--save_second_epoch",
        action="store_true",
        help="Whether to checkpoint at the end of the second epoch (1-indexing). This one will not be auto-deleted by cleanup.",
        default=True
    )
    return parser


def save_checkpoint(model, accelerator, args, filename):
    """
    filename: `save_path = os.path.join(args.output_dir, filename)`
    """
    unwrapped_model = accelerator.unwrap_model(model)
    save_path = os.path.join(args.output_dir, filename)

    if accelerator.is_main_process:
        unwrapped_model.save_pretrained(
            save_path, is_main_process=accelerator.is_main_process, save_function=accelerator.save
        )
        accelerator.save_state(save_path)


@torch.no_grad()
def visualize(accelerator, model, dataloader, window_size, metrics_prefix="train", max_steps=1):
    """
    Visualizes model's autoregressive generation outputs, logged to wandb.
    It uses teacher-forcing (causal in time axis)
    """

    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    if not unwrapped_model.config.jointly_predict_states:
        return
    metrics = defaultdict(list)
    if accelerator.is_main_process:
        lpips_alex = lpips.LPIPS(net="alex")  # Calculate LPIPS w/ AlexNet, the fastest option

    decode_latents = decode_latents_wrapper()  # re-initializing every time to save memory
    unwrapped_model.eval()
    rank = 0
    dataloader_iter = iter(dataloader)
    for step in range(len(dataloader)):
        try:
            batch = next(dataloader_iter)

            # Note: hardcoding 4 image cap for faster inference on small models
            TEST_NUM = 4
            reshaped_labels = rearrange(batch["labels"][:TEST_NUM], "b (t s) -> b t s", t=window_size).to(accelerator.device)  # `s` is really `(h, w)`
            domains = batch["domain"][:TEST_NUM]

            if 'action_ids' in batch:
                action_ids = batch["action_ids"][:TEST_NUM].to(accelerator.device)
            else:
                action_ids = None

            # hardcoding half of frames for context
            num_prompt_frames = unwrapped_model.config.num_prompt_frames
            num_new_tokens = batch["w"][0] * batch["h"][0] * (window_size - num_prompt_frames)
            prompt_input_ids = rearrange(reshaped_labels[:, :num_prompt_frames], "b t s -> b (t s)")
            outputs = unwrapped_model.generate(input_ids=prompt_input_ids, attention_mask=torch.ones_like(prompt_input_ids),
                                                max_new_tokens=num_new_tokens, min_new_tokens=num_new_tokens,
                                                action_ids=action_ids,
                                                domain=batch["domain"][:TEST_NUM],
                                                w=batch["w"][:TEST_NUM],
                                                h=batch["h"][:TEST_NUM])

            output_tokens = rearrange(outputs, "b (t h w) -> b t h w", t=window_size,
                                    h=batch["h"][0], w=batch["w"][0])
            gtruth_tokens = rearrange(reshaped_labels[:, num_prompt_frames:], "b t (h w) -> b t h w",
                                    h=batch["h"][0], w=batch["w"][0])

            decoded_output = decode_tokens(output_tokens.cpu(), decode_latents)
            decoded_gtruth = decode_tokens(gtruth_tokens.cpu(), decode_latents)

            decoded_output = accelerator.gather(decoded_output.to(accelerator.device)).cpu()
            decoded_gtruth = accelerator.gather(decoded_gtruth.to(accelerator.device)).cpu()

            # As in Genie. we also compute psnr_delta = PSNR(x_t, x_t_hat) - PSNR(x_t, x_t_hatprime) where x_t_hatprime samples random actions
            # this difference in PSNR measures the controllability
            # actions need to be just uniform random actions
            if action_ids is not None:
                random_action_ids = torch.randn_like(action_ids)
                random_action_outputs = unwrapped_model.generate(input_ids=prompt_input_ids, attention_mask=torch.ones_like(prompt_input_ids),
                                                    max_new_tokens=num_new_tokens, min_new_tokens=num_new_tokens,
                                                    action_ids=random_action_ids,
                                                    domain=batch["domain"][:TEST_NUM],
                                                    w=batch["w"][:TEST_NUM],
                                                    h=batch["h"][:TEST_NUM],
                                                    skip_normalization=True)

                random_output_tokens = rearrange(random_action_outputs, "b (t h w) -> b t h w", t=window_size,
                                        h=batch["h"][0], w=batch["w"][0])
                random_output_tokens = decode_tokens(random_output_tokens.cpu(), decode_latents)

                random_output_tokens = accelerator.gather(random_output_tokens.to(accelerator.device)).cpu()
                random_pred_frames_numpy = random_output_tokens[:, num_prompt_frames:].detach().cpu().numpy()


            if accelerator.is_main_process:
                exs_per_fig = 4

                for j in range(0, len(decoded_output), exs_per_fig):
                    fig, axs = plt.subplots(nrows=2 * exs_per_fig, ncols=window_size, figsize=(3 * window_size, 3 * 2 * exs_per_fig))
                    # If len(decoded_output) is not a multiple of 4, make sure to truncate properly
                    for k in range(min(exs_per_fig, len(decoded_output) - j)):
                        for i in range(num_prompt_frames):
                            for ax in (axs[k * 2, i], axs[k * 2 + 1, i]):
                                ax.imshow(transforms_f.to_pil_image(decoded_output[j + k, i]))
                                ax.set_title("Context")
                                ax.axis("off")

                        for i in range(num_prompt_frames, window_size):
                            axs[k * 2, i].imshow(transforms_f.to_pil_image(decoded_gtruth[j + k, i - num_prompt_frames]))
                            axs[k * 2, i].set_title("Ground truth")
                            axs[k * 2 + 1, i].imshow(transforms_f.to_pil_image(decoded_output[j + k, i]))
                            axs[k * 2 + 1, i].set_title("Prediction")
                            for ax in axs[:, i]:
                                ax.axis("off")

                    rank = accelerator.process_index
                    wandb_tracker = accelerator.get_tracker("wandb")
                    # wandb_tracker.log({f"vis_{metrics_prefix}_{j}": fig}, commit=False)
                    wandb_tracker.log({f"{domains[0]}/vis_{metrics_prefix}_{j}": fig}, commit=False)
                    plt.close(fig)

                metrics["ar_lpips"].extend(compute_lpips(decoded_gtruth,  # Note: not parallelizing right now
                                                        decoded_output[:, num_prompt_frames:], lpips_alex))

                gt_frames_numpy = decoded_gtruth.detach().cpu().numpy()
                pred_frames_numpy = decoded_output[:, num_prompt_frames:].detach().cpu().numpy()
                psnr = [image_metrics.peak_signal_noise_ratio(
                    gt_frames_numpy[i] / 255., pred_frames_numpy[i] / 255., data_range=1.0) for i in range(gt_frames_numpy.shape[0])]

                ssim = [np.mean([image_metrics.structural_similarity(
                    gt_frames_numpy[i][j]  / 255., pred_frames_numpy[i][j] / 255., data_range=1.0, channel_axis=0) \
                    for i in range(gt_frames_numpy.shape[0])]) for j in range(gt_frames_numpy.shape[1])]

                # compute some other metrics
                metrics[f"{metrics_prefix}/ar_psnr"].extend(psnr)
                metrics[f"{metrics_prefix}/ar_ssim"].extend(ssim)
                metrics[f"{batch['domain'][0]}/ar_lpips"].extend(compute_lpips(decoded_gtruth,  # Note: not parallelizing right now
                                                                    decoded_output[:, num_prompt_frames:], lpips_alex))

                if action_ids is not None:
                    # log controllability as random subtracts groundtruth
                    psnr_delta = [psnr[i] - image_metrics.peak_signal_noise_ratio(
                        gt_frames_numpy[i] / 255., random_pred_frames_numpy[i] / 255., data_range=1.0) for i in range(gt_frames_numpy.shape[0])]

                    metrics[f"{metrics_prefix}/ar_psnr_delta"].extend(psnr_delta)

        except Exception as e:
            print("batch failed", traceback.format_exc())

        if step + 1 >= max_steps:
            break

    unwrapped_model.train()
    if accelerator.is_main_process:
        metrics = {f"{metrics_prefix}_{key}": np.mean(val) for key, val in metrics.items() if len(val) > 0}

        print(f"{metrics=}")
        wandb_tracker = accelerator.get_tracker("wandb")
        wandb_tracker.log(metrics, commit=False)

def train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args):
    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataloader)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
    completed_steps = 0
    starting_epoch = 0
    resume_step = None
    checkpoint_path = ""

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        try:
            if  os.path.exists(args.resume_from_checkpoint + "/pytorch_model.bin"):
                checkpoint_path = args.resume_from_checkpoint
                path = os.path.basename(args.resume_from_checkpoint.rstrip("/"))
            # else:
            #     checkpoint_path = args.resume_from_checkpoint
            #     path = os.path.basename(args.resume_from_checkpoint.rstrip("/"))
            else:
                # Get the most recent checkpoint
                base_path = os.path.dirname(args.resume_from_checkpoint)
                dirs = [os.path.join(base_path, f.name) for f in os.scandir(base_path) if f.is_dir()]
                dirs.sort(key=os.path.getctime)

                # Sorts folders by date modified, most recent checkpoint is the last
                if len(dirs) > 0:
                    path = dirs[-1]
                    checkpoint_path = path
                    path = os.path.basename(checkpoint_path)

            accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")

            if os.path.exists(checkpoint_path):
                # for finetuning with a different structures
                print(f"loading checkpoint from {checkpoint_path}")
                accelerator.load_state(checkpoint_path, strict=False)
                # tied weights not saved so can't load strict, but also no need to tie again
                # Extract `epoch_{i}` or `step_{i}`
                training_difference = os.path.splitext(path)[0]
            else:
                print("No checkpoint found, training from scratch.")
                training_difference = "step_0"

            if "epoch" in training_difference:
                starting_epoch = int(training_difference.replace("epoch_", "")) + 1
                resume_step = None
                num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
                completed_steps = starting_epoch * num_update_steps_per_epoch
            else:
                # need to multiply `gradient_accumulation_steps` to reflect real steps
                resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
                starting_epoch = resume_step // len(train_dataloader)
                completed_steps = resume_step // args.gradient_accumulation_steps
                resume_step -= starting_epoch * len(train_dataloader)

        except Exception as e:
            training_difference = "step_0"
            starting_epoch = 0
            completed_steps = 0
            print("load checkpoint incomplete", traceback.format_exc())

    # update the progress_bar if load from checkpoint
    progress_bar.update(completed_steps)
    loss_info = torch.zeros(2, device=accelerator.device)  # sum, count

    for epoch in range(starting_epoch, args.num_train_epochs):
        model.train()
        train_dataloader.set_epoch(epoch)

        # potentially cleanup the previous checkpoints
        if args.cleanup_checkpoints:
            if os.path.exists(args.output_dir):
                dirs = [os.path.join(args.output_dir, f.name) for f in os.scandir(args.output_dir) if f.is_dir()]

                if len(dirs) > 3:  # must keep at least 2 checkpoints for second epoch and most recent one
                    if args.save_second_epoch and os.path.join(args.output_dir, "epoch_1") in dirs:  # never prune second epoch
                        dirs.remove(os.path.join(args.output_dir, "epoch_1"))

                    dirs.sort(key=os.path.getctime)
                    paths = dirs[:-3]
                    # only keep the last 3
                    # for path in paths:
                    #     print(f"remove rm -rf {path}")
                    #     os.system(f"rm -rf {path}")

        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
            active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
        else:
            active_dataloader = train_dataloader

        _time = time.time()
        dataloader_iter = iter(active_dataloader)

        # Switch back to train mode
        model.train()
        num_iters_per_epoch = max(len(active_dataloader) - 8, 1) # avoid the last few iters

        for step in range(num_iters_per_epoch):
            try:
                train_action_loss = 0
                batch = next(dataloader_iter)
                # to reduce the numerical instability in the very beginning of training
                gradient_accumulation_steps = args.gradient_accumulation_steps
                batch_size = batch["input_ids"].size(0)
                # Manual gradient accumulation because accelerator somehow taking a lot of memory
                is_update_step = (step + 1) % gradient_accumulation_steps == 0
                ctx_manager = contextlib.nullcontext() if is_update_step else accelerator.no_sync(model)

                with ctx_manager:
                    accelerator.wait_for_everyone()
                    outputs = model(**batch)
                    loss = outputs.loss

                    if not torch.isnan(loss).any():
                        loss_info[0] += loss.detach().mean() * batch_size # only video loss
                        if "action_loss" in outputs:
                            train_action_loss = outputs.action_loss.item()
                            loss += config.action_loss_weight * outputs.action_loss

                        loss_info[1] += batch_size
                        accelerator.backward(loss / gradient_accumulation_steps)
                    else:
                        print("Warning: NaN or Inf detected in loss. Skipping backward pass.")
                        dummy_loss = torch.zeros_like(loss, requires_grad=True)
                        accelerator.backward(dummy_loss)

                if not is_update_step:
                    continue

            except Exception as e:
                # avoid final iteration batch concatenation problems
                print("batch failed",  traceback.format_exc())
                continue
            # Everything below only happens on update step
            if args.max_grad_norm is not None:
                accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            loss_info = accelerator.reduce(loss_info)

            avg_train_loss = (loss_info[0] / loss_info[1]).item()  # sum / count
            loss_info *= 0  # reset sum and count
            try:
                perplexity = math.exp(avg_train_loss)
            except OverflowError:
                print("overflow error for perplexity")
                perplexity = float("inf")

            # print(f"{perplexity=} {avg_train_loss=}")
            batch_time = time.time() - _time  # accumulated batch
            rank = accelerator.process_index

            domain_iter = str(batch['domain'][0])
            _time = time.time()
            accelerator.log(
                {
                    "train_perplexity": perplexity,
                    "train_loss": avg_train_loss,
                    "train_action_loss": train_action_loss,
                    f"stat/{domain_iter}_action_loss": train_action_loss / loss_info[1],
                    f"stat/{domain_iter}_train_perplexity": perplexity,
                    f"stat/{domain_iter}_train_loss": avg_train_loss,
                    "epoch": epoch,
                    "update_step": completed_steps,
                    "examples_processed": completed_steps * args.per_device_train_batch_size
                                          * args.gradient_accumulation_steps * accelerator.num_processes,
                    "learning_rate": lr_scheduler.get_last_lr()[0],
                    "flops": (completed_steps + 1) * experiment_config["FLOPs_per_update_step"],
                    "throughput_examples": experiment_config["effective_batch_size"] / batch_time,
                }, step=completed_steps)

            progress_bar.update(1)
            completed_steps += 1


            # print(f"{completed_steps %  args.checkpointing_steps=} {completed_steps=} {args.checkpointing_steps=}")
            if  completed_steps % int(args.checkpointing_steps) == 0:
                print(f"Saving checkpoint at step {completed_steps}!")
                save_checkpoint(model, accelerator, args, f"step_{completed_steps}")

            if completed_steps % args.eval_every_n_steps == 0:
                time.sleep(1) # manual adding time sleep
                model.eval()
                eval_losses = []

                # Compute token-level accuracy (w/ teacher forcing)
                num_correct = 0
                num_total = 0

                # barrier

                # to resolve the data collating issues
                eval_dataloader_iter = iter(eval_dataloader)
                for step in range(args.max_eval_steps):
                    eval_action_loss = 0
                    try:
                        batch = next(eval_dataloader_iter)
                        batch_size = len(batch["input_ids"])  # Last batch might not be full
                        with torch.no_grad():
                            outputs = model(**batch)

                        loss = outputs.loss
                        if "action_loss" in outputs:
                            eval_action_loss = outputs.action_loss.item()
                        eval_losses.append(accelerator.gather_for_metrics(loss.repeat(batch_size)))
                    except Exception as e:
                        print("error:", e)
                        continue

                    if "acc" in outputs:
                        # `num_correct` and `num_total` actually track mean accuracy in this case.
                        num_correct_batch = accelerator.reduce(outputs.acc, reduction="mean").item() * batch_size
                        num_total_batch = batch_size
                        num_correct += num_correct_batch
                        num_total += num_total_batch
                    else:
                        shifted_preds = torch.argmax(outputs.logits[:, :-1, :], dim=-1)
                        shifted_labels = batch["labels"][:, 1:]
                        num_correct_batch = accelerator.gather_for_metrics((shifted_preds == shifted_labels).sum()).sum().item()
                        num_total_batch = accelerator.gather_for_metrics(torch.tensor(torch.numel(shifted_labels),
                                                                device=accelerator.device)).sum().item()
                        num_correct += num_correct_batch
                        num_total += num_total_batch

                    if step >= args.max_eval_steps * args.num_datasets:
                        break

                    try:
                        accelerator.log(
                        {
                            f'stat/{domain_iter}_eval_teacher_acc': num_correct_batch / num_total_batch,
                            f'stat/{domain_iter}_eval_loss': (torch.mean(eval_losses[-1])).item(),
                            f'stat/{domain_iter}_eval_action_loss': eval_action_loss,

                        },
                        step=completed_steps,
                        )
                    except Exception as e:
                        print("log failed", e)
                        continue

                if len(eval_losses) > 0:
                    eval_losses = torch.cat(eval_losses)
                    eval_loss = torch.mean(eval_losses).item()
                    eval_teacher_acc = num_correct / num_total
                    try:
                        perplexity = math.exp(eval_loss)
                    except OverflowError:
                        print("overflow error for perplexity")
                        perplexity = float("inf")
                else:
                    continue

                logger.info(f"{completed_steps=} {perplexity=} {eval_loss=} {eval_teacher_acc=}")
                accelerator.log(
                    {
                        "eval_perplexity": perplexity,
                        "eval_loss": eval_loss,
                        "eval_action_loss": eval_action_loss,
                        "eval_teacher_acc": eval_teacher_acc,
                        "epoch": epoch,
                        "update_step": completed_steps,
                        "examples_processed": completed_steps * args.per_device_train_batch_size
                                              * args.gradient_accumulation_steps * accelerator.num_processes,
                        "flops": completed_steps * experiment_config["FLOPs_per_update_step"],
                    },
                    step=completed_steps,
                )

            if completed_steps % args.vis_every_n_steps == 0 or completed_steps >= args.max_train_steps:
                if "encoder_type" not in experiment_config:
                    experiment_config["encoder_name_or_path"] = "data/magvit2.ckpt"
                    experiment_config["encoder_type"] = "magvit"

                if not args.overfit_first_batch:  # val is same as train otherwise
                    visualize(accelerator, model, eval_dataloader, args.window_size, "val")

                visualize(accelerator, model, train_dataloader, args.window_size, "train")

            if completed_steps >= args.max_train_steps:
                break

        if args.checkpointing_steps == "epoch" or (args.save_second_epoch and epoch == 1):
            save_checkpoint(model, accelerator, args, f"epoch_{epoch}")

    save_checkpoint(model, accelerator, args, f"final_checkpt")
    accelerator.end_training()

def main():
    parser = parse_args()
    args = parser.parse_args()
    assert (args.llama_config is not None) ^ (args.genie_config is not None), \
        "Exactly one of `llama_config` and `genie_config` should be set."

    # Manual gradient accumulation
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) #
    accelerator = Accelerator(gradient_accumulation_steps=1, log_with=args.report_to, project_dir=args.output_dir, kwargs_handlers=[ddp_kwargs])
    accelerator.init_trackers("video")
    if accelerator.is_main_process:
        accelerator.trackers[0].run.name = formatted_date + "_" + args.run_name

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_info()
        print(f"Rank {accelerator.process_index} assigned to device {torch.cuda.current_device()}")
    else:
        transformers.utils.logging.set_verbosity_error()

    if args.seed is not None:
        set_seed(args.seed)

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    accelerator.wait_for_everyone()
    config = GenieConfig.from_pretrained(args.genie_config)
    train_dataset = RawTokenDataset(args.train_data_dir, window_size=args.window_size, name=args.domain,
                                    stride=args.stride, filter_overlaps=args.filter_overlaps,
                                    compute_stride_from_freq_table=(args.stride is None),
                                    use_actions=config.use_actions)

    if not args.overfit_first_batch:
        eval_dataset = RawTokenDataset(args.val_data_dir, window_size=args.window_size, name=args.domain,
                                        stride=args.stride, filter_overlaps=True,
                                        compute_stride_from_freq_table=(args.stride is None),
                                        use_actions=config.use_actions)
    else:
        train_dataset.valid_start_inds = train_dataset.valid_start_inds[:args.per_device_train_batch_size
                                                                         * args.gradient_accumulation_steps
                                                                         * accelerator.num_processes]
        eval_dataset = train_dataset

    assert all(train_dataset.metadata[shared_key] == eval_dataset.metadata[shared_key]
               for shared_key in ("s", "vocab_size", "hz"))

    latent_side_len, vocab_size, hz = [train_dataset.metadata[key] for key in ("s", "vocab_size", "hz")]

    # Note: changing this may affect pre-trained model due to attn scaling
    config.use_mup = args.mu_transfer
    config.image_vocab_size = vocab_size
    config.T = args.window_size
    model = STMaskGIT(config)

    if config.use_actions:
        print(f"Initializing action projectors with {train_dataset.n_action}d action")
        model.init_action_projectors([train_dataset.name], [train_dataset.n_action], [train_dataset.action_stat], config.action_network)

    if args.mu_transfer:
        model.set_mup_shapes(rescale_params=True)
        # model.init_weights()  # might be unnecessary if `rescale_params` is True

    # Optimizer. Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    # scale base learning rate
    effective_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps \
                           * accelerator.num_processes
    args.learning_rate = args.learning_rate * min(max(1, effective_batch_size / 64), 8)

    opt_class = mup.MuAdamW if args.mu_transfer else torch.optim.AdamW
    optimizer = opt_class(optimizer_grouped_parameters, lr=args.learning_rate,
                          betas=(args.adam_beta_1, args.adam_beta_2), eps=args.adam_eps)

    # DataLoaders creation:
    collate_fn = default_data_collator if args.llama_config is not None else get_maskgit_collator(config)
    train_dataloader = DataLoader(
        train_dataset, shuffle=True, collate_fn=collate_fn,
        batch_size=args.per_device_train_batch_size, num_workers=8, pin_memory=True,
    )

    # Shuffle eval dataset and then set shuffle=False on the dataloader.
    # Shuffling in the dataloader results in reshuffling with each iteration.
    eval_dataset.valid_start_inds = torch.tensor(eval_dataset.valid_start_inds)[
        torch.randperm(len(eval_dataset), generator=torch.Generator().manual_seed(0))
    ].tolist()

    eval_dataloader = DataLoader(
        eval_dataset, shuffle=False, collate_fn=collate_fn,
        batch_size=args.per_device_eval_batch_size, pin_memory=True, num_workers=8,
    )

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        if args.max_train_steps < 2000 and args.resume_from_checkpoint is None: # minimal number of trainng steps
            args.max_train_steps = 2000
            args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

        overrode_max_train_steps = True

    if args.lr_scheduler_type == "custom_cosine":  # decay to `end_ratio` of the peak learning rate
        def get_lr_wrapper(warmup_steps, max_steps, end_ratio=0.1):
            def get_lr(step):
                if step < warmup_steps:
                    return (step + 1) / warmup_steps

                remaining_steps = max_steps - warmup_steps
                return ((1 + math.cos(math.pi * (step - warmup_steps) / remaining_steps)) / 2) \
                    * (1 - end_ratio) + end_ratio
            return get_lr

        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, get_lr_wrapper(args.num_warmup_steps * accelerator.num_processes,
                                      args.max_train_steps if overrode_max_train_steps
                                      else args.max_train_steps * accelerator.num_processes)
        )
    else:
        lr_scheduler = get_scheduler(
            name=args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
            num_training_steps=args.max_train_steps
            if overrode_max_train_steps
            else args.max_train_steps * accelerator.num_processes,
        )

    # Enable gradient checkpointing to save memory
    if args.gradient_checkpointing:
        logger.info("Enabling gradient checkpointing")
        model.gradient_checkpointing_enable()
        model.config.use_cache = False # incompatible with grad checkpointing

    # Prepare everything with our `accelerator`.
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    if not args.no_compile:
        torch._dynamo.config.cache_size_limit = 256
        torch._dynamo.config.optimize_ddp = False  # https://github.com/pytorch/pytorch/issues/104674
        # TODO: https://github.com/pytorch/pytorch/issues/109774#issuecomment-2046633776
        model = torch.compile(model)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Figure out how many steps we should save the Accelerator states
    checkpointing_steps = args.checkpointing_steps
    if checkpointing_steps is not None and checkpointing_steps.isdigit():
        checkpointing_steps = int(checkpointing_steps)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initialize automatically on the main process.
    experiment_config = vars(args) | vars(config)
    seq_len = latent_side_len**2 * args.window_size
    args.num_datasets = 1

    model_module = model.module if hasattr(model, "module") else model
    experiment_config.update({
        "model_parameters": sum(p.numel() for p in model_module.parameters()),
        "model_parameters_M": round(sum(p.numel() for p in model_module.parameters()) / 1e6),
        "trunk_parameters": sum(p.numel() for p in model_module.decoder.parameters()),
        "trunk_parameters_M": round(sum(p.numel() for p in model_module.decoder.parameters()) / 1e6),
        "seq_len": seq_len,
        "hz": hz / train_dataset.stride,
        "train_data_tokens": len(train_dataset) * seq_len, # only one epoch
        "effective_batch_size": effective_batch_size,
        "effective_batch_size_tokens": effective_batch_size * seq_len,
        "mixed_precision": accelerator.mixed_precision,
        "num_datasets": 1
    })

    print("============================")
    print(f"model parameters: {experiment_config['model_parameters_M']}M")
    print("============================")

    experiment_config["FLOPs_per_update_step"] = 6 * experiment_config["model_parameters"] \
                                                 * experiment_config["effective_batch_size_tokens"]
    accelerator.init_trackers(project_name="video", config=experiment_config)
    train(accelerator, model, optimizer, lr_scheduler, train_dataloader, eval_dataloader, experiment_config, config, args)

if __name__ == "__main__":
    main()