Update amplify.py
Browse files- amplify.py +15 -16
amplify.py
CHANGED
|
@@ -124,13 +124,13 @@ class EncoderBlock(nn.Module):
|
|
| 124 |
|
| 125 |
self.ffn_dropout = nn.Dropout(config.dropout_prob)
|
| 126 |
|
| 127 |
-
def forward(self, x: torch.Tensor,
|
| 128 |
-
attn, contact = self._att_block(self.attention_norm(x),
|
| 129 |
x = x + attn
|
| 130 |
x = x + self._ff_block(self.ffn_norm(x))
|
| 131 |
return x, contact
|
| 132 |
|
| 133 |
-
def _att_block(self, x: torch.Tensor,
|
| 134 |
batch_size, seq_len, _ = x.shape
|
| 135 |
xq, xk, xv = self.q(x), self.k(x), self.v(x)
|
| 136 |
|
|
@@ -144,15 +144,15 @@ class EncoderBlock(nn.Module):
|
|
| 144 |
query=xq,
|
| 145 |
key=xk,
|
| 146 |
value=xv,
|
| 147 |
-
attn_bias=
|
| 148 |
p=self.config.dropout_prob if self.training else 0,
|
| 149 |
)
|
| 150 |
|
| 151 |
_attn = None
|
| 152 |
if output_attentions:
|
| 153 |
_attn = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
|
| 154 |
-
if
|
| 155 |
-
_attn = _attn +
|
| 156 |
_attn = _attn.softmax(-1)
|
| 157 |
|
| 158 |
return self.resid_dropout(self.wo(attn.view(batch_size, seq_len, self.config.num_attention_heads * self.d_head))), _attn
|
|
@@ -203,28 +203,28 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
| 203 |
# Initialize weights and apply final processing
|
| 204 |
self.post_init()
|
| 205 |
|
| 206 |
-
def forward(self,
|
| 207 |
# Initialize
|
| 208 |
hidden_states, attentions = [], []
|
| 209 |
|
| 210 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
| 211 |
-
if
|
| 212 |
-
|
| 213 |
else:
|
| 214 |
-
|
| 215 |
|
| 216 |
# RoPE
|
| 217 |
-
self.freqs_cis = self.freqs_cis.to(
|
| 218 |
-
freqs_cis = self.freqs_cis[:
|
| 219 |
|
| 220 |
# Embedding
|
| 221 |
-
x = self.encoder(
|
| 222 |
if self.config.layer_norm_after_embedding:
|
| 223 |
x = self.layer_norm_1(x)
|
| 224 |
|
| 225 |
# Transformer encoder
|
| 226 |
for layer in self.transformer_encoder:
|
| 227 |
-
x, attn = layer(x,
|
| 228 |
if output_hidden_states:
|
| 229 |
hidden_states.append(x)
|
| 230 |
if output_attentions:
|
|
@@ -234,5 +234,4 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
| 234 |
logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
|
| 235 |
|
| 236 |
# Return logits or the output of the last hidden layer
|
| 237 |
-
return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
|
| 238 |
-
|
|
|
|
| 124 |
|
| 125 |
self.ffn_dropout = nn.Dropout(config.dropout_prob)
|
| 126 |
|
| 127 |
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
|
| 128 |
+
attn, contact = self._att_block(self.attention_norm(x), attention_mask, freqs_cis, output_attentions)
|
| 129 |
x = x + attn
|
| 130 |
x = x + self._ff_block(self.ffn_norm(x))
|
| 131 |
return x, contact
|
| 132 |
|
| 133 |
+
def _att_block(self, x: torch.Tensor, attention_mask: torch.Tensor, freqs_cis: torch.Tensor, output_attentions: bool):
|
| 134 |
batch_size, seq_len, _ = x.shape
|
| 135 |
xq, xk, xv = self.q(x), self.k(x), self.v(x)
|
| 136 |
|
|
|
|
| 144 |
query=xq,
|
| 145 |
key=xk,
|
| 146 |
value=xv,
|
| 147 |
+
attn_bias=attention_mask,
|
| 148 |
p=self.config.dropout_prob if self.training else 0,
|
| 149 |
)
|
| 150 |
|
| 151 |
_attn = None
|
| 152 |
if output_attentions:
|
| 153 |
_attn = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
|
| 154 |
+
if attention_mask is not None:
|
| 155 |
+
_attn = _attn + attention_mask
|
| 156 |
_attn = _attn.softmax(-1)
|
| 157 |
|
| 158 |
return self.resid_dropout(self.wo(attn.view(batch_size, seq_len, self.config.num_attention_heads * self.d_head))), _attn
|
|
|
|
| 203 |
# Initialize weights and apply final processing
|
| 204 |
self.post_init()
|
| 205 |
|
| 206 |
+
def forward(self, input_ids, attention_mask=None, output_hidden_states=False, output_attentions=False, **kwargs):
|
| 207 |
# Initialize
|
| 208 |
hidden_states, attentions = [], []
|
| 209 |
|
| 210 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
| 211 |
+
if attention_mask is not None and not torch.all(attention_mask == 0):
|
| 212 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
|
| 213 |
else:
|
| 214 |
+
attention_mask = None
|
| 215 |
|
| 216 |
# RoPE
|
| 217 |
+
self.freqs_cis = self.freqs_cis.to(input_ids.device, non_blocking=True)
|
| 218 |
+
freqs_cis = self.freqs_cis[: input_ids.shape[1]]
|
| 219 |
|
| 220 |
# Embedding
|
| 221 |
+
x = self.encoder(input_ids)
|
| 222 |
if self.config.layer_norm_after_embedding:
|
| 223 |
x = self.layer_norm_1(x)
|
| 224 |
|
| 225 |
# Transformer encoder
|
| 226 |
for layer in self.transformer_encoder:
|
| 227 |
+
x, attn = layer(x, attention_mask, freqs_cis, output_attentions)
|
| 228 |
if output_hidden_states:
|
| 229 |
hidden_states.append(x)
|
| 230 |
if output_attentions:
|
|
|
|
| 234 |
logits = self.decoder(self.layer_norm_2(x) if self.config.layer_norm_before_last_layer else x)
|
| 235 |
|
| 236 |
# Return logits or the output of the last hidden layer
|
| 237 |
+
return MaskedLMOutput(logits=logits, hidden_states=hidden_states, attentions=attentions)
|
|
|