import subprocess, os, sys

result = subprocess.run(["pip", "install", "-e", "GroundingDINO"], check=True)
print(f"pip install GroundingDINO = {result}")

sys.path.insert(0, "./GroundingDINO")

if not os.path.exists("./sam_vit_h_4b8939.pth"):
    result = subprocess.run(
        [
            "wget",
            "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
        ],
        check=True,
    )
    print(f"wget sam_vit_h_4b8939.pth result = {result}")


import argparse
import random
import warnings
import json

import gradio as gr
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from scipy import ndimage
from PIL import Image
from huggingface_hub import hf_hub_download
from segments.utils import bitmap2file

# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import (
    clean_state_dict,
)
from GroundingDINO.groundingdino.util.inference import annotate, predict

# segment anything
from segment_anything import build_sam, SamPredictor

# CLIPSeg
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation


def load_model_hf(model_config_path, repo_id, filename, device):
    args = SLConfig.fromfile(model_config_path)
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location=device)
    log = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print("Model loaded from {} \n => {}".format(cache_file, log))
    _ = model.eval()
    model = model.to(device)
    return model


def load_image_for_dino(image):
    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    dino_image, _ = transform(image, None)
    return dino_image


def dino_detection(
    model,
    image,
    image_array,
    category_names,
    category_name_to_id,
    box_threshold,
    text_threshold,
    device,
    visualize=False,
):
    detection_prompt = " . ".join(category_names)
    dino_image = load_image_for_dino(image)
    dino_image = dino_image.to(device)
    with torch.no_grad():
        boxes, logits, phrases = predict(
            model=model,
            image=dino_image,
            caption=detection_prompt,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            device=device,
            remove_combined=True
        )
    category_ids = [category_name_to_id[phrase] for phrase in phrases]

    if visualize:
        annotated_frame = annotate(
            image_source=image_array, boxes=boxes, logits=logits, phrases=phrases
        )
        annotated_frame = annotated_frame[..., ::-1]  # BGR to RGB
        visualization = Image.fromarray(annotated_frame)
        return boxes, category_ids, visualization
    else:
        return boxes, category_ids, phrases


def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
    # box: normalized box xywh -> unnormalized xyxy
    H, W, _ = image_array.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    transformed_boxes = predictor.transform.apply_boxes_torch(
        boxes_xyxy, image_array.shape[:2]
    ).to(device)
    thing_masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    return thing_masks


def preds_to_semantic_inds(preds, threshold):
    flat_preds = preds.reshape((preds.shape[0], -1))
    # Initialize a dummy "unlabeled" mask with the threshold
    flat_preds_with_treshold = torch.full(
        (preds.shape[0] + 1, flat_preds.shape[-1]), threshold
    )
    flat_preds_with_treshold[1 : preds.shape[0] + 1, :] = flat_preds

    # Get the top mask index for each pixel
    semantic_inds = torch.topk(flat_preds_with_treshold, 1, dim=0).indices.reshape(
        (preds.shape[-2], preds.shape[-1])
    )

    return semantic_inds


def clipseg_segmentation(
    processor, model, image, category_names, background_threshold, device
):
    inputs = processor(
        text=category_names,
        images=[image] * len(category_names),
        padding="max_length",
        return_tensors="pt",
    ).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    if len(logits.shape) == 2:
        logits = logits.unsqueeze(0)
    # resize the outputs
    upscaled_logits = nn.functional.interpolate(
        logits.unsqueeze(1),
        size=(image.size[1], image.size[0]),
        mode="bilinear",
    )
    preds = torch.sigmoid(upscaled_logits.squeeze(dim=1))
    semantic_inds = preds_to_semantic_inds(preds, background_threshold)
    return preds, semantic_inds


def semantic_inds_to_shrunken_bool_masks(
    semantic_inds, shrink_kernel_size, num_categories
):
    shrink_kernel = np.ones((shrink_kernel_size, shrink_kernel_size))

    bool_masks = torch.zeros((num_categories, *semantic_inds.shape), dtype=bool)
    for category in range(num_categories):
        binary_mask = semantic_inds == category
        shrunken_binary_mask_array = (
            ndimage.binary_erosion(binary_mask.numpy(), structure=shrink_kernel)
            if shrink_kernel_size > 0
            else binary_mask.numpy()
        )
        bool_masks[category] = torch.from_numpy(shrunken_binary_mask_array)

    return bool_masks


