BigSalmon commited on
Commit
05ec8c2
·
1 Parent(s): 07bc967

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers.activations import get_activation
8
+ from transformers import AutoTokenizer, AutoModelWithLMHead
9
+
10
+
11
+ st.title('GPT2:')
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ @st.cache(allow_output_mutation=True)
16
+ def get_model():
17
+ tokenizer = AutoTokenizer.from_pretrained("BigSalmon/MASKGPT2")
18
+ model = AutoModelWithLMHead.from_pretrained("BigSalmon/MASKGPT2")
19
+ return model, tokenizer
20
+
21
+ model, tokenizer = get_model()
22
+
23
+ g = """
24
+ ***
25
+
26
+ original: sports teams are profitable for owners. [MASK], their valuations experience a dramatic uptick.
27
+ infill: sports teams are profitable for owners. ( accumulating vast sums / stockpiling treasure / realizing benefits / cashing in / registering robust financials / scoring on balance sheets ), their valuations experience a dramatic uptick.
28
+
29
+ ***
30
+
31
+ original:"""
32
+
33
+ with st.form(key='my_form'):
34
+ prompt = st.text_area(label='Enter sentence', value=g)
35
+ submit_button = st.form_submit_button(label='Submit')
36
+
37
+ if submit_button:
38
+ with torch.no_grad():
39
+ text = tokenizer.encode(prompt)
40
+ myinput, past_key_values = torch.tensor([text]), None
41
+ myinput = myinput
42
+ myinput= myinput.to(device)
43
+ logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
44
+ logits = logits[0,-1]
45
+ probabilities = torch.nn.functional.softmax(logits)
46
+ best_logits, best_indices = logits.topk(300)
47
+ best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
48
+ text.append(best_indices[0].item())
49
+ best_probabilities = probabilities[best_indices].tolist()
50
+ words = []
51
+ st.write(best_words)