|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import timm |
|
|
|
from .inference import GreedySearch, BeamSearch |
|
from .transformer import TransformerDecoder, Embeddings |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, args, pretrained=False): |
|
super().__init__() |
|
model_name = args.encoder |
|
self.model_name = model_name |
|
if model_name.startswith('resnet'): |
|
self.model_type = 'resnet' |
|
self.cnn = timm.create_model(model_name, pretrained=pretrained) |
|
self.n_features = self.cnn.num_features |
|
self.cnn.global_pool = nn.Identity() |
|
self.cnn.fc = nn.Identity() |
|
elif model_name.startswith('swin'): |
|
self.model_type = 'swin' |
|
self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False, |
|
use_checkpoint=args.use_checkpoint) |
|
self.n_features = self.transformer.num_features |
|
self.transformer.head = nn.Identity() |
|
elif 'efficientnet' in model_name: |
|
self.model_type = 'efficientnet' |
|
self.cnn = timm.create_model(model_name, pretrained=pretrained) |
|
self.n_features = self.cnn.num_features |
|
self.cnn.global_pool = nn.Identity() |
|
self.cnn.classifier = nn.Identity() |
|
else: |
|
raise NotImplemented |
|
|
|
def swin_forward(self, transformer, x): |
|
x = transformer.patch_embed(x) |
|
if transformer.absolute_pos_embed is not None: |
|
x = x + transformer.absolute_pos_embed |
|
x = transformer.pos_drop(x) |
|
|
|
def layer_forward(layer, x, hiddens): |
|
for blk in layer.blocks: |
|
if not torch.jit.is_scripting() and layer.use_checkpoint: |
|
x = torch.utils.checkpoint.checkpoint(blk, x) |
|
else: |
|
x = blk(x) |
|
H, W = layer.input_resolution |
|
B, L, C = x.shape |
|
hiddens.append(x.view(B, H, W, C)) |
|
if layer.downsample is not None: |
|
x = layer.downsample(x) |
|
return x, hiddens |
|
|
|
hiddens = [] |
|
for layer in transformer.layers: |
|
x, hiddens = layer_forward(layer, x, hiddens) |
|
x = transformer.norm(x) |
|
hiddens[-1] = x.view_as(hiddens[-1]) |
|
return x, hiddens |
|
|
|
def forward(self, x, refs=None): |
|
if self.model_type in ['resnet', 'efficientnet']: |
|
features = self.cnn(x) |
|
features = features.permute(0, 2, 3, 1) |
|
hiddens = [] |
|
elif self.model_type == 'swin': |
|
if 'patch' in self.model_name: |
|
features, hiddens = self.swin_forward(self.transformer, x) |
|
else: |
|
features, hiddens = self.transformer(x) |
|
else: |
|
raise NotImplemented |
|
return features, hiddens |
|
|
|
|
|
class TransformerDecoderBase(nn.Module): |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
|
|
self.enc_trans_layer = nn.Sequential( |
|
nn.Linear(args.encoder_dim, args.dec_hidden_size) |
|
|
|
) |
|
self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None |
|
|
|
self.decoder = TransformerDecoder( |
|
num_layers=args.dec_num_layers, |
|
d_model=args.dec_hidden_size, |
|
heads=args.dec_attn_heads, |
|
d_ff=args.dec_hidden_size * 4, |
|
copy_attn=False, |
|
self_attn_type="scaled-dot", |
|
dropout=args.hidden_dropout, |
|
attention_dropout=args.attn_dropout, |
|
max_relative_positions=args.max_relative_positions, |
|
aan_useffn=False, |
|
full_context_alignment=False, |
|
alignment_layer=0, |
|
alignment_heads=0, |
|
pos_ffn_activation_fn='gelu' |
|
) |
|
|
|
def enc_transform(self, encoder_out): |
|
batch_size = encoder_out.size(0) |
|
encoder_dim = encoder_out.size(-1) |
|
encoder_out = encoder_out.view(batch_size, -1, encoder_dim) |
|
max_len = encoder_out.size(1) |
|
device = encoder_out.device |
|
if self.enc_pos_emb: |
|
pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0) |
|
encoder_out = encoder_out + pos_emb |
|
encoder_out = self.enc_trans_layer(encoder_out) |
|
return encoder_out |
|
|
|
|
|
class TransformerDecoderAR(TransformerDecoderBase): |
|
|
|
def __init__(self, args, tokenizer): |
|
super().__init__(args) |
|
self.tokenizer = tokenizer |
|
self.vocab_size = len(self.tokenizer) |
|
self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True) |
|
self.embeddings = Embeddings( |
|
word_vec_size=args.dec_hidden_size, |
|
word_vocab_size=self.vocab_size, |
|
word_padding_idx=tokenizer.PAD_ID, |
|
position_encoding=True, |
|
dropout=args.hidden_dropout) |
|
|
|
def dec_embedding(self, tgt, step=None): |
|
pad_idx = self.embeddings.word_padding_idx |
|
tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) |
|
emb = self.embeddings(tgt, step=step) |
|
assert emb.dim() == 3 |
|
return emb, tgt_pad_mask |
|
|
|
def forward(self, encoder_out, labels, label_lengths): |
|
batch_size, max_len, _ = encoder_out.size() |
|
memory_bank = self.enc_transform(encoder_out) |
|
|
|
tgt = labels.unsqueeze(-1) |
|
tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) |
|
dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask) |
|
|
|
logits = self.output_layer(dec_out) |
|
return logits[:, :-1], labels[:, 1:], dec_out |
|
|
|
def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256): |
|
batch_size, max_len, _ = encoder_out.size() |
|
memory_bank = self.enc_transform(encoder_out) |
|
|
|
if beam_size == 1: |
|
decode_strategy = GreedySearch( |
|
sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length, |
|
pad=self.tokenizer.PAD_ID, bos=self.tokenizer.SOS_ID, eos=self.tokenizer.EOS_ID, |
|
return_attention=False, return_hidden=True) |
|
else: |
|
decode_strategy = BeamSearch( |
|
beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length, |
|
pad=self.tokenizer.PAD_ID, bos=self.tokenizer.SOS_ID, eos=self.tokenizer.EOS_ID, |
|
return_attention=False) |
|
|
|
|
|
results = { |
|
"predictions": None, |
|
"scores": None, |
|
"attention": None |
|
} |
|
|
|
|
|
_, memory_bank = decode_strategy.initialize(memory_bank=memory_bank) |
|
|
|
|
|
for step in range(decode_strategy.max_length): |
|
tgt = decode_strategy.current_predictions.view(-1, 1, 1) |
|
tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) |
|
dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, |
|
tgt_pad_mask=tgt_pad_mask, step=step) |
|
|
|
attn = dec_attn.get("std", None) |
|
|
|
dec_logits = self.output_layer(dec_out) |
|
dec_logits = dec_logits.squeeze(1) |
|
log_probs = F.log_softmax(dec_logits, dim=-1) |
|
|
|
if self.tokenizer.output_constraint: |
|
output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()] |
|
output_mask = torch.tensor(output_mask, device=log_probs.device) |
|
log_probs.masked_fill_(output_mask, -10000) |
|
|
|
decode_strategy.advance(log_probs, attn, dec_out) |
|
any_finished = decode_strategy.is_finished.any() |
|
if any_finished: |
|
decode_strategy.update_finished() |
|
if decode_strategy.done: |
|
break |
|
|
|
select_indices = decode_strategy.select_indices |
|
if any_finished: |
|
|
|
memory_bank = memory_bank.index_select(0, select_indices) |
|
self.map_state(lambda state, dim: state.index_select(dim, select_indices)) |
|
|
|
results["scores"] = decode_strategy.scores |
|
results["predictions"] = decode_strategy.predictions |
|
results["attention"] = decode_strategy.attention |
|
results["hidden"] = decode_strategy.hidden |
|
|
|
return results["predictions"], results['scores'], results["hidden"] |
|
|
|
|
|
def map_state(self, fn): |
|
def _recursive_map(struct, batch_dim=0): |
|
for k, v in struct.items(): |
|
if v is not None: |
|
if isinstance(v, dict): |
|
_recursive_map(v) |
|
else: |
|
struct[k] = fn(v, batch_dim) |
|
if self.decoder.state["cache"] is not None: |
|
_recursive_map(self.decoder.state["cache"]) |
|
|
|
|
|
class Decoder(nn.Module): |
|
|
|
def __init__(self, args, tokenizer): |
|
super(Decoder, self).__init__() |
|
self.args = args |
|
self.formats = args.formats |
|
self.tokenizer = tokenizer |
|
decoder = {} |
|
for format_ in args.formats: |
|
decoder[format_] = TransformerDecoderAR(args, tokenizer[format_]) |
|
self.decoder = nn.ModuleDict(decoder) |
|
|
|
def forward(self, encoder_out, hiddens, refs): |
|
results = {} |
|
for format_ in self.formats: |
|
labels, label_lengths = refs[format_] |
|
results[format_] = self.decoder[format_](encoder_out, labels, label_lengths) |
|
return results |
|
|
|
def decode(self, encoder_out, hiddens, refs=None, beam_size=1, n_best=1): |
|
results = {} |
|
predictions = {} |
|
beam_predictions = {} |
|
for format_ in self.formats: |
|
max_len = self.tokenizer[format_].max_len |
|
results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len) |
|
outputs, scores, *_ = results[format_] |
|
beam_preds = [[self.tokenizer[format_].sequence_to_data(x.tolist()) for x in pred] for pred in outputs] |
|
beam_predictions[format_] = (beam_preds, scores) |
|
predictions[format_] = [preds[0] for preds in beam_preds] |
|
return predictions, beam_predictions |
|
|