File size: 7,843 Bytes
b7137b1
 
 
 
afd47cf
a03cf87
 
6f5d2d2
d1357c0
167f69d
b7137b1
 
3b3fa96
 
b7137b1
 
 
6f4ba26
 
b7137b1
 
 
 
 
 
 
 
 
b9fa3c7
b7137b1
 
 
08b9f95
ecce248
 
b7137b1
 
 
 
 
1f4b5d0
 
b7137b1
1f4b5d0
1c53eb1
c6d5fcb
 
 
 
1f4b5d0
b7137b1
f8dc81b
b7137b1
1f4b5d0
 
b7137b1
 
08b9f95
 
293e817
 
2406613
b9fa3c7
2406613
b7137b1
f8dc81b
b7137b1
1c53eb1
f8dc81b
b7137b1
 
 
b9f419a
 
f8dc81b
b9f419a
a03cf87
 
167f69d
 
b9f419a
 
 
f8dc81b
b9f419a
 
 
 
 
 
 
 
3f2b07b
7190b6a
d549833
7190b6a
d549833
3f2b07b
7190b6a
d549833
7190b6a
d549833
caf6c21
815063d
4b91aef
6f5d2d2
57cbad4
caf6c21
c273b5f
49400fb
 
 
caf6c21
 
 
3f2b07b
0e5769d
49400fb
 
 
0e5769d
caf6c21
 
b9f419a
77d733c
d1357c0
181a8b0
7190b6a
6f5d2d2
 
b7137b1
6f5d2d2
0a31bde
a340b6b
 
de7c5ee
a340b6b
6f5d2d2
 
5ce6476
b7137b1
 
6f5d2d2
 
 
0a31bde
6f5d2d2
167f69d
b7137b1
b9f419a
3b3fa96
b7137b1
3b3fa96
6f5d2d2
0a31bde
d1357c0
 
 
 
 
 
 
 
3f2b07b
 
 
 
 
b27f63f
 
d57b0b0
fde541a
b9f419a
b7137b1
6f5d2d2
 
 
b7137b1
6f5d2d2
 
 
316bafd
 
6f5d2d2
b7137b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import time
import streamlit as st
import torch
import string

bert_tokenizer = None
bert_model = None
top_k = 20
model_name = 'ajitrajasekharan/biomedical'

from transformers import BertTokenizer, BertForMaskedLM

st.set_page_config(page_title='Qualitative pretrained model eveluation', page_icon=None, layout='centered', initial_sidebar_state='auto')

@st.cache()
def load_bert_model(model_name):
  try:
    bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case
    =False)
    bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
    return bert_tokenizer,bert_model
  except Exception as e:
    pass



  
def decode(tokenizer, pred_idx, top_clean):
  ignore_tokens = string.punctuation
  tokens = []
  for w in pred_idx:
    token = ''.join(tokenizer.decode(w).split())
    if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
      #tokens.append(token.replace('##', ''))
      tokens.append(token)
  return '\n'.join(tokens[:top_clean])

def encode(tokenizer, text_sentence, add_special_tokens=True):
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
  #if tokenizer.mask_token == text_sentence.split()[-1]:
  #  text_sentence += ' .'

  tokenized_text = bert_tokenizer.tokenize(text_sentence) 
  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
  if (tokenizer.mask_token in text_sentence.split()):
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
  else:
    mask_idx = 0
  return input_ids, mask_idx,tokenized_text

def get_all_predictions(text_sentence, model_name,top_clean=5):
    # ========================= BERT =================================
  input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
   
  with torch.no_grad():
    predict = bert_model(input_ids)[0]
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
  
  if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
    return {'Input sentence':text_sentence,'Tokenized text': tokenized_text, 'results_count':top_k,'Model':model_name,'Masked position': bert,'[CLS]':cls}
  else:
    return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'results_count':top_k,'Model':model_name,'[CLS]':cls}

def get_bert_prediction(input_text,top_k,model_name):
  try:
    #input_text += ' <mask>'
    res = get_all_predictions(input_text,model_name, top_clean=int(top_k))
    return res
  except Exception as error:
    pass
    
 
def run_test(sent,top_k,model_name):
  start = None
  global bert_tokenizer
  global bert_model
  if (bert_tokenizer is None):
        bert_tokenizer, bert_model  = load_bert_model(model_name)
  with st.spinner("Computing"):
          start = time.time()
          try:
            res = get_bert_prediction(sent,top_k,model_name)
            st.caption("Results in JSON")
            st.json(res)
            
          except Exception as e:
            st.error("Some error occurred during prediction" + str(e))
            st.stop()
  if start is not None:
    st.text(f"prediction took {time.time() - start:.2f}s")
    
