import os
import requests

# Disable JIT
os.environ["PYTORCH_JIT"] = "0"

from einops import rearrange
import gradio as gr 
import numpy as np
import spaces
import torch 
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageOps
from transformers import AutoModel, CLIPImageProcessor
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from segment_anything.modeling.image_encoder import ImageEncoderViT


class RADIOVenc(nn.Module):
    def __init__(self, radio: nn.Module, img_enc: ImageEncoderViT, img_size: int = 1024):
        super().__init__()
        self.radio = radio
        self.neck = img_enc.neck
        self.img_size = img_size
        self.dtype = radio.input_conditioner.dtype

    def forward(self, x: torch.Tensor):
        h, w = x.shape[-2:]

        if self.dtype is not None:
            x = x.to(dtype=self.dtype)

        with torch.autocast('cuda', dtype=torch.bfloat16, enabled=self.dtype is None):
            output = self.radio(x)
        features = output["sam"].features

        rows = h // 16
        cols = w // 16

        features = rearrange(features, 'b (h w) c -> b c h w', h=rows, w=cols)

        features = self.neck(features)

        return features

    
def download_file(url, save_path):
    # Check if the file already exists
    if os.path.exists(save_path):
        print(f"File already exists at {save_path}. Skipping download.")
        return
    
    print(f"Downloading from {url}")

    # Send a GET request to the URL
    response = requests.get(url, stream=True)
    
    # Check if the request was successful
    if response.status_code == 200:
        # Open the file in binary write mode
        with open(save_path, 'wb') as file:
            # Iterate over the response content in chunks
            for chunk in response.iter_content(chunk_size=1024):
                if chunk:  # filter out keep-alive new chunks
                    file.write(chunk)
        print(f"File downloaded successfully and saved as {save_path}")
    else:
        print(f"Failed to download file. HTTP Status Code: {response.status_code}")
    

hf_repo = "nvidia/RADIO-L"
image_processor = CLIPImageProcessor.from_pretrained(hf_repo)

model_version = "radio_v2.5-l" # for RADIOv2.5-L model (ViT-L/16)

model = torch.hub.load(
    'NVlabs/RADIO',
    'radio_model',
    version=model_version,
    progress=True,
    skip_validation=True,
    adaptor_names='sam')
model.eval()
    
local_sam_checkpoint_path = "sam_vit_h_4b8939.pth"
download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", local_sam_checkpoint_path)    
sam = sam_model_registry["vit_h"](checkpoint=local_sam_checkpoint_path)
model._patch_size = 16
sam.image_encoder = RADIOVenc(model, sam.image_encoder, img_size=1024)
conditioner = model.make_preprocessor_external()    
sam.pixel_mean = conditioner.norm_mean * 255
sam.pixel_std = conditioner.norm_std * 255


def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
    # features: (N, C)
    # m: a hyperparam controlling how many std dev outside for outliers
    assert len(features.shape) == 2, "features should be (N, C)"
    reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
    colors = features @ reduction_mat
    if remove_first_component:
        colors_min = colors.min(dim=0).values
        colors_max = colors.max(dim=0).values
        tmp_colors = (colors - colors_min) / (colors_max - colors_min)
        fg_mask = tmp_colors[..., 0] < 0.2
        reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
        colors = features @ reduction_mat
    else:
        fg_mask = torch.ones_like(colors[:, 0]).bool()
    d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
    mdev = torch.median(d, dim=0).values
    s = d / mdev
    try:
        rins = colors[fg_mask][s[:, 0] < m, 0]
        gins = colors[fg_mask][s[:, 1] < m, 1]
        bins = colors[fg_mask][s[:, 2] < m, 2]
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
    except:
        rins = colors
        gins = colors
        bins = colors
        rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
        rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])

    return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)


def get_pca_map(
    feature_map: torch.Tensor,
    img_size,
    interpolation="bicubic",
    return_pca_stats=False,
    pca_stats=None,
):
    """
    feature_map: (1, h, w, C) is the feature map of a single image.
    """
    if feature_map.shape[0] != 1:
        # make it (1, h, w, C)
        feature_map = feature_map[None]
    if pca_stats is None:
        reduct_mat, color_min, color_max = get_robust_pca(
            feature_map.reshape(-1, feature_map.shape[-1])
        )
    else:
        reduct_mat, color_min, color_max = pca_stats
    pca_color = feature_map @ reduct_mat
    pca_color = (pca_color - color_min) / (color_max - color_min)
    pca_color = pca_color.clamp(0, 1)
    pca_color = F.interpolate(
        pca_color.permute(0, 3, 1, 2),
        size=img_size,
        mode=interpolation,
    ).permute(0, 2, 3, 1)
    pca_color = pca_color.cpu().numpy().squeeze(0)
    if return_pca_stats:
        return pca_color, (reduct_mat, color_min, color_max)
    return pca_color