def clip_and_shrink_preds(semantic_inds, preds, shrink_kernel_size, num_categories):
    # convert semantic_inds to shrunken bool masks
    bool_masks = semantic_inds_to_shrunken_bool_masks(
        semantic_inds, shrink_kernel_size, num_categories
    ).to(preds.device)

    sizes = [
        torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
    ]
    max_size = max(sizes)
    relative_sizes = [size / max_size for size in sizes] if max_size > 0 else sizes

    # use bool masks to clip preds
    clipped_preds = torch.zeros_like(preds)
    for i in range(1, bool_masks.size(0)):
        float_mask = bool_masks[i].float()
        clipped_preds[i - 1] = preds[i - 1] * float_mask

    return clipped_preds, relative_sizes


def sample_points_based_on_preds(preds, N):
    height, width = preds.shape
    weights = preds.ravel()
    indices = np.arange(height * width)

    # Randomly sample N indices based on the weights
    sampled_indices = random.choices(indices, weights=weights, k=N)

    # Convert the sampled indices into (col, row) coordinates
    sampled_points = [(index % width, index // width) for index in sampled_indices]

    return sampled_points


def upsample_pred(pred, image_source):
    pred = pred.unsqueeze(dim=0)
    original_height = image_source.shape[0]
    original_width = image_source.shape[1]

    larger_dim = max(original_height, original_width)
    aspect_ratio = original_height / original_width

    # upsample the tensor to the larger dimension
    upsampled_tensor = F.interpolate(
        pred, size=(larger_dim, larger_dim), mode="bilinear", align_corners=False
    )

    # remove the padding (at the end) to get the original image resolution
    if original_height > original_width:
        target_width = int(upsampled_tensor.shape[3] * aspect_ratio)
        upsampled_tensor = upsampled_tensor[:, :, :, :target_width]
    else:
        target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
        upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
    return upsampled_tensor.squeeze(dim=1)


def sam_mask_from_points(predictor, image_array, points):
    points_array = np.array(points)
    # we only sample positive points, so labels are all 1
    points_labels = np.ones(len(points))
    # we don't use predict_torch here cause it didn't seem to work...
    _, _, logits = predictor.predict(
        point_coords=points_array,
        point_labels=points_labels,
    )
    # max over the 3 segmentation levels
    total_pred = torch.max(torch.sigmoid(torch.tensor(logits)), dim=0)[0].unsqueeze(
        dim=0
    )
    # logits are 256x256 -> upsample back to image shape
    upsampled_pred = upsample_pred(total_pred, image_array)
    return upsampled_pred


def inds_to_segments_format(
    panoptic_inds, thing_category_ids, stuff_category_names, category_name_to_id
):
    panoptic_inds_array = panoptic_inds.numpy().astype(np.uint32)
    bitmap_file = bitmap2file(panoptic_inds_array, is_segmentation_bitmap=True)
    segmentation_bitmap = Image.open(bitmap_file)

    stuff_category_ids = [
        category_name_to_id[stuff_category_name]
        for stuff_category_name in stuff_category_names
    ]

    unique_inds = np.unique(panoptic_inds_array)
    stuff_annotations = [
        {"id": i, "category_id": stuff_category_ids[i - 1]}
        for i in range(1, len(stuff_category_names) + 1)
        if i in unique_inds
    ]
    thing_annotations = [
        {"id": len(stuff_category_names) + 1 + i, "category_id": thing_category_id}
        for i, thing_category_id in enumerate(thing_category_ids)
    ]
    annotations = stuff_annotations + thing_annotations

    return segmentation_bitmap, annotations


def generate_panoptic_mask(
    image,
    thing_category_names_string,
    stuff_category_names_string,
    dino_box_threshold=0.3,
    dino_text_threshold=0.25,
    segmentation_background_threshold=0.1,
    shrink_kernel_size=20,
    num_samples_factor=1000,
    task_attributes_json="",
):
    if task_attributes_json != "":
        task_attributes = json.loads(task_attributes_json)
        categories = task_attributes["categories"]
        category_name_to_id = {
            category["name"]: category["id"] for category in categories
        }
        # split the categories into "stuff" categories (regions w/o instances)
        # and "thing" categories (objects/instances)
        stuff_categories = [
            category
            for category in categories
            if "has_instances" not in category or not category["has_instances"]
        ]
        thing_categories = [
            category
            for category in categories
            if "has_instances" in category and category["has_instances"]
        ]
        stuff_category_names = [category["name"] for category in stuff_categories]
        thing_category_names = [category["name"] for category in thing_categories]
        category_names = thing_category_names + stuff_category_names
    else:
        # parse inputs
        thing_category_names = [
            thing_category_name.strip()
            for thing_category_name in thing_category_names_string.split(",")
        ]
        stuff_category_names = [
            stuff_category_name.strip()
            for stuff_category_name in stuff_category_names_string.split(",")
        ]
        category_names = thing_category_names + stuff_category_names
        category_name_to_id = {
            category_name: i for i, category_name in enumerate(category_names)
        }

    image = image.convert("RGB")
    image_array = np.asarray(image)

    # compute SAM image embedding
    sam_predictor.set_image(image_array)

    # detect boxes for "thing" categories using Grounding DINO
    thing_category_ids = []
    thing_masks = []
    thing_boxes = []
    detected_thing_category_names = []
    if len(thing_category_names) > 0:
        thing_boxes, thing_category_ids, detected_thing_category_names = dino_detection(
            dino_model,
            image,
            image_array,
            thing_category_names,
            category_name_to_id,
            dino_box_threshold,
            dino_text_threshold,
            device,
        )
        if len(thing_boxes) > 0:
            # get segmentation masks for the thing boxes
            thing_masks = sam_masks_from_dino_boxes(
                sam_predictor, image_array, thing_boxes, device
            )
    if len(stuff_category_names) > 0:
        # get rough segmentation masks for "stuff" categories using CLIPSeg
        clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
            clipseg_processor,
            clipseg_model,
            image,
            stuff_category_names,
            segmentation_background_threshold,
            device,
        )
        # remove things from stuff masks
        clipseg_semantic_inds_without_things = clipseg_semantic_inds.clone()
        if len(thing_boxes) > 0:
            combined_things_mask = torch.any(thing_masks, dim=0)
            clipseg_semantic_inds_without_things[combined_things_mask[0]] = 0
        # clip CLIPSeg preds based on non-overlapping semantic segmentation inds (+ optionally shrink the mask of each category)
        # also returns the relative size of each category
        clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
            clipseg_semantic_inds_without_things,
            clipseg_preds,
            shrink_kernel_size,
            len(stuff_category_names) + 1,
        )
        # get finer segmentation masks for the "stuff" categories using SAM
        sam_preds = torch.zeros_like(clipsed_clipped_preds)
        for i in range(clipsed_clipped_preds.shape[0]):
            clipseg_pred = clipsed_clipped_preds[i]
            # for each "stuff" category, sample points in the rough segmentation mask
            num_samples = int(relative_sizes[i] * num_samples_factor)
            if num_samples == 0:
                continue
            points = sample_points_based_on_preds(
                clipseg_pred.cpu().numpy(), num_samples
            )
            if len(points) == 0:
                continue
            # use SAM to get mask for points
            pred = sam_mask_from_points(sam_predictor, image_array, points)
            sam_preds[i] = pred
        sam_semantic_inds = preds_to_semantic_inds(
            sam_preds, segmentation_background_threshold
        )

    # combine the thing inds and the stuff inds into panoptic inds
    panoptic_inds = (
        sam_semantic_inds.clone()
        if len(stuff_category_names) > 0
        else torch.zeros(image_array.shape[0], image_array.shape[1], dtype=torch.long)
    )
    ind = len(stuff_category_names) + 1
    for thing_mask in thing_masks:
        # overlay thing mask on panoptic inds
        panoptic_inds[thing_mask.squeeze(dim=0)] = ind
        ind += 1

    panoptic_bool_masks = (
        semantic_inds_to_shrunken_bool_masks(panoptic_inds, 0, ind + 1)
        .numpy()
        .astype(int)
    )
    panoptic_names = (
        ["unlabeled"] + stuff_category_names + detected_thing_category_names
    )
    subsection_label_pairs = [
        (panoptic_bool_masks[i], panoptic_name)
        for i, panoptic_name in enumerate(panoptic_names)
    ]

    segmentation_bitmap, annotations = inds_to_segments_format(
        panoptic_inds, thing_category_ids, stuff_category_names, category_name_to_id
    )
    annotations_json = json.dumps(annotations)

    return (image_array, subsection_label_pairs), segmentation_bitmap, annotations_json


