from typing import List

import os
import cv2
import supervision as sv
import numpy as np
import gradio as gr
import torch

from transformers import pipeline
from PIL import Image

# Definición de la clase SamAutomaticMaskGenerator
class SamAutomaticMaskGenerator:
    def __init__(self, sam_pipeline):
        self.sam_pipeline = sam_pipeline

    def generate(self, image_rgb):
        # Convertir el array de NumPy a PIL Image
        image_pil = Image.fromarray(image_rgb)
        outputs = self.sam_pipeline(image_pil, points_per_batch=32)
        mask = np.array(outputs['masks'], dtype=np.uint8)
        return mask

# Configuración del modelo SAM
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
sam_pipeline = pipeline(
    task="mask-generation",
    model="facebook/sam-vit-large",
    device=DEVICE
)
EXAMPLES = [
    ["https://media.roboflow.com/notebooks/examples/dog.jpeg"],
    ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg"]
]


mask_generator = SamAutomaticMaskGenerator(sam_pipeline)

# Función para procesar y anotar la imagen
def process_image(image_pil):
    # Convertir PIL Image a numpy array para procesamiento
    image_rgb = np.array(image_pil)
    image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)

    # Generar la máscara y anotar la imagen
    sam_result = mask_generator.generate(image_rgb)
    mask_annotator = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
    detections = sv.Detections.from_sam(sam_result=sam_result)
    annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)

    # Convertir de nuevo a formato RGB y luego a PIL Image para Gradio
    annotated_image_rgb = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
    annotated_image_pil = Image.fromarray(annotated_image_rgb)

    return image_pil, annotated_image_pil

# Construcción de la interfaz Gradio
with gr.Blocks() as demo:
    gr.Markdown("# SAM - Segmentación de Imágenes")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Cargar Imagen")
            submit_button = gr.Button("Segmentar")
        with gr.Column():
            original_image = gr.Image(type="pil", label="Imagen Original")
            segmented_image = gr.Image(type="pil", label="Imagen Segmentada")
    
    submit_button.click(
        process_image, 
        inputs=input_image, 
        outputs=[original_image, segmented_image]
    )
    with gr.Row():
        gr.Examples(
            examples=EXAMPLES,
            fn=process_image,
            inputs=[input_image],
            outputs=[original_image, segmented_image],
            cache_examples=False,
            run_on_click=True
        )
demo.launch(debug=True)