#!/usr/bin/env python
# coding=utf-8
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
import sys

import datasets
import numpy as np
import torch
import transformers
from aac_metrics import evaluate
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import (
    AutoTokenizer,
    BartConfig,
    get_inverse_sqrt_schedule,
    get_scheduler,
)

from data.collator import DataCollatorForEnClapBart
from data.preprocess import Preprocessor
from modeling.enclap_bart import EnClapBartForConditionalGeneration

logger = get_logger(__name__)
metric_list = ["meteor", "spider"]


def main():
    # Load Configuration
    cfg_path = sys.argv[1]
    args = OmegaConf.load(cfg_path)

    # Initialize Logging
    accelerator_log_kwargs = {}
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    if args.with_tracking:
        accelerator_log_kwargs["log_with"] = args.report_to
        accelerator_log_kwargs["project_dir"] = args.output_dir

    # Initialize Accelerator
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        split_batches=args.split_batches,
        kwargs_handlers=[ddp_kwargs],
        **accelerator_log_kwargs,
    )
    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
        with open(os.path.join(args.output_dir, "args.yaml"), "w") as f:
            OmegaConf.save(args, f)
    accelerator.wait_for_everyone()

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    file_handler = logging.FileHandler(os.path.join(args.output_dir, "train_log.txt"))
    logger.logger.addHandler(file_handler)
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_warning()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Get the datasets
    data_files = {}
    data_files_eval = {}
    if args.train_file is not None:
        data_files["train"] = args.train_file
    if args.validation_file is not None:
        data_files_eval["validation"] = args.validation_file

    extension = args.train_file.split(".")[-1]
    raw_datasets = load_dataset(extension, data_files=data_files)
    raw_datasets_eval = load_dataset(extension, data_files=data_files_eval)

    # Load pretrained model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
    if args.config_name_or_path is not None:
        config = BartConfig.from_pretrained(args.config_name_or_path)
    else:
        config = None

    if args.model_name_or_path is not None:
        if config is None:
            model = EnClapBartForConditionalGeneration.from_pretrained(
                args.model_name_or_path
            )
        else:
            model = EnClapBartForConditionalGeneration.from_pretrained(
                args.model_name_or_path, config=config
            )
    else:
        model = EnClapBartForConditionalGeneration(config=config)

    # Set the generation config
    if args.val_max_target_length is None:
        args.val_max_target_length = args.max_target_length

    # Set max encodec length based on the shape of the positional encoding
    max_encodec_length = model.config.max_position_embeddings - 2
    label_pad_token_id = (
        -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    )
    preprocessor = Preprocessor(
        args.encodec_base_path,
        args.clap_base_path,
        tokenizer,
        model.config.max_position_embeddings,
        args.encodec_masking_prob,
        args.encodec_masking_span,
        label_pad_token_id,
        model.config.encodec_vocab_size,
        args.eval_num_captions,
    )

    with accelerator.main_process_first():
        train_dataset = raw_datasets["train"].map(
            preprocessor.preprocess_train,
            num_proc=args.preprocessing_num_workers,
            load_from_cache_file=not args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
        train_dataset.set_format(
            "pt",
            columns=[
                "input_ids",
                "attention_mask",
                "clap",
                "labels",
                "decoder_attention_mask",
            ],
        )

        # Temporarily set max_target_length for validation.
        eval_dataset = raw_datasets_eval["validation"].map(
            preprocessor.preprocess_eval,
            num_proc=args.preprocessing_num_workers,
            load_from_cache_file=not args.overwrite_cache,
            desc="Running tokenizer on dataset",
        )
        eval_dataset.set_format(
            "pt",
            columns=["input_ids", "attention_mask", "clap"],
            output_all_columns=True,
        )

    train_data_collator = DataCollatorForEnClapBart(
        tokenizer=tokenizer,
        model=model,
        return_tensors="pt",
        label_pad_token_id=label_pad_token_id,
        max_length=max_encodec_length,
        encodec_masking_prob=args.encodec_masking_prob,
        encodec_masking_span=args.encodec_masking_span,
    )
    valid_data_collator = DataCollatorForEnClapBart(
        tokenizer=tokenizer,
        model=model,
        return_tensors="pt",
        label_pad_token_id=label_pad_token_id,
        max_length=max_encodec_length,
    )

    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=train_data_collator,
        batch_size=args.per_device_train_batch_size,
    )
    eval_dataloader = DataLoader(
        eval_dataset,
        collate_fn=valid_data_collator,
        batch_size=args.per_device_eval_batch_size,
    )

    # Optimizer
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps
    )
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    if args.lr_scheduler_type == "inverse_sqrt" and hasattr(args, "time_scale"):
        lr_scheduler = get_inverse_sqrt_schedule(
            optimizer=optimizer,
            num_warmup_steps=args.num_warmup_steps,
            timescale=args.time_scale,
        )
    else:
        lr_scheduler = get_scheduler(
            name=args.lr_scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=args.num_warmup_steps,
            num_training_steps=args.max_train_steps,
        )

    # Prepare everything with our `accelerator`.
    (
        model,
        optimizer,
        train_dataloader,
        eval_dataloader,
        lr_scheduler,
    ) = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps
    )
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Figure out how many steps we should save the Accelerator states
    checkpointing_steps = args.checkpointing_steps
    if checkpointing_steps is not None and checkpointing_steps.isdigit():
        checkpointing_steps = int(checkpointing_steps)

    # The trackers initializes automatically on the main process.
    if args.with_tracking:
        accelerator.init_trackers(args.logging_dir)

    # Train!
    total_batch_size = (
        args.per_device_train_batch_size
        * accelerator.num_processes
        * args.gradient_accumulation_steps
    )

    if args.split_batches:
        total_batch_size = int(total_batch_size / accelerator.num_processes)

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(
        f"  Instantaneous batch size per device = {args.per_device_train_batch_size}"
    )
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")

    completed_steps = 0
    starting_epoch = 0
    # Potentially load in the weights and states from a previous save
    if not args.overwrite_output_dir and os.path.exists(
        os.path.join(args.output_dir, "checkpoints")
    ):
        if args.resume_from_checkpoint is not None:
            accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
            accelerator.load_state(args.resume_from_checkpoint)
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = [
                f
                for f in os.scandir(os.path.join(args.output_dir, "checkpoints"))
                if f.is_dir()
            ]
            dirs.sort(key=os.path.getctime)
            path = dirs[
                -1
            ].name  # Sorts folders by date modified, most recent checkpoint is the last
            accelerator.print(f"Resumed from checkpoint: {dirs[-1]}")
            accelerator.load_state(dirs[-1])
        # Extract `epoch_{i}` or `step_{i}`
        training_difference = os.path.splitext(path)[0]

        if "epoch" in training_difference:
            starting_epoch = int(training_difference.replace("epoch_", "")) + 1
            resume_step = None
            completed_steps = starting_epoch * num_update_steps_per_epoch
        else:
            # need to multiply `gradient_accumulation_steps` to reflect real steps
            resume_step = (
                int(training_difference.replace("step_", ""))
                * args.gradient_accumulation_steps
            )
            starting_epoch = resume_step // len(train_dataloader)
            resume_step -= starting_epoch * len(train_dataloader)
            completed_steps = resume_step // args.gradient_accumulation_stepp

    # update the progress_bar if load from checkpoint
    if args.with_tracking:
        total_loss = 0
        logging_loss = 0
        before_epoch_loss = 0

        if args.encodec_masking_prob > 0:
            total_encodec_loss = 0
            logging_encodec_loss = 0
            before_epoch_encodec_loss = 0

    for epoch in range(starting_epoch, args.num_train_epochs):
        model.train()
        if (
            args.resume_from_checkpoint
            and epoch == starting_epoch
            and resume_step is not None
        ):
            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
            active_dataloader = accelerator.skip_first_batches(
                train_dataloader, resume_step
            )
        else:
            active_dataloader = train_dataloader
        logger.info(f"***** Running epoch {epoch} *****")
        epoch_iterator = tqdm(
            active_dataloader,
            desc="Training",
            disable=not accelerator.is_local_main_process,
            dynamic_ncols=True,
            colour="CYAN",
        )
        for step, batch in enumerate(epoch_iterator):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                # We keep track of the loss at each epoch
                if args.with_tracking:
                    total_loss += outputs.lm_loss.item()
                    if args.encodec_masking_prob > 0:
                        if outputs.encodec_loss is not None:
                            total_encodec_loss += outputs.encodec_loss.item()
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(
                        model.parameters(), max_norm=args.max_grad_norm
                    )
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

                # Checks if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    completed_steps += 1
                    # Add loss information to tqdm
                    epoch_iterator.set_postfix(loss=total_loss / completed_steps)

                    if completed_steps % args.logging_steps == 0:
                        train_log = {
                            "train/learning_rate": lr_scheduler.get_last_lr()[0]
                        }
                        train_log["train/loss"] = (
                            total_loss - logging_loss
                        ) / args.logging_steps
                        logging_loss = total_loss
                        if args.encodec_masking_prob > 0:
                            train_log["train/encodec_loss"] = (
                                total_encodec_loss - logging_encodec_loss
                            ) / args.logging_steps
                            logging_encodec_loss = total_encodec_loss
                        accelerator.log(train_log, step=completed_steps)

            if isinstance(checkpointing_steps, int):
                if completed_steps % checkpointing_steps == 0:
                    output_dir = f"step_{completed_steps }"
                    if args.output_dir is not None:
                        output_dir = os.path.join(
                            args.output_dir, "checkpoints", output_dir
                        )
                    accelerator.save_state(output_dir)

            if completed_steps >= args.max_train_steps:
                break

        model.eval()
        gen_kwargs = {
            "max_length": args.val_max_target_length,
        }
        predictions = []
        references = []
        eval_iterator = tqdm(
            eval_dataloader,
            desc="Validation",
            disable=not accelerator.is_local_main_process,
            dynamic_ncols=True,
            colour="MAGENTA",
        )
        for step, batch in enumerate(eval_iterator):
            # Drop the padded samples of the last batch of dataloader
            # try:
            #    if accelerator.gradient_state.end_of_dataloader and accelerator.gradient_state.remainder > 0:
            #        batch = batch[:accelerator.gradient_state.remainder]
            # except:
            #    pass

            with torch.no_grad():
                batch["input_ids"] = batch["input_ids"].cuda()
                batch["clap"] = batch["clap"].cuda()
                batch["attention_mask"] = batch["attention_mask"].cuda()
                batch["eos_mask"] = batch["eos_mask"].cuda()

                generated_tokens = accelerator.unwrap_model(model).generate(
                    batch["input_ids"],
                    clap=batch["clap"],
                    attention_mask=batch["attention_mask"],
                    eos_mask=batch["eos_mask"],
                    **gen_kwargs,
                )

                generated_tokens = accelerator.pad_across_processes(
                    generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
                )
                generated_tokens = generated_tokens.cpu().numpy()
                captions = batch["captions"]

                if isinstance(generated_tokens, tuple):
                    generated_tokens = generated_tokens[0]
                decoded_preds = tokenizer.batch_decode(
                    generated_tokens, skip_special_tokens=True
                )

                predictions.extend(decoded_preds)
                references.extend(captions)

        logger.info("Evaluating predictions...")
        result = evaluate(predictions, references, metrics=metric_list)

        # Gather Result
        result = {k: v.cuda() for k, v in result[0].items()}
        result = accelerator.gather_for_metrics(result)
        # Log the average of metrics among the processes
        if accelerator.num_processes > 1:
            result = {f"eval/{k}": round(v.mean().item(), 4) for k, v in result.items()}
        else:
            result = {f"eval/{k}": round(v.item(), 4) for k, v in result.items()}
        logger.info(result)

        if args.with_tracking:
            result["train/epoch_train_loss"] = (total_loss - before_epoch_loss) / len(
                train_dataloader
            )
            result["train/steps"] = completed_steps
            before_epoch_loss = total_loss
            if args.encodec_masking_prob > 0:
                result["train/epoch_encodec_loss"] = (
                    total_encodec_loss - before_epoch_encodec_loss
                ) / len(train_dataloader)
                before_epoch_encodec_loss = total_encodec_loss
            accelerator.log(result, step=epoch)

        if args.checkpointing_steps == "epoch":
            output_dir = f"epoch_{epoch}"
            if args.output_dir is not None:
                output_dir = os.path.join(args.output_dir, "checkpoints", output_dir)
            accelerator.save_state(output_dir)
            if accelerator.is_main_process:
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.config.save_pretrained(output_dir)

    if args.output_dir is not None:
        save_dir = os.path.join(args.output_dir, "final")
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(
            save_dir,
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
        )
        if accelerator.is_main_process:
            tokenizer.save_pretrained(save_dir)


if __name__ == "__main__":
    main()