# This file is copied from https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/original_models.py

# MIT License

# Copyright (c) 2022 Lorenzo Breschi

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F

import torchvision
from torchvision import models

import pytorch_lightning as pl

class LeakySoftplus(nn.Module):
    def __init__(self,negative_slope: float = 0.01 ):
        super().__init__()
        self.negative_slope=negative_slope

    def forward(self,input):
        return F.softplus(input)+F.logsigmoid(input)*self.negative_slope


grelu = nn.LeakyReLU(0.2)
#grelu = nn.Softplus()
#grelu = LeakySoftplus(0.2)
#####
# Currently default generator we use
# conv0 -> conv1 -> conv2 -> resnet_blocks -> upconv2 -> upconv1 ->  conv_11 -> (conv_11_a)* -> conv_12 -> (Tanh)*
# there are 2 conv layers inside conv_11_a
# * means is optional, model uses skip-connections
class Generator(pl.LightningModule):
    def __init__(self, norm_layer='batch_norm', use_bias=False, resnet_blocks=7, tanh=True,
                 filters=[32, 64, 128, 128, 128, 64], input_channels=3, output_channels=3, append_smoothers=False):
        super().__init__()
        assert norm_layer in [None, 'batch_norm', 'instance_norm'], \
            "norm_layer should be None, 'batch_norm' or 'instance_norm', not {}".format(
                norm_layer)
        self.norm_layer = None
        if norm_layer == 'batch_norm':
            self.norm_layer = nn.BatchNorm2d
        elif norm_layer == 'instance_norm':
            self.norm_layer = nn.InstanceNorm2d

        # filters = [f//3 for f in filters]
        self.use_bias = use_bias
        self.resnet_blocks = resnet_blocks
        self.append_smoothers = append_smoothers

        stride1 = 2
        stride2 = 2
        self.conv0 = self.relu_layer(in_filters=input_channels, out_filters=filters[0],
                                     kernel_size=7, stride=1, padding=3,
                                     bias=self.use_bias,
                                     norm_layer=self.norm_layer,
                                     nonlinearity=grelu)

        self.conv1 = self.relu_layer(in_filters=filters[0],
                                     out_filters=filters[1],
                                     kernel_size=3, stride=stride1, padding=1,
                                     bias=self.use_bias,
                                     norm_layer=self.norm_layer,
                                     nonlinearity=grelu)

        self.conv2 = self.relu_layer(in_filters=filters[1],
                                     out_filters=filters[2],
                                     kernel_size=3, stride=stride2, padding=1,
                                     bias=self.use_bias,
                                     norm_layer=self.norm_layer,
                                     nonlinearity=grelu)

        self.resnets = nn.ModuleList()
        for i in range(self.resnet_blocks):
            self.resnets.append(
                self.resnet_block(in_filters=filters[2],
                                  out_filters=filters[2],
                                  kernel_size=3, stride=1, padding=1,
                                  bias=self.use_bias,
                                  norm_layer=self.norm_layer,
                                  nonlinearity=grelu))

        self.upconv2 = self.upconv_layer_upsample_and_conv(in_filters=filters[3] + filters[2],
                                                           # in_filters=filters[3], # disable skip-connections
                                                           out_filters=filters[4],
                                                           scale_factor=stride2,
                                                           kernel_size=3, stride=1, padding=1,
                                                           bias=self.use_bias,
                                                           norm_layer=self.norm_layer,
                                                           nonlinearity=grelu)

        self.upconv1 = self.upconv_layer_upsample_and_conv(in_filters=filters[4] + filters[1],
                                                           # in_filters=filters[4],  # disable skip-connections
                                                           out_filters=filters[4],
                                                           scale_factor=stride1,
                                                           kernel_size=3, stride=1, padding=1,
                                                           bias=self.use_bias,
                                                           norm_layer=self.norm_layer,
                                                           nonlinearity=grelu)

        self.conv_11 = nn.Sequential(
            nn.Conv2d(in_channels=filters[0] + filters[4] + input_channels,
                     # in_channels=filters[4],  # disable skip-connections
                     out_channels=filters[5],
                     kernel_size=7, stride=1, padding=3, bias=self.use_bias, padding_mode='zeros'),
            grelu
        )

        if self.append_smoothers:
            self.conv_11_a = nn.Sequential(
                nn.Conv2d(filters[5], filters[5], kernel_size=3,
                         bias=self.use_bias, padding=1, padding_mode='zeros'),
                grelu,
                # replace with variable
                nn.BatchNorm2d(num_features=filters[5]),
                nn.Conv2d(filters[5], filters[5], kernel_size=3,
                         bias=self.use_bias, padding=1, padding_mode='zeros'),
                grelu
            )

        if tanh:
            self.conv_12 = nn.Sequential(nn.Conv2d(filters[5], output_channels,
                                                  kernel_size=1, stride=1,
                                                  padding=0, bias=True, padding_mode='zeros'),
                                         #torchvision.transforms.Grayscale(num_output_channels=3),
                                         nn.Sigmoid())
        else:
            self.conv_12 = nn.Conv2d(filters[5], output_channels, kernel_size=1, stride=1,
                                    padding=0, bias=True, padding_mode='zeros')

    def log_tensors(self, logger, tag, img_tensor):
        logger.experiment.add_images(tag, img_tensor)

    def forward(self, input, logger=None, **kwargs):
        # [1, 3, 534, 800]
        output_d0 = self.conv0(input)
        output_d1 = self.conv1(output_d0)
        # comment to disable skip-connections
        output_d2 = self.conv2(output_d1)

        output = output_d2
        for layer in self.resnets:
            output = layer(output) + output

        output_u2 = self.upconv2(torch.cat((output, output_d2), dim=1))

        output_u1 = self.upconv1(torch.cat((output_u2, output_d1), dim=1))
        output = torch.cat(
            (output_u1, output_d0, input), dim=1)

        output_11 = self.conv_11(output)

        if self.append_smoothers:
            output_11_a = self.conv_11_a(output_11)
        else:
            output_11_a = output_11
        output_12 = self.conv_12(output_11_a)

        output = output_12

        return output

    def relu_layer(self, in_filters, out_filters, kernel_size, stride, padding, bias,
                   norm_layer, nonlinearity):
        out = nn.Sequential()
        out.add_module('conv', nn.Conv2d(in_channels=in_filters,
                                        out_channels=out_filters,
                                        kernel_size=kernel_size, stride=stride,
                                        padding=padding, bias=bias, padding_mode='zeros'))

        if norm_layer:
            out.add_module('normalization',
                           norm_layer(num_features=out_filters))
        if nonlinearity:
            out.add_module('nonlinearity', nonlinearity)
        # out.add_module('dropout', nn.Dropout2d(0.25))

        return out

    def resnet_block(self, in_filters, out_filters, kernel_size, stride, padding, bias,
                     norm_layer, nonlinearity):
        out = nn.Sequential()
        if nonlinearity:
            out.add_module('nonlinearity_0', nonlinearity)
        out.add_module('conv_0', nn.Conv2d(in_channels=in_filters,
                                          out_channels=out_filters,
                                          kernel_size=kernel_size, stride=stride,
                                          padding=padding, bias=bias, padding_mode='zeros'))
        if norm_layer:
            out.add_module('normalization',
                           norm_layer(num_features=out_filters))
        if nonlinearity:
            out.add_module('nonlinearity_1', nonlinearity)
        out.add_module('conv_1', nn.Conv2d(in_channels=in_filters,
                                          out_channels=out_filters,
                                          kernel_size=kernel_size, stride=stride,
                                          padding=padding, bias=bias, padding_mode='zeros'))
        return out

    def upconv_layer_upsample_and_conv(self, in_filters, out_filters, scale_factor, kernel_size, stride, padding, bias,
                                       norm_layer, nonlinearity):

        parts = [nn.Upsample(scale_factor=scale_factor),
                 nn.Conv2d(in_filters, out_filters, kernel_size,
                          stride, padding=padding, bias=False, padding_mode='zeros')
                 ]

        if norm_layer:
            parts.append(norm_layer(num_features=out_filters))

        if nonlinearity:
            parts.append(nonlinearity)

        return nn.Sequential(*parts)




