import os
import sys

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import einops
from PIL import Image

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
gmflow_dir = os.path.join(parent_dir, 'deps/gmflow')
sys.path.insert(0, gmflow_dir)

from GMFlow.gmflow.gmflow import GMFlow  # noqa: E702 E402 F401
from GMFlow.utils.utils import InputPadder  # noqa: E702 E402


def coords_grid(b, h, w, homogeneous=False, device=None):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w))  # [H, W]

    stacks = [x, y]

    if homogeneous:
        ones = torch.ones_like(x)  # [H, W]
        stacks.append(ones)

    grid = torch.stack(stacks, dim=0).float()  # [2, H, W] or [3, H, W]

    grid = grid[None].repeat(b, 1, 1, 1)  # [B, 2, H, W] or [B, 3, H, W]

    if device is not None:
        grid = grid.to(device)

    return grid


def bilinear_sample(img,
                    sample_coords,
                    mode='bilinear',
                    padding_mode='zeros',
                    return_mask=False):
    # img: [B, C, H, W]
    # sample_coords: [B, 2, H, W] in image scale
    if sample_coords.size(1) != 2:  # [B, H, W, 2]
        sample_coords = sample_coords.permute(0, 3, 1, 2)

    b, _, h, w = sample_coords.shape

    # Normalize to [-1, 1]
    x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1
    y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1

    grid = torch.stack([x_grid, y_grid], dim=-1)  # [B, H, W, 2]

    img = F.grid_sample(img,
                        grid,
                        mode=mode,
                        padding_mode=padding_mode,
                        align_corners=True)

    if return_mask:
        mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (
            y_grid <= 1)  # [B, H, W]

        return img, mask

    return img


def flow_warp(feature,
              flow,
              mask=False,
              mode='bilinear',
              padding_mode='zeros'):
    b, c, h, w = feature.size()
    assert flow.size(1) == 2

    grid = coords_grid(b, h, w).to(flow.device) + flow  # [B, 2, H, W]

    return bilinear_sample(feature,
                           grid,
                           mode=mode,
                           padding_mode=padding_mode,
                           return_mask=mask)


def forward_backward_consistency_check(fwd_flow,
                                       bwd_flow,
                                       alpha=0.01,
                                       beta=0.5,
                                       return_confidence=False):
    # fwd_flow, bwd_flow: [B, 2, H, W]
    # alpha and beta values are following UnFlow
    # (https://arxiv.org/abs/1711.07837)
    assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4
    assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2
    flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow,
                                                        dim=1)  # [B, H, W]

    warped_bwd_flow = flow_warp(bwd_flow, fwd_flow)  # [B, 2, H, W]
    warped_fwd_flow = flow_warp(fwd_flow, bwd_flow)  # [B, 2, H, W]

    diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1)  # [B, H, W]
    diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1)

    threshold = alpha * flow_mag + beta

    if return_confidence:
        # fwd_occ = diff_fwd
        # bwd_occ = diff_bwd
        fwd_occ = torch.exp(-diff_fwd)
        bwd_occ = torch.exp(-diff_bwd)
        # import ipdb; ipdb.set_trace()
        # Image.fromarray((bwd_occ * 255)[0,:,:].cpu().numpy().clip(0, 255).astype(np.uint8)).save("mask.png")
    else:
        fwd_occ = (diff_fwd > threshold).float()  # [B, H, W]
        bwd_occ = (diff_bwd > threshold).float()


    return fwd_occ, bwd_occ


@torch.no_grad()
def get_warped_and_mask(flow_model,
                        image1,
                        image2,
                        image3=None,
                        pixel_consistency=False,
                        return_confidence=False):
    if image3 is None:
        image3 = image1[None]
    padder = InputPadder(image1.shape, padding_factor=16)
    # image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
    image1, image2 = padder.pad(image1[None], image2[None])
    results_dict = flow_model(image1,
                              image2,
                              attn_splits_list=[2],
                              corr_radius_list=[-1],
                              prop_radius_list=[-1],
                              pred_bidir_flow=True)
    flow_pr = results_dict['flow_preds'][-1]  # [B, 2, H, W]
    fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0)  # [1, 2, H, W]
    bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0)  # [1, 2, H, W]

    # results_dict_ = flow_model(image2,
    #                           image1,
    #                           attn_splits_list=[2],
    #                           corr_radius_list=[-1],
    #                           prop_radius_list=[-1],
    #                           pred_bidir_flow=True)
    # flow_pr = results_dict_['flow_preds'][-1]  # [B, 2, H, W]
    # fwd_flow_ = padder.unpad(flow_pr[0]).unsqueeze(0)  # [1, 2, H, W]
    # bwd_flow_ = padder.unpad(flow_pr[1]).unsqueeze(0)  # [1, 2, H, W]
    # fwd_occ_, bwd_occ_ = forward_backward_consistency_check(
    #     fwd_flow_, bwd_flow_, return_error=True)  # [1, H, W] float

    fwd_occ, bwd_occ = forward_backward_consistency_check(
        fwd_flow, bwd_flow)  # [1, H, W] float
    
    if pixel_consistency:
        warped_image1 = flow_warp(image1, padder.pad(bwd_flow)[0])
        bwd_occ = torch.clamp(
            padder.pad(bwd_occ)[0] +
            (abs(image2 - warped_image1).mean(dim=1) > 255 * 0.25).float(), 0,
            1)
    warped_results = flow_warp(image3, bwd_flow)
    if return_confidence:
        fwd_err, bwd_err = forward_backward_consistency_check(
            fwd_flow, bwd_flow, return_confidence=return_confidence)  # [1, H, W] float
        return warped_results, bwd_occ, bwd_flow, bwd_err
    
    return warped_results, bwd_occ, bwd_flow


