import argparse
import functools
import json
import os
import sys
import tempfile

import cv2
import numpy as np
import supervision as sv
from groundingdino.util.inference import Model as DinoModel
from imutils import paths
from PIL import Image
from segment_anything import sam_model_registry
from segment_anything import SamAutomaticMaskGenerator
from segment_anything import SamPredictor
from supervision.detection.utils import xywh_to_xyxy
from tqdm import tqdm

sys.path.append("tag2text")

from tag2text.models import tag2text
from config import *
from utils import detect, download_file_hf, segment, generate_tags, show_anns_sv


def process(
    tag2text_model,
    grounding_dino_model,
    sam_predictor,
    sam_automask_generator,
    image_path,
    task,
    prompt,
    box_threshold,
    text_threshold,
    iou_threshold,
    kernel_size=2,
    expand_mask=False,
    device="cuda",
    output_dir=None,
    save_ann=True,
    save_mask=False,
):
    detections = None
    metadata = {"image": {}, "annotations": [], "assets": {}}

    if save_mask:
        metadata["assets"]["intermediate_mask"] = []

    try:
        # Load image
        image = Image.open(image_path)
        image_pil = image.convert("RGB")
        image = np.array(image_pil)
        orig_image = image.copy()

        # Extract image metadata
        filename = os.path.basename(image_path)
        basename = os.path.splitext(filename)[0]
        h, w = image.shape[:2]
        metadata["image"]["file_name"] = filename
        metadata["image"]["width"] = w
        metadata["image"]["height"] = h

        # Generate tags
        if task in ["auto", "detection"] and prompt == "":
            tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
            prompt = " . ".join(tags)
            # print(f"Caption: {caption}")
            # print(f"Tags: {tags}")

            # ToDo: Extract metadata
            metadata["image"]["caption"] = caption
            metadata["image"]["tags"] = tags

        if prompt:
            metadata["prompt"] = prompt

        # Detect boxes
        if prompt != "":
            detections, phrases, classes = detect(
                grounding_dino_model,
                image,
                caption=prompt,
                box_threshold=box_threshold,
                text_threshold=text_threshold,
                iou_threshold=iou_threshold,
                post_process=True,
            )

            # Save detection image
            if output_dir and save_ann:
                # Draw boxes
                box_annotator = sv.BoxAnnotator()
                labels = [
                    f"{phrases[i]} {detections.confidence[i]:0.2f}"
                    for i in range(len(phrases))
                ]
                box_image = box_annotator.annotate(
                    scene=image, detections=detections, labels=labels
                )
                box_image_path = os.path.join(output_dir, basename + "_detect.png")
                metadata["assets"]["detection"] = box_image_path
                Image.fromarray(box_image).save(box_image_path)

        # Segmentation
        if task in ["auto", "segment"]:
            kernel = cv2.getStructuringElement(
                cv2.MORPH_ELLIPSE, (2 * kernel_size + 1, 2 * kernel_size + 1)
            )
            if detections:
                masks, scores = segment(
                    sam_predictor, image=orig_image, boxes=detections.xyxy
                )
                if expand_mask:
                    masks = [
                        cv2.dilate(mask.astype(np.uint8), kernel) for mask in masks
                    ]
                else:
                    masks = [
                        cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
                        for mask in masks
                    ]
                detections.mask = masks
                binary_mask = functools.reduce(
                    lambda x, y: x + y, detections.mask
                ).astype(np.bool)
            else:
                masks = sam_automask_generator.generate(orig_image)
                sorted_generated_masks = sorted(
                    masks, key=lambda x: x["area"], reverse=True
                )

                xywh = np.array([mask["bbox"] for mask in sorted_generated_masks])
                scores = np.array(
                    [mask["predicted_iou"] for mask in sorted_generated_masks]
                )
                if expand_mask:
                    mask = np.array(
                        [
                            cv2.dilate(mask["segmentation"].astype(np.uint8), kernel)
                            for mask in sorted_generated_masks
                        ]
                    )
                else:
                    mask = np.array(
                        [mask["segmentation"] for mask in sorted_generated_masks]
                    )
                detections = sv.Detections(
                    xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
                )
                binary_mask = None

            # Save annotated image
            if output_dir and save_ann:
                mask_annotator = sv.MaskAnnotator()
                mask_image, res = show_anns_sv(detections)
                annotated_image = mask_annotator.annotate(image, detections=detections)

                mask_image_path = os.path.join(output_dir, basename + "_mask.png")
                metadata["assets"]["mask"] = mask_image_path
                Image.fromarray(mask_image).save(mask_image_path)

                # Save annotation encoding from https://github.com/LUSSeg/ImageNet-S
                mask_enc_path = os.path.join(output_dir, basename + "_mask_enc.npy")
                np.save(mask_enc_path, res)
                metadata["assets"]["mask_enc"] = mask_enc_path

                if binary_mask is not None:
                    cutout_image = np.expand_dims(binary_mask, axis=-1) * orig_image
                    cutout_image_path = os.path.join(
                        output_dir, basename + "_cutout.png"
                    )
                    Image.fromarray(cutout_image).save(cutout_image_path)

                annotated_image_path = os.path.join(
                    output_dir, basename + "_annotate.png"
                )
                metadata["assets"]["annotate"] = annotated_image_path
                Image.fromarray(annotated_image).save(annotated_image_path)

        # ToDo: Extract metadata
        if detections:
            i = 0
            for (xyxy, mask, confidence, _, _), area, box_area in zip(
                detections, detections.area, detections.box_area
            ):
                annotation = {
                    "id": i + 1,
                    "bbox": [int(x) for x in xyxy],
                    "box_area": float(box_area),
                }
                if confidence:
                    annotation["confidence"] = float(confidence)
                    annotation["label"] = phrases[i]
                if mask is not None:
                    # annotation["segmentation"] = mask_to_polygons(mask)
                    annotation["area"] = int(area)
                    annotation["predicted_iou"] = float(scores[i])
                metadata["annotations"].append(annotation)
                i += 1

                if output_dir and save_mask:
                    mask_image_path = os.path.join(
                        output_dir, f"{basename}_mask_{id}.png"
                    )
                    metadata["assets"]["intermediate_mask"].append(mask_image_path)
                    Image.fromarray(mask * 255).save(mask_image_path)

        if output_dir:
            meta_file_path = os.path.join(output_dir, basename + "_meta.json")
            with open(meta_file_path, "w") as fp:
                json.dump(metadata, fp)
        else:
            meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
            meta_file_path = meta_file.name

        return meta_file_path
    except Exception as error:
        raise ValueError(f"global exception: {error}")


