# Copyright (c) OpenMMLab. All rights reserved.
import math

import mmcv
import numpy as np
import torch
import torchvision.transforms.functional as TF
from mmcv.runner.dist_utils import get_dist_info
from mmdet.datasets.builder import PIPELINES
from PIL import Image
from shapely.geometry import Polygon
from shapely.geometry import box as shapely_box

import mmocr.utils as utils
from mmocr.datasets.pipelines.crop import warp_img


@PIPELINES.register_module()
class ResizeOCR:
    """Image resizing and padding for OCR.

    Args:
        height (int | tuple(int)): Image height after resizing.
        min_width (none | int | tuple(int)): Image minimum width
            after resizing.
        max_width (none | int | tuple(int)): Image maximum width
            after resizing.
        keep_aspect_ratio (bool): Keep image aspect ratio if True
            during resizing, Otherwise resize to the size height *
            max_width.
        img_pad_value (int): Scalar to fill padding area.
        width_downsample_ratio (float): Downsample ratio in horizontal
            direction from input image to output feature.
        backend (str | None): The image resize backend type. Options are `cv2`,
            `pillow`, `None`. If backend is None, the global imread_backend
            specified by ``mmcv.use_backend()`` will be used. Default: None.
    """

    def __init__(self,
                 height,
                 min_width=None,
                 max_width=None,
                 keep_aspect_ratio=True,
                 img_pad_value=0,
                 width_downsample_ratio=1.0 / 16,
                 backend=None):
        assert isinstance(height, (int, tuple))
        assert utils.is_none_or_type(min_width, (int, tuple))
        assert utils.is_none_or_type(max_width, (int, tuple))
        if not keep_aspect_ratio:
            assert max_width is not None, ('"max_width" must assigned '
                                           'if "keep_aspect_ratio" is False')
        assert isinstance(img_pad_value, int)
        if isinstance(height, tuple):
            assert isinstance(min_width, tuple)
            assert isinstance(max_width, tuple)
            assert len(height) == len(min_width) == len(max_width)

        self.height = height
        self.min_width = min_width
        self.max_width = max_width
        self.keep_aspect_ratio = keep_aspect_ratio
        self.img_pad_value = img_pad_value
        self.width_downsample_ratio = width_downsample_ratio
        self.backend = backend

    def __call__(self, results):
        rank, _ = get_dist_info()
        if isinstance(self.height, int):
            dst_height = self.height
            dst_min_width = self.min_width
            dst_max_width = self.max_width
        else:
            # Multi-scale resize used in distributed training.
            # Choose one (height, width) pair for one rank id.

            idx = rank % len(self.height)
            dst_height = self.height[idx]
            dst_min_width = self.min_width[idx]
            dst_max_width = self.max_width[idx]

        img_shape = results['img_shape']
        ori_height, ori_width = img_shape[:2]
        valid_ratio = 1.0
        resize_shape = list(img_shape)
        pad_shape = list(img_shape)

        if self.keep_aspect_ratio:
            new_width = math.ceil(float(dst_height) / ori_height * ori_width)
            width_divisor = int(1 / self.width_downsample_ratio)
            # make sure new_width is an integral multiple of width_divisor.
            if new_width % width_divisor != 0:
                new_width = round(new_width / width_divisor) * width_divisor
            if dst_min_width is not None:
                new_width = max(dst_min_width, new_width)
            if dst_max_width is not None:
                valid_ratio = min(1.0, 1.0 * new_width / dst_max_width)
                resize_width = min(dst_max_width, new_width)
                img_resize = mmcv.imresize(
                    results['img'], (resize_width, dst_height),
                    backend=self.backend)
                resize_shape = img_resize.shape
                pad_shape = img_resize.shape
                if new_width < dst_max_width:
                    img_resize = mmcv.impad(
                        img_resize,
                        shape=(dst_height, dst_max_width),
                        pad_val=self.img_pad_value)
                    pad_shape = img_resize.shape
            else:
                img_resize = mmcv.imresize(
                    results['img'], (new_width, dst_height),
                    backend=self.backend)
                resize_shape = img_resize.shape
                pad_shape = img_resize.shape
        else:
            img_resize = mmcv.imresize(
                results['img'], (dst_max_width, dst_height),
                backend=self.backend)
            resize_shape = img_resize.shape
            pad_shape = img_resize.shape

        results['img'] = img_resize
        results['img_shape'] = resize_shape
        results['resize_shape'] = resize_shape
        results['pad_shape'] = pad_shape
        results['valid_ratio'] = valid_ratio

        return results