relu = grelu

#####
# Default discriminator
#####

relu = nn.LeakyReLU(0.2)

class Discriminator(nn.Module):
    def __init__(self, num_filters=12, input_channels=3, n_layers=2,
                  norm_layer='instance_norm', use_bias=True):
        super().__init__()

        self.num_filters = num_filters

        self.input_channels = input_channels
        self.use_bias = use_bias

        if norm_layer == 'batch_norm':
            self.norm_layer = nn.BatchNorm2d
        else:
            self.norm_layer = nn.InstanceNorm2d
        self.net = self.make_net(
            n_layers, self.input_channels, 1, 4, 2, self.use_bias)

    def make_net(self, n, flt_in, flt_out=1, k=4, stride=2, bias=True):
        padding = 1
        model = nn.Sequential()

        model.add_module('conv0', self.make_block(
            flt_in, self.num_filters, k, stride, padding, bias, None, relu))

        flt_mult, flt_mult_prev = 1, 1
        # n - 1 blocks
        for l in range(1, n):
            flt_mult_prev = flt_mult
            flt_mult = min(2**(l), 8)
            model.add_module('conv_%d' % (l), self.make_block(self.num_filters * flt_mult_prev, self.num_filters * flt_mult,
                                                              k, stride, padding, bias, self.norm_layer, relu))

        flt_mult_prev = flt_mult
        flt_mult = min(2**n, 8)
        model.add_module('conv_%d' % (n), self.make_block(self.num_filters * flt_mult_prev, self.num_filters * flt_mult,
                                                          k, 1, padding, bias, self.norm_layer, relu))
        model.add_module('conv_out', self.make_block(
            self.num_filters * flt_mult, 1, k, 1, padding, bias, None, None))
        return model

    def make_block(self, flt_in, flt_out, k, stride, padding, bias, norm, relu):
        m = nn.Sequential()
        m.add_module('conv', nn.Conv2d(flt_in, flt_out, k,
                                      stride=stride, padding=padding, bias=bias, padding_mode='zeros'))
        if norm is not None:
            m.add_module('norm', norm(flt_out))
        if relu is not None:
            m.add_module('relu', relu)
        return m

    def forward(self, x):
        output = self.net(x)
        # output = output.mean((2, 3), True)
        # output = output.squeeze(-1).squeeze(-1)
        # output = output.mean(dim=(-1,-2))
        return output


