import cv2
import numpy as np
import pandas as pd
import pkg_resources as pkg
import torch
import math
from typing import Tuple
from data_utils.image_utils import _get_width_and_height


def points_to_xyxy(coords: np.ndarray) -> list:
    x_coords = [coord[0] for coord in coords]
    y_coords = [coord[1] for coord in coords]
    x1 = min(x_coords)
    y1 = min(y_coords)
    x2 = max(x_coords)
    y2 = max(y_coords)
    return [x1, y1, x2, y2]


def xyxy2xywh(x):
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = (x[..., 0] + x[..., 2]) / 2  # x center
    y[..., 1] = (x[..., 1] + x[..., 3]) / 2  # y center
    y[..., 2] = x[..., 2] - x[..., 0]  # width
    y[..., 3] = x[..., 3] - x[..., 1]  # height
    return y


def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
    y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
    y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
    y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
    return y


def is_abox_in_bbox(abox_coords, bbox_coords):
    # abox가 bbox안에 있는지 확인하는 함수. 좌표형식. (x1,y1,x2,y2)
    if (
        bbox_coords[0] <= abox_coords[0]
        and bbox_coords[1] <= abox_coords[1]
        and abox_coords[2] <= bbox_coords[2]
        and abox_coords[3] <= bbox_coords[3]
    ):
        return True
    else:
        return False


def calculate_aspect_ratio(box):
    width = box[2] - box[0]
    height = box[3] - box[1]
    aspect_ratio = width / (height + 1e-8)
    return aspect_ratio


def get_box_shape(box, threshold=0.1):
    """
    Check if a box is close to a square.
    - threshold (float): The threshold for considering the box as close to a square.
                        Default is 0.1.
    Returns:
    - str: "square" or "horizontal" or "vertical"
    """
    aspect_ratio = calculate_aspect_ratio(box)
    if abs(1 - aspect_ratio) < threshold:
        return "square"
    elif aspect_ratio > 1:
        return "horizontal"
    elif aspect_ratio < 1:
        return "vertical"


def calculate_aspect_ratio_loss(predicted_box, gt_box):
    """predicted_box와 gt_box간의 가로세로 비율에 대한 차이도를 반환 range:0~1. 클수록 차이가 크다는 뜻."""
    gt_aspect_ratio = calculate_aspect_ratio(gt_box)
    pred_aspect_ratio = calculate_aspect_ratio(predicted_box)

    ratio_difference = abs(gt_aspect_ratio - pred_aspect_ratio)

    loss = 2 * math.atan(ratio_difference) / math.pi

    return loss


def clip_boxes(boxes, shape):
    # Clip boxes (xyxy) to image shape (height, width)
    if isinstance(boxes, torch.Tensor):  # faster individually
        boxes[..., 0].clamp_(0, shape[1])  # x1
        boxes[..., 1].clamp_(0, shape[0])  # y1
        boxes[..., 2].clamp_(0, shape[1])  # x2
        boxes[..., 3].clamp_(0, shape[0])  # y2
    else:  # np.array (faster grouped)
        boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2
        boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2


def is_box_overlap(box1, box2):
    # Box overlap checking logic
    if box1[0] > box2[2] or box1[2] < box2[0] or box1[1] > box2[3] or box1[3] < box2[1]:
        return False
    else:
        return True


def intersection_area(box1, box2):
    """
    Calculate the intersection area between two bounding boxes.

    Parameters:
    - box1, box2: Tuple or list representing the bounding box in the format (x1, y1, x2, y2).

    Returns:
    - area: Intersection area between the two boxes.
    """
    x1_box1, y1_box1, x2_box1, y2_box1 = box1
    x1_box2, y1_box2, x2_box2, y2_box2 = box2

    # Calculate intersection coordinates
    x_intersection = max(x1_box1, x1_box2)
    y_intersection = max(y1_box1, y1_box2)
    x_intersection_end = min(x2_box1, x2_box2)
    y_intersection_end = min(y2_box1, y2_box2)

    # Calculate intersection area
    width_intersection = max(0, x_intersection_end - x_intersection)
    height_intersection = max(0, y_intersection_end - y_intersection)
    area = width_intersection * height_intersection

    return area


