from typing import List

import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import pipeline, CLIPProcessor, CLIPModel

MARKDOWN = """
# Segment Anything Model + MetaCLIP
This is the demo for a Open Vocabulary Image Segmentation using 
[Segment Anything Model](https://github.com/facebookresearch/segment-anything) and
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo.
"""
EXAMPLES = [
    ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5],
    ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5],
    ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5],
    ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6],
]
MIN_AREA_THRESHOLD = 0.01
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAM_GENERATOR = pipeline(
    task="mask-generation",
    model="facebook/sam-vit-large",
    device=DEVICE)
CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE)
CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator(
    color=sv.Color.RED,
    color_lookup=sv.ColorLookup.INDEX)
SOLID_MASK_ANNOTATOR = sv.MaskAnnotator(
    color=sv.Color.WHITE,
    color_lookup=sv.ColorLookup.INDEX,
    opacity=1)


def run_sam(image_rgb_pil: Image.Image) -> sv.Detections:
    outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32)
    mask = np.array(outputs['masks'])
    return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)


def run_clip(image_rgb_pil: Image.Image, text: List[str]) -> np.ndarray:
    inputs = CLIP_PROCESSOR(
        text=text,
        images=image_rgb_pil,
        return_tensors="pt",
        padding=True
    ).to(DEVICE)
    outputs = CLIP_MODEL(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1)
    return probs.detach().cpu().numpy()


def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128):
    gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8)
    return np.where(mask[..., None], image, gray_color)


def annotate(
    image_rgb_pil: Image.Image,
    detections: sv.Detections,
    annotator: sv.MaskAnnotator
) -> Image.Image:
    img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1]
    annotated_bgr_image = annotator.annotate(
        scene=img_bgr_numpy, detections=detections)
    return Image.fromarray(annotated_bgr_image[:, :, ::-1])


def filter_detections(
    image_rgb_pil: Image.Image,
    detections: sv.Detections,
    prompt: str,
    confidence: float
) -> sv.Detections:
    img_rgb_numpy = np.array(image_rgb_pil)
    text = [f"a picture of {prompt}", "a picture of background"]
    filtering_mask = []

    for xyxy, mask in zip(detections.xyxy, detections.mask):
        crop = sv.crop_image(image=img_rgb_numpy, xyxy=xyxy)
        mask_crop = sv.crop_image(image=mask, xyxy=xyxy)
        masked_crop = reverse_mask_image(image=crop, mask=mask_crop)
        masked_crop_pil = Image.fromarray(masked_crop)
        probs = run_clip(image_rgb_pil=masked_crop_pil, text=text)
        filtering_mask.append(probs[0][0] > confidence)

    filtering_mask = np.array(filtering_mask)
    return detections[filtering_mask]


def inference(
    image_rgb_pil: Image.Image,
    prompt: str,
    confidence: float
) -> List[Image.Image]:
    width, height = image_rgb_pil.size
    area = width * height

    detections = run_sam(image_rgb_pil)
    detections = detections[detections.area / area > MIN_AREA_THRESHOLD]
    detections = filter_detections(
        image_rgb_pil=image_rgb_pil,
        detections=detections,
        prompt=prompt,
        confidence=confidence)

    blank_image = Image.new("RGB", (width, height), "black")
    return [
        annotate(
            image_rgb_pil=image_rgb_pil,
            detections=detections,
            annotator=SEMITRANSPARENT_MASK_ANNOTATOR),
        annotate(
            image_rgb_pil=blank_image,
            detections=detections,
            annotator=SOLID_MASK_ANNOTATOR)
    ]


with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                image_mode='RGB', type='pil', height=500)
            prompt_text = gr.Textbox(
                label="Prompt", value="dog")
            confidence_slider = gr.Slider(
                label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6)
            submit_button = gr.Button("Submit")
        gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True)
    with gr.Row():
        gr.Examples(
            examples=EXAMPLES,
            fn=inference,
            inputs=[input_image, prompt_text, confidence_slider],
            outputs=[gallery],
            cache_examples=True,
            run_on_click=True
        )

    submit_button.click(
        inference,
        inputs=[input_image, prompt_text, confidence_slider],
        outputs=gallery)

demo.launch(debug=False, show_error=True)