|
import tensorflow as tf
|
|
|
|
def sdp_attention(query, key, value, mask):
|
|
matmul_qk = tf.matmul(query, key, transpose_b=True)
|
|
depth = tf.cast(tf.shape(key)[-1], tf.float32)
|
|
logits = matmul_qk / tf.math.sqrt(depth)
|
|
if mask is not None:
|
|
logits += mask * -1e9
|
|
attention_weights = tf.nn.softmax(logits, axis=-1)
|
|
output = tf.matmul(attention_weights, value)
|
|
return output
|
|
|
|
class MultiHeadAttention(tf.keras.layers.Layer):
|
|
def __init__(self, num_heads, d_model, **kwargs):
|
|
assert d_model % num_heads == 0
|
|
super(MultiHeadAttention, self).__init__(**kwargs)
|
|
self.num_heads = num_heads
|
|
self.d_model = d_model
|
|
|
|
self.depth = self.d_model // self.num_heads
|
|
|
|
self.query_dense = tf.keras.layers.Dense(self.d_model)
|
|
self.key_dense = tf.keras.layers.Dense(self.d_model)
|
|
self.value_dense = tf.keras.layers.Dense(self.d_model)
|
|
|
|
self.dense = tf.keras.layers.Dense(self.d_model)
|
|
|
|
|
|
def get_config(self):
|
|
config = super(MultiHeadAttention, self).get_config()
|
|
config.update({"num_heads": self.num_heads, "d_model": self.d_model})
|
|
return config
|
|
|
|
def split_heads(self, inputs: tf.Tensor, batch_size: int):
|
|
inputs = tf.keras.layers.Lambda(
|
|
lambda inputs: tf.reshape(
|
|
inputs, shape=(batch_size, -1, self.num_heads, self.depth))
|
|
)(inputs)
|
|
return tf.keras.layers.Lambda(
|
|
lambda inputs: tf.transpose(inputs, perm=[0, 2, 1, 3])
|
|
)(inputs)
|
|
|
|
def call(self, inputs: tf.Tensor):
|
|
query, key, value, mask = (
|
|
inputs["query"],
|
|
inputs["key"],
|
|
inputs["value"],
|
|
inputs["mask"],
|
|
)
|
|
batch_size = tf.shape(query)[0]
|
|
|
|
query = self.query_dense(query)
|
|
key = self.key_dense(key)
|
|
value = self.value_dense(value)
|
|
|
|
query = self.split_heads(query, batch_size)
|
|
key = self.split_heads(key, batch_size)
|
|
value = self.split_heads(value, batch_size)
|
|
|
|
scaled_attention = sdp_attention(query, key, value, mask)
|
|
scaled_attention = tf.keras.layers.Lambda(
|
|
lambda scaled_attention: tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
|
|
)(scaled_attention)
|
|
|
|
concat_attention = tf.keras.layers.Lambda(
|
|
lambda scaled_attention: tf.reshape(
|
|
scaled_attention, (batch_size, -1, self.d_model)
|
|
)
|
|
)(scaled_attention)
|
|
|
|
outputs = self.dense(concat_attention)
|
|
|
|
return outputs
|
|
|
|
|
|
def create_padding_mask(x):
|
|
mask = tf.cast(tf.math.equal(x, 0), dtype=tf.float32)
|
|
return mask[:, tf.newaxis, tf.newaxis, :]
|
|
|
|
|
|
def create_look_ahead_mask(x):
|
|
seq_len = tf.shape(x)[1]
|
|
look_ahead_mask = 1 - tf.linalg.band_part(
|
|
tf.ones((seq_len, seq_len), dtype=tf.float32), -1, 0
|
|
)
|
|
padding_mask = create_padding_mask(x)
|
|
return tf.maximum(look_ahead_mask, padding_mask)
|
|
|
|
|
|
class PositionalEncoding(tf.keras.layers.Layer):
|
|
def __init__(self, position: int, d_model: int, **kwargs):
|
|
super(PositionalEncoding, self).__init__(**kwargs)
|
|
self.position = position
|
|
self.d_model = d_model
|
|
self.pos_encoding = self.positional_encoding(position, d_model)
|
|
|
|
def get_config(self):
|
|
config = super(PositionalEncoding, self).get_config()
|
|
config.update({"position": self.position, "d_model": self.d_model})
|
|
return config
|
|
|
|
def get_angles(self, position: tf.Tensor, i: tf.Tensor, d_model: tf.Tensor):
|
|
angles = 1 / tf.pow(10000, (2 * (i // 2)) / d_model)
|
|
return position * angles
|
|
|
|
def positional_encoding(self, position: int, d_model: int):
|
|
angle_rads = self.get_angles(
|
|
position=tf.cast(tf.range(position)[:, tf.newaxis], dtype=tf.float32),
|
|
i=tf.cast(tf.range(d_model)[tf.newaxis, :], dtype=tf.float32),
|
|
d_model=tf.cast(d_model, dtype=tf.float32),
|
|
)
|
|
sines = tf.math.sin(angle_rads[:, 0::2])
|
|
cosines = tf.math.cos(angle_rads[:, 1::2])
|
|
|
|
pos_encoding = tf.concat([sines, cosines], axis=-1)
|
|
pos_encoding = pos_encoding[tf.newaxis, ...]
|
|
return pos_encoding
|
|
|
|
def call(self, inputs: tf.Tensor):
|
|
return inputs + self.pos_encoding[:, : tf.shape(inputs)[1], :]
|
|
|
|
def encoder_layer(hparams, name: str = "encoder_layer"):
|
|
inputs = tf.keras.Input(shape=(None, hparams.d_model), name="inputs")
|
|
padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")
|
|
|
|
attention = MultiHeadAttention(
|
|
num_heads=hparams.num_heads, d_model=hparams.d_model, name="attention"
|
|
)({"query": inputs, "key": inputs, "value": inputs, "mask": padding_mask})
|
|
attention = tf.keras.layers.Dropout(hparams.dropout)(attention)
|
|
attention += tf.cast(inputs, dtype=tf.float32)
|
|
attention = tf.keras.layers.LayerNormalization(epsilon=1e-6)(attention)
|
|
|
|
outputs = tf.keras.layers.Dense(hparams.num_units, activation=hparams.activation)(
|
|
attention
|
|
)
|
|
outputs = tf.keras.layers.Dense(hparams.d_model)(outputs)
|
|
outputs = tf.keras.layers.Dropout(hparams.dropout)(outputs)
|
|
outputs += attention
|
|
outputs = tf.keras.layers.LayerNormalization(epsilon=1e-6)(outputs)
|
|
|
|
return tf.keras.Model(inputs=[inputs, padding_mask], outputs=outputs, name=name)
|
|
|
|
def encoder(hparams, name: str = "encoder"):
|
|
inputs = tf.keras.Input(shape=(None,), name="inputs")
|
|
padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")
|
|
|
|
embeddings = tf.keras.layers.Embedding(hparams.vocab_size, hparams.d_model)(inputs)
|
|
embeddings *= tf.math.sqrt(tf.cast(hparams.d_model, dtype=tf.float32))
|
|
embeddings = PositionalEncoding(
|
|
position=hparams.vocab_size, d_model=hparams.d_model
|
|
)(embeddings)
|
|
|
|
outputs = tf.keras.layers.Dropout(hparams.dropout)(embeddings)
|
|
|
|
for i in range(hparams.num_layers):
|
|
outputs = encoder_layer(hparams, name=f"encoder_layer_{i}")(
|
|
[outputs, padding_mask]
|
|
)
|
|
|
|
return tf.keras.Model(inputs=[inputs, padding_mask], outputs=outputs, name=name)
|
|
|
|
def decoder_layer(hparams, name: str = "decoder_layer"):
|
|
inputs = tf.keras.Input(shape=(None, hparams.d_model), name="inputs")
|
|
enc_outputs = tf.keras.Input(shape=(None, hparams.d_model), name="encoder_outputs")
|
|
look_ahead_mask = tf.keras.Input(shape=(1, None, None), name="look_ahead_mask")
|
|
padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")
|
|
|
|
attention1 = MultiHeadAttention(
|
|
num_heads=hparams.num_heads, d_model=hparams.d_model, name="attention_1"
|
|
)(
|
|
inputs={
|
|
"query": inputs,
|
|
"key": inputs,
|
|
"value": inputs,
|
|
"mask": look_ahead_mask,
|
|
}
|
|
)
|
|
attention1 += tf.cast(inputs, dtype=tf.float32)
|
|
attention1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(attention1)
|
|
|
|
attention2 = MultiHeadAttention(
|
|
num_heads=hparams.num_heads, d_model=hparams.d_model, name="attention_2"
|
|
)(
|
|
inputs={
|
|
"query": attention1,
|
|
"key": enc_outputs,
|
|
"value": enc_outputs,
|
|
"mask": padding_mask,
|
|
}
|
|
)
|
|
attention2 = tf.keras.layers.Dropout(hparams.dropout)(attention2)
|
|
attention2 += attention1
|
|
attention2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(
|
|
attention2 + attention1
|
|
)
|
|
|
|
outputs = tf.keras.layers.Dense(hparams.num_units, activation=hparams.activation)(
|
|
attention2
|
|
)
|
|
outputs = tf.keras.layers.Dense(hparams.d_model)(outputs)
|
|
outputs = tf.keras.layers.Dropout(hparams.dropout)(outputs)
|
|
outputs += attention2
|
|
outputs = tf.keras.layers.LayerNormalization(epsilon=1e-6)(outputs)
|
|
|
|
return tf.keras.Model(
|
|
inputs=[inputs, enc_outputs, look_ahead_mask, padding_mask],
|
|
outputs=outputs,
|
|
name=name,
|
|
)
|
|
|
|
def decoder(hparams, name: str = "decoder"):
|
|
inputs = tf.keras.Input(shape=(None,), name="inputs")
|
|
enc_outputs = tf.keras.Input(shape=(None, hparams.d_model), name="encoder_outputs")
|
|
look_ahead_mask = tf.keras.Input(shape=(1, None, None), name="look_ahead_mask")
|
|
padding_mask = tf.keras.Input(shape=(1, 1, None), name="padding_mask")
|
|
|
|
embeddings = tf.keras.layers.Embedding(hparams.vocab_size, hparams.d_model)(inputs)
|
|
embeddings *= tf.math.sqrt(tf.cast(hparams.d_model, dtype=tf.float32))
|
|
embeddings = PositionalEncoding(
|
|
position=hparams.vocab_size, d_model=hparams.d_model
|
|
)(embeddings)
|
|
|
|
outputs = tf.keras.layers.Dropout(hparams.dropout)(embeddings)
|
|
|
|
for i in range(hparams.num_layers):
|
|
outputs = decoder_layer(
|
|
hparams,
|
|
name="decoder_layer_{}".format(i),
|
|
)(inputs=[outputs, enc_outputs, look_ahead_mask, padding_mask])
|
|
|
|
return tf.keras.Model(
|
|
inputs=[inputs, enc_outputs, look_ahead_mask, padding_mask],
|
|
outputs=outputs,
|
|
name=name,
|
|
)
|
|
|
|
|
|
|
|
def transformer(hparams, name: str = "transformer"):
|
|
inputs = tf.keras.Input(shape=(None,), name="inputs")
|
|
dec_inputs = tf.keras.Input(shape=(None,), name="dec_inputs")
|
|
|
|
enc_padding_mask = tf.keras.layers.Lambda(
|
|
create_padding_mask, output_shape=(1, 1, None), name="enc_padding_mask"
|
|
)(inputs)
|
|
|
|
look_ahead_mask = tf.keras.layers.Lambda(
|
|
create_look_ahead_mask, output_shape=(1, None, None), name="look_ahead_mask"
|
|
)(dec_inputs)
|
|
|
|
dec_padding_mask = tf.keras.layers.Lambda(
|
|
create_padding_mask, output_shape=(1, 1, None), name="dec_padding_mask"
|
|
)(inputs)
|
|
|
|
enc_outputs = encoder(hparams)(inputs=[inputs, enc_padding_mask])
|
|
|
|
dec_outputs = decoder(hparams)(
|
|
inputs=[dec_inputs, enc_outputs, look_ahead_mask, dec_padding_mask]
|
|
)
|
|
|
|
outputs = tf.keras.layers.Dense(hparams.vocab_size, name="outputs")(dec_outputs)
|
|
|
|
return tf.keras.Model(inputs=[inputs, dec_inputs], outputs=outputs, name=name)
|
|
|
|
|
|
|