Commit
·
08b9f95
1
Parent(s):
6f4ba26
Update app.py
Browse files
app.py
CHANGED
@@ -25,7 +25,7 @@ def decode(tokenizer, pred_idx, top_clean):
|
|
25 |
tokens = []
|
26 |
for w in pred_idx:
|
27 |
token = ''.join(tokenizer.decode(w).split())
|
28 |
-
if token not in ignore_tokens:
|
29 |
#tokens.append(token.replace('##', ''))
|
30 |
tokens.append(token)
|
31 |
return '\n'.join(tokens[:top_clean])
|
@@ -48,8 +48,8 @@ def get_all_predictions(text_sentence, top_clean=5):
|
|
48 |
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
|
49 |
with torch.no_grad():
|
50 |
predict = bert_model(input_ids)[0]
|
51 |
-
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*
|
52 |
-
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*
|
53 |
return {'bert': bert,'[CLS]':cls}
|
54 |
|
55 |
def get_bert_prediction(input_text,top_k):
|
@@ -67,7 +67,7 @@ st.markdown("""
|
|
67 |
st.write("Incomplete. Work in progress...")
|
68 |
#st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
|
69 |
st.write("CLS vectors as well as the model prediction for a blank position are examined")
|
70 |
-
top_k =
|
71 |
print(top_k)
|
72 |
|
73 |
|
|
|
25 |
tokens = []
|
26 |
for w in pred_idx:
|
27 |
token = ''.join(tokenizer.decode(w).split())
|
28 |
+
if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
|
29 |
#tokens.append(token.replace('##', ''))
|
30 |
tokens.append(token)
|
31 |
return '\n'.join(tokens[:top_clean])
|
|
|
48 |
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
|
49 |
with torch.no_grad():
|
50 |
predict = bert_model(input_ids)[0]
|
51 |
+
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
|
52 |
+
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
|
53 |
return {'bert': bert,'[CLS]':cls}
|
54 |
|
55 |
def get_bert_prediction(input_text,top_k):
|
|
|
67 |
st.write("Incomplete. Work in progress...")
|
68 |
#st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
|
69 |
st.write("CLS vectors as well as the model prediction for a blank position are examined")
|
70 |
+
top_k = 20
|
71 |
print(top_k)
|
72 |
|
73 |
|