# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use

from pdb import set_trace as bb
import os, os.path as osp
from tqdm import tqdm
from PIL import Image
import numpy as np
import torch

from .image_set import ImageSet
from .transforms import instanciate_transforms
from .utils import DatasetWithRng
invh = np.linalg.inv


class ImagePairs (DatasetWithRng):
    """ Base class for a dataset that serves image pairs.
    """
    imgs = None # regular image dataset
    pairs = [] # list of (idx1, idx2), ...

    def __init__(self, image_set, pairs, trf=None, **rng):
        assert image_set and pairs, 'empty images or pairs'
        super().__init__(**rng)
        self.imgs = image_set
        self.pairs = pairs
        self.trf = instanciate_transforms(trf, rng=self.rng)

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        transform = self.trf or (lambda x:x)
        pair = tuple(map(transform, self._load_pair(idx)))
        return pair, {}

    def _load_pair(self, idx):
        i,j = self.pairs[idx]
        img1 = self.imgs.get_image(i)
        return (img1, img1) if i == j else (img1, self.imgs.get_image(j))

    def __repr__(self):
        return f'{self.__class__.__name__}({len(self)} pairs from {self.imgs})'


class StillImagePairs (ImagePairs):
    """ A dataset of 'still' image pairs used for debugging purposes.
    """
    def __init__(self, image_set, pairs=None, **rng):
        if isinstance(image_set, ImagePairs):
            super().__init__(image_set.imgs, pairs or image_set.pairs, **rng)
        else:
            super().__init__(image_set, pairs or [(i,i) for i in range(len(image_set))], **rng)

    def __getitem__(self, idx):
        img1, img2 = self._load_pair(idx)
        sx, sy = img2.size / np.float32(img1.size)
        return (img1, img2), dict(homography=np.diag(np.float32([sx, sy, 1])))


class SyntheticImagePairs (StillImagePairs):
    """ A synthetic generator of image pairs.
        Given a normal image dataset, it constructs pairs using random homographies & noise.

    scale: prior image scaling.
    distort: distortion applied independently to (img1,img2) if sym=True else just img2
    sym: (bool) see above.
    """
    def __init__(self, image_set, scale='', distort='', sym=False, **rng):
        super().__init__(image_set, **rng)
        self.symmetric = sym
        self.scale = instanciate_transforms(scale, rng=self.rng)
        self.distort = instanciate_transforms(distort, rng=self.rng)

    def __getitem__(self, idx):
        (img1, img2), gt = super().__getitem__(idx)

        img1 = dict(img=img1, homography=np.eye(3,dtype=np.float32))
        if img1['img'] is img2:
            img1 = self.scale(img1)
            img2 = self.distort(dict(img1))
            if self.symmetric: img1 = self.distort(img1)
        else:
            if self.symmetric: img1 = self.distort(self.scale(img1))
            img2 = self.distort(self.scale(dict(img=img2, **gt)))

        return (img1['img'], img2['img']), dict(homography=img2['homography'] @ invh(img1['homography']))

    def __repr__(self):
        format = lambda s: ','.join(l.strip() for l in repr(s).splitlines() if l).replace(',','',1)
        return f"{self.__class__.__name__}({len(self)} images, scale={format(self.scale)}, distort={format(self.distort)})"


class CatImagePairs (DatasetWithRng):
    """ Concatenation of several ImagePairs datasets
    """
    def __init__(self, *pair_datasets, seed=torch.initial_seed()):
        assert all(isinstance(db, ImagePairs) for db in pair_datasets)
        self.pair_datasets = pair_datasets
        DatasetWithRng.__init__(self, seed=seed) # init last
        self._init()

    def _init(self):
        self._pair_offsets = np.cumsum([0] + [len(db) for db in self.pair_datasets])
        self.npairs = self._pair_offsets[-1]

    def __len__(self):
        return self.npairs

    def __repr__(self):
        fmt_str = f"{type(self).__name__}({len(self)} pairs,"
        for i,db in enumerate(self.pair_datasets):
            npairs = self._pair_offsets[i+1] - self._pair_offsets[i]
            fmt_str += f'\n\t{npairs} from '+str(db).replace("\n"," ") + ','
        return fmt_str[:-1] + ')'

    def __getitem__(self, idx):
        b, i = self._which(idx)
        return self.pair_datasets[b].__getitem__(i)

    def _which(self, i):
        pos = np.searchsorted(self._pair_offsets, i, side='right')-1
        assert pos < self.npairs, 'Bad pair index %d >= %d' % (i, self.npairs)
        return pos, i - self._pair_offsets[pos]

    def _call(self, func, i, *args, **kwargs):
        b, j = self._which(i)
        return getattr(self.pair_datasets[b], func)(j, *args, **kwargs)

    def init_worker(self, tid):
        for db in self.pair_datasets:
            db.init_worker(tid)


