# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use

from pdb import set_trace as bb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from core import functional as myF
from core.pixel_desc import PixelDesc
from tools.common import mkdir_for, todevice, cudnn_benchmark, nparray, image, image_with_trf
from tools.viz import dbgfig, show_correspondences


def arg_parser():
    import argparse
    parser = argparse.ArgumentParser('SingleScalePUMP on GPU with PyTorch')

    parser.add_argument('--img1', required=True, help='path to img1')
    parser.add_argument('--img2', required=True, help='path to img2')
    parser.add_argument('--resize', type=int, default=512, nargs='+', help='prior downsize of img1 and img2')

    parser.add_argument('--output', default=None, help='output path for correspondences')

    parser.add_argument('--levels', type=int, default=99, help='number of pyramid levels')
    parser.add_argument('--min-shape', type=int, default=5, help='minimum size of corr maps')
    parser.add_argument('--nlpow', type=float, default=1.5, help='non-linear activation power in [1,2]')
    parser.add_argument('--border', type=float, default=0.9, help='border invariance level in [0,1]')
    parser.add_argument('--dtype', default='float16', choices='float16 float32 float64'.split())

    parser.add_argument('--desc', default='PUMP-stytrf', help='checkpoint name')
    parser.add_argument('--first-level', choices='torch'.split(), default='torch')
    parser.add_argument('--activation', choices='torch'.split(), default='torch')
    parser.add_argument('--forward', choices='torch cuda cuda-lowmem'.split(), default='cuda-lowmem')
    parser.add_argument('--backward', choices='python torch cuda'.split(), default='cuda')
    parser.add_argument('--reciprocal', choices='cpu cuda'.split(), default='cpu')

    parser.add_argument('--post-filter', default=None, const=True, nargs='?', help='post-filtering (See post_filter.py)')

    parser.add_argument('--verbose', type=int, default=0, help='verbosity')
    parser.add_argument('--device', default='cuda', help='gpu device')
    parser.add_argument('--dbg', nargs='*', default=(), help='debug options')

    return parser


