# -*- coding: utf-8 -*-
"""
@File    :   visualizer.py
@Time    :   2022/04/05 11:39:33
@Author  :   Shilong Liu 
@Contact :   slongliu86@gmail.com
"""

import datetime
import os

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import transforms
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
from pycocotools import mask as maskUtils


def renorm(
    img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
) -> torch.FloatTensor:
    # img: tensor(3,H,W) or tensor(B,3,H,W)
    # return: same as img
    assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
    if img.dim() == 3:
        assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
            img.size(0),
            str(img.size()),
        )
        img_perm = img.permute(1, 2, 0)
        mean = torch.Tensor(mean)
        std = torch.Tensor(std)
        img_res = img_perm * std + mean
        return img_res.permute(2, 0, 1)
    else:  # img.dim() == 4
        assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
            img.size(1),
            str(img.size()),
        )
        img_perm = img.permute(0, 2, 3, 1)
        mean = torch.Tensor(mean)
        std = torch.Tensor(std)
        img_res = img_perm * std + mean
        return img_res.permute(0, 3, 1, 2)


class ColorMap:
    def __init__(self, basergb=[255, 255, 0]):
        self.basergb = np.array(basergb)

    def __call__(self, attnmap):
        # attnmap: h, w. np.uint8.
        # return: h, w, 4. np.uint8.
        assert attnmap.dtype == np.uint8
        h, w = attnmap.shape
        res = self.basergb.copy()
        res = res[None][None].repeat(h, 0).repeat(w, 1)  # h, w, 3
        attn1 = attnmap.copy()[..., None]  # h, w, 1
        res = np.concatenate((res, attn1), axis=-1).astype(np.uint8)
        return res


def rainbow_text(x, y, ls, lc, **kw):
    """
    Take a list of strings ``ls`` and colors ``lc`` and place them next to each
    other, with text ls[i] being shown in color lc[i].

    This example shows how to do both vertical and horizontal text, and will
    pass all keyword arguments to plt.text, so you can set the font size,
    family, etc.
    """
    t = plt.gca().transData
    fig = plt.gcf()
    plt.show()

    # horizontal version
    for s, c in zip(ls, lc):
        text = plt.text(x, y, " " + s + " ", color=c, transform=t, **kw)
        text.draw(fig.canvas.get_renderer())
        ex = text.get_window_extent()
        t = transforms.offset_copy(text._transform, x=ex.width, units="dots")

    # #vertical version
    # for s,c in zip(ls,lc):
    #     text = plt.text(x,y," "+s+" ",color=c, transform=t,
    #             rotation=90,va='bottom',ha='center',**kw)
    #     text.draw(fig.canvas.get_renderer())
    #     ex = text.get_window_extent()
    #     t = transforms.offset_copy(text._transform, y=ex.height, units='dots')


