import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from .nn import mean_flat
#from . import dist_util
import functools

class AdversarialLoss(nn.Module):
    def __init__(self,  gan_type='WGAN_GP', gan_k=1, 
        lr_dis=1e-5 ):

        super(AdversarialLoss, self).__init__()
       
        self.gan_type = gan_type
        self.gan_k = gan_k
 
        model = NLayerDiscriminator().cuda()
        self.discriminator =  DDP(
                model,
                device_ids=[torch.device('cuda')],
                output_device=torch.device('cuda'),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
 
        if (gan_type in ['WGAN_GP', 'GAN']):
            self.optimizer = optim.Adam(
                self.discriminator.parameters(),
                 lr=lr_dis
            )
   
    def forward(self, fake, real):
        fake_detach = fake.detach()
        for _ in range(self.gan_k):
            self.optimizer.zero_grad()
            d_fake = self.discriminator(fake_detach)
            d_real = self.discriminator(real)
            if (self.gan_type.find('WGAN') >= 0):
                loss_d = (d_fake - d_real).mean()
                if self.gan_type.find('GP') >= 0:
                    epsilon = torch.rand(real.size(0), 1, 1, 1).cuda()
                    epsilon = epsilon.expand(real.size())
                    hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
                    hat.requires_grad = True
                    d_hat = self.discriminator(hat)
                    gradients = torch.autograd.grad(
                        outputs=d_hat.sum(), inputs=hat,
                        retain_graph=True, create_graph=True, only_inputs=True
                    )[0]
                    gradients = gradients.view(gradients.size(0), -1)
                    gradient_norm = gradients.norm(2, dim=1)
                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
                    loss_d += gradient_penalty
 
            # print('d loss:', loss_d)
            # Discriminator update
            loss_d.backward()
            self.optimizer.step()

        d_fake_for_g = self.discriminator(fake)
        if (self.gan_type.find('WGAN') >= 0):
            loss_g = -d_fake_for_g 
    
        # Generator loss
        return  mean_flat(loss_g)



def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=1, bias=True)


def conv7x7(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=7, 
                     stride=stride, padding=3, bias=True)


class Discriminator(nn.Module):
    def __init__(self, ):
        super(Discriminator, self).__init__()
        self.conv1 = conv7x7(3, 32)
        self.norm1 = nn.InstanceNorm2d(32, affine=True)
        self.LReLU1 = nn.LeakyReLU(0.2)

        self.conv2 = conv3x3(32, 32, 2)
        self.norm2 = nn.InstanceNorm2d(32, affine=True)
        self.LReLU2 = nn.LeakyReLU(0.2)

        self.conv3 = conv3x3(32, 64)
        self.norm3 = nn.InstanceNorm2d(64, affine=True)
        self.LReLU3 = nn.LeakyReLU(0.2)

        self.conv4 = conv3x3(64, 64, 2)
        self.norm4 = nn.InstanceNorm2d(64, affine=True)
        self.LReLU4 = nn.LeakyReLU(0.2)

        self.conv5 = conv3x3(64, 128)
        self.norm5 = nn.InstanceNorm2d(128, affine=True)
        self.LReLU5 = nn.LeakyReLU(0.2)

        self.conv6 = conv3x3(128, 128, 2)
        self.norm6 = nn.InstanceNorm2d(128, affine=True)
        self.LReLU6 = nn.LeakyReLU(0.2)

        self.conv7 = conv3x3(128, 256)
        self.norm7 = nn.InstanceNorm2d(256, affine=True)
        self.LReLU7 = nn.LeakyReLU(0.2)

        self.conv8 = conv3x3(256, 256, 2)
        self.norm8 = nn.InstanceNorm2d(256, affine=True)
        self.LReLU8 = nn.LeakyReLU(0.2)

        self.conv9 = conv3x3(256, 512)
        self.norm9 = nn.InstanceNorm2d(512, affine=True)
        self.LReLU9 = nn.LeakyReLU(0.2)

        self.conv10 = conv3x3(512, 512, 2)
        self.norm10 = nn.InstanceNorm2d(512, affine=True)
        self.LReLU10 = nn.LeakyReLU(0.2)

        self.conv11 = conv3x3(512, 32)
        self.norm11 = nn.InstanceNorm2d(32, affine=True)
        self.LReLU11 = nn.LeakyReLU(0.2)
        self.conv12 = conv3x3(32, 1)
 
        
    def forward(self, x):
        x = self.LReLU1(self.norm1(self.conv1(x)))
        x = self.LReLU2(self.norm2(self.conv2(x)))
        x = self.LReLU3(self.norm3(self.conv3(x)))
        x = self.LReLU4(self.norm4(self.conv4(x)))
        x = self.LReLU5(self.norm5(self.conv5(x)))
        x = self.LReLU6(self.norm6(self.conv6(x)))
        x = self.LReLU7(self.norm7(self.conv7(x)))
        x = self.LReLU8(self.norm8(self.conv8(x)))
        x = self.LReLU9(self.norm9(self.conv9(x)))
        x = self.LReLU10(self.norm10(self.conv10(x)))
        
        x = self.LReLU11(self.norm11(self.conv11(x)))
        x = self.conv12(x)
 
        return x        



def get_norm_layer(norm_type='instance'):
    """Return a normalization layer
    Parameters:
        norm_type (str) -- the name of the normalization layer: batch | instance | none
    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'none':
        def norm_layer(x): return Identity()
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer


class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc=3, ndf=64, n_layers=3 ):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        norm_layer = get_norm_layer(norm_type='instance')
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)