import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, List
from timm.models.layers import trunc_normal_


class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x


def conv_block(
    din: int,
    dout: int,
    k: int,
    act: nn.Module = nn.Identity,
    depthwise: bool = False,
    bottleneck: bool = False,
    batchnorm: bool = False,
    separable: bool = False,
    batchnorm_type: str = "bn",
    **kwargs
):
    if batchnorm_type == "bn":
        bn_fun = lambda x: nn.BatchNorm2d(x, eps=1e-2)
    elif batchnorm_type == "frn":
        bn_fun = lambda x: FRN(x)
    else:
        raise NotImplementedError(f"batchnorm_type {batchnorm_type} not implemented")

    if not depthwise:
        d1 = din // 4 if bottleneck else din
        kwargs = kwargs.copy()
        kwargs["bias"] = True # (not batchnorm)
        batchnorm = True # (not batchnorm)
        mods = []
        mods.append(nn.Conv2d(d1, dout, k, **kwargs))
        if batchnorm:
            mods.append(bn_fun(dout))
        mods.append(act())
        if bottleneck:
            blayer = [nn.Conv2d(din, din // 4, 1, bias=(not batchnorm))]
            if batchnorm:
                blayer.append(LayerNorm(din // 4, eps=1e-2))
            mods = mods + blayer
    elif separable:
        kwargs = kwargs.copy()
        kwargs2 = kwargs.copy()
        kwargs3 = kwargs.copy()
        kwargs["groups"] = din
        kwargs["padding"] = [0, k//2]
        kwargs["bias"] = True
        kwargs2["groups"] = din
        kwargs2["padding"] = [k//2, 0]
        kwargs2["bias"] = False
        kwargs3["groups"] = 1
        kwargs3["padding"] = 0
        kwargs3["bias"] = (not batchnorm)
        mods = [
            nn.Conv2d(din, din, (1, k), **kwargs),
            nn.Conv2d(din, din, (k, 1), **kwargs2),
            nn.Conv2d(din, dout, 1, **kwargs3),
        ]
        if batchnorm:
            mods.append(bn_fun(dout))
        mods.append(act())
    else:
        kwargs = kwargs.copy()
        kwargs2 = kwargs.copy()
        kwargs["groups"] = din
        kwargs["padding"] = k//2
        kwargs["bias"] = True
        kwargs2["groups"] = 1
        kwargs2["padding"] = 0
        kwargs2["bias"] = True  # 
        mods = [
            nn.Conv2d(din, din, (k, k), **kwargs),
            nn.Conv2d(din, dout, 1, **kwargs2)
        ]
        if batchnorm:
            mods.append(bn_fun(dout))
        mods.append(act())
    mod = nn.Sequential(*mods)
    return mod


class FRN(nn.Module):
    def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
        """
        weight = gamma, bias = beta
        beta, gamma:
            Variables of shape [1, 1, 1, C]. if TensorFlow
            Variables of shape [1, C, 1, 1]. if PyTorch
        eps: A scalar constant or learnable variable.
        """
        super(FRN, self).__init__()

        self.num_features = num_features
        self.init_eps = eps
        self.is_eps_leanable = is_eps_leanable

        self.weight = nn.parameter.Parameter(
            torch.Tensor(1, num_features, 1, 1), requires_grad=True)
        self.bias = nn.parameter.Parameter(
            torch.Tensor(1, num_features, 1, 1), requires_grad=True)
        if is_eps_leanable:
            self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True)
        else:
            self.register_buffer('eps', torch.Tensor([eps]))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.ones_(self.weight)
        nn.init.zeros_(self.bias)
        if self.is_eps_leanable:
            nn.init.constant_(self.eps, self.init_eps)

    def extra_repr(self):
        return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)

    def forward(self, x):
        """
        0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
        0, 1, 2, 3 -> (B, C, H, W) in PyTorch
        TensorFlow code
            nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
            x = x * tf.rsqrt(nu2 + tf.abs(eps))
            # This Code include TLU function max(y, tau)
            return tf.maximum(gamma * x + beta, tau)
        """
        # Compute the mean norm of activations per channel.
        nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)

        # Perform FRN.
        x = x * torch.rsqrt(nu2 + self.eps.abs())

        # Scale and Bias
        x = self.weight * x + self.bias
        return x



class res_layers(nn.Module):
    def __init__(self, d: int, n: int, k: int, act: nn.Module, **kwargs):
        super().__init__()
        kwargs = kwargs.copy()
        self.convs = nn.ModuleList(
            [conv_block(d, d, k, padding=k//2, batchnorm=True, act=nn.Identity) for _ in range(n)]
        )
        self.acts = nn.ModuleList([act() for _ in range(n)])

    def forward(self, inputs: Tensor):
        x = inputs
        for f, act in zip(self.convs, self.acts):
            x = act(f(x) + x)
        return x


def down_layer(d: int, k: int, act: nn.Module, num_res=0, pool=nn.MaxPool2d, **kwargs):
    modules = [
        pool(2, 2, 0), conv_block(d // 2, d, k, act, **kwargs)
    ]
    if num_res > 0:
        modules.append(res_layers(d, num_res, k, act, **kwargs))
    return nn.Sequential(*modules)


class up_layer(nn.Module):
    def __init__(self, d: int, k: int, act: nn.Module, num_res=0, **kwargs):
        super().__init__()
        modules = []
        if num_res > 0:
            modules.append(res_layers(2 * d, num_res, k, act, **kwargs))
        modules.append(nn.UpsamplingBilinear2d(scale_factor=2))
        modules.append(conv_block(2 * d, d, k, act, **kwargs))
        self.net1 = nn.Sequential(*modules)
        self.net2 = conv_block(2 * d, d, k, act, **kwargs)

    def forward(self, x: Tensor, z: Tensor):
        x = self.net1(x)
        x = torch.cat([x, z], 1)
        x = self.net2(x)
        return x


class UNetEncoder(nn.Module):
    def __init__(
        self,
        cin,
        cout,
        ksize=3,
        depth=3,
        num_res=1,
        groups: int = 1,
        n_hidden=16,
        act=nn.SiLU,
        pool='MaxPool2d',
        dropout: bool = False,
        bottleneck: bool = False,
        depthwise: bool = False,
        separable: bool = False,
        batchnorm: bool = False,
        batchnorm_type: str = "bn",
    ) -> None:
        super().__init__()
        self.cin = cin
        self.cout = cout
        self.num_res = num_res
        self.depth = depth
        self.n_hidden = n_hidden
        self.ksize = ksize
        self.groups = groups
        d, k = n_hidden, ksize
        self.act = act
        self.batchnorm = batchnorm
        act = getattr(nn, act) if isinstance(act, str) else act
        pool = getattr(nn, pool)

        kwargs = dict(
            bias=True,
            groups=groups,
            padding=k // 2,
            padding_mode="replicate",
            bottleneck=bottleneck,
            depthwise=depthwise,
            separable=separable,
            batchnorm=batchnorm,
            batchnorm_type=batchnorm_type,
        )

        kwargs1 = kwargs.copy()
        kwargs1["groups"] = 1
        self.first = nn.Sequential(
            nn.Conv2d(cin, d, k, padding=k//2),
            # LayerNorm(d, eps=1e-6),
            FRN(d, eps=1e-2),
            act(),
            # conv_block(cin, d, k, act, **kwargs1),
            res_layers(d, num_res, k, act, **kwargs)
        )

        self.down = nn.ModuleList()
        for _ in range(depth):
            d *= 2
            layer = down_layer(d, k, act, num_res=num_res, pool=pool, **kwargs)
            self.down.append(layer)

        if dropout:
            self.dropout = nn.Dropout2d(0.5)

        self.up = nn.ModuleList()
        for _ in range(depth):
            d //= 2
            layer = up_layer(d, k, act, num_res=num_res, **kwargs)
            self.up.append(layer)

        kwargsf = kwargs.copy()
        kwargsf["groups"] = 1
        self.final = nn.Sequential(
            res_layers(d, num_res, k, act, **kwargs),
            nn.Conv2d(d, cout, k, padding=k//2)
            # conv_block(d, cout, k, **kwargsf)
        )
        self.apply(self._init_weights)

    def forward(self, inputs: Tensor):
        x = inputs
        x = self.first(x)

        # downsampling
        intermediate_outputs = [x]
        for _, fdown in enumerate(self.down):
            x = fdown(x)
            intermediate_outputs.append(x)

        # dropout
        if hasattr(self, "dropout"):
            x = self.dropout(x)

        # upsampling
        for i, fup in enumerate(self.up):
            z = intermediate_outputs[-(i + 2)]
            x = fup(x, z)

        # final layers and desired size
        x = self.final(x)

        return x
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)


class Decoder(nn.Module):
    def __init__(self, din, dout, n_hidden=64, n_res=0, act=nn.SiLU, offset=False, collapse_dims=True, loss_type="mse", fit_sigma: bool = False, **kwargs):
        super().__init__()
        if offset:
            self.spatial_emb = SpatialEmbeddings(din, n_hidden, n_hidden)
        self.din = din
        self.dout = dout
        self.collapse_dims =collapse_dims
        self.hin = n_hidden if offset else din
        self.head = nn.Sequential(
            conv_block(self.hin, n_hidden, 1, act, **kwargs),
            *[conv_block(n_hidden, n_hidden, 1, act, **kwargs) for _ in range(n_res)],
            nn.Conv2d(n_hidden, dout, 1),
        )
        self.loss_type = loss_type
        if loss_type == "mse":
            self.log_sig = nn.Parameter(torch.tensor(0.0), requires_grad=fit_sigma)

        # self.head = conv_block(n_hidden, n_hidden, 1, act, **kwargs)

    def forward_offset(self, latent, offset):
        spatial = self.spatial_emb(offset)
        bs, D, D1 = spatial.shape
        *_, nr, nc = latent.shape
        spatial = spatial.view(bs, D, D1, 1, 1)
        latent = latent.view(bs, D, 1, nr, nc)
        v = (latent * spatial).view(bs, D, D1, nr, nc)
        out = v.sum(1) # .view(1, D, nr, nc)
        out = self.head(out)
        return out

    def forward(self, latent, offset=None):
        if offset is not None:
            out = self.forward_offset(latent, offset)
        else:
            out = self.head(latent)
        if self.collapse_dims:
            out = out.squeeze(1)
        return out

    def loss_binary(
        self,
        tgt: Tensor,
        C: Optional[Tensor] = None,
        M: Optional[Tensor] = None,
    ):
        if M is None:
            M = torch.ones_like(tgt)
        L = self(C)
        if len(L.shape) == 0:
            L = torch.ones_like(tgt) * L
        # L = L.clip(-5.0, 5.0)
        # loss = torch.where(
        #     L >= 0.0,
        #     torch.log(1.0 + torch.exp(-L)) + (1.0 - tgt) * L,
        #     torch.log(1.0 + torch.exp(L)) - tgt * L,
        # )
        loss = F.binary_cross_entropy_with_logits(L, tgt, reduction='none')
        loss = (loss * M).sum() / tgt.shape[0]
        return L, loss
    
    def loss_mse(
        self,
        Y: Tensor,
        C: Optional[Tensor] = None,
        A: Optional[Tensor] = None,
        M: Optional[Tensor] = None,
    ):
        if M is None:
            M = torch.ones_like(Y)
        Yhat = self.forward(C, A) * M
        # return Yhat, 0.0
        sig = (1e-4 + self.log_sig).exp()
        prec = 1.0 / (1e-4 + sig ** 2)
        err = 0.5 * prec * F.mse_loss(Yhat, Y, reduction="none") + self.log_sig
        prior = sig - self.log_sig
        # err = 0.5 * F.mse_loss(Yhat, Y, reduction="none") * M
        loss = ((err * M).sum() + prior)
        # sigerr = F.smooth_l1_loss(mse.detach().m,sqrt(), sig)
        loss = loss  / Y.shape[0]
        return Yhat, loss

    def loss(self, *args, **kwargs):
        if self.loss_type == "binary":
            l = self.loss_binary(*args, **kwargs)
        elif self.loss_type == "mse":
            l = self.loss_mse(*args, **kwargs)
        return l 


class SpatialEmbeddings(nn.Module):
    def __init__(self, dout, nl, n_hidden=64, act=nn.SiLU, n_chunks=1):
        super().__init__()
        self.D = dout
        self.nl = nl
        act = getattr(nn, act) if isinstance(act, str) else act
        self.dec = nn.Sequential(
            nn.Linear(2, n_hidden),
            act(),
            nn.Linear(n_hidden, n_hidden),
            act(),
            nn.Linear(n_hidden, dout * nl),
        )

    def forward(self, inputs):
        u = inputs
        u = (
            [u] #+
            # [torch.sin(u / (1000 ** (2 * k / 12))) for k in range(0, 11, 2)]
            # + [torch.cos(u / (1000 ** (2 * k / 12))) for k in range(0, 11, 2)]
            # [torch.sin(2 * np.pi * u / k) for k in range(1, 12, 2)] +
            # [torch.cos(2 * np.pi * u / k) for k in range(1, 12, 2)]
        )
        u = torch.cat(u, -1)
        out = self.dec(u)
        out = out.view(-1, self.D, self.nl)
        return out