import torch
import torch.nn as nn
import torch.nn.functional as F
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
import einops
from modules.util import UpBlock2d, DownBlock2d


def make_coordinate_grid(spatial_size, type):
    d, h, w = spatial_size
    x = torch.arange(w).type(type)
    y = torch.arange(h).type(type)
    z = torch.arange(d).type(type)

    x = (2 * (x / (w - 1)) - 1)
    y = (2 * (y / (h - 1)) - 1)
    z = (2 * (z / (d - 1)) - 1)

    yy = y.view(1, -1, 1).repeat(d, 1, w)
    xx = x.view(1, 1, -1).repeat(d, h, 1)
    zz = z.view(-1, 1, 1).repeat(1, h, w)

    meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)

    return meshed


def kp2gaussian_3d(kp, spatial_size, kp_variance):
    """
    Transform a keypoint into gaussian like representation
    """
    # mean = kp['value']
    mean = kp

    coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
    number_of_leading_dimensions = len(mean.shape) - 1
    shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
    coordinate_grid = coordinate_grid.view(*shape)
    repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
    coordinate_grid = coordinate_grid.repeat(*repeats)

    # Preprocess kp shape
    shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
    mean = mean.view(*shape)

    mean_sub = (coordinate_grid - mean)

    out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)

    return out


class ResBlock3d(nn.Module):
    """
    Res block, preserve spatial resolution.
    """

    def __init__(self, in_features, kernel_size, padding):
        super(ResBlock3d, self).__init__()
        self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
                               padding=padding)
        self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
                               padding=padding)
        self.norm1 = BatchNorm3d(in_features, affine=True)
        self.norm2 = BatchNorm3d(in_features, affine=True)

    def forward(self, x):
        out = self.norm1(x)
        out = F.relu(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = F.relu(out)
        out = self.conv2(out)
        out += x
        return out


class rgb_predictor(nn.Module):
    def __init__(self, in_channels, simpled_channel=128, floor_num=8):
        super(rgb_predictor, self).__init__()
        self.floor_num = floor_num
        self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1)

    def forward(self, feature):
        """
        Args:
            feature: warp feature: bs * c * h * w
        Returns:
            rgb: bs * h * w * floor_num * e
        """
        feature = self.down_conv(feature)
        feature = einops.rearrange(feature, 'b (c f) h w  -> b c f h w', f=self.floor_num)
        feature = einops.rearrange(feature, 'b c f h w -> b h w f c')
        return feature


class sigma_predictor(nn.Module):
    def __init__(self, in_channels, simpled_channel=128, floor_num=8):
        super(sigma_predictor, self).__init__()
        self.floor_num = floor_num
        self.down_conv = nn.Conv2d(in_channels=in_channels, out_channels=simpled_channel, kernel_size=3, padding=1)

        self.res_conv3d = nn.Sequential(
            ResBlock3d(16, 3, 1),
            nn.BatchNorm3d(16),
            ResBlock3d(16, 3, 1),
            nn.BatchNorm3d(16),
            ResBlock3d(16, 3, 1),
            nn.BatchNorm3d(16)
        )

    def forward(self, feature):
        """
        Args:
            feature: bs * h * w * floor * c, the output of rgb predictor
        Returns:
            sigma: bs * h * w * floor * encode
            point: bs * 5023 * 3
        """
        heatmap = self.down_conv(feature)
        heatmap = einops.rearrange(heatmap, "b (c f) h w -> b c f h w", f=self.floor_num)
        heatmap = self.res_conv3d(heatmap)
        sigma = einops.rearrange(heatmap, "b c f h w -> b h w f c")

        point_dict = {'sigma_map': heatmap}
        # point_pred = einops.rearrange(point_pred, 'b p n -> b n p')
        return sigma, point_dict


