from typing import Optional

import cv2
import numpy as np
from PIL import Image


def export_mask(
    masks,
    autogenerated: Optional[bool] = False,
    random_color: Optional[bool] = True,
    smoothen_contours: Optional[bool] = True,
) -> Image:
    if not autogenerated:
        num_masks, _, h, w = masks.shape
        num_masks = len(masks)

        # Ensure masks are 2D by squeezing channel dimension
        masks = masks.squeeze(axis=1)

        # Create a single uint8 image with unique values for each mask
        combined_mask = np.zeros((h, w), dtype=np.uint8)

        for i in range(num_masks):
            mask = masks[i]
            mask = mask.astype(np.uint8)
            combined_mask[mask > 0] = i + 1

        # Create color map for visualization
        if random_color:
            colors = np.random.rand(num_masks, 3)  # Random colors for each mask
        else:
            colors = np.array(
                [[30 / 255, 144 / 255, 255 / 255]] * num_masks
            )  # Use fixed color

        # Create an RGB image where each mask has its own color
        color_image = np.zeros((h, w, 3), dtype=np.uint8)

        for i in range(1, num_masks + 1):
            mask_color = colors[i - 1] * 255
            color_image[combined_mask == i] = mask_color

        # Convert the NumPy array to a PIL Image
        pil_image = Image.fromarray(color_image)

        # Optional: Add contours to the mask image
        if smoothen_contours:
            contours_image = np.zeros((h, w, 4), dtype=np.float32)

            for i in range(1, num_masks + 1):
                mask = (combined_mask == i).astype(np.uint8)
                contours_image = smoothen(mask, contours_image)

            # Convert contours to PIL image and blend with the color image
            contours_image = (contours_image[:, :, :3] * 255).astype(np.uint8)
            contours_pil_image = Image.fromarray(contours_image)
            pil_image = Image.blend(pil_image, contours_pil_image, alpha=0.6)

        return pil_image
    else:
        sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True)
        img_shape = sorted_anns[0]["segmentation"].shape
        img = np.ones((img_shape[0], img_shape[1], 4))
        img[:, :, 3] = 0

        for ann in sorted_anns:
            m = ann["segmentation"]
            color_mask = np.concatenate([np.random.random(3), [0.5]])
            img[m] = color_mask

            if smoothen_contours:
                img = smoothen(m, img)

        img = (img * 255).astype(np.uint8)
        pil_image = Image.fromarray(img)

        return pil_image


def smoothen(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
    contours, _ = cv2.findContours(
        mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
    )
    contours = [
        cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
    ]
    image = cv2.drawContours(image, contours, -1, (0, 0, 1, 0.4), thickness=1)
    return image