import gradio as gr
from datasets import load_dataset
import random
import math
from datasets import load_dataset
import gradio as gr
import os

mydataset_private = load_dataset("glitchbench/GlitchBench")["validation"]
dataset_size = len(mydataset_private)

GRID_SIZE = (4, 1)


def get_item_data(image_index):
    item = mydataset_private[image_index]
    return item


def show_random_samples():
    total = GRID_SIZE[0] * GRID_SIZE[1]
    random_indexes = random.sample(range(dataset_size), total)

    all_examples = [get_item_data(index) for index in random_indexes]
    all_inputs_left_right = []
    for example_idx, example in enumerate(all_examples):
        all_inputs_left_right.append(example["image"])
        all_inputs_left_right.append(example["source"])
        all_inputs_left_right.append(example["glitch-type"])
        all_inputs_left_right.append(example["reddit"])
        all_inputs_left_right.append("Secret")

    return all_inputs_left_right


def make_grid(grid_size):
    list_of_components = []

    with gr.Row():
        for row_counter in range(grid_size[0]):
            with gr.Column():
                for col_counter in range(grid_size[1]):
                    item_image = gr.Image()
                    with gr.Accordion("Click for details", open=False):
                        item_glitch_source = gr.Textbox(label="Glitch Source")
                        item_reddit = gr.Textbox(label="Glitch Type")
                        item_id = gr.Textbox(label="Reddit ID")
                        item_description = gr.Textbox(label="Description")

                    list_of_components.append(item_image)
                    list_of_components.append(item_glitch_source)
                    list_of_components.append(item_reddit)
                    list_of_components.append(item_id)
                    list_of_components.append(item_description)

    return list_of_components


with gr.Blocks(title="GltichBench") as browser:
    gr.Markdown("## GlitchBench dataset explorer")

    with gr.Column():
        random_btn = gr.Button("Random Sample")
        with gr.Row():
            grid = make_grid(GRID_SIZE)
    random_btn.click(show_random_samples, inputs=[], outputs=[*grid])


browser.launch()