ikozlov90 commited on
Commit
133312a
·
1 Parent(s): e13f1f0

Update add.py

Browse files
Files changed (1) hide show
  1. add.py +82 -3
add.py CHANGED
@@ -1,12 +1,91 @@
1
  import streamlit as st
 
 
2
  import torch
3
 
4
  st.markdown("Hello!")
5
 
6
- model, tokenizer = torch.load("address", map_location='cpu')
 
 
 
 
 
 
 
 
 
7
 
8
  user_input = st.text_input("Please enter your thoughts:")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  if len(user_input.split()) > 0:
11
- print(output)
12
- st.markdown(f"{repr(model)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
  import torch
5
 
6
  st.markdown("Hello!")
7
 
8
+ bert_mlm_positive = torch.load("bert_mlm_positive.pth", map_location='cpu')
9
+ bert_mlm_negative = torch.load("bert_mlm_negative.pth", map_location='cpu')
10
+ bert_classifier = torch.load("bert_classifier.pth", map_location='cpu')
11
+ tokenizer = torch.load("tokenizer.pth", map_location='cpu')
12
+
13
+
14
+ bert_mlm_positive.eval();
15
+ bert_mlm_negative.eval();
16
+ bert_classifier.eval();
17
+
18
 
19
  user_input = st.text_input("Please enter your thoughts:")
20
 
21
+ def get_replacements(sentence: str, num_tokens, k_best, epsilon=1e-3):
22
+ """
23
+ - split the sentence into tokens using the INGSOC-approved BERT tokenizer
24
+ - find :num_tokens: tokens with the highest ratio (see above)
25
+ - replace them with :k_best: words according to bert_mlm_positive
26
+ :return: a list of all possible strings (up to k_best * num_tokens)
27
+ """
28
+ sentence_ix = tokenizer(sentence, return_tensors="pt")
29
+ length = len(sentence_ix['input_ids'][0])
30
+
31
+ # we can't replace more tokens than we have
32
+ num_tokens = min(num_tokens, length-2)
33
+
34
+ probs_positive = bert_mlm_positive(**sentence_ix).logits.softmax(dim=-1)[0]
35
+ probs_negative = bert_mlm_negative(**sentence_ix).logits.softmax(dim=-1)[0]
36
+ # ^-- shape is [seq_length, vocab_size]
37
+
38
+ # Находим вероятности токенов для моделей
39
+ p_tokens_positive = probs_positive[torch.arange(length), sentence_ix['input_ids'][0]]
40
+ p_tokens_negative = probs_negative[torch.arange(length), sentence_ix['input_ids'][0]]
41
+
42
+ ratio = (p_tokens_positive + epsilon) / (p_tokens_negative + epsilon)
43
+ ratio = ratio[1:-1].detach().numpy() # do not change [CLS] and [SEP]
44
+ # ratio len is length - 2
45
+
46
+ replacements = []
47
+ # take indices of num_tokens of tokens with highest ratio
48
+ ind = np.argpartition(-ratio, -num_tokens)[-num_tokens:]
49
+ # for each token find k_best replacements
50
+ for i in ind:
51
+ # take probabilities of tokens for replacement
52
+ # note that we need ind + 1, since [CLS] is 0th token
53
+ tokens_probs = probs_positive[ind + 1, :][0].detach().numpy()
54
+ prob_ind_top_k = np.argpartition(tokens_probs, -k_best)[-k_best:]
55
+ for new_token in prob_ind_top_k:
56
+ new_tokens = tokenizer.encode(sentence)
57
+ new_tokens[i+1] = new_token
58
+ replacements.append(tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(new_tokens)[1:-1]))
59
+
60
+ return replacements
61
+
62
+
63
+ def get_sent_score(sentence):
64
+ sentence_ix = tokenizer(sentence, return_tensors="pt")
65
+ # negative is class 1
66
+ return bert_classifier(**sentence_ix).logits[0][1].detach().numpy()
67
+
68
+
69
  if len(user_input.split()) > 0:
70
+ st.markdown(f"Original sentence negativity: {get_sent_score(user_input)}")
71
+
72
+ num_iter = 5
73
+ M = 3
74
+ num_tokens = 3
75
+ k_best = 3
76
+
77
+
78
+ fix_list =[user_input]
79
+ for j in range(num_iter):
80
+ replacements = []
81
+ for cur_sent in fix_list:
82
+ replacements.extend(get_replacements(cur_sent, num_tokens=num_tokens, k_best=k_best))
83
+ replacements = pd.DataFrame(replacements, columns = ['new_sentence'])
84
+ replacements['new_scores'] = replacements['new_sentence'].apply(get_sent_score)
85
+ replacements = replacements.nsmallest(M, 'new_scores')
86
+ fix_list = replacements.new_sentence.to_list()
87
+
88
+ for new_sentence in fix_list:
89
+ st.markdown(f"New sentence:")
90
+ st.markdown(f"{new_sentence}")
91
+ st.markdown(f"New sentence negativity: {get_sent_score(new_sentence)}")