class FlowCalc():

    def __init__(self, model_path='./weights/gmflow_sintel-0c07dcb3.pth'):
        flow_model = GMFlow(
            feature_channels=128,
            num_scales=1,
            upsample_factor=8,
            num_head=1,
            attention_type='swin',
            ffn_dim_expansion=4,
            num_transformer_layers=6,
        ).to('cuda')

        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
        flow_model.load_state_dict(weights, strict=False)
        flow_model.eval()
        self.model = flow_model

    @torch.no_grad()
    def get_flow(self, image1, image2, save_path=None):

        if save_path is not None and os.path.exists(save_path):
            bwd_flow = read_flow(save_path)
            return bwd_flow

        image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
        image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
        padder = InputPadder(image1.shape, padding_factor=8)
        image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
        results_dict = self.model(image1,
                                  image2,
                                  attn_splits_list=[2],
                                  corr_radius_list=[-1],
                                  prop_radius_list=[-1],
                                  pred_bidir_flow=True)
        flow_pr = results_dict['flow_preds'][-1]  # [B, 2, H, W]
        fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0)  # [1, 2, H, W]
        bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0)  # [1, 2, H, W]
        fwd_occ, bwd_occ = forward_backward_consistency_check(
            fwd_flow, bwd_flow)  # [1, H, W] float
        if save_path is not None:
            flow_np = bwd_flow.cpu().numpy()
            np.save(save_path, flow_np)
            mask_path = os.path.splitext(save_path)[0] + '.png'
            bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to(
                torch.long).numpy() * 255
            cv2.imwrite(mask_path, bwd_occ)

        return bwd_flow

    @torch.no_grad()
    def get_mask(self, image1, image2, save_path=None):

        if save_path is not None:
            mask_path = os.path.splitext(save_path)[0] + '.png'
            if os.path.exists(mask_path):
                return read_mask(mask_path)

        image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
        image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
        padder = InputPadder(image1.shape, padding_factor=8)
        image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
        results_dict = self.model(image1,
                                  image2,
                                  attn_splits_list=[2],
                                  corr_radius_list=[-1],
                                  prop_radius_list=[-1],
                                  pred_bidir_flow=True)
        flow_pr = results_dict['flow_preds'][-1]  # [B, 2, H, W]
        fwd_flow = padder.unpad(flow_pr[0]).unsqueeze(0)  # [1, 2, H, W]
        bwd_flow = padder.unpad(flow_pr[1]).unsqueeze(0)  # [1, 2, H, W]
        fwd_occ, bwd_occ = forward_backward_consistency_check(
            fwd_flow, bwd_flow)  # [1, H, W] float
        if save_path is not None:
            flow_np = bwd_flow.cpu().numpy()
            np.save(save_path, flow_np)
            mask_path = os.path.splitext(save_path)[0] + '.png'
            bwd_occ = bwd_occ.cpu().permute(1, 2, 0).to(
                torch.long).numpy() * 255
            cv2.imwrite(mask_path, bwd_occ)

        return bwd_occ

    def warp(self, img, flow, mode='bilinear'):
        expand = False
        if len(img.shape) == 2:
            expand = True
            img = np.expand_dims(img, 2)

        img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
        dtype = img.dtype
        img = img.to(torch.float)
        res = flow_warp(img, flow, mode=mode)
        res = res.to(dtype)
        res = res[0].cpu().permute(1, 2, 0).numpy()
        if expand:
            res = res[:, :, 0]
        return res


def read_flow(save_path):
    flow_np = np.load(save_path)
    bwd_flow = torch.from_numpy(flow_np)
    return bwd_flow


def read_mask(save_path):
    mask_path = os.path.splitext(save_path)[0] + '.png'
    mask = cv2.imread(mask_path)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    return mask


flow_calc = FlowCalc()