import cv2
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List
from itertools import chain

from transformers import SegformerForSemanticSegmentation,Mask2FormerForUniversalSegmentation
device='cpu'
class EncoderDecoder(nn.Module):
    def __init__(
        self,
        encoder,
        decoder,
        prefix=nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=True),
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.prefix = prefix

    def forward(self, x):
        if self.prefix is not None:
            x = self.prefix(x)
        x = self.encoder(x)["hidden_states"] #transformers
        return self.decoder(x)
    

def conv2d_relu(input_filters,output_filters,kernel_size=3,  bias=True):
    return nn.Sequential(
        nn.Conv2d(input_filters, output_filters, kernel_size=kernel_size, padding=kernel_size//2, bias=bias),
        nn.LeakyReLU(0.2, inplace=True),
        nn.BatchNorm2d(output_filters)
    )

def up_and_add(x, y):
    return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y

class FPN_fuse(nn.Module):
    def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256):
        super(FPN_fuse, self).__init__()
        assert feature_channels[0] == fpn_out
        self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1)
                                    for ft_size in feature_channels[1:]])
        self.smooth_conv =  nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)] 
                                    * (len(feature_channels)-1))
        self.conv_fusion = nn.Sequential(
            nn.Conv2d(2*fpn_out, fpn_out, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(fpn_out),
            nn.ReLU(inplace=True),
        )

    def forward(self, features):
        
        features[:-1] = [conv1x1(feature) for feature, conv1x1 in zip(features[:-1], self.conv1x1)]##
        feature=up_and_add(self.smooth_conv[0](features[0]),features[1])
        feature=up_and_add(self.smooth_conv[1](feature),features[2])
        feature=up_and_add(self.smooth_conv[2](feature),features[3])


        H, W = features[-1].size(2), features[-1].size(3)
        x = [feature,features[-1]]
        x = [F.interpolate(x_el, size=(H, W), mode='bilinear', align_corners=True) for x_el in x]

        x = self.conv_fusion(torch.cat(x, dim=1))
        #x = F.interpolate(x, size=(H*4, W*4), mode='bilinear', align_corners=True) 
        return x

class PSPModule(nn.Module):
    # In the original inmplementation they use precise RoI pooling 
    # Instead of using adaptative average pooling
    def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]):
        super(PSPModule, self).__init__()
        out_channels = in_channels // len(bin_sizes)
        self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) 
                                                        for b_s in bin_sizes])
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels, 
                                    kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1)
        )

    def _make_stages(self, in_channels, out_channels, bin_sz):
        prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
        conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        bn = nn.BatchNorm2d(out_channels)
        relu = nn.ReLU(inplace=True)
        return nn.Sequential(prior, conv, bn, relu)
    
    def forward(self, features):
        h, w = features.size()[2], features.size()[3]
        pyramids = [features]
        pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', 
                                        align_corners=True) for stage in self.stages])
        output = self.bottleneck(torch.cat(pyramids, dim=1))
        return output
class UperNet_swin(nn.Module):
    # Implementing only the object path
    def __init__(self, backbone,pretrained=True):
        super(UperNet_swin, self).__init__()


        self.backbone = backbone
        feature_channels = [192,384,768,768]
        self.PPN = PSPModule(feature_channels[-1])
        self.FPN = FPN_fuse(feature_channels, fpn_out=feature_channels[0])
        self.head = nn.Conv2d(feature_channels[0], 1, kernel_size=3, padding=1)



    def forward(self, x):
        input_size = (x.size()[2], x.size()[3])
        features = self.backbone(x)["hidden_states"]
        features[-1] = self.PPN(features[-1])
        x = self.head(self.FPN(features))

        x = F.interpolate(x, size=input_size, mode='bilinear')
        return x

    def get_backbone_params(self):
        return self.backbone.parameters()

    def get_decoder_params(self):
        return chain(self.PPN.parameters(), self.FPN.parameters(), self.head.parameters())
    
class UnetDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels= (3,192,384,768,768),
        decoder_channels=(512,256,128,64),
        n_blocks=4,
        use_batchnorm=True,
        attention_type=None,
        center=False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]

        out_channels = decoder_channels

        if center:
            self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm)
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
            for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)
        upscale_factor=4
        self.matting_head = nn.Sequential(
            nn.Conv2d(64,1, kernel_size=3, padding=1),
            nn.ReLU(),                                         
            nn.UpsamplingBilinear2d(scale_factor=upscale_factor),
)

    def preprocess_features(self,x):
        features=[]
        for out_tensor in x:
            bs,n,f=out_tensor.size()
            h = int(n**0.5)
            feature = out_tensor.view(-1,h,h,f).permute(0, 3, 1, 2).contiguous()
            features.append(feature)
        return features 

    def forward(self, features):
        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        features = self.preprocess_features(features)

        head = features[0]
        skips = features[1:]

        x = self.center(head)
        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip)
            #y_i = self.upsample1(y_i)
        #hypercol = torch.cat([y0,y1,y2,y3,y4], dim=1)
        x = self.matting_head(x)
        x=1-nn.ReLU()(1-x)
        return x


class SegmentationHead(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        super().__init__(conv2d, upsampling)


class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        use_batchnorm=True,
        attention_type=None,
    ):
        super().__init__()
        self.conv1 = conv2d_relu(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3
        )
        self.conv2 = conv2d_relu(
            out_channels,
            out_channels,
            kernel_size=3,
        )
        self.in_channels=in_channels
        self.out_channels = out_channels
        self.skip_channels = skip_channels
    def forward(self, x, skip=None):
        if skip is None:
            x = F.interpolate(x, scale_factor=2, mode="nearest")
        else:
            if x.shape[-1]!=skip.shape[-1]:
                x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            #print(x.shape,skip.shape)
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class CenterBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        conv1 = conv2d_relu(
            in_channels,
            out_channels,
            kernel_size=3,
        )
        conv2 = conv2d_relu(
            out_channels,
            out_channels,
            kernel_size=3,
        )
        super().__init__(conv1, conv2)



class SegForm(nn.Module):
    def __init__(self):
        super(SegForm, self).__init__()
#         configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
#         configuration.num_labels = 1 ## set output as 1 
#         self.model = SegformerForSemanticSegmentation(config=configuration)

        self.model =  SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", num_labels=1, ignore_mismatched_sizes=True
        )
    def forward(self, image):
        img_segs = self.model(image)  
        upsampled_logits = nn.functional.interpolate(img_segs.logits,
                scale_factor=4,
                mode='nearest',
               )
        return upsampled_logits
 

class StyleMatte(nn.Module):
    def __init__(self):
        super(StyleMatte, self).__init__()
#         configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
#         configuration.num_labels = 1 ## set output as 1 
        self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256],fpn_out=256)
        self.pixel_decoder =  Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-tiny-coco-instance").base_model.pixel_level_module
        self.fgf = FastGuidedFilter()
        self.conv = nn.Conv2d(256,1,kernel_size=3,padding=1)
        # self.mean =  torch.Tensor([0.43216, 0.394666, 0.37645]).float().view(-1, 1, 1)
        # self.register_buffer('image_net_mean', self.mean)
        # self.std = torch.Tensor([0.22803, 0.22145, 0.216989]).float().view(-1, 1, 1)
        # self.register_buffer('image_net_std', self.std)
    def forward(self, image, normalize=False):
        # if normalize:
        #     image.sub_(self.get_buffer("image_net_mean")).div_(self.get_buffer("image_net_std"))
        
        decoder_out = self.pixel_decoder(image)  
        decoder_states=list(decoder_out.decoder_hidden_states)
        decoder_states.append(decoder_out.decoder_last_hidden_state)    
        out_pure=self.fpn(decoder_states)
        
        image_lr=nn.functional.interpolate(image.mean(1, keepdim=True),
                scale_factor=0.25,
                mode='bicubic',
                align_corners=True
               )
        out = self.conv(out_pure)
        out = self.fgf(image_lr,out,image.mean(1, keepdim=True))#.clip(0,1)
        # out = nn.Sigmoid()(out)
        # out =  nn.functional.interpolate(out,
        #                     scale_factor=4,
        #                     mode='bicubic',
        #                     align_corners=True
        #                 )
        
        return torch.sigmoid(out)

    def get_training_params(self):
        return list(self.fpn.parameters())+list(self.conv.parameters())#+list(self.fgf.parameters())
    