class BalancedCatImagePairs (CatImagePairs):
    """ Balanced concatenation of several ImagePairs datasets
    """
    def __init__(self, npairs=0, *pair_datasets, **kw):
        assert isinstance(npairs, int) and npairs >= 0, 'BalancedCatImagePairs(npairs != int)'
        assert len(pair_datasets) > 0, 'no dataset provided'

        if len(pair_datasets) >= 3 and isinstance(pair_datasets[1], int):
            assert len(pair_datasets) % 2 == 1
            pair_datasets = [npairs] + list(pair_datasets)
            npairs, pair_datasets = pair_datasets[0::2], pair_datasets[1::2]
            assert all(isinstance(n, int) for n in npairs)
            self._pair_offsets = np.cumsum([0]+npairs)
            self.npairs = self._pair_offsets[-1]
        else:
            self.npairs = npairs or max(len(db) for db in pair_datasets)
            self._pair_offsets = np.linspace(0, self.npairs, len(pair_datasets)+1).astype(int)
        CatImagePairs.__init__(self, *pair_datasets, **kw)

    def set_epoch(self, epoch):
        DatasetWithRng.init_worker(self, epoch) # random seed only depends on the epoch
        self._init() # reset permutations for this epoch

    def init_worker(self, tid):
        CatImagePairs.init_worker(self, tid) 

    def _init(self):
        self._perms = []
        for i,db in enumerate(self.pair_datasets):
            assert len(db), 'cannot balance if there is an empty dataset'
            avail = self._pair_offsets[i+1] - self._pair_offsets[i]
            idxs = np.arange(len(db))
            while len(idxs) < avail: 
                idxs = np.r_[idxs,idxs]
            if self.seed: # if not seed, then no shuffle
                self.rng.shuffle(idxs[(avail//len(db))*len(db):])
            self._perms.append( idxs[:avail] )
        # print(self._perms)

    def _which(self, i):
        pos, idx = super()._which(i)
        return pos, self._perms[pos][idx]


class UnsupervisedPairs (ImagePairs):
    """ Unsupervised image pairs obtained from SfM
    """
    def __init__(self, img_set, pair_file_path):
        assert isinstance(img_set, ImageSet), bb()
        self.pair_list = self._parse_pair_list(pair_file_path)
        self.corres_dir = osp.join(osp.split(pair_file_path)[0], 'corres')

        tag_to_idx = {n:i for i,n in enumerate(img_set.imgs)}
        img_indices = lambda pair: tuple([tag_to_idx[n] for n in pair])
        super().__init__(img_set, [img_indices(pair) for pair in self.pair_list])

    def __repr__(self):
        return f"{type(self).__name__}({len(self)} pairs from {self.imgs})"

    def _parse_pair_list(self, pair_file_path):
        res = []
        for row in open(pair_file_path).read().splitlines():
            row = row.split()
            if len(row) != 2: raise IOError()
            res.append((row[0], row[1]))
        return res

    def get_corres_path(self, pair_idx):
        img1, img2 = [osp.basename(self.imgs.imgs[i]) for i in self.pairs[pair_idx]]
        return osp.join(self.corres_dir, f'{img1}_{img2}.npy')

    def get_corres(self, pair_idx):
        return np.load(self.get_corres_path(pair_idx))

    def __getitem__(self, idx):
        img1, img2 = self._load_pair(idx)
        return (img1, img2), dict(corres=self.get_corres(idx))


if __name__ == '__main__':
    from datasets import *
    from tools.viz import show_random_pairs

    db = BalancedCatImagePairs(
                3125, SyntheticImagePairs(RandomWebImages(0,52),distort='RandomTilting(0.5)'),
                4875, SyntheticImagePairs(SfM120k_Images(),distort='RandomTilting(0.5)'),
                8000, SfM120k_Pairs())

    show_random_pairs(db)