import torch
import pytorch_lightning as pl
from pathlib import Path
from typing import Any
import torchvision
import wandb


class EvalSaveCallback(pl.Callback):

    def __init__(self, save_dir: Path) -> None:
        super().__init__()
        self.save_dir = save_dir

    def save(self, outputs, batch, batch_idx):
        name = batch['name']

        filename = self.save_dir / f"{batch_idx:06d}_{name[0]}.pt"
        torch.save({
            "fpv": batch['image'],
            "seg_masks": batch['seg_masks'],
            'name': name,
            "output": outputs["output"],
            "valid_bev": outputs["valid_bev"],
        }, filename)

    def on_test_batch_end(self, trainer: pl.Trainer,
                          pl_module: pl.LightningModule,
                          outputs: torch.Tensor | Any | None,
                          batch: Any,
                          batch_idx: int,
                          dataloader_idx: int = 0) -> None:
        if not outputs:
            return

        self.save(outputs, batch, batch_idx)

    def on_validation_batch_end(self, trainer: pl.Trainer,
                                pl_module: pl.LightningModule,
                                outputs: torch.Tensor | Any | None,
                                batch: Any,
                                batch_idx: int,
                                dataloader_idx: int = 0) -> None:
        if not outputs:

            return

        self.save(outputs, batch, batch_idx)


class ImageLoggerCallback(pl.Callback):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

    def log_image(self, trainer, pl_module, outputs, batch, batch_idx, mode="train"):
        fpv_rgb = batch["image"]
        fpv_grid = torchvision.utils.make_grid(
            fpv_rgb, nrow=8, normalize=False)
        images = [
            wandb.Image(fpv_grid, caption="fpv")
        ]

        pred = outputs['output'].permute(0, 2, 3, 1)
        pred[outputs["valid_bev"][..., :-1] == 0] = 0
        pred = (pred > 0.5).float()
        pred = pred.permute(0, 3, 1, 2)

        for i in range(self.num_classes):
            gt_class_i = batch['seg_masks'][..., i]
            gt_class_i_grid = torchvision.utils.make_grid(
                gt_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0)
            pred_class_i = pred[:, i]
            pred_class_i_grid = torchvision.utils.make_grid(
                pred_class_i.unsqueeze(1), nrow=8, normalize=False, pad_value=0)

            images += [
                wandb.Image(gt_class_i_grid, caption=f"gt_class_{i}"),
                wandb.Image(pred_class_i_grid, caption=f"pred_class_{i}")
            ]

        trainer.logger.experiment.log(
            {
                "{}/images".format(mode): images
            }
        )

    def on_validation_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx):
        if batch_idx == 0:
            with torch.no_grad():
                outputs = pl_module(batch)
            self.log_image(trainer, pl_module, outputs,
                           batch, batch_idx, mode="val")

    def on_train_batch_end(self, trainer, pl_module: pl.LightningModule, outputs, batch, batch_idx):
        if batch_idx == 0:
            pl_module.eval()

            with torch.no_grad():
                outputs = pl_module(batch)

            self.log_image(trainer, pl_module, outputs,
                           batch, batch_idx, mode="train")

            pl_module.train()