import os
import typing

import torch
import torch.distributed as dist
from torch.nn.parallel import DataParallel
from torch.nn.parallel import DistributedDataParallel

from ..data.datasets import ResumableDistributedSampler as DistributedSampler
from ..data.datasets import ResumableSequentialSampler as SequentialSampler


class Accelerator:  # pragma: no cover
    """This class is used to prepare models and dataloaders for
    usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
    prepare the respective objects. In the case of models, they are moved to
    the appropriate GPU and SyncBatchNorm is applied to them. In the case of
    dataloaders, a sampler is created and the dataloader is initialized with
    that sampler.

    If the world size is 1, prepare_model and prepare_dataloader are
    no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
    script was launched without ``torchrun``, and ``DataParallel``
    will be used instead of ``DistributedDataParallel`` (not recommended), if
    the world size (number of GPUs) is greater than 1.

    Parameters
    ----------
    amp : bool, optional
        Whether or not to enable automatic mixed precision, by default False
    """

    def __init__(self, amp: bool = False):
        local_rank = os.getenv("LOCAL_RANK", None)
        self.world_size = torch.cuda.device_count()

        self.use_ddp = self.world_size > 1 and local_rank is not None
        self.use_dp = self.world_size > 1 and local_rank is None
        self.device = "cpu" if self.world_size == 0 else "cuda"

        if self.use_ddp:
            local_rank = int(local_rank)
            dist.init_process_group(
                "nccl",
                init_method="env://",
                world_size=self.world_size,
                rank=local_rank,
            )

        self.local_rank = 0 if local_rank is None else local_rank
        self.amp = amp

        class DummyScaler:
            def __init__(self):
                pass

            def step(self, optimizer):
                optimizer.step()

            def scale(self, loss):
                return loss

            def unscale_(self, optimizer):
                return optimizer

            def update(self):
                pass

        self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
        self.device_ctx = (
            torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
        )

    def __enter__(self):
        if self.device_ctx is not None:
            self.device_ctx.__enter__()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.device_ctx is not None:
            self.device_ctx.__exit__(exc_type, exc_value, traceback)

    def prepare_model(self, model: torch.nn.Module, **kwargs):
        """Prepares model for DDP or DP. The model is moved to
        the device of the correct rank.

        Parameters
        ----------
        model : torch.nn.Module
            Model that is converted for DDP or DP.

        Returns
        -------
        torch.nn.Module
            Wrapped model, or original model if DDP and DP are turned off.
        """
        model = model.to(self.device)
        if self.use_ddp:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = DistributedDataParallel(
                model, device_ids=[self.local_rank], **kwargs
            )
        elif self.use_dp:
            model = DataParallel(model, **kwargs)
        return model

    # Automatic mixed-precision utilities
    def autocast(self, *args, **kwargs):
        """Context manager for autocasting. Arguments
        go to ``torch.cuda.amp.autocast``.
        """
        return torch.cuda.amp.autocast(self.amp, *args, **kwargs)

    def backward(self, loss: torch.Tensor):
        """Backwards pass, after scaling the loss if ``amp`` is
        enabled.

        Parameters
        ----------
        loss : torch.Tensor
            Loss value.
        """
        self.scaler.scale(loss).backward()

    def step(self, optimizer: torch.optim.Optimizer):
        """Steps the optimizer, using a ``scaler`` if ``amp`` is
        enabled.

        Parameters
        ----------
        optimizer : torch.optim.Optimizer
            Optimizer to step forward.
        """
        self.scaler.step(optimizer)

    def update(self):
        """Updates the scale factor."""
        self.scaler.update()

    def prepare_dataloader(
        self, dataset: typing.Iterable, start_idx: int = None, **kwargs
    ):
        """Wraps a dataset with a DataLoader, using the correct sampler if DDP is
        enabled.

        Parameters
        ----------
        dataset : typing.Iterable
            Dataset to build Dataloader around.
        start_idx : int, optional
            Start index of sampler, useful if resuming from some epoch,
            by default None

        Returns
        -------
        _type_
            _description_
        """

        if self.use_ddp:
            sampler = DistributedSampler(
                dataset,
                start_idx,
                num_replicas=self.world_size,
                rank=self.local_rank,
            )
            if "num_workers" in kwargs:
                kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
            kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
        else:
            sampler = SequentialSampler(dataset, start_idx)

        dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
        return dataloader

    @staticmethod
    def unwrap(model):
        """Unwraps the model if it was wrapped in DDP or DP, otherwise
        just returns the model. Use this to unwrap the model returned by
        :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
        """
        if hasattr(model, "module"):
            return model.module
        return model