import glob
import gradio as gr
import pandas as pd
import faiss
import clip
import torch
from huggingface_hub import hf_hub_download, snapshot_download

title = r"""
<h1 align="center" id="space-title"> 🔍 Search Similar Text/Image in the Dataset</h1>
"""

description = r"""

Find text or images similar to your query text with this demo. Currently, it supports text search only.<br>
In this demo, we use a subset of [danbooru22](https://huggingface.co/datasets/animelover/danbooru2022) or [DiffusionDB](https://huggingface.co/datasets/poloclub/diffusiondb) instead of [LAION](https://laion.ai/blog/laion-400-open-dataset/) because LAION is currently not available. 
<br>
The content will be updated to include image search once LAION is available.

The code is based on [clip-retrieval](https://github.com/rom1504/clip-retrieval) and [autofaiss](https://github.com/criteo/autofaiss)

"""

# From local file
# INDEX_DIR = "dataset/diffusiondb/text_index_folder"
# IND = faiss.read_index(f"{INDEX_DIR}/text.index")
# TEXT_LIST = pd.concat(
#     pd.read_parquet(file) for file in glob.glob(f"{INDEX_DIR}/metadata/*.parquet")
# )['caption'].tolist()

def download_all_index(dataset_dict):
    for k in dataset_dict:
        load_faiss_index(k)

def load_faiss_index(dataset):
    index_dir = "data/faiss_index"
    dataset = DATASET_NAME[dataset]

    hf_hub_download(
        repo_id="Eun02/text_image_faiss_index",
        subfolder=dataset,
        filename="text.index",
        repo_type="dataset",
        local_dir=index_dir,
    )

    # Download text file
    snapshot_download(
        repo_id="Eun02/text_image_faiss_index", 
        allow_patterns=f"{dataset}/*.parquet",
        repo_type="dataset",
        local_dir=index_dir,
    )
    index = faiss.read_index(f"{index_dir}/{dataset}/text.index")
    text_list = pd.concat(
        pd.read_parquet(file) for file in sorted(glob.glob(f"{index_dir}/{dataset}/metadata/*.parquet"))
    )['caption'].tolist()

    return index, text_list

def change_index(dataset):
    global INDEX, TEXT_LIST, PREV_DATASET
    if PREV_DATASET != dataset:
        gr.Info("Load index...")
        INDEX, TEXT_LIST = load_faiss_index(dataset)
        PREV_DATASET = dataset
        gr.Info("Done!!")
    return None

@torch.inference_mode
def get_emb(text, device="cpu"):
    text_tokens = clip.tokenize([text], truncate=True)
    text_features = CLIP_MODEL.encode_text(text_tokens.to(device))
    text_features /= text_features.norm(dim=-1, keepdim=True)
    text_embeddings = text_features.cpu().numpy().astype('float32')
    return text_embeddings

@torch.inference_mode
def search_text(top_k, show_score, numbering_prefix, output_file, query_text):
    if query_text is None or query_text == "":
        raise gr.Error("Query text is missing")
    
    text_embeddings = get_emb(query_text, device)
    scores, retrieved_texts = INDEX.search(text_embeddings, top_k)
    scores, retrieved_texts = scores[0], retrieved_texts[0]

    result_list = [] 
    for score, ind in zip(scores, retrieved_texts):
        item_str = TEXT_LIST[ind].strip()
        if item_str == "":
            continue
        if (item_str, score) not in result_list:
            result_list.append((item_str, score))

    # Postprocessing text
    result_str = ""
    for count, (item_str, score) in enumerate(result_list):
        if numbering_prefix: 
            item_str = f"######################  {count+1}  ######################\n {item_str}"
        if show_score:
            item_str += f", {score:0.2f}"
        result_str += f"{item_str}\n"
            
    # file_name = query_text.replace(" ", "_")
    # if show_score:
    #     file_name += "_score"
    output_path = None
    if output_file:
        file_name = "output"
        output_path = f"./{file_name}.txt"
        with open(output_path, "w") as f:
            f.writelines(result_str)
    
    return result_str, output_path


# Load CLIP model
device = "cpu"
CLIP_MODEL, _ = clip.load("ViT-B/32", device=device)

# Dataset
DATASET_NAME = {
    "danbooru22": "booru22_000-300",
    "DiffusionDB": "diffusiondb",
}

DEFAULT_DATASET = "danbooru22"
PREV_DATASET = "danbooru22"

# Download needed index
download_all_index(DATASET_NAME)

# Load default index
INDEX, TEXT_LIST = load_faiss_index(DEFAULT_DATASET)


with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Row():
        dataset = gr.Dropdown(label="dataset", choices=["danbooru22", "DiffusionDB"], value=DEFAULT_DATASET)
        top_k = gr.Slider(label="top k", minimum=1, maximum=20, value=8)
        with gr.Column():
            show_score = gr.Checkbox(label="Show score", value=False)
            numbering_prefix = gr.Checkbox(label="Add numbering prefix", value=True)
            output_file = gr.Checkbox(label="Return text file", value=True)
    query_text = gr.Textbox(label="query text")
    btn = gr.Button()
    result_text = gr.Textbox(label="retrieved text", interactive=False)
    result_file = gr.File(label="output file", visible=True)
    
    #dataset.change(change_index, dataset, None)

    btn.click(
        fn=change_index,
        inputs=[dataset],
        outputs=[result_text],
    ).success(
        fn=search_text,
        inputs=[top_k, show_score, numbering_prefix, output_file, query_text],
        outputs=[result_text, result_file],
    )

demo.launch()