import tensorflow as tf
import tensorflow.python.keras.backend as K
from tensorflow.python.eager import context
from tensorflow.python.ops import (
    gen_math_ops,
    math_ops,
    sparse_ops,
    standard_ops,
)


def l2normalize(v, eps=1e-12):
    return v / (tf.norm(v) + eps)


class ConvSN2D(tf.keras.layers.Conv2D):
    def __init__(self, filters, kernel_size, power_iterations=1, datatype=tf.float32, **kwargs):
        super(ConvSN2D, self).__init__(filters, kernel_size, **kwargs)
        self.power_iterations = power_iterations
        self.datatype = datatype

    def build(self, input_shape):
        super(ConvSN2D, self).build(input_shape)

        if self.data_format == "channels_first":
            channel_axis = 1
        else:
            channel_axis = -1

        self.u = self.add_weight(
            self.name + "_u",
            shape=tuple([1, self.kernel.shape.as_list()[-1]]),
            initializer=tf.initializers.RandomNormal(0, 1),
            trainable=False,
            dtype=self.dtype,
        )

    def compute_spectral_norm(self, W, new_u, W_shape):
        for _ in range(self.power_iterations):

            new_v = l2normalize(tf.matmul(new_u, tf.transpose(W)))
            new_u = l2normalize(tf.matmul(new_v, W))

        sigma = tf.matmul(tf.matmul(new_v, W), tf.transpose(new_u))
        W_bar = W / sigma

        with tf.control_dependencies([self.u.assign(new_u)]):
            W_bar = tf.reshape(W_bar, W_shape)

        return W_bar

    def call(self, inputs):
        W_shape = self.kernel.shape.as_list()
        W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1]))
        new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape)
        outputs = self._convolution_op(inputs, new_kernel)

        if self.use_bias:
            if self.data_format == "channels_first":
                outputs = tf.nn.bias_add(outputs, self.bias, data_format="NCHW")
            else:
                outputs = tf.nn.bias_add(outputs, self.bias, data_format="NHWC")
        if self.activation is not None:
            return self.activation(outputs)

        return outputs


class DenseSN(tf.keras.layers.Dense):
    def __init__(self, datatype=tf.float32, **kwargs):
        super(DenseSN, self).__init__(**kwargs)
        self.datatype = datatype

    def build(self, input_shape):
        super(DenseSN, self).build(input_shape)

        self.u = self.add_weight(
            self.name + "_u",
            shape=tuple([1, self.kernel.shape.as_list()[-1]]),
            initializer=tf.initializers.RandomNormal(0, 1),
            trainable=False,
            dtype=self.datatype,
        )

    def compute_spectral_norm(self, W, new_u, W_shape):
        new_v = l2normalize(tf.matmul(new_u, tf.transpose(W)))
        new_u = l2normalize(tf.matmul(new_v, W))
        sigma = tf.matmul(tf.matmul(new_v, W), tf.transpose(new_u))
        W_bar = W / sigma
        with tf.control_dependencies([self.u.assign(new_u)]):
            W_bar = tf.reshape(W_bar, W_shape)
        return W_bar

    def call(self, inputs):
        W_shape = self.kernel.shape.as_list()
        W_reshaped = tf.reshape(self.kernel, (-1, W_shape[-1]))
        new_kernel = self.compute_spectral_norm(W_reshaped, self.u, W_shape)
        rank = len(inputs.shape)
        if rank > 2:
            outputs = standard_ops.tensordot(inputs, new_kernel, [[rank - 1], [0]])
            if not context.executing_eagerly():
                shape = inputs.shape.as_list()
                output_shape = shape[:-1] + [self.units]
                outputs.set_shape(output_shape)
        else:
            inputs = math_ops.cast(inputs, self._compute_dtype)
            if K.is_sparse(inputs):
                outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, new_kernel)
            else:
                outputs = gen_math_ops.mat_mul(inputs, new_kernel)
        if self.use_bias:
            outputs = tf.nn.bias_add(outputs, self.bias)
        if self.activation is not None:
            return self.activation(outputs)
        return outputs


class AddNoise(tf.keras.layers.Layer):
    def __init__(self, datatype=tf.float32, **kwargs):
        super(AddNoise, self).__init__(**kwargs)
        self.datatype = datatype

    def build(self, input_shape):
        self.b = self.add_weight(
            shape=[
                1,
            ],
            initializer=tf.keras.initializers.zeros(),
            trainable=True,
            name="noise_weight",
        )

    def call(self, inputs):
        rand = tf.random.normal(
            [tf.shape(inputs)[0], inputs.shape[1], inputs.shape[2], 1],
            mean=0.0,
            stddev=1.0,
            dtype=self.datatype,
        )
        output = inputs + self.b * rand
        return output


class PosEnc(tf.keras.layers.Layer):
    def __init__(self, datatype=tf.float32, **kwargs):
        super(PosEnc, self).__init__(**kwargs)
        self.datatype = datatype

    def call(self, inputs):
        pos = tf.repeat(
            tf.reshape(tf.range(inputs.shape[-3], dtype=tf.int32), [1, -1, 1, 1]),
            inputs.shape[-2],
            -2,
        )
        pos = tf.cast(tf.repeat(pos, tf.shape(inputs)[0], 0), self.dtype) / tf.cast(inputs.shape[-3], self.datatype)
        return tf.concat([inputs, pos], -1)  # [bs,1,hop,2]


def flatten_hw(x, data_format="channels_last"):
    if data_format == "channels_last":
        x = tf.transpose(x, perm=[0, 3, 1, 2])  # Convert to `channels_first`

    old_shape = tf.shape(x)