import os

import pandas as pd
import streamlit as st
from transformers import AutoTokenizer

if os.getenv("SPACE_ID"):
    USE_HF_SPACE = True
    os.environ["HF_HOME"] = "/data/.huggingface"
    os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface"
else:
    USE_HF_SPACE = False

DEFAULT_TOKENIZER_NAME = os.environ.get(
    "DEFAULT_TOKENIZER_NAME", "tohoku-nlp/bert-base-japanese-v3"
)
DEFAULT_TEXT = """
hello world!
こんにちは、世界!
你好,世界
""".strip()

DEFAULT_COLOR = "gray"
COLORS_CYCLE = [
    "yellow",
    "cyan",
]


def color_cycle_generator():
    def _color_cycle_generator():
        while True:
            for color in COLORS_CYCLE:
                yield color

    return _color_cycle_generator()


@st.cache_resource
def get_tokenizer(tokenizer_name: str = DEFAULT_TOKENIZER_NAME):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    return tokenizer


def main():
    st.set_page_config(
        page_title="TokenViz: AutoTokenizer Visualization Tool",
        layout="centered",
        initial_sidebar_state="auto",
    )

    st.title("TokenViz: AutoTokenizer Visualization Tool")
    st.text_input(
        "AutoTokenizer model name", key="tokenizer_name", value=DEFAULT_TOKENIZER_NAME
    )
    if st.session_state.tokenizer_name:
        tokenizer = get_tokenizer(st.session_state.tokenizer_name)
    st.text_input("subword prefix", key="subword_prefix", value="##")
    st.text_area("text", key="text", height=200, value=DEFAULT_TEXT)
    # Submit
    if st.button("tokenize"):
        text = st.session_state.text.strip()
        subword_prefix = st.session_state.subword_prefix.strip()
        token_ids = tokenizer.encode(text, add_special_tokens=True)
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        total_tokens = len(tokens)
        token_table_df = pd.DataFrame(
            {
                "token_id": token_ids,
                "token": tokens,
            }
        )

        st.subheader("visualized tokens")
        st.markdown(f"total tokens: **{total_tokens}**")
        tab_main, tab_token_table = st.tabs(["tokens", "table"])

        color_gen = color_cycle_generator()
        with tab_main:
            current_subword_color = next(color_gen)
            token_html = ""
            for idx, (token_id, token) in enumerate(zip(token_ids, tokens)):
                if len(subword_prefix) == 0:
                    token_border = f"1px solid {DEFAULT_COLOR}"
                else:
                    current_token_is_subword = token.startswith(subword_prefix)
                    next_token_is_subword = idx + 1 < total_tokens and tokens[
                        idx + 1
                    ].startswith(subword_prefix)

                    if next_token_is_subword and not current_token_is_subword:
                        current_subword_color = next(color_gen)

                    if current_token_is_subword or next_token_is_subword:
                        token_border = f"1px solid {current_subword_color}"
                    else:
                        token_border = f"1px solid {DEFAULT_COLOR}"

                html_escaped_token = token.replace("<", "&lt;").replace(">", "&gt;")
                token_html += f'<span title="{str(token_id)}" style="border: {token_border}; border-radius: 3px; padding: 2px; margin: 2px;">{html_escaped_token}</span>'
            st.html(
                f"<p style='line-height:2em;'>{token_html}</p>",
            )

            st.subheader("token_ids")

            token_ids_str = ",".join(map(str, token_ids))
            st.code(token_ids_str)

        with tab_token_table:
            st.table(token_table_df)


if __name__ == "__main__":
    main()