import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from transformers import AutoTokenizer
from transformers import TFAutoModelForQuestionAnswering
from datasets import Dataset
import streamlit as st




# loading saved roberta-base tokenizer to tokenize the text into input IDs that model can make sense of.
model_checkpoint = "Modfiededition/roberta-fine-tuned-tweet-sentiment-extractor"

@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_tokenizer():
    return AutoTokenizer.from_pretrained(model_checkpoint )
tokenizer = load_tokenizer()  

@st.cache(allow_output_mutation=True, suppress_st_warning=True)
def load_model():
    return TFAutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
model = load_model()


#prompts
st.title("Tweet Sentiment Extractor...")

# take text/tweet input
textbox = st.text_area('Write your text in this box:', '',height=100,  max_chars=500 )
option = st.selectbox(
     'Pick the sentiment',
         ('positive', 'negative', 'neutral'))

python_dict = {"text":[textbox], "sentiment":[option]}

dataset = Dataset.from_dict(python_dict)

MAX_LENGTH = 105

button = st.button('Click here to extract the word/phrase from the text with the given sentiment: {0}..'.format(option))


if button:
     if not textbox:
          st.markdown("#### " +"Please write something in the above textbox..")
     else:
          with st.spinner('In progress.......'):
            
               def process_data(examples):
                    questions = examples["sentiment"]
                    context = examples["text"]
                    inputs = tokenizer(
                           questions,
                           context,
                           max_length = MAX_LENGTH,
                           padding="max_length",
                           return_offsets_mapping = True,   
                    )
                    # Assigning None values to all offset mapping of tokens which are not the context tokens.
                    for i in range(len(inputs["input_ids"])):
                         offset = inputs["offset_mapping"][i]
                         sequence_ids = inputs.sequence_ids(i)
                         inputs["offset_mapping"][i] = [
                               o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
                         ]
                    return inputs
                   
               processed_raw_data = dataset.map(
                       process_data,
                       batched = True
               )
               tf_raw_dataset = processed_raw_data.to_tf_dataset(
                       columns=["input_ids", "attention_mask"],
                       shuffle=False,
                       batch_size=1,
                   )
               
               # final predictions.
               outputs = model.predict(tf_raw_dataset)
               start_logits = outputs.start_logits
               end_logits = outputs.end_logits
                   
                   # Post Processing.
                   # Using start_logits and end_logits to generate the final answer from the given context.
               n_best = 20
               
               def predict_answers(inputs):
                    predicted_answer = []
                    for i in range(len(inputs["offset_mapping"])):
                         start_logit = inputs["start_logits"][i]
                         end_logit = inputs["end_logits"][i]
                         context = inputs["text"][i]
                         offset = inputs["offset_mapping"][i]
                         start_indexes = np.argsort(start_logit)[-1: -n_best - 1:-1].tolist()
                         end_indexes = np.argsort(end_logit)[-1: -n_best - 1: -1].tolist()
               
                         flag = False
                         for start_index in start_indexes:
                              for end_index in end_indexes:
                                   # skip answer that are not in the context.
                                   if offset[start_index] is None or offset[end_index] is None:
                                        continue
                                   # skip answer with length that is either < 0
                                   if end_index < start_index:
                                        continue
                                   flag = True
                                   answer = context[offset[start_index][0]: offset[end_index][1]]
                                   predicted_answer.append(answer)
                                   break
                              if flag:
                                   break
                         if not flag:
                              predicted_answer.append(answer)
                    return {"predicted_answer":predicted_answer}
                   
               processed_raw_data.set_format("pandas")
                   
               processed_raw_df =  processed_raw_data[:]
               processed_raw_df["start_logits"] = start_logits.tolist()
               processed_raw_df["end_logits"] = end_logits.tolist()
               processed_raw_df["text"] = python_dict["text"]
                   
               final_data = Dataset.from_pandas(processed_raw_df)
               final_data = final_data.map(predict_answers,batched=True)
              
            
          
          st.markdown("## " +final_data["predicted_answer"][0])