GPT2Mask / app.py
BigSalmon's picture
Update app.py
bf92c11
raw
history blame
2.72 kB
import streamlit as st
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
from transformers.activations import get_activation
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModel
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
st.title('GPT2:')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@st.cache(allow_output_mutation=True)
def get_model():
#tokenizer = AutoTokenizer.from_pretrained("BigSalmon/MASKGPT2")
#model = AutoModelForCausalLM.from_pretrained("BigSalmon/MASKGPT2")
#tokenizer = GPTNeoXTokenizerFast.from_pretrained("CarperAI/FIM-NeoX-1.3B")
#model = GPTNeoXForCausalLM.from_pretrained("BigSalmon/FormalInformalConcise-FIM-NeoX-1.3B")
tokenizer = AutoTokenizer.from_pretrained("BigSalmon/FamilyFeud")
model = AutoModelForCausalLM.from_pretrained("BigSalmon/FamilyFeud")
return model, tokenizer
model, tokenizer = get_model()
g = """***
original: sports teams are profitable for owners. [MASK], their valuations experience a dramatic uptick.
infill: sports teams are profitable for owners. ( accumulating vast sums / stockpiling treasure / realizing benefits / cashing in / registering robust financials / scoring on balance sheets ), their valuations experience a dramatic uptick.
***
original:"""
def prefix_format(sentence):
words = sentence.split()
if "[MASK]" in sentence:
words2 = words.index("[MASK]")
#print(words2)
output = ("<|SUF|> " + ' '.join(words[words2+1:]) + " <|PRE|> " + ' '.join(words[:words2]) + " <|MID|>")
st.write(output)
else:
st.write("Add [MASK] to sentence")
with st.form(key='my_form'):
prompt = st.text_area(label='Enter sentence', value=g)
submit_button = st.form_submit_button(label='Submit')
submit_button6 = st.form_submit_button(label='Turn Into Infill Format. Just add [MASK] where you want it infilled')
if submit_button:
with torch.no_grad():
text = tokenizer.encode(prompt)
myinput, past_key_values = torch.tensor([text]), None
myinput = myinput
myinput= myinput
logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
logits = logits[0,-1]
probabilities = torch.nn.functional.softmax(logits)
best_logits, best_indices = logits.topk(250)
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
text.append(best_indices[0].item())
best_probabilities = probabilities[best_indices].tolist()
words = []
st.write(best_words)
if submit_button6:
prefix_format(prompt)