class COCOVisualizer:
    def __init__(self, coco=None, tokenlizer=None) -> None:
        self.coco = coco

    def visualize(self, img, tgt, caption=None, dpi=180, savedir="vis"):
        """
        img: tensor(3, H, W)
        tgt: make sure they are all on cpu.
            must have items: 'image_id', 'boxes', 'size'
        """
        plt.figure(dpi=dpi)
        plt.rcParams["font.size"] = "5"
        ax = plt.gca()
        img = renorm(img).permute(1, 2, 0)
        # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
        #     import ipdb; ipdb.set_trace()
        ax.imshow(img)

        self.addtgt(tgt)

        if tgt is None:
            image_id = 0
        elif "image_id" not in tgt:
            image_id = 0
        else:
            image_id = tgt["image_id"]

        if caption is None:
            savename = "{}/{}-{}.png".format(
                savedir, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
            )
        else:
            savename = "{}/{}-{}-{}.png".format(
                savedir, caption, int(image_id), str(datetime.datetime.now()).replace(" ", "-")
            )
        print("savename: {}".format(savename))
        os.makedirs(os.path.dirname(savename), exist_ok=True)
        plt.savefig(savename)
        plt.close()

    def addtgt(self, tgt):
        """ """
        if tgt is None or not "boxes" in tgt:
            ax = plt.gca()

            if "caption" in tgt:
                ax.set_title(tgt["caption"], wrap=True)

            ax.set_axis_off()
            return

        ax = plt.gca()
        H, W = tgt["size"]
        numbox = tgt["boxes"].shape[0]

        color = []
        polygons = []
        boxes = []
        for box in tgt["boxes"].cpu():
            unnormbbox = box * torch.Tensor([W, H, W, H])
            unnormbbox[:2] -= unnormbbox[2:] / 2
            [bbox_x, bbox_y, bbox_w, bbox_h] = unnormbbox.tolist()
            boxes.append([bbox_x, bbox_y, bbox_w, bbox_h])
            poly = [
                [bbox_x, bbox_y],
                [bbox_x, bbox_y + bbox_h],
                [bbox_x + bbox_w, bbox_y + bbox_h],
                [bbox_x + bbox_w, bbox_y],
            ]
            np_poly = np.array(poly).reshape((4, 2))
            polygons.append(Polygon(np_poly))
            c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
            color.append(c)

        p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.1)
        ax.add_collection(p)
        p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
        ax.add_collection(p)

        if "strings_positive" in tgt and len(tgt["strings_positive"]) > 0:
            assert (
                len(tgt["strings_positive"]) == numbox
            ), f"{len(tgt['strings_positive'])} = {numbox}, "
            for idx, strlist in enumerate(tgt["strings_positive"]):
                cate_id = int(tgt["labels"][idx])
                _string = str(cate_id) + ":" + " ".join(strlist)
                bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
                # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
                ax.text(
                    bbox_x,
                    bbox_y,
                    _string,
                    color="black",
                    bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
                )

        if "box_label" in tgt:
            assert len(tgt["box_label"]) == numbox, f"{len(tgt['box_label'])} = {numbox}, "
            for idx, bl in enumerate(tgt["box_label"]):
                _string = str(bl)
                bbox_x, bbox_y, bbox_w, bbox_h = boxes[idx]
                # ax.text(bbox_x, bbox_y, _string, color='black', bbox={'facecolor': 'yellow', 'alpha': 1.0, 'pad': 1})
                ax.text(
                    bbox_x,
                    bbox_y,
                    _string,
                    color="black",
                    bbox={"facecolor": color[idx], "alpha": 0.6, "pad": 1},
                )

        if "caption" in tgt:
            ax.set_title(tgt["caption"], wrap=True)
            # plt.figure()
            # rainbow_text(0.0,0.0,"all unicorns poop rainbows ! ! !".split(),
            #         ['red', 'orange', 'brown', 'green', 'blue', 'purple', 'black'])

        if "attn" in tgt:
            # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO':
            #     import ipdb; ipdb.set_trace()
            if isinstance(tgt["attn"], tuple):
                tgt["attn"] = [tgt["attn"]]
            for item in tgt["attn"]:
                attn_map, basergb = item
                attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min() + 1e-3)
                attn_map = (attn_map * 255).astype(np.uint8)
                cm = ColorMap(basergb)
                heatmap = cm(attn_map)
                ax.imshow(heatmap)
        ax.set_axis_off()

    def showAnns(self, anns, draw_bbox=False):
        """
        Display the specified annotations.
        :param anns (array of object): annotations to display
        :return: None
        """
        if len(anns) == 0:
            return 0
        if "segmentation" in anns[0] or "keypoints" in anns[0]:
            datasetType = "instances"
        elif "caption" in anns[0]:
            datasetType = "captions"
        else:
            raise Exception("datasetType not supported")
        if datasetType == "instances":
            ax = plt.gca()
            ax.set_autoscale_on(False)
            polygons = []
            color = []
            for ann in anns:
                c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
                if "segmentation" in ann:
                    if type(ann["segmentation"]) == list:
                        # polygon
                        for seg in ann["segmentation"]:
                            poly = np.array(seg).reshape((int(len(seg) / 2), 2))
                            polygons.append(Polygon(poly))
                            color.append(c)
                    else:
                        # mask
                        t = self.imgs[ann["image_id"]]
                        if type(ann["segmentation"]["counts"]) == list:
                            rle = maskUtils.frPyObjects(
                                [ann["segmentation"]], t["height"], t["width"]
                            )
                        else:
                            rle = [ann["segmentation"]]
                        m = maskUtils.decode(rle)
                        img = np.ones((m.shape[0], m.shape[1], 3))
                        if ann["iscrowd"] == 1:
                            color_mask = np.array([2.0, 166.0, 101.0]) / 255
                        if ann["iscrowd"] == 0:
                            color_mask = np.random.random((1, 3)).tolist()[0]
                        for i in range(3):
                            img[:, :, i] = color_mask[i]
                        ax.imshow(np.dstack((img, m * 0.5)))
                if "keypoints" in ann and type(ann["keypoints"]) == list:
                    # turn skeleton into zero-based index
                    sks = np.array(self.loadCats(ann["category_id"])[0]["skeleton"]) - 1
                    kp = np.array(ann["keypoints"])
                    x = kp[0::3]
                    y = kp[1::3]
                    v = kp[2::3]
                    for sk in sks:
                        if np.all(v[sk] > 0):
                            plt.plot(x[sk], y[sk], linewidth=3, color=c)
                    plt.plot(
                        x[v > 0],
                        y[v > 0],
                        "o",
                        markersize=8,
                        markerfacecolor=c,
                        markeredgecolor="k",
                        markeredgewidth=2,
                    )
                    plt.plot(
                        x[v > 1],
                        y[v > 1],
                        "o",
                        markersize=8,
                        markerfacecolor=c,
                        markeredgecolor=c,
                        markeredgewidth=2,
                    )

                if draw_bbox:
                    [bbox_x, bbox_y, bbox_w, bbox_h] = ann["bbox"]
                    poly = [
                        [bbox_x, bbox_y],
                        [bbox_x, bbox_y + bbox_h],
                        [bbox_x + bbox_w, bbox_y + bbox_h],
                        [bbox_x + bbox_w, bbox_y],
                    ]
                    np_poly = np.array(poly).reshape((4, 2))
                    polygons.append(Polygon(np_poly))
                    color.append(c)

            # p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
            # ax.add_collection(p)
            p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
            ax.add_collection(p)
        elif datasetType == "captions":
            for ann in anns:
                print(ann["caption"])