""" Main training script """

import argparse
import copy
import glob
import os
import random
import functools

import numpy as np
import torch
# torch.multiprocessing.set_sharing_strategy('file_system')
import wandb
from data2 import get_data
from distributed import init_distributed_device, world_info_from_env
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
    FullStateDictConfig,
    CPUOffload,
    StateDictType,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)

from train_utils import train_one_epoch
from transformers import (
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
)

from open_flamingo import create_model_and_transforms
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler
from torch.distributed.optim import ZeroRedundancyOptimizer
import warnings
warnings.filterwarnings("ignore")
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(message)s',
    datefmt='%m/%d %I:%M:%S',
)

class FakeDataloader:
    def __iter__(self):
        return self
    
    def __next__(self):
        return None

def random_seed(seed=42, rank=0):
    torch.manual_seed(seed + rank)
    np.random.seed(seed + rank)
    random.seed(seed + rank)


def get_grouped_params(model, args):
    params_with_wd, params_without_wd = [], []

    def apply_decay(x):
        x = x.lower()
        return "norm" not in x and "bn" not in x and "bias" not in x and "embed" not in x and "wte" not in x and "flat_param" not in x

    for n, p in model.named_parameters():
        # if p.requires_grad:
        if apply_decay(n):
            if torch.distributed.get_rank() == 0:
                logging.info(f"with wd: {n}")
            params_with_wd.append(p)
        else:
            if torch.distributed.get_rank() == 0:
                logging.info(f"without wd: {n}")
            params_without_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": args.weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]


def lambda_policy_fn(module):
    if (
        len(list(module.named_children())) == 0
        and getattr(module, "weight", None) is not None
        and module.weight.requires_grad
    ):
        return True
    return False


