ajitrajasekharan commited on
Commit
f264b44
·
1 Parent(s): 1c53eb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -44,7 +44,8 @@ def get_all_predictions(text_sentence, top_clean=5):
44
  with torch.no_grad():
45
  predict = bert_model(input_ids)[0]
46
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
47
- return {'bert': bert}
 
48
 
49
  def get_bert_prediction(input_text,top_k):
50
  try:
 
44
  with torch.no_grad():
45
  predict = bert_model(input_ids)[0]
46
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
47
+ cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k).indices.tolist(), top_clean)
48
+ return {'bert': bert,'[CLS]':cls}
49
 
50
  def get_bert_prediction(input_text,top_k):
51
  try: