from datasets import load_dataset
from functools import partial
from pandas import DataFrame
import earthview as ev
import utils
import gradio as gr 
import tqdm
import os

DEBUG = False      # False, "random", "samples"

if DEBUG == "random":
    import numpy as np

def open_dataset(dataset, subset, split, batch_size, shard, only_rgb, state):

    nshards = ev.get_nshards(subset)
    
    if shard == -1:
        shards = None
    else:
        shards = [shard]

    if DEBUG == "random":
        ds = range(batch_size)
    elif DEBUG == "samples":
        ds = ev.load_parquet(subset, batch_size=batch_size)
    elif not DEBUG:
        ds = ev.load_dataset(subset, dataset=dataset, split=split, shards=shards, cache_dir="dataset")
    
    dsi = iter(ds)

    state["subset"]  = subset
    state["dsi"] = dsi
    return (
        gr.update(label=f"Shard (max {nshards})", value=shard, maximum=nshards),
        *get_images(batch_size, only_rgb, state),
        state
    )
    
def get_images(batch_size, only_rgb, state):
    try:
        subset = state["subset"]
    except KeyError:
        raise gr.Error("You need to load a Dataset first")

    images = []
    metadatas = []
    
    for i in tqdm.trange(batch_size, desc=f"Getting images"):
        if DEBUG == "random":
            images.append(np.random.randint(0,255,(384,384,3)))
            if not only_rgb:
                images.append(np.random.randint(0,255,(100,100,3)))

            metadatas.append({"bounds":[[1,1,4,4]], })
        else:
            try:
                item = next(state["dsi"])
            except StopIteration:
                break
            item = ev.item_to_images(subset, item)
            metadata = item["metadata"]

            if  subset == "satellogic":
                images.extend(item["rgb"])
                if not only_rgb:
                    images.extend(item["1m"])
            elif  subset == "sentinel_1":
                images.extend(item["10m"])
            elif  subset == "sentinel_2":
                images.extend(item["rgb"])
                if not only_rgb:
                    images.extend(item["10m"])
                    images.extend(item["20m"])
                    images.extend(item["scl"])
            elif  subset == "neon":
                images.extend(item["rgb"])
                if not only_rgb:
                    images.extend(item["chm"])
                    images.extend(item["1m"])

            metadata["map"] = f'<a href="{utils.get_google_map_link(item, subset)}" target="about:_blank">🧭</a>'
            metadatas.append(metadata)

    return images, DataFrame(metadatas)

def update_shape(columns):
    return gr.update(columns=columns)

def new_state():
    return gr.State({})

if __name__ == "__main__":
    with gr.Blocks(title="EarthView Viewer", fill_height = True) as demo:
        state = new_state()

        gr.Markdown(f"# Viewer for [{ev.DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset")
        batch_size = gr.Number(10, label = "Batch Size", render=False)
        shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False)
        table = gr.DataFrame(render = False, datatype="html")
        # headers=["Index","TimeStamp","Bounds","CRS"], 

        gallery = gr.Gallery(
            label=ev.DATASET,
            interactive=False,
            object_fit="scale-down",
            columns=5, render=False)

        with gr.Row():
            dataset = gr.Textbox(label="Dataset", value=ev.DATASET, interactive=False)
            subset = gr.Dropdown(choices=ev.get_subsets(), label="Subset", value="satellogic", )
            split = gr.Textbox(label="Split", value="train")
            initial_shard = gr.Number(label = "Initial shard", value=10, info="-1 for whole dataset")
            only_rgb = gr.Checkbox(label="Only RGB", value=True)

            gr.Button("Load (minutes)").click(
                open_dataset,
                inputs=[dataset, subset, split, batch_size, initial_shard, only_rgb, state],
                outputs=[shard, gallery, table, state])

        gallery.render()
        
        with gr.Row():
            batch_size.render()

            columns = gr.Number(5, label="Columns")

            columns.change(update_shape, [columns], [gallery])

        with gr.Row():
            shard.render()
            shard.release(
                open_dataset,
                inputs=[dataset, subset, split, batch_size, shard, only_rgb, state],
                outputs=[shard, gallery, table, state])

            btn = gr.Button("Next Batch (same shard)", scale=0)
            btn.click(get_images, [batch_size, only_rgb, state], [gallery, table])
            btn.click()
        
        table.render()

    demo.launch(show_api=False)