def lambda_auto_wrap_policy(
    module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn,
) -> bool:
    """
    A convenient auto wrap policy to wrap submodules based on an arbitrary user
    function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
    a `wrapper_cls` unit.

    Return if a module should be wrapped during auto wrapping.

    The first three parameters are required by :func:`_recursive_wrap`.

    Args:
        module (nn.Module): Current module being considered.
        recurse (bool): If ``False``, then this function must decide whether
            ``module`` should be wrapped as an FSDP instance or not. If
            ``True``, then the function is still recursing down the module
            tree as a part of the DFS.
        nonwrapped_numel (int): Parameter numel not yet wrapped.

        lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
            this module will be wrapped.
    """
    if recurse:
        return True  # always recurse
    return lambda_fn(module)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str)
    parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str)
    parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
    parser.add_argument(
        "--tokenizer_path",
        default="facebook/opt-1.3b",
        type=str,
        help="path to tokenizer",
    )
    parser.add_argument(
        "--run_name",
        type=str,
        default="openflamingo3B",
        help="used to name saving directory and wandb run",
    )
    parser.add_argument("--use_media_placement_augmentation", action="store_true")
    parser.add_argument("--offline", action="store_true")
    parser.add_argument("--num_steps", type=int, default=300000)
    parser.add_argument(
        "--logging_steps", type=int, default=10, help="log loss every n steps"
    )
    # Sum of gradient optimization batch size
    parser.add_argument("--batch_size_mmc4", type=int, default=128)
    parser.add_argument("--batch_size_laion", type=int, default=128)
    parser.add_argument("--batch_size_pile", type=int, default=128)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
        default=None,
    )
    parser.add_argument(
        "--delete_previous_checkpoint",
        action="store_true",
        help="delete previous checkpoint when saving new checkpoint",
    )
    parser.add_argument(
        "--laion_shards",
        type=str,
        help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
    )
    parser.add_argument(
        "--mmc4_shards",
        type=str,
        help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
    )
    parser.add_argument(
        "--pile_shards",
        type=str,
        default=None,
        help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
    )
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--learning_rate", default=1e-4, type=float)
    parser.add_argument(
        "--lr_scheduler",
        default="constant",
        type=str,
        help="constant, linear, or cosine",
    )
    parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
    parser.add_argument("--loss_multiplier_laion", type=float, default=1.0)
    parser.add_argument("--loss_multiplier_pile", type=float, default=1.0)
    parser.add_argument("--loss_multiplier_det", type=float, default=1.0)
    parser.add_argument("--loss_multiplier_rel", type=float, default=1.0)
    parser.add_argument("--loss_multiplier_attn", type=float, default=1.0)
    parser.add_argument("--warmup_steps", default=5000, type=int)
    # weight decay is only apply to YOLOX head if using FSDP
    # https://medium.com/@huanghaian123/optimize-and-accelerate-yolox-with-rtmdet-hyps-in-mmyolo-80fc06d61159
    parser.add_argument("--weight_decay", default=0.05, type=float)
    parser.add_argument(
        "--precision",
        choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
        default="fp32",
        help="Floating point precision.",
    )
    # data args
    parser.add_argument("--workers", type=int, default=1)
    parser.add_argument("--dataset_resampled", action="store_true")
    # distributed training args
    parser.add_argument(
        "--dist-url",
        default="env://",
        type=str,
        help="url used to set up distributed training",
    )
    parser.add_argument(
        "--dist-backend", default="nccl", type=str, help="distributed backend"
    )
    parser.add_argument(
        "--horovod",
        default=False,
        action="store_true",
        help="Use horovod for distributed training.",
    )
    parser.add_argument(
        "--no-set-device-rank",
        default=False,
        action="store_true",
        help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
    )
    # wandb args
    parser.add_argument("--report_to_wandb", default=False, action="store_true")
    parser.add_argument(
        "--wandb_project",
        type=str,
    )
    parser.add_argument(
        "--wandb_entity",
        type=str,
    )
    parser.add_argument(
        "--save_checkpoints_to_wandb",
        default=False,
        action="store_true",
        help="save checkpoints to wandb",
    )
    parser.add_argument(
        "--checkpoint_activations",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--freeze_vision_encoder",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--mmc4_textsim_threshold",
        default=30,
        type=float,
        help="threshold for filtering images in mmc4 based on image-text similarity",
    )
    parser.add_argument(
        "--location_token_num",
        default=1000,
        type=int,
    )
    parser.add_argument(
        "--vis_embed_size",
        type=int,
        required=False,
    )
    parser.add_argument(
        "--save_interval",
        default=1000,
        type=int,
        required=False,
    )
    parser.add_argument(
        "--skip_delete_pattern",
        default=1500,
        type=int,
        required=False,
    )
    parser.add_argument(
        "--ddp",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--pile_freq",
        default=1,
        type=int,
        required=False,
    )
    parser.add_argument(
        "--restart",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--lora",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--lora_r",
        default=16,
        type=int,
        required=False,
    )
    parser.add_argument(
        "--single",
        default=False,
        action="store_true",
    )

    # Finetune
    parser.add_argument(
        "--instruct",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--fix-ffn",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--prob_ground",
        default=1.0,
        type=float,
        required=False,
    )
    parser.add_argument(
        "--optimizer",
        default="adamw",
        type=str,
        required=False,
    )
    parser.add_argument(
        "--add_visual_token",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--use_format_v2",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--use_sam",
        default=None,
        type=str,
        required=False,
    )
    parser.add_argument(
        "--max-length",
        default=608,
        type=int,
        required=False,
    )
    parser.add_argument(
        "--image-size",
        default=256,
        type=int,
        required=False,
    )
    parser.add_argument(
        "--reset_llm",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--add_box",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--add_pe",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--only_grounded_sample",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--expand",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--delete_contained",
        default=False,
        action="store_true",
    )

    parser.add_argument(
        "--relation",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--attn_reg",
        default="l1",
        type=str,
        required=False,
    )
    parser.add_argument(
        "--enhance_data",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--no_visual",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--no_previsual",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--roi_align",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--roi_output_size",
        default=4,
        type=int,
        required=False,
    )
    parser.add_argument(
        "--apply_mask",
        default=False,
        action="store_true",
    )
    parser.add_argument(
        "--longer_previsual",
        default=False,
        action="store_true",
    )

    args = parser.parse_args()
    assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation"
    if args.no_previsual:
        assert args.no_visual, "no_previsual MUST come with no_visual"
    assert not args.enhance_data, "dont enable enhance_data"

    if args.offline:
        os.environ["WANDB_MODE"] = "offline"
        os.environ["TRANSFORMERS_OFFLINE"] = "1"

    args.local_rank, args.rank, args.world_size = world_info_from_env()
    print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
    device_id = init_distributed_device(args)

    random_seed(args.seed)
    model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms(
        args.vision_encoder_path,
        args.vision_encoder_pretrained,
        args.lm_path,
        args.tokenizer_path if args.tokenizer_path else args.lm_path,
        use_local_files=args.offline,
        use_media_placement_augmentation=args.use_media_placement_augmentation,
        checkpoint_activations=args.checkpoint_activations,
        freeze_vision_encoder=args.freeze_vision_encoder,
        location_token_num=args.location_token_num,
        lora=args.lora,
        lora_r=args.lora_r,
        fix_ffn=args.fix_ffn,
        add_visual_token=args.add_visual_token,
        add_box=args.add_box,
        add_pe=args.add_pe,
        add_relation=args.relation,
        use_format_v2=args.use_format_v2,
        use_sam=args.use_sam,
        enhance_data=args.enhance_data,
        roi_align=args.roi_align,
        roi_output_size=args.roi_output_size,
        apply_mask=args.apply_mask,
    )
    if args.reset_llm:
        llm_state_dict = model.lang_encoder.state_dict()
    if args.rank == 0:
        print(args)
        print(image_processor)

    random_seed(args.seed, args.rank)

    if args.rank == 0 and args.report_to_wandb:
        wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=args.run_name,
            config=vars(args),
        )

    device_id = args.rank % torch.cuda.device_count()
    if args.ddp:
        print("use ddp mode")
        model = model.to(device_id)
        model = DDP(model)
    else:
        fpSixteen = MixedPrecision(
            param_dtype=torch.float16,
            # Gradient communication precision.
            reduce_dtype=torch.float16,
            # Buffer precision.
            # buffer_dtype=torch.float16,
        )
        # from transformers.models.opt.modeling_opt import OPTDecoderLayer
        from open_clip.transformer import ResidualAttentionBlock
        from open_flamingo.src.flamingo_lm import FlamingoLayer
        from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention
        from segment_anything.modeling.image_encoder import Block
        transformer_layer_cls=[
            FlamingoLayer,
            ResidualAttentionBlock,
            Block,
        ]
        if args.fix_ffn:
            transformer_layer_cls.append(OPTAttention)
        auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls=transformer_layer_cls,
        )
        if args.lora:
            from torch.distributed.fsdp.wrap import _or_policy
            lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
            auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy])
            ignored_modules = [model.vision_encoder]
            # ignored_modules = None
        else:
            ignored_modules = [model.detection_head]
            # ignored_modules = None
        if args.add_pe:
            ignored_modules += [model.pos_enc]
        # if args.use_format_v2:
        #     ignored_modules += [model.lang_encoder.visual_guided_lm_head]
        model = FSDP(
            model,
            auto_wrap_policy=auto_wrap_policy,
            mixed_precision=fpSixteen,
            device_id=torch.cuda.current_device(),
            ignored_modules=ignored_modules,
            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
        )
        model = model.to(device_id)


    pile_dataset = None
    if args.instruct:
        laion_dataset = get_data(args, image_processor, tokenizer, "instruct")
    else:
        laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text")
    if args.pile_shards is not None:
        pile_dataset = get_data(args, image_processor, tokenizer, "pile")


    optim_groups = get_grouped_params(model, args)
    # optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
    if args.ddp:
        optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
        # optimizer = ZeroRedundancyOptimizer(
        #     optim_groups,
        #     optimizer_class=torch.optim.AdamW,
        #     lr=args.learning_rate,
        #     parameters_as_bucket_view=True,
        # )
    else:
        if args.optimizer == "adamw":
            print("use adamw")
            optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
        elif args.optimizer == "sgd":
            print("use sgd...")
            optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
        else:
            raise NotImplementedError

    total_training_steps = args.num_steps

    if args.rank == 0:
        logging.info(f"Total training steps: {total_training_steps}")

    if args.lr_scheduler == "linear":
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=total_training_steps,
        )
    elif args.lr_scheduler == "cosine":
        lr_scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=total_training_steps,
        )
    else:
        lr_scheduler = get_constant_schedule_with_warmup(
            optimizer, num_warmup_steps=args.warmup_steps
        )
    if args.ddp:
        scaler = GradScaler()
    else:
        scaler = ShardedGradScaler()
    total_laion_token = 0
    total_pile_token = 0
    total_laion_sample = 0
    total_step = 0

    # check if a checkpoint exists for this run
    if os.path.exists(f"{args.run_name}"):
        checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
        if len(checkpoint_list) == 0:
            if args.rank == 0:
                logging.info(f"Found no checkpoints for run {args.run_name}.")
        else:
            args.resume_from_checkpoint = sorted(
                checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
            )[-1]
            if args.rank == 0:
                logging.info(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.")
            args.restart = False
            if args.rank == 0:
                logging.info("do not restart because an existed checkpoint is found")
    if args.resume_from_checkpoint is not None:
        if args.rank == 0:
            logging.info(f"Loading checkpoint from {args.resume_from_checkpoint}")
        checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
        torch.distributed.barrier()
        if args.ddp:
            model.module.load_state_dict(checkpoint["model_state_dict"], strict=False)
            # sharded_osd = checkpoint['optimizer_state_dict']
        else:
            with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
                if args.reset_llm:
                    for key in checkpoint["model_state_dict"]:
                        if key.startswith("lang_encoder"):
                            if args.rank == 0:
                                logging.info(f"reset {key}")
                            llm_key = key.replace("lang_encoder.", "")
                            checkpoint["model_state_dict"][key] = llm_state_dict[llm_key]
                model_state_dict = model.state_dict()
                for key in checkpoint["model_state_dict"].keys():
                    if model_state_dict[key].shape != checkpoint["model_state_dict"][key].shape:
                        if args.rank == 0:
                            logging.info(f'{key}: shape mismatched! {model_state_dict[key].shape} vs {checkpoint["model_state_dict"][key].shape}')
                        checkpoint["model_state_dict"][key] = model_state_dict[key].clone()
                del model_state_dict
                model.load_state_dict(checkpoint["model_state_dict"], False)
            # sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups)
        if not args.restart:
            # optimizer.load_state_dict(sharded_osd)
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
            # scaler.load_state_dict(checkpoint["scaler_state_dict"])
            total_laion_token = checkpoint.get("total_laion_token", 0)
            total_pile_token = checkpoint.get("total_pile_token", 0)
            total_laion_sample = checkpoint.get("total_laion_sample", 0)
            total_step = checkpoint.get("total_step", 0)
            if args.rank == 0:
                logging.info("load training statistics...")
        else:
            if args.rank == 0:
                logging.info("restart training / finetuning. only load model weight...")
        del checkpoint
        if args.reset_llm:
            del llm_state_dict
        torch.cuda.empty_cache()
        torch.distributed.barrier()

    model.train()
    if args.rank == 0:
        if not os.path.exists(args.run_name):
            os.makedirs(args.run_name)
        writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog"))
    else:
        writer = None

    laion_dataset.set_epoch(total_step)
    laion_loader = laion_dataset.dataloader
    if pile_dataset is not None:
        pile_dataset.set_epoch(total_step)
        pile_loader = pile_dataset.dataloader
    else:
        pile_loader = FakeDataloader()
    train_one_epoch(
        args=args,
        model=model,
        tokenizer=tokenizer,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        laion_loader=laion_loader,
        pile_loader=pile_loader,
        device_id=device_id,
        writer=writer,
        scaler=scaler,
        optim_groups=optim_groups,
        total_laion_token=total_laion_token,
        total_pile_token=total_pile_token,
        total_laion_sample=total_laion_sample,
        total_step=total_step,
    )

if __name__ == "__main__":
    main()