import numpy as np
import cv2
import os
import tqdm
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from .util import rgb_to_lab, lab_to_rgb


def blend(f, b, a):
    return f*a + b*(1 - a)


class PatchedHarmonizer(nn.Module):
    def __init__(self, grid_count=1, init_weights=[0.9, 0.1]):
        super(PatchedHarmonizer, self).__init__()
        self.eps = 1e-8
        # self.weights = torch.nn.Parameter(torch.ones((grid_count, grid_count)), requires_grad=True)
        # self.grid_weights_ = torch.nn.Parameter(torch.FloatTensor(init_weights), requires_grad=True)
        self.grid_weights = torch.nn.Parameter(
            torch.FloatTensor(init_weights), requires_grad=True)
        # self.weights.retain_graph = True
        self.grid_count = grid_count

    def lab_shift(self, x, invert=False):
        x = x.float()
        if invert:
            x[:, 0, :, :] /= 2.55
            x[:, 1, :, :] -= 128
            x[:, 2, :, :] -= 128
        else:
            x[:, 0, :, :] *= 2.55
            x[:, 1, :, :] += 128
            x[:, 2, :, :] += 128

        return x

    def get_mean_std(self, img, mask, dim=[2, 3]):
        sum = torch.sum(img*mask, dim=dim)     # (B, C)
        num = torch.sum(mask, dim=dim)       # (B, C)
        mu = sum / (num + self.eps)
        mean = mu[:, :, None, None]
        var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps)
        var = var[:, :, None, None]

        return mean, torch.sqrt(var+self.eps)

    def compute_patch_statistics(self, lab):
        means, stds = [], []
        bs, dx, dy = lab.shape[0], lab.shape[2] // self.grid_count, lab.shape[3] // self.grid_count
        for h in range(self.grid_count):
            cmeans, cstds = [], []
            for w in range(self.grid_count):
                ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy]
                if h == self.grid_count - 1:
                    ind[1] = None
                if w == self.grid_count - 1:
                    ind[-1] = None
                m, v = self.compute_mean_var(
                    lab[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3])
                cmeans.append(m)
                cstds.append(v)
            means.append(cmeans)
            stds.append(cstds)

        return means, stds

    def compute_mean_var(self, x, dim=[1, 2]):
        mean = x.float().mean(dim=dim)[:, :, None, None]
        var = torch.sqrt(x.float().var(dim=dim))[:, :, None, None]

        return mean, var

    def forward(self, fg_rgb, bg_rgb, alpha, masked_stats=False):

        bg_rgb = F.interpolate(bg_rgb, size=(
            fg_rgb.shape[2:]))  # b x C x H x W

        bg_lab = bg_rgb  # self.lab_shift(rgb_to_lab(bg_rgb/255.))
        fg_lab = fg_rgb  # self.lab_shift(rgb_to_lab(fg_rgb/255.))

        if masked_stats:
            self.bg_global_mean, self.bg_global_var = self.get_mean_std(
                img=bg_lab, mask=(1-alpha))
            self.fg_global_mean, self.fg_global_var = self.get_mean_std(
                img=fg_lab, mask=torch.ones_like(alpha))
        else:
            self.bg_global_mean, self.bg_global_var = self.compute_mean_var(bg_lab, dim=[
                                                                            2, 3])
            self.fg_global_mean, self.fg_global_var = self.compute_mean_var(fg_lab, dim=[
                                                                            2, 3])

        self.bg_means, self.bg_vars = self.compute_patch_statistics(
            bg_lab)
        self.fg_means, self.fg_vars = self.compute_patch_statistics(
            fg_lab)

        fg_harm = self.harmonize(fg_lab)
        # fg_harm = lab_to_rgb(fg_harm)
        bg = F.interpolate(bg_rgb, size=(fg_rgb.shape[2:]))/255.

        composite = blend(fg_harm, bg, alpha)

        return composite, fg_harm

    def harmonize(self, fg):
        harmonized = torch.zeros_like(fg)
        dx = fg.shape[2] // self.grid_count
        dy = fg.shape[3] // self.grid_count
        for h in range(self.grid_count):
            for w in range(self.grid_count):
                ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy]
                if h == self.grid_count - 1:
                    ind[1] = None
                if w == self.grid_count - 1:
                    ind[-1] = None
                harmonized[:, :, ind[0]:ind[1], ind[2]:ind[3]] = self.normalize_channel(
                    fg[:, :, ind[0]:ind[1], ind[2]:ind[3]], h, w)

        # harmonized = self.lab_shift(harmonized, invert=True)

        return harmonized

    def normalize_channel(self, value, h, w):

        fg_local_mean, fg_local_var = self.fg_means[h][w], self.fg_vars[h][w]
        bg_local_mean, bg_local_var = self.bg_means[h][w], self.bg_vars[h][w]
        fg_global_mean, fg_global_var = self.fg_global_mean, self.fg_global_var
        bg_global_mean, bg_global_var = self.bg_global_mean, self.bg_global_var

        # global2global normalization
        zeroed_mean = value - fg_global_mean
        # (fg_v * div_global_v +  (1-fg_v) * div_v)
        scaled_var = zeroed_mean * (bg_global_var/(fg_global_var + self.eps))
        normalized_global = scaled_var + bg_global_mean

        # local2local normalization
        zeroed_mean = value - fg_local_mean
        # (fg_v * div_global_v +  (1-fg_v) * div_v)
        scaled_var = zeroed_mean * (bg_local_var/(fg_local_var + self.eps))
        normalized_local = scaled_var + bg_local_mean

        return self.grid_weights[0]*normalized_local + self.grid_weights[1]*normalized_global

    def normalize_fg(self, value):
        zeroed_mean = value - \
            (self.fg_local_mean *
             self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
        # (fg_v * div_global_v +  (1-fg_v) * div_v)
        scaled_var = zeroed_mean * \
            (self.bg_global_var/(self.fg_global_var + self.eps))
        normalized_lg = scaled_var + \
            (self.bg_local_mean *
             self.grid_weights[None, None, :, :, None, None]).sum().squeeze()

        return normalized_lg


class PatchNormalizer(nn.Module):
    def __init__(self, in_channels=3, eps=1e-7, grid_count=1, weights=[0.5, 0.5], init_value=1e-2):
        super(PatchNormalizer, self).__init__()
        self.grid_count = grid_count
        self.eps = eps

        self.weights = nn.Parameter(
            torch.FloatTensor(weights), requires_grad=True)
        self.fg_var = nn.Parameter(
            init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True)
        self.fg_bias = nn.Parameter(
            init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True)
        self.patched_fg_var = nn.Parameter(
            init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True)
        self.patched_fg_bias = nn.Parameter(
            init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True)
        self.bg_var = nn.Parameter(
            init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True)
        self.bg_bias = nn.Parameter(
            init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True)
        self.grid_weights = torch.nn.Parameter(torch.ones((in_channels, grid_count, grid_count))[
                                               None, :, :, :] / (grid_count*grid_count*in_channels), requires_grad=True)

    def local_normalization(self, value):
        zeroed_mean = value - \
            (self.fg_local_mean *
             self.grid_weights[None, None, :, :, None, None]).sum().squeeze()
        # (fg_v * div_global_v +  (1-fg_v) * div_v)
        scaled_var = zeroed_mean * \
            (self.bg_global_var/(self.fg_global_var + self.eps))
        normalized_lg = scaled_var + \
            (self.bg_local_mean *
             self.grid_weights[None, None, :, :, None, None]).sum().squeeze()

        return normalized_lg

    def get_mean_std(self, img, mask, dim=[2, 3]):
        sum = torch.sum(img*mask, dim=dim)     # (B, C)
        num = torch.sum(mask, dim=dim)       # (B, C)
        mu = sum / (num + self.eps)
        mean = mu[:, :, None, None]
        var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps)
        var = var[:, :, None, None]

        return mean, torch.sqrt(var+self.eps)

    def compute_patch_statistics(self, img, mask):
        means, stds = [], []
        bs, dx, dy = img.shape[0], img.shape[2] // self.grid_count, img.shape[3] // self.grid_count
        for h in range(self.grid_count):
            cmeans, cstds = [], []
            for w in range(self.grid_count):
                ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy]
                if h == self.grid_count - 1:
                    ind[1] = None
                if w == self.grid_count - 1:
                    ind[-1] = None
                m, v = self.get_mean_std(
                    img[:, :, ind[0]:ind[1], ind[2]:ind[3]], mask[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3])
                cmeans.append(m.reshape(m.shape[:2]))
                cstds.append(v.reshape(v.shape[:2]))
            means.append(torch.stack(cmeans))
            stds.append(torch.stack(cstds))

        return torch.stack(means), torch.stack(stds)

    def compute_mean_var(self, x, dim=[2, 3]):
        mean = x.float().mean(dim=dim)
        var = torch.sqrt(x.float().var(dim=dim))

        return mean, var

    def forward(self, fg, bg, mask):

        self.local_means, self.local_vars = self.compute_patch_statistics(
            bg, (1-mask))

        bg_mean, bg_var = self.get_mean_std(bg, 1 - mask)
        zeroed_mean = (bg - bg_mean)
        unscaled = zeroed_mean / bg_var
        bg_normalized = unscaled * self.bg_var + self.bg_bias

        fg_mean, fg_var = self.get_mean_std(fg, mask)
        zeroed_mean = fg - fg_mean
        unscaled = zeroed_mean / fg_var

        mean_patched_back = (self.local_means.permute(
            2, 3, 0, 1)*self.grid_weights).sum(dim=[2, 3])[:, :, None, None]

        normalized = unscaled * bg_var + bg_mean
        patch_normalized = unscaled * bg_var + mean_patched_back

        fg_normalized = normalized * self.fg_var + self.fg_bias
        fg_patch_normalized = patch_normalized * \
            self.patched_fg_var + self.patched_fg_bias

        fg_result = self.weights[0] * fg_normalized + \
            self.weights[1] * fg_patch_normalized
        composite = blend(fg_result, bg_normalized, mask)

        return composite