import os

import gradio as gr
import pandas as pd
from gradio.themes import colors
from transformers import AutoTokenizer

os.environ['TOKENIZERS_PARALLELISM'] = "false"


# Function to map tokenized text to IDs
def inference(
        text="",
        model_id="openai/clip-vit-large-patch14",
        progress=gr.Progress()
) -> (list[str, str], list[str, str], pd.DataFrame):
    if text == "":
        return [], [], pd.DataFrame()

    progress(0, desc='Loading tokenizer...')
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Use tokenizer to tokenize the text
    progress(0.5, desc='Tokenizing text...')
    text_inputs = tokenizer(text, return_tensors='pt')

    input_ids = text_inputs['input_ids'].tolist()[0]  # Convert tensor to list

    # Create pairs of tokens and IDs
    tokens = [tokenizer.decode([id_]) for id_ in input_ids]
    token_pairs = []

    for token, id_ in zip(tokens, input_ids):
        token_pairs.append((token, str(id_)))

    # Count the number of characters and tokens
    pos_count = pd.DataFrame({
        "Char Count": [len(text)],
        "Token Count": [len(token_pairs)]
    })

    # Create list of special tokens
    special_tokens = []
    for k, v in tokenizer.special_tokens_map.items():
        if k == 'additional_special_tokens':
            continue
        sp_token_map = [str(k), str(v)]
        special_tokens.append(sp_token_map)

    return token_pairs, special_tokens, pos_count


if __name__ == '__main__':
    iface = gr.Interface(
        fn=inference,
        inputs=[
            gr.Textbox(label="Text"),
            gr.Dropdown(
                label="Model",
                choices=[
                    "openai/clip-vit-large-patch14",
                    "google/gemma-7b",
                    "google/gemma-2-9b",
                    "google-bert/bert-base-uncased",
                    "google/flan-t5-base",
                    "openai-community/gpt2",
                    "rinna/japanese-gpt-1b",
                    "cyberagent/open-calm-7b",
                ],
                value="openai/clip-vit-large-patch14"
            ),
        ],
        outputs=[
            gr.Highlightedtext(label="Highlighted Text"),
            gr.Highlightedtext(label="Special Tokens", combine_adjacent=True, adjacent_separator=' / '),
            gr.Dataframe(label="Position Count"),
        ],
        examples=[
            ["When I told my computer I needed a break, it froze.", "openai/clip-vit-large-patch14"],
            ["Yesterday, I thought my cat was studying for her degree in philosophy because she sat on my book, "
             "but turns out she was just trying to hatch a plot to steal my dinner.", "openai/clip-vit-large-patch14"],
            ["The square root of x is the cube root of y. What is y to the power of 2, if x = 4?",
             "google/flan-t5-base"],
            ["In my home country, it's a custom to say 'いただきマサチューセッツ' before we start eating a meal.",
             "google/gemma-7b"],
            ["日本で一番高い山は富士山ですが、二番目に高い山は何ですか?", "rinna/japanese-gpt-1b"],
        ],
        cache_examples=True,
        title="TokenVisor 👀",
        description="Visualize how the Tokenizer used in Hugging Face's Transformers library tokenizes text.",
        theme=gr.Theme(primary_hue=colors.green, secondary_hue=colors.yellow),
        allow_flagging="never",

    )
    iface.queue().launch()