CYF200127's picture
Upload 116 files
5e9bd47 verified
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Pix2Seq model and criterion classes.
"""
import torch
import torch.nn.functional as F
from torch import nn
from .misc import nested_tensor_from_tensor_list
from .backbone import build_backbone
from .transformer import build_transformer
class Pix2Seq(nn.Module):
""" This is the Pix2Seq module that performs object detection """
def __init__(self, backbone, transformer):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_bins: number of bins for each side of the input image
"""
super().__init__()
self.transformer = transformer
hidden_dim = transformer.d_model
self.input_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=(1, 1)),
nn.GroupNorm(32, hidden_dim))
self.backbone = backbone
def forward(self, image_tensor, targets=None, max_len=500):
""" 
image_tensor:
The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all vocabulary.
Shape= [batch_size, num_sequence, num_vocal]
"""
if isinstance(image_tensor, (list, torch.Tensor)):
image_tensor = nested_tensor_from_tensor_list(image_tensor)
features, pos = self.backbone(image_tensor)
src, mask = features[-1].decompose()
assert mask is not None
mask = torch.zeros_like(mask).bool()
src = self.input_proj(src)
if targets is not None:
input_seq, input_len = targets
output_logits = self.transformer(src, input_seq[:, 1:], mask, pos[-1])
return output_logits[:, :-1]
else:
output_seqs, output_scores = self.transformer(src, None, mask, pos[-1], max_len=max_len)
return output_seqs, output_scores
def build_pix2seq_model(args, tokenizer):
# the `num_classes` naming here is somewhat misleading.
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
# is the maximum id for a class in your dataset. For example,
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
# As another example, for a dataset that has a single class with id 1,
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
backbone = build_backbone(args)
transformer = build_transformer(args, tokenizer)
model = Pix2Seq(backbone, transformer)
if args.pix2seq_ckpt is not None:
checkpoint = torch.load(args.pix2seq_ckpt, map_location='cpu')
model.load_state_dict(checkpoint['model'])
return model