#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import torch
import torch.nn.functional as F
import torch.distributed as dist
import functools
import pickle
import numpy as np
from collections import OrderedDict
from torch.autograd import Function

__all__ = ['is_dist_initialized',
           'get_world_size',
           'get_rank',
           'new_group',
           'destroy_process_group',
           'barrier',
           'broadcast',
           'all_reduce',
           'reduce',
           'gather',
           'all_gather',
           'reduce_dict',
           'get_global_gloo_group',
           'generalized_all_gather',
           'generalized_gather',
           'scatter',
           'reduce_scatter',
           'send',
           'recv',
           'isend',
           'irecv',
           'shared_random_seed',
           'diff_all_gather',
           'diff_all_reduce',
           'diff_scatter',
           'diff_copy',
           'spherical_kmeans',
           'sinkhorn']

#-------------------------------- Distributed operations --------------------------------#

def is_dist_initialized():
    return dist.is_available() and dist.is_initialized()

def get_world_size(group=None):
    return dist.get_world_size(group) if is_dist_initialized() else 1

def get_rank(group=None):
    return dist.get_rank(group) if is_dist_initialized() else 0

def new_group(ranks=None, **kwargs):
    if is_dist_initialized():
        return dist.new_group(ranks, **kwargs)
    return None

def destroy_process_group():
    if is_dist_initialized():
        dist.destroy_process_group()

def barrier(group=None, **kwargs):
    if get_world_size(group) > 1:
        dist.barrier(group, **kwargs)

def broadcast(tensor, src, group=None, **kwargs):
    if get_world_size(group) > 1:
        return dist.broadcast(tensor, src, group, **kwargs)

def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs):
    if get_world_size(group) > 1:
        return dist.all_reduce(tensor, op, group, **kwargs)

def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs):
    if get_world_size(group) > 1:
        return dist.reduce(tensor, dst, op, group, **kwargs)

def gather(tensor, dst=0, group=None, **kwargs):
    rank = get_rank()  # global rank
    world_size = get_world_size(group)
    if world_size == 1:
        return [tensor]
    tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None
    dist.gather(tensor, tensor_list, dst, group, **kwargs)
    return tensor_list

def all_gather(tensor, uniform_size=True, group=None, **kwargs):
    world_size = get_world_size(group)
    if world_size == 1:
        return [tensor]
    assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()'
    
    if uniform_size:
        tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
        dist.all_gather(tensor_list, tensor, group, **kwargs)
        return tensor_list
    else:
        # collect tensor shapes across GPUs
        shape = tuple(tensor.shape)
        shape_list = generalized_all_gather(shape, group)

        # flatten the tensor
        tensor = tensor.reshape(-1)
        size = int(np.prod(shape))
        size_list = [int(np.prod(u)) for u in shape_list]
        max_size = max(size_list)

        # pad to maximum size
        if size != max_size:
            padding = tensor.new_zeros(max_size - size)
            tensor = torch.cat([tensor, padding], dim=0)
        
        # all_gather
        tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
        dist.all_gather(tensor_list, tensor, group, **kwargs)

        # reshape tensors
        tensor_list = [t[:n].view(s) for t, n, s in zip(
            tensor_list, size_list, shape_list)]
        return tensor_list

@torch.no_grad()
def reduce_dict(input_dict, group=None, reduction='mean', **kwargs):
    assert reduction in ['mean', 'sum']
    world_size = get_world_size(group)
    if world_size == 1:
        return input_dict
    
    # ensure that the orders of keys are consistent across processes
    if isinstance(input_dict, OrderedDict):
        keys = list(input_dict.keys)
    else:
        keys = sorted(input_dict.keys())
    vals = [input_dict[key] for key in keys]
    vals = torch.stack(vals, dim=0)
    dist.reduce(vals, dst=0, group=group, **kwargs)
    if dist.get_rank(group) == 0 and reduction == 'mean':
        vals /= world_size
    dist.broadcast(vals, src=0, group=group, **kwargs)
    reduced_dict = type(input_dict)([
        (key, val) for key, val in zip(keys, vals)])
    return reduced_dict