def main(args: argparse.Namespace) -> None:
    device = args.device
    prompt = args.prompt
    task = args.task

    tag2text_model = None
    grounding_dino_model = None
    sam_predictor = None
    sam_automask_generator = None

    box_threshold = args.box_threshold
    text_threshold = args.text_threshold
    iou_threshold = args.iou_threshold
    save_ann = not args.no_save_ann
    save_mask = args.save_mask

    # load model
    if task in ["auto", "detection"] and prompt == "":
        print("Loading Tag2Text model...")
        tag2text_type = args.tag2text_type
        tag2text_checkpoint = os.path.join(
            abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
        )
        if not os.path.exists(tag2text_checkpoint):
            print(f"Downloading weights for Tag2Text {tag2text_type} model")
            os.system(
                f"wget {tag2text_dict[tag2text_type]['checkpoint_url']} -O {tag2text_checkpoint}"
            )
        tag2text_model = tag2text.tag2text_caption(
            pretrained=tag2text_checkpoint,
            image_size=384,
            vit="swin_b",
            delete_tag_index=delete_tag_index,
        )
        # threshold for tagging
        # we reduce the threshold to obtain more tags
        tag2text_model.threshold = 0.64
        tag2text_model.to(device)
        tag2text_model.eval()

    if task in ["auto", "detection"] or prompt != "":
        print("Loading Grounding Dino model...")
        dino_type = args.dino_type
        dino_checkpoint = os.path.join(
            abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
        )
        dino_config_file = os.path.join(
            abs_weight_dir, dino_dict[dino_type]["config_file"]
        )
        if not os.path.exists(dino_checkpoint):
            print(f"Downloading weights for Grounding Dino {dino_type} model")
            dino_repo_id = dino_dict[dino_type]["repo_id"]
            download_file_hf(
                repo_id=dino_repo_id,
                filename=dino_dict[dino_type]["checkpoint_file"],
                cache_dir=weight_dir,
            )
            download_file_hf(
                repo_id=dino_repo_id,
                filename=dino_dict[dino_type]["checkpoint_file"],
                cache_dir=weight_dir,
            )
        grounding_dino_model = DinoModel(
            model_config_path=dino_config_file,
            model_checkpoint_path=dino_checkpoint,
            device=device,
        )

    if task in ["auto", "segment"]:
        print("Loading SAM...")
        sam_type = args.sam_type
        sam_checkpoint = os.path.join(
            abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
        )
        if not os.path.exists(sam_checkpoint):
            print(f"Downloading weights for SAM {sam_type}")
            os.system(
                f"wget {sam_dict[sam_type]['checkpoint_url']} -O {sam_checkpoint}"
            )
        sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
        sam.to(device=device)
        sam_predictor = SamPredictor(sam)
        sam_automask_generator = SamAutomaticMaskGenerator(sam)

    if not os.path.exists(args.input):
        raise ValueError("The input directory doesn't exist!")
    elif not os.path.isdir(args.input):
        image_paths = [args.input]
    else:
        image_paths = paths.list_images(args.input)

    os.makedirs(args.output, exist_ok=True)

    with tqdm(image_paths) as pbar:
        for image_path in pbar:
            pbar.set_postfix_str(f"Processing {image_path}")
            process(
                tag2text_model=tag2text_model,
                grounding_dino_model=grounding_dino_model,
                sam_predictor=sam_predictor,
                sam_automask_generator=sam_automask_generator,
                image_path=image_path,
                task=task,
                prompt=prompt,
                box_threshold=box_threshold,
                text_threshold=text_threshold,
                iou_threshold=iou_threshold,
                device=device,
                output_dir=args.output,
                save_ann=save_ann,
                save_mask=save_mask,
            )


