import os

import gdown as gdown
import nltk
import streamlit as st
from nltk.tokenize import sent_tokenize

from source.pipeline import MultiLabelPipeline, inputs_to_dataset


def download_models(ids):
    """
    Download all models.
    :param ids: name and links of models
    :return:
    """

    # Download sentence tokenizer
    nltk.download('punkt')

    # Download model from drive if not stored locally
    for key in ids:
        if not os.path.isfile(f"model/{key}.pt"):
            url = f"https://drive.google.com/uc?id={ids[key]}"
            gdown.download(url=url, output=f"model/{key}.pt")


@st.cache
def load_labels():
    """
    Load model labels.
    :return:
    """

    return [
        "admiration",
        "amusement",
        "anger",
        "annoyance",
        "approval",
        "caring",
        "confusion",
        "curiosity",
        "desire",
        "disappointment",
        "disapproval",
        "disgust",
        "embarrassment",
        "excitement",
        "fear",
        "gratitude",
        "grief",
        "joy",
        "love",
        "nervousness",
        "optimism",
        "pride",
        "realization",
        "relief",
        "remorse",
        "sadness",
        "surprise",
        "neutral"
    ]


@st.cache(allow_output_mutation=True)
def load_model(model_path):
    """
    Load model and cache it.
    :param model_path: path to model
    :return:
    """

    model = MultiLabelPipeline(model_path=model_path)

    return model


# Page config
st.set_page_config(layout="centered")
st.title("Multiclass Emotion Classification")
st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ")

maintenance = False
if maintenance:
    st.write("Unavailable for now (file downloads limit). ")
else:
    # Variables
    ids = {'perceiver-go-emotions': st.secrets['model']}
    labels = load_labels()

    # Download all models from drive
    download_models(ids)

    # Display labels
    st.markdown(f"__Labels:__ {', '.join(labels)}")

    # Model selection
    left, right = st.columns([4, 2])
    inputs = left.text_area('', max_chars=4096, value='This is a space about multiclass emotion classification. Write '
                                                      'something here to see what happens!')
    model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ')
    split = right.checkbox('Split into sentences', value=True)
    model = load_model(model_path=f"model/{model_path}.pt")
    right.write(model.device)

    if split:
        if not inputs.isspace() and inputs != "":
            with st.spinner('Processing text... This may take a while.'):
                left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1))
    else:
        if not inputs.isspace() and inputs != "":
            with st.spinner('Processing text... This may take a while.'):
                left.write(model(inputs_to_dataset([inputs]), batch_size=1))