import math
import wandb
from pathlib import Path
from typing import Tuple, List, Union, Dict
from omegaconf import DictConfig
from hydra.utils import instantiate
import logging
import torch
import time
from functools import partial
from torch import nn, Tensor, autograd
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from torchvision.utils import make_grid

from ..utils import printer, compute_grad_norm

SNAPSHOT_KEYS = set(["EPOCH", "STEP", "OPTIMIZER", "LR_SCHEDULER", "MODEL", "LOSS"])


class VqvaeTrainer:
    def __init__(
        self,
        device: int,
        model: nn.Module,
        log: logging.Logger,
        exp_dir: Path,
        snapshot: Path = None,
        model_weights: Path = None,  # only for testing
    ) -> None:
        self.device = device
        self.log = log
        self.exp_dir = exp_dir
        assert (
            snapshot is None or model_weights is None
        ), "Snapshot and model weights cannot be set at the same time."

        self.model = model
        if snapshot is not None and snapshot.is_file():
            self.snapshot = self.load_snapshot(snapshot)
            self.model.load_state_dict(self.snapshot["MODEL"])
            self.start_epoch = self.snapshot["EPOCH"]
            self.global_step = self.snapshot["STEP"]
        elif model_weights is not None and model_weights.is_file():
            self.load_model(model_weights)
        else:
            self.snapshot = None
            self.start_epoch = 0

        self.model = self.model.to(device)
        self.model = DDP(self.model, device_ids=[device])

        # https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113
        torch.cuda.set_device(device)  # master gpu takes up extra memory
        torch.cuda.empty_cache()

    def train_epoch(
        self,
        epoch: int,
        starting_temp: float,
        anneal_rate: float,
        temp_min: float,
        grad_clip: float = None,
    ):
        start = time.time()
        total_loss = 0.0
        total_samples = 0

        # load data from dataloader
        for i, obj in enumerate(self.train_dataloader):
            if isinstance(obj, Tensor):
                img = obj.to(self.device)
            elif isinstance(obj, (list, tuple)):
                img = obj[0].to(self.device)
            else:
                raise ValueError(f"Unrecognized object type {type(obj)}")

            # temperature annealing
            self.temp = max(
                starting_temp * math.exp(-anneal_rate * self.global_step), temp_min
            )

            with autograd.detect_anomaly():
                loss, soft_recons = self.model(
                    img, return_loss=True, return_recons=True, temp=self.temp
                )

                self.optimizer.zero_grad()
                loss.backward()
                if grad_clip:
                    nn.utils.clip_grad_norm_(
                        self.model.parameters(), max_norm=grad_clip
                    )
                self.optimizer.step()

            loss = loss.detach().cpu().data
            total_loss += loss * img.shape[0]
            total_samples += img.shape[0]

            self.lr_scheduler.step()
            self.global_step += 1

            if i % 10 == 0:
                grad_norm = compute_grad_norm(self.model)
                lr = self.optimizer.param_groups[0]["lr"]
                elapsed = time.time() - start
                self.log.info(
                    printer(
                        self.device,
                        f"Epoch {epoch} Step {i + 1}/{len(self.train_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f}) | Grad norm {grad_norm:.3f} | {total_samples / elapsed:4.1f} images/s | lr {lr:5.1e} | Temp {self.temp:.2e}",
                    )
                )

            # visualize reconstruction images
            if i % 100 == 0 and self.device == 0:
                lr = self.optimizer.param_groups[0]["lr"]
                k = 4  # num of images saved for visualization
                codes = self.model.module.get_codebook_indices(img[:k])
                hard_recons = self.model.module.decode(codes)

                img = img[:k].detach().cpu()
                soft_recons = soft_recons[:k].detach().cpu()
                codes = codes.flatten(start_dim=1).detach().cpu()
                hard_recons = hard_recons.detach().cpu()

                make_vis = partial(make_grid, nrow=int(math.sqrt(k)), normalize=True)
                img, soft_recons, hard_recons = map(
                    make_vis, (img, soft_recons, hard_recons)
                )

                log_info = {
                    "epoch": epoch,
                    "train_loss": loss,
                    "temperature": self.temp,
                    "learning rate": lr,
                    "original images": wandb.Image(
                        img, caption=f"step: {self.global_step}"
                    ),
                    "soft reconstruction": wandb.Image(
                        soft_recons, caption=f"step: {self.global_step}"
                    ),
                    "hard reconstruction": wandb.Image(
                        hard_recons, caption=f"step: {self.global_step}"
                    ),
                    "codebook_indices": wandb.Histogram(codes),
                }

                wandb.log(
                    log_info,
                    step=self.global_step,
                )

        return total_loss, total_samples

    def train(
        self,
        train_dataloader: DataLoader,
        valid_dataloader: DataLoader,
        train_cfg: DictConfig,
        valid_cfg: DictConfig,
    ):
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.optimizer = instantiate(
            train_cfg.optimizer, params=self.model.parameters()
        )

        self.lr_scheduler = instantiate(
            train_cfg.lr_scheduler, optimizer=self.optimizer
        )

        if self.snapshot is not None:
            self.optimizer.load_state_dict(self.snapshot["OPTIMIZER"])
            self.lr_scheduler.load_state_dict(self.snapshot["LR_SCHEDULER"])

        best_loss = float("inf")
        self.model.train()
        self.global_step = 0
        # self.temp = train_cfg.starting_temp
        for epoch in range(self.start_epoch, train_cfg.epochs):
            train_dataloader.sampler.set_epoch(epoch)
            epoch_loss, epoch_samples = self.train_epoch(
                epoch,
                starting_temp=train_cfg.starting_temp,
                anneal_rate=train_cfg.temp_anneal_rate,
                temp_min=train_cfg.temp_min,
                grad_clip=train_cfg.grad_clip,
            )

            torch.cuda.empty_cache()

            valid_loss, valid_samples = self.valid(valid_cfg)

            # reduce loss to gpu 0
            training_info = torch.tensor(
                [epoch_loss, epoch_samples, valid_loss, valid_samples],
                device=self.device,
            )

            dist.reduce(
                training_info,
                dst=0,
                op=dist.ReduceOp.SUM,
            )

            if self.device == 0:
                grad_norm = compute_grad_norm(self.model)
                epoch_loss, epoch_samples, valid_loss, valid_samples = training_info
                epoch_loss, valid_loss = (
                    float(epoch_loss) / epoch_samples,
                    float(valid_loss) / valid_samples,
                )

                log_info = {
                    "train loss (epoch)": epoch_loss,
                    "valid loss (epoch)": valid_loss,
                    "train_samples": epoch_samples,
                    "valid_samples": valid_samples,
                    "grad_norm": grad_norm,
                }

                wandb.log(
                    log_info,
                    step=self.global_step,
                )

                if epoch % train_cfg.save_every == 0:
                    self.save_snapshot(epoch, best_loss)
                if valid_loss < best_loss:
                    self.save_model(epoch)
                    best_loss = valid_loss

    def valid(self, cfg: DictConfig):
        total_samples = 0
        total_loss = 0.0

        self.model.eval()
        for i, obj in enumerate(self.valid_dataloader):
            if isinstance(obj, Tensor):
                img = obj.to(self.device)
            elif isinstance(obj, (list, tuple)):
                img = obj[0].to(self.device)
            else:
                raise ValueError(f"Unrecognized object type {type(obj)}")

            with torch.no_grad():
                loss = self.model(
                    img, return_loss=True, return_recons=False, temp=self.temp
                )

            loss = loss.detach().cpu().data
            total_loss += loss * img.shape[0]
            total_samples += img.shape[0]

            if i % 10 == 0:
                self.log.info(
                    printer(
                        self.device,
                        f"Valid: Step {i + 1}/{len(self.valid_dataloader)} | Loss {loss:.4f} ({total_loss / total_samples:.4f})",
                    )
                )

        return total_loss, total_samples

    def save_model(self, epoch: int):
        filename = Path(self.exp_dir) / "model" / f"epoch{epoch}_model.pt"
        torch.save(self.model.module.state_dict(), filename)
        self.log.info(printer(self.device, f"Saving model to {filename}"))
        filename = Path(self.exp_dir) / "model" / f"best.pt"
        torch.save(self.model.module.state_dict(), filename)

    def load_model(self, path: Union[str, Path]):
        self.model.load_state_dict(torch.load(path, map_location="cpu"))
        self.log.info(printer(self.device, f"Loading model from {path}"))

    def save_snapshot(self, epoch: int, best_loss: float):
        state_info = {
            "EPOCH": epoch + 1,
            "STEP": self.global_step,
            "OPTIMIZER": self.optimizer.state_dict(),
            "LR_SCHEDULER": self.lr_scheduler.state_dict(),
            "MODEL": self.model.module.state_dict(),
            "LOSS": best_loss,
        }

        snapshot_path = Path(self.exp_dir) / "snapshot" / f"epoch{epoch}_snapshot.pt"
        torch.save(state_info, snapshot_path)

        self.log.info(printer(self.device, f"Saving snapshot to {snapshot_path}"))

    def load_snapshot(self, path: Path):
        self.log.info(printer(self.device, f"Loading snapshot from {path}"))
        snapshot = torch.load(path, map_location="cpu")
        assert SNAPSHOT_KEYS.issubset(snapshot.keys())
        return snapshot