Spaces:
Runtime error
Runtime error
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import torch | |
st.markdown("Hello!") | |
bert_mlm_positive = torch.load("bert_mlm_positive.pth", map_location='cpu') | |
bert_mlm_negative = torch.load("bert_mlm_negative.pth", map_location='cpu') | |
bert_classifier = torch.load("bert_classifier.pth", map_location='cpu') | |
tokenizer = torch.load("tokenizer.pth", map_location='cpu') | |
bert_mlm_positive.eval(); | |
bert_mlm_negative.eval(); | |
bert_classifier.eval(); | |
user_input = st.text_input("Please enter your thoughts:") | |
def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3): | |
""" | |
- split the sentence into tokens using the INGSOC-approved BERT tokenizer | |
- find :num_tokens: tokens with the highest ratio (see above) | |
- replace them with :k_best: words according to bert_mlm_positive | |
:return: a list of all possible strings (up to k_best * num_tokens) | |
""" | |
sentence_ix = tokenizer(sentence, return_tensors="pt") | |
length = len(sentence_ix['input_ids'][0]) | |
# we can't replace more tokens than we have | |
num_tokens = min(num_tokens, length-2) | |
probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0] | |
probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0] | |
# ^-- shape is [seq_length, vocab_size] | |
# Находим вероятности токенов для моделей | |
p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]] | |
p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]] | |
ratio = (p_tokens_positive + epsilon) / (p_tokens_negative + epsilon) | |
ratio = ratio[1:-1].detach().numpy() # do not change [CLS] and [SEP] | |
# ratio len is length - 2 | |
replacements = [] | |
# take indices of num_tokens of tokens with highest ratio | |
ind = np.argpartition(-ratio, -num_tokens)[-num_tokens:] | |
# for each token find k_best replacements | |
for i in ind: | |
# take probabilities of tokens for replacement | |
# note that we need ind + 1, since [CLS] is 0th token | |
tokens_probs = probs_positive[ind + 1, :][0].detach().numpy() | |
prob_ind_top_k = np.argpartition(tokens_probs, -k_best)[-k_best:] | |
for new_token in prob_ind_top_k: | |
new_tokens = tokenizer.encode(sentence) | |
new_tokens[i+1] = new_token | |
replacements.append(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(new_tokens)[1:-1])) | |
return replacements | |
def get_sent_score(sentence): | |
sentence_ix = tokenizer(sentence, return_tensors="pt") | |
# negative is class 1 | |
return bert_classifier(**sentence_ix).logits[0][1].detach().numpy() | |
if len(user_input.split()) > 0: | |
st.markdown(f"Original sentence negativity: {get_sent_score(user_input)}") | |
num_iter = 5 | |
M = 3 | |
num_tokens = 3 | |
k_best = 3 | |
fix_list =[user_input] | |
for j in range(num_iter): | |
replacements = [] | |
for cur_sent in fix_list: | |
replacements.extend(get_replacements(cur_sent, num_tokens=num_tokens, k_best=k_best)) | |
replacements = pd.DataFrame(replacements, columns = ['new_sentence']) | |
replacements['new_scores'] = replacements['new_sentence'].apply(get_sent_score) | |
replacements = replacements.nsmallest(M, 'new_scores') | |
fix_list = replacements.new_sentence.to_list() | |
for new_sentence in fix_list: | |
st.markdown(f"New sentence:") | |
st.markdown(f"{new_sentence}") | |
st.markdown(f"New sentence negativity: {get_sent_score(new_sentence)}") |