import torch
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
metadata = MetadataCatalog.get('coco_2017_train_panoptic')


class SemanticSAMPredictor:
    def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100):
        """
        thresh: iou thresh to filter low confidence objects
        text_size: resize the input image short edge for the model to process
        hole_scale: fill in small holes as in SAM
        island_scale: remove small regions as in SAM
        """
        self.model = model
        self.thresh = thresh
        self.text_size = hole_scale
        self.hole_scale = hole_scale
        self.island_scale = island_scale
        self.point = None

    def predict(self, image_ori, image, point=None):
        """
        produce up to 6 prediction results for each click
        """
        width = image_ori.shape[0]
        height = image_ori.shape[1]

        data = {"image": image, "height": height, "width": width}
        # import ipdb; ipdb.set_trace()
        if point is None:
            point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda()
        else:
            point = torch.tensor(point).cuda()
            point_ = point
            point = point_.clone()
            point[0, 0] = point_[0, 0]
            point[0, 1] = point_[0, 1]
            # point = point[:, [1, 0]]
            point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1)

        self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point))

        data['targets'] = [dict()]
        data['targets'][0]['points'] = point
        data['targets'][0]['pb'] = point.new_tensor([0.])

        batch_inputs = [data]
        masks, ious = self.model.model.evaluate_demo(batch_inputs)

        return masks, ious

    def process_multi_mask(self, masks, ious, image_ori):
        pred_masks_poses = masks
        reses = []
        ious = ious[0, 0]
        ids = torch.argsort(ious, descending=True)

        text_res = ''
        mask_ls = []
        ious_res = []
        areas = []
        for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])):
            iou = round(float(iou), 2)
            texts = f'{iou}'
            mask = (pred_masks_pos > 0.0).cpu().numpy()
            area = mask.sum()
            conti = False
            if iou < self.thresh:
                conti = True
            for m in mask_ls:
                if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95:
                    conti = True
                    break
            if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []:
                conti = False
            if conti:
                continue
            ious_res.append(iou)
            mask_ls.append(mask)
            areas.append(area)
            mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes")
            mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands")
            mask = (mask).astype(np.float)
            out_txt = texts
            visual = Visualizer(image_ori, metadata=metadata)
            color = [0., 0., 1.0]
            demo = visual.draw_binary_mask(mask, color=color, text=texts)
            res = demo.get_image()
            point_x0 = max(0, int(self.point[0, 0]) - 3)
            point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3)
            point_y0 = max(0, int(self.point[0, 1]) - 3)
            point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3)
            res[point_y0:point_y1, point_x0:point_x1, 0] = 255
            res[point_y0:point_y1, point_x0:point_x1, 1] = 0
            res[point_y0:point_y1, point_x0:point_x1, 2] = 0
            reses.append(Image.fromarray(res))
            text_res = text_res + ';' + out_txt
        ids = list(torch.argsort(torch.tensor(areas), descending=False))
        ids = [int(i) for i in ids]

        torch.cuda.empty_cache()

        return reses, [reses[i] for i in ids]

    def predict_masks(self, image_ori, image, point=None):
        masks, ious = self.predict(image_ori, image, point)
        return self.process_multi_mask(masks, ious, image_ori)

    @staticmethod
    def remove_small_regions(
            mask: np.ndarray, area_thresh: float, mode: str
    ) -> Tuple[np.ndarray, bool]:
        """
        Removes small disconnected regions and holes in a mask. Returns the
        mask and an indicator of if the mask has been modified.
        """
        import cv2  # type: ignore

        assert mode in ["holes", "islands"]
        correct_holes = mode == "holes"
        working_mask = (correct_holes ^ mask).astype(np.uint8)
        n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
        sizes = stats[:, -1][1:]  # Row 0 is background label
        small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
        if len(small_regions) == 0:
            return mask, False
        fill_labels = [0] + small_regions
        if not correct_holes:
            fill_labels = [i for i in range(n_labels) if i not in fill_labels]
            # If every region is below threshold, keep largest
            if len(fill_labels) == 0:
                fill_labels = [int(np.argmax(sizes)) + 1]
        mask = np.isin(regions, fill_labels)
        return mask, True