import logging
import random
import subprocess
from datetime import datetime

import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import _find_tensors
import torch.optim
import torch.utils.data
from packaging import version
from omegaconf import OmegaConf


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def is_logging_process():
    return not dist.is_initialized() or dist.get_rank() == 0


def get_logger(cfg, name=None):
    # log_file_path is used when unit testing
    if is_logging_process():
        logging.config.dictConfig(
            OmegaConf.to_container(cfg.job_logging_config, resolve=True)
        )
        return logging.getLogger(name)


# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
class SyncFunction(torch.autograd.Function):
    @staticmethod
    # @torch.no_grad()
    def forward(ctx, tensor):
        ctx.batch_size = tensor.shape[0]

        gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]

        torch.distributed.all_gather(gathered_tensor, tensor)
        gathered_tensor = torch.cat(gathered_tensor, 0)

        return gathered_tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)

        idx_from = torch.distributed.get_rank() * ctx.batch_size
        idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
        return grad_input[idx_from:idx_to]


def get_timestamp():
    return datetime.now().strftime("%y%m%d-%H%M%S")


def get_commit_hash():
    message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
    return message.strip().decode("utf-8")


class DDP(DistributedDataParallel):
    """
    Override the forward call in lightning so it goes to training and validation step respectively
    """

    def forward(self, *inputs, **kwargs):  # pragma: no cover
        if version.parse(torch.__version__[:6]) < version.parse("1.11"):
            self._sync_params()
            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
            assert len(self.device_ids) == 1
            if self.module.training:
                output = self.module.training_step(*inputs[0], **kwargs[0])
            elif self.module.testing:
                output = self.module.test_step(*inputs[0], **kwargs[0])
            else:
                output = self.module.validation_step(*inputs[0], **kwargs[0])
            if torch.is_grad_enabled():
                # We'll return the output object verbatim since it is a freeform
                # object. We need to find any tensors in this object, though,
                # because we need to figure out which parameters were used during
                # this forward pass, to ensure we short circuit reduction for any
                # unused parameters. Only if `find_unused_parameters` is set.
                if self.find_unused_parameters:
                    self.reducer.prepare_for_backward(list(_find_tensors(output)))
                else:
                    self.reducer.prepare_for_backward([])
        else:
            from torch.nn.parallel.distributed import \
                logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref
            with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
                if torch.is_grad_enabled() and self.require_backward_grad_sync:
                    self.logger.set_runtime_stats_and_log()
                    self.num_iterations += 1
                    self.reducer.prepare_for_forward()

                # Notify the join context that this process has not joined, if
                # needed
                work = Join.notify_join_context(self)
                if work:
                    self.reducer._set_forward_pass_work_handle(
                        work, self._divide_by_initial_world_size
                    )

                # Calling _rebuild_buckets before forward compuation,
                # It may allocate new buckets before deallocating old buckets
                # inside _rebuild_buckets. To save peak memory usage,
                # call _rebuild_buckets before the peak memory usage increases
                # during forward computation.
                # This should be called only once during whole training period.
                if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
                    logging.info("Reducer buckets have been rebuilt in this iteration.")
                    self._has_rebuilt_buckets = True

                # sync params according to location (before/after forward) user
                # specified as part of hook, if hook was specified.
                buffer_hook_registered = hasattr(self, 'buffer_hook')
                if self._check_sync_bufs_pre_fwd():
                    self._sync_buffers()

                if self._join_config.enable:
                    # Notify joined ranks whether they should sync in backwards pass or not.
                    self._check_global_requires_backward_grad_sync(is_joined_rank=False)

                inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
                if self.module.training:
                    output = self.module.training_step(*inputs[0], **kwargs[0])
                elif self.module.testing:
                    output = self.module.test_step(*inputs[0], **kwargs[0])
                else:
                    output = self.module.validation_step(*inputs[0], **kwargs[0])

                # sync params according to location (before/after forward) user
                # specified as part of hook, if hook was specified.
                if self._check_sync_bufs_post_fwd():
                    self._sync_buffers()

                if torch.is_grad_enabled() and self.require_backward_grad_sync:
                    self.require_forward_param_sync = True
                    # We'll return the output object verbatim since it is a freeform
                    # object. We need to find any tensors in this object, though,
                    # because we need to figure out which parameters were used during
                    # this forward pass, to ensure we short circuit reduction for any
                    # unused parameters. Only if `find_unused_parameters` is set.
                    if self.find_unused_parameters and not self.static_graph:
                        # Do not need to populate this for static graph.
                        self.reducer.prepare_for_backward(list(_find_tensors(output)))
                    else:
                        self.reducer.prepare_for_backward([])
                else:
                    self.require_forward_param_sync = False

            # TODO: DDPSink is currently enabled for unused parameter detection and
            # static graph training for first iteration.
            if (self.find_unused_parameters and not self.static_graph) or (
                    self.static_graph and self.num_iterations == 1
            ):
                state_dict = {
                    'static_graph': self.static_graph,
                    'num_iterations': self.num_iterations,
                }

                output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
                    output
                )
                output_placeholders = [None for _ in range(len(output_tensor_list))]
                # Do not touch tensors that have no grad_fn, which can cause issues
                # such as https://github.com/pytorch/pytorch/issues/60733
                for i, output in enumerate(output_tensor_list):
                    if torch.is_tensor(output) and output.grad_fn is None:
                        output_placeholders[i] = output

                # When find_unused_parameters=True, makes tensors which require grad
                # run through the DDPSink backward pass. When not all outputs are
                # used in loss, this makes those corresponding tensors receive
                # undefined gradient which the reducer then handles to ensure
                # param.grad field is not touched and we don't error out.
                passthrough_tensor_list = _DDPSink.apply(
                    self.reducer,
                    state_dict,
                    *output_tensor_list,
                )
                for i in range(len(output_placeholders)):
                    if output_placeholders[i] is None:
                        output_placeholders[i] = passthrough_tensor_list[i]

                # Reconstruct output data structure.
                output = _tree_unflatten_with_rref(
                    output_placeholders, treespec, output_is_rref
                )
        return output