class SingleScalePUMP (nn.Module):
    def __init__(self, levels = 9, nlpow = 1.4, cutoff = 1, 
                 border_inv=0.9, min_shape=5, renorm=(),
                 pixel_desc = None, dtype = torch.float32,
                 verbose = True ):
        super().__init__()
        self.levels = levels
        self.min_shape = min_shape
        self.nlpow = nlpow
        self.border_inv = border_inv
        assert pixel_desc, 'Requires a pixel descriptor'
        self.pixel_desc = pixel_desc.configure(self)
        self.dtype = dtype
        self.verbose = verbose

    @torch.no_grad()
    def forward(self, img1, img2, ret='corres', dbg=()):
        with cudnn_benchmark(False):
            # compute descriptors
            (img1, img2), pixel_descs, trfs = self.extract_descs(img1, img2, dtype=self.dtype)

            # backward and forward passes
            pixel_corr = self.first_level(*pixel_descs, dbg=dbg)
            pixel_corr = self.backward_pass(self.forward_pass(pixel_corr, dbg=dbg), dbg=dbg)

            # recover correspondences
            corres = myF.best_correspondences( pixel_corr )

        if dbgfig('corres', dbg): viz_correspondences(img1[0], img2[0], *corres, fig='last')
        corres = [(myF.affmul(trfs,pos),score) for pos, score in corres] # rectify scaling etc.
        if ret == 'raw': return corres, trfs
        return self.reciprocal(*corres)

    def extract_descs(self, img1, img2, dtype=None):
        img1, sca1 = self.demultiplex_img_trf(img1)
        img2, sca2 = self.demultiplex_img_trf(img2)
        desc1, trf1 = self.pixel_desc(img1)
        desc2, trf2 = self.pixel_desc(img2)
        return (img1, img2), (desc1.type(dtype), desc2.type(dtype)), (sca1@trf1, sca2@trf2)

    def demultiplex_img_trf(self, img, **kw):
        return img if isinstance(img, tuple) else (img, torch.eye(3, device=img.device))

    def forward_pass(self, pixel_corr, dbg=()):
        weights = None
        if isinstance(pixel_corr, tuple):
            pixel_corr, weights = pixel_corr

        # first-level with activation
        if self.verbose: print(f'  Pyramid level {0} shape={tuple(pixel_corr.shape)}')
        pyramid = [ self.activation(0,pixel_corr) ]
        if dbgfig(f'corr0', dbg): viz_correlation_maps(*from_stack('img1','img2'), pyramid[0], fig='last')

        for level in range(1, self.levels+1):
            upper, weights = self.forward_level(level, pyramid[-1], weights)
            if weights.sum() == 0: break # img1 has become too small

            # activation
            pyramid.append( self.activation(level,upper) )

            if self.verbose: print(f'  Pyramid level {level} shape={tuple(upper.shape)}')
            if dbgfig(f'corr{level}', dbg): viz_correlation_maps(*from_stack('img1','img2'), upper, level=level, fig='last')
            if min(upper.shape[-2:]) <= self.min_shape: break # img2 has become too small

        return pyramid

    def forward_level(self, level, corr, weights):
        # max-pooling
        pooled = F.max_pool2d(corr, 3, padding=1, stride=2)

        # sparse conv
        return myF.sparse_conv(level, pooled, weights, norm=self.border_inv)

    def backward_pass(self, pyramid, dbg=()):
        # same than forward in reverse order
        for level in range(len(pyramid)-1, 0, -1):
            lower = self.backward_level(level, pyramid)
            # assert not torch.isnan(lower).any(), bb()
            if self.verbose: print(f'  Pyramid level {level-1} shape={tuple(lower.shape)}')
            del pyramid[-1] # free memory
            if dbgfig(f'corr{level}-bw', dbg): viz_correlation_maps(img1, img2, lower, fig='last')
        return pyramid[0]

    def backward_level(self, level, pyramid):
        # reverse sparse-coonv
        pooled = myF.sparse_conv(level, pyramid[level], reverse=True)

        # reverse max-pool and add to lower level
        return myF.max_unpool(pooled, pyramid[level-1])

    def activation(self, level, corr):
        assert 1 <= self.nlpow <= 3
        corr.clamp_(min=0).pow_(self.nlpow)
        return corr

    def first_level(self, desc1, desc2, dbg=()):
        assert desc1.ndim == desc2.ndim == 4
        assert len(desc1) == len(desc2) == 1, "not implemented"
        H1, W1 = desc1.shape[-2:]
        H2, W2 = desc2.shape[-2:]

        patches = F.unfold(desc1, 4, stride=4) # C*4*4, H1*W1//16
        B, C, N = patches.shape
        # rearrange(patches, 'B (C Kh Kw) H1W1 -> B H1W1 C Kh Kw', Kh=4, Kw=4)
        patches = patches.permute(0, 2, 1).view(B, H1W1, C//16, 4, 4)

        corr, norms = myF.normalized_corr(patches[0], desc2[0], ret_norms=True)
        if dbgfig('ncc',dbg):
            for j in range(0,len(corr),9):
              for i in range(9):
                pl.subplot(3,3,i+1).cla()
                i += j
                pl.imshow(corr[i], vmin=0.9, vmax=1)
                pl.plot(2+(i%16)*4, 2+(i//16)*4,'xr', ms=10)
              bb()
        return corr.view(H1//4, W1//4, H2+1, W2+1), (norms.view(H1//4, W1//4)>0).float()

    def reciprocal(self, corres1, corres2 ):
        corres1, corres2 = todevice(corres1, 'cpu'), todevice(corres2, 'cpu')
        return myF.reciprocal(self, corres1, corres2)


class Main:
    def __init__(self):
        self.post_filtering = False

    def run_from_args(self, args):
        device = args.device
        self.matcher = self.build_matcher(args, device)
        if args.post_filter:
            self.post_filtering = {} if args.post_filter is True else eval(f'dict({args.post_filter})')

        corres = self(*self.load_images(args, device), dbg=set(args.dbg))

        if args.output:
            self.save_output( args.output, corres )

    def run_from_args_with_images(self, img1, img2, args):
        device = args.device
        self.matcher = self.build_matcher(args, device)
        if args.post_filter:
            self.post_filtering = {} if args.post_filter is True else eval(f'dict({args.post_filter})')
        
        if isinstance(args.resize, int): # user can provide 2 separate sizes for each image
            args.resize = (args.resize, args.resize)

        if len(args.resize) == 1: 
            args.resize = 2 * args.resize

        images = []
        for imgx, size in zip([img1, img2], args.resize):
            img = torch.from_numpy(np.array(imgx.convert('RGB'))).permute(2,0,1).to(device)
            img = myF.imresize(img, size)
            images.append( img )
        
        corres = self(*images, dbg=set(args.dbg))
        
        if args.output:
            self.save_output( args.output, corres )
        
        return corres 
        

    @staticmethod
    def get_options( args ):
        # configure the pipeline
        pixel_desc = PixelDesc(path=f'checkpoints/{args.desc}.pt')
        return dict(levels=args.levels, min_shape=args.min_shape, border_inv=args.border, nlpow=args.nlpow,
                    pixel_desc=pixel_desc, dtype=eval(f'torch.{args.dtype}'), verbose=args.verbose)

    @staticmethod
    def tune_matcher( args, matcher, device ):
        if device == 'cpu': 
            matcher.dtype = torch.float32
            args.forward = 'torch'
            args.backward = 'torch'
            args.reciprocal = 'cpu'

        if args.forward == 'cuda':       type(matcher).forward_level = myF.forward_cuda
        if args.forward == 'cuda-lowmem':type(matcher).forward_level = myF.forward_cuda_lowmem
        if args.backward == 'python':    type(matcher).backward_pass = legacy.backward_python
        if args.backward == 'cuda':      type(matcher).backward_level = myF.backward_cuda
        if args.reciprocal == 'cuda':    type(matcher).reciprocal = myF.reciprocal

        return matcher.to(device)

    @staticmethod
    def build_matcher(args, device):
        options = Main.get_options(args)
        matcher = SingleScalePUMP(**options)
        return Main.tune_matcher(args, matcher, device)

    def __call__(self, *imgs, dbg=()):
        corres = self.matcher( *imgs, dbg=dbg).cpu().numpy()
        if self.post_filtering is not False: 
            corres = self.post_filter( imgs, corres )

        if 'print' in dbg: print(corres)
        if dbgfig('viz',dbg):   show_correspondences(*imgs, corres)
        return corres

    @staticmethod
    def load_images( args, device='cpu' ):
        def read_image(impath):
            try:
                from torchvision.io.image import read_image, ImageReadMode
                return read_image(impath, mode=ImageReadMode.RGB)
            except RuntimeError:
                from PIL import Image
                return torch.from_numpy(np.array(Image.open(impath).convert('RGB'))).permute(2,0,1)

        if isinstance(args.resize, int): # user can provide 2 separate sizes for each image
            args.resize = (args.resize, args.resize)

        if len(args.resize) == 1: 
            args.resize = 2 * args.resize

        images = []
        for impath, size in zip([args.img1, args.img2], args.resize):
            img = read_image(impath).to(device)
            img = myF.imresize(img, size)
            images.append( img )
        return images

    def post_filter(self, imgs, corres ):
        from post_filter import filter_corres
        return filter_corres(*map(image_with_trf,imgs), corres, **self.post_filtering)

    def save_output(self, output_path, corres ):
        mkdir_for( output_path )
        np.savez(open(output_path,'wb'), corres=corres)



if __name__ == '__main__':
    Main().run_from_args(arg_parser().parse_args())