|
import torch |
|
from torch import nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision.models import resnet50 |
|
from torchvision import transforms |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
from transformers import BertTokenizer, BertModel |
|
import os |
|
import json |
|
import numpy as np |
|
from collections import defaultdict |
|
import random |
|
from tqdm.notebook import tqdm |
|
from torchvision import models |
|
from torch.nn.utils.rnn import pad_sequence |
|
import matplotlib.patches as patches |
|
|
|
import math |
|
import time |
|
import os |
|
from PIL import Image |
|
import requests |
|
import nltk |
|
|
|
import os |
|
import cv2 |
|
import colorsys |
|
from numpy import asarray |
|
import math |
|
|
|
|
|
from transformers import GPT2LMHeadModel, GPT2Config |
|
|
|
from scipy.optimize import linear_sum_assignment |
|
|
|
import sys |
|
sys.path.append("../src") |
|
|
|
from utils import * |
|
|
|
NUM_QUERIES = 40 |
|
feature_size = 256 |
|
token_size = 256 |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
""" |
|
Various positional encodings for the transformer. |
|
""" |
|
|
|
class PositionEmbeddingSine(nn.Module): |
|
""" |
|
This is a more standard version of the position embedding, very similar to the one |
|
used by the Attention is all you need paper, generalized to work on images. |
|
""" |
|
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): |
|
super().__init__() |
|
self.num_pos_feats = num_pos_feats |
|
self.temperature = temperature |
|
self.normalize = normalize |
|
if scale is not None and normalize is False: |
|
raise ValueError("normalize should be True if scale is passed") |
|
if scale is None: |
|
scale = 2 * math.pi |
|
self.scale = scale |
|
|
|
def forward(self, tensor_list: NestedTensor): |
|
x = tensor_list.tensors |
|
mask = tensor_list.mask |
|
assert mask is not None |
|
not_mask = ~mask |
|
y_embed = not_mask.cumsum(1, dtype=torch.float32) |
|
x_embed = not_mask.cumsum(2, dtype=torch.float32) |
|
if self.normalize: |
|
eps = 1e-6 |
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
|
|
|
pos_x = x_embed[:, :, :, None] / dim_t |
|
pos_y = y_embed[:, :, :, None] / dim_t |
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
|
return pos |
|
|
|
|
|
class PositionEmbeddingLearned(nn.Module): |
|
""" |
|
Absolute pos embedding, learned. |
|
""" |
|
def __init__(self, num_pos_feats=256): |
|
super().__init__() |
|
self.row_embed = nn.Embedding(50, num_pos_feats) |
|
self.col_embed = nn.Embedding(50, num_pos_feats) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
nn.init.uniform_(self.row_embed.weight) |
|
nn.init.uniform_(self.col_embed.weight) |
|
|
|
def forward(self, tensor_list: NestedTensor): |
|
x = tensor_list.tensors |
|
h, w = x.shape[-2:] |
|
i = torch.arange(w, device=x.device) |
|
j = torch.arange(h, device=x.device) |
|
x_emb = self.col_embed(i) |
|
y_emb = self.row_embed(j) |
|
pos = torch.cat([ |
|
x_emb.unsqueeze(0).repeat(h, 1, 1), |
|
y_emb.unsqueeze(1).repeat(1, w, 1), |
|
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) |
|
return pos |
|
|
|
|
|
def build_position_encoding(args): |
|
N_steps = args.hidden_dim // 2 |
|
if args.position_embedding in ('v2', 'sine'): |
|
|
|
position_embedding = PositionEmbeddingSine(N_steps, normalize=True) |
|
elif args.position_embedding in ('v3', 'learned'): |
|
position_embedding = PositionEmbeddingLearned(N_steps) |
|
else: |
|
raise ValueError(f"not supported {args.position_embedding}") |
|
|
|
return position_embedding |
|
|
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision |
|
from torch import nn |
|
from torchvision.models._utils import IntermediateLayerGetter |
|
from typing import Dict, List |
|
|
|
|
|
class FrozenBatchNorm2d(torch.nn.Module): |
|
""" |
|
BatchNorm2d where the batch statistics and the affine parameters are fixed. |
|
|
|
Copy-paste from torchvision.misc.ops with added eps before rqsrt, |
|
without which any other models than torchvision.models.resnet[18,34,50,101] |
|
produce nans. |
|
""" |
|
|
|
def __init__(self, n): |
|
super(FrozenBatchNorm2d, self).__init__() |
|
self.register_buffer("weight", torch.ones(n)) |
|
self.register_buffer("bias", torch.zeros(n)) |
|
self.register_buffer("running_mean", torch.zeros(n)) |
|
self.register_buffer("running_var", torch.ones(n)) |
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
|
missing_keys, unexpected_keys, error_msgs): |
|
num_batches_tracked_key = prefix + 'num_batches_tracked' |
|
if num_batches_tracked_key in state_dict: |
|
del state_dict[num_batches_tracked_key] |
|
|
|
super(FrozenBatchNorm2d, self)._load_from_state_dict( |
|
state_dict, prefix, local_metadata, strict, |
|
missing_keys, unexpected_keys, error_msgs) |
|
|
|
def forward(self, x): |
|
|
|
|
|
w = self.weight.reshape(1, -1, 1, 1) |
|
b = self.bias.reshape(1, -1, 1, 1) |
|
rv = self.running_var.reshape(1, -1, 1, 1) |
|
rm = self.running_mean.reshape(1, -1, 1, 1) |
|
eps = 1e-5 |
|
scale = w * (rv + eps).rsqrt() |
|
bias = b - rm * scale |
|
return x * scale + bias |
|
|
|
|
|
class BackboneBase(nn.Module): |
|
|
|
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): |
|
super().__init__() |
|
for name, parameter in backbone.named_parameters(): |
|
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: |
|
parameter.requires_grad_(False) |
|
if return_interm_layers: |
|
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} |
|
else: |
|
return_layers = {'layer4': "0"} |
|
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) |
|
self.num_channels = num_channels |
|
|
|
def forward(self, tensor_list: NestedTensor): |
|
xs = self.body(tensor_list.tensors) |
|
out: Dict[str, NestedTensor] = {} |
|
for name, x in xs.items(): |
|
m = tensor_list.mask |
|
assert m is not None |
|
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] |
|
out[name] = NestedTensor(x, mask) |
|
return out |
|
|
|
''' |
|
The line mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] applies a mask to the output |
|
features from the backbone. The mask is used to indicate which pixels in the image are valid. |
|
|
|
|
|
The mask is a tensor of the same size as the output features. The mask is initialized to all zeros. The m[None].float() |
|
operation expands the mask to be a 1-D tensor of size 1 x H x W. The F.interpolate() |
|
operation then resizes the mask to the same size as the output features. The to(torch.bool) operation converts the |
|
mask to a binary tensor. The [0] operation takes the first element of the tensor, which is the mask for the first output |
|
feature map. |
|
|
|
The mask of a feature extracted from ResNet50 as a backbone is a binary tensor that indicates which pixels in the image |
|
are valid. The pixels that are valid are those that are not padded. The mask is used by the backbone to ignore the padded |
|
pixels when it is extracting features from the image. |
|
|
|
''' |
|
|
|
class Backbone(BackboneBase): |
|
"""ResNet backbone with frozen BatchNorm.""" |
|
def __init__(self, name: str, |
|
train_backbone: bool, |
|
return_interm_layers: bool, |
|
dilation: bool): |
|
backbone = getattr(torchvision.models, name)( |
|
replace_stride_with_dilation=[False, False, dilation], |
|
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) |
|
|
|
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 |
|
super().__init__(backbone, train_backbone, num_channels, return_interm_layers) |
|
|
|
|
|
class Joiner(nn.Sequential): |
|
def __init__(self, backbone, position_embedding): |
|
super().__init__(backbone, position_embedding) |
|
|
|
def forward(self, tensor_list: NestedTensor): |
|
xs = self[0](tensor_list) |
|
out: List[NestedTensor] = [] |
|
pos = [] |
|
for name, x in xs.items(): |
|
out.append(x) |
|
|
|
pos.append(self[1](x).to(x.tensors.dtype)) |
|
|
|
return out, pos |
|
|
|
|
|
def build_backbone(args): |
|
position_embedding = build_position_encoding(args) |
|
train_backbone = args.lr_backbone > 0 |
|
return_interm_layers = args.masks |
|
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) |
|
model = Joiner(backbone, position_embedding) |
|
model.num_channels = backbone.num_channels |
|
return model |
|
|
|
def get_sinusoid_encoding_table(n_position, d_hid): |
|
def cal_angle(position, hid_idx): |
|
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) |
|
|
|
def get_posi_angle_vec(position): |
|
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] |
|
|
|
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) |
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) |
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) |
|
return torch.FloatTensor(sinusoid_table) |
|
|
|
class PostProcess(nn.Module): |
|
""" This module converts the model's output into the format expected by the coco api""" |
|
@torch.no_grad() |
|
def forward(self, outputs, target_sizes): |
|
""" Perform the computation |
|
Parameters: |
|
outputs: raw outputs of the model |
|
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch |
|
For evaluation, this must be the original image size (before any data augmentation) |
|
For visualization, this should be the image size after data augment, but before padding |
|
""" |
|
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] |
|
|
|
assert len(out_logits) == len(target_sizes) |
|
assert target_sizes.shape[1] == 2 |
|
|
|
prob = F.softmax(out_logits, -1) |
|
scores, labels = prob[..., :-1].max(-1) |
|
|
|
|
|
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) |
|
|
|
img_h, img_w = target_sizes.unbind(1) |
|
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) |
|
boxes = boxes * scale_fct[:, None, :] |
|
|
|
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] |
|
|
|
return results |
|
|
|
|
|
class MLP(nn.Module): |
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
return x |
|
|
|
|
|
def build(args): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_classes = 20 if args.dataset_file != 'coco' else 91 |
|
if args.dataset_file == "coco_panoptic": |
|
|
|
|
|
num_classes = 250 |
|
device = torch.device(args.device) |
|
|
|
backbone = build_backbone(args) |
|
|
|
transformer = build_transformer(args) |
|
|
|
model = DETR( |
|
backbone, |
|
transformer, |
|
num_classes=num_classes, |
|
num_queries=args.num_queries, |
|
aux_loss=args.aux_loss, |
|
) |
|
if args.masks: |
|
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) |
|
matcher = build_matcher(args) |
|
weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} |
|
weight_dict['loss_giou'] = args.giou_loss_coef |
|
if args.masks: |
|
weight_dict["loss_mask"] = args.mask_loss_coef |
|
weight_dict["loss_dice"] = args.dice_loss_coef |
|
|
|
if args.aux_loss: |
|
aux_weight_dict = {} |
|
for i in range(args.dec_layers - 1): |
|
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) |
|
weight_dict.update(aux_weight_dict) |
|
|
|
losses = ['labels', 'boxes', 'cardinality'] |
|
if args.masks: |
|
losses += ["masks"] |
|
criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, |
|
eos_coef=args.eos_coef, losses=losses) |
|
criterion.to(device) |
|
postprocessors = {'bbox': PostProcess()} |
|
if args.masks: |
|
postprocessors['segm'] = PostProcessSegm() |
|
if args.dataset_file == "coco_panoptic": |
|
is_thing_map = {i: i <= 90 for i in range(201)} |
|
postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) |
|
|
|
return model, criterion, postprocessors |
|
|
|
class Parameters: |
|
def __init__(self): |
|
self.lr = 1e-4 |
|
self.lr_backbone = 1e-5 |
|
self.batch_size = 2 |
|
self.weight_decay = 1e-4 |
|
self.epochs = 300 |
|
self.lr_drop = 200 |
|
self.clip_max_norm = 0.1 |
|
|
|
args = Parameters() |
|
|
|
args.lr=1e-4 |
|
args.lr_backbone=1e-5 |
|
args.batch_size=32 |
|
args.weight_decay=1e-4 |
|
args.epochs=300 |
|
args.lr_drop=200 |
|
args.clip_max_norm=0.1 |
|
|
|
|
|
args.frozen_weights=False |
|
|
|
|
|
args.backbone='resnet50' |
|
args.dilation=False |
|
args.position_embedding='sine' |
|
|
|
|
|
args.enc_layers=6 |
|
args.dec_layers=6 |
|
args.dim_feedforward=2048 |
|
args.hidden_dim=256 |
|
args.dropout=0.1 |
|
args.nheads=8 |
|
args.num_queries=40 |
|
args.pre_norm=True |
|
|
|
|
|
args.masks=False |
|
|
|
|
|
""" |
|
LLMEyeCap Transformer class. |
|
|
|
A DETR (FaceBook) Copy-paste from torch.nn.Transformer with modifications: |
|
* positional encodings are passed in MHattention |
|
* extra LN at the end of encoder is removed |
|
* decoder returns a stack of activations from all decoding layers |
|
|
|
""" |
|
import copy |
|
from typing import Optional, List |
|
|
|
|
|
class Transformer(nn.Module): |
|
|
|
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, |
|
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False, |
|
return_intermediate_dec=False): |
|
super().__init__() |
|
|
|
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, |
|
dropout, activation, normalize_before) |
|
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
|
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) |
|
|
|
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, |
|
dropout, activation, normalize_before) |
|
decoder_norm = nn.LayerNorm(d_model) |
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, |
|
return_intermediate=return_intermediate_dec) |
|
|
|
self._reset_parameters() |
|
|
|
self.d_model = d_model |
|
self.nhead = nhead |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
|
|
def forward(self, src, mask, query_embed, pos_embed): |
|
|
|
bs, c, h, w = src.shape |
|
src = src.flatten(2).permute(2, 0, 1) |
|
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) |
|
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) |
|
mask = mask.flatten(1) |
|
|
|
tgt = torch.zeros_like(query_embed) |
|
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) |
|
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, |
|
pos=pos_embed, query_pos=query_embed) |
|
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
|
|
def __init__(self, encoder_layer, num_layers, norm=None): |
|
super().__init__() |
|
self.layers = _get_clones(encoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
|
|
def forward(self, src, |
|
mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
output = src |
|
|
|
for layer in self.layers: |
|
output = layer(output, src_mask=mask, |
|
src_key_padding_mask=src_key_padding_mask, pos=pos) |
|
|
|
if self.norm is not None: |
|
output = self.norm(output) |
|
|
|
return output |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
|
|
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): |
|
super().__init__() |
|
self.layers = _get_clones(decoder_layer, num_layers) |
|
self.num_layers = num_layers |
|
self.norm = norm |
|
self.return_intermediate = return_intermediate |
|
|
|
def forward(self, tgt, memory, |
|
tgt_mask: Optional[Tensor] = None, |
|
memory_mask: Optional[Tensor] = None, |
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
output = tgt |
|
|
|
intermediate = [] |
|
|
|
for layer in self.layers: |
|
output = layer(output, memory, tgt_mask=tgt_mask, |
|
memory_mask=memory_mask, |
|
tgt_key_padding_mask=tgt_key_padding_mask, |
|
memory_key_padding_mask=memory_key_padding_mask, |
|
pos=pos, query_pos=query_pos) |
|
if self.return_intermediate: |
|
intermediate.append(self.norm(output)) |
|
|
|
if self.norm is not None: |
|
output = self.norm(output) |
|
if self.return_intermediate: |
|
intermediate.pop() |
|
intermediate.append(output) |
|
|
|
if self.return_intermediate: |
|
return torch.stack(intermediate) |
|
|
|
return output.unsqueeze(0) |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post(self, |
|
src, |
|
src_mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
q = k = self.with_pos_embed(src, pos) |
|
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, |
|
key_padding_mask=src_key_padding_mask)[0] |
|
src = src + self.dropout1(src2) |
|
src = self.norm1(src) |
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
|
src = src + self.dropout2(src2) |
|
src = self.norm2(src) |
|
return src |
|
|
|
def forward_pre(self, src, |
|
src_mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
src2 = self.norm1(src) |
|
q = k = self.with_pos_embed(src2, pos) |
|
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, |
|
key_padding_mask=src_key_padding_mask)[0] |
|
src = src + self.dropout1(src2) |
|
src2 = self.norm2(src) |
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) |
|
src = src + self.dropout2(src2) |
|
return src |
|
|
|
def forward(self, src, |
|
src_mask: Optional[Tensor] = None, |
|
src_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None): |
|
if self.normalize_before: |
|
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) |
|
return self.forward_post(src, src_mask, src_key_padding_mask, pos) |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
|
activation="relu", normalize_before=False): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.norm3 = nn.LayerNorm(d_model) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
self.activation = _get_activation_fn(activation) |
|
self.normalize_before = normalize_before |
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
|
return tensor if pos is None else tensor + pos |
|
|
|
def forward_post(self, tgt, memory, |
|
tgt_mask: Optional[Tensor] = None, |
|
memory_mask: Optional[Tensor] = None, |
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
q = k = self.with_pos_embed(tgt, query_pos) |
|
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, |
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
tgt = tgt + self.dropout1(tgt2) |
|
tgt = self.norm1(tgt) |
|
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), |
|
key=self.with_pos_embed(memory, pos), |
|
value=memory, attn_mask=memory_mask, |
|
key_padding_mask=memory_key_padding_mask)[0] |
|
tgt = tgt + self.dropout2(tgt2) |
|
tgt = self.norm2(tgt) |
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
|
tgt = tgt + self.dropout3(tgt2) |
|
tgt = self.norm3(tgt) |
|
return tgt |
|
|
|
def forward_pre(self, tgt, memory, |
|
tgt_mask: Optional[Tensor] = None, |
|
memory_mask: Optional[Tensor] = None, |
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
tgt2 = self.norm1(tgt) |
|
q = k = self.with_pos_embed(tgt2, query_pos) |
|
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, |
|
key_padding_mask=tgt_key_padding_mask)[0] |
|
tgt = tgt + self.dropout1(tgt2) |
|
tgt2 = self.norm2(tgt) |
|
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), |
|
key=self.with_pos_embed(memory, pos), |
|
value=memory, attn_mask=memory_mask, |
|
key_padding_mask=memory_key_padding_mask)[0] |
|
tgt = tgt + self.dropout2(tgt2) |
|
tgt2 = self.norm3(tgt) |
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
|
tgt = tgt + self.dropout3(tgt2) |
|
return tgt |
|
|
|
def forward(self, tgt, memory, |
|
tgt_mask: Optional[Tensor] = None, |
|
memory_mask: Optional[Tensor] = None, |
|
tgt_key_padding_mask: Optional[Tensor] = None, |
|
memory_key_padding_mask: Optional[Tensor] = None, |
|
pos: Optional[Tensor] = None, |
|
query_pos: Optional[Tensor] = None): |
|
if self.normalize_before: |
|
return self.forward_pre(tgt, memory, tgt_mask, memory_mask, |
|
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
|
return self.forward_post(tgt, memory, tgt_mask, memory_mask, |
|
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
|
|
|
|
|
def _get_clones(module, N): |
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
|
|
|
|
|
def build_transformer(args): |
|
return Transformer( |
|
d_model=args.hidden_dim, |
|
dropout=args.dropout, |
|
nhead=args.nheads, |
|
dim_feedforward=args.dim_feedforward, |
|
num_encoder_layers=args.enc_layers, |
|
num_decoder_layers=args.dec_layers, |
|
normalize_before=args.pre_norm, |
|
return_intermediate_dec=True, |
|
) |
|
|
|
|
|
def _get_activation_fn(activation): |
|
"""Return an activation function given a string""" |
|
if activation == "relu": |
|
return F.relu |
|
if activation == "gelu": |
|
return F.gelu |
|
if activation == "glu": |
|
return F.glu |
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
|
|
|
|
|
class LLMEyeCap(nn.Module): |
|
|
|
def __init__(self, backbone, transformer, num_queries, vocab_size,pad_token): |
|
|
|
super().__init__() |
|
self.num_queries = num_queries |
|
self.transformer = transformer |
|
self.hidden_dim = transformer.d_model |
|
|
|
self.caption_embed = nn.Linear(self.hidden_dim, vocab_size) |
|
self.bbox_embed = MLP(self.hidden_dim, self.hidden_dim, 4, 3) |
|
|
|
self.query_embed = nn.Embedding(num_queries, self.hidden_dim) |
|
self.input_proj = nn.Conv2d(backbone.num_channels, self.hidden_dim, kernel_size=1) |
|
self.backbone = backbone |
|
''' |
|
self.capdecoder = CaptioningDecoder(detr_decoder_dim=transformer.d_model, token_embedding_dim=transformer.d_model, |
|
vocab_size=vocab_size, num_queries=num_queries, num_layers=6) |
|
''' |
|
self.capdecoder = CaptionDecoder(feature_size, token_size, vocab_size,num_queries,pad_token ).to(device) |
|
|
|
|
|
def forward(self, samples: NestedTensor, captions): |
|
|
|
if isinstance(samples, (list, torch.Tensor)): |
|
samples = nested_tensor_from_tensor_list(samples) |
|
|
|
features, pos = self.backbone(samples) |
|
src, mask = features[-1].decompose() |
|
assert mask is not None |
|
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] |
|
outputs_coord = self.bbox_embed(hs).sigmoid() |
|
|
|
outputs_captions=self.capdecoder(hs,captions) |
|
|
|
|
|
out = {'pred_logits': outputs_captions , 'pred_boxes': outputs_coord[-1]} |
|
return out |
|
|
|
def generate_caption(self, image_path, tokenizer, max_length, pad_sos): |
|
|
|
image = Image.open(image_path).convert('RGB') |
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
image = transform(image).unsqueeze(0).to(device) |
|
|
|
if isinstance(image, (list, torch.Tensor)): |
|
image = nested_tensor_from_tensor_list(image) |
|
|
|
with torch.no_grad(): |
|
features, pos = self.backbone(image) |
|
src, mask = features[-1].decompose() |
|
assert mask is not None |
|
|
|
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] |
|
outputs_coord = self.bbox_embed(hs).sigmoid() |
|
|
|
input_ids = torch.ones((1, 40, 1), dtype=torch.long, device=device) |
|
input_ids.fill_(pad_sos) |
|
|
|
|
|
for i in range(max_length): |
|
outputs_captions = self.capdecoder(hs, input_ids) |
|
predicted_sequences = torch.argmax(outputs_captions, dim=-1) |
|
next_token = predicted_sequences[:, :, -1:] |
|
input_ids = torch.cat((input_ids, next_token), dim=-1) |
|
|
|
|
|
|
|
return outputs_coord[-1], input_ids |
|
|
|
class LLMEyeCapModel(nn.Module): |
|
def __init__(self, num_queries,vocab_size,pad_token): |
|
super(LLMEyeCapModel,self).__init__() |
|
self.num_queries = num_queries |
|
self.vocab_size=vocab_size |
|
self.backbone = build_backbone(args) |
|
self.transformer = build_transformer(args) |
|
|
|
self.model = LLMEyeCap( |
|
self.backbone, |
|
self.transformer, |
|
num_queries=self.num_queries, |
|
vocab_size=self.vocab_size, |
|
pad_token=pad_token |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.model.num_queries = self.num_queries |
|
|
|
def forward(self,images,captions): |
|
return self.model(images,captions) |
|
|
|
def generate_caption(self, image_path, tokenizer, max_length=20,pad_sos=0): |
|
return self.model.generate_caption(image_path, tokenizer, max_length,pad_sos) |
|
|
|
class CaptionDecoder(nn.Module): |
|
def __init__(self, detr_decoder_dim, token_embedding_dim, vocab_size, num_queries, pad_token, num_layers=6): |
|
super(CaptionDecoder, self).__init__() |
|
|
|
self.detr_decoder_dim = detr_decoder_dim |
|
self.token_embedding_dim = token_embedding_dim |
|
self.vocab_size = vocab_size |
|
self.num_queries = num_queries |
|
self.pad_token = pad_token |
|
|
|
|
|
self.token_embedding = nn.Embedding(vocab_size, token_embedding_dim) |
|
|
|
|
|
config = GPT2Config(vocab_size=vocab_size, n_embd=detr_decoder_dim + token_embedding_dim, n_head=8 ) |
|
self.gpt2 = GPT2LMHeadModel(config) |
|
|
|
self.target_projection = nn.Linear(token_embedding_dim, detr_decoder_dim + token_embedding_dim) |
|
|
|
def forward(self, detr_output, captions): |
|
|
|
|
|
|
|
attention_mask = (captions != self.pad_token).float().to(captions.device) |
|
|
|
|
|
seq_length = captions.size(2) |
|
pos_encoding = get_sinusoid_encoding_table(seq_length, self.token_embedding_dim).to(captions.device) |
|
pos_encoding = pos_encoding.unsqueeze(0).repeat(captions.size(0) * self.num_queries, 1, 1) |
|
|
|
|
|
spatial_embedding = detr_output[-1] |
|
|
|
|
|
token_embeddings = self.token_embedding(captions) |
|
|
|
|
|
spatial_embedding = spatial_embedding.unsqueeze(2) |
|
combined_embedding = torch.cat([spatial_embedding.repeat(1, 1, token_embeddings.size(2), 1), token_embeddings], dim=-1) |
|
|
|
|
|
|
|
memory = combined_embedding.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.detr_decoder_dim + self.token_embedding_dim) |
|
|
|
|
|
|
|
target = token_embeddings.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.token_embedding_dim) |
|
|
|
|
|
|
|
pos_encoding = pos_encoding.permute(1, 0, 2) |
|
target += pos_encoding |
|
|
|
|
|
|
|
|
|
target = self.target_projection(target) |
|
|
|
attention_mask = attention_mask.permute(2, 0, 1).reshape(captions.size(2), -1) |
|
tgt_key_padding_mask = (attention_mask == 0).permute(1,0) |
|
|
|
|
|
inputs_embeds = combined_embedding.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.detr_decoder_dim + self.token_embedding_dim) |
|
|
|
|
|
attention_mask = attention_mask.reshape(-1, captions.size(2)) |
|
|
|
|
|
outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
|
|
|
|
logits = logits.view(captions.size(2), captions.size(0), self.num_queries, self.vocab_size).permute(1, 2, 0, 3) |
|
|
|
return logits |
|
|
|
|