|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import os |
|
import torch |
|
import torch.nn as nn |
|
from transformers.activations import get_activation |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers import AutoTokenizer, AutoModel |
|
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast |
|
|
|
|
|
st.title('GPT2:') |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def get_model(): |
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("BigSalmon/FamilyFeud") |
|
model = AutoModelForCausalLM.from_pretrained("BigSalmon/FamilyFeud") |
|
return model, tokenizer |
|
|
|
model, tokenizer = get_model() |
|
|
|
g = """*** |
|
|
|
original: sports teams are profitable for owners. [MASK], their valuations experience a dramatic uptick. |
|
infill: sports teams are profitable for owners. ( accumulating vast sums / stockpiling treasure / realizing benefits / cashing in / registering robust financials / scoring on balance sheets ), their valuations experience a dramatic uptick. |
|
|
|
*** |
|
|
|
original:""" |
|
|
|
def prefix_format(sentence): |
|
words = sentence.split() |
|
if "[MASK]" in sentence: |
|
words2 = words.index("[MASK]") |
|
|
|
output = ("<|SUF|> " + ' '.join(words[words2+1:]) + " <|PRE|> " + ' '.join(words[:words2]) + " <|MID|>") |
|
st.write(output) |
|
else: |
|
st.write("Add [MASK] to sentence") |
|
|
|
with st.form(key='my_form'): |
|
prompt = st.text_area(label='Enter sentence', value=g) |
|
submit_button = st.form_submit_button(label='Submit') |
|
submit_button6 = st.form_submit_button(label='Turn Into Infill Format. Just add [MASK] where you want it infilled') |
|
if submit_button: |
|
with torch.no_grad(): |
|
text = tokenizer.encode(prompt) |
|
myinput, past_key_values = torch.tensor([text]), None |
|
myinput = myinput |
|
myinput= myinput |
|
logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False) |
|
logits = logits[0,-1] |
|
probabilities = torch.nn.functional.softmax(logits) |
|
best_logits, best_indices = logits.topk(250) |
|
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices] |
|
text.append(best_indices[0].item()) |
|
best_probabilities = probabilities[best_indices].tolist() |
|
words = [] |
|
st.write(best_words) |
|
if submit_button6: |
|
prefix_format(prompt) |