def bbox_iou(box1, box2, GIoU=False, DIoU=False, CIoU=False, CIoU2=False, eps=1e-7):
    """
    Caclulate IoUs(GIoU,DIoU,CIoU,CIoU2)

    Parameters:
    - box1, box2: Tuple or list representing the bounding box in the format (x1, y1, x2, y2).

    Returns:
    - IoU or GIoU or DIoU or CIoU or CIoU2
    """
    # Returns Intersection over Union (IoU)

    # Get the coordinates of bounding boxes
    # x1, y1, x2, y2 = box1
    b1_x1, b1_y1, b1_x2, b1_y2 = box1
    b2_x1, b2_y1, b2_x2, b2_y2 = box2
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps

    # Intersection area
    inter = intersection_area(box1, box2)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    iou = inter / union

    if CIoU or DIoU or GIoU or CIoU2:
        cw = max(b1_x2, b2_x2) - min(
            b1_x1, b2_x1
        )  # convex (smallest enclosing box) width
        ch = max(b1_y2, b2_y2) - min(b1_y1, b2_y1)  # convex height
        c_area = cw * ch + eps  # convex area
        giou_penalty = (c_area - union) / c_area
        if GIoU:  # GIoU https://arxiv.org/pdf/1902.09630.pdf
            return round(iou - giou_penalty, 4)  # GIoU
        elif (
            DIoU or CIoU
        ):  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            rho2 = (
                (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
                + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
            ) / 4  # center dist ** 2
            c2 = cw**2 + ch**2 + eps  # convex diagonal squared
            diou_penalty = rho2 / c2
            if DIoU:
                return round(iou - diou_penalty, 4)  # DIoU
            if CIoU or CIoU2:
                v = (4 / math.pi**2) * (
                    (np.arctan((w2 / h2)) - np.arctan(w1 / h1)) ** 2
                )
                alpha = v / (v - iou + (1 + eps))
                ciou_penalty = diou_penalty + alpha * v
                if CIoU2:
                    ciou2_penalty = giou_penalty + diou_penalty + alpha * v
                    return round(iou - ciou2_penalty)  # CIoU2
                return round(iou - ciou_penalty, 4)  # CIoU

    return round(iou, 4)  # IoU


def rotate_around_point(x, y, pivot_x, pivot_y, degrees) -> Tuple[int, int]:
    """주어진 좌표 (x,y)를 축 좌표(pivot_x,pivot_y_를 기준으로 반시계 방향으로 회전. return new_x,new_y"""

    # 각도를 라디안으로 변환
    angle_radians = np.radians(degrees)

    # 회전 변환 적용
    x_new = (
        pivot_x
        + np.cos(angle_radians) * (x - pivot_x)
        - np.sin(angle_radians) * (y - pivot_y)
    )
    y_new = (
        pivot_y
        + np.sin(angle_radians) * (x - pivot_x)
        + np.cos(angle_radians) * (y - pivot_y)
    )

    return int(x_new), int(y_new)


def rotate_box_coordinates_on_pivot(x1, y1, x2, y2, degrees, pivot_x, pivot_y):
    """주어진 box 좌표(x1,y1,x2,y2)를 주어진 축 좌표(pivot_x,pivot_y)에 대해 시계 방향으로 회전"""
    radians = np.radians(degrees)
    rotation_matrix = np.array(
        [[np.cos(radians), -np.sin(radians)], [np.sin(radians), np.cos(radians)]]
    )

    # 상자 좌표를 중심을 기준으로 회전
    box_coordinates = np.array(
        [
            [x1 - pivot_x, y1 - pivot_y],
            [x2 - pivot_x, y1 - pivot_y],
            [x2 - pivot_x, y2 - pivot_y],
            [x1 - pivot_x, y2 - pivot_y],
        ]
    )

    rotated_box_coordinates = np.dot(box_coordinates, rotation_matrix.T)

    # 회전 후 좌표에 중심 좌표를 더해 원래 좌표로 변환
    rotated_box_coordinates += np.array([pivot_y, pivot_x])

    # 변환된 좌표를 새로운 상자 좌표로 반환
    new_x1, new_y1 = rotated_box_coordinates.min(axis=0)
    new_x2, new_y2 = rotated_box_coordinates.max(axis=0)

    return int(new_x1), int(new_y1), int(new_x2), int(new_y2)


def bbox_iou_torch(
    box1, box2, xywh=False, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * (
        b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
    ).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps

    # IoU
    iou = inter / union
    if CIoU or DIoU or GIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(
            b2_x1
        )  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = cw**2 + ch**2 + eps  # convex diagonal squared
            rho2 = (
                (b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2
                + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2
            ) / 4  # center dist ** 2
            if (
                CIoU
            ):  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi**2) * (
                    torch.atan(w2 / h2) - torch.atan(w1 / h1)
                ).pow(2)
                with torch.no_grad():
                    alpha = v / (v - iou + (1 + eps))
                return iou - (rho2 / c2 + v * alpha)  # CIoU
            return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        return (
            iou - (c_area - union) / c_area
        )  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    return iou  # IoU


def generate_random_box(width_range, height_range):
    """
    Generate random bounding box coordinates (x1, y1, x2, y2) with random width and height.

    Parameters:
    - width_range: Tuple representing the range of width values (min_width, max_width).
    - height_range: Tuple representing the range of height values (min_height, max_height).

    Returns:
    - box: Tuple representing the bounding box in the format (x1, y1, x2, y2).
    """
    min_width, max_width = width_range
    min_height, max_height = height_range

    width = np.random.randint(min_width, max_width)
    height = np.random.randint(min_height, max_height)

    x1 = np.random.randint(0, 100 - width)
    y1 = np.random.randint(0, 100 - height)
    x2 = x1 + width
    y2 = y1 + height

    return x1, y1, x2, y2


def mask_to_bboxes(mask, margin_rate=2, pixel_thresh=300) -> pd.DataFrame:
    nlabels, segmap, stats, centroids = cv2.connectedComponentsWithStats(
        image=mask, connectivity=4
    )
    bboxes = pd.DataFrame(
        stats[1:, :], columns=["bbox_x1", "bbox_y1", "width", "height", "pixel_count"]
    )
    img_width, img_height = _get_width_and_height(mask)

    bboxes = bboxes[bboxes["pixel_count"].ge(pixel_thresh)]

    bboxes["bbox_x2"] = bboxes["bbox_x1"] + bboxes["width"]
    bboxes["bbox_y2"] = bboxes["bbox_y1"] + bboxes["height"]

    bboxes["margin"] = bboxes.apply(
        lambda x: int(
            math.sqrt(
                x["pixel_count"]
                * min(x["width"], x["height"])
                / (x["width"] * x["height"])
            )
            * margin_rate
        ),
        axis=1,
    )
    bboxes["bbox_x1"] = bboxes.apply(
        lambda x: max(0, x["bbox_x1"] - x["margin"]), axis=1
    )
    bboxes["bbox_y1"] = bboxes.apply(
        lambda x: max(0, x["bbox_y1"] - x["margin"]), axis=1
    )
    bboxes["bbox_x2"] = bboxes.apply(
        lambda x: min(img_width, x["bbox_x2"] + x["margin"]), axis=1
    )
    bboxes["bbox_y2"] = bboxes.apply(
        lambda x: min(img_height, x["bbox_y2"] + x["margin"]), axis=1
    )
    bboxes = bboxes[["bbox_x1", "bbox_y1", "bbox_x2", "bbox_y2"]]
    img_width, img_height = _get_width_and_height(mask)
    if img_width >= img_height:
        bboxes.sort_values(by=["bbox_x1", "bbox_y1"], inplace=True)
    else:
        bboxes.sort_values(by=["bbox_y1", "bbox_x1"], inplace=True)

    return bboxes


def bbox_to_mask(bboxes: list, mask_size):
    """
    Creates a mask image based on bounding box coordinates.

    Args:
    - bboxes: list (x_min, y_min, x_max, y_max) representing the bounding box coordinates.
    - mask_size: Tuple (height, width) representing the size of the mask image to be created.

    Returns:
    - Mask image with the specified bounding box area filled with white.
    """
    # Initialize a black mask image with the specified size
    mask = np.zeros(mask_size, dtype=np.uint8)
    # mask = np.zeros_like(img).astype("uint8")

    for bbox in bboxes:
        # Extract bounding box coordinates
        x_min, y_min, x_max, y_max = bbox

        # Ensure bbox coordinates are within mask bounds
        x_min = max(0, x_min)
        y_min = max(0, y_min)
        x_max = min(mask_size[1], x_max)
        y_max = min(mask_size[0], y_max)

        # Fill the bounding box area with white color in the mask image
        mask[y_min:y_max, x_min:x_max] = 255

    return mask


def move_box_a_to_center_of_box_b(A, B):
    # A와 B의 좌표 (l, t, r, b)
    lA, tA, rA, bA = A
    lB, tB, rB, bB = B

    # 박스 A의 너비와 높이
    width_A = rA - lA
    height_A = bA - tA

    # 박스 B의 중심 좌표
    center_x_B = (lB + rB) / 2
    center_y_B = (tB + bB) / 2

    # 박스 A의 새로운 좌표 (중심을 B의 중심으로 이동)
    new_lA = center_x_B - width_A / 2
    new_tA = center_y_B - height_A / 2
    new_rA = center_x_B + width_A / 2
    new_bA = center_y_B + height_A / 2

    # 새로운 A 박스의 좌표 반환
    return (new_lA, new_tA, new_rA, new_bA)


def scale_bboxes(bboxes, max_x, max_y, x_scale_factor=1.2, y_scale_factor=1.05):
    # 기존 좌표에서 각 박스의 중심 좌표, 너비, 높이 계산
    bboxes["cx"] = (bboxes["bbox_x1"] + bboxes["bbox_x2"]) / 2
    bboxes["cy"] = (bboxes["bbox_y1"] + bboxes["bbox_y2"]) / 2
    bboxes["width"] = bboxes["bbox_x2"] - bboxes["bbox_x1"]
    bboxes["height"] = bboxes["bbox_y2"] - bboxes["bbox_y1"]

    # 각 박스의 크기를 1.2배로 늘림
    bboxes["new_width"] = bboxes["width"] * x_scale_factor
    bboxes["new_height"] = bboxes["height"] * y_scale_factor

    # 새로운 좌표 계산
    bboxes["new_x1"] = bboxes["cx"] - bboxes["new_width"] / 2
    bboxes["new_y1"] = bboxes["cy"] - bboxes["new_height"] / 2
    bboxes["new_x2"] = bboxes["cx"] + bboxes["new_width"] / 2
    bboxes["new_y2"] = bboxes["cy"] + bboxes["new_height"] / 2

    # box 범위 제한
    bboxes["new_x1"] = bboxes["new_x1"].clip(lower=0).astype(int)
    bboxes["new_y1"] = bboxes["new_y1"].clip(lower=0).astype(int)
    bboxes["new_x2"] = bboxes["new_x2"].clip(upper=max_x).astype(int)
    bboxes["new_y2"] = bboxes["new_y2"].clip(upper=max_y).astype(int)

    # 결과 데이터프레임 생성
    new_bboxes = bboxes[
        ["ori_content", "new_x1", "new_y1", "new_x2", "new_y2", "predicted_lang"]
    ].copy()
    new_bboxes.columns = [
        "ori_content",
        "bbox_x1",
        "bbox_y1",
        "bbox_x2",
        "bbox_y2",
        "predicted_lang",
    ]

    return new_bboxes


if __name__ == "__main__":
    w_range = (100, 200)
    h_range = (100, 200)

    box1 = generate_random_box(w_range, h_range)
    box2 = generate_random_box(w_range, h_range)

    print(f"box1 coors : {box1}")
    print(f"box2 coors : {box2}")

    print(f"intersection area : {intersection_area(box1,box2)}")
    iou = bbox_iou(box1, box2)
    giou = bbox_iou(box1, box2, GIoU=True)
    diou = bbox_iou(box1, box2, DIoU=True)
    ciou = bbox_iou(box1, box2, CIoU=True)
    print(iou, giou, diou, ciou)