Imagroune commited on
Commit
5f6da0d
·
1 Parent(s): bd2522f

Upload 5 files

Browse files
Files changed (5) hide show
  1. LLMEyeCap_01.bin +3 -0
  2. model.py +893 -0
  3. train.py +642 -0
  4. tuto.ipynb +0 -0
  5. utils.py +569 -0
LLMEyeCap_01.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d53f80ed02bdee05882919aa81232ebf8e1af0510bfb6388dc6e616ce57db2a3
3
+ size 445770457
model.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision.models import resnet50
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ from transformers import BertTokenizer, BertModel
9
+ import os
10
+ import json
11
+ import numpy as np
12
+ from collections import defaultdict
13
+ import random
14
+ from tqdm.notebook import tqdm
15
+ from torchvision import models
16
+ from torch.nn.utils.rnn import pad_sequence
17
+ import matplotlib.patches as patches
18
+
19
+ import math
20
+ import time
21
+ import os
22
+ from PIL import Image
23
+ import requests
24
+ import nltk
25
+
26
+ import os
27
+ import cv2
28
+ import colorsys
29
+ from numpy import asarray
30
+ import math
31
+
32
+
33
+ from transformers import GPT2LMHeadModel, GPT2Config
34
+
35
+ from scipy.optimize import linear_sum_assignment
36
+
37
+ import sys
38
+ sys.path.append("../src")
39
+
40
+ from utils import *
41
+
42
+ NUM_QUERIES = 40
43
+ feature_size = 256 # Pour ResNet50
44
+ token_size = 256 # Pour GPT-2
45
+
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
49
+ # minimal updates here
50
+
51
+ """
52
+ Various positional encodings for the transformer.
53
+ """
54
+
55
+ class PositionEmbeddingSine(nn.Module):
56
+ """
57
+ This is a more standard version of the position embedding, very similar to the one
58
+ used by the Attention is all you need paper, generalized to work on images.
59
+ """
60
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
61
+ super().__init__()
62
+ self.num_pos_feats = num_pos_feats
63
+ self.temperature = temperature
64
+ self.normalize = normalize
65
+ if scale is not None and normalize is False:
66
+ raise ValueError("normalize should be True if scale is passed")
67
+ if scale is None:
68
+ scale = 2 * math.pi
69
+ self.scale = scale
70
+
71
+ def forward(self, tensor_list: NestedTensor):
72
+ x = tensor_list.tensors
73
+ mask = tensor_list.mask
74
+ assert mask is not None
75
+ not_mask = ~mask
76
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
77
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
78
+ if self.normalize:
79
+ eps = 1e-6
80
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
81
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
82
+
83
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
84
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
85
+
86
+ pos_x = x_embed[:, :, :, None] / dim_t
87
+ pos_y = y_embed[:, :, :, None] / dim_t
88
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
89
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
90
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
91
+ return pos
92
+
93
+
94
+ class PositionEmbeddingLearned(nn.Module):
95
+ """
96
+ Absolute pos embedding, learned.
97
+ """
98
+ def __init__(self, num_pos_feats=256):
99
+ super().__init__()
100
+ self.row_embed = nn.Embedding(50, num_pos_feats)
101
+ self.col_embed = nn.Embedding(50, num_pos_feats)
102
+ self.reset_parameters()
103
+
104
+ def reset_parameters(self):
105
+ nn.init.uniform_(self.row_embed.weight)
106
+ nn.init.uniform_(self.col_embed.weight)
107
+
108
+ def forward(self, tensor_list: NestedTensor):
109
+ x = tensor_list.tensors
110
+ h, w = x.shape[-2:]
111
+ i = torch.arange(w, device=x.device)
112
+ j = torch.arange(h, device=x.device)
113
+ x_emb = self.col_embed(i)
114
+ y_emb = self.row_embed(j)
115
+ pos = torch.cat([
116
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
117
+ y_emb.unsqueeze(1).repeat(1, w, 1),
118
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
119
+ return pos
120
+
121
+
122
+ def build_position_encoding(args):
123
+ N_steps = args.hidden_dim // 2
124
+ if args.position_embedding in ('v2', 'sine'):
125
+ # TODO find a better way of exposing other arguments
126
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
127
+ elif args.position_embedding in ('v3', 'learned'):
128
+ position_embedding = PositionEmbeddingLearned(N_steps)
129
+ else:
130
+ raise ValueError(f"not supported {args.position_embedding}")
131
+
132
+ return position_embedding
133
+
134
+ from collections import OrderedDict
135
+
136
+ import torch
137
+ import torch.nn.functional as F
138
+ import torchvision
139
+ from torch import nn
140
+ from torchvision.models._utils import IntermediateLayerGetter
141
+ from typing import Dict, List
142
+
143
+
144
+ class FrozenBatchNorm2d(torch.nn.Module):
145
+ """
146
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
147
+
148
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
149
+ without which any other models than torchvision.models.resnet[18,34,50,101]
150
+ produce nans.
151
+ """
152
+
153
+ def __init__(self, n):
154
+ super(FrozenBatchNorm2d, self).__init__()
155
+ self.register_buffer("weight", torch.ones(n))
156
+ self.register_buffer("bias", torch.zeros(n))
157
+ self.register_buffer("running_mean", torch.zeros(n))
158
+ self.register_buffer("running_var", torch.ones(n))
159
+
160
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
161
+ missing_keys, unexpected_keys, error_msgs):
162
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
163
+ if num_batches_tracked_key in state_dict:
164
+ del state_dict[num_batches_tracked_key]
165
+
166
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
167
+ state_dict, prefix, local_metadata, strict,
168
+ missing_keys, unexpected_keys, error_msgs)
169
+
170
+ def forward(self, x):
171
+ # move reshapes to the beginning
172
+ # to make it fuser-friendly
173
+ w = self.weight.reshape(1, -1, 1, 1)
174
+ b = self.bias.reshape(1, -1, 1, 1)
175
+ rv = self.running_var.reshape(1, -1, 1, 1)
176
+ rm = self.running_mean.reshape(1, -1, 1, 1)
177
+ eps = 1e-5
178
+ scale = w * (rv + eps).rsqrt()
179
+ bias = b - rm * scale
180
+ return x * scale + bias
181
+
182
+
183
+ class BackboneBase(nn.Module):
184
+
185
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
186
+ super().__init__()
187
+ for name, parameter in backbone.named_parameters():
188
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
189
+ parameter.requires_grad_(False)
190
+ if return_interm_layers:
191
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
192
+ else:
193
+ return_layers = {'layer4': "0"}
194
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
195
+ self.num_channels = num_channels
196
+
197
+ def forward(self, tensor_list: NestedTensor):
198
+ xs = self.body(tensor_list.tensors)
199
+ out: Dict[str, NestedTensor] = {}
200
+ for name, x in xs.items():
201
+ m = tensor_list.mask
202
+ assert m is not None
203
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
204
+ out[name] = NestedTensor(x, mask)
205
+ return out
206
+
207
+ '''
208
+ The line mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] applies a mask to the output
209
+ features from the backbone. The mask is used to indicate which pixels in the image are valid.
210
+
211
+
212
+ The mask is a tensor of the same size as the output features. The mask is initialized to all zeros. The m[None].float()
213
+ operation expands the mask to be a 1-D tensor of size 1 x H x W. The F.interpolate()
214
+ operation then resizes the mask to the same size as the output features. The to(torch.bool) operation converts the
215
+ mask to a binary tensor. The [0] operation takes the first element of the tensor, which is the mask for the first output
216
+ feature map.
217
+
218
+ The mask of a feature extracted from ResNet50 as a backbone is a binary tensor that indicates which pixels in the image
219
+ are valid. The pixels that are valid are those that are not padded. The mask is used by the backbone to ignore the padded
220
+ pixels when it is extracting features from the image.
221
+
222
+ '''
223
+
224
+ class Backbone(BackboneBase):
225
+ """ResNet backbone with frozen BatchNorm."""
226
+ def __init__(self, name: str,
227
+ train_backbone: bool,
228
+ return_interm_layers: bool,
229
+ dilation: bool):
230
+ backbone = getattr(torchvision.models, name)(
231
+ replace_stride_with_dilation=[False, False, dilation],
232
+ pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
233
+ # ==> todo weights=ResNet50_Weights.DEFAULT)
234
+ num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
235
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
236
+
237
+
238
+ class Joiner(nn.Sequential):
239
+ def __init__(self, backbone, position_embedding):
240
+ super().__init__(backbone, position_embedding)
241
+
242
+ def forward(self, tensor_list: NestedTensor):
243
+ xs = self[0](tensor_list)
244
+ out: List[NestedTensor] = []
245
+ pos = []
246
+ for name, x in xs.items():
247
+ out.append(x)
248
+ # position encoding
249
+ pos.append(self[1](x).to(x.tensors.dtype))
250
+
251
+ return out, pos
252
+
253
+
254
+ def build_backbone(args):
255
+ position_embedding = build_position_encoding(args)
256
+ train_backbone = args.lr_backbone > 0
257
+ return_interm_layers = args.masks
258
+ backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
259
+ model = Joiner(backbone, position_embedding)
260
+ model.num_channels = backbone.num_channels
261
+ return model
262
+
263
+ def get_sinusoid_encoding_table(n_position, d_hid):
264
+ def cal_angle(position, hid_idx):
265
+ return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
266
+
267
+ def get_posi_angle_vec(position):
268
+ return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
269
+
270
+ sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
271
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
272
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
273
+ return torch.FloatTensor(sinusoid_table)
274
+
275
+ class PostProcess(nn.Module):
276
+ """ This module converts the model's output into the format expected by the coco api"""
277
+ @torch.no_grad()
278
+ def forward(self, outputs, target_sizes):
279
+ """ Perform the computation
280
+ Parameters:
281
+ outputs: raw outputs of the model
282
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
283
+ For evaluation, this must be the original image size (before any data augmentation)
284
+ For visualization, this should be the image size after data augment, but before padding
285
+ """
286
+ out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
287
+
288
+ assert len(out_logits) == len(target_sizes)
289
+ assert target_sizes.shape[1] == 2
290
+
291
+ prob = F.softmax(out_logits, -1)
292
+ scores, labels = prob[..., :-1].max(-1)
293
+
294
+ # convert to [x0, y0, x1, y1] format
295
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
296
+ # and from relative [0, 1] to absolute [0, height] coordinates
297
+ img_h, img_w = target_sizes.unbind(1)
298
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
299
+ boxes = boxes * scale_fct[:, None, :]
300
+
301
+ results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
302
+
303
+ return results
304
+
305
+
306
+ class MLP(nn.Module):
307
+ """ Very simple multi-layer perceptron (also called FFN)"""
308
+
309
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
310
+ super().__init__()
311
+ self.num_layers = num_layers
312
+ h = [hidden_dim] * (num_layers - 1)
313
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
314
+
315
+ def forward(self, x):
316
+ for i, layer in enumerate(self.layers):
317
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
318
+ return x
319
+
320
+
321
+ def build(args):
322
+ # the `num_classes` naming here is somewhat misleading.
323
+ # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
324
+ # is the maximum id for a class in your dataset. For example,
325
+ # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
326
+ # As another example, for a dataset that has a single class with id 1,
327
+ # you should pass `num_classes` to be 2 (max_obj_id + 1).
328
+ # For more details on this, check the following discussion
329
+ # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
330
+ num_classes = 20 if args.dataset_file != 'coco' else 91
331
+ if args.dataset_file == "coco_panoptic":
332
+ # for panoptic, we just add a num_classes that is large enough to hold
333
+ # max_obj_id + 1, but the exact value doesn't really matter
334
+ num_classes = 250
335
+ device = torch.device(args.device)
336
+
337
+ backbone = build_backbone(args)
338
+
339
+ transformer = build_transformer(args)
340
+
341
+ model = DETR(
342
+ backbone,
343
+ transformer,
344
+ num_classes=num_classes,
345
+ num_queries=args.num_queries,
346
+ aux_loss=args.aux_loss,
347
+ )
348
+ if args.masks:
349
+ model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
350
+ matcher = build_matcher(args)
351
+ weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
352
+ weight_dict['loss_giou'] = args.giou_loss_coef
353
+ if args.masks:
354
+ weight_dict["loss_mask"] = args.mask_loss_coef
355
+ weight_dict["loss_dice"] = args.dice_loss_coef
356
+ # TODO this is a hack
357
+ if args.aux_loss:
358
+ aux_weight_dict = {}
359
+ for i in range(args.dec_layers - 1):
360
+ aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
361
+ weight_dict.update(aux_weight_dict)
362
+
363
+ losses = ['labels', 'boxes', 'cardinality']
364
+ if args.masks:
365
+ losses += ["masks"]
366
+ criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,
367
+ eos_coef=args.eos_coef, losses=losses)
368
+ criterion.to(device)
369
+ postprocessors = {'bbox': PostProcess()}
370
+ if args.masks:
371
+ postprocessors['segm'] = PostProcessSegm()
372
+ if args.dataset_file == "coco_panoptic":
373
+ is_thing_map = {i: i <= 90 for i in range(201)}
374
+ postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)
375
+
376
+ return model, criterion, postprocessors
377
+
378
+ class Parameters:
379
+ def __init__(self):
380
+ self.lr = 1e-4
381
+ self.lr_backbone = 1e-5
382
+ self.batch_size = 2
383
+ self.weight_decay = 1e-4
384
+ self.epochs = 300
385
+ self.lr_drop = 200
386
+ self.clip_max_norm = 0.1
387
+
388
+ args = Parameters()
389
+
390
+ args.lr=1e-4
391
+ args.lr_backbone=1e-5
392
+ args.batch_size=32
393
+ args.weight_decay=1e-4
394
+ args.epochs=300
395
+ args.lr_drop=200
396
+ args.clip_max_norm=0.1 # type=float, help='gradient clipping max norm')
397
+
398
+ # Model parameters
399
+ args.frozen_weights=False # ', type=str, default=None, # help="Path to the pretrained model. If set, only the mask head will be trained")
400
+
401
+ # * Backbone
402
+ args.backbone='resnet50' # type=str, # help="Name of the convolutional backbone to use")
403
+ args.dilation=False # ', action='store_true', # help="If true, we replace stride with dilation in the last convolutional block (DC5)")
404
+ args.position_embedding='sine' # type=str, choices=('sine', 'learned'), help="Type of positional embedding to use on top of the image features")
405
+
406
+ # * Transformer
407
+ args.enc_layers=6 # type=int, help="Number of encoding layers in the transformer")
408
+ args.dec_layers=6 # type=int, help="Number of decoding layers in the transformer")
409
+ args.dim_feedforward=2048 # ===> type=int, help="Intermediate size of the feedforward layers in the transformer blocks")
410
+ args.hidden_dim=256 # ===> type=int, help="Size of the embeddings (dimension of the transformer)")
411
+ args.dropout=0.1 #type=float, help="Dropout applied in the transformer")
412
+ args.nheads=8 #type=int, help="Number of attention heads inside the transformer's attentions")
413
+ args.num_queries=40 #type=int, help="Number of query slots")
414
+ args.pre_norm=True # ', action='store_true')
415
+
416
+ # * Segmentation
417
+ args.masks=False #, action='store_true', help="Train segmentation head if the flag is provided")
418
+
419
+
420
+ """
421
+ LLMEyeCap Transformer class.
422
+
423
+ A DETR (FaceBook) Copy-paste from torch.nn.Transformer with modifications:
424
+ * positional encodings are passed in MHattention
425
+ * extra LN at the end of encoder is removed
426
+ * decoder returns a stack of activations from all decoding layers
427
+
428
+ """
429
+ import copy
430
+ from typing import Optional, List
431
+
432
+
433
+ class Transformer(nn.Module):
434
+
435
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
436
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
437
+ activation="relu", normalize_before=False,
438
+ return_intermediate_dec=False):
439
+ super().__init__()
440
+
441
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
442
+ dropout, activation, normalize_before)
443
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
444
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
445
+
446
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
447
+ dropout, activation, normalize_before)
448
+ decoder_norm = nn.LayerNorm(d_model)
449
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
450
+ return_intermediate=return_intermediate_dec)
451
+
452
+ self._reset_parameters()
453
+
454
+ self.d_model = d_model
455
+ self.nhead = nhead
456
+
457
+ def _reset_parameters(self):
458
+ for p in self.parameters():
459
+ if p.dim() > 1:
460
+ nn.init.xavier_uniform_(p)
461
+
462
+ def forward(self, src, mask, query_embed, pos_embed):
463
+ # flatten NxCxHxW to HWxNxC
464
+ bs, c, h, w = src.shape
465
+ src = src.flatten(2).permute(2, 0, 1)
466
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
467
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
468
+ mask = mask.flatten(1)
469
+
470
+ tgt = torch.zeros_like(query_embed)
471
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
472
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
473
+ pos=pos_embed, query_pos=query_embed)
474
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
475
+
476
+
477
+ class TransformerEncoder(nn.Module):
478
+
479
+ def __init__(self, encoder_layer, num_layers, norm=None):
480
+ super().__init__()
481
+ self.layers = _get_clones(encoder_layer, num_layers)
482
+ self.num_layers = num_layers
483
+ self.norm = norm
484
+
485
+ def forward(self, src,
486
+ mask: Optional[Tensor] = None,
487
+ src_key_padding_mask: Optional[Tensor] = None,
488
+ pos: Optional[Tensor] = None):
489
+ output = src
490
+
491
+ for layer in self.layers:
492
+ output = layer(output, src_mask=mask,
493
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
494
+
495
+ if self.norm is not None:
496
+ output = self.norm(output)
497
+
498
+ return output
499
+
500
+
501
+ class TransformerDecoder(nn.Module):
502
+
503
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
504
+ super().__init__()
505
+ self.layers = _get_clones(decoder_layer, num_layers)
506
+ self.num_layers = num_layers
507
+ self.norm = norm
508
+ self.return_intermediate = return_intermediate
509
+
510
+ def forward(self, tgt, memory,
511
+ tgt_mask: Optional[Tensor] = None,
512
+ memory_mask: Optional[Tensor] = None,
513
+ tgt_key_padding_mask: Optional[Tensor] = None,
514
+ memory_key_padding_mask: Optional[Tensor] = None,
515
+ pos: Optional[Tensor] = None,
516
+ query_pos: Optional[Tensor] = None):
517
+ output = tgt
518
+
519
+ intermediate = []
520
+
521
+ for layer in self.layers:
522
+ output = layer(output, memory, tgt_mask=tgt_mask,
523
+ memory_mask=memory_mask,
524
+ tgt_key_padding_mask=tgt_key_padding_mask,
525
+ memory_key_padding_mask=memory_key_padding_mask,
526
+ pos=pos, query_pos=query_pos)
527
+ if self.return_intermediate:
528
+ intermediate.append(self.norm(output))
529
+
530
+ if self.norm is not None:
531
+ output = self.norm(output)
532
+ if self.return_intermediate:
533
+ intermediate.pop()
534
+ intermediate.append(output)
535
+
536
+ if self.return_intermediate:
537
+ return torch.stack(intermediate)
538
+
539
+ return output.unsqueeze(0)
540
+
541
+
542
+ class TransformerEncoderLayer(nn.Module):
543
+
544
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
545
+ activation="relu", normalize_before=False):
546
+ super().__init__()
547
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
548
+ # Implementation of Feedforward model
549
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
550
+ self.dropout = nn.Dropout(dropout)
551
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
552
+
553
+ self.norm1 = nn.LayerNorm(d_model)
554
+ self.norm2 = nn.LayerNorm(d_model)
555
+ self.dropout1 = nn.Dropout(dropout)
556
+ self.dropout2 = nn.Dropout(dropout)
557
+
558
+ self.activation = _get_activation_fn(activation)
559
+ self.normalize_before = normalize_before
560
+
561
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
562
+ return tensor if pos is None else tensor + pos
563
+
564
+ def forward_post(self,
565
+ src,
566
+ src_mask: Optional[Tensor] = None,
567
+ src_key_padding_mask: Optional[Tensor] = None,
568
+ pos: Optional[Tensor] = None):
569
+ q = k = self.with_pos_embed(src, pos)
570
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
571
+ key_padding_mask=src_key_padding_mask)[0]
572
+ src = src + self.dropout1(src2)
573
+ src = self.norm1(src)
574
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
575
+ src = src + self.dropout2(src2)
576
+ src = self.norm2(src)
577
+ return src
578
+
579
+ def forward_pre(self, src,
580
+ src_mask: Optional[Tensor] = None,
581
+ src_key_padding_mask: Optional[Tensor] = None,
582
+ pos: Optional[Tensor] = None):
583
+ src2 = self.norm1(src)
584
+ q = k = self.with_pos_embed(src2, pos)
585
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
586
+ key_padding_mask=src_key_padding_mask)[0]
587
+ src = src + self.dropout1(src2)
588
+ src2 = self.norm2(src)
589
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
590
+ src = src + self.dropout2(src2)
591
+ return src
592
+
593
+ def forward(self, src,
594
+ src_mask: Optional[Tensor] = None,
595
+ src_key_padding_mask: Optional[Tensor] = None,
596
+ pos: Optional[Tensor] = None):
597
+ if self.normalize_before:
598
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
599
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
600
+
601
+
602
+ class TransformerDecoderLayer(nn.Module):
603
+
604
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
605
+ activation="relu", normalize_before=False):
606
+ super().__init__()
607
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
608
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
609
+ # Implementation of Feedforward model
610
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
611
+ self.dropout = nn.Dropout(dropout)
612
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
613
+
614
+ self.norm1 = nn.LayerNorm(d_model)
615
+ self.norm2 = nn.LayerNorm(d_model)
616
+ self.norm3 = nn.LayerNorm(d_model)
617
+ self.dropout1 = nn.Dropout(dropout)
618
+ self.dropout2 = nn.Dropout(dropout)
619
+ self.dropout3 = nn.Dropout(dropout)
620
+
621
+ self.activation = _get_activation_fn(activation)
622
+ self.normalize_before = normalize_before
623
+
624
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
625
+ return tensor if pos is None else tensor + pos
626
+
627
+ def forward_post(self, tgt, memory,
628
+ tgt_mask: Optional[Tensor] = None,
629
+ memory_mask: Optional[Tensor] = None,
630
+ tgt_key_padding_mask: Optional[Tensor] = None,
631
+ memory_key_padding_mask: Optional[Tensor] = None,
632
+ pos: Optional[Tensor] = None,
633
+ query_pos: Optional[Tensor] = None):
634
+ q = k = self.with_pos_embed(tgt, query_pos)
635
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
636
+ key_padding_mask=tgt_key_padding_mask)[0]
637
+ tgt = tgt + self.dropout1(tgt2)
638
+ tgt = self.norm1(tgt)
639
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
640
+ key=self.with_pos_embed(memory, pos),
641
+ value=memory, attn_mask=memory_mask,
642
+ key_padding_mask=memory_key_padding_mask)[0]
643
+ tgt = tgt + self.dropout2(tgt2)
644
+ tgt = self.norm2(tgt)
645
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
646
+ tgt = tgt + self.dropout3(tgt2)
647
+ tgt = self.norm3(tgt)
648
+ return tgt
649
+
650
+ def forward_pre(self, tgt, memory,
651
+ tgt_mask: Optional[Tensor] = None,
652
+ memory_mask: Optional[Tensor] = None,
653
+ tgt_key_padding_mask: Optional[Tensor] = None,
654
+ memory_key_padding_mask: Optional[Tensor] = None,
655
+ pos: Optional[Tensor] = None,
656
+ query_pos: Optional[Tensor] = None):
657
+ tgt2 = self.norm1(tgt)
658
+ q = k = self.with_pos_embed(tgt2, query_pos)
659
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
660
+ key_padding_mask=tgt_key_padding_mask)[0]
661
+ tgt = tgt + self.dropout1(tgt2)
662
+ tgt2 = self.norm2(tgt)
663
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
664
+ key=self.with_pos_embed(memory, pos),
665
+ value=memory, attn_mask=memory_mask,
666
+ key_padding_mask=memory_key_padding_mask)[0]
667
+ tgt = tgt + self.dropout2(tgt2)
668
+ tgt2 = self.norm3(tgt)
669
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
670
+ tgt = tgt + self.dropout3(tgt2)
671
+ return tgt
672
+
673
+ def forward(self, tgt, memory,
674
+ tgt_mask: Optional[Tensor] = None,
675
+ memory_mask: Optional[Tensor] = None,
676
+ tgt_key_padding_mask: Optional[Tensor] = None,
677
+ memory_key_padding_mask: Optional[Tensor] = None,
678
+ pos: Optional[Tensor] = None,
679
+ query_pos: Optional[Tensor] = None):
680
+ if self.normalize_before:
681
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
682
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
683
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
684
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
685
+
686
+
687
+ def _get_clones(module, N):
688
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
689
+
690
+
691
+ def build_transformer(args):
692
+ return Transformer(
693
+ d_model=args.hidden_dim,
694
+ dropout=args.dropout,
695
+ nhead=args.nheads,
696
+ dim_feedforward=args.dim_feedforward,
697
+ num_encoder_layers=args.enc_layers,
698
+ num_decoder_layers=args.dec_layers,
699
+ normalize_before=args.pre_norm,
700
+ return_intermediate_dec=True,
701
+ )
702
+
703
+
704
+ def _get_activation_fn(activation):
705
+ """Return an activation function given a string"""
706
+ if activation == "relu":
707
+ return F.relu
708
+ if activation == "gelu":
709
+ return F.gelu
710
+ if activation == "glu":
711
+ return F.glu
712
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
713
+
714
+
715
+ class LLMEyeCap(nn.Module): # Im Novel Object Captioning V 0.1
716
+
717
+ def __init__(self, backbone, transformer, num_queries, vocab_size,pad_token):
718
+
719
+ super().__init__()
720
+ self.num_queries = num_queries
721
+ self.transformer = transformer
722
+ self.hidden_dim = transformer.d_model
723
+
724
+ self.caption_embed = nn.Linear(self.hidden_dim, vocab_size)
725
+ self.bbox_embed = MLP(self.hidden_dim, self.hidden_dim, 4, 3)
726
+
727
+ self.query_embed = nn.Embedding(num_queries, self.hidden_dim)
728
+ self.input_proj = nn.Conv2d(backbone.num_channels, self.hidden_dim, kernel_size=1)
729
+ self.backbone = backbone
730
+ '''
731
+ self.capdecoder = CaptioningDecoder(detr_decoder_dim=transformer.d_model, token_embedding_dim=transformer.d_model,
732
+ vocab_size=vocab_size, num_queries=num_queries, num_layers=6)
733
+ '''
734
+ self.capdecoder = CaptionDecoder(feature_size, token_size, vocab_size,num_queries,pad_token ).to(device)
735
+
736
+
737
+ def forward(self, samples: NestedTensor, captions):
738
+
739
+ if isinstance(samples, (list, torch.Tensor)):
740
+ samples = nested_tensor_from_tensor_list(samples)
741
+
742
+ features, pos = self.backbone(samples) #featers + position embedding
743
+ src, mask = features[-1].decompose()
744
+ assert mask is not None
745
+ hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
746
+ outputs_coord = self.bbox_embed(hs).sigmoid()
747
+
748
+ outputs_captions=self.capdecoder(hs,captions)
749
+ # predicted_sequences = torch.argmax(outputs_captions, dim=-1)
750
+
751
+ out = {'pred_logits': outputs_captions , 'pred_boxes': outputs_coord[-1]}
752
+ return out
753
+
754
+ def generate_caption(self, image_path, tokenizer, max_length, pad_sos):
755
+
756
+ image = Image.open(image_path).convert('RGB')
757
+ transform = transforms.Compose([
758
+ transforms.Resize((256, 256)),
759
+ transforms.ToTensor(),
760
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
761
+ ])
762
+
763
+ image = transform(image).unsqueeze(0).to(device)
764
+
765
+ if isinstance(image, (list, torch.Tensor)):
766
+ image = nested_tensor_from_tensor_list(image)
767
+
768
+ with torch.no_grad():
769
+ features, pos = self.backbone(image) #featers + position embedding
770
+ src, mask = features[-1].decompose()
771
+ assert mask is not None
772
+
773
+ hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
774
+ outputs_coord = self.bbox_embed(hs).sigmoid()
775
+
776
+ input_ids = torch.ones((1, 40, 1), dtype=torch.long, device=device)
777
+ input_ids.fill_(pad_sos)
778
+
779
+
780
+ for i in range(max_length):
781
+ outputs_captions = self.capdecoder(hs, input_ids)
782
+ predicted_sequences = torch.argmax(outputs_captions, dim=-1)
783
+ next_token = predicted_sequences[:, :, -1:] # take the last token from the sequence
784
+ input_ids = torch.cat((input_ids, next_token), dim=-1)
785
+
786
+ #caption = tokenizer.detokenize(input_ids[0].tolist()) #, skip_special_tokens=True)
787
+
788
+ return outputs_coord[-1], input_ids # caption[-1]
789
+
790
+ class LLMEyeCapModel(nn.Module):
791
+ def __init__(self, num_queries,vocab_size,pad_token):
792
+ super(LLMEyeCapModel,self).__init__()
793
+ self.num_queries = num_queries
794
+ self.vocab_size=vocab_size
795
+ self.backbone = build_backbone(args)
796
+ self.transformer = build_transformer(args)
797
+
798
+ self.model = LLMEyeCap(
799
+ self.backbone,
800
+ self.transformer,
801
+ num_queries=self.num_queries,
802
+ vocab_size=self.vocab_size,
803
+ pad_token=pad_token
804
+ )
805
+
806
+ # self.in_features = self.caption_embed.in_features
807
+
808
+ # self.model.class_embed = nn.Linear(in_features=self.in_features,out_features=self.num_classes)
809
+
810
+ self.model.num_queries = self.num_queries
811
+
812
+ def forward(self,images,captions):
813
+ return self.model(images,captions)
814
+
815
+ def generate_caption(self, image_path, tokenizer, max_length=20,pad_sos=0):
816
+ return self.model.generate_caption(image_path, tokenizer, max_length,pad_sos)
817
+
818
+ class CaptionDecoder(nn.Module):
819
+ def __init__(self, detr_decoder_dim, token_embedding_dim, vocab_size, num_queries, pad_token, num_layers=6):
820
+ super(CaptionDecoder, self).__init__()
821
+
822
+ self.detr_decoder_dim = detr_decoder_dim
823
+ self.token_embedding_dim = token_embedding_dim
824
+ self.vocab_size = vocab_size
825
+ self.num_queries = num_queries
826
+ self.pad_token = pad_token
827
+
828
+ # Token embedding layer
829
+ self.token_embedding = nn.Embedding(vocab_size, token_embedding_dim)
830
+
831
+ # Initialize GPT-2
832
+ config = GPT2Config(vocab_size=vocab_size, n_embd=detr_decoder_dim + token_embedding_dim, n_head=8 )
833
+ self.gpt2 = GPT2LMHeadModel(config)
834
+
835
+ self.target_projection = nn.Linear(token_embedding_dim, detr_decoder_dim + token_embedding_dim)
836
+
837
+ def forward(self, detr_output, captions):
838
+
839
+
840
+ # Create an attention mask with shape [batch_size, num_queries, sequence_length]
841
+ attention_mask = (captions != self.pad_token).float().to(captions.device) # [batch_size, num_queries, sequence_length]
842
+
843
+
844
+ seq_length = captions.size(2)
845
+ pos_encoding = get_sinusoid_encoding_table(seq_length, self.token_embedding_dim).to(captions.device)
846
+ pos_encoding = pos_encoding.unsqueeze(0).repeat(captions.size(0) * self.num_queries, 1, 1)
847
+
848
+ # Get the last layer's output from the DETR decoder
849
+ spatial_embedding = detr_output[-1] # [batch_size, num_queries, detr_decoder_dim]
850
+
851
+ # Get token embeddings
852
+ token_embeddings = self.token_embedding(captions) # [batch_size, num_queries, seq_length, token_embedding_dim]
853
+
854
+ # Repeat the spatial embedding for each token in the sequence and concatenate
855
+ spatial_embedding = spatial_embedding.unsqueeze(2) # Add seq_length dimension: [batch_size, num_queries, 1, detr_decoder_dim]
856
+ combined_embedding = torch.cat([spatial_embedding.repeat(1, 1, token_embeddings.size(2), 1), token_embeddings], dim=-1)
857
+ # combined_embedding shape: [batch_size, num_queries, seq_length, detr_decoder_dim + token_embedding_dim]
858
+
859
+ # Prepare the memory for the transformer decoder
860
+ memory = combined_embedding.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.detr_decoder_dim + self.token_embedding_dim)
861
+ # memory shape: [seq_length, batch_size*num_queries, detr_decoder_dim + token_embedding_dim]
862
+
863
+ # Prepare the target for the transformer decoder (using token embeddings)
864
+ target = token_embeddings.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.token_embedding_dim)
865
+ # target shape: [seq_length, batch_size*num_queries, token_embedding_dim]
866
+
867
+
868
+ pos_encoding = pos_encoding.permute(1, 0, 2)
869
+ target += pos_encoding
870
+
871
+
872
+ # Project target to the required dimension
873
+
874
+ target = self.target_projection(target)
875
+
876
+ attention_mask = attention_mask.permute(2, 0, 1).reshape(captions.size(2), -1)
877
+ tgt_key_padding_mask = (attention_mask == 0).permute(1,0)
878
+
879
+ # Prepare the inputs for GPT-2
880
+ inputs_embeds = combined_embedding.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.detr_decoder_dim + self.token_embedding_dim)
881
+
882
+ # Reshape attention_mask for GPT-2. Flatten the batch_size and num_queries dimensions.
883
+ attention_mask = attention_mask.reshape(-1, captions.size(2)) # New shape: [batch_size * num_queries, sequence_length]
884
+
885
+ # Pass through GPT-2
886
+ outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
887
+ logits = outputs.logits
888
+
889
+ # Reshape logits to match the original shape
890
+ logits = logits.view(captions.size(2), captions.size(0), self.num_queries, self.vocab_size).permute(1, 2, 0, 3)
891
+
892
+ return logits
893
+
train.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision.models import resnet50
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ from transformers import BertTokenizer, BertModel
9
+ import os
10
+ import json
11
+ import numpy as np
12
+ from collections import defaultdict
13
+ import random
14
+ from tqdm.notebook import tqdm
15
+ from torchvision import models
16
+ from torch.nn.utils.rnn import pad_sequence
17
+ import matplotlib.patches as patches
18
+
19
+ import math
20
+ import time
21
+ import os
22
+ from PIL import Image
23
+ import requests
24
+ import nltk
25
+
26
+ import os
27
+ import cv2
28
+ import colorsys
29
+ from numpy import asarray
30
+ import math
31
+
32
+
33
+ from transformers import GPT2LMHeadModel, GPT2Config
34
+
35
+ from transformers import BertTokenizer
36
+
37
+
38
+ from scipy.optimize import linear_sum_assignment
39
+
40
+
41
+
42
+
43
+ class CocoDataset(Dataset):
44
+ def __init__(self, root_dir, annotation_file, instance_file, max_objects=40, transform=None):
45
+ self.root_dir = root_dir
46
+ self.transform = transform
47
+ self.max_objects = max_objects
48
+ self.img_cache = dict() # Cache for images
49
+
50
+ # Load instance file only once
51
+ with open(instance_file, 'r') as file:
52
+ data = json.load(file)
53
+ instances = data['annotations']
54
+ categories = data['categories']
55
+
56
+ with open(annotation_file, 'r') as file:
57
+ annotations = json.load(file)['annotations']
58
+
59
+ self.image_captions = defaultdict(list)
60
+ for annotation in annotations:
61
+ img_id = annotation['image_id']
62
+ self.image_captions[img_id].append(annotation['caption'])
63
+
64
+ self.image_instances = defaultdict(list)
65
+ self.category_id_to_name = {category['id']: category['name'] for category in categories}
66
+
67
+ for instance in instances:
68
+ img_id = instance['image_id']
69
+ bbox = instance['bbox']
70
+ category_id = instance['category_id']
71
+ self.image_instances[img_id].append((bbox, category_id))
72
+
73
+ self.img_ids = list(self.image_captions.keys())
74
+
75
+ def __len__(self):
76
+ return len(self.img_ids)
77
+
78
+ def __getitem__(self, index):
79
+ img_id = self.img_ids[index]
80
+ img_path = os.path.join(self.root_dir, f'{str(img_id).zfill(12)}.jpg')
81
+
82
+ # Use cached image if available
83
+
84
+ if img_id in self.img_cache:
85
+ img = self.img_cache[img_id]
86
+ else:
87
+ img = Image.open(img_path).convert("RGB")
88
+ self.img_cache[img_id] = img
89
+
90
+
91
+ captions = self.image_captions[img_id]
92
+ caption = random.choice(captions)
93
+
94
+ annotations = self.image_instances[img_id]
95
+ bboxes = []
96
+ labels = []
97
+ for obbox, label_id in annotations:
98
+ bbox = torch.tensor(obbox) # Convert to PyTorch tensor immediately
99
+ bbox[0] = bbox[0] / img.width + (bbox[2] / img.width)/2
100
+ bbox[1] = bbox[1] / img.height + (bbox[3] / img.height)/2
101
+ bbox[2] = bbox[2] / img.width
102
+ bbox[3] = bbox[3] / img.height
103
+ label = self.category_id_to_name[label_id]
104
+ bboxes.append(bbox)
105
+ labels.append(label)
106
+
107
+ bboxes.append(torch.tensor([0.5, 0.5, 1, 1]))
108
+ labels.append(caption)
109
+
110
+ total_boxes = len(bboxes)
111
+
112
+ if total_boxes < 40:
113
+ for _ in range(40-total_boxes):
114
+ bboxes.append(torch.tensor([0, 0, 0 ,0]))
115
+ labels.append("na")
116
+ else:
117
+ bboxes = bboxes[:40]
118
+ labels = labels[:40]
119
+
120
+ if self.transform:
121
+ img = self.transform(img)
122
+
123
+ return img, bboxes, labels
124
+
125
+ # Définir les transformations
126
+ transform = transforms.Compose([
127
+ transforms.Resize((256, 256)),
128
+ transforms.ToTensor(),
129
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
130
+ ])
131
+
132
+
133
+ def custom_collate(batch):
134
+ images, boxes_list, labels_list = zip(*batch)
135
+
136
+ # Convert list of PIL images to a single PyTorch tensor
137
+ stacked_images = torch.stack(images)
138
+
139
+ # Convert list of list of boxes to a list of PyTorch tensors
140
+ stacked_boxes = [torch.stack([box.clone().detach() for box in boxes]) for boxes in boxes_list]
141
+
142
+
143
+ # Since labels are strings, we can keep them as a list of lists
144
+ # labels_list is already in the desired format
145
+
146
+ return stacked_images, stacked_boxes, labels_list
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+ def train_fn(data_loader, model, criterion, optimizer, device, scheduler, epoch):
156
+ model.train()
157
+ criterion.train()
158
+ summary_loss = AverageMeter()
159
+
160
+ tk0 = tqdm(data_loader, total=len(data_loader)-1)
161
+
162
+ for step, (images, bboxes, captions) in enumerate(tk0):
163
+
164
+ try:
165
+ flattened_captions = [caption for sublist in captions for caption in sublist]
166
+ captions = tokenizer(flattened_captions, padding=True, return_tensors="pt", truncation=True)
167
+ captions = captions["input_ids"]
168
+ input_ids = captions.reshape(batch_size, num_queries, -1).to(device)
169
+ min_length = 2
170
+ except RuntimeError as e:
171
+ print("Reshape failed:", e)
172
+ continue
173
+
174
+ '''
175
+ min_length = 2
176
+ if input_ids.size(-1) < min_length:
177
+ padding_needed = min_length - input_ids.size(-1)
178
+ input_ids = F.pad(input_ids, (0, padding_needed), 'constant', PAD_TOKEN)
179
+
180
+ targets = build_targets(bboxes, input_ids[:, :, 1:])
181
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
182
+
183
+ images = list(image.to(device) for image in images)
184
+
185
+
186
+ output = model(images,input_ids[:, :,:-1])
187
+ '''
188
+
189
+ min_length = 2
190
+ if input_ids.size(-1) < min_length:
191
+ padding_needed = min_length - input_ids.size(-1)
192
+ input_ids = F.pad(input_ids, (0, padding_needed), 'constant', PAD_TOKEN)
193
+
194
+ # input_ids = captions["input_ids"]
195
+ # input_ids = input_ids.reshape(batch_size, num_queries, -1).to(device)
196
+
197
+ targets = build_targets(bboxes, input_ids[:, :, 1:])
198
+
199
+ #targets = build_targets(bboxes, captions[:,:,1:])
200
+
201
+ images = list(image.to(device) for image in images)
202
+
203
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
204
+
205
+
206
+ output = model(images,input_ids[:,:,:-1])
207
+
208
+ loss_dict = criterion(output, targets)
209
+ weight_dict = criterion.weight_dict
210
+
211
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
212
+
213
+ optimizer.zero_grad()
214
+ losses.backward()
215
+ optimizer.step()
216
+
217
+ if scheduler is not None:
218
+ scheduler.step()
219
+
220
+ # Detach and delete tensors
221
+ loss_dict = {k: v.detach() for k, v in loss_dict.items()}
222
+
223
+ del images, bboxes, captions, output, targets, loss_dict
224
+ torch.cuda.empty_cache() # Clear cache
225
+
226
+ summary_loss.update(losses.item(),BATCH_SIZE)
227
+ tk0.set_postfix(loss=summary_loss.avg)
228
+
229
+
230
+ return summary_loss
231
+ class HungarianMatcher(nn.Module):
232
+ """This class computes an assignment between the targets and the predictions of the network
233
+
234
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
235
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
236
+ while the others are un-matched (and thus treated as non-objects).
237
+ """
238
+
239
+ def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
240
+ """Creates the matcher
241
+
242
+ Params:
243
+ cost_class: This is the relative weight of the classification error in the matching cost
244
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
245
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
246
+ """
247
+ super().__init__()
248
+ self.cost_class = cost_class
249
+ self.cost_bbox = cost_bbox
250
+ self.cost_giou = cost_giou
251
+ assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
252
+
253
+ @torch.no_grad()
254
+ def forward(self, outputs, targets):
255
+ bs, num_queries = outputs["pred_logits"].shape[:2]
256
+
257
+ # We flatten to compute the cost matrices in a batch
258
+ # out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
259
+
260
+ out_prob = outputs["pred_logits"].flatten(0,2 ).softmax(-1) # [batch_size * num_queries * seq_length, vocab_size ]
261
+
262
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
263
+
264
+ # Also concat the target labels and boxes
265
+ tgt_ids = torch.cat([v["labels"] for v in targets])
266
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
267
+
268
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
269
+ # but approximate it in 1 - proba[target class].
270
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
271
+
272
+ cost_class = -out_prob[:, tgt_ids]
273
+
274
+ # Compute the L1 cost between boxes
275
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
276
+
277
+ # Compute the giou cost betwen boxes
278
+ cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
279
+
280
+ # Final cost matrix
281
+ C = self.cost_bbox * cost_bbox + self.cost_class * cost_class.mean() + self.cost_giou * cost_giou
282
+ #C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
283
+ C = C.view(bs, num_queries, -1).cpu()
284
+
285
+ sizes = [len(v["boxes"]) for v in targets]
286
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
287
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
288
+
289
+
290
+
291
+ def build_matcher(args):
292
+ return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)
293
+
294
+
295
+
296
+ class SetCriterion(nn.Module):
297
+ """ This class computes the loss for DETR.
298
+ The process happens in two steps:
299
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
300
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
301
+ """
302
+ def __init__(self, vocab_size, matcher, weight_dict, eos_coef, losses,pad_token):
303
+ """ Create the criterion.
304
+ Parameters:
305
+ vocab_size : es number of object categories, omitting the special no-object category
306
+ matcher: module able to compute a matching between targets and proposals
307
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
308
+ eos_coef: relative classification weight applied to the no-object category
309
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
310
+ """
311
+ super().__init__()
312
+ self.vocab_size = vocab_size
313
+ self.matcher = matcher
314
+ self.weight_dict = weight_dict
315
+ self.eos_coef = eos_coef
316
+ self.losses = losses
317
+ self.pad_token=pad_token
318
+ empty_weight = torch.ones(self.vocab_size)
319
+ # empty_weight[-1] = self.eos_coef
320
+ self.register_buffer('empty_weight', empty_weight)
321
+ self.criterion = nn.CrossEntropyLoss(ignore_index=pad_token)
322
+
323
+
324
+ def loss_labels(self, outputs, targets, indices, num_boxes, log=False):
325
+
326
+ """Classification loss (NLL) for sequences
327
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes, seq_length]
328
+ """
329
+ assert 'pred_logits' in outputs
330
+ src_logits = outputs['pred_logits']
331
+ batch_size, num_boxes , sequence_length, _ = src_logits.size()
332
+
333
+ # Get the indices for the permutation
334
+ batch_idx, src_idx = self._get_src_permutation_idx(indices)
335
+
336
+ target_classes = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
337
+
338
+ # Ensure the target classes are valid
339
+ assert (target_classes >= 0).all() and (target_classes < self.vocab_size).all(), "Invalid token index in target!"
340
+
341
+ # loss_ce = criterion(outputs.reshape(-1, vocab_size), captions.view(-1))
342
+ loss_ce = self.criterion(src_logits.reshape(batch_size * num_boxes * sequence_length, -1), target_classes.reshape(-1))
343
+
344
+
345
+ # loss_ce = torchmetrics.functional.smooth_cross_entropy(src_logits[batch_idx], target_classes, ignore_index=PAD_TOKEN)
346
+ losses = {'loss_ce': loss_ce}
347
+
348
+ return losses
349
+
350
+
351
+
352
+
353
+ '''
354
+ criterion = nn.CrossEntropyLoss(ignore_index=self.PAD_TOKEN)
355
+ loss_ce = criterion(src_logits, target_classes_for_loss)
356
+ losses = {'loss_ce': loss_ce}
357
+ '''
358
+
359
+ @torch.no_grad()
360
+ def loss_cardinality(self, outputs, targets, indices, num_boxes):
361
+ """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
362
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
363
+ """
364
+ pred_logits = outputs['pred_logits']
365
+ device = pred_logits.device
366
+ tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
367
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
368
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
369
+
370
+ card_pred = card_pred.sum(dim=1)
371
+
372
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
373
+ losses = {'cardinality_error': card_err}
374
+ return losses
375
+
376
+ def loss_boxes(self, outputs, targets, indices, num_boxes):
377
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
378
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
379
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
380
+ """
381
+ assert 'pred_boxes' in outputs
382
+ idx = self._get_src_permutation_idx(indices)
383
+
384
+ src_boxes = outputs['pred_boxes'][idx]
385
+ target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
386
+
387
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
388
+
389
+ losses = {}
390
+ losses['loss_bbox'] = loss_bbox.sum() / num_boxes
391
+
392
+ loss_giou = 1 - torch.diag(generalized_box_iou(
393
+ box_cxcywh_to_xyxy(src_boxes),
394
+ box_cxcywh_to_xyxy(target_boxes)))
395
+ losses['loss_giou'] = loss_giou.sum() / num_boxes
396
+ return losses
397
+
398
+ def loss_masks(self, outputs, targets, indices, num_boxes):
399
+ """Compute the losses related to the masks: the focal loss and the dice loss.
400
+ targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
401
+ """
402
+ assert "pred_masks" in outputs
403
+
404
+ src_idx = self._get_src_permutation_idx(indices)
405
+ tgt_idx = self._get_tgt_permutation_idx(indices)
406
+ src_masks = outputs["pred_masks"]
407
+ src_masks = src_masks[src_idx]
408
+ masks = [t["masks"] for t in targets]
409
+ # TODO use valid to mask invalid areas due to padding in loss
410
+ target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
411
+ target_masks = target_masks.to(src_masks)
412
+ target_masks = target_masks[tgt_idx]
413
+
414
+ # upsample predictions to the target size
415
+ src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
416
+ mode="bilinear", align_corners=False)
417
+ src_masks = src_masks[:, 0].flatten(1)
418
+
419
+ target_masks = target_masks.flatten(1)
420
+ target_masks = target_masks.view(src_masks.shape)
421
+ losses = {
422
+ "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
423
+ "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
424
+ }
425
+ return losses
426
+
427
+ def _get_src_permutation_idx(self, indices):
428
+ # permute predictions following indices
429
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
430
+ src_idx = torch.cat([src for (src, _) in indices])
431
+ return batch_idx, src_idx
432
+
433
+ def _get_tgt_permutation_idx(self, indices):
434
+ # permute targets following indices
435
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
436
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
437
+ return batch_idx, tgt_idx
438
+
439
+ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
440
+ loss_map = {
441
+ 'labels': self.loss_labels,
442
+ 'cardinality': self.loss_cardinality,
443
+ 'boxes': self.loss_boxes,
444
+ 'masks': self.loss_masks
445
+ }
446
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
447
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
448
+
449
+ def forward(self, outputs, targets):
450
+ """ This performs the loss computation.
451
+ Parameters:
452
+ outputs: dict of tensors, see the output specification of the model for the format
453
+ targets: list of dicts, such that len(targets) == batch_size.
454
+ The expected keys in each dict depends on the losses applied, see each loss' doc
455
+ """
456
+
457
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
458
+
459
+ # Retrieve the matching between the outputs of the last layer and the targets
460
+ indices = self.matcher(outputs_without_aux, targets)
461
+
462
+ # print("indice len", len(indices), "len (indices[0]) ", len (indices[0]))
463
+ # print( " shape indices 0 0 ", indices [0][0].shape , " shape indices 0 1 ", indices [0][1].shape)
464
+
465
+ # Compute the average number of target boxes accross all nodes, for normalization purposes
466
+ num_boxes = sum(len(t["labels"]) for t in targets)
467
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
468
+ if is_dist_avail_and_initialized():
469
+ torch.distributed.all_reduce(num_boxes)
470
+ num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
471
+ # print("num_boxes",num_boxes)
472
+ # Compute all the requested losses
473
+ losses = {}
474
+ for loss in self.losses:
475
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
476
+ '''
477
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
478
+ if 'aux_outputs' in outputs:
479
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
480
+ indices = self.matcher(aux_outputs, targets)
481
+ for loss in self.losses:
482
+ if loss == 'masks':
483
+ # Intermediate masks losses are too costly to compute, we ignore them.
484
+ continue
485
+ kwargs = {}
486
+ if loss == 'labels':
487
+ # Logging is enabled only for the last layer
488
+ kwargs = {'log': False}
489
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
490
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
491
+ losses.update(l_dict)
492
+ '''
493
+ return losses
494
+
495
+ def eval_fn(data_loader, model,criterion, device):
496
+ model.eval()
497
+ criterion.eval()
498
+ summary_loss = AverageMeter()
499
+
500
+ with torch.no_grad():
501
+
502
+ #tk0 = tqdm(data_loader, total=len(data_loader))
503
+ #for step, (images, bboxes, captions) in enumerate(tk0):
504
+ #pbar = tqdm(range(len(data_loader)))**
505
+
506
+ tk0 = tqdm(data_loader, total=len(data_loader)-1)
507
+ for step, (images, bboxes, captions) in enumerate(tk0):
508
+
509
+ try:
510
+ flattened_captions = [caption for sublist in captions for caption in sublist]
511
+ captions = tokenizer(flattened_captions, padding=True, return_tensors="pt", truncation=True)
512
+ captions = captions["input_ids"]
513
+ input_ids = captions.reshape(batch_size, num_queries, -1).to(device)
514
+ min_length = 2
515
+ except RuntimeError as e:
516
+ print("Reshape failed:", e)
517
+ continue
518
+
519
+ if input_ids.size(-1) < min_length:
520
+ padding_needed = min_length - input_ids.size(-1)
521
+ input_ids = F.pad(input_ids, (0, padding_needed), 'constant', PAD_TOKEN)
522
+
523
+ # input_ids = captions["input_ids"]
524
+ # input_ids = input_ids.reshape(batch_size, num_queries, -1).to(device)
525
+
526
+ targets = build_targets(bboxes, input_ids[:, :, 1:])
527
+
528
+ #targets = build_targets(bboxes, captions[:,:,1:])
529
+
530
+ images = list(image.to(device) for image in images)
531
+
532
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
533
+
534
+
535
+ output = model(images,input_ids[:,:,:-1])
536
+
537
+
538
+ loss_dict = criterion(output, targets)
539
+ weight_dict = criterion.weight_dict
540
+
541
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
542
+
543
+
544
+ summary_loss.update(losses.item(),BATCH_SIZE)
545
+
546
+ #
547
+
548
+ # Detach and delete tensors
549
+ loss_dict = {k: v.detach() for k, v in loss_dict.items()}
550
+
551
+ del images, bboxes, captions, output, targets, loss_dict
552
+ torch.cuda.empty_cache() # Clear cache
553
+
554
+ tk0.set_postfix(loss=summary_loss.avg)
555
+ #data_loader.on_epoch_end()
556
+
557
+ return summary_loss
558
+
559
+ def build_targets(bboxes, captions):
560
+ targets = []
561
+ for i, (bbox, caption) in enumerate(zip(bboxes, captions)):
562
+ target = {
563
+ "boxes": bbox,
564
+ "labels": caption,
565
+ }
566
+ targets.append(target)
567
+ return targets
568
+
569
+ if __name__ == "__main__":
570
+
571
+ # Créer les datasets
572
+ train_dataset = CocoDataset(root_dir="../data/coco91/train2017",
573
+ annotation_file="../data/coco91/annotations/captions_train2017.json",
574
+ instance_file="../data/coco91/annotations/instances_train2017.json",
575
+ transform=transform)
576
+ val_dataset = CocoDataset(root_dir="../data/coco91/val2017", annotation_file="../data/coco91/annotations/captions_val2017.json",
577
+ instance_file="../data/coco91/annotations/instances_val2017.json",
578
+ transform=transform)
579
+
580
+
581
+ batch_size=4
582
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
583
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=custom_collate)
584
+
585
+ # Initialiser le tokenizer BERT
586
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
587
+
588
+ # Obtenir le token de padding et son ID
589
+ #PAD_TOKEN = tokenizer.pad_token
590
+ PAD_TOKEN = tokenizer.pad_token_id
591
+
592
+ # Obtenir le token de début de séquence et son ID
593
+ # Pour BERT, le token de début de séquence est souvent le même que le token [CLS]
594
+ #start_of_sequence_token = tokenizer.cls_token
595
+ PAD_SOS = tokenizer.cls_token_id
596
+
597
+ # Obtenir la taille du vocabulaire
598
+ vocab_size = tokenizer.vocab_size
599
+
600
+ print(f"Pad token: {PAD_TOKEN}")
601
+ print(f"Start of Sequence token: {PAD_SOS}, ID: {PAD_SOS}")
602
+ print(f"Vocab size: {vocab_size}")
603
+
604
+ matcher = HungarianMatcher()
605
+
606
+ weight_dict = weight_dict = {'loss_ce': 1, 'loss_bbox': 1 , 'loss_giou': 1}
607
+
608
+ losses = ['labels', 'boxes', 'cardinality']
609
+
610
+ criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
611
+
612
+ model = LLMEyaCapModel(num_queries=NUM_QUERIES,vocab_size=vocab_size)
613
+ model = model.to(device)
614
+
615
+ criterion = SetCriterion(vocab_size, matcher=matcher, weight_dict=weight_dict, eos_coef = NULL_CLASS_COEF, losses=losses)
616
+ criterion = criterion.to(device)
617
+
618
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
619
+
620
+ best_loss = 10**5
621
+
622
+ LR = 2e-6
623
+ #LR = 2e-4
624
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR) #, weight_decay=0.0001)
625
+ EPOCHS=1
626
+ num_queries=NUM_QUERIES
627
+ batch_size=4
628
+
629
+ for epoch in range(EPOCHS):
630
+ time_start = time.time()
631
+ train_loss = train_fn(train_loader, model,criterion, optimizer,device,scheduler=None,epoch=epoch)
632
+ valid_loss = eval_fn(val_loader, model,criterion, device)
633
+
634
+ elapsed = time.time() - time_start
635
+ chk_name = f'LLMEyeCap_01_e{epoch}.bin'
636
+ torch.save(model.state_dict(), chk_name)
637
+ print(f"[Epoch {epoch+1:2d} / {EPOCHS:2d}] Train loss: {train_loss.avg:.3f}. Val loss: {valid_loss.avg:.3f} --> {chk_name} [{elapsed/60:.0f} mins]")
638
+
639
+ if valid_loss.avg < best_loss:
640
+ best_loss = valid_loss.avg
641
+ print(f'Best model found in epoch {epoch+1}........Saving Model')
642
+ torch.save(model.state_dict(), 'LLMEyeCap_01_model.bin')
tuto.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.ops.boxes import box_area
2
+
3
+
4
+ def box_cxcywh_to_xyxy(x):
5
+ x_c, y_c, w, h = x.unbind(-1)
6
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
7
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
8
+ return torch.stack(b, dim=-1)
9
+
10
+
11
+ def box_xyxy_to_cxcywh(x):
12
+ x0, y0, x1, y1 = x.unbind(-1)
13
+ b = [(x0 + x1) / 2, (y0 + y1) / 2,
14
+ (x1 - x0), (y1 - y0)]
15
+ return torch.stack(b, dim=-1)
16
+
17
+
18
+ # modified from torchvision to also return the union
19
+ def box_iou_2(boxes1, boxes2):
20
+ area1 = box_area(boxes1)
21
+ area2 = box_area(boxes2)
22
+
23
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
24
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
25
+
26
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
27
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
28
+
29
+ union = area1[:, None] + area2 - inter
30
+
31
+ iou = inter / union
32
+ return iou , union
33
+
34
+
35
+ def generalized_box_iou(boxes1, boxes2):
36
+ """
37
+ Generalized IoU from https://giou.stanford.edu/
38
+
39
+ The boxes should be in [x0, y0, x1, y1] format
40
+
41
+ Returns a [N, M] pairwise matrix, where N = len(boxes1)
42
+ and M = len(boxes2)
43
+ """
44
+ # degenerate boxes gives inf / nan results
45
+ # so do an early check
46
+ assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
47
+ assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
48
+ iou, union = box_iou_2(boxes1, boxes2)
49
+
50
+ lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
51
+ rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
52
+
53
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
54
+ area = wh[:, :, 0] * wh[:, :, 1]
55
+
56
+ return iou - (area - union) / area
57
+
58
+
59
+ def masks_to_boxes(masks):
60
+ """Compute the bounding boxes around the provided masks
61
+
62
+ The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
63
+
64
+ Returns a [N, 4] tensors, with the boxes in xyxy format
65
+ """
66
+ if masks.numel() == 0:
67
+ return torch.zeros((0, 4), device=masks.device)
68
+
69
+ h, w = masks.shape[-2:]
70
+
71
+ y = torch.arange(0, h, dtype=torch.float)
72
+ x = torch.arange(0, w, dtype=torch.float)
73
+ y, x = torch.meshgrid(y, x)
74
+
75
+ x_mask = (masks * x.unsqueeze(0))
76
+ x_max = x_mask.flatten(1).max(-1)[0]
77
+ x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
78
+
79
+ y_mask = (masks * y.unsqueeze(0))
80
+ y_max = y_mask.flatten(1).max(-1)[0]
81
+ y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
82
+
83
+ return torch.stack([x_min, y_min, x_max, y_max], 1)
84
+ """
85
+ Misc functions, including distributed helpers.
86
+
87
+ Mostly copy-paste from torchvision references.
88
+ """
89
+ import os
90
+ import subprocess
91
+ import time
92
+ from collections import defaultdict, deque
93
+ import datetime
94
+ import pickle
95
+ from packaging import version
96
+ from typing import Optional, List
97
+
98
+ import torch
99
+ import torch.distributed as dist
100
+ from torch import Tensor
101
+
102
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
103
+ import torchvision
104
+ if version.parse(torchvision.__version__) < version.parse('0.7'):
105
+ from torchvision.ops import _new_empty_tensor
106
+ from torchvision.ops.misc import _output_size
107
+
108
+
109
+ class SmoothedValue(object):
110
+ """Track a series of values and provide access to smoothed values over a
111
+ window or the global series average.
112
+ """
113
+
114
+ def __init__(self, window_size=20, fmt=None):
115
+ if fmt is None:
116
+ fmt = "{median:.4f} ({global_avg:.4f})"
117
+ self.deque = deque(maxlen=window_size)
118
+ self.total = 0.0
119
+ self.count = 0
120
+ self.fmt = fmt
121
+
122
+ def update(self, value, n=1):
123
+ self.deque.append(value)
124
+ self.count += n
125
+ self.total += value * n
126
+
127
+ def synchronize_between_processes(self):
128
+ """
129
+ Warning: does not synchronize the deque!
130
+ """
131
+ if not is_dist_avail_and_initialized():
132
+ return
133
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
134
+ dist.barrier()
135
+ dist.all_reduce(t)
136
+ t = t.tolist()
137
+ self.count = int(t[0])
138
+ self.total = t[1]
139
+
140
+ @property
141
+ def median(self):
142
+ d = torch.tensor(list(self.deque))
143
+ return d.median().item()
144
+
145
+ @property
146
+ def avg(self):
147
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
148
+ return d.mean().item()
149
+
150
+ @property
151
+ def global_avg(self):
152
+ return self.total / self.count
153
+
154
+ @property
155
+ def max(self):
156
+ return max(self.deque)
157
+
158
+ @property
159
+ def value(self):
160
+ return self.deque[-1]
161
+
162
+ def __str__(self):
163
+ return self.fmt.format(
164
+ median=self.median,
165
+ avg=self.avg,
166
+ global_avg=self.global_avg,
167
+ max=self.max,
168
+ value=self.value)
169
+
170
+
171
+ def all_gather(data):
172
+ """
173
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
174
+ Args:
175
+ data: any picklable object
176
+ Returns:
177
+ list[data]: list of data gathered from each rank
178
+ """
179
+ world_size = get_world_size()
180
+ if world_size == 1:
181
+ return [data]
182
+
183
+ # serialized to a Tensor
184
+ buffer = pickle.dumps(data)
185
+ storage = torch.ByteStorage.from_buffer(buffer)
186
+ tensor = torch.ByteTensor(storage).to("cuda")
187
+
188
+ # obtain Tensor size of each rank
189
+ local_size = torch.tensor([tensor.numel()], device="cuda")
190
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
191
+ dist.all_gather(size_list, local_size)
192
+ size_list = [int(size.item()) for size in size_list]
193
+ max_size = max(size_list)
194
+
195
+ # receiving Tensor from all ranks
196
+ # we pad the tensor because torch all_gather does not support
197
+ # gathering tensors of different shapes
198
+ tensor_list = []
199
+ for _ in size_list:
200
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
201
+ if local_size != max_size:
202
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
203
+ tensor = torch.cat((tensor, padding), dim=0)
204
+ dist.all_gather(tensor_list, tensor)
205
+
206
+ data_list = []
207
+ for size, tensor in zip(size_list, tensor_list):
208
+ buffer = tensor.cpu().numpy().tobytes()[:size]
209
+ data_list.append(pickle.loads(buffer))
210
+
211
+ return data_list
212
+
213
+
214
+ def reduce_dict(input_dict, average=True):
215
+ """
216
+ Args:
217
+ input_dict (dict): all the values will be reduced
218
+ average (bool): whether to do average or sum
219
+ Reduce the values in the dictionary from all processes so that all processes
220
+ have the averaged results. Returns a dict with the same fields as
221
+ input_dict, after reduction.
222
+ """
223
+ world_size = get_world_size()
224
+ if world_size < 2:
225
+ return input_dict
226
+ with torch.no_grad():
227
+ names = []
228
+ values = []
229
+ # sort the keys so that they are consistent across processes
230
+ for k in sorted(input_dict.keys()):
231
+ names.append(k)
232
+ values.append(input_dict[k])
233
+ values = torch.stack(values, dim=0)
234
+ dist.all_reduce(values)
235
+ if average:
236
+ values /= world_size
237
+ reduced_dict = {k: v for k, v in zip(names, values)}
238
+ return reduced_dict
239
+
240
+
241
+ class MetricLogger(object):
242
+ def __init__(self, delimiter="\t"):
243
+ self.meters = defaultdict(SmoothedValue)
244
+ self.delimiter = delimiter
245
+
246
+ def update(self, **kwargs):
247
+ for k, v in kwargs.items():
248
+ if isinstance(v, torch.Tensor):
249
+ v = v.item()
250
+ assert isinstance(v, (float, int))
251
+ self.meters[k].update(v)
252
+
253
+ def __getattr__(self, attr):
254
+ if attr in self.meters:
255
+ return self.meters[attr]
256
+ if attr in self.__dict__:
257
+ return self.__dict__[attr]
258
+ raise AttributeError("'{}' object has no attribute '{}'".format(
259
+ type(self).__name__, attr))
260
+
261
+ def __str__(self):
262
+ loss_str = []
263
+ for name, meter in self.meters.items():
264
+ loss_str.append(
265
+ "{}: {}".format(name, str(meter))
266
+ )
267
+ return self.delimiter.join(loss_str)
268
+
269
+ def synchronize_between_processes(self):
270
+ for meter in self.meters.values():
271
+ meter.synchronize_between_processes()
272
+
273
+ def add_meter(self, name, meter):
274
+ self.meters[name] = meter
275
+
276
+ def log_every(self, iterable, print_freq, header=None):
277
+ i = 0
278
+ if not header:
279
+ header = ''
280
+ start_time = time.time()
281
+ end = time.time()
282
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
283
+ data_time = SmoothedValue(fmt='{avg:.4f}')
284
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
285
+ if torch.cuda.is_available():
286
+ log_msg = self.delimiter.join([
287
+ header,
288
+ '[{0' + space_fmt + '}/{1}]',
289
+ 'eta: {eta}',
290
+ '{meters}',
291
+ 'time: {time}',
292
+ 'data: {data}',
293
+ 'max mem: {memory:.0f}'
294
+ ])
295
+ else:
296
+ log_msg = self.delimiter.join([
297
+ header,
298
+ '[{0' + space_fmt + '}/{1}]',
299
+ 'eta: {eta}',
300
+ '{meters}',
301
+ 'time: {time}',
302
+ 'data: {data}'
303
+ ])
304
+ MB = 1024.0 * 1024.0
305
+ for obj in iterable:
306
+ data_time.update(time.time() - end)
307
+ yield obj
308
+ iter_time.update(time.time() - end)
309
+ if i % print_freq == 0 or i == len(iterable) - 1:
310
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
311
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
312
+ if torch.cuda.is_available():
313
+ print(log_msg.format(
314
+ i, len(iterable), eta=eta_string,
315
+ meters=str(self),
316
+ time=str(iter_time), data=str(data_time),
317
+ memory=torch.cuda.max_memory_allocated() / MB))
318
+ else:
319
+ print(log_msg.format(
320
+ i, len(iterable), eta=eta_string,
321
+ meters=str(self),
322
+ time=str(iter_time), data=str(data_time)))
323
+ i += 1
324
+ end = time.time()
325
+ total_time = time.time() - start_time
326
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
327
+ print('{} Total time: {} ({:.4f} s / it)'.format(
328
+ header, total_time_str, total_time / len(iterable)))
329
+
330
+
331
+ def get_sha():
332
+ cwd = os.path.dirname(os.path.abspath(__file__))
333
+
334
+ def _run(command):
335
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
336
+ sha = 'N/A'
337
+ diff = "clean"
338
+ branch = 'N/A'
339
+ try:
340
+ sha = _run(['git', 'rev-parse', 'HEAD'])
341
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
342
+ diff = _run(['git', 'diff-index', 'HEAD'])
343
+ diff = "has uncommited changes" if diff else "clean"
344
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
345
+ except Exception:
346
+ pass
347
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
348
+ return message
349
+
350
+
351
+ def collate_fn(batch):
352
+ batch = list(zip(*batch))
353
+ batch[0] = nested_tensor_from_tensor_list(batch[0])
354
+ return tuple(batch)
355
+
356
+
357
+ def _max_by_axis(the_list):
358
+ # type: (List[List[int]]) -> List[int]
359
+ maxes = the_list[0]
360
+ for sublist in the_list[1:]:
361
+ for index, item in enumerate(sublist):
362
+ maxes[index] = max(maxes[index], item)
363
+ return maxes
364
+
365
+
366
+ class NestedTensor(object):
367
+ def __init__(self, tensors, mask: Optional[Tensor]):
368
+ self.tensors = tensors
369
+ self.mask = mask
370
+
371
+ def to(self, device):
372
+ # type: (Device) -> NestedTensor # noqa
373
+ cast_tensor = self.tensors.to(device)
374
+ mask = self.mask
375
+ if mask is not None:
376
+ assert mask is not None
377
+ cast_mask = mask.to(device)
378
+ else:
379
+ cast_mask = None
380
+ return NestedTensor(cast_tensor, cast_mask)
381
+
382
+ def decompose(self):
383
+ return self.tensors, self.mask
384
+
385
+ def __repr__(self):
386
+ return str(self.tensors)
387
+
388
+
389
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
390
+ # TODO make this more general
391
+ if tensor_list[0].ndim == 3:
392
+ if torchvision._is_tracing():
393
+ # nested_tensor_from_tensor_list() does not export well to ONNX
394
+ # call _onnx_nested_tensor_from_tensor_list() instead
395
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
396
+
397
+ # TODO make it support different-sized images
398
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
399
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
400
+ batch_shape = [len(tensor_list)] + max_size
401
+ b, c, h, w = batch_shape
402
+ dtype = tensor_list[0].dtype
403
+ device = tensor_list[0].device
404
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
405
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
406
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
407
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
408
+ m[: img.shape[1], :img.shape[2]] = False
409
+ else:
410
+ raise ValueError('not supported')
411
+ return NestedTensor(tensor, mask)
412
+
413
+
414
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
415
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
416
+ @torch.jit.unused
417
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
418
+ max_size = []
419
+ for i in range(tensor_list[0].dim()):
420
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
421
+ max_size.append(max_size_i)
422
+ max_size = tuple(max_size)
423
+
424
+ # work around for
425
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
426
+ # m[: img.shape[1], :img.shape[2]] = False
427
+ # which is not yet supported in onnx
428
+ padded_imgs = []
429
+ padded_masks = []
430
+ for img in tensor_list:
431
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
432
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
433
+ padded_imgs.append(padded_img)
434
+
435
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
436
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
437
+ padded_masks.append(padded_mask.to(torch.bool))
438
+
439
+ tensor = torch.stack(padded_imgs)
440
+ mask = torch.stack(padded_masks)
441
+
442
+ return NestedTensor(tensor, mask=mask)
443
+
444
+
445
+ def setup_for_distributed(is_master):
446
+ """
447
+ This function disables printing when not in master process
448
+ """
449
+ import builtins as __builtin__
450
+ builtin_print = __builtin__.print
451
+
452
+ def print(*args, **kwargs):
453
+ force = kwargs.pop('force', False)
454
+ if is_master or force:
455
+ builtin_print(*args, **kwargs)
456
+
457
+ __builtin__.print = print
458
+
459
+
460
+ def is_dist_avail_and_initialized():
461
+ if not dist.is_available():
462
+ return False
463
+ if not dist.is_initialized():
464
+ return False
465
+ return True
466
+
467
+
468
+ def get_world_size():
469
+ if not is_dist_avail_and_initialized():
470
+ return 1
471
+ return dist.get_world_size()
472
+
473
+
474
+ def get_rank():
475
+ if not is_dist_avail_and_initialized():
476
+ return 0
477
+ return dist.get_rank()
478
+
479
+
480
+ def is_main_process():
481
+ return get_rank() == 0
482
+
483
+
484
+ def save_on_master(*args, **kwargs):
485
+ if is_main_process():
486
+ torch.save(*args, **kwargs)
487
+
488
+
489
+ def init_distributed_mode(args):
490
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
491
+ args.rank = int(os.environ["RANK"])
492
+ args.world_size = int(os.environ['WORLD_SIZE'])
493
+ args.gpu = int(os.environ['LOCAL_RANK'])
494
+ elif 'SLURM_PROCID' in os.environ:
495
+ args.rank = int(os.environ['SLURM_PROCID'])
496
+ args.gpu = args.rank % torch.cuda.device_count()
497
+ else:
498
+ print('Not using distributed mode')
499
+ args.distributed = False
500
+ return
501
+
502
+ args.distributed = True
503
+
504
+ torch.cuda.set_device(args.gpu)
505
+ args.dist_backend = 'nccl'
506
+ print('| distributed init (rank {}): {}'.format(
507
+ args.rank, args.dist_url), flush=True)
508
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
509
+ world_size=args.world_size, rank=args.rank)
510
+ torch.distributed.barrier()
511
+ setup_for_distributed(args.rank == 0)
512
+
513
+
514
+ @torch.no_grad()
515
+ def accuracy(output, target, topk=(1,)):
516
+ if output.dim() == 1:
517
+ output = output.unsqueeze(0)
518
+
519
+ maxk = max(topk)
520
+ batch_size = target.size(0)
521
+
522
+ _, pred = output.topk(maxk, 1, True, True)
523
+ pred = pred.t()
524
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
525
+
526
+ res = []
527
+ for k in topk:
528
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
529
+ res.append(correct_k.mul_(100.0 / batch_size))
530
+ return res
531
+
532
+
533
+ '''
534
+ def accuracy(output, target, topk=(1,)):
535
+ """Computes the precision@k for the specified values of k"""
536
+ if target.numel() == 0:
537
+ return [torch.zeros([], device=output.device)]
538
+ maxk = max(topk)
539
+ batch_size = target.size(0)
540
+
541
+ _, pred = output.topk(maxk, 1, True, True)
542
+ pred = pred.t()
543
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
544
+
545
+ res = []
546
+ for k in topk:
547
+ correct_k = correct[:k].view(-1).float().sum(0)
548
+ res.append(correct_k.mul_(100.0 / batch_size))
549
+ return res
550
+ '''
551
+
552
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
553
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
554
+ """
555
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
556
+ This will eventually be supported natively by PyTorch, and this
557
+ class can go away.
558
+ """
559
+ if version.parse(torchvision.__version__) < version.parse('0.7'):
560
+ if input.numel() > 0:
561
+ return torch.nn.functional.interpolate(
562
+ input, size, scale_factor, mode, align_corners
563
+ )
564
+
565
+ output_shape = _output_size(2, input, size, scale_factor)
566
+ output_shape = list(input.shape[:-2]) + list(output_shape)
567
+ return _new_empty_tensor(input, output_shape)
568
+ else:
569
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)