@functools.lru_cache()
def get_global_gloo_group():
    backend = dist.get_backend()
    assert backend in ['gloo', 'nccl']
    if backend == 'nccl':
        return dist.new_group(backend='gloo')
    else:
        return dist.group.WORLD

def _serialize_to_tensor(data, group):
    backend = dist.get_backend(group)
    assert backend in ['gloo', 'nccl']
    device = torch.device('cpu' if backend == 'gloo' else 'cuda')

    buffer = pickle.dumps(data)
    if len(buffer) > 1024 ** 3:
        logger = logging.getLogger(__name__)
        logger.warning(
            'Rank {} trying to all-gather {:.2f} GB of data on device'
            '{}'.format(get_rank(), len(buffer) / (1024 ** 3), device))
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to(device=device)
    return tensor

def _pad_to_largest_tensor(tensor, group):
    world_size = dist.get_world_size(group=group)
    assert world_size >= 1, \
        'gather/all_gather must be called from ranks within' \
        'the give group!'
    local_size = torch.tensor(
        [tensor.numel()], dtype=torch.int64, device=tensor.device)
    size_list = [torch.zeros(
        [1], dtype=torch.int64, device=tensor.device)
        for _ in range(world_size)]

    # gather tensors and compute the maximum size
    dist.all_gather(size_list, local_size, group=group)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    # pad tensors to the same size
    if local_size != max_size:
        padding = torch.zeros(
            (max_size - local_size, ),
            dtype=torch.uint8, device=tensor.device)
        tensor = torch.cat((tensor, padding), dim=0)
    return size_list, tensor

def generalized_all_gather(data, group=None):
    if get_world_size(group) == 1:
        return [data]
    if group is None:
        group = get_global_gloo_group()
    
    tensor = _serialize_to_tensor(data, group)
    size_list, tensor = _pad_to_largest_tensor(tensor, group)
    max_size = max(size_list)

    # receiving tensors from all ranks
    tensor_list = [torch.empty(
        (max_size, ), dtype=torch.uint8, device=tensor.device)
        for _ in size_list]
    dist.all_gather(tensor_list, tensor, group=group)

    data_list = []
    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))
    return data_list

def generalized_gather(data, dst=0, group=None):
    world_size = get_world_size(group)
    if world_size == 1:
        return [data]
    if group is None:
        group = get_global_gloo_group()
    rank = dist.get_rank()  # global rank

    tensor = _serialize_to_tensor(data, group)
    size_list, tensor = _pad_to_largest_tensor(tensor, group)

    # receiving tensors from all ranks to dst
    if rank == dst:
        max_size = max(size_list)
        tensor_list = [torch.empty(
            (max_size, ), dtype=torch.uint8, device=tensor.device)
            for _ in size_list]
        dist.gather(tensor, tensor_list, dst=dst, group=group)

        data_list = []
        for size, tensor in zip(size_list, tensor_list):
            buffer = tensor.cpu().numpy().tobytes()[:size]
            data_list.append(pickle.loads(buffer))
        return data_list
    else:
        dist.gather(tensor, [], dst=dst, group=group)
        return []

def scatter(data, scatter_list=None, src=0, group=None, **kwargs):
    r"""NOTE: only supports CPU tensor communication.
    """
    if get_world_size(group) > 1:
        return dist.scatter(data, scatter_list, src, group, **kwargs)

def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs):
    if get_world_size(group) > 1:
        return dist.reduce_scatter(output, input_list, op, group, **kwargs)

def send(tensor, dst, group=None, **kwargs):
    if get_world_size(group) > 1:
        assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()'
        return dist.send(tensor, dst, group, **kwargs)

def recv(tensor, src=None, group=None, **kwargs):
    if get_world_size(group) > 1:
        assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()'
        return dist.recv(tensor, src, group, **kwargs)

def isend(tensor, dst, group=None, **kwargs):
    if get_world_size(group) > 1:
        assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()'
        return dist.isend(tensor, dst, group, **kwargs)

def irecv(tensor, src=None, group=None, **kwargs):
    if get_world_size(group) > 1:
        assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()'
        return dist.irecv(tensor, src, group, **kwargs)

