import os
import pathlib

from glob import glob

from argparse import ArgumentParser
import torch
import pytorch_lightning as pl
import numpy as np
import cv2
import random
import math
from torchvision import transforms


def do_training(hparams, model_constructor):
    # instantiate model
    model = model_constructor(**vars(hparams))
    # set all sorts of training parameters
    hparams.gpus = -1
    hparams.accelerator = "ddp"
    hparams.benchmark = True

    if hparams.dry_run:
        print("Doing a dry run")
        hparams.overfit_batches = hparams.batch_size

    if not hparams.no_resume:
        hparams = set_resume_parameters(hparams)

    if not hasattr(hparams, "version") or hparams.version is None:
        hparams.version = 0

    hparams.sync_batchnorm = True

    ttlogger = pl.loggers.TestTubeLogger(
        "checkpoints", name=hparams.exp_name, version=hparams.version
    )

    hparams.callbacks = make_checkpoint_callbacks(hparams.exp_name, hparams.version)

    wblogger = get_wandb_logger(hparams)
    hparams.logger = [wblogger, ttlogger]

    trainer = pl.Trainer.from_argparse_args(hparams)
    trainer.fit(model)
    

def get_default_argument_parser():
    parser = ArgumentParser(add_help=False)
    parser.add_argument(
        "--num_nodes",
        type=int,
        default=1,
        help="number of nodes for distributed training",
    )

    parser.add_argument(
        "--exp_name", type=str, required=True, help="name your experiment"
    )

    parser.add_argument(
        "--dry-run",
        action="store_true",
        default=False,
        help="run on batch of train/val/test",
    )

    parser.add_argument(
        "--no_resume",
        action="store_true",
        default=False,
        help="resume if we have a checkpoint",
    )

    parser.add_argument(
        "--accumulate_grad_batches",
        type=int,
        default=1,
        help="accumulate N batches for gradient computation",
    )

    parser.add_argument(
        "--max_epochs", type=int, default=200, help="maximum number of epochs"
    )

    parser.add_argument(
        "--project_name", type=str, default="lightseg", help="project name for logging"
    )

    return parser


def make_checkpoint_callbacks(exp_name, version, base_path="checkpoints", frequency=1):
    version = 0 if version is None else version

    base_callback = pl.callbacks.ModelCheckpoint(
        dirpath=f"{base_path}/{exp_name}/version_{version}/checkpoints/",
        save_last=True,
        verbose=True,
    )

    val_callback = pl.callbacks.ModelCheckpoint(
        monitor="val_acc_epoch",
        dirpath=f"{base_path}/{exp_name}/version_{version}/checkpoints/",
        filename="result-{epoch}-{val_acc_epoch:.2f}",
        mode="max",
        save_top_k=3,
        verbose=True,
    )

    return [base_callback, val_callback]


def get_latest_version(folder):
    versions = [
        int(pathlib.PurePath(path).name.split("_")[-1])
        for path in glob(f"{folder}/version_*/")
    ]

    if len(versions) == 0:
        return None

    versions.sort()
    return versions[-1]


def get_latest_checkpoint(exp_name, version):
    while version > -1:
        folder = f"./checkpoints/{exp_name}/version_{version}/checkpoints/"

        latest = f"{folder}/last.ckpt"
        if os.path.exists(latest):
            return latest, version

        chkpts = glob(f"{folder}/epoch=*.ckpt")

        if len(chkpts) > 0:
            break

        version -= 1

    if len(chkpts) == 0:
        return None, None

    latest = max(chkpts, key=os.path.getctime)

    return latest, version


def set_resume_parameters(hparams):
    version = get_latest_version(f"./checkpoints/{hparams.exp_name}")

    if version is not None:
        latest, version = get_latest_checkpoint(hparams.exp_name, version)
        print(f"Resuming checkpoint {latest}, exp_version={version}")

        hparams.resume_from_checkpoint = latest
        hparams.version = version

        wandb_file = "checkpoints/{hparams.exp_name}/version_{version}/wandb_id"
        if os.path.exists(wandb_file):
            with open(wandb_file, "r") as f:
                hparams.wandb_id = f.read()
    else:
        version = 0

    return hparams


def get_wandb_logger(hparams):
    exp_dir = f"checkpoints/{hparams.exp_name}/version_{hparams.version}/"
    id_file = f"{exp_dir}/wandb_id"

    if os.path.exists(id_file):
        with open(id_file) as f:
            hparams.wandb_id = f.read()
    else:
        hparams.wandb_id = None

    logger = pl.loggers.WandbLogger(
        save_dir="checkpoints",
        project=hparams.project_name,
        name=hparams.exp_name,
        id=hparams.wandb_id,
    )

    if hparams.wandb_id is None:
        _ = logger.experiment

    if not os.path.exists(exp_dir):
        os.makedirs(exp_dir)

        with open(id_file, "w") as f:
            f.write(logger.version)

    return logger


