File size: 4,436 Bytes
b7137b1
 
 
 
afd47cf
b7137b1
 
 
3b3fa96
 
b7137b1
 
 
6f4ba26
 
b7137b1
 
 
 
 
 
 
 
 
 
 
 
 
08b9f95
ecce248
 
b7137b1
 
 
 
 
 
 
 
1c53eb1
c6d5fcb
 
 
 
b7137b1
 
 
 
 
 
 
08b9f95
 
3d953c4
b7137b1
 
 
1c53eb1
b7137b1
 
 
 
 
35a5cd4
5408f33
35a5cd4
b7137b1
 
001ff1e
 
9c3de2e
 
5408f33
b7137b1
 
5408f33
3b3fa96
b7137b1
3b3fa96
5408f33
fce5f58
0a7c967
dec3f54
aed5912
2056bb6
9c3de2e
fce5f58
 
d4804d5
 
 
 
 
afd47cf
 
e32131f
 
d4804d5
 
5408f33
b7137b1
 
5408f33
 
b7137b1
 
d4804d5
5408f33
b7137b1
32acf13
b7137b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import time
import streamlit as st
import torch
import string


from transformers import BertTokenizer, BertForMaskedLM

st.set_page_config(page_title='Qualitative pretrained model eveluation', page_icon=None, layout='centered', initial_sidebar_state='auto')

@st.cache()
def load_bert_model(model_name):
  try:
    bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case
    =False)
    bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
    return bert_tokenizer,bert_model
  except Exception as e:
    pass



  
def decode(tokenizer, pred_idx, top_clean):
  ignore_tokens = string.punctuation + '[PAD]'
  tokens = []
  for w in pred_idx:
    token = ''.join(tokenizer.decode(w).split())
    if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
      #tokens.append(token.replace('##', ''))
      tokens.append(token)
  return '\n'.join(tokens[:top_clean])

def encode(tokenizer, text_sentence, add_special_tokens=True):
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
  if tokenizer.mask_token == text_sentence.split()[-1]:
    text_sentence += ' .'

  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
  if (tokenizer.mask_token in text_sentence.split()):
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
  else:
    mask_idx = 0
  return input_ids, mask_idx

def get_all_predictions(text_sentence, top_clean=5):
    # ========================= BERT =================================
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
  with torch.no_grad():
    predict = bert_model(input_ids)[0]
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
  return {'Masked position': bert,'[CLS]':cls}

def get_bert_prediction(input_text,top_k):
  try:
    #input_text += ' <mask>'
    res = get_all_predictions(input_text, top_clean=int(top_k))
    return res
  except Exception as error:
    pass

st.markdown("<h3 style='text-align: center;'>Qualitative evaluation of Pretrained BERT models</h3>", unsafe_allow_html=True)
st.markdown("""
        <small style="font-size:18px; color: #8f8f8f">This app is used to qualitatively examine the performance of pretrained models to do NER , <a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><b>with no fine tuning</b></small></a>
        """, unsafe_allow_html=True)
  #st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
st.write("The neighborhood of CLS vectors as well as the model prediction for a blank position are examined")
st.write("To examine model prediction for a position, enter the token [MASK] or <mask>")
st.write("To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
top_k = st.sidebar.slider("Select how many predictions do you need", 1 , 50, 20) #some times it is possible to have less words
print(top_k)

  
start = None
  #if st.button("Submit"):
    
  #  with st.spinner("Computing"):
try:
            
      model_name = st.sidebar.selectbox(label='Select Model to Apply',  options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased'], index=0,  key = "model_name")
      bert_tokenizer, bert_model  = load_bert_model(model_name)
      default_text = "Imatinib is used to [MASK] nsclc"
      input_text = st.text_area(
                  label="Enter text below",
                  value=default_text,
                )
      if st.button("Submit"):
        with st.spinner("Computing"):
          start = time.time()
          try:
            res = get_bert_prediction(input_text,top_k)
            st.caption("Results in JSON")
            st.json(res)
            
          except Exception as e:
            st.error("Some error occurred during prediction" + str(e))
            st.stop()
 

	
      if start is not None:
          st.text(f"prediction took {time.time() - start:.2f}s")

except Exception as e:
  st.error("Some error occurred during loading" + str(e))
  st.stop()  
	
st.write("---")