|
|
|
""" |
|
Pix2Seq model and criterion classes. |
|
""" |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from .misc import nested_tensor_from_tensor_list |
|
from .backbone import build_backbone |
|
from .transformer import build_transformer |
|
|
|
|
|
class Pix2Seq(nn.Module): |
|
""" This is the Pix2Seq module that performs object detection """ |
|
def __init__(self, backbone, transformer): |
|
""" Initializes the model. |
|
Parameters: |
|
backbone: torch module of the backbone to be used. See backbone.py |
|
transformer: torch module of the transformer architecture. See transformer.py |
|
num_classes: number of object classes |
|
num_bins: number of bins for each side of the input image |
|
""" |
|
super().__init__() |
|
self.transformer = transformer |
|
hidden_dim = transformer.d_model |
|
self.input_proj = nn.Sequential( |
|
nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=(1, 1)), |
|
nn.GroupNorm(32, hidden_dim)) |
|
self.backbone = backbone |
|
|
|
def forward(self, image_tensor, targets=None, max_len=500): |
|
""" |
|
image_tensor: |
|
The forward expects a NestedTensor, which consists of: |
|
- samples.tensor: batched images, of shape [batch_size x 3 x H x W] |
|
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels |
|
It returns a dict with the following elements: |
|
- "pred_logits": the classification logits (including no-object) for all vocabulary. |
|
Shape= [batch_size, num_sequence, num_vocal] |
|
""" |
|
if isinstance(image_tensor, (list, torch.Tensor)): |
|
image_tensor = nested_tensor_from_tensor_list(image_tensor) |
|
features, pos = self.backbone(image_tensor) |
|
|
|
src, mask = features[-1].decompose() |
|
assert mask is not None |
|
mask = torch.zeros_like(mask).bool() |
|
|
|
src = self.input_proj(src) |
|
if targets is not None: |
|
input_seq, input_len = targets |
|
output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1]) |
|
return output_logits[:, :-1] |
|
else: |
|
output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len) |
|
return output_seqs, output_scores |
|
|
|
|
|
def build_pix2seq_model(args, tokenizer): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backbone = build_backbone(args) |
|
transformer = build_transformer(args, tokenizer) |
|
|
|
model = Pix2Seq(backbone, transformer) |
|
|
|
if args.pix2seq_ckpt is not None: |
|
checkpoint = torch.load(args.pix2seq_ckpt, map_location='cpu') |
|
model.load_state_dict(checkpoint['model']) |
|
|
|
return model |
|
|