File size: 8,944 Bytes
8eff58f a0b398e 27888bb a0b398e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
from base_bert import *
from everything import *
class BertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
# Initialize the linear transformation layers for key, value, query.
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
# This dropout is applied to normalized attention scores following the original
# implementation of transformer. Although it is a bit unusual, we empirically
# observe that it yields better performance.
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transform(self, x, linear_layer):
# The corresponding linear_layer of k, v, q are used to project the hidden_state (x).
bs, seq_len = x.shape[:2]
proj = linear_layer(x)
# Next, we need to produce multiple heads for the proj. This is done by spliting the
# hidden state to self.num_attention_heads, each of size self.attention_head_size.
proj = proj.view(bs, seq_len, self.num_attention_heads, self.attention_head_size)
# By proper transpose, we have proj of size [bs, num_attention_heads, seq_len, attention_head_size].
proj = proj.transpose(1, 2)
return proj
def attention(self, key, query, value, attention_mask):
"""
key, query, value: [batch_size, num_attention_heads, seq_len, attention_head_size]
attention_mask: [batch_size, 1, 1, seq_len], masks padding tokens in the input.
"""
d_k = query.size(-1) # attention_head_size
attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
# attention_scores shape: [batch_size, num_attention_heads, seq_len, seq_len]
# Apply attention mask
attention_scores = attention_scores + attention_mask
# Normalize scores with softmax and apply dropout.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context = torch.matmul(attention_probs, value)
# context shape: [batch_size, num_attention_heads, seq_len, attention_head_size]
# Concatenate all attention heads to recover original shape: [batch_size, seq_len, hidden_size]
context = context.transpose(1, 2).contiguous()
context = context.view(context.size(0), context.size(1), -1)
return context
def forward(self, hidden_states, attention_mask):
"""
hidden_states: [bs, seq_len, hidden_size]
attention_mask: [bs, 1, 1, seq_len]
output: [bs, seq_len, hidden_state]
"""
# First, we have to generate the key, value, query for each token for multi-head attention
# using self.transform (more details inside the function).
# Size of *_layer is [bs, num_attention_heads, seq_len, attention_head_size].
key_layer = self.transform(hidden_states, self.key)
value_layer = self.transform(hidden_states, self.value)
query_layer = self.transform(hidden_states, self.query)
# Calculate the multi-head attention.
attn_value = self.attention(key_layer, query_layer, value_layer, attention_mask)
return attn_value
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
# Multi-head attention.
self.self_attention = BertSelfAttention(config)
# Add-norm for multi-head attention.
self.attention_dense = nn.Linear(config.hidden_size, config.hidden_size)
self.attention_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention_dropout = nn.Dropout(config.hidden_dropout_prob)
# Feed forward.
self.interm_dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.interm_af = F.gelu
# Add-norm for feed forward.
self.out_dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.out_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.out_dropout = nn.Dropout(config.hidden_dropout_prob)
def add_norm(self, input, output, dense_layer, dropout, ln_layer):
transformed_output = dense_layer(output) # Biến đổi output bằng dense_layer
transformed_output = dropout(transformed_output) # Áp dụng dropout
added_output = input + transformed_output # Kết hợp input và output
normalized_output = ln_layer(added_output) # Áp dụng chuẩn hóa
return normalized_output
def forward(self, hidden_states, attention_mask):
# 1. Multi-head attention
attention_output = self.self_attention(hidden_states, attention_mask)
# 2. Add-norm after attention
attention_output = self.add_norm(
hidden_states,
attention_output,
self.attention_dense,
self.attention_dropout,
self.attention_layer_norm
)
# 3. Feed-forward network
intermediate_output = self.interm_af(self.interm_dense(attention_output))
# 4. Add-norm after feed-forward
layer_output = self.add_norm(
attention_output,
intermediate_output,
self.out_dense,
self.out_dropout,
self.out_layer_norm
)
return layer_output
class BertModel(BertPreTrainedModel):
"""
The BERT model returns the final embeddings for each token in a sentence.
The model consists of:
1. Embedding layers (used in self.embed).
2. A stack of n BERT layers (used in self.encode).
3. A linear transformation layer for the [CLS] token (used in self.forward, as given).
"""
def __init__(self, config):
super().__init__(config)
self.config = config
# Embedding layers.
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.pos_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.tk_type_embedding = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.embed_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
# Register position_ids (1, len position emb) to buffer because it is a constant.
position_ids = torch.arange(config.max_position_embeddings).unsqueeze(0)
self.register_buffer('position_ids', position_ids)
# BERT encoder.
self.bert_layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
# [CLS] token transformations.
self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
self.pooler_af = nn.Tanh()
self.init_weights()
def embed(self, input_ids):
input_shape = input_ids.size()
seq_length = input_shape[1]
inputs_embeds = self.word_embedding(input_ids)
pos_ids = self.position_ids[:, :seq_length]
pos_embeds = self.pos_embedding(pos_ids)
# Since we are not considering token type, this embedding is just a placeholder.
tk_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
tk_type_embeds = self.tk_type_embedding(tk_type_ids)
embeddings = inputs_embeds + pos_embeds + tk_type_embeds
embeddings = self.embed_layer_norm(embeddings)
embeddings = self.embed_dropout(embeddings)
return embeddings
def encode(self, hidden_states, attention_mask):
"""
hidden_states: the output from the embedding layer [batch_size, seq_len, hidden_size]
attention_mask: [batch_size, seq_len]
"""
# Get the extended attention mask for self-attention.
# Returns extended_attention_mask of size [batch_size, 1, 1, seq_len].
# Distinguishes between non-padding tokens (with a value of 0) and padding tokens
# (with a value of a large negative number).
extended_attention_mask: torch.Tensor = get_extended_attention_mask(attention_mask, self.dtype)
# Pass the hidden states through the encoder layers.
for i, layer_module in enumerate(self.bert_layers):
# Feed the encoding from the last bert_layer to the next.
hidden_states = layer_module(hidden_states, extended_attention_mask)
return hidden_states
def forward(self, input_ids, attention_mask):
"""
input_ids: [batch_size, seq_len], seq_len is the max length of the batch
attention_mask: same size as input_ids, 1 represents non-padding tokens, 0 represents padding tokens
"""
# Get the embedding for each input token.
embedding_output = self.embed(input_ids=input_ids)
# Feed to a transformer (a stack of BertLayers).
sequence_output = self.encode(embedding_output, attention_mask=attention_mask)
# Get cls token hidden state.
first_tk = sequence_output[:, 0]
first_tk = self.pooler_dense(first_tk)
first_tk = self.pooler_af(first_tk)
return {'last_hidden_state': sequence_output, 'pooler_output': first_tk}
|