import cv2
import numpy as np
import torch

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer import Trainer
from torch import Tensor


@torch.no_grad()
def unnormalize(
    images: Tensor,
    mean: tuple[float] = (0.5, 0.5, 0.5),
    std: tuple[float] = (0.5, 0.5, 0.5),
) -> Tensor:
    """Reverts the normalization transformation applied before ViT.

    Args:
        images (Tensor): a batch of images
        mean (tuple[int]): the means used for normalization - defaults to (0.5, 0.5, 0.5)
        std (tuple[int]): the stds used for normalization - defaults to (0.5, 0.5, 0.5)

    Returns:
        the un-normalized batch of images
    """
    unnormalized_images = images.clone()
    for i, (m, s) in enumerate(zip(mean, std)):
        unnormalized_images[:, i, :, :].mul_(s).add_(m)

    return unnormalized_images


@torch.no_grad()
def smoothen(mask: Tensor, patch_size: int = 16) -> Tensor:
    """Smoothens a mask by downsampling it and re-upsampling it
     with bi-linear interpolation.

    Args:
        mask (Tensor): a 2D float torch tensor with values in [0, 1]
        patch_size (int): the patch size in pixels

    Returns:
        a smoothened mask at the pixel level
    """
    device = mask.device
    (h, w) = mask.shape
    mask = cv2.resize(
        mask.cpu().numpy(),
        (h // patch_size, w // patch_size),
        interpolation=cv2.INTER_NEAREST,
    )
    mask = cv2.resize(mask, (h, w), interpolation=cv2.INTER_LINEAR)
    return torch.tensor(mask).to(device)


@torch.no_grad()
def draw_mask_on_image(image: Tensor, mask: Tensor) -> Tensor:
    """Overlays a dimming mask on the image.

    Args:
        image (Tensor): a float torch tensor with values in [0, 1]
        mask (Tensor): a float torch tensor with values in [0, 1]

    Returns:
        the image with parts of it dimmed according to the mask
    """
    masked_image = image * mask

    return masked_image


@torch.no_grad()
def draw_heatmap_on_image(
    image: Tensor,
    mask: Tensor,
    colormap: int = cv2.COLORMAP_JET,
) -> Tensor:
    """Overlays a heatmap on the image.

    Args:
        image (Tensor): a float torch tensor with values in [0, 1]
        mask (Tensor): a float torch tensor with values in [0, 1]
        colormap (int): the OpenCV colormap to be used

    Returns:
        the image with the heatmap overlaid
    """
    # Save the device of the image
    original_device = image.device

    # Convert image & mask to numpy
    image = image.permute(1, 2, 0).cpu().numpy()
    mask = mask.cpu().numpy()

    # Create heatmap
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255

    # Overlay heatmap on image
    masked_image = image + heatmap
    masked_image = masked_image / np.max(masked_image)

    return torch.tensor(masked_image).permute(2, 0, 1).to(original_device)


def _prepare_samples(images: Tensor, masks: Tensor) -> tuple[Tensor, list[float]]:
    """Prepares the samples for the masking/heatmap visualization.

    Args:
        images (Tensor): a float torch tensor with values in [0, 1]
        masks (Tensor): a float torch tensor with values in [0, 1]

    Returns
        a tuple of image triplets (img, masked, heatmap) and their
         corresponding masking percentages
    """
    num_channels = images[0].shape[0]

    # Smoothen masks
    masks = [smoothen(m) for m in masks]

    # Un-normalize images
    if num_channels == 1:
        images = [
            torch.repeat_interleave(img, 3, 0)
            for img in unnormalize(images, mean=(0.5,), std=(0.5,))
        ]
    else:
        images = [img for img in unnormalize(images)]

    # Draw mask on sample images
    images_with_mask = [
        draw_mask_on_image(image, mask) for image, mask in zip(images, masks)
    ]

    # Draw heatmap on sample images
    images_with_heatmap = [
        draw_heatmap_on_image(image, mask) for image, mask in zip(images, masks)
    ]

    # Chunk to triplets (image, masked image, heatmap)
    samples = torch.cat(
        [
            torch.cat(images, dim=2),
            torch.cat(images_with_mask, dim=2),
            torch.cat(images_with_heatmap, dim=2),
        ],
        dim=1,
    ).chunk(len(images), dim=-1)

    # Compute masking percentages
    masked_pixels_percentages = [
        100 * (1 - torch.stack(masks)[i].mean(-1).mean(-1).item())
        for i in range(len(masks))
    ]

    return samples, masked_pixels_percentages


def log_masks(images: Tensor, masks: Tensor, key: str, logger: WandbLogger):
    """Logs a set of images with their masks to WandB.

    Args:
        images (Tensor): a float torch tensor with values in [0, 1]
        masks (Tensor): a float torch tensor with values in [0, 1]
        key (str): the key to log the images with
        logger (WandbLogger): the logger to log the images to
    """
    samples, masked_pixels_percentages = _prepare_samples(images, masks)

    # Log with wandb
    logger.log_image(
        key=key,
        images=list(samples),
        caption=[
            f"Masking: {masked_pixels_percentage:.2f}% "
            for masked_pixels_percentage in masked_pixels_percentages
        ],
    )


class DrawMaskCallback(Callback):
    def __init__(
        self,
        samples: list[tuple[Tensor, Tensor]],
        log_every_n_steps: int = 200,
        key: str = "",
    ):
        """A callback that logs VisionDiffMask masks for the sample images to WandB.

        Args:
            samples (list[tuple[Tensor, Tensor]): a list of image, label pairs
            log_every_n_steps (int): the interval in steps to log the masks to WandB
            key (str): the key to log the images with (allows for multiple batches)
        """
        self.images = torch.stack([img for img in samples[0]])
        self.labels = [label.item() for label in samples[1]]
        self.log_every_n_steps = log_every_n_steps
        self.key = key

    def _log_masks(self, trainer: Trainer, pl_module: LightningModule):
        # Predict mask
        with torch.no_grad():
            pl_module.eval()
            outputs = pl_module.get_mask(self.images)
            pl_module.train()

        # Unnest outputs
        masks = outputs["mask"]
        kl_divs = outputs["kl_div"]
        pred_classes = outputs["pred_class"].cpu()

        # Prepare masked samples for logging
        samples, masked_pixels_percentages = _prepare_samples(self.images, masks)

        # Log with wandb
        trainer.logger.log_image(
            key="DiffMask " + self.key,
            images=list(samples),
            caption=[
                f"Masking: {masked_pixels_percentage:.2f}% "
                f"\n KL-divergence: {kl_div:.4f} "
                f"\n Class: {pl_module.model.config.id2label[label]} "
                f"\n Predicted Class: {pl_module.model.config.id2label[pred_class.item()]}"
                for masked_pixels_percentage, kl_div, label, pred_class in zip(
                    masked_pixels_percentages, kl_divs, self.labels, pred_classes
                )
            ],
        )

    def on_fit_start(self, trainer: Trainer, pl_module: LightningModule):
        # Transfer sample images to correct device
        self.images = self.images.to(pl_module.device)

        # Log sample images
        self._log_masks(trainer, pl_module)

    def on_train_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: dict,
        batch: tuple[Tensor, Tensor],
        batch_idx: int,
        unused: int = 0,
    ):
        # Log sample images every n steps
        if batch_idx % self.log_every_n_steps == 0:
            self._log_masks(trainer, pl_module)