import gradio as gr
import os
from huggingface_hub import hf_hub_download
from pathlib import Path
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast
import json
import torch

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# temp_folder = 'temp'
# os.makedirs(temp_folder, exist_ok=True)
logit = {}
json_file = 'index.json'
with open(json_file, 'r') as file:
    data = json.load(file)
for key, value in data.items():
    text_description = value['text_description']
    inputs = tokenizer(text_description, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
    outputs = model(**inputs, labels=inputs["input_ids"])
    logits = outputs.logits
    logit[key] = logits
    # torch.save(logits, os.path.join(temp_folder, f"{key}.pt"))


def search_index(query):
    inputs = tokenizer(query, return_tensors="pt", padding="max_length", max_length=32, truncation=True)
    outputs = model(**inputs, labels=inputs["input_ids"])

    max_similarity = float('-inf')
    max_similarity_uuid = None
    # for file in os.listdir(temp_folder):
    #     uuid = file.split('.')[0]
    #     logits = torch.load(os.path.join(temp_folder, file))
    for uuid, logits in logit.items():
        similarity = (outputs.logits * logits).sum()
        if similarity > max_similarity:
            max_similarity = similarity
            max_similarity_uuid = uuid

    gr.Info(f"Found the most similar video with UUID: {max_similarity_uuid}. \n Downloading the video...")
    return max_similarity_uuid


def download_video(uuid):
    dataset_name = "quchenyuan/360x_dataset_LR"
    dataset_path = "360_dataset/binocular/"
    video_filename = f"{uuid}.mp4"

    storage_dir = Path("videos")
    storage_dir.mkdir(exist_ok=True)

    # storage_limit = 40 * 1024 * 1024 * 1024
    # current_storage = sum(f.stat().st_size for f in storage_dir.glob('*') if f.is_file())
    # if current_storage + os.path.getsize(video_filename) > storage_limit:
    #     oldest_file = min(storage_dir.glob('*'), key=os.path.getmtime)
    #     oldest_file.unlink()

    downloaded_file_path = hf_hub_download(dataset_name, dataset_path + video_filename)

    return str(storage_dir / video_filename)


def search_and_show_video(query):
    uuid = search_index(query)
    video_path = download_video(uuid)
    return video_path


if __name__ == "__main__":
    with gr.Blocks() as demo:
        with gr.Column():
            with gr.Row():
                gr.HTML("<h1><i>360+x</i> : A Panoptic Multi-modal Scene Understanding Dataset</h1>")
            with gr.Row():
                gr.HTML("<p><a href='https://x360dataset.github.io/'>Official Website</a>    <a href='https://arxiv.org/abs/2404.00989'>Paper</a></p>")
            with gr.Row():
                gr.HTML("<h2>Search for a video by entering a query below:</h2>")
            with gr.Row():
                search_input = gr.Textbox(label="Query", placeholder="Enter a query to search for a video.")
            with gr.Row():
                with gr.Column():
                    video_output_1 = gr.Video()
                with gr.Column():
                    video_output_2 = gr.Video()
                with gr.Column():
                    video_output_3 = gr.Video()
            with gr.Row():
                submit_button = gr.Button(value="Search")

        submit_button.click(search_and_show_video, search_input,
                            outputs=[video_output_1, video_output_2, video_output_3])
    demo.launch()