GPT2Mask / app.py
BigSalmon's picture
Update app.py
5129477
raw
history blame
1.9 kB
import streamlit as st
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
from transformers.activations import get_activation
from transformers import AutoTokenizer, AutoModelForCausalLM
st.title('GPT2:')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@st.cache(allow_output_mutation=True)
def get_model():
tokenizer = AutoTokenizer.from_pretrained("BigSalmon/MASKGPT2")
model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln80Paraphrase")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/MASKGPT2")
return model, tokenizer
model, tokenizer = get_model()
g = """***
original: sports teams are profitable for owners. [MASK], their valuations experience a dramatic uptick.
infill: sports teams are profitable for owners. ( accumulating vast sums / stockpiling treasure / realizing benefits / cashing in / registering robust financials / scoring on balance sheets ), their valuations experience a dramatic uptick.
***
original:"""
with st.form(key='my_form'):
prompt = st.text_area(label='Enter sentence', value=g)
submit_button = st.form_submit_button(label='Submit')
if submit_button:
with torch.no_grad():
text = tokenizer.encode(prompt)
myinput, past_key_values = torch.tensor([text]), None
myinput = myinput
myinput= myinput.to(device)
logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
logits = logits[0,-1]
probabilities = torch.nn.functional.softmax(logits)
best_logits, best_indices = logits.topk(300)
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
text.append(best_indices[0].item())
best_probabilities = probabilities[best_indices].tolist()
words = []
st.write(best_words)