import tensorflow as tf from .feature import ReduceSize @tf.keras.utils.register_keras_serializable(package="gcvit") class Stem(tf.keras.layers.Layer): def __init__(self, dim, **kwargs): super().__init__(**kwargs) self.dim = dim def build(self, input_shape): self.pad = tf.keras.layers.ZeroPadding2D(1, name='pad') self.proj = tf.keras.layers.Conv2D(self.dim, kernel_size=3, strides=2, name='proj') self.conv_down = ReduceSize(keep_dim=True, name='conv_down') super().build(input_shape) def call(self, inputs, **kwargs): x = self.pad(inputs) x = self.proj(x) x = self.conv_down(x) return x def get_config(self): config = super().get_config() config.update({'dim': self.dim}) return config