# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """VisionTransformer models.""" import math from typing import Optional, Tuple from absl import logging import tensorflow as tf, tf_keras from official.modeling import activations from official.vision.modeling.backbones import factory from official.vision.modeling.backbones.vit_specs import VIT_SPECS from official.vision.modeling.layers import nn_blocks from official.vision.modeling.layers import nn_layers layers = tf_keras.layers class AddPositionEmbs(layers.Layer): """Adds (optionally learned) positional embeddings to the inputs.""" def __init__(self, posemb_init: Optional[tf_keras.initializers.Initializer] = None, posemb_origin_shape: Optional[Tuple[int, int]] = None, posemb_target_shape: Optional[Tuple[int, int]] = None, **kwargs): """Constructs Positional Embedding module. The logic of this module is: the learnable positional embeddings length will be determined by the inputs_shape or posemb_origin_shape (if provided) during the construction. If the posemb_target_shape is provided and is different from the positional embeddings length, the embeddings will be interpolated during the forward call. Args: posemb_init: The positional embedding initializer. posemb_origin_shape: The intended positional embedding shape. posemb_target_shape: The potential target shape positional embedding may be interpolated to. **kwargs: other args. """ super().__init__(**kwargs) self.posemb_init = posemb_init self.posemb_origin_shape = posemb_origin_shape self.posemb_target_shape = posemb_target_shape def build(self, inputs_shape): if self.posemb_origin_shape is not None: pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1] else: pos_emb_length = inputs_shape[1] pos_emb_shape = (1, pos_emb_length, inputs_shape[2]) self.pos_embedding = self.add_weight( 'pos_embedding', pos_emb_shape, initializer=self.posemb_init) def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int], to_shape: Tuple[int, int]) -> tf.Tensor: """Interpolates the positional embeddings.""" logging.info('Interpolating postional embedding from length: %s to %s', from_shape, to_shape) grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1]) # NOTE: Using BILINEAR interpolation by default. grid_emb = tf.image.resize(grid_emb, to_shape) return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1]) def call(self, inputs, inputs_positions=None): del inputs_positions pos_embedding = self.pos_embedding # inputs.shape is (batch_size, seq_len, emb_dim). if inputs.shape[1] != pos_embedding.shape[1]: pos_embedding = self._interpolate( pos_embedding, from_shape=self.posemb_origin_shape, to_shape=self.posemb_target_shape) pos_embedding = tf.cast(pos_embedding, inputs.dtype) return inputs + pos_embedding class TokenLayer(layers.Layer): """A simple layer to wrap token parameters.""" def build(self, inputs_shape): self.cls = self.add_weight( 'cls', (1, 1, inputs_shape[-1]), initializer='zeros') def call(self, inputs): cls = tf.cast(self.cls, inputs.dtype) cls = cls + tf.zeros_like(inputs[:, 0:1]) # A hacky way to tile. x = tf.concat([cls, inputs], axis=1) return x class Encoder(layers.Layer): """Transformer Encoder.""" def __init__(self, num_layers, mlp_dim, num_heads, dropout_rate=0.1, attention_dropout_rate=0.1, kernel_regularizer=None, inputs_positions=None, init_stochastic_depth_rate=0.0, kernel_initializer='glorot_uniform', add_pos_embed=True, pos_embed_origin_shape=None, pos_embed_target_shape=None, layer_scale_init_value=0.0, transformer_partition_dims=None, **kwargs): super().__init__(**kwargs) self._num_layers = num_layers self._mlp_dim = mlp_dim self._num_heads = num_heads self._dropout_rate = dropout_rate self._attention_dropout_rate = attention_dropout_rate self._kernel_regularizer = kernel_regularizer self._inputs_positions = inputs_positions self._init_stochastic_depth_rate = init_stochastic_depth_rate self._kernel_initializer = kernel_initializer self._add_pos_embed = add_pos_embed self._pos_embed_origin_shape = pos_embed_origin_shape self._pos_embed_target_shape = pos_embed_target_shape self._layer_scale_init_value = layer_scale_init_value self._transformer_partition_dims = transformer_partition_dims def build(self, input_shape): if self._add_pos_embed: self._pos_embed = AddPositionEmbs( posemb_init=tf_keras.initializers.RandomNormal(stddev=0.02), posemb_origin_shape=self._pos_embed_origin_shape, posemb_target_shape=self._pos_embed_target_shape, name='posembed_input') self._dropout = layers.Dropout(rate=self._dropout_rate) self._encoder_layers = [] # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation. # https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html for i in range(self._num_layers): encoder_layer = nn_blocks.TransformerEncoderBlock( inner_activation=activations.gelu, num_attention_heads=self._num_heads, inner_dim=self._mlp_dim, output_dropout=self._dropout_rate, attention_dropout=self._attention_dropout_rate, kernel_regularizer=self._kernel_regularizer, kernel_initializer=self._kernel_initializer, norm_first=True, stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate( self._init_stochastic_depth_rate, i + 1, self._num_layers), norm_epsilon=1e-6, layer_scale_init_value=self._layer_scale_init_value, transformer_partition_dims=self._transformer_partition_dims) self._encoder_layers.append(encoder_layer) self._norm = layers.LayerNormalization(epsilon=1e-6) super().build(input_shape) def call(self, inputs, training=None): x = inputs if self._add_pos_embed: x = self._pos_embed(x, inputs_positions=self._inputs_positions) x = self._dropout(x, training=training) for encoder_layer in self._encoder_layers: x = encoder_layer(x, training=training) x = self._norm(x) return x def get_config(self): config = super().get_config() updates = { 'num_layers': self._num_layers, 'mlp_dim': self._mlp_dim, 'num_heads': self._num_heads, 'dropout_rate': self._dropout_rate, 'attention_dropout_rate': self._attention_dropout_rate, 'kernel_regularizer': self._kernel_regularizer, 'inputs_positions': self._inputs_positions, 'init_stochastic_depth_rate': self._init_stochastic_depth_rate, 'kernel_initializer': self._kernel_initializer, 'add_pos_embed': self._add_pos_embed, 'pos_embed_origin_shape': self._pos_embed_origin_shape, 'pos_embed_target_shape': self._pos_embed_target_shape, 'layer_scale_init_value': self._layer_scale_init_value, 'transformer_partition_dims': self._transformer_partition_dims, } config.update(updates) return config class VisionTransformer(tf_keras.Model): """Class to build VisionTransformer family model.""" def __init__( self, mlp_dim=3072, num_heads=12, num_layers=12, attention_dropout_rate=0.0, dropout_rate=0.1, init_stochastic_depth_rate=0.0, input_specs=layers.InputSpec(shape=[None, None, None, 3]), patch_size=16, hidden_size=768, representation_size=0, pooler='token', kernel_regularizer=None, original_init: bool = True, output_encoded_tokens: bool = True, output_2d_feature_maps: bool = False, pos_embed_shape: Optional[Tuple[int, int]] = None, layer_scale_init_value: float = 0.0, transformer_partition_dims: Optional[Tuple[int, int, int, int]] = None, ): """VisionTransformer initialization function.""" self._mlp_dim = mlp_dim self._num_heads = num_heads self._num_layers = num_layers self._hidden_size = hidden_size self._patch_size = patch_size inputs = tf_keras.Input(shape=input_specs.shape[1:]) x = layers.Conv2D( filters=hidden_size, kernel_size=patch_size, strides=patch_size, padding='valid', kernel_regularizer=kernel_regularizer, kernel_initializer='lecun_normal' if original_init else 'he_uniform')( inputs) if tf_keras.backend.image_data_format() == 'channels_last': rows_axis, cols_axis = (1, 2) else: rows_axis, cols_axis = (2, 3) # The reshape below assumes the data_format is 'channels_last,' so # transpose to that. Once the data is flattened by the reshape, the # data_format is irrelevant, so no need to update # tf_keras.backend.image_data_format. x = tf.transpose(x, perm=[0, 2, 3, 1]) pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis]) feat_h = input_specs.shape[rows_axis] // patch_size feat_w = input_specs.shape[cols_axis] // patch_size seq_len = feat_h * feat_w x = tf.reshape(x, [-1, seq_len, hidden_size]) # If we want to add a class token, add it here. if pooler == 'token': x = TokenLayer(name='cls')(x) x = Encoder( num_layers=num_layers, mlp_dim=mlp_dim, num_heads=num_heads, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, kernel_regularizer=kernel_regularizer, kernel_initializer='glorot_uniform' if original_init else dict( class_name='TruncatedNormal', config=dict(stddev=.02)), init_stochastic_depth_rate=init_stochastic_depth_rate, pos_embed_origin_shape=pos_embed_shape, pos_embed_target_shape=pos_embed_target_shape, layer_scale_init_value=layer_scale_init_value)( x) if pooler == 'token': output_feature = x[:, 1:] x = x[:, 0] elif pooler == 'gap': output_feature = x x = tf.reduce_mean(x, axis=1) elif pooler == 'none': output_feature = x x = tf.identity(x, name='encoded_tokens') else: raise ValueError(f'unrecognized pooler type: {pooler}') endpoints = {} if output_2d_feature_maps: # Use the closest feature level. feat_level = round(math.log2(patch_size)) logging.info( 'VisionTransformer patch size %d and feature level: %d', patch_size, feat_level, ) endpoints[str(feat_level)] = tf.reshape( output_feature, [-1, feat_h, feat_w, x.shape.as_list()[-1]]) # Don"t include `pre_logits` or `encoded_tokens` to support decoders. self._output_specs = {k: v.shape for k, v in endpoints.items()} if representation_size: x = layers.Dense( representation_size, kernel_regularizer=kernel_regularizer, name='pre_logits', kernel_initializer='lecun_normal' if original_init else 'he_uniform', )(x) x = tf.nn.tanh(x) else: x = tf.identity(x, name='pre_logits') if pooler == 'none': if output_encoded_tokens: endpoints['encoded_tokens'] = x else: endpoints['pre_logits'] = tf.reshape( x, [-1, 1, 1, representation_size or hidden_size]) super().__init__(inputs=inputs, outputs=endpoints) @property def output_specs(self): """A dict of {level: TensorShape} pairs for the model output.""" return self._output_specs @factory.register_backbone_builder('vit') def build_vit(input_specs, backbone_config, norm_activation_config, l2_regularizer=None): """Build ViT model.""" del norm_activation_config backbone_type = backbone_config.type backbone_cfg = backbone_config.get() assert backbone_type == 'vit', (f'Inconsistent backbone type ' f'{backbone_type}') backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name]) logging.info( ( 'ViT specs: mlp_dim=%d, num_heads=%d, num_layers=%d,' 'patch_size=%d, hidden_size=%d, representation_size=%d.' ), backbone_cfg.transformer.mlp_dim, backbone_cfg.transformer.num_heads, backbone_cfg.transformer.num_layers, backbone_cfg.patch_size, backbone_cfg.hidden_size, backbone_cfg.representation_size, ) return VisionTransformer( mlp_dim=backbone_cfg.transformer.mlp_dim, num_heads=backbone_cfg.transformer.num_heads, num_layers=backbone_cfg.transformer.num_layers, attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate, dropout_rate=backbone_cfg.transformer.dropout_rate, init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate, input_specs=input_specs, patch_size=backbone_cfg.patch_size, hidden_size=backbone_cfg.hidden_size, representation_size=backbone_cfg.representation_size, pooler=backbone_cfg.pooler, kernel_regularizer=l2_regularizer, original_init=backbone_cfg.original_init, output_encoded_tokens=backbone_cfg.output_encoded_tokens, output_2d_feature_maps=backbone_cfg.output_2d_feature_maps, layer_scale_init_value=backbone_cfg.layer_scale_init_value, pos_embed_shape=backbone_cfg.pos_embed_shape, transformer_partition_dims=backbone_cfg.transformer_partition_dims)