@PIPELINES.register_module()
class ToTensorOCR:
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor."""

    def __init__(self):
        pass

    def __call__(self, results):
        results['img'] = TF.to_tensor(results['img'].copy())

        return results


@PIPELINES.register_module()
class NormalizeOCR:
    """Normalize a tensor image with mean and standard deviation."""

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, results):
        results['img'] = TF.normalize(results['img'], self.mean, self.std)
        results['img_norm_cfg'] = dict(mean=self.mean, std=self.std)
        return results


@PIPELINES.register_module()
class OnlineCropOCR:
    """Crop text areas from whole image with bounding box jitter. If no bbox is
    given, return directly.

    Args:
        box_keys (list[str]): Keys in results which correspond to RoI bbox.
        jitter_prob (float): The probability of box jitter.
        max_jitter_ratio_x (float): Maximum horizontal jitter ratio
            relative to height.
        max_jitter_ratio_y (float): Maximum vertical jitter ratio
            relative to height.
    """

    def __init__(self,
                 box_keys=['x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4'],
                 jitter_prob=0.5,
                 max_jitter_ratio_x=0.05,
                 max_jitter_ratio_y=0.02):
        assert utils.is_type_list(box_keys, str)
        assert 0 <= jitter_prob <= 1
        assert 0 <= max_jitter_ratio_x <= 1
        assert 0 <= max_jitter_ratio_y <= 1

        self.box_keys = box_keys
        self.jitter_prob = jitter_prob
        self.max_jitter_ratio_x = max_jitter_ratio_x
        self.max_jitter_ratio_y = max_jitter_ratio_y

    def __call__(self, results):

        if 'img_info' not in results:
            return results

        crop_flag = True
        box = []
        for key in self.box_keys:
            if key not in results['img_info']:
                crop_flag = False
                break

            box.append(float(results['img_info'][key]))

        if not crop_flag:
            return results

        jitter_flag = np.random.random() > self.jitter_prob

        kwargs = dict(
            jitter_flag=jitter_flag,
            jitter_ratio_x=self.max_jitter_ratio_x,
            jitter_ratio_y=self.max_jitter_ratio_y)
        crop_img = warp_img(results['img'], box, **kwargs)

        results['img'] = crop_img
        results['img_shape'] = crop_img.shape

        return results


@PIPELINES.register_module()
class FancyPCA:
    """Implementation of PCA based image augmentation, proposed in the paper
    ``Imagenet Classification With Deep Convolutional Neural Networks``.

    It alters the intensities of RGB values along the principal components of
    ImageNet dataset.
    """

    def __init__(self, eig_vec=None, eig_val=None):
        if eig_vec is None:
            eig_vec = torch.Tensor([
                [-0.5675, +0.7192, +0.4009],
                [-0.5808, -0.0045, -0.8140],
                [-0.5836, -0.6948, +0.4203],
            ]).t()
        if eig_val is None:
            eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
        self.eig_val = eig_val  # 1*3
        self.eig_vec = eig_vec  # 3*3

    def pca(self, tensor):
        assert tensor.size(0) == 3
        alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
        reconst = torch.mm(self.eig_val * alpha, self.eig_vec)
        tensor = tensor + reconst.view(3, 1, 1)

        return tensor

    def __call__(self, results):
        img = results['img']
        tensor = self.pca(img)
        results['img'] = tensor

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module()
class RandomPaddingOCR:
    """Pad the given image on all sides, as well as modify the coordinates of
    character bounding box in image.

    Args:
        max_ratio (list[int]): [left, top, right, bottom].
        box_type (None|str): Character box type. If not none,
            should be either 'char_rects' or 'char_quads', with
            'char_rects' for rectangle with ``xyxy`` style and
            'char_quads' for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
    """

    def __init__(self, max_ratio=None, box_type=None):
        if max_ratio is None:
            max_ratio = [0.1, 0.2, 0.1, 0.2]
        else:
            assert utils.is_type_list(max_ratio, float)
            assert len(max_ratio) == 4
        assert box_type is None or box_type in ('char_rects', 'char_quads')

        self.max_ratio = max_ratio
        self.box_type = box_type

    def __call__(self, results):

        img_shape = results['img_shape']
        ori_height, ori_width = img_shape[:2]

        random_padding_left = round(
            np.random.uniform(0, self.max_ratio[0]) * ori_width)
        random_padding_top = round(
            np.random.uniform(0, self.max_ratio[1]) * ori_height)
        random_padding_right = round(
            np.random.uniform(0, self.max_ratio[2]) * ori_width)
        random_padding_bottom = round(
            np.random.uniform(0, self.max_ratio[3]) * ori_height)

        padding = (random_padding_left, random_padding_top,
                   random_padding_right, random_padding_bottom)
        img = mmcv.impad(results['img'], padding=padding, padding_mode='edge')

        results['img'] = img
        results['img_shape'] = img.shape

        if self.box_type is not None:
            num_points = 2 if self.box_type == 'char_rects' else 4
            char_num = len(results['ann_info'][self.box_type])
            for i in range(char_num):
                for j in range(num_points):
                    results['ann_info'][self.box_type][i][
                        j * 2] += random_padding_left
                    results['ann_info'][self.box_type][i][
                        j * 2 + 1] += random_padding_top

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module()
class RandomRotateImageBox:
    """Rotate augmentation for segmentation based text recognition.

    Args:
        min_angle (int): Minimum rotation angle for image and box.
        max_angle (int): Maximum rotation angle for image and box.
        box_type (str): Character box type, should be either
            'char_rects' or 'char_quads', with 'char_rects'
            for rectangle with ``xyxy`` style and 'char_quads'
            for quadrangle with ``x1y1x2y2x3y3x4y4`` style.
    """

    def __init__(self, min_angle=-10, max_angle=10, box_type='char_quads'):
        assert box_type in ('char_rects', 'char_quads')

        self.min_angle = min_angle
        self.max_angle = max_angle
        self.box_type = box_type

    def __call__(self, results):
        in_img = results['img']
        in_chars = results['ann_info']['chars']
        in_boxes = results['ann_info'][self.box_type]

        img_width, img_height = in_img.size
        rotate_center = [img_width / 2., img_height / 2.]

        tan_temp_max_angle = rotate_center[1] / rotate_center[0]
        temp_max_angle = np.arctan(tan_temp_max_angle) * 180. / np.pi

        random_angle = np.random.uniform(
            max(self.min_angle, -temp_max_angle),
            min(self.max_angle, temp_max_angle))
        random_angle_radian = random_angle * np.pi / 180.

        img_box = shapely_box(0, 0, img_width, img_height)

        out_img = TF.rotate(
            in_img,
            random_angle,
            resample=False,
            expand=False,
            center=rotate_center)

        out_boxes, out_chars = self.rotate_bbox(in_boxes, in_chars,
                                                random_angle_radian,
                                                rotate_center, img_box)

        results['img'] = out_img
        results['ann_info']['chars'] = out_chars
        results['ann_info'][self.box_type] = out_boxes

        return results

    @staticmethod
    def rotate_bbox(boxes, chars, angle, center, img_box):
        out_boxes = []
        out_chars = []
        for idx, bbox in enumerate(boxes):
            temp_bbox = []
            for i in range(len(bbox) // 2):
                point = [bbox[2 * i], bbox[2 * i + 1]]
                temp_bbox.append(
                    RandomRotateImageBox.rotate_point(point, angle, center))
            poly_temp_bbox = Polygon(temp_bbox).buffer(0)
            if poly_temp_bbox.is_valid:
                if img_box.intersects(poly_temp_bbox) and (
                        not img_box.touches(poly_temp_bbox)):
                    temp_bbox_area = poly_temp_bbox.area

                    intersect_area = img_box.intersection(poly_temp_bbox).area
                    intersect_ratio = intersect_area / temp_bbox_area

                    if intersect_ratio >= 0.7:
                        out_box = []
                        for p in temp_bbox:
                            out_box.extend(p)
                        out_boxes.append(out_box)
                        out_chars.append(chars[idx])

        return out_boxes, out_chars

    @staticmethod
    def rotate_point(point, angle, center):
        cos_theta = math.cos(-angle)
        sin_theta = math.sin(-angle)
        c_x = center[0]
        c_y = center[1]
        new_x = (point[0] - c_x) * cos_theta - (point[1] -
                                                c_y) * sin_theta + c_x
        new_y = (point[0] - c_x) * sin_theta + (point[1] -
                                                c_y) * cos_theta + c_y

        return [new_x, new_y]


@PIPELINES.register_module()
class OpencvToPil:
    """Convert ``numpy.ndarray`` (bgr) to ``PIL Image`` (rgb)."""

    def __init__(self, **kwargs):
        pass

    def __call__(self, results):
        img = results['img'][..., ::-1]
        img = Image.fromarray(img)
        results['img'] = img

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str


@PIPELINES.register_module()
class PilToOpencv:
    """Convert ``PIL Image`` (rgb) to ``numpy.ndarray`` (bgr)."""

    def __init__(self, **kwargs):
        pass

    def __call__(self, results):
        img = np.asarray(results['img'])
        img = img[..., ::-1]
        results['img'] = img

        return results

    def __repr__(self):
        repr_str = self.__class__.__name__
        return repr_str