# Copyright (c) OpenMMLab. All rights reserved. # Modified from https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/utils/fuse_helper.py # noqa # and https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/modeling/rpn/modeling_bert.py # noqa import math from typing import Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from mmcv.cnn.bricks import DropPath from torch import Tensor try: from transformers import BertConfig, BertPreTrainedModel from transformers.modeling_utils import apply_chunking_to_forward from transformers.models.bert.modeling_bert import \ BertAttention as HFBertAttention from transformers.models.bert.modeling_bert import \ BertIntermediate as HFBertIntermediate from transformers.models.bert.modeling_bert import \ BertOutput as HFBertOutput except ImportError: BertConfig = None BertPreTrainedModel = object apply_chunking_to_forward = None HFBertAttention = object HFBertIntermediate = object HFBertOutput = object MAX_CLAMP_VALUE = 50000 def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor: """Permute and then flatten a tensor, from size (N, A, C, H, W) to (N, H * W * A, C). Args: layer (Tensor): Tensor of shape (N, C, H, W). N (int): Batch size. A (int): Number of attention heads. C (int): Number of channels. H (int): Height of feature map. W (int): Width of feature map. Returns: Tensor: A Tensor of shape (N, H * W * A, C). """ layer = layer.view(N, A, C, H, W) layer = layer.permute(0, 3, 4, 1, 2) layer = layer.reshape(N, -1, C) return layer def clamp_values(vector: Tensor) -> Tensor: """Clamp the values of a vector to the range [-MAX_CLAMP_VALUE, MAX_CLAMP_VALUE]. Args: vector (Tensor): Tensor of shape (N, C, H, W). Returns: Tensor: A Tensor of shape (N, C, H, W) with clamped values. """ vector = torch.clamp(vector, min=-MAX_CLAMP_VALUE, max=MAX_CLAMP_VALUE) return vector class BiMultiHeadAttention(nn.Module): """Bidirectional fusion Multi-Head Attention layer. Args: v_dim (int): The dimension of the vision input. l_dim (int): The dimension of the language input. embed_dim (int): The embedding dimension for the attention operation. num_heads (int): The number of attention heads. dropout (float, optional): The dropout probability. Defaults to 0.1. """ def __init__(self, v_dim: int, l_dim: int, embed_dim: int, num_heads: int, dropout: float = 0.1): super(BiMultiHeadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.v_dim = v_dim self.l_dim = l_dim assert ( self.head_dim * self.num_heads == self.embed_dim ), 'embed_dim must be divisible by num_heads ' \ f'(got `embed_dim`: {self.embed_dim} ' \ f'and `num_heads`: {self.num_heads}).' self.scale = self.head_dim**(-0.5) self.dropout = dropout self.v_proj = nn.Linear(self.v_dim, self.embed_dim) self.l_proj = nn.Linear(self.l_dim, self.embed_dim) self.values_v_proj = nn.Linear(self.v_dim, self.embed_dim) self.values_l_proj = nn.Linear(self.l_dim, self.embed_dim) self.out_v_proj = nn.Linear(self.embed_dim, self.v_dim) self.out_l_proj = nn.Linear(self.embed_dim, self.l_dim) self.stable_softmax_2d = False self.clamp_min_for_underflow = True self.clamp_max_for_overflow = True self._reset_parameters() def _shape(self, tensor: Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def _reset_parameters(self): nn.init.xavier_uniform_(self.v_proj.weight) self.v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.l_proj.weight) self.l_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.values_v_proj.weight) self.values_v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.values_l_proj.weight) self.values_l_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.out_v_proj.weight) self.out_v_proj.bias.data.fill_(0) nn.init.xavier_uniform_(self.out_l_proj.weight) self.out_l_proj.bias.data.fill_(0) def forward( self, vision: Tensor, lang: Tensor, attention_mask_v: Optional[Tensor] = None, attention_mask_l: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: bsz, tgt_len, _ = vision.size() query_states = self.v_proj(vision) * self.scale key_states = self._shape(self.l_proj(lang), -1, bsz) value_v_states = self._shape(self.values_v_proj(vision), -1, bsz) value_l_states = self._shape(self.values_l_proj(lang), -1, bsz) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_v_states = value_v_states.view(*proj_shape) value_l_states = value_l_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f'Attention weights should be of ' f'size {(bsz * self.num_heads, tgt_len, src_len)}, ' f'but is {attn_weights.size()}') if self.stable_softmax_2d: attn_weights = attn_weights - attn_weights.max() if self.clamp_min_for_underflow: # Do not increase -50000, data type half has quite limited range attn_weights = torch.clamp(attn_weights, min=-MAX_CLAMP_VALUE) if self.clamp_max_for_overflow: # Do not increase 50000, data type half has quite limited range attn_weights = torch.clamp(attn_weights, max=MAX_CLAMP_VALUE) attn_weights_T = attn_weights.transpose(1, 2) attn_weights_l = ( attn_weights_T - torch.max(attn_weights_T, dim=-1, keepdim=True)[0]) if self.clamp_min_for_underflow: # Do not increase -50000, data type half has quite limited range attn_weights_l = torch.clamp(attn_weights_l, min=-MAX_CLAMP_VALUE) if self.clamp_max_for_overflow: # Do not increase 50000, data type half has quite limited range attn_weights_l = torch.clamp(attn_weights_l, max=MAX_CLAMP_VALUE) if attention_mask_v is not None: attention_mask_v = ( attention_mask_v[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)) attn_weights_l.masked_fill_(attention_mask_v, float('-inf')) attn_weights_l = attn_weights_l.softmax(dim=-1) if attention_mask_l is not None: assert (attention_mask_l.dim() == 2) attention_mask = attention_mask_l.unsqueeze(1).unsqueeze(1) attention_mask = attention_mask.expand(bsz, 1, tgt_len, src_len) attention_mask = attention_mask.masked_fill( attention_mask == 0, -9e15) if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError('Attention mask should be of ' f'size {(bsz, 1, tgt_len, src_len)}') attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights_v = nn.functional.softmax(attn_weights, dim=-1) attn_probs_v = F.dropout( attn_weights_v, p=self.dropout, training=self.training) attn_probs_l = F.dropout( attn_weights_l, p=self.dropout, training=self.training) attn_output_v = torch.bmm(attn_probs_v, value_l_states) attn_output_l = torch.bmm(attn_probs_l, value_v_states) if attn_output_v.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( '`attn_output_v` should be of ' f'size {(bsz, self.num_heads, tgt_len, self.head_dim)}, ' f'but is {attn_output_v.size()}') if attn_output_l.size() != (bsz * self.num_heads, src_len, self.head_dim): raise ValueError( '`attn_output_l` should be of size ' f'{(bsz, self.num_heads, src_len, self.head_dim)}, ' f'but is {attn_output_l.size()}') attn_output_v = attn_output_v.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output_v = attn_output_v.transpose(1, 2) attn_output_v = attn_output_v.reshape(bsz, tgt_len, self.embed_dim) attn_output_l = attn_output_l.view(bsz, self.num_heads, src_len, self.head_dim) attn_output_l = attn_output_l.transpose(1, 2) attn_output_l = attn_output_l.reshape(bsz, src_len, self.embed_dim) attn_output_v = self.out_v_proj(attn_output_v) attn_output_l = self.out_l_proj(attn_output_l) return attn_output_v, attn_output_l class BiAttentionBlock(nn.Module): """BiAttentionBlock Module: First, multi-level visual features are concat; Then the concat visual feature and lang feature are fused by attention; Finally the newly visual feature are split into multi levels. Args: v_dim (int): The dimension of the visual features. l_dim (int): The dimension of the language feature. embed_dim (int): The embedding dimension for the attention operation. num_heads (int): The number of attention heads. dropout (float, optional): The dropout probability. Defaults to 0.1. drop_path (float, optional): The drop path probability. Defaults to 0.0. init_values (float, optional): The initial value for the scaling parameter. Defaults to 1e-4. """ def __init__(self, v_dim: int, l_dim: int, embed_dim: int, num_heads: int, dropout: float = 0.1, drop_path: float = .0, init_values: float = 1e-4): super().__init__() # pre layer norm self.layer_norm_v = nn.LayerNorm(v_dim) self.layer_norm_l = nn.LayerNorm(l_dim) self.attn = BiMultiHeadAttention( v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) # add layer scale for training stability self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.gamma_v = nn.Parameter( init_values * torch.ones(v_dim), requires_grad=True) self.gamma_l = nn.Parameter( init_values * torch.ones(l_dim), requires_grad=True) def forward(self, vf0: Tensor, vf1: Tensor, vf2: Tensor, vf3: Tensor, vf4: Tensor, lang_feature: Tensor, attention_mask_l=None): visual_features = [vf0, vf1, vf2, vf3, vf4] size_per_level, visual_features_flatten = [], [] for i, feat_per_level in enumerate(visual_features): bs, c, h, w = feat_per_level.shape size_per_level.append([h, w]) feat = permute_and_flatten(feat_per_level, bs, -1, c, h, w) visual_features_flatten.append(feat) visual_features_flatten = torch.cat(visual_features_flatten, dim=1) new_v, new_lang_feature = self.single_attention_call( visual_features_flatten, lang_feature, attention_mask_l=attention_mask_l) # [bs, N, C] -> [bs, C, N] new_v = new_v.transpose(1, 2).contiguous() start = 0 # fvfs is mean fusion_visual_features fvfs = [] for (h, w) in size_per_level: new_v_per_level = new_v[:, :, start:start + h * w].view(bs, -1, h, w).contiguous() fvfs.append(new_v_per_level) start += h * w return fvfs[0], fvfs[1], fvfs[2], fvfs[3], fvfs[4], new_lang_feature def single_attention_call( self, visual: Tensor, lang: Tensor, attention_mask_v: Optional[Tensor] = None, attention_mask_l: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """Perform a single attention call between the visual and language inputs. Args: visual (Tensor): The visual input tensor. lang (Tensor): The language input tensor. attention_mask_v (Optional[Tensor]): An optional attention mask tensor for the visual input. attention_mask_l (Optional[Tensor]): An optional attention mask tensor for the language input. Returns: Tuple[Tensor, Tensor]: A tuple containing the updated visual and language tensors after the attention call. """ visual = self.layer_norm_v(visual) lang = self.layer_norm_l(lang) delta_v, delta_l = self.attn( visual, lang, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l) # visual, lang = visual + delta_v, l + delta_l visual = visual + self.drop_path(self.gamma_v * delta_v) lang = lang + self.drop_path(self.gamma_l * delta_l) return visual, lang class SingleScaleBiAttentionBlock(BiAttentionBlock): """This is a single-scale implementation of `BiAttentionBlock`. The only differenece between it and `BiAttentionBlock` is that the `forward` function of `SingleScaleBiAttentionBlock` only accepts a single flatten visual feature map, while the `forward` function in `BiAttentionBlock` accepts multiple visual feature maps. """ def forward(self, visual_feature: Tensor, lang_feature: Tensor, attention_mask_v=None, attention_mask_l=None): """Single-scale forward pass. Args: visual_feature (Tensor): The visual input tensor. Tensor of shape (bs, patch_len, ch). lang_feature (Tensor): The language input tensor. Tensor of shape (bs, text_len, ch). attention_mask_v (_type_, optional): Visual feature attention mask. Defaults to None. attention_mask_l (_type_, optional): Language feature attention mask.Defaults to None. """ new_v, new_lang_feature = self.single_attention_call( visual_feature, lang_feature, attention_mask_v=attention_mask_v, attention_mask_l=attention_mask_l) return new_v, new_lang_feature class VLFuse(nn.Module): """Early Fusion Module. Args: v_dim (int): Dimension of visual features. l_dim (int): Dimension of language features. embed_dim (int): The embedding dimension for the attention operation. num_heads (int): Number of attention heads. dropout (float): Dropout probability. drop_path (float): Drop path probability. use_checkpoint (bool): Whether to use PyTorch's checkpoint function. """ def __init__(self, v_dim: int = 256, l_dim: int = 768, embed_dim: int = 2048, num_heads: int = 8, dropout: float = 0.1, drop_path: float = 0.0, use_checkpoint: bool = False): super().__init__() self.use_checkpoint = use_checkpoint self.b_attn = BiAttentionBlock( v_dim=v_dim, l_dim=l_dim, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, drop_path=drop_path, init_values=1.0 / 6.0) def forward(self, x: dict) -> dict: """Forward pass of the VLFuse module.""" visual_features = x['visual'] language_dict_features = x['lang'] if self.use_checkpoint: # vf is mean visual_features # checkpoint does not allow complex data structures as input, # such as list, so we must split them. vf0, vf1, vf2, vf3, vf4, language_features = checkpoint.checkpoint( self.b_attn, *visual_features, language_dict_features['hidden'], language_dict_features['masks']) else: vf0, vf1, vf2, vf3, vf4, language_features = self.b_attn( *visual_features, language_dict_features['hidden'], language_dict_features['masks']) language_dict_features['hidden'] = language_features fused_language_dict_features = language_dict_features features_dict = { 'visual': [vf0, vf1, vf2, vf3, vf4], 'lang': fused_language_dict_features } return features_dict class BertEncoderLayer(BertPreTrainedModel): """A modified version of the `BertLayer` class from the `transformers.models.bert.modeling_bert` module. Args: config (:class:`~transformers.BertConfig`): The configuration object that contains various parameters for the model. clamp_min_for_underflow (bool, optional): Whether to clamp the minimum value of the hidden states to prevent underflow. Defaults to `False`. clamp_max_for_overflow (bool, optional): Whether to clamp the maximum value of the hidden states to prevent overflow. Defaults to `False`. """ def __init__(self, config: BertConfig, clamp_min_for_underflow: bool = False, clamp_max_for_overflow: bool = False): super().__init__(config) self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = BertAttention(config, clamp_min_for_underflow, clamp_max_for_overflow) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward( self, inputs: Dict[str, Dict[str, torch.Tensor]] ) -> Dict[str, Dict[str, torch.Tensor]]: """Applies the BertEncoderLayer to the input features.""" language_dict_features = inputs['lang'] hidden_states = language_dict_features['hidden'] attention_mask = language_dict_features['masks'] device = hidden_states.device input_shape = hidden_states.size()[:-1] extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_shape, device) self_attention_outputs = self.attention( hidden_states, extended_attention_mask, None, output_attentions=False, past_key_value=None) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output) outputs = (layer_output, ) + outputs hidden_states = outputs[0] language_dict_features['hidden'] = hidden_states features_dict = { 'visual': inputs['visual'], 'lang': language_dict_features } return features_dict def feed_forward_chunk(self, attention_output: Tensor) -> Tensor: """Applies the intermediate and output layers of the BertEncoderLayer to a chunk of the input sequence.""" intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output # The following code is the same as the Huggingface code, # with the only difference being the additional clamp operation. class BertSelfAttention(nn.Module): """BERT self-attention layer from Huggingface transformers. Compared to the BertSelfAttention of Huggingface, only add the clamp. Args: config (:class:`~transformers.BertConfig`): The configuration object that contains various parameters for the model. clamp_min_for_underflow (bool, optional): Whether to clamp the minimum value of the hidden states to prevent underflow. Defaults to `False`. clamp_max_for_overflow (bool, optional): Whether to clamp the maximum value of the hidden states to prevent overflow. Defaults to `False`. """ def __init__(self, config: BertConfig, clamp_min_for_underflow: bool = False, clamp_max_for_overflow: bool = False): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and \ not hasattr(config, 'embedding_size'): raise ValueError(f'The hidden size ({config.hidden_size}) is ' 'not a multiple of the number of attention ' f'heads ({config.num_attention_heads})') self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * \ self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr(config, 'position_embedding_type', 'absolute') if self.position_embedding_type == 'relative_key' or \ self.position_embedding_type == 'relative_key_query': self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding( 2 * config.max_position_embeddings - 1, self.attention_head_size) self.clamp_min_for_underflow = clamp_min_for_underflow self.clamp_max_for_overflow = clamp_max_for_overflow self.is_decoder = config.is_decoder def transpose_for_scores(self, x: Tensor) -> Tensor: """Transpose the dimensions of `x`.""" new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states: Tensor, attention_mask: Optional[Tensor] = None, head_mask: Optional[Tensor] = None, encoder_hidden_states: Optional[Tensor] = None, encoder_attention_mask: Optional[Tensor] = None, past_key_value: Optional[Tuple[Tensor, Tensor]] = None, output_attentions: bool = False, ) -> Tuple[Tensor, ...]: """Perform a forward pass through the BERT self-attention layer.""" mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_layer = past_key_value[0] value_layer = past_key_value[1] attention_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores( self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores( self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) if self.is_decoder: past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" # to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == 'relative_key' or \ self.position_embedding_type == 'relative_key_query': seq_length = hidden_states.size()[1] position_ids_l = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding( distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to( dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == 'relative_key': relative_position_scores = torch.einsum( 'bhld,lrd->bhlr', query_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == 'relative_key_query': relative_position_scores_query = torch.einsum( 'bhld,lrd->bhlr', query_layer, positional_embedding) relative_position_scores_key = torch.einsum( 'bhrd,lrd->bhlr', key_layer, positional_embedding) attention_scores = attention_scores + \ relative_position_scores_query + \ relative_position_scores_key attention_scores = attention_scores / math.sqrt( self.attention_head_size) if self.clamp_min_for_underflow: attention_scores = torch.clamp( attention_scores, min=-MAX_CLAMP_VALUE ) # Do not increase -50000, data type half has quite limited range if self.clamp_max_for_overflow: attention_scores = torch.clamp( attention_scores, max=MAX_CLAMP_VALUE ) # Do not increase 50000, data type half has quite limited range if attention_mask is not None: # Apply the attention mask is # (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs = attention_probs * head_mask context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + ( self.all_head_size, ) context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer, ) if self.is_decoder: outputs = outputs + (past_key_value, ) return outputs class BertAttention(HFBertAttention): """BertAttention is made up of self-attention and intermediate+output. Compared to the BertAttention of Huggingface, only add the clamp. Args: config (:class:`~transformers.BertConfig`): The configuration object that contains various parameters for the model. clamp_min_for_underflow (bool, optional): Whether to clamp the minimum value of the hidden states to prevent underflow. Defaults to `False`. clamp_max_for_overflow (bool, optional): Whether to clamp the maximum value of the hidden states to prevent overflow. Defaults to `False`. """ def __init__(self, config: BertConfig, clamp_min_for_underflow: bool = False, clamp_max_for_overflow: bool = False): super().__init__(config) self.self = BertSelfAttention(config, clamp_min_for_underflow, clamp_max_for_overflow) class BertIntermediate(HFBertIntermediate): """Modified from transformers.models.bert.modeling_bert.BertIntermediate. Compared to the BertIntermediate of Huggingface, only add the clamp. """ def forward(self, hidden_states: Tensor) -> Tensor: hidden_states = self.dense(hidden_states) hidden_states = clamp_values(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = clamp_values(hidden_states) return hidden_states class BertOutput(HFBertOutput): """Modified from transformers.models.bert.modeling_bert.BertOutput. Compared to the BertOutput of Huggingface, only add the clamp. """ def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = clamp_values(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = clamp_values(hidden_states) return hidden_states