ajitrajasekharan commited on
Commit
08b9f95
·
1 Parent(s): 6f4ba26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -25,7 +25,7 @@ def decode(tokenizer, pred_idx, top_clean):
25
  tokens = []
26
  for w in pred_idx:
27
  token = ''.join(tokenizer.decode(w).split())
28
- if token not in ignore_tokens:
29
  #tokens.append(token.replace('##', ''))
30
  tokens.append(token)
31
  return '\n'.join(tokens[:top_clean])
@@ -48,8 +48,8 @@ def get_all_predictions(text_sentence, top_clean=5):
48
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
49
  with torch.no_grad():
50
  predict = bert_model(input_ids)[0]
51
- bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*2).indices.tolist(), top_clean)
52
- cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*2).indices.tolist(), top_clean)
53
  return {'bert': bert,'[CLS]':cls}
54
 
55
  def get_bert_prediction(input_text,top_k):
@@ -67,7 +67,7 @@ st.markdown("""
67
  st.write("Incomplete. Work in progress...")
68
  #st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
69
  st.write("CLS vectors as well as the model prediction for a blank position are examined")
70
- top_k = 10
71
  print(top_k)
72
 
73
 
 
25
  tokens = []
26
  for w in pred_idx:
27
  token = ''.join(tokenizer.decode(w).split())
28
+ if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
29
  #tokens.append(token.replace('##', ''))
30
  tokens.append(token)
31
  return '\n'.join(tokens[:top_clean])
 
48
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
49
  with torch.no_grad():
50
  predict = bert_model(input_ids)[0]
51
+ bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
52
+ cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
53
  return {'bert': bert,'[CLS]':cls}
54
 
55
  def get_bert_prediction(input_text,top_k):
 
67
  st.write("Incomplete. Work in progress...")
68
  #st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
69
  st.write("CLS vectors as well as the model prediction for a blank position are examined")
70
+ top_k = 20
71
  print(top_k)
72
 
73