File size: 3,706 Bytes
2d5548b
4b03ca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d5548b
4b03ca9
 
2d5548b
4b03ca9
2d5548b
4b03ca9
2d5548b
4b03ca9
2d5548b
4b03ca9
 
2d5548b
4b03ca9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
# Import libraries
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Load the text data
text = open('shakespeare.txt', 'r').read() # Read the text file
vocab = sorted(set(text)) # Get the unique characters in the text
char2idx = {c: i for i, c in enumerate(vocab)} # Map characters to indices
idx2char = np.array(vocab) # Map indices to characters
text_as_int = np.array([char2idx[c] for c in text]) # Convert text to integers

# Create training examples and targets
seq_length = 100 # Length of the input sequence
examples_per_epoch = len(text) // (seq_length + 1) # Number of examples per epoch
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int) # Create a dataset from the text
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True) # Create batches of sequences

def split_input_target(chunk): # Define a function to split the input and target
  input_text = chunk[:-1] # Input is the sequence except the last character
  target_text = chunk[1:] # Target is the sequence except the first character
  return input_text, target_text

dataset = sequences.map(split_input_target) # Apply the function to the dataset

# Shuffle and batch the dataset
BATCH_SIZE = 64 # Batch size
BUFFER_SIZE = 10000 # Buffer size for shuffling
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True) # Shuffle and batch the dataset

# Define the model
vocab_size = len(vocab) # Size of the vocabulary
embedding_dim = 256 # Dimension of the embedding layer
rnn_units = 1024 # Number of units in the RNN layer

model = keras.Sequential([
  layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None]), # Embedding layer
  layers.GRU(rnn_units, return_sequences=True, stateful=True), # GRU layer
  layers.Dense(vocab_size) # Dense layer with vocab_size units
])

# Define the loss function
def loss(labels, logits):
  return keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

# Compile the model
model.compile(optimizer='adam', loss=loss)

# Define a function to generate text
def generate_text(model, start_string):
  num_generate = 50 # Number of characters to generate
  input_eval = [char2idx[s] for s in start_string] # Convert the start string to numbers
  input_eval = tf.expand_dims(input_eval, 0) # Expand the dimension for batch size
  text_generated = [] # Empty list to store the generated text

  temperature = 1.0 # Temperature parameter to control the randomness

  model.reset_states() # Reset the states of the model

  for i in range(num_generate): # Loop over the number of characters to generate
    predictions = model(input_eval) # Get the predictions from the model
    predictions = tf.squeeze(predictions, 0) # Remove the batch dimension

    predictions = predictions / temperature # Divide by temperature to increase or decrease randomness
    predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy() # Sample from the predictions

    input_eval = tf.expand_dims([predicted_id], 0) # Update the input with the predicted id

    text_generated.append(idx2char[predicted_id]) # Append the predicted character to the generated text

  return (start_string + ''.join(text_generated)) # Return the start string and the generated text

# Train the model
EPOCHS = 1 # Number of epochs to train

for epoch in range(EPOCHS): # Loop over the epochs
  print(f'Epoch {epoch + 1}') 
  model.fit(dataset, epochs=1) # Fit the model on the dataset for one epoch
  
  start_string = 'ROMEO: ' # Define a start string to generate text from
  
  print(generate_text(model, start_string)) # Print the generated text