def pad_image_to_multiple_of(image, multiple=16):
    # Calculate the new dimensions to make them multiples
    width, height = image.size
    new_width = (width + multiple -1) // multiple * multiple
    new_height = (height + multiple -1) // multiple * multiple

    # Calculate the padding needed on each side
    pad_width = new_width - width
    pad_height = new_height - height

    left = pad_width // 2
    right = pad_width - left
    top = pad_height // 2
    bottom = pad_height - top

    # Apply the padding
    padded_image = ImageOps.expand(image, (left, top, right, bottom), fill='black')

    return padded_image


def center_crop_resize(image, size=(1024, 1024)):
    # Get dimensions
    width, height = image.size
    
    # Determine the center crop box
    if width > height:
        new_width = height
        new_height = height
        left = (width - new_width) / 2
        top = 0
        right = (width + new_width) / 2
        bottom = height
    else:
        new_width = width
        new_height = width
        left = 0
        top = (height - new_height) / 2
        right = width
        bottom = (height + new_height) / 2
    
    # Crop the image to a square
    image = image.crop((left, top, right, bottom))
    
    # Resize the cropped image to the target size
    image = image.resize(size, Image.LANCZOS)
        
    return image


def visualize_anns(orig_image: np.ndarray, anns):
    if len(anns) == 0:
        return orig_image
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)

    kernel = torch.ones(1, 1, 5, 5, dtype=torch.float32)

    # RGBA
    mask = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4), dtype=np.float32)
    mask[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])

        tm = torch.as_tensor(m).reshape(1, 1, *m.shape).float()
        cvtm = F.conv2d(tm, kernel, padding=2)

        border_mask = (cvtm < 25).flatten(0, 2).numpy()

        mask[m] = color_mask
        mask[m & border_mask, 3] *= 1.0 / 0.35

    color, alpha = mask[..., :3], mask[..., -1:]

    orig_image = orig_image.astype(np.float32) / 255
    overlay = alpha * color + (1 - alpha) * orig_image

    overlay = (overlay * 255).astype(np.uint8)
    return overlay



@spaces.GPU 
def infer_radio(image):
    """Define the function to generate the output."""
    model.cuda()
    conditioner.cuda()
    sam.cuda()
    sam_generator = SamAutomaticMaskGenerator(sam, output_mode="binary_mask")
    
    # PCA feature visalization
    padded_image=pad_image_to_multiple_of(image, multiple=256)
    width, height = padded_image.size
    pixel_values = image_processor(images=padded_image, return_tensors='pt').pixel_values
    pixel_values = pixel_values.to(torch.bfloat16).cuda()
    pixel_values = conditioner(pixel_values)
    
    _, features = model(pixel_values)["backbone"]    
    
    num_rows = height // model.patch_size
    num_cols = width // model.patch_size
    
    features = features.detach()
    features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float()
    
    pca_viz = get_pca_map(features, (height, width), interpolation='bilinear')
    
    # SAM feature visualization
    resized_image = center_crop_resize(image)
    image_array = np.array(image)
    print("image size", image_array.shape)
    #image_array = np.transpose(image_array, (2, 0, 1))
    masks = sam_generator.generate(image_array)
    overlay = visualize_anns(image_array, masks)
      
    return pca_viz, overlay, f"{features.shape}"



title = """RADIO: Reduce All Domains Into One"""

description = """
# RADIO

[AM-RADIO](https://github.com/NVlabs/RADIO) is a framework to distill Large Vision Foundation models into a single one.
RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones.
Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence.
Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images.

# Instructions

Paste an image into the input box or pick one from the gallery of examples and then click the "Submit" button.
The RADIO backbone features are processed with a PCA projection to 3 channels and displayed as an RGB channels.
The SAM features are processed using the SAM decoder and shown as an overlay on top of the input image.
"""

inputs = [
    gr.Image(type="pil")
]

outputs = [    
    gr.Image(label="PCA Feature Visalization"),
    gr.Image(label="SAM Masks"),
    gr.Textbox(label="Feature Shape"),
]

# Create the Gradio interface
demo = gr.Interface(
    fn=infer_radio,
    inputs=inputs,
    examples="./samples/",
    outputs=outputs,
    title=title,
    description=description,
    cache_examples=False
)
  
if __name__ == "__main__":  
    demo.launch()