import time
from contextlib import suppress
import numpy as np

import torch
from tqdm import tqdm
import datetime
import os
import gc
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)

from torch.utils.tensorboard import SummaryWriter
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(message)s',
    datefmt='%m/%d %I:%M:%S',
)

def get_cast_dtype(precision: str):
    cast_dtype = None
    if precision == "bf16":
        cast_dtype = torch.bfloat16
    elif precision == "fp16":
        cast_dtype = torch.float16
    return cast_dtype


def get_autocast(precision):
    if precision == "amp_fp16":
        return lambda: torch.cuda.amp.autocast(dtype=torch.float16)
    elif precision == "amp_bfloat16" or precision == "amp_bf16":
        # amp_bfloat16 is more stable than amp float16 for clip training
        return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
    else:
        return suppress


def get_sync(model, flag):
    if flag:
        return suppress
    else:
        return lambda: model.no_sync()


def train_one_epoch(
    args,
    model,
    laion_loader,
    pile_loader,
    tokenizer,
    optimizer,
    lr_scheduler,
    device_id,
    writer: SummaryWriter,
    optim_groups,
    scaler,
    total_laion_token: int,
    total_pile_token: int,
    total_laion_sample: int,
    total_step: int,
):
    world_size = torch.distributed.get_world_size()
    autocast = get_autocast(args.precision)
    cast_dtype = get_cast_dtype(args.precision)

    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
    if args.add_box:
        box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
        endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
        endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
    if args.use_format_v2:
        prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
        previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
    if args.rank == 0:
        logging.info(f"train from: {total_step} step")
    model.train()
    # loop through dataloader
    last_logging_step = total_step
    last_save_step = total_step
    for num_steps, (batch_laion, batch_pile) in tqdm(
        enumerate(zip(laion_loader, pile_loader)),
        disable=args.rank != 0 or "SLURM_PROCID" in os.environ,
        total=args.num_steps * args.gradient_accumulation_steps,
        initial=total_step * args.gradient_accumulation_steps,
    ):
        #### LAION FORWARD PASS ####
        images = (
            batch_laion[0]
            .to(device_id, dtype=cast_dtype, non_blocking=True)
            .unsqueeze(1)
            .unsqueeze(1)
        )
        image_nums = batch_laion[1]
        image_start_index_list = batch_laion[2]

        # TODO: OPT model: input_ids is not started with </s> while input_ids2 is?
        input_ids = batch_laion[3].to(device_id, non_blocking=True).long()
        attention_mask = batch_laion[4].to(device_id, dtype=cast_dtype, non_blocking=True)
        added_bbox_list = [x.to(device_id) for x in batch_laion[5]] # list object
        total_laion_token += int(attention_mask.sum().long()) * world_size
        total_laion_sample += sum(image_nums) * world_size

        labels = input_ids.clone()
        if args.add_box:
            labels[input_ids == visual_token_id] = -100
            labels[input_ids == box_token_id] = -100
            labels[input_ids == endofattr_token_id] = -100
            if args.use_format_v2:
                labels[input_ids == previsual_token_id] = -100
                labels[input_ids == prebox_token_id] = -100
                labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
                labels[torch.roll(input_ids == box_token_id, 1)] = -100
        labels[:, 0] = -100
        labels[input_ids == tokenizer.pad_token_id] = -100
        labels[input_ids == media_token_id] = -100
        labels[input_ids == endofmedia_token_id] = -100
        labels.to(device_id)
        current_laion_num = input_ids.shape[0]

        #### PILE FORWARD PASS ####
        if batch_pile is not None and batch_pile[0] is not None and batch_pile[1] is not None:
            input_ids2 = batch_pile[0].to(device_id, non_blocking=True).long()
            attention_mask2 = batch_pile[1].to(device_id, dtype=cast_dtype, non_blocking=True)
            input_length = input_ids.shape[-1]

            input_ids2 = torch.cat([input_ids2, torch.ones((input_ids2.shape[0], input_length - input_ids2.shape[1]), device=input_ids2.device, dtype=input_ids2.dtype) * tokenizer.pad_token_id], dim=-1)
            attention_mask2 = torch.cat([attention_mask2, torch.zeros((attention_mask2.shape[0], input_length - attention_mask2.shape[1]), device=attention_mask2.device, dtype=attention_mask2.dtype)], dim=-1)

            labels2 = input_ids2.clone()
            labels2[labels2 == tokenizer.pad_token_id] = -100
            labels2[:, 0] = -100
            labels2.to(device_id)

            if (num_steps != 0 and num_steps % args.pile_freq == 0) or args.pile_freq == 1:
                image_nums = image_nums + [0] * len(input_ids2)
                image_start_index_list = image_start_index_list + [[]] * len(input_ids2)
                input_ids = torch.cat([input_ids, input_ids2], dim=0)
                attention_mask = torch.cat([attention_mask, attention_mask2], dim=0)
                labels = torch.cat([labels, labels2], dim=0)
                total_pile_token += int(attention_mask2.sum().long()) * world_size
            else:
                del input_ids2
                del attention_mask2
                del labels2

        if args.instruct:
            answer_token_id = tokenizer(" Answer").input_ids[0]
            answer_token_loc = (input_ids == answer_token_id).nonzero()
            for batch_idx, idx in answer_token_loc:
                labels[batch_idx][:idx+2] = -100
        
        if args.relation and not args.instruct:
            relations = batch_laion[6]
        else:
            relations = None
        if len(added_bbox_list) == 0:
            added_bbox_list = None
        update_flag = (num_steps != 0 and num_steps % args.gradient_accumulation_steps == 0) or args.gradient_accumulation_steps == 1
        # do_sync = get_sync(model, update_flag)
        with autocast():
            # modify: 
            #   /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/codegen/modeling_codegen.py
            #   /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py
            # CrossEntropyLoss(reduction="none")
            outputs = model(
                vision_x=images,
                lang_x=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                image_nums=image_nums,
                image_start_index_list=image_start_index_list,
                added_bbox_list=added_bbox_list,
                add_box=args.add_box,
                relations=relations,
            )
            loss_total = outputs.loss.reshape(labels.shape[0], -1)
            loss_sample = loss_total.sum(-1) / (loss_total != 0).sum(-1)
            loss_sample_for_laion = loss_sample[:current_laion_num]
            nan_mask = torch.isnan(loss_sample_for_laion)
            if nan_mask.sum() > 0:
                logging.warning(f"caption NaN: {nan_mask}")
            if nan_mask.sum() == len(loss_sample_for_laion) or not model.valid:
                logging.info("WARNING: skip this caption loss due to some error")
                loss_laion = torch.tensor(0.0).cuda()
            else:
                loss_laion = loss_sample_for_laion[~nan_mask].mean()
            loss_caption = loss_laion
            divided_loss_laion = loss_laion / args.gradient_accumulation_steps
            if current_laion_num != loss_sample.shape[0]:
                loss_pile = loss_sample[current_laion_num:].mean()
            else:
                loss_pile = torch.tensor(0.0).cuda()
            divided_loss_pile = loss_pile / args.gradient_accumulation_steps

            if "detection_losses" in outputs:
                loss_det = outputs["detection_losses"]["loss"]
                loss_iou = outputs["detection_losses"]["loss_iou"]
                loss_obj = outputs["detection_losses"]["loss_obj"]
                loss_cls = outputs["detection_losses"]["loss_cls"]
            else:
                loss_det = torch.tensor(0.0).cuda()
                loss_iou = torch.tensor(0.0).cuda()
                loss_obj = torch.tensor(0.0).cuda()
                loss_cls = torch.tensor(0.0).cuda()

            if "loss_dict" in outputs:
                visual_loss_iou = outputs["loss_dict"][0]["loss_iou"]
                previsual_loss_iou = outputs["loss_dict"][1]["loss_iou"]
                visual_loss_obj = outputs["loss_dict"][0]["loss_obj"]
                previsual_loss_obj = outputs["loss_dict"][1]["loss_obj"]
            else:
                visual_loss_iou = torch.tensor(0.0).cuda()
                previsual_loss_iou = torch.tensor(0.0).cuda()
                visual_loss_obj = torch.tensor(0.0).cuda()
                previsual_loss_obj = torch.tensor(0.0).cuda()

            divided_loss_det = loss_det / args.gradient_accumulation_steps
            loss_rel = outputs.get("rel_loss", torch.tensor(0.0).cuda())
            divided_loss_rel = loss_rel / args.gradient_accumulation_steps
            loss = (
                divided_loss_laion * args.loss_multiplier_laion +
                divided_loss_pile * args.loss_multiplier_pile +
                divided_loss_det * args.loss_multiplier_det +
                divided_loss_rel * args.loss_multiplier_rel
            )

        scaler.scale(loss).backward()

        # for logging only
        loss = (
            loss_laion * args.loss_multiplier_laion
            + loss_pile * args.loss_multiplier_pile
            + loss_det * args.loss_multiplier_det
            + loss_rel * args.loss_multiplier_rel
        ).detach()

        # step optimizer and log
        if update_flag:
            #### MASK GRADIENTS FOR EMBEDDINGS ####
            # Note (anas): Do not apply weight decay to embeddings as it will break this function.
            # ! not an important point
            # if args.ddp:
            #     def mask_embedding(m):
            #         if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad:
            #             zero_mask = torch.zeros_like(m.weight.grad)
            #             zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
            #             zero_mask[endofmedia_token_id] = torch.ones_like(zero_mask[endofmedia_token_id])
            #             m.weight.grad = m.weight.grad * zero_mask
            #     model.apply(mask_embedding)
            total_step += 1
            scaler.unscale_(optimizer)
            if args.ddp:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            else:
                model.clip_grad_norm_(1.0)
            scaler.step(optimizer)
            scaler.update()
            lr_scheduler.step()
            optimizer.zero_grad()
            # https://github.com/facebookresearch/fairscale/issues/627
            model.zero_grad(set_to_none=True)

        if args.rank == 0 and total_step % args.logging_steps == 0 and total_step != last_logging_step:
            last_logging_step = total_step
            global_step = total_step
            lr = optimizer.param_groups[0]["lr"]
            writer.add_scalar("lr", lr, global_step)
            writer.add_scalar("scale", scaler.get_scale(), global_step)
            writer.add_scalar("loss_groundcaption", loss_laion.item(), global_step)
            writer.add_scalar("loss_laion", loss_caption.item(), global_step)
            writer.add_scalar("loss_pile", loss_pile.item(), global_step)
            writer.add_scalar("loss", loss.item(), global_step)
            writer.add_scalar("loss_det", loss_det.item(), global_step)
            writer.add_scalar("loss_iou", loss_iou.item(), global_step)
            writer.add_scalar("loss_obj", loss_obj.item(), global_step)
            writer.add_scalar("loss_cls", loss_cls.item(), global_step)
            if loss_rel.item() != 0:
                writer.add_scalar("loss_rel", loss_rel.item(), global_step)
            if args.use_format_v2:
                writer.add_scalar("loss_iou_visual", visual_loss_iou.item(), global_step)
                writer.add_scalar("loss_obj_visual", visual_loss_obj.item(), global_step)
                writer.add_scalar("loss_iou_previsual", previsual_loss_iou.item(), global_step)
                writer.add_scalar("loss_obj_previsual", previsual_loss_obj.item(), global_step)

            global_sample_num = total_laion_sample
            writer.add_scalar("loss_groundcaption_vs_sample_num", loss_laion.item(), global_sample_num)
            writer.add_scalar("loss_laion_vs_sample_num", loss_caption.item(), global_sample_num)
            writer.add_scalar("loss_pile_vs_sample_num", loss_pile.item(), global_sample_num)
            writer.add_scalar("loss_vs_sample_num", loss.item(), global_sample_num)
            writer.add_scalar("loss_det_vs_sample_num", loss_det.item(), global_sample_num)
            writer.add_scalar("loss_iou_vs_sample_num", loss_iou.item(), global_sample_num)
            writer.add_scalar("loss_obj_vs_sample_num", loss_obj.item(), global_sample_num)
            if loss_rel.item() != 0:
                writer.add_scalar("loss_rel_vs_sample_num", loss_rel.item(), global_sample_num)
            writer.add_scalar("lr_vs_sample_num", optimizer.param_groups[0]["lr"], global_sample_num)

            writer.add_scalar("loss_groundcaption_vs_token", loss_laion.item(), total_laion_token)
            writer.add_scalar("loss_laion_vs_token", loss_caption.item(), total_laion_token)
            writer.add_scalar("loss_pile_vs_token", loss_pile.item(), total_pile_token)
            writer.add_scalar("loss_det_vs_token", loss_det.item(), total_laion_token)
            writer.add_scalar("loss_iou_vs_token", loss_iou.item(), total_laion_token)
            writer.add_scalar("loss_obj_vs_token", loss_obj.item(), total_laion_token)
            writer.add_scalar("loss_cls_vs_token", loss_cls.item(), total_laion_token)
            if loss_rel.item() != 0:
                writer.add_scalar("loss_rel_vs_token", loss_rel.item(), total_laion_token)

            total_token = total_laion_token + total_pile_token
            writer.add_scalar("sample_num", global_sample_num, global_step)
            writer.add_scalar("total_laion_token", total_laion_token, global_step)
            writer.add_scalar("total_pile_token", total_pile_token, global_step)
            writer.add_scalar("total_token", total_token, global_step)
            logging.info(
                f"[{global_step}][{total_laion_sample}][{total_token}]. total: {loss.item():.3f} //  laion: {loss_caption.item():.3f} // pile: {loss_pile.item():.3f} // iou: {loss_iou.item():.4f} // obj: {loss_obj.item():.4f} // previsual_obj: {previsual_loss_obj.item():.4f} // visual_obj: {visual_loss_obj.item():.4f} // previsual_iou: {previsual_loss_iou.item():.4f} // visual_iou: {visual_loss_iou.item():.4f} // lr: {lr:.2e} // scale: {scaler.get_scale()}"
            )

        if total_step % args.save_interval == 0 and total_step != last_save_step:
            last_save_step = total_step
            torch.distributed.barrier()
            if args.ddp:
                cpu_state = model.state_dict()
                # if args.rank == 0:
                #     optimizer_state = optimizer.state_dict()
            else:
                save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
                with FSDP.state_dict_type(
                    model, StateDictType.FULL_STATE_DICT, save_policy
                ):
                    cpu_state = model.state_dict()
                torch.distributed.barrier()
                # https://pytorch.org/docs/1.12/fsdp.html
                # need to pass optim_groups as optim_input
                # optimizer_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_groups)
            if args.rank == 0:
                checkpoint_dict = {
                    "model_state_dict": cpu_state,
                    # "optimizer_state_dict": optimizer_state,
                    "lr_scheduler_state_dict": lr_scheduler.state_dict(),
                    "scaler_state_dict": scaler.state_dict(),
                    "total_pile_token": total_pile_token,
                    "total_laion_token": total_laion_token,
                    "total_laion_sample": total_laion_sample,
                    "total_step": total_step,
                }
                logging.info(f"Saving checkpoint to {args.run_name}/checkpoint_{total_step}.pt")
                torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{total_step}.pt")
                del checkpoint_dict
                if args.delete_previous_checkpoint and total_step-args.save_interval > 0 and (total_step-args.save_interval) % args.skip_delete_pattern != 0:
                    try:
                        os.remove(f"{args.run_name}/checkpoint_{total_step-args.save_interval}.pt")
                    except:
                        pass
            torch.distributed.barrier()


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count