"""
    reference: https://github.com/xuebinqin/DIS
"""

import PIL.Image
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torch.autograd import Variable
from torchvision import transforms
from torchvision.transforms.functional import normalize

from .models import ISNetDIS

# Helpers
device = 'cuda' if torch.cuda.is_available() else 'cpu'


class GOSNormalize(object):
    """
    Normalize the Image using torch.transforms
    """

    def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.mean = mean
        self.std = std

    def __call__(self, image):
        image = normalize(image, self.mean, self.std)
        return image


def im_preprocess(im, size):
    if len(im.shape) < 3:
        im = im[:, :, np.newaxis]
    if im.shape[2] == 1:
        im = np.repeat(im, 3, axis=2)
    im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
    im_tensor = torch.transpose(torch.transpose(im_tensor, 1, 2), 0, 1)
    if len(size) < 2:
        return im_tensor, im.shape[0:2]
    else:
        im_tensor = torch.unsqueeze(im_tensor, 0)
        im_tensor = F.upsample(im_tensor, size, mode="bilinear")
        im_tensor = torch.squeeze(im_tensor, 0)

    return im_tensor.type(torch.uint8), im.shape[0:2]


class IsNetPipeLine:
    def __init__(self, model_path=None, model_digit="full"):
        self.model_digit = model_digit
        self.model = ISNetDIS()
        self.cache_size = [1024, 1024]
        self.transform = transforms.Compose([
            GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
        ])

        # Build Model
        self.build_model(model_path)

    def load_image(self, image: PIL.Image.Image):
        im = np.array(image.convert("RGB"))
        im, im_shp = im_preprocess(im, self.cache_size)
        im = torch.divide(im, 255.0)
        shape = torch.from_numpy(np.array(im_shp))
        return self.transform(im).unsqueeze(0), shape.unsqueeze(0)  # make a batch of image, shape

    def build_model(self, model_path=None):
        if model_path is not None:
            self.model.load_state_dict(torch.load(model_path, map_location=device))

        # convert to half precision
        if self.model_digit == "half":
            self.model.half()
            for layer in self.model.modules():
                if isinstance(layer, nn.BatchNorm2d):
                    layer.float()
        self.model.to(device)
        self.model.eval()

    def __call__(self, image: PIL.Image.Image):
        image_tensor, orig_size = self.load_image(image)
        mask = self.predict(image_tensor, orig_size)

        pil_mask = Image.fromarray(mask).convert('L')
        im_rgb = image.convert("RGB")

        im_rgba = im_rgb.copy()
        im_rgba.putalpha(pil_mask)

        return [im_rgba, pil_mask]

    def predict(self, inputs_val: torch.Tensor, shapes_val):
        """
        Given an Image, predict the mask
        """

        if self.model_digit == "full":
            inputs_val = inputs_val.type(torch.FloatTensor)
        else:
            inputs_val = inputs_val.type(torch.HalfTensor)

        inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)  # wrap inputs in Variable

        ds_val = self.model(inputs_val_v)[0]  # list of 6 results

        # B x 1 x H x W    # we want the first one which is the most accurate prediction
        pred_val = ds_val[0][0, :, :, :]

        # recover the prediction spatial size to the orignal image size
        pred_val = torch.squeeze(
            F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear'))

        ma = torch.max(pred_val)
        mi = torch.min(pred_val)
        pred_val = (pred_val - mi) / (ma - mi)  # max = 1

        if device == 'cuda':
            torch.cuda.empty_cache()
        return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)  # it is the mask we need


# a = IsNetPipeLine(model_path="save_models/isnet.pth")
# input_image = Image.open("image_0mx.png")
# rgb, mask = a(input_image)
#
# rgb.save("rgb.png")
# mask.save("mask.png")