if __name__ == "__main__":
    if not os.path.exists(abs_weight_dir):
        os.makedirs(abs_weight_dir, exist_ok=True)

    parser = argparse.ArgumentParser(
        description=(
            "Runs automatic detection and mask generation on an input image or directory of images"
        )
    )

    parser.add_argument(
        "--input",
        "-i",
        type=str,
        required=True,
        help="Path to either a single input image or folder of images.",
    )

    parser.add_argument(
        "--output",
        "-o",
        type=str,
        required=True,
        help="Path to the directory where masks will be output.",
    )

    parser.add_argument(
        "--sam-type",
        type=str,
        default=default_sam,
        choices=sam_dict.keys(),
        help="The type of SA model use for segmentation.",
    )

    parser.add_argument(
        "--tag2text-type",
        type=str,
        default=default_tag2text,
        choices=tag2text_dict.keys(),
        help="The type of Tag2Text model use for tags and caption generation.",
    )

    parser.add_argument(
        "--dino-type",
        type=str,
        default=default_dino,
        choices=dino_dict.keys(),
        help="The type of Grounding Dino model use for promptable object detection.",
    )

    parser.add_argument(
        "--task",
        help="Task to run",
        default="auto",
        choices=["auto", "detect", "segment"],
        type=str,
    )
    parser.add_argument(
        "--prompt",
        help="Detection prompt",
        default="",
        type=str,
    )

    parser.add_argument(
        "--box-threshold", type=float, default=0.25, help="box threshold"
    )
    parser.add_argument(
        "--text-threshold", type=float, default=0.2, help="text threshold"
    )
    parser.add_argument(
        "--iou-threshold", type=float, default=0.5, help="iou threshold"
    )
    parser.add_argument(
        "--kernel-size",
        type=int,
        default=2,
        choices=range(1, 6),
        help="kernel size use for smoothing/expanding segment masks",
    )
    parser.add_argument(
        "--expand-mask",
        action="store_true",
        default=False,
        help="If True, expanding segment masks for smoother output.",
    )

    parser.add_argument(
        "--no-save-ann",
        action="store_true",
        default=False,
        help="If False, save original image with blended masks and detection boxes.",
    )
    parser.add_argument(
        "--save-mask",
        action="store_true",
        default=False,
        help="If True, save all intermidiate masks.",
    )
    parser.add_argument(
        "--device", type=str, default="cuda", help="The device to run generation on."
    )
    args = parser.parse_args()
    main(args)