import tensorflow as tf class ReflectionPad1d(tf.keras.layers.Layer): def __init__(self, padding): super(ReflectionPad1d, self).__init__() self.padding = padding def call(self, x): return tf.pad(x, [[0, 0], [self.padding, self.padding], [0, 0], [0, 0]], "REFLECT") class ResidualStack(tf.keras.layers.Layer): def __init__(self, channels, num_res_blocks, kernel_size, name): super(ResidualStack, self).__init__(name=name) assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd." base_padding = (kernel_size - 1) // 2 self.blocks = [] num_layers = 2 for idx in range(num_res_blocks): layer_kernel_size = kernel_size layer_dilation = layer_kernel_size**idx layer_padding = base_padding * layer_dilation block = [ tf.keras.layers.LeakyReLU(0.2), ReflectionPad1d(layer_padding), tf.keras.layers.Conv2D(filters=channels, kernel_size=(kernel_size, 1), dilation_rate=(layer_dilation, 1), use_bias=True, padding='valid', name=f'blocks.{idx}.{num_layers}'), tf.keras.layers.LeakyReLU(0.2), tf.keras.layers.Conv2D(filters=channels, kernel_size=(1, 1), use_bias=True, name=f'blocks.{idx}.{num_layers + 2}') ] self.blocks.append(block) self.shortcuts = [ tf.keras.layers.Conv2D(channels, kernel_size=1, use_bias=True, name=f'shortcuts.{i}') for i in range(num_res_blocks) ] def call(self, x): for block, shortcut in zip(self.blocks, self.shortcuts): res = shortcut(x) for layer in block: x = layer(x) x += res return x