def shared_random_seed(group=None):
    seed = np.random.randint(2 ** 31)
    all_seeds = generalized_all_gather(seed, group)
    return all_seeds[0]

#-------------------------------- Differentiable operations --------------------------------#

def _all_gather(x):
    if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
        return x
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    tensors = [torch.empty_like(x) for _ in range(world_size)]
    tensors[rank] = x
    dist.all_gather(tensors, x)
    return torch.cat(tensors, dim=0).contiguous()

def _all_reduce(x):
    if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
        return x
    dist.all_reduce(x)
    return x

def _split(x):
    if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1:
        return x
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    return x.chunk(world_size, dim=0)[rank].contiguous()

class DiffAllGather(Function):
    r"""Differentiable all-gather.
    """
    @staticmethod
    def symbolic(graph, input):
        return _all_gather(input)
    
    @staticmethod
    def forward(ctx, input):
        return _all_gather(input)
    
    @staticmethod
    def backward(ctx, grad_output):
        return _split(grad_output)

class DiffAllReduce(Function):
    r"""Differentiable all-reducd.
    """
    @staticmethod
    def symbolic(graph, input):
        return _all_reduce(input)
    
    @staticmethod
    def forward(ctx, input):
        return _all_reduce(input)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

class DiffScatter(Function):
    r"""Differentiable scatter.
    """
    @staticmethod
    def symbolic(graph, input):
        return _split(input)
    
    @staticmethod
    def symbolic(ctx, input):
        return _split(input)
    
    @staticmethod
    def backward(ctx, grad_output):
        return _all_gather(grad_output)

class DiffCopy(Function):
    r"""Differentiable copy that reduces all gradients during backward.
    """
    @staticmethod
    def symbolic(graph, input):
        return input
    
    @staticmethod
    def forward(ctx, input):
        return input
    
    @staticmethod
    def backward(ctx, grad_output):
        return _all_reduce(grad_output)

diff_all_gather = DiffAllGather.apply
diff_all_reduce = DiffAllReduce.apply
diff_scatter = DiffScatter.apply
diff_copy = DiffCopy.apply

#-------------------------------- Distributed algorithms --------------------------------#

@torch.no_grad()
def spherical_kmeans(feats, num_clusters, num_iters=10):
    k, n, c = num_clusters, *feats.size()
    ones = feats.new_ones(n, dtype=torch.long)

    # distributed settings
    rank = get_rank()
    world_size = get_world_size()

    # init clusters
    rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))]
    clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k]

    # variables
    new_clusters = feats.new_zeros(k, c)
    counts = feats.new_zeros(k, dtype=torch.long)

    # iterative Expectation-Maximization
    for step in range(num_iters + 1):
        # Expectation step
        simmat = torch.mm(feats, clusters.t())
        scores, assigns = simmat.max(dim=1)
        if step == num_iters:
            break

        # Maximization step
        new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats)
        all_reduce(new_clusters)

        counts.zero_()
        counts.index_add_(0, assigns, ones)
        all_reduce(counts)

        mask = (counts > 0)
        clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1)
        clusters = F.normalize(clusters, p=2, dim=1)
    return clusters, assigns, scores

@torch.no_grad()
def sinkhorn(Q, eps=0.5, num_iters=3):
    # normalize Q
    Q = torch.exp(Q / eps).t()
    sum_Q = Q.sum()
    all_reduce(sum_Q)
    Q /= sum_Q

    # variables
    n, m = Q.size()
    u = Q.new_zeros(n)
    r = Q.new_ones(n) / n
    c = Q.new_ones(m) / (m * get_world_size())

    # iterative update
    cur_sum = Q.sum(dim=1)
    all_reduce(cur_sum)
    for i in range(num_iters):
        u = cur_sum
        Q *= (r / u).unsqueeze(1)
        Q *= (c / Q.sum(dim=0)).unsqueeze(0)
        cur_sum = Q.sum(dim=1)
        all_reduce(cur_sum)
    return (Q / Q.sum(dim=0, keepdim=True)).t().float()