class Resize(object):
    """Resize sample to given size (width, height)."""

    def __init__(
        self,
        width,
        height,
        resize_target=True,
        keep_aspect_ratio=False,
        ensure_multiple_of=1,
        resize_method="lower_bound",
        image_interpolation_method=cv2.INTER_AREA,
        letter_box=False,
    ):
        """Init.

        Args:
            width (int): desired output width
            height (int): desired output height
            resize_target (bool, optional):
                True: Resize the full sample (image, mask, target).
                False: Resize image only.
                Defaults to True.
            keep_aspect_ratio (bool, optional):
                True: Keep the aspect ratio of the input sample.
                Output sample might not have the given width and height, and
                resize behaviour depends on the parameter 'resize_method'.
                Defaults to False.
            ensure_multiple_of (int, optional):
                Output width and height is constrained to be multiple of this parameter.
                Defaults to 1.
            resize_method (str, optional):
                "lower_bound": Output will be at least as large as the given size.
                "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
                "minimal": Scale as least as possible.  (Output size might be smaller than given size.)
                Defaults to "lower_bound".
        """
        self.__width = width
        self.__height = height

        self.__resize_target = resize_target
        self.__keep_aspect_ratio = keep_aspect_ratio
        self.__multiple_of = ensure_multiple_of
        self.__resize_method = resize_method
        self.__image_interpolation_method = image_interpolation_method
        self.__letter_box = letter_box

    def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
        y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)

        if max_val is not None and y > max_val:
            y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)

        if y < min_val:
            y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)

        return y

    def get_size(self, width, height):
        # determine new height and width
        scale_height = self.__height / height
        scale_width = self.__width / width

        if self.__keep_aspect_ratio:
            if self.__resize_method == "lower_bound":
                # scale such that output size is lower bound
                if scale_width > scale_height:
                    # fit width
                    scale_height = scale_width
                else:
                    # fit height
                    scale_width = scale_height
            elif self.__resize_method == "upper_bound":
                # scale such that output size is upper bound
                if scale_width < scale_height:
                    # fit width
                    scale_height = scale_width
                else:
                    # fit height
                    scale_width = scale_height
            elif self.__resize_method == "minimal":
                # scale as least as possbile
                if abs(1 - scale_width) < abs(1 - scale_height):
                    # fit width
                    scale_height = scale_width
                else:
                    # fit height
                    scale_width = scale_height
            else:
                raise ValueError(
                    f"resize_method {self.__resize_method} not implemented"
                )

        if self.__resize_method == "lower_bound":
            new_height = self.constrain_to_multiple_of(
                scale_height * height, min_val=self.__height
            )
            new_width = self.constrain_to_multiple_of(
                scale_width * width, min_val=self.__width
            )
        elif self.__resize_method == "upper_bound":
            new_height = self.constrain_to_multiple_of(
                scale_height * height, max_val=self.__height
            )
            new_width = self.constrain_to_multiple_of(
                scale_width * width, max_val=self.__width
            )
        elif self.__resize_method == "minimal":
            new_height = self.constrain_to_multiple_of(scale_height * height)
            new_width = self.constrain_to_multiple_of(scale_width * width)
        else:
            raise ValueError(f"resize_method {self.__resize_method} not implemented")

        return (new_width, new_height)

    def make_letter_box(self, sample):
        top = bottom = (self.__height - sample.shape[0]) // 2
        left = right = (self.__width - sample.shape[1]) // 2
        sample = cv2.copyMakeBorder(
            sample, top, bottom, left, right, cv2.BORDER_CONSTANT, None, 0
        )
        return sample

    def __call__(self, sample):
        width, height = self.get_size(
            sample["image"].shape[1], sample["image"].shape[0]
        )

        # resize sample
        sample["image"] = cv2.resize(
            sample["image"],
            (width, height),
            interpolation=self.__image_interpolation_method,
        )

        if self.__letter_box:
            sample["image"] = self.make_letter_box(sample["image"])

        if self.__resize_target:
            if "disparity" in sample:
                sample["disparity"] = cv2.resize(
                    sample["disparity"],
                    (width, height),
                    interpolation=cv2.INTER_NEAREST,
                )

                if self.__letter_box:
                    sample["disparity"] = self.make_letter_box(sample["disparity"])

            if "depth" in sample:
                sample["depth"] = cv2.resize(
                    sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
                )

                if self.__letter_box:
                    sample["depth"] = self.make_letter_box(sample["depth"])

            sample["mask"] = cv2.resize(
                sample["mask"].astype(np.float32),
                (width, height),
                interpolation=cv2.INTER_NEAREST,
            )

            if self.__letter_box:
                sample["mask"] = self.make_letter_box(sample["mask"])

            sample["mask"] = sample["mask"].astype(bool)

        return sample