# Copyright (c) OpenMMLab. All rights reserved.
import imgaug
import imgaug.augmenters as iaa
import mmcv
import numpy as np
from mmdet.core.mask import PolygonMasks
from mmdet.datasets.builder import PIPELINES


class AugmenterBuilder:
    """Build imgaug object according ImgAug argmentations."""

    def __init__(self):
        pass

    def build(self, args, root=True):
        if args is None:
            return None
        if isinstance(args, (int, float, str)):
            return args
        if isinstance(args, list):
            if root:
                sequence = [self.build(value, root=False) for value in args]
                return iaa.Sequential(sequence)
            arg_list = [self.to_tuple_if_list(a) for a in args[1:]]
            return getattr(iaa, args[0])(*arg_list)
        if isinstance(args, dict):
            if 'cls' in args:
                cls = getattr(iaa, args['cls'])
                return cls(
                    **{
                        k: self.to_tuple_if_list(v)
                        for k, v in args.items() if not k == 'cls'
                    })
            else:
                return {
                    key: self.build(value, root=False)
                    for key, value in args.items()
                }
        raise RuntimeError('unknown augmenter arg: ' + str(args))

    def to_tuple_if_list(self, obj):
        if isinstance(obj, list):
            return tuple(obj)
        return obj


@PIPELINES.register_module()
class ImgAug:
    """A wrapper to use imgaug https://github.com/aleju/imgaug.

    Args:
        args ([list[list|dict]]): The argumentation list. For details, please
            refer to imgaug document. Take args=[['Fliplr', 0.5],
            dict(cls='Affine', rotate=[-10, 10]), ['Resize', [0.5, 3.0]]] as an
            example. The args horizontally flip images with probability 0.5,
            followed by random rotation with angles in range [-10, 10], and
            resize with an independent scale in range [0.5, 3.0] for each
            side of images.
    """

    def __init__(self, args=None):
        self.augmenter_args = args
        self.augmenter = AugmenterBuilder().build(self.augmenter_args)

    def __call__(self, results):
        # img is bgr
        image = results['img']
        aug = None
        shape = image.shape

        if self.augmenter:
            aug = self.augmenter.to_deterministic()
            results['img'] = aug.augment_image(image)
            results['img_shape'] = results['img'].shape
            results['flip'] = 'unknown'  # it's unknown
            results['flip_direction'] = 'unknown'  # it's unknown
            target_shape = results['img_shape']

            self.may_augment_annotation(aug, shape, target_shape, results)

        return results

    def may_augment_annotation(self, aug, shape, target_shape, results):
        if aug is None:
            return results

        # augment polygon mask
        for key in results['mask_fields']:
            masks = self.may_augment_poly(aug, shape, results[key])
            if len(masks) > 0:
                results[key] = PolygonMasks(masks, *target_shape[:2])

        # augment bbox
        for key in results['bbox_fields']:
            bboxes = self.may_augment_poly(
                aug, shape, results[key], mask_flag=False)
            results[key] = np.zeros(0)
            if len(bboxes) > 0:
                results[key] = np.stack(bboxes)

        return results

    def may_augment_poly(self, aug, img_shape, polys, mask_flag=True):
        key_points, poly_point_nums = [], []
        for poly in polys:
            if mask_flag:
                poly = poly[0]
            poly = poly.reshape(-1, 2)
            key_points.extend([imgaug.Keypoint(p[0], p[1]) for p in poly])
            poly_point_nums.append(poly.shape[0])
        key_points = aug.augment_keypoints(
            [imgaug.KeypointsOnImage(keypoints=key_points,
                                     shape=img_shape)])[0].keypoints

        new_polys = []
        start_idx = 0
        for poly_point_num in poly_point_nums:
            new_poly = []
            for key_point in key_points[start_idx:(start_idx +
                                                   poly_point_num)]:
                new_poly.append([key_point.x, key_point.y])
            start_idx += poly_point_num
            new_poly = np.array(new_poly).flatten()
            new_polys.append([new_poly] if mask_flag else new_poly)

        return new_polys

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module()
class EastRandomCrop:

    def __init__(self,
                 target_size=(640, 640),
                 max_tries=10,
                 min_crop_side_ratio=0.1):
        self.target_size = target_size
        self.max_tries = max_tries
        self.min_crop_side_ratio = min_crop_side_ratio

    def __call__(self, results):
        # sampling crop
        # crop image, boxes, masks
        img = results['img']
        crop_x, crop_y, crop_w, crop_h = self.crop_area(
            img, results['gt_masks'])
        scale_w = self.target_size[0] / crop_w
        scale_h = self.target_size[1] / crop_h
        scale = min(scale_w, scale_h)
        h = int(crop_h * scale)
        w = int(crop_w * scale)
        padded_img = np.zeros(
            (self.target_size[1], self.target_size[0], img.shape[2]),
            img.dtype)
        padded_img[:h, :w] = mmcv.imresize(
            img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))

        # for bboxes
        for key in results['bbox_fields']:
            lines = []
            for box in results[key]:
                box = box.reshape(2, 2)
                poly = ((box - (crop_x, crop_y)) * scale)
                if not self.is_poly_outside_rect(poly, 0, 0, w, h):
                    lines.append(poly.flatten())
            results[key] = np.array(lines)
        # for masks
        for key in results['mask_fields']:
            polys = []
            polys_label = []
            for poly in results[key]:
                poly = np.array(poly).reshape(-1, 2)
                poly = ((poly - (crop_x, crop_y)) * scale)
                if not self.is_poly_outside_rect(poly, 0, 0, w, h):
                    polys.append([poly])
                    polys_label.append(0)
            results[key] = PolygonMasks(polys, *self.target_size)
            if key == 'gt_masks':
                results['gt_labels'] = polys_label

        results['img'] = padded_img
        results['img_shape'] = padded_img.shape

        return results

    def is_poly_in_rect(self, poly, x, y, w, h):
        poly = np.array(poly)
        if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
            return False
        if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
            return False
        return True

    def is_poly_outside_rect(self, poly, x, y, w, h):
        poly = np.array(poly).reshape(-1, 2)
        if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
            return True
        if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
            return True
        return False

    def split_regions(self, axis):
        regions = []
        min_axis = 0
        for i in range(1, axis.shape[0]):
            if axis[i] != axis[i - 1] + 1:
                region = axis[min_axis:i]
                min_axis = i
                regions.append(region)
        return regions

    def random_select(self, axis, max_size):
        xx = np.random.choice(axis, size=2)
        xmin = np.min(xx)
        xmax = np.max(xx)
        xmin = np.clip(xmin, 0, max_size - 1)
        xmax = np.clip(xmax, 0, max_size - 1)
        return xmin, xmax

    def region_wise_random_select(self, regions):
        selected_index = list(np.random.choice(len(regions), 2))
        selected_values = []
        for index in selected_index:
            axis = regions[index]
            xx = int(np.random.choice(axis, size=1))
            selected_values.append(xx)
        xmin = min(selected_values)
        xmax = max(selected_values)
        return xmin, xmax

    def crop_area(self, img, polys):
        h, w, _ = img.shape
        h_array = np.zeros(h, dtype=np.int32)
        w_array = np.zeros(w, dtype=np.int32)
        for points in polys:
            points = np.round(
                points, decimals=0).astype(np.int32).reshape(-1, 2)
            min_x = np.min(points[:, 0])
            max_x = np.max(points[:, 0])
            w_array[min_x:max_x] = 1
            min_y = np.min(points[:, 1])
            max_y = np.max(points[:, 1])
            h_array[min_y:max_y] = 1
        # ensure the cropped area not across a text
        h_axis = np.where(h_array == 0)[0]
        w_axis = np.where(w_array == 0)[0]

        if len(h_axis) == 0 or len(w_axis) == 0:
            return 0, 0, w, h

        h_regions = self.split_regions(h_axis)
        w_regions = self.split_regions(w_axis)

        for i in range(self.max_tries):
            if len(w_regions) > 1:
                xmin, xmax = self.region_wise_random_select(w_regions)
            else:
                xmin, xmax = self.random_select(w_axis, w)
            if len(h_regions) > 1:
                ymin, ymax = self.region_wise_random_select(h_regions)
            else:
                ymin, ymax = self.random_select(h_axis, h)

            if (xmax - xmin < self.min_crop_side_ratio * w
                    or ymax - ymin < self.min_crop_side_ratio * h):
                # area too small
                continue
            num_poly_in_rect = 0
            for poly in polys:
                if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
                                                 ymax - ymin):
                    num_poly_in_rect += 1
                    break

            if num_poly_in_rect > 0:
                return xmin, ymin, xmax - xmin, ymax - ymin

        return 0, 0, w, h