|
|
|
""" |
|
Pix2Seq Transformer class. |
|
|
|
Copy-paste from torch.nn.Transformer with modifications: |
|
* positional encodings are passed in MHattention |
|
* extra LN at the end of encoder is removed |
|
* decoder returns a stack of activations from all decoding layers |
|
""" |
|
import copy |
|
from typing import Optional, List |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn, Tensor |
|
from .attention_layer import Attention |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
|
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, |
|
num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, |
|
activation="relu", normalize_before=False, num_vocal=2094, |
|
pred_eos=False, tokenizer=None): |
|
super().__init__() |
|
|
|
encoder_layer = TransformerEncoderLayer( |
|
d_model, nhead, dim_feedforward, dropout, activation, normalize_before) |
|
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
|
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) |
|
|
|
decoder_layer = TransformerDecoderLayer( |
|
d_model, nhead, dim_feedforward, dropout, activation, normalize_before) |
|
decoder_norm = nn.LayerNorm(d_model) |
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) |
|
self._reset_parameters() |
|
|
|
self.num_vocal = num_vocal |
|
self.vocal_classifier = nn.Linear(d_model, num_vocal) |
|
self.det_embed = nn.Embedding(1, d_model) |
|
self.vocal_embed = nn.Embedding(self.num_vocal - 2, d_model) |
|
self.pred_eos = pred_eos |
|
|
|
self.d_model = d_model |
|
self.nhead = nhead |
|
self.num_decoder_layers = num_decoder_layers |
|
self.tokenizer = tokenizer |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward(self, src, input_seq, mask, pos_embed, max_len=500): |
|
""" |
|
Args: |
|
src: shape[B, C, H, W] |
|
input_seq: shape[B, 501, C] for training and shape[B, 1, C] for inference |
|
mask: shape[B, H, W] |
|
pos_embed: shape[B, C, H, W] |
|
""" |
|
|
|
bs = src.shape[0] |
|
src = src.flatten(2).permute(2, 0, 1) |
|
mask = mask.flatten(1) |
|
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) |
|
|
|
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) |
|
pre_kv = [torch.as_tensor([[], []], device=memory.device) |
|
for _ in range(self.num_decoder_layers)] |
|
|
|
if self.training: |
|
input_seq = input_seq.clamp(max=self.num_vocal - 3) |
|
input_embed = torch.cat( |
|
[self.det_embed.weight.unsqueeze(0).repeat(bs, 1, 1), |
|
self.vocal_embed(input_seq)], dim=1) |
|
input_embed = input_embed.transpose(0, 1) |
|
num_seq = input_embed.shape[0] |
|
self_attn_mask = torch.triu(torch.ones((num_seq, num_seq)), diagonal=1).bool().to(input_embed.device) |
|
hs, pre_kv = self.decoder( |
|
input_embed, |
|
memory, |
|
memory_key_padding_mask=mask, |
|
pos=pos_embed, |
|
pre_kv_list=pre_kv, |
|
self_attn_mask=self_attn_mask) |
|
|
|
pred_seq_logits = self.vocal_classifier(hs.transpose(0, 1)) |
|
return pred_seq_logits |
|
else: |
|
end = torch.zeros(bs).bool().to(memory.device) |
|
end_lens = torch.zeros(bs).long().to(memory.device) |
|
input_embed = self.det_embed.weight.unsqueeze(0).repeat(bs, 1, 1).transpose(0, 1) |
|
states, pred_token = [None] * bs, [None] * bs |
|
pred_seq, pred_scores = [], [] |
|
for seq_i in range(max_len): |
|
hs, pre_kv = self.decoder( |
|
input_embed, |
|
memory, |
|
memory_key_padding_mask=mask, |
|
pos=pos_embed, |
|
pre_kv_list=pre_kv) |
|
|
|
logits = self.vocal_classifier(hs.transpose(0, 1)) |
|
log_probs = F.log_softmax(logits, dim=-1) |
|
if self.tokenizer.output_constraint: |
|
states, output_masks = self.tokenizer.update_states_and_masks(states, pred_token) |
|
output_masks = torch.tensor(output_masks, device=logits.device).unsqueeze(1) |
|
log_probs.masked_fill_(output_masks, -10000) |
|
if not self.pred_eos: |
|
log_probs[:, :, self.tokenizer.EOS_ID] = -10000 |
|
|
|
score, pred_token = log_probs.max(dim=-1) |
|
pred_seq.append(pred_token) |
|
pred_scores.append(score) |
|
|
|
if self.pred_eos: |
|
stop_state = pred_token.squeeze(1).eq(self.tokenizer.EOS_ID) |
|
end_lens += seq_i * (~end * stop_state) |
|
end = (stop_state + end).bool() |
|
if end.all() and seq_i > 4: |
|
break |
|
|
|
token = log_probs[:, :, :self.num_vocal - 2].argmax(dim=-1) |
|
input_embed = self.vocal_embed(token.transpose(0, 1)) |
|
|
|
if not self.pred_eos: |
|
end_lens = end_lens.fill_(max_len) |
|
pred_seq = torch.cat(pred_seq, dim=1) |
|
pred_seq = [seq[:end_idx] for end_idx, seq in zip(end_lens, pred_seq)] |
|
pred_scores = torch.cat(pred_scores, dim=1) |
|
pred_scores = [scores[:end_idx] for end_idx, scores in zip(end_lens, pred_scores)] |
|
return pred_seq, pred_scores |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
|
|
def __init__(self, encoder_layer, num_layers, norm=None): |
|
super().__init__() |
|
self.layers = _get_clones(encoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
|
|
def forward(self, src, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
output = src |
|
|
|
for layer in self.layers: |
|
output = layer(output, src_key_padding_mask=src_key_padding_mask, pos=pos) |
|
|
|
if self.norm is not None: |
|
output = self.norm(output) |
|
|
|
return output |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
|
|
def __init__(self, decoder_layer, num_layers, norm=None): |
|
super().__init__() |
|
self.layers = _get_clones(decoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
|
|
def forward(self, tgt, memory, memory_key_padding_mask, pos, pre_kv_list=None, self_attn_mask=None): |
|
output = tgt |
|
cur_kv_list = [] |
|
for layer, pre_kv in zip(self.layers, pre_kv_list): |
|
output, cur_kv = layer( |
|
output, |
|
memory, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
pos=pos, |
|
self_attn_mask=self_attn_mask, |
|
pre_kv=pre_kv) |
|
cur_kv_list.append(cur_kv) |
|
|
|
if self.norm is not None: |
|
output = self.norm(output) |
|
|
|
return output, cur_kv_list |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post(self, |
|
src, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
q = k = self.with_pos_embed(src, pos) |
|
src2 = self.self_attn(q, k, value=src, key_padding_mask=src_key_padding_mask)[0] |
|
src = src + self.dropout1(src2) |
|
src = self.norm1(src) |
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
|
src = src + self.dropout2(src2) |
|
src = self.norm2(src) |
|
return src |
|
|
|
def forward_pre(self, src, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
src2 = self.norm1(src) |
|
q = k = self.with_pos_embed(src2, pos) |
|
src2 = self.self_attn(q, k, value=src2, key_padding_mask=src_key_padding_mask)[0] |
|
src = src + self.dropout1(src2) |
|
src2 = self.norm2(src) |
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) |
|
src = src + self.dropout2(src2) |
|
return src |
|
|
|
def forward(self, src, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
if self.normalize_before: |
|
return self.forward_pre(src, src_key_padding_mask, pos) |
|
return self.forward_post(src, src_key_padding_mask, pos) |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
self.self_attn = Attention(d_model, nhead, dropout=dropout) |
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.norm3 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post( |
|
self, |
|
tgt, |
|
memory, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
self_attn_mask: Optional[Tensor] = None, |
|
pre_kv=None, |
|
): |
|
tgt2, pre_kv = self.self_attn(tgt, pre_kv=pre_kv, attn_mask=self_attn_mask) |
|
tgt = tgt + self.dropout1(tgt2) |
|
tgt = self.norm1(tgt) |
|
tgt2 = self.multihead_attn( |
|
query=tgt, |
|
key=self.with_pos_embed(memory, pos), |
|
value=memory, |
|
key_padding_mask=memory_key_padding_mask, |
|
)[0] |
|
tgt = tgt + self.dropout2(tgt2) |
|
tgt = self.norm2(tgt) |
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
tgt = tgt + self.dropout3(tgt2) |
|
tgt = self.norm3(tgt) |
|
return tgt, pre_kv |
|
|
|
def forward_pre( |
|
self, |
|
tgt, |
|
memory, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
self_attn_mask: Optional[Tensor] = None, |
|
pre_kv=None, |
|
): |
|
tgt2 = self.norm1(tgt) |
|
tgt2, pre_kv = self.self_attn(tgt2, pre_kv=pre_kv, attn_mask=self_attn_mask) |
|
tgt = tgt + self.dropout1(tgt2) |
|
tgt2 = self.norm2(tgt) |
|
tgt2 = self.multihead_attn( |
|
query=tgt2, |
|
key=self.with_pos_embed(memory, pos), |
|
value=memory, |
|
key_padding_mask=memory_key_padding_mask, |
|
)[0] |
|
tgt = tgt + self.dropout2(tgt2) |
|
tgt2 = self.norm3(tgt) |
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
tgt = tgt + self.dropout3(tgt2) |
|
return tgt, pre_kv |
|
|
|
def forward( |
|
self, |
|
tgt, |
|
memory, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
self_attn_mask: Optional[Tensor] = None, |
|
pre_kv=None, |
|
): |
|
if self.normalize_before: |
|
return self.forward_pre(tgt, memory, memory_key_padding_mask, pos, self_attn_mask, pre_kv) |
|
return self.forward_post(tgt, memory, memory_key_padding_mask, pos, self_attn_mask, pre_kv) |
|
|
|
|
|
class MLP(nn.Module): |
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
return x |
|
|
|
|
|
def _get_clones(module, N): |
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
def build_transformer(args, tokenizer): |
|
num_vocal = len(tokenizer) |
|
return Transformer( |
|
d_model=args.hidden_dim, |
|
dropout=args.dropout, |
|
nhead=args.nheads, |
|
dim_feedforward=args.dim_feedforward, |
|
num_encoder_layers=args.enc_layers, |
|
num_decoder_layers=args.dec_layers, |
|
normalize_before=args.pre_norm, |
|
num_vocal=num_vocal, |
|
pred_eos=args.pred_eos, |
|
tokenizer=tokenizer |
|
) |
|
|
|
|
|
def _get_activation_fn(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return F.relu |
|
if activation == "gelu": |
|
return F.gelu |
|
if activation == "glu": |
|
return F.glu |
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|