import streamlit as st
import numpy as np
import pandas as pd
import torch

st.markdown("Hello!")

bert_mlm_positive  = torch.load("bert_mlm_positive.pth", map_location='cpu')
bert_mlm_negative  = torch.load("bert_mlm_negative.pth", map_location='cpu')
bert_classifier  = torch.load("bert_classifier.pth", map_location='cpu')
tokenizer  = torch.load("tokenizer.pth", map_location='cpu')


bert_mlm_positive.eval();
bert_mlm_negative.eval();
bert_classifier.eval();


user_input = st.text_input("Please enter your thoughts:")

def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
  """
  - split the sentence into tokens using the INGSOC-approved BERT tokenizer
  - find :num_tokens: tokens with the highest ratio (see above)
  - replace them with :k_best: words according to bert_mlm_positive
  :return: a list of all possible strings (up to k_best * num_tokens)
  """  
  sentence_ix = tokenizer(sentence, return_tensors="pt") 
  length = len(sentence_ix['input_ids'][0])

  # we can't replace more tokens than we have
  num_tokens = min(num_tokens, length-2)

  probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0]
  probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0]
  # ^-- shape is [seq_length, vocab_size]

  # Находим вероятности токенов для моделей
  p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]]
  p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]]
  
  ratio = (p_tokens_positive + epsilon) / (p_tokens_negative + epsilon)
  ratio = ratio[1:-1].detach().numpy()  # do not change  [CLS] and [SEP]
  # ratio len is length - 2

  replacements = []
  # take indices of num_tokens of tokens with highest ratio  
  ind = np.argpartition(-ratio, -num_tokens)[-num_tokens:]
  # for each token find k_best replacements
  for i in ind:
    # take probabilities of tokens for replacement
    # note that we need ind + 1, since [CLS] is 0th token
    tokens_probs = probs_positive[ind + 1, :][0].detach().numpy()    
    prob_ind_top_k = np.argpartition(tokens_probs, -k_best)[-k_best:]
    for new_token in prob_ind_top_k:
      new_tokens = tokenizer.encode(sentence)      
      new_tokens[i+1] = new_token
      replacements.append(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(new_tokens)[1:-1]))

  return replacements


def get_sent_score(sentence):
    sentence_ix = tokenizer(sentence, return_tensors="pt")
    # negative is class 1
    return bert_classifier(**sentence_ix).logits[0][1].detach().numpy()


if len(user_input.split()) > 0:
    st.markdown(f"Original sentence negativity: {get_sent_score(user_input)}")
    
    num_iter = 5
    M = 3
    num_tokens = 3
    k_best = 3


    fix_list =[user_input]
    for j in range(num_iter):    
        replacements = []
    for cur_sent in fix_list:      
        replacements.extend(get_replacements(cur_sent, num_tokens=num_tokens, k_best=k_best))  
    replacements = pd.DataFrame(replacements, columns = ['new_sentence'])
    replacements['new_scores'] = replacements['new_sentence'].apply(get_sent_score)
    replacements = replacements.nsmallest(M, 'new_scores')    
    fix_list = replacements.new_sentence.to_list()

    for new_sentence in fix_list:
        st.markdown(f"New sentence:") 
        st.markdown(f"{new_sentence}")
        st.markdown(f"New sentence negativity: {get_sent_score(new_sentence)}")