import streamlit as st
import io
import os
import yaml
import pyarrow
import tokenizers

os.environ["TOKENIZERS_PARALLELISM"] = "true"

# Setting page config to wide mode
st.set_page_config(layout="wide")

@st.cache_resource
def from_library():
    from retro_reader import RetroReader
    from retro_reader import constants as C
    return C, RetroReader

C, RetroReader = from_library()

my_hash_func = {
    io.TextIOWrapper: lambda _: None,
    pyarrow.lib.Buffer: lambda _: 0,
    tokenizers.Tokenizer: lambda _: None,
    tokenizers.AddedToken: lambda _: None
}

@st.cache_resource(hash_funcs=my_hash_func)
def load_en_electra_base_model():
    config_file = "configs/inference_en_electra_base.yaml"
    return RetroReader.load(config_file=config_file)

@st.cache_resource(hash_funcs=my_hash_func)
def load_en_electra_large_model():
    config_file = "configs/inference_en_electra_large.yaml"
    return RetroReader.load(config_file=config_file)

RETRO_READER_HOST = {
    "google/electra-base-discriminator": load_en_electra_base_model(),
    "google/electra-large-discriminator": load_en_electra_large_model(),
}

def display_top_predictions(nbest_preds, top_k=10):
    # Assuming nbest_preds might be a dictionary with a key that contains the list
    if not isinstance(nbest_preds, list):
        nbest_preds = nbest_preds['id-01']  # Adjust key as per actual structure

    sorted_preds = sorted(nbest_preds, key=lambda x: x['probability'], reverse=True)[:top_k]
    st.markdown("### Top Predictions")
    for i, pred in enumerate(sorted_preds, 1):
        st.markdown(f"**{i}. {pred['text']}** - Probability: {pred['probability']*100:.2f}%")

def main():
    # Sidebar Introduction
    st.sidebar.title("πŸ“ Welcome to Retro Reader")
    st.sidebar.write("""
    MRC-RetroReader is a machine reading comprehension (MRC) model designed for reading comprehension tasks. The model leverages advanced neural network architectures to provide high accuracy in understanding and responding to textual queries.
    """)
    image_url = "img.jpg"  # Replace this URL with your actual image URL or local path
    st.sidebar.image(image_url, use_column_width=True)
    st.sidebar.title("Contributors")
    st.sidebar.write("""
    - Phan Van Hoang
    - Pham Long Khanh
    """)

    st.title("Retrospective Reader Demo")
    st.markdown("## Model name🚨")
    option = st.selectbox(
        label="Choose the model used in retro reader",
        options=(
            "[1] google/electra-base-discriminator",
            "[2] google/electra-large-discriminator"
        ),
        index=1,
    )
    lang_code, model_name = option.split(" ")
    retro_reader = RETRO_READER_HOST[model_name]

    lang_prefix = "EN"
    height = 200
    return_submodule_outputs = True

    with st.form(key="my_form"):
        st.markdown("## Type your query ❓")
        query = st.text_input(
            label="",
            value=getattr(C, f"{lang_prefix}_EXAMPLE_QUERY"),
            max_chars=None,
            help=getattr(C, f"{lang_prefix}_QUERY_HELP_TEXT"),
        )
        st.markdown("## Type your query πŸ’¬")
        context = st.text_area(
            label="",
            value=getattr(C, f"{lang_prefix}_EXAMPLE_CONTEXTS"),
            height=height,
            max_chars=None,
            help=getattr(C, f"{lang_prefix}_CONTEXT_HELP_TEXT"),
        )
        submit_button = st.form_submit_button(label="Submit")
        
    if submit_button:
        with st.spinner("πŸ•’ Please wait.."):
            outputs = retro_reader(query=query, context=context, return_submodule_outputs=return_submodule_outputs)
        answer, score = outputs[0]["id-01"], outputs[1]
        if not answer:
            answer = "No answer"
        st.markdown("## πŸ“œ Results")
        st.write(answer)
        if return_submodule_outputs:
            score_ext, nbest_preds, score_diff = outputs[2:]
            display_top_predictions(nbest_preds)

if __name__ == "__main__":
    main()