|
import streamlit as st |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
|
|
|
|
text = open('shakespeare.txt', 'r').read() |
|
vocab = sorted(set(text)) |
|
char2idx = {c: i for i, c in enumerate(vocab)} |
|
idx2char = np.array(vocab) |
|
text_as_int = np.array([char2idx[c] for c in text]) |
|
|
|
|
|
seq_length = 100 |
|
examples_per_epoch = len(text) // (seq_length + 1) |
|
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int) |
|
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True) |
|
|
|
def split_input_target(chunk): |
|
input_text = chunk[:-1] |
|
target_text = chunk[1:] |
|
return input_text, target_text |
|
|
|
dataset = sequences.map(split_input_target) |
|
|
|
|
|
BATCH_SIZE = 64 |
|
BUFFER_SIZE = 10000 |
|
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True) |
|
|
|
|
|
vocab_size = len(vocab) |
|
embedding_dim = 256 |
|
rnn_units = 1024 |
|
|
|
model = keras.Sequential([ |
|
layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None]), |
|
layers.GRU(rnn_units, return_sequences=True, stateful=True), |
|
layers.Dense(vocab_size) |
|
]) |
|
|
|
|
|
def loss(labels, logits): |
|
return keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True) |
|
|
|
|
|
model.compile(optimizer='adam', loss=loss) |
|
|
|
|
|
def generate_text(model, start_string): |
|
num_generate = 50 |
|
input_eval = [char2idx[s] for s in start_string] |
|
input_eval = tf.expand_dims(input_eval, 0) |
|
text_generated = [] |
|
|
|
temperature = 1.0 |
|
|
|
model.reset_states() |
|
|
|
for i in range(num_generate): |
|
predictions = model(input_eval) |
|
predictions = tf.squeeze(predictions, 0) |
|
|
|
predictions = predictions / temperature |
|
predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy() |
|
|
|
input_eval = tf.expand_dims([predicted_id], 0) |
|
|
|
text_generated.append(idx2char[predicted_id]) |
|
|
|
return (start_string + ''.join(text_generated)) |
|
|
|
|
|
EPOCHS = 1 |
|
|
|
for epoch in range(EPOCHS): |
|
print(f'Epoch {epoch + 1}') |
|
model.fit(dataset, epochs=1) |
|
|
|
start_string = 'ROMEO: ' |
|
|
|
print(generate_text(model, start_string)) |
|
|