import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import timm from .inference import GreedySearch, BeamSearch from .transformer import TransformerDecoder, Embeddings class Encoder(nn.Module): def __init__(self, args, pretrained=False): super().__init__() model_name = args.encoder self.model_name = model_name if model_name.startswith('resnet'): self.model_type = 'resnet' self.cnn = timm.create_model(model_name, pretrained=pretrained) self.n_features = self.cnn.num_features # encoder_dim self.cnn.global_pool = nn.Identity() self.cnn.fc = nn.Identity() elif model_name.startswith('swin'): self.model_type = 'swin' self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False, use_checkpoint=args.use_checkpoint) self.n_features = self.transformer.num_features self.transformer.head = nn.Identity() elif 'efficientnet' in model_name: self.model_type = 'efficientnet' self.cnn = timm.create_model(model_name, pretrained=pretrained) self.n_features = self.cnn.num_features self.cnn.global_pool = nn.Identity() self.cnn.classifier = nn.Identity() else: raise NotImplemented def swin_forward(self, transformer, x): x = transformer.patch_embed(x) if transformer.absolute_pos_embed is not None: x = x + transformer.absolute_pos_embed x = transformer.pos_drop(x) def layer_forward(layer, x, hiddens): for blk in layer.blocks: if not torch.jit.is_scripting() and layer.use_checkpoint: x = torch.utils.checkpoint.checkpoint(blk, x) else: x = blk(x) H, W = layer.input_resolution B, L, C = x.shape hiddens.append(x.view(B, H, W, C)) if layer.downsample is not None: x = layer.downsample(x) return x, hiddens hiddens = [] for layer in transformer.layers: x, hiddens = layer_forward(layer, x, hiddens) x = transformer.norm(x) # B L C hiddens[-1] = x.view_as(hiddens[-1]) return x, hiddens def forward(self, x, refs=None): if self.model_type in ['resnet', 'efficientnet']: features = self.cnn(x) features = features.permute(0, 2, 3, 1) hiddens = [] elif self.model_type == 'swin': if 'patch' in self.model_name: features, hiddens = self.swin_forward(self.transformer, x) else: features, hiddens = self.transformer(x) else: raise NotImplemented return features, hiddens class TransformerDecoderBase(nn.Module): def __init__(self, args): super().__init__() self.args = args self.enc_trans_layer = nn.Sequential( nn.Linear(args.encoder_dim, args.dec_hidden_size) # nn.LayerNorm(args.dec_hidden_size, eps=1e-6) ) self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None self.decoder = TransformerDecoder( num_layers=args.dec_num_layers, d_model=args.dec_hidden_size, heads=args.dec_attn_heads, d_ff=args.dec_hidden_size * 4, copy_attn=False, self_attn_type="scaled-dot", dropout=args.hidden_dropout, attention_dropout=args.attn_dropout, max_relative_positions=args.max_relative_positions, aan_useffn=False, full_context_alignment=False, alignment_layer=0, alignment_heads=0, pos_ffn_activation_fn='gelu' ) def enc_transform(self, encoder_out): batch_size = encoder_out.size(0) encoder_dim = encoder_out.size(-1) encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) max_len = encoder_out.size(1) device = encoder_out.device if self.enc_pos_emb: pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0) encoder_out = encoder_out + pos_emb encoder_out = self.enc_trans_layer(encoder_out) return encoder_out class TransformerDecoderAR(TransformerDecoderBase): def __init__(self, args, tokenizer): super().__init__(args) self.tokenizer = tokenizer self.vocab_size = len(self.tokenizer) self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True) self.embeddings = Embeddings( word_vec_size=args.dec_hidden_size, word_vocab_size=self.vocab_size, word_padding_idx=tokenizer.PAD_ID, position_encoding=True, dropout=args.hidden_dropout) def dec_embedding(self, tgt, step=None): pad_idx = self.embeddings.word_padding_idx tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) # [B, 1, T_tgt] emb = self.embeddings(tgt, step=step) assert emb.dim() == 3 # batch x len x embedding_dim return emb, tgt_pad_mask def forward(self, encoder_out, labels, label_lengths): batch_size, max_len, _ = encoder_out.size() memory_bank = self.enc_transform(encoder_out) tgt = labels.unsqueeze(-1) # (b, t, 1) tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask) logits = self.output_layer(dec_out) # (b, t, h) -> (b, t, v) return logits[:, :-1], labels[:, 1:], dec_out def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256): batch_size, max_len, _ = encoder_out.size() memory_bank = self.enc_transform(encoder_out) if beam_size == 1: decode_strategy = GreedySearch( sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length, pad=self.tokenizer.PAD_ID, bos=self.tokenizer.SOS_ID, eos=self.tokenizer.EOS_ID, return_attention=False, return_hidden=True) else: decode_strategy = BeamSearch( beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length, pad=self.tokenizer.PAD_ID, bos=self.tokenizer.SOS_ID, eos=self.tokenizer.EOS_ID, return_attention=False) # adapted from onmt.translate.translator results = { "predictions": None, "scores": None, "attention": None } # (2) prep decode_strategy. Possibly repeat src objects. _, memory_bank = decode_strategy.initialize(memory_bank=memory_bank) # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): tgt = decode_strategy.current_predictions.view(-1, 1, 1) tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask, step=step) attn = dec_attn.get("std", None) dec_logits = self.output_layer(dec_out) # [b, t, h] => [b, t, v] dec_logits = dec_logits.squeeze(1) log_probs = F.log_softmax(dec_logits, dim=-1) if self.tokenizer.output_constraint: output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()] output_mask = torch.tensor(output_mask, device=log_probs.device) log_probs.masked_fill_(output_mask, -10000) decode_strategy.advance(log_probs, attn, dec_out) any_finished = decode_strategy.is_finished.any() if any_finished: decode_strategy.update_finished() if decode_strategy.done: break select_indices = decode_strategy.select_indices if any_finished: # Reorder states. memory_bank = memory_bank.index_select(0, select_indices) self.map_state(lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = decode_strategy.scores results["predictions"] = decode_strategy.predictions results["attention"] = decode_strategy.attention results["hidden"] = decode_strategy.hidden return results["predictions"], results['scores'], results["hidden"] # adapted from onmt.decoders.transformer def map_state(self, fn): def _recursive_map(struct, batch_dim=0): for k, v in struct.items(): if v is not None: if isinstance(v, dict): _recursive_map(v) else: struct[k] = fn(v, batch_dim) if self.decoder.state["cache"] is not None: _recursive_map(self.decoder.state["cache"]) class Decoder(nn.Module): def __init__(self, args, tokenizer): super(Decoder, self).__init__() self.args = args self.formats = args.formats self.tokenizer = tokenizer decoder = {} for format_ in args.formats: decoder[format_] = TransformerDecoderAR(args, tokenizer[format_]) self.decoder = nn.ModuleDict(decoder) def forward(self, encoder_out, hiddens, refs): results = {} for format_ in self.formats: labels, label_lengths = refs[format_] results[format_] = self.decoder[format_](encoder_out, labels, label_lengths) return results def decode(self, encoder_out, hiddens, refs=None, beam_size=1, n_best=1): results = {} predictions = {} beam_predictions = {} for format_ in self.formats: max_len = self.tokenizer[format_].max_len results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len) outputs, scores, *_ = results[format_] beam_preds = [[self.tokenizer[format_].sequence_to_data(x.tolist()) for x in pred] for pred in outputs] beam_predictions[format_] = (beam_preds, scores) predictions[format_] = [preds[0] for preds in beam_preds] return predictions, beam_predictions