#####
# Perception VGG19 loss
#####
class PerceptualVGG19(nn.Module):
    def __init__(self, feature_layers=[0, 3, 5], use_normalization=False):
        super().__init__()
        # model = models.vgg19(pretrained=True)
        model = models.squeezenet1_1(pretrained=True)
        model.float()
        model.eval()

        self.model = model
        self.feature_layers = feature_layers

        self.mean = torch.FloatTensor([0.485, 0.456, 0.406])
        self.mean_tensor = None

        self.std = torch.FloatTensor([0.229, 0.224, 0.225])
        self.std_tensor = None

        self.use_normalization = use_normalization

        for param in self.parameters():
            param.requires_grad = False

    def normalize(self, x):
        if not self.use_normalization:
            return x

        if self.mean_tensor is None:
            self.mean_tensor = Variable(
                self.mean.view(1, 3, 1, 1).expand(x.shape),
                requires_grad=False)
            self.std_tensor = Variable(
                self.std.view(1, 3, 1, 1).expand(x.shape), requires_grad=False)

        x = (x + 1) / 2
        return (x - self.mean_tensor) / self.std_tensor

    def run(self, x):
        features = []

        h = x

        for f in range(max(self.feature_layers) + 1):
            h = self.model.features[f](h)
            if f in self.feature_layers:
                not_normed_features = h.clone().view(h.size(0), -1)
                features.append(not_normed_features)

        return torch.cat(features, dim=1)

    def forward(self, x):
        h = self.normalize(x)
        return self.run(h)