config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filename = "groundingdino_swint_ogc.pth"
sam_checkpoint = "./sam_vit_h_4b8939.pth"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

if device != "cpu":
    try:
        from GroundingDINO.groundingdino import _C
    except:
        warnings.warn(
            "Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!"
        )

# initialize groundingdino model
dino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filename, device)

# initialize SAM
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)

clipseg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clipseg_model = CLIPSegForImageSegmentation.from_pretrained(
    "CIDAS/clipseg-rd64-refined"
)
clipseg_model.to(device)

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Panoptic Segment Anything demo", add_help=True)
    parser.add_argument("--debug", action="store_true", help="using debug mode")
    parser.add_argument("--share", action="store_true", help="share the app")
    args = parser.parse_args()

    print(f"args = {args}")

    block = gr.Blocks(title="Panoptic Segment Anything").queue()
    with block:
        with gr.Column():
            title = gr.Markdown(
                "# [Panoptic Segment Anything](https://github.com/segments-ai/panoptic-segment-anything)"
            )
            description = gr.Markdown(
                "Demo for zero-shot panoptic segmentation using Segment Anything, Grounding DINO, and CLIPSeg."
            )
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(sources=["upload"], type="pil")
                    thing_category_names_string = gr.Textbox(
                        label="Thing categories (i.e. categories with instances), comma-separated",
                        placeholder="E.g. car, bus, person",
                    )
                    stuff_category_names_string = gr.Textbox(
                        label="Stuff categories (i.e. categories without instances), comma-separated",
                        placeholder="E.g. sky, road, buildings",
                    )
                    run_button = gr.Button(value="Run")
                    with gr.Accordion("Advanced options", open=False):
                        box_threshold = gr.Slider(
                            label="Grounding DINO box threshold",
                            minimum=0.0,
                            maximum=1.0,
                            value=0.3,
                            step=0.001,
                        )
                        text_threshold = gr.Slider(
                            label="Grounding DINO text threshold",
                            minimum=0.0,
                            maximum=1.0,
                            value=0.25,
                            step=0.001,
                        )
                        segmentation_background_threshold = gr.Slider(
                            label="Segmentation background threshold (under this threshold, a pixel is considered background/unlabeled)",
                            minimum=0.0,
                            maximum=1.0,
                            value=0.1,
                            step=0.001,
                        )
                        shrink_kernel_size = gr.Slider(
                            label="Shrink kernel size (how much to shrink the mask before sampling points)",
                            minimum=0,
                            maximum=100,
                            value=20,
                            step=1,
                        )
                        num_samples_factor = gr.Slider(
                            label="Number of samples factor (how many points to sample in the largest category)",
                            minimum=0,
                            maximum=1000,
                            value=1000,
                            step=1,
                        )
                        task_attributes_json = gr.Textbox(
                            label="Task attributes JSON",
                        )

                with gr.Column():
                    annotated_image = gr.AnnotatedImage()
                    with gr.Accordion("Segmentation bitmap", open=False):
                        segmentation_bitmap_text = gr.Markdown(
                            """
The segmentation bitmap is a 32-bit RGBA png image which contains the segmentation masks.
The alpha channel is set to 255, and the remaining 24-bit values in the RGB channels correspond to the object ids in the annotations list.
Unlabeled regions have a value of 0.
Because of the large dynamic range, the segmentation bitmap appears black in the image viewer.
"""
                        )
                        segmentation_bitmap = gr.Image(
                            type="pil", label="Segmentation bitmap"
                        )
                        annotations_json = gr.Textbox(
                            label="Annotations JSON",
                        )

            examples = gr.Examples(
                examples=[
                    [
                        "a2d2.png",
                        "car, bus, person",
                        "road, sky, buildings, sidewalk",
                    ],
                    [
                        "bxl.png",
                        "car, tram, motorcycle, person",
                        "road, buildings, sky",
                    ],
                ],
                fn=generate_panoptic_mask,
                inputs=[
                    input_image,
                    thing_category_names_string,
                    stuff_category_names_string,
                ],
                outputs=[annotated_image, segmentation_bitmap, annotations_json],
                cache_examples=True,
            )

        run_button.click(
            fn=generate_panoptic_mask,
            inputs=[
                input_image,
                thing_category_names_string,
                stuff_category_names_string,
                box_threshold,
                text_threshold,
                segmentation_background_threshold,
                shrink_kernel_size,
                num_samples_factor,
                task_attributes_json,
            ],
            outputs=[annotated_image, segmentation_bitmap, annotations_json],
            api_name="segment",
        )

    block.launch(server_name="0.0.0.0", debug=args.debug, share=args.share)