class MultiHeadNeRFModel(torch.nn.Module):

    def __init__(self, hidden_size=128, num_encoding_rgb=16, num_encoding_sigma=16):
        super(MultiHeadNeRFModel, self).__init__()
        # self.xyz_encoding_dims = 1 + 1 * 2 * num_encoding_functions + num_encoding_rgb
        self.xyz_encoding_dims = num_encoding_sigma
        self.viewdir_encoding_dims = num_encoding_rgb

        # Input layer (default: 16 -> 128)
        self.layer1 = torch.nn.Linear(self.xyz_encoding_dims, hidden_size)
        # Layer 2 (default: 128 -> 128)
        self.layer2 = torch.nn.Linear(hidden_size, hidden_size)
        # Layer 3_1 (default: 128 -> 1): Predicts radiance ("sigma")
        self.layer3_1 = torch.nn.Linear(hidden_size, 1)
        # Layer 3_2 (default: 128 -> 32): Predicts a feature vector (used for color)
        self.layer3_2 = torch.nn.Linear(hidden_size, hidden_size // 4)
        self.layer3_3 = torch.nn.Linear(self.viewdir_encoding_dims, hidden_size)

        # Layer 4 (default: 32 + 128 -> 128)
        self.layer4 = torch.nn.Linear(
            hidden_size // 4 + hidden_size, hidden_size
        )
        # Layer 5 (default: 128 -> 128)
        self.layer5 = torch.nn.Linear(hidden_size, hidden_size)
        # Layer 6 (default: 128 -> 256): Predicts RGB color
        self.layer6 = torch.nn.Linear(hidden_size, 256)

        # Short hand for torch.nn.functional.relu
        self.relu = torch.nn.functional.relu

    def forward(self, rgb_in, sigma_in):
        """
        Args:
            x: rgb pred result of Perdict3D
            view: result of LightPredict
        Returns:
        """
        bs, h, w, floor_num, _ = rgb_in.size()
        # x = torch.cat((x, point3D), dim=-1)
        out = self.relu(self.layer1(sigma_in))
        out = self.relu(self.layer2(out))
        sigma = self.layer3_1(out)
        feat_sigma = self.relu(self.layer3_2(out))
        feat_rgb = self.relu(self.layer3_3(rgb_in))
        x = torch.cat((feat_sigma, feat_rgb), dim=-1)
        x = self.relu(self.layer4(x))
        x = self.relu(self.layer5(x))
        x = self.layer6(x)
        return x, sigma


def volume_render(rgb_pred, sigma_pred):
    """
    Args:
        rgb_pred: result of Nerf, [bs, h, w, floor, rgb_channel]
        sigma_pred: result of Nerf, [bs, h, w, floor, sigma_channel]
    Returns:

    """
    _, _, _, floor, _ = sigma_pred.size()
    c = 0
    T = 0
    for i in range(floor):
        sigma_mid = torch.nn.functional.relu(sigma_pred[:, :, :, i, :])
        T = T + (-sigma_mid)
        c = c + torch.exp(T) * (1 - torch.exp(-sigma_mid)) * rgb_pred[:, :, :, i, :]
    c = einops.rearrange(c, 'b h w c -> b c h w')
    return c


class RenderModel(nn.Module):
    def __init__(self, in_channels, simpled_channel_rgb, simpled_channel_sigma, floor_num, hidden_size):
        super(RenderModel, self).__init__()
        self.rgb_predict = rgb_predictor(in_channels=in_channels, simpled_channel=simpled_channel_rgb,
                                         floor_num=floor_num)
        self.sigma_predict = sigma_predictor(in_channels=in_channels, simpled_channel=simpled_channel_sigma,
                                             floor_num=floor_num)
        num_encoding_rgb, num_encoding_sigma = simpled_channel_rgb // floor_num, simpled_channel_sigma // floor_num
        self.nerf_module = MultiHeadNeRFModel(hidden_size=hidden_size, num_encoding_rgb=num_encoding_rgb,
                                              num_encoding_sigma=num_encoding_sigma)
        self.mini_decoder = nn.Sequential(
            UpBlock2d(256, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            UpBlock2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, feature):
        rgb_in = self.rgb_predict(feature)
        # sigma_in, point_dict = self.sigma_predict(feature.detach())
        sigma_in, point_dict = self.sigma_predict(feature)
        rgb_out, sigma_out = self.nerf_module(rgb_in, sigma_in)
        render_result = volume_render(rgb_out, sigma_out)
        render_result = torch.sigmoid(render_result)
        mini_pred = self.mini_decoder(render_result)
        out_dict = {'render': render_result, 'mini_pred': mini_pred, 'point_pred': point_dict}
        return out_dict