Upload 5 files
Browse files- LLMEyeCap_01.bin +3 -0
- model.py +893 -0
- train.py +642 -0
- tuto.ipynb +0 -0
- utils.py +569 -0
LLMEyeCap_01.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d53f80ed02bdee05882919aa81232ebf8e1af0510bfb6388dc6e616ce57db2a3
|
3 |
+
size 445770457
|
model.py
ADDED
@@ -0,0 +1,893 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from torchvision.models import resnet50
|
5 |
+
from torchvision import transforms
|
6 |
+
from PIL import Image
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from transformers import BertTokenizer, BertModel
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import numpy as np
|
12 |
+
from collections import defaultdict
|
13 |
+
import random
|
14 |
+
from tqdm.notebook import tqdm
|
15 |
+
from torchvision import models
|
16 |
+
from torch.nn.utils.rnn import pad_sequence
|
17 |
+
import matplotlib.patches as patches
|
18 |
+
|
19 |
+
import math
|
20 |
+
import time
|
21 |
+
import os
|
22 |
+
from PIL import Image
|
23 |
+
import requests
|
24 |
+
import nltk
|
25 |
+
|
26 |
+
import os
|
27 |
+
import cv2
|
28 |
+
import colorsys
|
29 |
+
from numpy import asarray
|
30 |
+
import math
|
31 |
+
|
32 |
+
|
33 |
+
from transformers import GPT2LMHeadModel, GPT2Config
|
34 |
+
|
35 |
+
from scipy.optimize import linear_sum_assignment
|
36 |
+
|
37 |
+
import sys
|
38 |
+
sys.path.append("../src")
|
39 |
+
|
40 |
+
from utils import *
|
41 |
+
|
42 |
+
NUM_QUERIES = 40
|
43 |
+
feature_size = 256 # Pour ResNet50
|
44 |
+
token_size = 256 # Pour GPT-2
|
45 |
+
|
46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
47 |
+
|
48 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
49 |
+
# minimal updates here
|
50 |
+
|
51 |
+
"""
|
52 |
+
Various positional encodings for the transformer.
|
53 |
+
"""
|
54 |
+
|
55 |
+
class PositionEmbeddingSine(nn.Module):
|
56 |
+
"""
|
57 |
+
This is a more standard version of the position embedding, very similar to the one
|
58 |
+
used by the Attention is all you need paper, generalized to work on images.
|
59 |
+
"""
|
60 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
61 |
+
super().__init__()
|
62 |
+
self.num_pos_feats = num_pos_feats
|
63 |
+
self.temperature = temperature
|
64 |
+
self.normalize = normalize
|
65 |
+
if scale is not None and normalize is False:
|
66 |
+
raise ValueError("normalize should be True if scale is passed")
|
67 |
+
if scale is None:
|
68 |
+
scale = 2 * math.pi
|
69 |
+
self.scale = scale
|
70 |
+
|
71 |
+
def forward(self, tensor_list: NestedTensor):
|
72 |
+
x = tensor_list.tensors
|
73 |
+
mask = tensor_list.mask
|
74 |
+
assert mask is not None
|
75 |
+
not_mask = ~mask
|
76 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
77 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
78 |
+
if self.normalize:
|
79 |
+
eps = 1e-6
|
80 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
81 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
82 |
+
|
83 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
84 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
85 |
+
|
86 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
87 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
88 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
89 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
90 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
91 |
+
return pos
|
92 |
+
|
93 |
+
|
94 |
+
class PositionEmbeddingLearned(nn.Module):
|
95 |
+
"""
|
96 |
+
Absolute pos embedding, learned.
|
97 |
+
"""
|
98 |
+
def __init__(self, num_pos_feats=256):
|
99 |
+
super().__init__()
|
100 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
101 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
102 |
+
self.reset_parameters()
|
103 |
+
|
104 |
+
def reset_parameters(self):
|
105 |
+
nn.init.uniform_(self.row_embed.weight)
|
106 |
+
nn.init.uniform_(self.col_embed.weight)
|
107 |
+
|
108 |
+
def forward(self, tensor_list: NestedTensor):
|
109 |
+
x = tensor_list.tensors
|
110 |
+
h, w = x.shape[-2:]
|
111 |
+
i = torch.arange(w, device=x.device)
|
112 |
+
j = torch.arange(h, device=x.device)
|
113 |
+
x_emb = self.col_embed(i)
|
114 |
+
y_emb = self.row_embed(j)
|
115 |
+
pos = torch.cat([
|
116 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
117 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
118 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
119 |
+
return pos
|
120 |
+
|
121 |
+
|
122 |
+
def build_position_encoding(args):
|
123 |
+
N_steps = args.hidden_dim // 2
|
124 |
+
if args.position_embedding in ('v2', 'sine'):
|
125 |
+
# TODO find a better way of exposing other arguments
|
126 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
127 |
+
elif args.position_embedding in ('v3', 'learned'):
|
128 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
129 |
+
else:
|
130 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
131 |
+
|
132 |
+
return position_embedding
|
133 |
+
|
134 |
+
from collections import OrderedDict
|
135 |
+
|
136 |
+
import torch
|
137 |
+
import torch.nn.functional as F
|
138 |
+
import torchvision
|
139 |
+
from torch import nn
|
140 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
141 |
+
from typing import Dict, List
|
142 |
+
|
143 |
+
|
144 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
145 |
+
"""
|
146 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
147 |
+
|
148 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
149 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
150 |
+
produce nans.
|
151 |
+
"""
|
152 |
+
|
153 |
+
def __init__(self, n):
|
154 |
+
super(FrozenBatchNorm2d, self).__init__()
|
155 |
+
self.register_buffer("weight", torch.ones(n))
|
156 |
+
self.register_buffer("bias", torch.zeros(n))
|
157 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
158 |
+
self.register_buffer("running_var", torch.ones(n))
|
159 |
+
|
160 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
161 |
+
missing_keys, unexpected_keys, error_msgs):
|
162 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
163 |
+
if num_batches_tracked_key in state_dict:
|
164 |
+
del state_dict[num_batches_tracked_key]
|
165 |
+
|
166 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
167 |
+
state_dict, prefix, local_metadata, strict,
|
168 |
+
missing_keys, unexpected_keys, error_msgs)
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
# move reshapes to the beginning
|
172 |
+
# to make it fuser-friendly
|
173 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
174 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
175 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
176 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
177 |
+
eps = 1e-5
|
178 |
+
scale = w * (rv + eps).rsqrt()
|
179 |
+
bias = b - rm * scale
|
180 |
+
return x * scale + bias
|
181 |
+
|
182 |
+
|
183 |
+
class BackboneBase(nn.Module):
|
184 |
+
|
185 |
+
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
186 |
+
super().__init__()
|
187 |
+
for name, parameter in backbone.named_parameters():
|
188 |
+
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
189 |
+
parameter.requires_grad_(False)
|
190 |
+
if return_interm_layers:
|
191 |
+
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
192 |
+
else:
|
193 |
+
return_layers = {'layer4': "0"}
|
194 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
195 |
+
self.num_channels = num_channels
|
196 |
+
|
197 |
+
def forward(self, tensor_list: NestedTensor):
|
198 |
+
xs = self.body(tensor_list.tensors)
|
199 |
+
out: Dict[str, NestedTensor] = {}
|
200 |
+
for name, x in xs.items():
|
201 |
+
m = tensor_list.mask
|
202 |
+
assert m is not None
|
203 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
204 |
+
out[name] = NestedTensor(x, mask)
|
205 |
+
return out
|
206 |
+
|
207 |
+
'''
|
208 |
+
The line mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] applies a mask to the output
|
209 |
+
features from the backbone. The mask is used to indicate which pixels in the image are valid.
|
210 |
+
|
211 |
+
|
212 |
+
The mask is a tensor of the same size as the output features. The mask is initialized to all zeros. The m[None].float()
|
213 |
+
operation expands the mask to be a 1-D tensor of size 1 x H x W. The F.interpolate()
|
214 |
+
operation then resizes the mask to the same size as the output features. The to(torch.bool) operation converts the
|
215 |
+
mask to a binary tensor. The [0] operation takes the first element of the tensor, which is the mask for the first output
|
216 |
+
feature map.
|
217 |
+
|
218 |
+
The mask of a feature extracted from ResNet50 as a backbone is a binary tensor that indicates which pixels in the image
|
219 |
+
are valid. The pixels that are valid are those that are not padded. The mask is used by the backbone to ignore the padded
|
220 |
+
pixels when it is extracting features from the image.
|
221 |
+
|
222 |
+
'''
|
223 |
+
|
224 |
+
class Backbone(BackboneBase):
|
225 |
+
"""ResNet backbone with frozen BatchNorm."""
|
226 |
+
def __init__(self, name: str,
|
227 |
+
train_backbone: bool,
|
228 |
+
return_interm_layers: bool,
|
229 |
+
dilation: bool):
|
230 |
+
backbone = getattr(torchvision.models, name)(
|
231 |
+
replace_stride_with_dilation=[False, False, dilation],
|
232 |
+
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
|
233 |
+
# ==> todo weights=ResNet50_Weights.DEFAULT)
|
234 |
+
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
235 |
+
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
236 |
+
|
237 |
+
|
238 |
+
class Joiner(nn.Sequential):
|
239 |
+
def __init__(self, backbone, position_embedding):
|
240 |
+
super().__init__(backbone, position_embedding)
|
241 |
+
|
242 |
+
def forward(self, tensor_list: NestedTensor):
|
243 |
+
xs = self[0](tensor_list)
|
244 |
+
out: List[NestedTensor] = []
|
245 |
+
pos = []
|
246 |
+
for name, x in xs.items():
|
247 |
+
out.append(x)
|
248 |
+
# position encoding
|
249 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
250 |
+
|
251 |
+
return out, pos
|
252 |
+
|
253 |
+
|
254 |
+
def build_backbone(args):
|
255 |
+
position_embedding = build_position_encoding(args)
|
256 |
+
train_backbone = args.lr_backbone > 0
|
257 |
+
return_interm_layers = args.masks
|
258 |
+
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
|
259 |
+
model = Joiner(backbone, position_embedding)
|
260 |
+
model.num_channels = backbone.num_channels
|
261 |
+
return model
|
262 |
+
|
263 |
+
def get_sinusoid_encoding_table(n_position, d_hid):
|
264 |
+
def cal_angle(position, hid_idx):
|
265 |
+
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
|
266 |
+
|
267 |
+
def get_posi_angle_vec(position):
|
268 |
+
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
|
269 |
+
|
270 |
+
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
|
271 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
272 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
273 |
+
return torch.FloatTensor(sinusoid_table)
|
274 |
+
|
275 |
+
class PostProcess(nn.Module):
|
276 |
+
""" This module converts the model's output into the format expected by the coco api"""
|
277 |
+
@torch.no_grad()
|
278 |
+
def forward(self, outputs, target_sizes):
|
279 |
+
""" Perform the computation
|
280 |
+
Parameters:
|
281 |
+
outputs: raw outputs of the model
|
282 |
+
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
|
283 |
+
For evaluation, this must be the original image size (before any data augmentation)
|
284 |
+
For visualization, this should be the image size after data augment, but before padding
|
285 |
+
"""
|
286 |
+
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
|
287 |
+
|
288 |
+
assert len(out_logits) == len(target_sizes)
|
289 |
+
assert target_sizes.shape[1] == 2
|
290 |
+
|
291 |
+
prob = F.softmax(out_logits, -1)
|
292 |
+
scores, labels = prob[..., :-1].max(-1)
|
293 |
+
|
294 |
+
# convert to [x0, y0, x1, y1] format
|
295 |
+
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
|
296 |
+
# and from relative [0, 1] to absolute [0, height] coordinates
|
297 |
+
img_h, img_w = target_sizes.unbind(1)
|
298 |
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
299 |
+
boxes = boxes * scale_fct[:, None, :]
|
300 |
+
|
301 |
+
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
|
302 |
+
|
303 |
+
return results
|
304 |
+
|
305 |
+
|
306 |
+
class MLP(nn.Module):
|
307 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
308 |
+
|
309 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
310 |
+
super().__init__()
|
311 |
+
self.num_layers = num_layers
|
312 |
+
h = [hidden_dim] * (num_layers - 1)
|
313 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
314 |
+
|
315 |
+
def forward(self, x):
|
316 |
+
for i, layer in enumerate(self.layers):
|
317 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
318 |
+
return x
|
319 |
+
|
320 |
+
|
321 |
+
def build(args):
|
322 |
+
# the `num_classes` naming here is somewhat misleading.
|
323 |
+
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
|
324 |
+
# is the maximum id for a class in your dataset. For example,
|
325 |
+
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
|
326 |
+
# As another example, for a dataset that has a single class with id 1,
|
327 |
+
# you should pass `num_classes` to be 2 (max_obj_id + 1).
|
328 |
+
# For more details on this, check the following discussion
|
329 |
+
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
|
330 |
+
num_classes = 20 if args.dataset_file != 'coco' else 91
|
331 |
+
if args.dataset_file == "coco_panoptic":
|
332 |
+
# for panoptic, we just add a num_classes that is large enough to hold
|
333 |
+
# max_obj_id + 1, but the exact value doesn't really matter
|
334 |
+
num_classes = 250
|
335 |
+
device = torch.device(args.device)
|
336 |
+
|
337 |
+
backbone = build_backbone(args)
|
338 |
+
|
339 |
+
transformer = build_transformer(args)
|
340 |
+
|
341 |
+
model = DETR(
|
342 |
+
backbone,
|
343 |
+
transformer,
|
344 |
+
num_classes=num_classes,
|
345 |
+
num_queries=args.num_queries,
|
346 |
+
aux_loss=args.aux_loss,
|
347 |
+
)
|
348 |
+
if args.masks:
|
349 |
+
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
|
350 |
+
matcher = build_matcher(args)
|
351 |
+
weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
|
352 |
+
weight_dict['loss_giou'] = args.giou_loss_coef
|
353 |
+
if args.masks:
|
354 |
+
weight_dict["loss_mask"] = args.mask_loss_coef
|
355 |
+
weight_dict["loss_dice"] = args.dice_loss_coef
|
356 |
+
# TODO this is a hack
|
357 |
+
if args.aux_loss:
|
358 |
+
aux_weight_dict = {}
|
359 |
+
for i in range(args.dec_layers - 1):
|
360 |
+
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
|
361 |
+
weight_dict.update(aux_weight_dict)
|
362 |
+
|
363 |
+
losses = ['labels', 'boxes', 'cardinality']
|
364 |
+
if args.masks:
|
365 |
+
losses += ["masks"]
|
366 |
+
criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,
|
367 |
+
eos_coef=args.eos_coef, losses=losses)
|
368 |
+
criterion.to(device)
|
369 |
+
postprocessors = {'bbox': PostProcess()}
|
370 |
+
if args.masks:
|
371 |
+
postprocessors['segm'] = PostProcessSegm()
|
372 |
+
if args.dataset_file == "coco_panoptic":
|
373 |
+
is_thing_map = {i: i <= 90 for i in range(201)}
|
374 |
+
postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)
|
375 |
+
|
376 |
+
return model, criterion, postprocessors
|
377 |
+
|
378 |
+
class Parameters:
|
379 |
+
def __init__(self):
|
380 |
+
self.lr = 1e-4
|
381 |
+
self.lr_backbone = 1e-5
|
382 |
+
self.batch_size = 2
|
383 |
+
self.weight_decay = 1e-4
|
384 |
+
self.epochs = 300
|
385 |
+
self.lr_drop = 200
|
386 |
+
self.clip_max_norm = 0.1
|
387 |
+
|
388 |
+
args = Parameters()
|
389 |
+
|
390 |
+
args.lr=1e-4
|
391 |
+
args.lr_backbone=1e-5
|
392 |
+
args.batch_size=32
|
393 |
+
args.weight_decay=1e-4
|
394 |
+
args.epochs=300
|
395 |
+
args.lr_drop=200
|
396 |
+
args.clip_max_norm=0.1 # type=float, help='gradient clipping max norm')
|
397 |
+
|
398 |
+
# Model parameters
|
399 |
+
args.frozen_weights=False # ', type=str, default=None, # help="Path to the pretrained model. If set, only the mask head will be trained")
|
400 |
+
|
401 |
+
# * Backbone
|
402 |
+
args.backbone='resnet50' # type=str, # help="Name of the convolutional backbone to use")
|
403 |
+
args.dilation=False # ', action='store_true', # help="If true, we replace stride with dilation in the last convolutional block (DC5)")
|
404 |
+
args.position_embedding='sine' # type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")
|
405 |
+
|
406 |
+
# * Transformer
|
407 |
+
args.enc_layers=6 # type=int, help="Number of encoding layers in the transformer")
|
408 |
+
args.dec_layers=6 # type=int, help="Number of decoding layers in the transformer")
|
409 |
+
args.dim_feedforward=2048 # ===> type=int, help="Intermediate size of the feedforward layers in the transformer blocks")
|
410 |
+
args.hidden_dim=256 # ===> type=int, help="Size of the embeddings (dimension of the transformer)")
|
411 |
+
args.dropout=0.1 #type=float, help="Dropout applied in the transformer")
|
412 |
+
args.nheads=8 #type=int, help="Number of attention heads inside the transformer's attentions")
|
413 |
+
args.num_queries=40 #type=int, help="Number of query slots")
|
414 |
+
args.pre_norm=True # ', action='store_true')
|
415 |
+
|
416 |
+
# * Segmentation
|
417 |
+
args.masks=False #, action='store_true', help="Train segmentation head if the flag is provided")
|
418 |
+
|
419 |
+
|
420 |
+
"""
|
421 |
+
LLMEyeCap Transformer class.
|
422 |
+
|
423 |
+
A DETR (FaceBook) Copy-paste from torch.nn.Transformer with modifications:
|
424 |
+
* positional encodings are passed in MHattention
|
425 |
+
* extra LN at the end of encoder is removed
|
426 |
+
* decoder returns a stack of activations from all decoding layers
|
427 |
+
|
428 |
+
"""
|
429 |
+
import copy
|
430 |
+
from typing import Optional, List
|
431 |
+
|
432 |
+
|
433 |
+
class Transformer(nn.Module):
|
434 |
+
|
435 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
436 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
437 |
+
activation="relu", normalize_before=False,
|
438 |
+
return_intermediate_dec=False):
|
439 |
+
super().__init__()
|
440 |
+
|
441 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
442 |
+
dropout, activation, normalize_before)
|
443 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
444 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
445 |
+
|
446 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
447 |
+
dropout, activation, normalize_before)
|
448 |
+
decoder_norm = nn.LayerNorm(d_model)
|
449 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
450 |
+
return_intermediate=return_intermediate_dec)
|
451 |
+
|
452 |
+
self._reset_parameters()
|
453 |
+
|
454 |
+
self.d_model = d_model
|
455 |
+
self.nhead = nhead
|
456 |
+
|
457 |
+
def _reset_parameters(self):
|
458 |
+
for p in self.parameters():
|
459 |
+
if p.dim() > 1:
|
460 |
+
nn.init.xavier_uniform_(p)
|
461 |
+
|
462 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
463 |
+
# flatten NxCxHxW to HWxNxC
|
464 |
+
bs, c, h, w = src.shape
|
465 |
+
src = src.flatten(2).permute(2, 0, 1)
|
466 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
467 |
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
468 |
+
mask = mask.flatten(1)
|
469 |
+
|
470 |
+
tgt = torch.zeros_like(query_embed)
|
471 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
472 |
+
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
473 |
+
pos=pos_embed, query_pos=query_embed)
|
474 |
+
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
|
475 |
+
|
476 |
+
|
477 |
+
class TransformerEncoder(nn.Module):
|
478 |
+
|
479 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
480 |
+
super().__init__()
|
481 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
482 |
+
self.num_layers = num_layers
|
483 |
+
self.norm = norm
|
484 |
+
|
485 |
+
def forward(self, src,
|
486 |
+
mask: Optional[Tensor] = None,
|
487 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
488 |
+
pos: Optional[Tensor] = None):
|
489 |
+
output = src
|
490 |
+
|
491 |
+
for layer in self.layers:
|
492 |
+
output = layer(output, src_mask=mask,
|
493 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
494 |
+
|
495 |
+
if self.norm is not None:
|
496 |
+
output = self.norm(output)
|
497 |
+
|
498 |
+
return output
|
499 |
+
|
500 |
+
|
501 |
+
class TransformerDecoder(nn.Module):
|
502 |
+
|
503 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
504 |
+
super().__init__()
|
505 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
506 |
+
self.num_layers = num_layers
|
507 |
+
self.norm = norm
|
508 |
+
self.return_intermediate = return_intermediate
|
509 |
+
|
510 |
+
def forward(self, tgt, memory,
|
511 |
+
tgt_mask: Optional[Tensor] = None,
|
512 |
+
memory_mask: Optional[Tensor] = None,
|
513 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
514 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
515 |
+
pos: Optional[Tensor] = None,
|
516 |
+
query_pos: Optional[Tensor] = None):
|
517 |
+
output = tgt
|
518 |
+
|
519 |
+
intermediate = []
|
520 |
+
|
521 |
+
for layer in self.layers:
|
522 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
523 |
+
memory_mask=memory_mask,
|
524 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
525 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
526 |
+
pos=pos, query_pos=query_pos)
|
527 |
+
if self.return_intermediate:
|
528 |
+
intermediate.append(self.norm(output))
|
529 |
+
|
530 |
+
if self.norm is not None:
|
531 |
+
output = self.norm(output)
|
532 |
+
if self.return_intermediate:
|
533 |
+
intermediate.pop()
|
534 |
+
intermediate.append(output)
|
535 |
+
|
536 |
+
if self.return_intermediate:
|
537 |
+
return torch.stack(intermediate)
|
538 |
+
|
539 |
+
return output.unsqueeze(0)
|
540 |
+
|
541 |
+
|
542 |
+
class TransformerEncoderLayer(nn.Module):
|
543 |
+
|
544 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
545 |
+
activation="relu", normalize_before=False):
|
546 |
+
super().__init__()
|
547 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
548 |
+
# Implementation of Feedforward model
|
549 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
550 |
+
self.dropout = nn.Dropout(dropout)
|
551 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
552 |
+
|
553 |
+
self.norm1 = nn.LayerNorm(d_model)
|
554 |
+
self.norm2 = nn.LayerNorm(d_model)
|
555 |
+
self.dropout1 = nn.Dropout(dropout)
|
556 |
+
self.dropout2 = nn.Dropout(dropout)
|
557 |
+
|
558 |
+
self.activation = _get_activation_fn(activation)
|
559 |
+
self.normalize_before = normalize_before
|
560 |
+
|
561 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
562 |
+
return tensor if pos is None else tensor + pos
|
563 |
+
|
564 |
+
def forward_post(self,
|
565 |
+
src,
|
566 |
+
src_mask: Optional[Tensor] = None,
|
567 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
568 |
+
pos: Optional[Tensor] = None):
|
569 |
+
q = k = self.with_pos_embed(src, pos)
|
570 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
571 |
+
key_padding_mask=src_key_padding_mask)[0]
|
572 |
+
src = src + self.dropout1(src2)
|
573 |
+
src = self.norm1(src)
|
574 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
575 |
+
src = src + self.dropout2(src2)
|
576 |
+
src = self.norm2(src)
|
577 |
+
return src
|
578 |
+
|
579 |
+
def forward_pre(self, src,
|
580 |
+
src_mask: Optional[Tensor] = None,
|
581 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
582 |
+
pos: Optional[Tensor] = None):
|
583 |
+
src2 = self.norm1(src)
|
584 |
+
q = k = self.with_pos_embed(src2, pos)
|
585 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
586 |
+
key_padding_mask=src_key_padding_mask)[0]
|
587 |
+
src = src + self.dropout1(src2)
|
588 |
+
src2 = self.norm2(src)
|
589 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
590 |
+
src = src + self.dropout2(src2)
|
591 |
+
return src
|
592 |
+
|
593 |
+
def forward(self, src,
|
594 |
+
src_mask: Optional[Tensor] = None,
|
595 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
596 |
+
pos: Optional[Tensor] = None):
|
597 |
+
if self.normalize_before:
|
598 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
599 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
600 |
+
|
601 |
+
|
602 |
+
class TransformerDecoderLayer(nn.Module):
|
603 |
+
|
604 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
605 |
+
activation="relu", normalize_before=False):
|
606 |
+
super().__init__()
|
607 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
608 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
609 |
+
# Implementation of Feedforward model
|
610 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
611 |
+
self.dropout = nn.Dropout(dropout)
|
612 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
613 |
+
|
614 |
+
self.norm1 = nn.LayerNorm(d_model)
|
615 |
+
self.norm2 = nn.LayerNorm(d_model)
|
616 |
+
self.norm3 = nn.LayerNorm(d_model)
|
617 |
+
self.dropout1 = nn.Dropout(dropout)
|
618 |
+
self.dropout2 = nn.Dropout(dropout)
|
619 |
+
self.dropout3 = nn.Dropout(dropout)
|
620 |
+
|
621 |
+
self.activation = _get_activation_fn(activation)
|
622 |
+
self.normalize_before = normalize_before
|
623 |
+
|
624 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
625 |
+
return tensor if pos is None else tensor + pos
|
626 |
+
|
627 |
+
def forward_post(self, tgt, memory,
|
628 |
+
tgt_mask: Optional[Tensor] = None,
|
629 |
+
memory_mask: Optional[Tensor] = None,
|
630 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
631 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
632 |
+
pos: Optional[Tensor] = None,
|
633 |
+
query_pos: Optional[Tensor] = None):
|
634 |
+
q = k = self.with_pos_embed(tgt, query_pos)
|
635 |
+
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
|
636 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
637 |
+
tgt = tgt + self.dropout1(tgt2)
|
638 |
+
tgt = self.norm1(tgt)
|
639 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
640 |
+
key=self.with_pos_embed(memory, pos),
|
641 |
+
value=memory, attn_mask=memory_mask,
|
642 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
643 |
+
tgt = tgt + self.dropout2(tgt2)
|
644 |
+
tgt = self.norm2(tgt)
|
645 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
646 |
+
tgt = tgt + self.dropout3(tgt2)
|
647 |
+
tgt = self.norm3(tgt)
|
648 |
+
return tgt
|
649 |
+
|
650 |
+
def forward_pre(self, tgt, memory,
|
651 |
+
tgt_mask: Optional[Tensor] = None,
|
652 |
+
memory_mask: Optional[Tensor] = None,
|
653 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
654 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
655 |
+
pos: Optional[Tensor] = None,
|
656 |
+
query_pos: Optional[Tensor] = None):
|
657 |
+
tgt2 = self.norm1(tgt)
|
658 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
659 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
660 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
661 |
+
tgt = tgt + self.dropout1(tgt2)
|
662 |
+
tgt2 = self.norm2(tgt)
|
663 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
664 |
+
key=self.with_pos_embed(memory, pos),
|
665 |
+
value=memory, attn_mask=memory_mask,
|
666 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
667 |
+
tgt = tgt + self.dropout2(tgt2)
|
668 |
+
tgt2 = self.norm3(tgt)
|
669 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
670 |
+
tgt = tgt + self.dropout3(tgt2)
|
671 |
+
return tgt
|
672 |
+
|
673 |
+
def forward(self, tgt, memory,
|
674 |
+
tgt_mask: Optional[Tensor] = None,
|
675 |
+
memory_mask: Optional[Tensor] = None,
|
676 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
677 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
678 |
+
pos: Optional[Tensor] = None,
|
679 |
+
query_pos: Optional[Tensor] = None):
|
680 |
+
if self.normalize_before:
|
681 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
682 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
683 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
684 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
685 |
+
|
686 |
+
|
687 |
+
def _get_clones(module, N):
|
688 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
689 |
+
|
690 |
+
|
691 |
+
def build_transformer(args):
|
692 |
+
return Transformer(
|
693 |
+
d_model=args.hidden_dim,
|
694 |
+
dropout=args.dropout,
|
695 |
+
nhead=args.nheads,
|
696 |
+
dim_feedforward=args.dim_feedforward,
|
697 |
+
num_encoder_layers=args.enc_layers,
|
698 |
+
num_decoder_layers=args.dec_layers,
|
699 |
+
normalize_before=args.pre_norm,
|
700 |
+
return_intermediate_dec=True,
|
701 |
+
)
|
702 |
+
|
703 |
+
|
704 |
+
def _get_activation_fn(activation):
|
705 |
+
"""Return an activation function given a string"""
|
706 |
+
if activation == "relu":
|
707 |
+
return F.relu
|
708 |
+
if activation == "gelu":
|
709 |
+
return F.gelu
|
710 |
+
if activation == "glu":
|
711 |
+
return F.glu
|
712 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
713 |
+
|
714 |
+
|
715 |
+
class LLMEyeCap(nn.Module): # Im Novel Object Captioning V 0.1
|
716 |
+
|
717 |
+
def __init__(self, backbone, transformer, num_queries, vocab_size,pad_token):
|
718 |
+
|
719 |
+
super().__init__()
|
720 |
+
self.num_queries = num_queries
|
721 |
+
self.transformer = transformer
|
722 |
+
self.hidden_dim = transformer.d_model
|
723 |
+
|
724 |
+
self.caption_embed = nn.Linear(self.hidden_dim, vocab_size)
|
725 |
+
self.bbox_embed = MLP(self.hidden_dim, self.hidden_dim, 4, 3)
|
726 |
+
|
727 |
+
self.query_embed = nn.Embedding(num_queries, self.hidden_dim)
|
728 |
+
self.input_proj = nn.Conv2d(backbone.num_channels, self.hidden_dim, kernel_size=1)
|
729 |
+
self.backbone = backbone
|
730 |
+
'''
|
731 |
+
self.capdecoder = CaptioningDecoder(detr_decoder_dim=transformer.d_model, token_embedding_dim=transformer.d_model,
|
732 |
+
vocab_size=vocab_size, num_queries=num_queries, num_layers=6)
|
733 |
+
'''
|
734 |
+
self.capdecoder = CaptionDecoder(feature_size, token_size, vocab_size,num_queries,pad_token ).to(device)
|
735 |
+
|
736 |
+
|
737 |
+
def forward(self, samples: NestedTensor, captions):
|
738 |
+
|
739 |
+
if isinstance(samples, (list, torch.Tensor)):
|
740 |
+
samples = nested_tensor_from_tensor_list(samples)
|
741 |
+
|
742 |
+
features, pos = self.backbone(samples) #featers + position embedding
|
743 |
+
src, mask = features[-1].decompose()
|
744 |
+
assert mask is not None
|
745 |
+
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
|
746 |
+
outputs_coord = self.bbox_embed(hs).sigmoid()
|
747 |
+
|
748 |
+
outputs_captions=self.capdecoder(hs,captions)
|
749 |
+
# predicted_sequences = torch.argmax(outputs_captions, dim=-1)
|
750 |
+
|
751 |
+
out = {'pred_logits': outputs_captions , 'pred_boxes': outputs_coord[-1]}
|
752 |
+
return out
|
753 |
+
|
754 |
+
def generate_caption(self, image_path, tokenizer, max_length, pad_sos):
|
755 |
+
|
756 |
+
image = Image.open(image_path).convert('RGB')
|
757 |
+
transform = transforms.Compose([
|
758 |
+
transforms.Resize((256, 256)),
|
759 |
+
transforms.ToTensor(),
|
760 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
761 |
+
])
|
762 |
+
|
763 |
+
image = transform(image).unsqueeze(0).to(device)
|
764 |
+
|
765 |
+
if isinstance(image, (list, torch.Tensor)):
|
766 |
+
image = nested_tensor_from_tensor_list(image)
|
767 |
+
|
768 |
+
with torch.no_grad():
|
769 |
+
features, pos = self.backbone(image) #featers + position embedding
|
770 |
+
src, mask = features[-1].decompose()
|
771 |
+
assert mask is not None
|
772 |
+
|
773 |
+
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
|
774 |
+
outputs_coord = self.bbox_embed(hs).sigmoid()
|
775 |
+
|
776 |
+
input_ids = torch.ones((1, 40, 1), dtype=torch.long, device=device)
|
777 |
+
input_ids.fill_(pad_sos)
|
778 |
+
|
779 |
+
|
780 |
+
for i in range(max_length):
|
781 |
+
outputs_captions = self.capdecoder(hs, input_ids)
|
782 |
+
predicted_sequences = torch.argmax(outputs_captions, dim=-1)
|
783 |
+
next_token = predicted_sequences[:, :, -1:] # take the last token from the sequence
|
784 |
+
input_ids = torch.cat((input_ids, next_token), dim=-1)
|
785 |
+
|
786 |
+
#caption = tokenizer.detokenize(input_ids[0].tolist()) #, skip_special_tokens=True)
|
787 |
+
|
788 |
+
return outputs_coord[-1], input_ids # caption[-1]
|
789 |
+
|
790 |
+
class LLMEyeCapModel(nn.Module):
|
791 |
+
def __init__(self, num_queries,vocab_size,pad_token):
|
792 |
+
super(LLMEyeCapModel,self).__init__()
|
793 |
+
self.num_queries = num_queries
|
794 |
+
self.vocab_size=vocab_size
|
795 |
+
self.backbone = build_backbone(args)
|
796 |
+
self.transformer = build_transformer(args)
|
797 |
+
|
798 |
+
self.model = LLMEyeCap(
|
799 |
+
self.backbone,
|
800 |
+
self.transformer,
|
801 |
+
num_queries=self.num_queries,
|
802 |
+
vocab_size=self.vocab_size,
|
803 |
+
pad_token=pad_token
|
804 |
+
)
|
805 |
+
|
806 |
+
# self.in_features = self.caption_embed.in_features
|
807 |
+
|
808 |
+
# self.model.class_embed = nn.Linear(in_features=self.in_features,out_features=self.num_classes)
|
809 |
+
|
810 |
+
self.model.num_queries = self.num_queries
|
811 |
+
|
812 |
+
def forward(self,images,captions):
|
813 |
+
return self.model(images,captions)
|
814 |
+
|
815 |
+
def generate_caption(self, image_path, tokenizer, max_length=20,pad_sos=0):
|
816 |
+
return self.model.generate_caption(image_path, tokenizer, max_length,pad_sos)
|
817 |
+
|
818 |
+
class CaptionDecoder(nn.Module):
|
819 |
+
def __init__(self, detr_decoder_dim, token_embedding_dim, vocab_size, num_queries, pad_token, num_layers=6):
|
820 |
+
super(CaptionDecoder, self).__init__()
|
821 |
+
|
822 |
+
self.detr_decoder_dim = detr_decoder_dim
|
823 |
+
self.token_embedding_dim = token_embedding_dim
|
824 |
+
self.vocab_size = vocab_size
|
825 |
+
self.num_queries = num_queries
|
826 |
+
self.pad_token = pad_token
|
827 |
+
|
828 |
+
# Token embedding layer
|
829 |
+
self.token_embedding = nn.Embedding(vocab_size, token_embedding_dim)
|
830 |
+
|
831 |
+
# Initialize GPT-2
|
832 |
+
config = GPT2Config(vocab_size=vocab_size, n_embd=detr_decoder_dim + token_embedding_dim, n_head=8 )
|
833 |
+
self.gpt2 = GPT2LMHeadModel(config)
|
834 |
+
|
835 |
+
self.target_projection = nn.Linear(token_embedding_dim, detr_decoder_dim + token_embedding_dim)
|
836 |
+
|
837 |
+
def forward(self, detr_output, captions):
|
838 |
+
|
839 |
+
|
840 |
+
# Create an attention mask with shape [batch_size, num_queries, sequence_length]
|
841 |
+
attention_mask = (captions != self.pad_token).float().to(captions.device) # [batch_size, num_queries, sequence_length]
|
842 |
+
|
843 |
+
|
844 |
+
seq_length = captions.size(2)
|
845 |
+
pos_encoding = get_sinusoid_encoding_table(seq_length, self.token_embedding_dim).to(captions.device)
|
846 |
+
pos_encoding = pos_encoding.unsqueeze(0).repeat(captions.size(0) * self.num_queries, 1, 1)
|
847 |
+
|
848 |
+
# Get the last layer's output from the DETR decoder
|
849 |
+
spatial_embedding = detr_output[-1] # [batch_size, num_queries, detr_decoder_dim]
|
850 |
+
|
851 |
+
# Get token embeddings
|
852 |
+
token_embeddings = self.token_embedding(captions) # [batch_size, num_queries, seq_length, token_embedding_dim]
|
853 |
+
|
854 |
+
# Repeat the spatial embedding for each token in the sequence and concatenate
|
855 |
+
spatial_embedding = spatial_embedding.unsqueeze(2) # Add seq_length dimension: [batch_size, num_queries, 1, detr_decoder_dim]
|
856 |
+
combined_embedding = torch.cat([spatial_embedding.repeat(1, 1, token_embeddings.size(2), 1), token_embeddings], dim=-1)
|
857 |
+
# combined_embedding shape: [batch_size, num_queries, seq_length, detr_decoder_dim + token_embedding_dim]
|
858 |
+
|
859 |
+
# Prepare the memory for the transformer decoder
|
860 |
+
memory = combined_embedding.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.detr_decoder_dim + self.token_embedding_dim)
|
861 |
+
# memory shape: [seq_length, batch_size*num_queries, detr_decoder_dim + token_embedding_dim]
|
862 |
+
|
863 |
+
# Prepare the target for the transformer decoder (using token embeddings)
|
864 |
+
target = token_embeddings.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.token_embedding_dim)
|
865 |
+
# target shape: [seq_length, batch_size*num_queries, token_embedding_dim]
|
866 |
+
|
867 |
+
|
868 |
+
pos_encoding = pos_encoding.permute(1, 0, 2)
|
869 |
+
target += pos_encoding
|
870 |
+
|
871 |
+
|
872 |
+
# Project target to the required dimension
|
873 |
+
|
874 |
+
target = self.target_projection(target)
|
875 |
+
|
876 |
+
attention_mask = attention_mask.permute(2, 0, 1).reshape(captions.size(2), -1)
|
877 |
+
tgt_key_padding_mask = (attention_mask == 0).permute(1,0)
|
878 |
+
|
879 |
+
# Prepare the inputs for GPT-2
|
880 |
+
inputs_embeds = combined_embedding.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.detr_decoder_dim + self.token_embedding_dim)
|
881 |
+
|
882 |
+
# Reshape attention_mask for GPT-2. Flatten the batch_size and num_queries dimensions.
|
883 |
+
attention_mask = attention_mask.reshape(-1, captions.size(2)) # New shape: [batch_size * num_queries, sequence_length]
|
884 |
+
|
885 |
+
# Pass through GPT-2
|
886 |
+
outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
|
887 |
+
logits = outputs.logits
|
888 |
+
|
889 |
+
# Reshape logits to match the original shape
|
890 |
+
logits = logits.view(captions.size(2), captions.size(0), self.num_queries, self.vocab_size).permute(1, 2, 0, 3)
|
891 |
+
|
892 |
+
return logits
|
893 |
+
|
train.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from torchvision.models import resnet50
|
5 |
+
from torchvision import transforms
|
6 |
+
from PIL import Image
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from transformers import BertTokenizer, BertModel
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import numpy as np
|
12 |
+
from collections import defaultdict
|
13 |
+
import random
|
14 |
+
from tqdm.notebook import tqdm
|
15 |
+
from torchvision import models
|
16 |
+
from torch.nn.utils.rnn import pad_sequence
|
17 |
+
import matplotlib.patches as patches
|
18 |
+
|
19 |
+
import math
|
20 |
+
import time
|
21 |
+
import os
|
22 |
+
from PIL import Image
|
23 |
+
import requests
|
24 |
+
import nltk
|
25 |
+
|
26 |
+
import os
|
27 |
+
import cv2
|
28 |
+
import colorsys
|
29 |
+
from numpy import asarray
|
30 |
+
import math
|
31 |
+
|
32 |
+
|
33 |
+
from transformers import GPT2LMHeadModel, GPT2Config
|
34 |
+
|
35 |
+
from transformers import BertTokenizer
|
36 |
+
|
37 |
+
|
38 |
+
from scipy.optimize import linear_sum_assignment
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
class CocoDataset(Dataset):
|
44 |
+
def __init__(self, root_dir, annotation_file, instance_file, max_objects=40, transform=None):
|
45 |
+
self.root_dir = root_dir
|
46 |
+
self.transform = transform
|
47 |
+
self.max_objects = max_objects
|
48 |
+
self.img_cache = dict() # Cache for images
|
49 |
+
|
50 |
+
# Load instance file only once
|
51 |
+
with open(instance_file, 'r') as file:
|
52 |
+
data = json.load(file)
|
53 |
+
instances = data['annotations']
|
54 |
+
categories = data['categories']
|
55 |
+
|
56 |
+
with open(annotation_file, 'r') as file:
|
57 |
+
annotations = json.load(file)['annotations']
|
58 |
+
|
59 |
+
self.image_captions = defaultdict(list)
|
60 |
+
for annotation in annotations:
|
61 |
+
img_id = annotation['image_id']
|
62 |
+
self.image_captions[img_id].append(annotation['caption'])
|
63 |
+
|
64 |
+
self.image_instances = defaultdict(list)
|
65 |
+
self.category_id_to_name = {category['id']: category['name'] for category in categories}
|
66 |
+
|
67 |
+
for instance in instances:
|
68 |
+
img_id = instance['image_id']
|
69 |
+
bbox = instance['bbox']
|
70 |
+
category_id = instance['category_id']
|
71 |
+
self.image_instances[img_id].append((bbox, category_id))
|
72 |
+
|
73 |
+
self.img_ids = list(self.image_captions.keys())
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return len(self.img_ids)
|
77 |
+
|
78 |
+
def __getitem__(self, index):
|
79 |
+
img_id = self.img_ids[index]
|
80 |
+
img_path = os.path.join(self.root_dir, f'{str(img_id).zfill(12)}.jpg')
|
81 |
+
|
82 |
+
# Use cached image if available
|
83 |
+
|
84 |
+
if img_id in self.img_cache:
|
85 |
+
img = self.img_cache[img_id]
|
86 |
+
else:
|
87 |
+
img = Image.open(img_path).convert("RGB")
|
88 |
+
self.img_cache[img_id] = img
|
89 |
+
|
90 |
+
|
91 |
+
captions = self.image_captions[img_id]
|
92 |
+
caption = random.choice(captions)
|
93 |
+
|
94 |
+
annotations = self.image_instances[img_id]
|
95 |
+
bboxes = []
|
96 |
+
labels = []
|
97 |
+
for obbox, label_id in annotations:
|
98 |
+
bbox = torch.tensor(obbox) # Convert to PyTorch tensor immediately
|
99 |
+
bbox[0] = bbox[0] / img.width + (bbox[2] / img.width)/2
|
100 |
+
bbox[1] = bbox[1] / img.height + (bbox[3] / img.height)/2
|
101 |
+
bbox[2] = bbox[2] / img.width
|
102 |
+
bbox[3] = bbox[3] / img.height
|
103 |
+
label = self.category_id_to_name[label_id]
|
104 |
+
bboxes.append(bbox)
|
105 |
+
labels.append(label)
|
106 |
+
|
107 |
+
bboxes.append(torch.tensor([0.5, 0.5, 1, 1]))
|
108 |
+
labels.append(caption)
|
109 |
+
|
110 |
+
total_boxes = len(bboxes)
|
111 |
+
|
112 |
+
if total_boxes < 40:
|
113 |
+
for _ in range(40-total_boxes):
|
114 |
+
bboxes.append(torch.tensor([0, 0, 0 ,0]))
|
115 |
+
labels.append("na")
|
116 |
+
else:
|
117 |
+
bboxes = bboxes[:40]
|
118 |
+
labels = labels[:40]
|
119 |
+
|
120 |
+
if self.transform:
|
121 |
+
img = self.transform(img)
|
122 |
+
|
123 |
+
return img, bboxes, labels
|
124 |
+
|
125 |
+
# Définir les transformations
|
126 |
+
transform = transforms.Compose([
|
127 |
+
transforms.Resize((256, 256)),
|
128 |
+
transforms.ToTensor(),
|
129 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
130 |
+
])
|
131 |
+
|
132 |
+
|
133 |
+
def custom_collate(batch):
|
134 |
+
images, boxes_list, labels_list = zip(*batch)
|
135 |
+
|
136 |
+
# Convert list of PIL images to a single PyTorch tensor
|
137 |
+
stacked_images = torch.stack(images)
|
138 |
+
|
139 |
+
# Convert list of list of boxes to a list of PyTorch tensors
|
140 |
+
stacked_boxes = [torch.stack([box.clone().detach() for box in boxes]) for boxes in boxes_list]
|
141 |
+
|
142 |
+
|
143 |
+
# Since labels are strings, we can keep them as a list of lists
|
144 |
+
# labels_list is already in the desired format
|
145 |
+
|
146 |
+
return stacked_images, stacked_boxes, labels_list
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
def train_fn(data_loader, model, criterion, optimizer, device, scheduler, epoch):
|
156 |
+
model.train()
|
157 |
+
criterion.train()
|
158 |
+
summary_loss = AverageMeter()
|
159 |
+
|
160 |
+
tk0 = tqdm(data_loader, total=len(data_loader)-1)
|
161 |
+
|
162 |
+
for step, (images, bboxes, captions) in enumerate(tk0):
|
163 |
+
|
164 |
+
try:
|
165 |
+
flattened_captions = [caption for sublist in captions for caption in sublist]
|
166 |
+
captions = tokenizer(flattened_captions, padding=True, return_tensors="pt", truncation=True)
|
167 |
+
captions = captions["input_ids"]
|
168 |
+
input_ids = captions.reshape(batch_size, num_queries, -1).to(device)
|
169 |
+
min_length = 2
|
170 |
+
except RuntimeError as e:
|
171 |
+
print("Reshape failed:", e)
|
172 |
+
continue
|
173 |
+
|
174 |
+
'''
|
175 |
+
min_length = 2
|
176 |
+
if input_ids.size(-1) < min_length:
|
177 |
+
padding_needed = min_length - input_ids.size(-1)
|
178 |
+
input_ids = F.pad(input_ids, (0, padding_needed), 'constant', PAD_TOKEN)
|
179 |
+
|
180 |
+
targets = build_targets(bboxes, input_ids[:, :, 1:])
|
181 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
182 |
+
|
183 |
+
images = list(image.to(device) for image in images)
|
184 |
+
|
185 |
+
|
186 |
+
output = model(images,input_ids[:, :,:-1])
|
187 |
+
'''
|
188 |
+
|
189 |
+
min_length = 2
|
190 |
+
if input_ids.size(-1) < min_length:
|
191 |
+
padding_needed = min_length - input_ids.size(-1)
|
192 |
+
input_ids = F.pad(input_ids, (0, padding_needed), 'constant', PAD_TOKEN)
|
193 |
+
|
194 |
+
# input_ids = captions["input_ids"]
|
195 |
+
# input_ids = input_ids.reshape(batch_size, num_queries, -1).to(device)
|
196 |
+
|
197 |
+
targets = build_targets(bboxes, input_ids[:, :, 1:])
|
198 |
+
|
199 |
+
#targets = build_targets(bboxes, captions[:,:,1:])
|
200 |
+
|
201 |
+
images = list(image.to(device) for image in images)
|
202 |
+
|
203 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
204 |
+
|
205 |
+
|
206 |
+
output = model(images,input_ids[:,:,:-1])
|
207 |
+
|
208 |
+
loss_dict = criterion(output, targets)
|
209 |
+
weight_dict = criterion.weight_dict
|
210 |
+
|
211 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
212 |
+
|
213 |
+
optimizer.zero_grad()
|
214 |
+
losses.backward()
|
215 |
+
optimizer.step()
|
216 |
+
|
217 |
+
if scheduler is not None:
|
218 |
+
scheduler.step()
|
219 |
+
|
220 |
+
# Detach and delete tensors
|
221 |
+
loss_dict = {k: v.detach() for k, v in loss_dict.items()}
|
222 |
+
|
223 |
+
del images, bboxes, captions, output, targets, loss_dict
|
224 |
+
torch.cuda.empty_cache() # Clear cache
|
225 |
+
|
226 |
+
summary_loss.update(losses.item(),BATCH_SIZE)
|
227 |
+
tk0.set_postfix(loss=summary_loss.avg)
|
228 |
+
|
229 |
+
|
230 |
+
return summary_loss
|
231 |
+
class HungarianMatcher(nn.Module):
|
232 |
+
"""This class computes an assignment between the targets and the predictions of the network
|
233 |
+
|
234 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
235 |
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
236 |
+
while the others are un-matched (and thus treated as non-objects).
|
237 |
+
"""
|
238 |
+
|
239 |
+
def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
|
240 |
+
"""Creates the matcher
|
241 |
+
|
242 |
+
Params:
|
243 |
+
cost_class: This is the relative weight of the classification error in the matching cost
|
244 |
+
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
245 |
+
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
246 |
+
"""
|
247 |
+
super().__init__()
|
248 |
+
self.cost_class = cost_class
|
249 |
+
self.cost_bbox = cost_bbox
|
250 |
+
self.cost_giou = cost_giou
|
251 |
+
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
|
252 |
+
|
253 |
+
@torch.no_grad()
|
254 |
+
def forward(self, outputs, targets):
|
255 |
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
256 |
+
|
257 |
+
# We flatten to compute the cost matrices in a batch
|
258 |
+
# out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
259 |
+
|
260 |
+
out_prob = outputs["pred_logits"].flatten(0,2 ).softmax(-1) # [batch_size * num_queries * seq_length, vocab_size ]
|
261 |
+
|
262 |
+
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
263 |
+
|
264 |
+
# Also concat the target labels and boxes
|
265 |
+
tgt_ids = torch.cat([v["labels"] for v in targets])
|
266 |
+
tgt_bbox = torch.cat([v["boxes"] for v in targets])
|
267 |
+
|
268 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
269 |
+
# but approximate it in 1 - proba[target class].
|
270 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
271 |
+
|
272 |
+
cost_class = -out_prob[:, tgt_ids]
|
273 |
+
|
274 |
+
# Compute the L1 cost between boxes
|
275 |
+
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
276 |
+
|
277 |
+
# Compute the giou cost betwen boxes
|
278 |
+
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
|
279 |
+
|
280 |
+
# Final cost matrix
|
281 |
+
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class.mean() + self.cost_giou * cost_giou
|
282 |
+
#C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
|
283 |
+
C = C.view(bs, num_queries, -1).cpu()
|
284 |
+
|
285 |
+
sizes = [len(v["boxes"]) for v in targets]
|
286 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
|
287 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
def build_matcher(args):
|
292 |
+
return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
class SetCriterion(nn.Module):
|
297 |
+
""" This class computes the loss for DETR.
|
298 |
+
The process happens in two steps:
|
299 |
+
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
300 |
+
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
301 |
+
"""
|
302 |
+
def __init__(self, vocab_size, matcher, weight_dict, eos_coef, losses,pad_token):
|
303 |
+
""" Create the criterion.
|
304 |
+
Parameters:
|
305 |
+
vocab_size : es number of object categories, omitting the special no-object category
|
306 |
+
matcher: module able to compute a matching between targets and proposals
|
307 |
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
308 |
+
eos_coef: relative classification weight applied to the no-object category
|
309 |
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
310 |
+
"""
|
311 |
+
super().__init__()
|
312 |
+
self.vocab_size = vocab_size
|
313 |
+
self.matcher = matcher
|
314 |
+
self.weight_dict = weight_dict
|
315 |
+
self.eos_coef = eos_coef
|
316 |
+
self.losses = losses
|
317 |
+
self.pad_token=pad_token
|
318 |
+
empty_weight = torch.ones(self.vocab_size)
|
319 |
+
# empty_weight[-1] = self.eos_coef
|
320 |
+
self.register_buffer('empty_weight', empty_weight)
|
321 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=pad_token)
|
322 |
+
|
323 |
+
|
324 |
+
def loss_labels(self, outputs, targets, indices, num_boxes, log=False):
|
325 |
+
|
326 |
+
"""Classification loss (NLL) for sequences
|
327 |
+
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes, seq_length]
|
328 |
+
"""
|
329 |
+
assert 'pred_logits' in outputs
|
330 |
+
src_logits = outputs['pred_logits']
|
331 |
+
batch_size, num_boxes , sequence_length, _ = src_logits.size()
|
332 |
+
|
333 |
+
# Get the indices for the permutation
|
334 |
+
batch_idx, src_idx = self._get_src_permutation_idx(indices)
|
335 |
+
|
336 |
+
target_classes = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
|
337 |
+
|
338 |
+
# Ensure the target classes are valid
|
339 |
+
assert (target_classes >= 0).all() and (target_classes < self.vocab_size).all(), "Invalid token index in target!"
|
340 |
+
|
341 |
+
# loss_ce = criterion(outputs.reshape(-1, vocab_size), captions.view(-1))
|
342 |
+
loss_ce = self.criterion(src_logits.reshape(batch_size * num_boxes * sequence_length, -1), target_classes.reshape(-1))
|
343 |
+
|
344 |
+
|
345 |
+
# loss_ce = torchmetrics.functional.smooth_cross_entropy(src_logits[batch_idx], target_classes, ignore_index=PAD_TOKEN)
|
346 |
+
losses = {'loss_ce': loss_ce}
|
347 |
+
|
348 |
+
return losses
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
'''
|
354 |
+
criterion = nn.CrossEntropyLoss(ignore_index=self.PAD_TOKEN)
|
355 |
+
loss_ce = criterion(src_logits, target_classes_for_loss)
|
356 |
+
losses = {'loss_ce': loss_ce}
|
357 |
+
'''
|
358 |
+
|
359 |
+
@torch.no_grad()
|
360 |
+
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
361 |
+
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
|
362 |
+
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
363 |
+
"""
|
364 |
+
pred_logits = outputs['pred_logits']
|
365 |
+
device = pred_logits.device
|
366 |
+
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
|
367 |
+
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
368 |
+
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
|
369 |
+
|
370 |
+
card_pred = card_pred.sum(dim=1)
|
371 |
+
|
372 |
+
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
373 |
+
losses = {'cardinality_error': card_err}
|
374 |
+
return losses
|
375 |
+
|
376 |
+
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
377 |
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
378 |
+
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
379 |
+
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
380 |
+
"""
|
381 |
+
assert 'pred_boxes' in outputs
|
382 |
+
idx = self._get_src_permutation_idx(indices)
|
383 |
+
|
384 |
+
src_boxes = outputs['pred_boxes'][idx]
|
385 |
+
target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
386 |
+
|
387 |
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
388 |
+
|
389 |
+
losses = {}
|
390 |
+
losses['loss_bbox'] = loss_bbox.sum() / num_boxes
|
391 |
+
|
392 |
+
loss_giou = 1 - torch.diag(generalized_box_iou(
|
393 |
+
box_cxcywh_to_xyxy(src_boxes),
|
394 |
+
box_cxcywh_to_xyxy(target_boxes)))
|
395 |
+
losses['loss_giou'] = loss_giou.sum() / num_boxes
|
396 |
+
return losses
|
397 |
+
|
398 |
+
def loss_masks(self, outputs, targets, indices, num_boxes):
|
399 |
+
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
400 |
+
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
401 |
+
"""
|
402 |
+
assert "pred_masks" in outputs
|
403 |
+
|
404 |
+
src_idx = self._get_src_permutation_idx(indices)
|
405 |
+
tgt_idx = self._get_tgt_permutation_idx(indices)
|
406 |
+
src_masks = outputs["pred_masks"]
|
407 |
+
src_masks = src_masks[src_idx]
|
408 |
+
masks = [t["masks"] for t in targets]
|
409 |
+
# TODO use valid to mask invalid areas due to padding in loss
|
410 |
+
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
411 |
+
target_masks = target_masks.to(src_masks)
|
412 |
+
target_masks = target_masks[tgt_idx]
|
413 |
+
|
414 |
+
# upsample predictions to the target size
|
415 |
+
src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
|
416 |
+
mode="bilinear", align_corners=False)
|
417 |
+
src_masks = src_masks[:, 0].flatten(1)
|
418 |
+
|
419 |
+
target_masks = target_masks.flatten(1)
|
420 |
+
target_masks = target_masks.view(src_masks.shape)
|
421 |
+
losses = {
|
422 |
+
"loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
|
423 |
+
"loss_dice": dice_loss(src_masks, target_masks, num_boxes),
|
424 |
+
}
|
425 |
+
return losses
|
426 |
+
|
427 |
+
def _get_src_permutation_idx(self, indices):
|
428 |
+
# permute predictions following indices
|
429 |
+
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
430 |
+
src_idx = torch.cat([src for (src, _) in indices])
|
431 |
+
return batch_idx, src_idx
|
432 |
+
|
433 |
+
def _get_tgt_permutation_idx(self, indices):
|
434 |
+
# permute targets following indices
|
435 |
+
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
436 |
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
437 |
+
return batch_idx, tgt_idx
|
438 |
+
|
439 |
+
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
440 |
+
loss_map = {
|
441 |
+
'labels': self.loss_labels,
|
442 |
+
'cardinality': self.loss_cardinality,
|
443 |
+
'boxes': self.loss_boxes,
|
444 |
+
'masks': self.loss_masks
|
445 |
+
}
|
446 |
+
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
447 |
+
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
448 |
+
|
449 |
+
def forward(self, outputs, targets):
|
450 |
+
""" This performs the loss computation.
|
451 |
+
Parameters:
|
452 |
+
outputs: dict of tensors, see the output specification of the model for the format
|
453 |
+
targets: list of dicts, such that len(targets) == batch_size.
|
454 |
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
455 |
+
"""
|
456 |
+
|
457 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
|
458 |
+
|
459 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
460 |
+
indices = self.matcher(outputs_without_aux, targets)
|
461 |
+
|
462 |
+
# print("indice len", len(indices), "len (indices[0]) ", len (indices[0]))
|
463 |
+
# print( " shape indices 0 0 ", indices [0][0].shape , " shape indices 0 1 ", indices [0][1].shape)
|
464 |
+
|
465 |
+
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
466 |
+
num_boxes = sum(len(t["labels"]) for t in targets)
|
467 |
+
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
468 |
+
if is_dist_avail_and_initialized():
|
469 |
+
torch.distributed.all_reduce(num_boxes)
|
470 |
+
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
471 |
+
# print("num_boxes",num_boxes)
|
472 |
+
# Compute all the requested losses
|
473 |
+
losses = {}
|
474 |
+
for loss in self.losses:
|
475 |
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
476 |
+
'''
|
477 |
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
478 |
+
if 'aux_outputs' in outputs:
|
479 |
+
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
480 |
+
indices = self.matcher(aux_outputs, targets)
|
481 |
+
for loss in self.losses:
|
482 |
+
if loss == 'masks':
|
483 |
+
# Intermediate masks losses are too costly to compute, we ignore them.
|
484 |
+
continue
|
485 |
+
kwargs = {}
|
486 |
+
if loss == 'labels':
|
487 |
+
# Logging is enabled only for the last layer
|
488 |
+
kwargs = {'log': False}
|
489 |
+
l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
|
490 |
+
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
491 |
+
losses.update(l_dict)
|
492 |
+
'''
|
493 |
+
return losses
|
494 |
+
|
495 |
+
def eval_fn(data_loader, model,criterion, device):
|
496 |
+
model.eval()
|
497 |
+
criterion.eval()
|
498 |
+
summary_loss = AverageMeter()
|
499 |
+
|
500 |
+
with torch.no_grad():
|
501 |
+
|
502 |
+
#tk0 = tqdm(data_loader, total=len(data_loader))
|
503 |
+
#for step, (images, bboxes, captions) in enumerate(tk0):
|
504 |
+
#pbar = tqdm(range(len(data_loader)))**
|
505 |
+
|
506 |
+
tk0 = tqdm(data_loader, total=len(data_loader)-1)
|
507 |
+
for step, (images, bboxes, captions) in enumerate(tk0):
|
508 |
+
|
509 |
+
try:
|
510 |
+
flattened_captions = [caption for sublist in captions for caption in sublist]
|
511 |
+
captions = tokenizer(flattened_captions, padding=True, return_tensors="pt", truncation=True)
|
512 |
+
captions = captions["input_ids"]
|
513 |
+
input_ids = captions.reshape(batch_size, num_queries, -1).to(device)
|
514 |
+
min_length = 2
|
515 |
+
except RuntimeError as e:
|
516 |
+
print("Reshape failed:", e)
|
517 |
+
continue
|
518 |
+
|
519 |
+
if input_ids.size(-1) < min_length:
|
520 |
+
padding_needed = min_length - input_ids.size(-1)
|
521 |
+
input_ids = F.pad(input_ids, (0, padding_needed), 'constant', PAD_TOKEN)
|
522 |
+
|
523 |
+
# input_ids = captions["input_ids"]
|
524 |
+
# input_ids = input_ids.reshape(batch_size, num_queries, -1).to(device)
|
525 |
+
|
526 |
+
targets = build_targets(bboxes, input_ids[:, :, 1:])
|
527 |
+
|
528 |
+
#targets = build_targets(bboxes, captions[:,:,1:])
|
529 |
+
|
530 |
+
images = list(image.to(device) for image in images)
|
531 |
+
|
532 |
+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
533 |
+
|
534 |
+
|
535 |
+
output = model(images,input_ids[:,:,:-1])
|
536 |
+
|
537 |
+
|
538 |
+
loss_dict = criterion(output, targets)
|
539 |
+
weight_dict = criterion.weight_dict
|
540 |
+
|
541 |
+
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
|
542 |
+
|
543 |
+
|
544 |
+
summary_loss.update(losses.item(),BATCH_SIZE)
|
545 |
+
|
546 |
+
#
|
547 |
+
|
548 |
+
# Detach and delete tensors
|
549 |
+
loss_dict = {k: v.detach() for k, v in loss_dict.items()}
|
550 |
+
|
551 |
+
del images, bboxes, captions, output, targets, loss_dict
|
552 |
+
torch.cuda.empty_cache() # Clear cache
|
553 |
+
|
554 |
+
tk0.set_postfix(loss=summary_loss.avg)
|
555 |
+
#data_loader.on_epoch_end()
|
556 |
+
|
557 |
+
return summary_loss
|
558 |
+
|
559 |
+
def build_targets(bboxes, captions):
|
560 |
+
targets = []
|
561 |
+
for i, (bbox, caption) in enumerate(zip(bboxes, captions)):
|
562 |
+
target = {
|
563 |
+
"boxes": bbox,
|
564 |
+
"labels": caption,
|
565 |
+
}
|
566 |
+
targets.append(target)
|
567 |
+
return targets
|
568 |
+
|
569 |
+
if __name__ == "__main__":
|
570 |
+
|
571 |
+
# Créer les datasets
|
572 |
+
train_dataset = CocoDataset(root_dir="../data/coco91/train2017",
|
573 |
+
annotation_file="../data/coco91/annotations/captions_train2017.json",
|
574 |
+
instance_file="../data/coco91/annotations/instances_train2017.json",
|
575 |
+
transform=transform)
|
576 |
+
val_dataset = CocoDataset(root_dir="../data/coco91/val2017", annotation_file="../data/coco91/annotations/captions_val2017.json",
|
577 |
+
instance_file="../data/coco91/annotations/instances_val2017.json",
|
578 |
+
transform=transform)
|
579 |
+
|
580 |
+
|
581 |
+
batch_size=4
|
582 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
|
583 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=custom_collate)
|
584 |
+
|
585 |
+
# Initialiser le tokenizer BERT
|
586 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
587 |
+
|
588 |
+
# Obtenir le token de padding et son ID
|
589 |
+
#PAD_TOKEN = tokenizer.pad_token
|
590 |
+
PAD_TOKEN = tokenizer.pad_token_id
|
591 |
+
|
592 |
+
# Obtenir le token de début de séquence et son ID
|
593 |
+
# Pour BERT, le token de début de séquence est souvent le même que le token [CLS]
|
594 |
+
#start_of_sequence_token = tokenizer.cls_token
|
595 |
+
PAD_SOS = tokenizer.cls_token_id
|
596 |
+
|
597 |
+
# Obtenir la taille du vocabulaire
|
598 |
+
vocab_size = tokenizer.vocab_size
|
599 |
+
|
600 |
+
print(f"Pad token: {PAD_TOKEN}")
|
601 |
+
print(f"Start of Sequence token: {PAD_SOS}, ID: {PAD_SOS}")
|
602 |
+
print(f"Vocab size: {vocab_size}")
|
603 |
+
|
604 |
+
matcher = HungarianMatcher()
|
605 |
+
|
606 |
+
weight_dict = weight_dict = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1}
|
607 |
+
|
608 |
+
losses = ['labels', 'boxes', 'cardinality']
|
609 |
+
|
610 |
+
criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
|
611 |
+
|
612 |
+
model = LLMEyaCapModel(num_queries=NUM_QUERIES,vocab_size=vocab_size)
|
613 |
+
model = model.to(device)
|
614 |
+
|
615 |
+
criterion = SetCriterion(vocab_size, matcher=matcher, weight_dict=weight_dict, eos_coef = NULL_CLASS_COEF, losses=losses)
|
616 |
+
criterion = criterion.to(device)
|
617 |
+
|
618 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
619 |
+
|
620 |
+
best_loss = 10**5
|
621 |
+
|
622 |
+
LR = 2e-6
|
623 |
+
#LR = 2e-4
|
624 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR) #, weight_decay=0.0001)
|
625 |
+
EPOCHS=1
|
626 |
+
num_queries=NUM_QUERIES
|
627 |
+
batch_size=4
|
628 |
+
|
629 |
+
for epoch in range(EPOCHS):
|
630 |
+
time_start = time.time()
|
631 |
+
train_loss = train_fn(train_loader, model,criterion, optimizer,device,scheduler=None,epoch=epoch)
|
632 |
+
valid_loss = eval_fn(val_loader, model,criterion, device)
|
633 |
+
|
634 |
+
elapsed = time.time() - time_start
|
635 |
+
chk_name = f'LLMEyeCap_01_e{epoch}.bin'
|
636 |
+
torch.save(model.state_dict(), chk_name)
|
637 |
+
print(f"[Epoch {epoch+1:2d} / {EPOCHS:2d}] Train loss: {train_loss.avg:.3f}. Val loss: {valid_loss.avg:.3f} --> {chk_name} [{elapsed/60:.0f} mins]")
|
638 |
+
|
639 |
+
if valid_loss.avg < best_loss:
|
640 |
+
best_loss = valid_loss.avg
|
641 |
+
print(f'Best model found in epoch {epoch+1}........Saving Model')
|
642 |
+
torch.save(model.state_dict(), 'LLMEyeCap_01_model.bin')
|
tuto.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.ops.boxes import box_area
|
2 |
+
|
3 |
+
|
4 |
+
def box_cxcywh_to_xyxy(x):
|
5 |
+
x_c, y_c, w, h = x.unbind(-1)
|
6 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
7 |
+
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
8 |
+
return torch.stack(b, dim=-1)
|
9 |
+
|
10 |
+
|
11 |
+
def box_xyxy_to_cxcywh(x):
|
12 |
+
x0, y0, x1, y1 = x.unbind(-1)
|
13 |
+
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
14 |
+
(x1 - x0), (y1 - y0)]
|
15 |
+
return torch.stack(b, dim=-1)
|
16 |
+
|
17 |
+
|
18 |
+
# modified from torchvision to also return the union
|
19 |
+
def box_iou_2(boxes1, boxes2):
|
20 |
+
area1 = box_area(boxes1)
|
21 |
+
area2 = box_area(boxes2)
|
22 |
+
|
23 |
+
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
24 |
+
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
25 |
+
|
26 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
27 |
+
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
28 |
+
|
29 |
+
union = area1[:, None] + area2 - inter
|
30 |
+
|
31 |
+
iou = inter / union
|
32 |
+
return iou , union
|
33 |
+
|
34 |
+
|
35 |
+
def generalized_box_iou(boxes1, boxes2):
|
36 |
+
"""
|
37 |
+
Generalized IoU from https://giou.stanford.edu/
|
38 |
+
|
39 |
+
The boxes should be in [x0, y0, x1, y1] format
|
40 |
+
|
41 |
+
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
42 |
+
and M = len(boxes2)
|
43 |
+
"""
|
44 |
+
# degenerate boxes gives inf / nan results
|
45 |
+
# so do an early check
|
46 |
+
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
47 |
+
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
48 |
+
iou, union = box_iou_2(boxes1, boxes2)
|
49 |
+
|
50 |
+
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
51 |
+
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
52 |
+
|
53 |
+
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
54 |
+
area = wh[:, :, 0] * wh[:, :, 1]
|
55 |
+
|
56 |
+
return iou - (area - union) / area
|
57 |
+
|
58 |
+
|
59 |
+
def masks_to_boxes(masks):
|
60 |
+
"""Compute the bounding boxes around the provided masks
|
61 |
+
|
62 |
+
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
63 |
+
|
64 |
+
Returns a [N, 4] tensors, with the boxes in xyxy format
|
65 |
+
"""
|
66 |
+
if masks.numel() == 0:
|
67 |
+
return torch.zeros((0, 4), device=masks.device)
|
68 |
+
|
69 |
+
h, w = masks.shape[-2:]
|
70 |
+
|
71 |
+
y = torch.arange(0, h, dtype=torch.float)
|
72 |
+
x = torch.arange(0, w, dtype=torch.float)
|
73 |
+
y, x = torch.meshgrid(y, x)
|
74 |
+
|
75 |
+
x_mask = (masks * x.unsqueeze(0))
|
76 |
+
x_max = x_mask.flatten(1).max(-1)[0]
|
77 |
+
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
78 |
+
|
79 |
+
y_mask = (masks * y.unsqueeze(0))
|
80 |
+
y_max = y_mask.flatten(1).max(-1)[0]
|
81 |
+
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
82 |
+
|
83 |
+
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
84 |
+
"""
|
85 |
+
Misc functions, including distributed helpers.
|
86 |
+
|
87 |
+
Mostly copy-paste from torchvision references.
|
88 |
+
"""
|
89 |
+
import os
|
90 |
+
import subprocess
|
91 |
+
import time
|
92 |
+
from collections import defaultdict, deque
|
93 |
+
import datetime
|
94 |
+
import pickle
|
95 |
+
from packaging import version
|
96 |
+
from typing import Optional, List
|
97 |
+
|
98 |
+
import torch
|
99 |
+
import torch.distributed as dist
|
100 |
+
from torch import Tensor
|
101 |
+
|
102 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
103 |
+
import torchvision
|
104 |
+
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
105 |
+
from torchvision.ops import _new_empty_tensor
|
106 |
+
from torchvision.ops.misc import _output_size
|
107 |
+
|
108 |
+
|
109 |
+
class SmoothedValue(object):
|
110 |
+
"""Track a series of values and provide access to smoothed values over a
|
111 |
+
window or the global series average.
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(self, window_size=20, fmt=None):
|
115 |
+
if fmt is None:
|
116 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
117 |
+
self.deque = deque(maxlen=window_size)
|
118 |
+
self.total = 0.0
|
119 |
+
self.count = 0
|
120 |
+
self.fmt = fmt
|
121 |
+
|
122 |
+
def update(self, value, n=1):
|
123 |
+
self.deque.append(value)
|
124 |
+
self.count += n
|
125 |
+
self.total += value * n
|
126 |
+
|
127 |
+
def synchronize_between_processes(self):
|
128 |
+
"""
|
129 |
+
Warning: does not synchronize the deque!
|
130 |
+
"""
|
131 |
+
if not is_dist_avail_and_initialized():
|
132 |
+
return
|
133 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
134 |
+
dist.barrier()
|
135 |
+
dist.all_reduce(t)
|
136 |
+
t = t.tolist()
|
137 |
+
self.count = int(t[0])
|
138 |
+
self.total = t[1]
|
139 |
+
|
140 |
+
@property
|
141 |
+
def median(self):
|
142 |
+
d = torch.tensor(list(self.deque))
|
143 |
+
return d.median().item()
|
144 |
+
|
145 |
+
@property
|
146 |
+
def avg(self):
|
147 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
148 |
+
return d.mean().item()
|
149 |
+
|
150 |
+
@property
|
151 |
+
def global_avg(self):
|
152 |
+
return self.total / self.count
|
153 |
+
|
154 |
+
@property
|
155 |
+
def max(self):
|
156 |
+
return max(self.deque)
|
157 |
+
|
158 |
+
@property
|
159 |
+
def value(self):
|
160 |
+
return self.deque[-1]
|
161 |
+
|
162 |
+
def __str__(self):
|
163 |
+
return self.fmt.format(
|
164 |
+
median=self.median,
|
165 |
+
avg=self.avg,
|
166 |
+
global_avg=self.global_avg,
|
167 |
+
max=self.max,
|
168 |
+
value=self.value)
|
169 |
+
|
170 |
+
|
171 |
+
def all_gather(data):
|
172 |
+
"""
|
173 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
174 |
+
Args:
|
175 |
+
data: any picklable object
|
176 |
+
Returns:
|
177 |
+
list[data]: list of data gathered from each rank
|
178 |
+
"""
|
179 |
+
world_size = get_world_size()
|
180 |
+
if world_size == 1:
|
181 |
+
return [data]
|
182 |
+
|
183 |
+
# serialized to a Tensor
|
184 |
+
buffer = pickle.dumps(data)
|
185 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
186 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
187 |
+
|
188 |
+
# obtain Tensor size of each rank
|
189 |
+
local_size = torch.tensor([tensor.numel()], device="cuda")
|
190 |
+
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
191 |
+
dist.all_gather(size_list, local_size)
|
192 |
+
size_list = [int(size.item()) for size in size_list]
|
193 |
+
max_size = max(size_list)
|
194 |
+
|
195 |
+
# receiving Tensor from all ranks
|
196 |
+
# we pad the tensor because torch all_gather does not support
|
197 |
+
# gathering tensors of different shapes
|
198 |
+
tensor_list = []
|
199 |
+
for _ in size_list:
|
200 |
+
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
201 |
+
if local_size != max_size:
|
202 |
+
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
203 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
204 |
+
dist.all_gather(tensor_list, tensor)
|
205 |
+
|
206 |
+
data_list = []
|
207 |
+
for size, tensor in zip(size_list, tensor_list):
|
208 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
209 |
+
data_list.append(pickle.loads(buffer))
|
210 |
+
|
211 |
+
return data_list
|
212 |
+
|
213 |
+
|
214 |
+
def reduce_dict(input_dict, average=True):
|
215 |
+
"""
|
216 |
+
Args:
|
217 |
+
input_dict (dict): all the values will be reduced
|
218 |
+
average (bool): whether to do average or sum
|
219 |
+
Reduce the values in the dictionary from all processes so that all processes
|
220 |
+
have the averaged results. Returns a dict with the same fields as
|
221 |
+
input_dict, after reduction.
|
222 |
+
"""
|
223 |
+
world_size = get_world_size()
|
224 |
+
if world_size < 2:
|
225 |
+
return input_dict
|
226 |
+
with torch.no_grad():
|
227 |
+
names = []
|
228 |
+
values = []
|
229 |
+
# sort the keys so that they are consistent across processes
|
230 |
+
for k in sorted(input_dict.keys()):
|
231 |
+
names.append(k)
|
232 |
+
values.append(input_dict[k])
|
233 |
+
values = torch.stack(values, dim=0)
|
234 |
+
dist.all_reduce(values)
|
235 |
+
if average:
|
236 |
+
values /= world_size
|
237 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
238 |
+
return reduced_dict
|
239 |
+
|
240 |
+
|
241 |
+
class MetricLogger(object):
|
242 |
+
def __init__(self, delimiter="\t"):
|
243 |
+
self.meters = defaultdict(SmoothedValue)
|
244 |
+
self.delimiter = delimiter
|
245 |
+
|
246 |
+
def update(self, **kwargs):
|
247 |
+
for k, v in kwargs.items():
|
248 |
+
if isinstance(v, torch.Tensor):
|
249 |
+
v = v.item()
|
250 |
+
assert isinstance(v, (float, int))
|
251 |
+
self.meters[k].update(v)
|
252 |
+
|
253 |
+
def __getattr__(self, attr):
|
254 |
+
if attr in self.meters:
|
255 |
+
return self.meters[attr]
|
256 |
+
if attr in self.__dict__:
|
257 |
+
return self.__dict__[attr]
|
258 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
259 |
+
type(self).__name__, attr))
|
260 |
+
|
261 |
+
def __str__(self):
|
262 |
+
loss_str = []
|
263 |
+
for name, meter in self.meters.items():
|
264 |
+
loss_str.append(
|
265 |
+
"{}: {}".format(name, str(meter))
|
266 |
+
)
|
267 |
+
return self.delimiter.join(loss_str)
|
268 |
+
|
269 |
+
def synchronize_between_processes(self):
|
270 |
+
for meter in self.meters.values():
|
271 |
+
meter.synchronize_between_processes()
|
272 |
+
|
273 |
+
def add_meter(self, name, meter):
|
274 |
+
self.meters[name] = meter
|
275 |
+
|
276 |
+
def log_every(self, iterable, print_freq, header=None):
|
277 |
+
i = 0
|
278 |
+
if not header:
|
279 |
+
header = ''
|
280 |
+
start_time = time.time()
|
281 |
+
end = time.time()
|
282 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
283 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
284 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
285 |
+
if torch.cuda.is_available():
|
286 |
+
log_msg = self.delimiter.join([
|
287 |
+
header,
|
288 |
+
'[{0' + space_fmt + '}/{1}]',
|
289 |
+
'eta: {eta}',
|
290 |
+
'{meters}',
|
291 |
+
'time: {time}',
|
292 |
+
'data: {data}',
|
293 |
+
'max mem: {memory:.0f}'
|
294 |
+
])
|
295 |
+
else:
|
296 |
+
log_msg = self.delimiter.join([
|
297 |
+
header,
|
298 |
+
'[{0' + space_fmt + '}/{1}]',
|
299 |
+
'eta: {eta}',
|
300 |
+
'{meters}',
|
301 |
+
'time: {time}',
|
302 |
+
'data: {data}'
|
303 |
+
])
|
304 |
+
MB = 1024.0 * 1024.0
|
305 |
+
for obj in iterable:
|
306 |
+
data_time.update(time.time() - end)
|
307 |
+
yield obj
|
308 |
+
iter_time.update(time.time() - end)
|
309 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
310 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
311 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
312 |
+
if torch.cuda.is_available():
|
313 |
+
print(log_msg.format(
|
314 |
+
i, len(iterable), eta=eta_string,
|
315 |
+
meters=str(self),
|
316 |
+
time=str(iter_time), data=str(data_time),
|
317 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
318 |
+
else:
|
319 |
+
print(log_msg.format(
|
320 |
+
i, len(iterable), eta=eta_string,
|
321 |
+
meters=str(self),
|
322 |
+
time=str(iter_time), data=str(data_time)))
|
323 |
+
i += 1
|
324 |
+
end = time.time()
|
325 |
+
total_time = time.time() - start_time
|
326 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
327 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
328 |
+
header, total_time_str, total_time / len(iterable)))
|
329 |
+
|
330 |
+
|
331 |
+
def get_sha():
|
332 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
333 |
+
|
334 |
+
def _run(command):
|
335 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
336 |
+
sha = 'N/A'
|
337 |
+
diff = "clean"
|
338 |
+
branch = 'N/A'
|
339 |
+
try:
|
340 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
341 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
342 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
343 |
+
diff = "has uncommited changes" if diff else "clean"
|
344 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
345 |
+
except Exception:
|
346 |
+
pass
|
347 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
348 |
+
return message
|
349 |
+
|
350 |
+
|
351 |
+
def collate_fn(batch):
|
352 |
+
batch = list(zip(*batch))
|
353 |
+
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
354 |
+
return tuple(batch)
|
355 |
+
|
356 |
+
|
357 |
+
def _max_by_axis(the_list):
|
358 |
+
# type: (List[List[int]]) -> List[int]
|
359 |
+
maxes = the_list[0]
|
360 |
+
for sublist in the_list[1:]:
|
361 |
+
for index, item in enumerate(sublist):
|
362 |
+
maxes[index] = max(maxes[index], item)
|
363 |
+
return maxes
|
364 |
+
|
365 |
+
|
366 |
+
class NestedTensor(object):
|
367 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
368 |
+
self.tensors = tensors
|
369 |
+
self.mask = mask
|
370 |
+
|
371 |
+
def to(self, device):
|
372 |
+
# type: (Device) -> NestedTensor # noqa
|
373 |
+
cast_tensor = self.tensors.to(device)
|
374 |
+
mask = self.mask
|
375 |
+
if mask is not None:
|
376 |
+
assert mask is not None
|
377 |
+
cast_mask = mask.to(device)
|
378 |
+
else:
|
379 |
+
cast_mask = None
|
380 |
+
return NestedTensor(cast_tensor, cast_mask)
|
381 |
+
|
382 |
+
def decompose(self):
|
383 |
+
return self.tensors, self.mask
|
384 |
+
|
385 |
+
def __repr__(self):
|
386 |
+
return str(self.tensors)
|
387 |
+
|
388 |
+
|
389 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
390 |
+
# TODO make this more general
|
391 |
+
if tensor_list[0].ndim == 3:
|
392 |
+
if torchvision._is_tracing():
|
393 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
394 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
395 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
396 |
+
|
397 |
+
# TODO make it support different-sized images
|
398 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
399 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
400 |
+
batch_shape = [len(tensor_list)] + max_size
|
401 |
+
b, c, h, w = batch_shape
|
402 |
+
dtype = tensor_list[0].dtype
|
403 |
+
device = tensor_list[0].device
|
404 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
405 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
406 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
407 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
408 |
+
m[: img.shape[1], :img.shape[2]] = False
|
409 |
+
else:
|
410 |
+
raise ValueError('not supported')
|
411 |
+
return NestedTensor(tensor, mask)
|
412 |
+
|
413 |
+
|
414 |
+
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
415 |
+
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
416 |
+
@torch.jit.unused
|
417 |
+
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
418 |
+
max_size = []
|
419 |
+
for i in range(tensor_list[0].dim()):
|
420 |
+
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
421 |
+
max_size.append(max_size_i)
|
422 |
+
max_size = tuple(max_size)
|
423 |
+
|
424 |
+
# work around for
|
425 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
426 |
+
# m[: img.shape[1], :img.shape[2]] = False
|
427 |
+
# which is not yet supported in onnx
|
428 |
+
padded_imgs = []
|
429 |
+
padded_masks = []
|
430 |
+
for img in tensor_list:
|
431 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
432 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
433 |
+
padded_imgs.append(padded_img)
|
434 |
+
|
435 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
436 |
+
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
437 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
438 |
+
|
439 |
+
tensor = torch.stack(padded_imgs)
|
440 |
+
mask = torch.stack(padded_masks)
|
441 |
+
|
442 |
+
return NestedTensor(tensor, mask=mask)
|
443 |
+
|
444 |
+
|
445 |
+
def setup_for_distributed(is_master):
|
446 |
+
"""
|
447 |
+
This function disables printing when not in master process
|
448 |
+
"""
|
449 |
+
import builtins as __builtin__
|
450 |
+
builtin_print = __builtin__.print
|
451 |
+
|
452 |
+
def print(*args, **kwargs):
|
453 |
+
force = kwargs.pop('force', False)
|
454 |
+
if is_master or force:
|
455 |
+
builtin_print(*args, **kwargs)
|
456 |
+
|
457 |
+
__builtin__.print = print
|
458 |
+
|
459 |
+
|
460 |
+
def is_dist_avail_and_initialized():
|
461 |
+
if not dist.is_available():
|
462 |
+
return False
|
463 |
+
if not dist.is_initialized():
|
464 |
+
return False
|
465 |
+
return True
|
466 |
+
|
467 |
+
|
468 |
+
def get_world_size():
|
469 |
+
if not is_dist_avail_and_initialized():
|
470 |
+
return 1
|
471 |
+
return dist.get_world_size()
|
472 |
+
|
473 |
+
|
474 |
+
def get_rank():
|
475 |
+
if not is_dist_avail_and_initialized():
|
476 |
+
return 0
|
477 |
+
return dist.get_rank()
|
478 |
+
|
479 |
+
|
480 |
+
def is_main_process():
|
481 |
+
return get_rank() == 0
|
482 |
+
|
483 |
+
|
484 |
+
def save_on_master(*args, **kwargs):
|
485 |
+
if is_main_process():
|
486 |
+
torch.save(*args, **kwargs)
|
487 |
+
|
488 |
+
|
489 |
+
def init_distributed_mode(args):
|
490 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
491 |
+
args.rank = int(os.environ["RANK"])
|
492 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
493 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
494 |
+
elif 'SLURM_PROCID' in os.environ:
|
495 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
496 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
497 |
+
else:
|
498 |
+
print('Not using distributed mode')
|
499 |
+
args.distributed = False
|
500 |
+
return
|
501 |
+
|
502 |
+
args.distributed = True
|
503 |
+
|
504 |
+
torch.cuda.set_device(args.gpu)
|
505 |
+
args.dist_backend = 'nccl'
|
506 |
+
print('| distributed init (rank {}): {}'.format(
|
507 |
+
args.rank, args.dist_url), flush=True)
|
508 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
509 |
+
world_size=args.world_size, rank=args.rank)
|
510 |
+
torch.distributed.barrier()
|
511 |
+
setup_for_distributed(args.rank == 0)
|
512 |
+
|
513 |
+
|
514 |
+
@torch.no_grad()
|
515 |
+
def accuracy(output, target, topk=(1,)):
|
516 |
+
if output.dim() == 1:
|
517 |
+
output = output.unsqueeze(0)
|
518 |
+
|
519 |
+
maxk = max(topk)
|
520 |
+
batch_size = target.size(0)
|
521 |
+
|
522 |
+
_, pred = output.topk(maxk, 1, True, True)
|
523 |
+
pred = pred.t()
|
524 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
525 |
+
|
526 |
+
res = []
|
527 |
+
for k in topk:
|
528 |
+
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
529 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
530 |
+
return res
|
531 |
+
|
532 |
+
|
533 |
+
'''
|
534 |
+
def accuracy(output, target, topk=(1,)):
|
535 |
+
"""Computes the precision@k for the specified values of k"""
|
536 |
+
if target.numel() == 0:
|
537 |
+
return [torch.zeros([], device=output.device)]
|
538 |
+
maxk = max(topk)
|
539 |
+
batch_size = target.size(0)
|
540 |
+
|
541 |
+
_, pred = output.topk(maxk, 1, True, True)
|
542 |
+
pred = pred.t()
|
543 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
544 |
+
|
545 |
+
res = []
|
546 |
+
for k in topk:
|
547 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
548 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
549 |
+
return res
|
550 |
+
'''
|
551 |
+
|
552 |
+
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
553 |
+
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
554 |
+
"""
|
555 |
+
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
556 |
+
This will eventually be supported natively by PyTorch, and this
|
557 |
+
class can go away.
|
558 |
+
"""
|
559 |
+
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
560 |
+
if input.numel() > 0:
|
561 |
+
return torch.nn.functional.interpolate(
|
562 |
+
input, size, scale_factor, mode, align_corners
|
563 |
+
)
|
564 |
+
|
565 |
+
output_shape = _output_size(2, input, size, scale_factor)
|
566 |
+
output_shape = list(input.shape[:-2]) + list(output_shape)
|
567 |
+
return _new_empty_tensor(input, output_shape)
|
568 |
+
else:
|
569 |
+
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|