def on_text_change():
  global top_k,model_name
  text = st.session_state.my_text
  run_test(text,top_k,model_name)

def on_option_change():
  global top_k,model_name
  text = st.session_state.my_choice
  run_test(text,top_k,model_name)
  
def on_results_count_change():
   global top_k
   top_k = int(st.session_state.my_slider)
   st.info("Results count changed " + str(top_k))

def on_model_change1():
  global model_name
  global bert_tokenizer
  global bert_model
  model_name = st.session_state.my_model1
  st.info("Pre-selected model chosen: " + model_name)
  bert_tokenizer, bert_model  = load_bert_model(model_name)

def on_model_change2(): 
  global model_name
  global bert_tokenizer
  global bert_model
  model_name = st.session_state.my_model2
  st.info("Custom model chosen: " + model_name)
  bert_tokenizer, bert_model  = load_bert_model(model_name)
  
def init_selectbox():
  st.selectbox(
     'Choose any of these sentences or type any text below',
     ('', "[MASK] who lives in New York and works for XCorp suffers from Parkinson's", "Lou Gehrig who lives in [MASK] and works for XCorp suffers from Parkinson's","Lou Gehrig who lives in New York and works       for [MASK] suffers from Parkinson's","Lou Gehrig who lives in New York and works for XCorp suffers from [MASK]","[MASK] who lives in New York and works for XCorp suffers from Lou Gehrig's", "Parkinson who      lives in [MASK] and works for XCorp suffers from Lou Gehrig's","Parkinson who lives in New York and works for [MASK] suffers from Lou Gehrig's","Parkinson who lives in New York and works for XCorp suffers      from [MASK]","Lou Gehrig","Parkinson","Lou Gehrigh's is a [MASK]","Parkinson is a [MASK]","New York is a [MASK]","New York","XCorp","XCorp is a [MASK]","acute lymphoblastic leukemia","acute lymphoblastic       leukemia is a [MASK]"),on_change=on_option_change,key='my_choice') 
  


def main():
  global top_k 
  global bert_tokenizer
  global bert_model
  global model_name
  
  st.markdown("<h3 style='text-align: center;'>Qualitative evaluation of  any pretrained BERT model</h3>", unsafe_allow_html=True)
  st.markdown("""
        <small style="font-size:18px; color: #7f7f7f">Pretrained BERT models can be used as is, <a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><b>with no fine tuning to perform tasks like NER</b></a> <i>ideally if both fill-mask and CLS predictions are good, or minimally if fill-mask predictions are adequate</i></small>
        """, unsafe_allow_html=True)
  #st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
  st.write("This app can be used to examine both model prediction for a masked  position as well as the neighborhood of CLS vector")
  st.write("   - To examine model prediction for a position, enter the token [MASK] or <mask>")
  st.write("   - To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
  st.sidebar.slider("Select how many predictions do you need", 1 , 50, top_k,key='my_slider',on_change=on_results_count_change) #some times it is possible to have less words
  print(top_k)



  #if st.button("Submit"):
    
  #  with st.spinner("Computing"):
  try:
      
      st.sidebar.selectbox(label='Select Model to Apply',  options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased','dmis-lab/biobert-v1.1'], index=0,  key = "my_model1",on_change=on_model_change1)
      init_selectbox()
      st.text_input("Enter text below", "",on_change=on_text_change,key='my_text')
      st.text_input("Model not listed on left? Type the model name (fill-mask BERT models only)", "",key="my_model2",on_change=on_model_change2)
      #if (len(custom_model_name) > 0):
       # model_name = custom_model_name
      #  st.info("Custom model selected: " + model_name)
      #  bert_tokenizer, bert_model  = load_bert_model(model_name)
      #if len(input_text) > 0:
      #  run_test(input_text,top_k,model_name)
      #else:
       # if len(option) > 0:
       #   run_test(option,top_k,model_name)
      if (bert_tokenizer is None):
        bert_tokenizer, bert_model  = load_bert_model(model_name)
       
      
      

  except Exception as e:
    st.error("Some error occurred during loading" + str(e))
    st.stop()  
	
  st.write("---")
  
 

if __name__ == "__main__":
   main()