tokenizer / app.py
alonsosilva's picture
Add app
7780b98
import pandas as pd
import random
import solara
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
text = solara.reactive("Example text is here")
text2 = solara.reactive("")
text3 = solara.reactive("")
# Create dataframe mapping token IDs and tokens
df = pd.DataFrame()
df["token ID"] = range(50257)
df["token"] = [tokenizer.decode([i]) for i in range(50257)]
@solara.component
def Page():
with solara.Column(margin=10):
solara.Markdown("#GPT token encoder and decoder")
solara.InputText("Enter text to tokenize it:", value=text, continuous_update=True)
tokens = tokenizer.encode(text.value, return_tensors="pt")
spans = ""
spans1 = ""
for i, token in enumerate(tokens[0]):
random.seed(i)
random_color = ''.join([random.choice('0123456789ABCDEF') for k in range(6)])
spans += " " + f"<span style='font-family: cursive;color: #{random_color}'>{token.numpy()}</span>"
spans1 += " " + f"""<span style="
padding: 5px;
border-right: 3px solid white;
line-height: 3em;
font-family: courier;
background-color: #{random_color};
color: white;
position: relative;
"><span style="position: absolute; top: 5.5ch; line-height: 1em; left: -0.5px; font-size: 0.45em">{token.numpy()}</span>{tokenizer.decode(token)}</span>"""
solara.Markdown(f"{spans}")
if len(tokens[0]) == 1:
solara.Markdown(f"{len(tokens[0])} token")
else:
solara.Markdown(f"{len(tokens[0])} tokens")
solara.Markdown(f'{spans1}')
solara.InputText("Or convert space separated tokens to text:", value=text2, continuous_update=True)
spans2 = text2.value.split(' ')
spans2 = [int(span) for span in spans2 if span != ""]
spans2 = tokenizer.decode(spans2)
solara.Markdown(f'{spans2}')
solara.Markdown("##Search tokens")
solara.InputText("Search for a token:", value=text3, continuous_update=True)
df_subset = df[df["token"].str.startswith(text3.value)]
solara.Markdown(f"{df_subset.shape[0]:,} results")
solara.DataFrame(df_subset, items_per_page=10)
Page()