class GuidedFilter(nn.Module):
    def __init__(self, r, eps=1e-8):
        super(GuidedFilter, self).__init__()

        self.r = r
        self.eps = eps
        self.boxfilter = BoxFilter(r)


    def forward(self, x, y):
        n_x, c_x, h_x, w_x = x.size()
        n_y, c_y, h_y, w_y = y.size()

        assert n_x == n_y
        assert c_x == 1 or c_x == c_y
        assert h_x == h_y and w_x == w_y
        assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1

        # N
        N = self.boxfilter((x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))

        # mean_x
        mean_x = self.boxfilter(x) / N
        # mean_y
        mean_y = self.boxfilter(y) / N
        # cov_xy
        cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
        # var_x
        var_x = self.boxfilter(x * x) / N - mean_x * mean_x

        # A
        A = cov_xy / (var_x + self.eps)
        # b
        b = mean_y - A * mean_x

        # mean_A; mean_b
        mean_A = self.boxfilter(A) / N
        mean_b = self.boxfilter(b) / N

        return mean_A * x + mean_b
class FastGuidedFilter(nn.Module):
    def __init__(self, r=1, eps=1e-8):
        super(FastGuidedFilter, self).__init__()

        self.r = r
        self.eps = eps
        self.boxfilter = BoxFilter(r)


    def forward(self, lr_x, lr_y, hr_x):
        n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
        n_lry, c_lry, h_lry, w_lry = lr_y.size()
        n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()

        assert n_lrx == n_lry and n_lry == n_hrx
        assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
        assert h_lrx == h_lry and w_lrx == w_lry
        assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1

        ## N
        N = self.boxfilter(lr_x.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0))

        ## mean_x
        mean_x = self.boxfilter(lr_x) / N
        ## mean_y
        mean_y = self.boxfilter(lr_y) / N
        ## cov_xy
        cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
        ## var_x
        var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x

        ## A
        A = cov_xy / (var_x + self.eps)
        ## b
        b = mean_y - A * mean_x

        ## mean_A; mean_b
        mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
        mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)

        return mean_A*hr_x+mean_b
class DeepGuidedFilterRefiner(nn.Module):
    def __init__(self, hid_channels=16):
        super().__init__()
        self.box_filter = nn.Conv2d(4, 4, kernel_size=3, padding=1, bias=False, groups=4)
        self.box_filter.weight.data[...] = 1 / 9
        self.conv = nn.Sequential(
            nn.Conv2d(4 * 2 + hid_channels, hid_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(hid_channels),
            nn.ReLU(True),
            nn.Conv2d(hid_channels, hid_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(hid_channels),
            nn.ReLU(True),
            nn.Conv2d(hid_channels, 4, kernel_size=1, bias=True)
        )
        
    def forward(self, fine_src, base_src, base_fgr, base_pha, base_hid):
        fine_x = torch.cat([fine_src, fine_src.mean(1, keepdim=True)], dim=1)
        base_x = torch.cat([base_src, base_src.mean(1, keepdim=True)], dim=1)
        base_y = torch.cat([base_fgr, base_pha], dim=1)
        
        mean_x = self.box_filter(base_x)
        mean_y = self.box_filter(base_y)
        cov_xy = self.box_filter(base_x * base_y) - mean_x * mean_y
        var_x  = self.box_filter(base_x * base_x) - mean_x * mean_x
        
        A = self.conv(torch.cat([cov_xy, var_x, base_hid], dim=1))
        b = mean_y - A * mean_x
        
        H, W = fine_src.shape[2:]
        A = F.interpolate(A, (H, W), mode='bilinear', align_corners=False)
        b = F.interpolate(b, (H, W), mode='bilinear', align_corners=False)
        
        out = A * fine_x + b
        fgr, pha = out.split([3, 1], dim=1)
        return fgr, pha
    
def diff_x(input, r):
    assert input.dim() == 4

    left   = input[:, :,         r:2 * r + 1]
    middle = input[:, :, 2 * r + 1:         ] - input[:, :,           :-2 * r - 1]
    right  = input[:, :,        -1:         ] - input[:, :, -2 * r - 1:    -r - 1]

    output = torch.cat([left, middle, right], dim=2)

    return output

def diff_y(input, r):
    assert input.dim() == 4

    left   = input[:, :, :,         r:2 * r + 1]
    middle = input[:, :, :, 2 * r + 1:         ] - input[:, :, :,           :-2 * r - 1]
    right  = input[:, :, :,        -1:         ] - input[:, :, :, -2 * r - 1:    -r - 1]

    output = torch.cat([left, middle, right], dim=3)

    return output

class BoxFilter(nn.Module):
    def __init__(self, r):
        super(BoxFilter, self).__init__()

        self.r = r

    def forward(self, x):
        assert x.dim() == 4

        return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
    
if __name__ == '__main__':
    model = StyleMatte().to(device)
    out=model(torch.randn(1,3,640,480).to(devuce))
    print(out.shape)