import gradio as gr
import torch
import cv2
import numpy as np
import json
from unidepth.models import UniDepthV2
import os
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image


# Load model configurations and initialize model
def load_model(config_path, model_path, encoder, device):
    with open(config_path) as f:
        config = json.load(f)

    model = UniDepthV2(config)
    model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True)
    model = model.to(device).eval()

    return model

# Inference function
def depth_estimation(image, model_path, encoder='vits'):
    try:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        # device = 'cpu'
        config_path = 'configs/config_v2_vits14.json'
        
        # Ensure model path exists or download if needed
        model_path="checkpoint/latest.pth"
        if not os.path.exists(model_path):
            return "Model checkpoint not found. Please upload a valid model path."
        
        model = load_model(config_path, model_path, encoder, device)

        # Preprocess image
        rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device)  # C, H, W
        predictions = model.infer(rgb)
        depth = predictions["depth"].squeeze().to(device).numpy()

        min_depth = depth.min()
        max_depth = depth.max()

        depth_normalized = (depth - min_depth) / (max_depth - min_depth)

        # Apply colormap
        cmap = matplotlib.colormaps.get_cmap('Spectral')
        depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8)

        # Create a figure and axis for the colorbar
        fig, ax = plt.subplots(figsize=(6, 0.4))
        fig.subplots_adjust(bottom=0.5)

        # Create a colorbar
        norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)')

        # Save the colorbar to a BytesIO object
        from io import BytesIO
        buf = BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
        plt.close(fig)
        buf.seek(0)

        # Open the colorbar image
        colorbar_img = Image.open(buf)

        # Create a new image with space for the colorbar
        new_height = depth_color.shape[0] + colorbar_img.size[1]
        new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255))

        # Paste the depth image and colorbar
        new_img.paste(Image.fromarray(depth_color), (0, 0))
        new_img.paste(colorbar_img, (0, depth_color.shape[0]))

        return new_img
    

    except Exception as e:
        return f"Error occurred: {str(e)}"

# Gradio Interface
def main():
    iface = gr.Interface(
        fn=depth_estimation,
        inputs=[
            gr.Image(type="numpy", label="Input Image"),
            gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'),
        ],
        outputs=[
            gr.Image(type="pil", label="Predicted Depth")
        ],
        title="Metric Depth Estimation",
        description="Upload an image to get its estimated depth map using Depth Anything V2.",
    )

    iface.launch()


if __name__ == "__main__":
    main()