BigSalmon commited on
Commit
b743d9d
·
1 Parent(s): dcb539c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -33,21 +33,34 @@ infill: sports teams are profitable for owners. ( accumulating vast sums / stock
33
 
34
  original:"""
35
 
 
 
 
 
 
 
 
 
 
 
36
  with st.form(key='my_form'):
37
  prompt = st.text_area(label='Enter sentence', value=g)
38
  submit_button = st.form_submit_button(label='Submit')
 
39
  if submit_button:
40
  with torch.no_grad():
41
  text = tokenizer.encode(prompt)
42
  myinput, past_key_values = torch.tensor([text]), None
43
  myinput = myinput
44
- myinput= myinput.to(device)
45
  logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
46
  logits = logits[0,-1]
47
  probabilities = torch.nn.functional.softmax(logits)
48
- best_logits, best_indices = logits.topk(300)
49
  best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
50
  text.append(best_indices[0].item())
51
  best_probabilities = probabilities[best_indices].tolist()
52
  words = []
53
- st.write(best_words)
 
 
 
33
 
34
  original:"""
35
 
36
+ def prefix_format(sentence):
37
+ words = sentence.split()
38
+ if "[MASK]" in sentence:
39
+ words2 = words.index("[MASK]")
40
+ #print(words2)
41
+ output = ("<|SUF|> " + ' '.join(words[words2+1:]) + " <|PRE|> " + ' '.join(words[:words2]) + " <|MID|>")
42
+ st.write(output)
43
+ else:
44
+ st.write("Add [MASK] to sentence")
45
+
46
  with st.form(key='my_form'):
47
  prompt = st.text_area(label='Enter sentence', value=g)
48
  submit_button = st.form_submit_button(label='Submit')
49
+ submit_button6 = st.form_submit_button(label='Turn Into Infill Format. Just add [MASK] where you want it infilled')
50
  if submit_button:
51
  with torch.no_grad():
52
  text = tokenizer.encode(prompt)
53
  myinput, past_key_values = torch.tensor([text]), None
54
  myinput = myinput
55
+ myinput= myinput
56
  logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
57
  logits = logits[0,-1]
58
  probabilities = torch.nn.functional.softmax(logits)
59
+ best_logits, best_indices = logits.topk(250)
60
  best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
61
  text.append(best_indices[0].item())
62
  best_probabilities = probabilities[best_indices].tolist()
63
  words = []
64
+ st.write(best_words)
65
+ if submit_button6:
66
+ prefix_format(prompt)