import random
import sys
from typing import Dict
from typing import List

import numpy as np
import supervision as sv
import torch
import torchvision
import torchvision.transforms as T
from huggingface_hub import hf_hub_download
from PIL import Image
from segment_anything import SamPredictor

# segment anything

sys.path.append("tag2text")
sys.path.append("GroundingDINO")

from groundingdino.models import build_model
from groundingdino.util.inference import Model as DinoModel
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from tag2text.inference import inference as tag2text_inference


def load_model_hf(repo_id, filename, ckpt_config_filename, device="cpu"):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file)
    args.device = device
    model = build_model(args)

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location=device)
    model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    model.eval()
    return model


def download_file_hf(repo_id, filename, cache_dir="./cache"):
    cache_file = hf_hub_download(
        repo_id=repo_id, filename=filename, force_filename=filename, cache_dir=cache_dir
    )
    return cache_file


def transform_image_tag2text(image_pil: Image) -> torch.Tensor:
    transform = T.Compose(
        [
            T.Resize((384, 384)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image = transform(image_pil)  # 3, h, w
    return image


def show_anns_sam(anns: List[Dict]):
    """Extracts the mask annotations from the Segment Anything model output and plots them.
    https://github.com/facebookresearch/segment-anything.

    Arguments:
      anns (List[Dict]): Segment Anything model output.

    Returns:
      (np.ndarray): Masked image.
      (np.ndarray): annotation encoding from https://github.com/LUSSeg/ImageNet-S
    """
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
    full_img = None

    # for ann in sorted_anns:
    for i in range(len(sorted_anns)):
        ann = anns[i]
        m = ann["segmentation"]
        if full_img is None:
            full_img = np.zeros((m.shape[0], m.shape[1], 3))
            map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
        map[m != 0] = i + 1
        color_mask = np.random.random((1, 3)).tolist()[0]
        full_img[m != 0] = color_mask
    full_img = full_img * 255

    # anno encoding from https://github.com/LUSSeg/ImageNet-S
    res = np.zeros((map.shape[0], map.shape[1], 3))
    res[:, :, 0] = map % 256
    res[:, :, 1] = map // 256
    res.astype(np.float32)
    full_img = np.uint8(full_img)
    return full_img, res


def show_anns_sv(detections: sv.Detections):
    """Extracts the mask annotations from the Supervision Detections object.
    https://roboflow.github.io/supervision/detection/core/.

    Arguments:
      anns (sv.Detections): Containing information about the detections.

    Returns:
      (np.ndarray): Masked image.
      (np.ndarray): annotation encoding from https://github.com/LUSSeg/ImageNet-S
    """
    if detections.mask is None:
        return
    full_img = None

    for i in np.flip(np.argsort(detections.area)):
        m = detections.mask[i]
        if full_img is None:
            full_img = np.zeros((m.shape[0], m.shape[1], 3))
            map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
        map[m != 0] = i + 1
        color_mask = np.random.random((1, 3)).tolist()[0]
        full_img[m != 0] = color_mask
    full_img = full_img * 255

    # anno encoding from https://github.com/LUSSeg/ImageNet-S
    res = np.zeros((map.shape[0], map.shape[1], 3))
    res[:, :, 0] = map % 256
    res[:, :, 1] = map // 256
    res.astype(np.float32)
    full_img = np.uint8(full_img)
    return full_img, res


def generate_tags(tag2text_model, image, specified_tags, device="cpu"):
    """Generate image tags and caption using Tag2Text model.

    Arguments:
      tag2text_model (nn.Module): Tag2Text model to use for prediction.
      image (np.ndarray): The image for calculating. Expects an
        image in HWC uint8 format, with pixel values in [0, 255].
      specified_tags(str): User input specified tags

    Returns:
      (List[str]): Predicted image tags.
      (str): Predicted image caption
    """
    image = transform_image_tag2text(image).unsqueeze(0).to(device)
    res = tag2text_inference(image, tag2text_model, specified_tags)
    tags = res[0].split(" | ")
    caption = res[2]
    return tags, caption


def detect(
    grounding_dino_model: DinoModel,
    image: np.ndarray,
    caption: str,
    box_threshold: float = 0.3,
    text_threshold: float = 0.25,
    iou_threshold: float = 0.5,
    post_process: bool = True,
):
    """Detect bounding boxes for the given image, using the input caption.

    Arguments:
      grounding_dino_model (DinoModel): The model to use for detection.
      image (np.ndarray): The image for calculating masks. Expects an
        image in HWC uint8 format, with pixel values in [0, 255].
      caption (str): Input caption contain object names to detect. To detect multiple objects, seperating each name with '.', like this: cat . dog . chair
      box_threshold (float): Box confidence threshold
      text_threshold (float): Text confidence threshold
      iou_threshold (float): IOU score threshold for post processing
      post_process (bool): If True, run NMS algorithm to remove duplicates segments.

    Returns:
      (sv.Detections): Containing information about the detections in a video frame.
      (str): Predicted phrases.
      (List[str]): Predicted classes.
    """
    detections, phrases = grounding_dino_model.predict_with_caption(
        image=image,
        caption=caption,
        box_threshold=box_threshold,
        text_threshold=text_threshold,
    )
    classes = list(map(lambda x: x.strip(), caption.split(".")))
    detections.class_id = DinoModel.phrases2classes(phrases=phrases, classes=classes)

    # NMS post process
    if post_process:
        # print(f"Before NMS: {len(detections.xyxy)} boxes")
        nms_idx = (
            torchvision.ops.nms(
                torch.from_numpy(detections.xyxy),
                torch.from_numpy(detections.confidence),
                iou_threshold,
            )
            .numpy()
            .tolist()
        )

        phrases = [phrases[idx] for idx in nms_idx]
        detections.xyxy = detections.xyxy[nms_idx]
        detections.confidence = detections.confidence[nms_idx]
        detections.class_id = detections.class_id[nms_idx]

        # print(f"After NMS: {len(detections.xyxy)} boxes")

    return detections, phrases, classes


def segment(sam_model: SamPredictor, image: np.ndarray, boxes: np.ndarray):
    """Predict masks for the given input boxes, using the currently set image.

    Arguments:
      sam_model (SamPredictor): The model to use for mask prediction.
      image (np.ndarray): The image for calculating masks. Expects an
        image in HWC uint8 format, with pixel values in [0, 255].
      boxes (np.ndarray or None): A Bx4 array given a box prompt to the
        model, in XYXY format.
      return_logits (bool): If true, returns un-thresholded masks logits
        instead of a binary mask.

    Returns:
      (torch.Tensor): The output masks in BxCxHxW format, where C is the
        number of masks, and (H, W) is the original image size.
      (torch.Tensor): An array of shape BxC containing the model's
        predictions for the quality of each mask.
      (torch.Tensor): An array of shape BxCxHxW, where C is the number
        of masks and H=W=256. These low res logits can be passed to
        a subsequent iteration as mask input.
    """
    sam_model.set_image(image)
    transformed_boxes = None
    if boxes is not None:
        boxes = torch.from_numpy(boxes)

        transformed_boxes = sam_model.transform.apply_boxes_torch(
            boxes.to(sam_model.device), image.shape[:2]
        )

    masks, scores, _ = sam_model.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    masks = masks[:, 0, :, :]
    scores = scores[:, 0]
    return masks.cpu().numpy(), scores.cpu().numpy()


def draw_mask(mask, draw, random_color=False):
    if random_color:
        color = (
            random.randint(0, 255),
            random.randint(0, 255),
            random.randint(0, 255),
            153,
        )
    else:
        color = (30, 144, 255, 153)

    nonzero_coords = np.transpose(np.nonzero(mask))

    for coord in nonzero_coords:
        draw.point(coord[::-1], fill=color)