Commit
·
f264b44
1
Parent(s):
1c53eb1
Update app.py
Browse files
app.py
CHANGED
@@ -44,7 +44,8 @@ def get_all_predictions(text_sentence, top_clean=5):
|
|
44 |
with torch.no_grad():
|
45 |
predict = bert_model(input_ids)[0]
|
46 |
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
|
47 |
-
|
|
|
48 |
|
49 |
def get_bert_prediction(input_text,top_k):
|
50 |
try:
|
|
|
44 |
with torch.no_grad():
|
45 |
predict = bert_model(input_ids)[0]
|
46 |
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
|
47 |
+
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k).indices.tolist(), top_clean)
|
48 |
+
return {'bert': bert,'[CLS]':cls}
|
49 |
|
50 |
def get_bert_prediction(input_text,top_k):
|
51 |
try:
|