dianecy commited on
Commit
8377130
·
verified ·
1 Parent(s): 599450c

Upload folder using huggingface_hub

Browse files
engine/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
engine/__init__.py ADDED
File without changes
engine/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (144 Bytes). View file
 
engine/__pycache__/engine.cpython-39.pyc ADDED
Binary file (7.4 kB). View file
 
engine/__pycache__/engine_verbonly.cpython-39.pyc ADDED
Binary file (5.62 kB). View file
 
engine/__pycache__/engine_verbonly_hardneg.cpython-39.pyc ADDED
Binary file (5.65 kB). View file
 
engine/engine.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import math
4
+ from tqdm import tqdm
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.cuda.amp as amp
10
+ import torch.distributed as dist
11
+ import torch.nn.functional as F
12
+ import wandb
13
+ from loguru import logger
14
+ from utils.dataset import tokenize
15
+ from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather,
16
+ trainMetricGPU)
17
+
18
+
19
+ def return_mask(emb_distance):
20
+ B_, B_ = emb_distance.shape
21
+ positive_mask = torch.zeros_like(emb_distance)
22
+ for i in range(B_//2):
23
+ positive_mask[2*i, 2*i+1] = 1
24
+ positive_mask[2*i+1, 2*i] = 1
25
+ positive_mask.fill_diagonal_(1)
26
+ negative_mask = torch.ones_like(emb_distance) - positive_mask
27
+
28
+ return positive_mask, negative_mask
29
+
30
+
31
+ def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None):
32
+ # embeddings: ((2*B), C, (H*W))
33
+ # n_pos : chunk size of positive pairs
34
+ # args: args
35
+ # returns: loss
36
+ metric_loss = 0
37
+
38
+ # flatten embeddings
39
+ B_, C, HW = embeddings.shape
40
+ emb = torch.mean(embeddings, dim=-1) # (2*B, C)
41
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
42
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
43
+ emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B)
44
+ assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \
45
+ "Diagonals are not zero. please check the permutation on the batch"
46
+ # print("distance metrix : ", emb_distance)
47
+
48
+ positive_mask, negative_mask = return_mask(emb_distance)
49
+ positive_loss = torch.sum(emb_distance * positive_mask) / B_**2 #B_
50
+
51
+ # negative pairs and loss
52
+ # negative_mask = torch.ones_like(emb_distance) - positive_mask
53
+ negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_**2) #(B_**2 - 2*B_))
54
+
55
+ # print(positive_mask, negative_mask)
56
+
57
+ metric_loss = alpha * positive_loss + (1-alpha) * negative_loss
58
+
59
+ return metric_loss
60
+
61
+
62
+ def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None, mask = None):
63
+ # embeddings: ((2*B), C, (H*W))
64
+ # n_pos : chunk size of positive pairs
65
+ # args: args
66
+ # returns: loss
67
+ geometric_loss = 0
68
+
69
+ # flatten embeddings
70
+ B_, C, HW = embeddings.shape
71
+ emb = torch.mean(embeddings, dim=-1) # (2*B, C)
72
+ emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C)
73
+ emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C)
74
+ sim = nn.CosineSimilarity(dim=-1, eps=1e-6)
75
+ sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (2*B , 2*B)
76
+ sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999)
77
+ #print("similarity metrix : ", sim_matrix)
78
+ phi = torch.acos(sim_matrix) # (2*B, 2*B)
79
+ #print("phi metrix : ", phi)
80
+ #print(args.batch_size, B_)
81
+ assert (B_ == args.batch_size * 2 * args.ngpus_per_node), \
82
+ "B_ must be 2x batch_size. please check the inputs."
83
+
84
+ # positive pairs and loss
85
+ positive_mask, negative_mask = return_mask(sim_matrix)
86
+ # positive_mask = torch.zeros_like(sim_matrix)
87
+ # for i in range(B_//2):
88
+ # positive_mask[2*i, 2*i+1] = 1
89
+ # positive_mask[2*i+1, 2*i] = 1
90
+ # positive_mask.fill_diagonal_(1)
91
+ positive_loss = torch.sum((phi**2) * positive_mask) / B_**2
92
+
93
+ # negative pairs and loss
94
+ # negative_mask = torch.ones_like(sim_matrix) - positive_mask
95
+ phi_mask = phi < args.phi_threshold
96
+ negative_loss = (args.phi_threshold - phi)**2
97
+ #print(negative_mask * phi_mask)
98
+ negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / B_**2
99
+
100
+ #print("pos loss, neg loss : ", positive_loss, negative_loss)
101
+
102
+ geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss
103
+
104
+ return geometric_loss
105
+
106
+
107
+ def train(train_loader, model, optimizer, scheduler, scaler, epoch, args):
108
+ batch_time = AverageMeter('Batch', ':2.2f')
109
+ data_time = AverageMeter('Data', ':2.2f')
110
+ lr = AverageMeter('Lr', ':1.6f')
111
+ loss_meter = AverageMeter('Loss', ':2.4f')
112
+ iou_meter = AverageMeter('IoU', ':2.2f')
113
+ pr_meter = AverageMeter('Prec@50', ':2.2f')
114
+ progress = ProgressMeter(
115
+ len(train_loader),
116
+ [batch_time, data_time, lr, loss_meter, iou_meter, pr_meter],
117
+ prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs))
118
+ metric_learning = args.metric_learning
119
+ # mix_distance_angular = args.mix_distance_angular
120
+ # positive_strength = args.positive_strength
121
+ # angular_loss_weight = args.metric_loss_weight * math.exp(-3.0 * (1-epoch/args.epochs)**2)
122
+ #print("epoch : ", epoch, ", angular loss weight : ", angular_loss_weight)
123
+ # distance_loss_weight = args.distance_loss_weight
124
+
125
+ model.train()
126
+ time.sleep(2)
127
+ end = time.time()
128
+
129
+ # size_list = [320, 352, 384, 416, 448, 480, 512]
130
+ # idx = np.random.choice(len(size_list))
131
+ # new_size = size_list[idx]
132
+
133
+ for i, (image, text, target) in enumerate(train_loader):
134
+ data_time.update(time.time() - end)
135
+ # data
136
+ image = image.cuda(non_blocking=True)
137
+ text = text.cuda(non_blocking=True)
138
+ target = target.cuda(non_blocking=True).unsqueeze(1)
139
+
140
+ # # multi-scale training
141
+ # image = F.interpolate(image, size=(new_size, new_size), mode='bilinear')
142
+
143
+ # masking when params exists
144
+ #mask_tensor = torch.tensor([True if params[i] else False for i in range(len(params))], dtype=torch.bool)
145
+
146
+ # forward
147
+ with amp.autocast():
148
+ pred, target, loss = model(image, text, target)
149
+ # pred, target, CE_loss, metric_tensor = model(image, text, target)
150
+
151
+ # gather tensors
152
+ # metric_tensor = concat_all_gather(metric_tensor)
153
+
154
+ # get metric loss
155
+ #print("gathered tensor shape : ", metric_tensor.shape)
156
+ # metric_loss = 0
157
+ # if metric_learning:
158
+ # metric_loss += \
159
+ # angular_loss_weight * AngularMetricLoss(metric_tensor, 2, alpha=positive_strength, args = args) #, mask=mask_tensor)
160
+ # if mix_distance_angular:
161
+ # metric_loss += \
162
+ # distance_loss_weight * MetricLoss(metric_tensor, 2, alpha=positive_strength, args = args) #, mask=mask_tensor)
163
+ # loss = (CE_loss + metric_loss) / \
164
+ # (1 + angular_loss_weight*metric_learning + \
165
+ # distance_loss_weight*metric_learning*mix_distance_angular)
166
+ # else :
167
+ # loss = CE_loss
168
+
169
+ # backward
170
+ optimizer.zero_grad()
171
+ scaler.scale(loss).backward()
172
+ #loss.backward()
173
+ if args.max_norm:
174
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
175
+ #optimizer.step()
176
+ scaler.step(optimizer)
177
+ scaler.update()
178
+ #dist.barrier()
179
+
180
+ # metric
181
+ iou, pr5 = trainMetricGPU(pred, target, 0.35, 0.5)
182
+ dist.all_reduce(loss.detach())
183
+ dist.all_reduce(iou)
184
+ dist.all_reduce(pr5)
185
+ loss = loss / dist.get_world_size()
186
+ iou = iou / dist.get_world_size()
187
+ pr5 = pr5 / dist.get_world_size()
188
+
189
+ loss_meter.update(loss.item(), image.size(0))
190
+ iou_meter.update(iou.item(), image.size(0))
191
+ pr_meter.update(pr5.item(), image.size(0))
192
+ lr.update(scheduler.get_last_lr()[-1])
193
+ batch_time.update(time.time() - end)
194
+ end = time.time()
195
+
196
+ if (i + 1) % args.print_freq == 0:
197
+ progress.display(i + 1)
198
+ if dist.get_rank() in [-1, 0]:
199
+ wandb.log(
200
+ {
201
+ "time/batch": batch_time.val,
202
+ "time/data": data_time.val,
203
+ "training/lr": lr.val,
204
+ "training/loss": loss_meter.val,
205
+ "training/iou": iou_meter.val,
206
+ "training/prec@50": pr_meter.val,
207
+ },
208
+ step=epoch * len(train_loader) + (i + 1))
209
+ torch.cuda.empty_cache()
210
+
211
+
212
+ @torch.no_grad()
213
+ def validate(val_loader, model, epoch, args):
214
+ iou_list = []
215
+ I_list = []
216
+ U_list = []
217
+ model.eval()
218
+ time.sleep(16)
219
+ for imgs, texts, masks, param in val_loader:
220
+ # data
221
+ imgs = imgs.cuda(non_blocking=True)
222
+ texts = texts.cuda(non_blocking=True)
223
+ # inference
224
+ preds = model(imgs, texts)
225
+ preds = torch.sigmoid(preds)
226
+ if preds.shape[-2:] != imgs.shape[-2:]:
227
+ preds = F.interpolate(preds,
228
+ size=imgs.shape[-2:],
229
+ mode='bicubic',
230
+ align_corners=True).squeeze(1)
231
+ # process one batch
232
+ # for pred, mask_dir, mat, ori_size in zip(preds, param['mask_dir'],
233
+ # param['inverse'],
234
+ # param['ori_size']):
235
+ # h, w = np.array(ori_size)
236
+ # mat = np.array(mat)
237
+ # pred = pred.cpu().numpy()
238
+ # pred = cv2.warpAffine(pred, mat, (w, h),
239
+ # flags=cv2.INTER_CUBIC,
240
+ # borderValue=0.)
241
+ # pred = np.array(pred > 0.35)
242
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
243
+ # mask = mask / 255.
244
+ # # iou
245
+ # inter = np.logical_and(pred, mask)
246
+ # union = np.logical_or(pred, mask)
247
+ # iou = np.sum(inter) / (np.sum(union) + 1e-6)
248
+ # iou_list.append(iou)
249
+ # I_list.append(inter)
250
+ # U_list.append(union)
251
+ for pred, mask in zip(preds, masks):
252
+ # h, w = np.array(ori_size)
253
+ # mat = np.array(mat)
254
+ pred = pred.cpu().numpy()
255
+ # pred = cv2.warpAffine(pred, mat, (w, h),
256
+ # flags=cv2.INTER_CUBIC,
257
+ # borderValue=0.)
258
+ pred = np.array(pred > 0.35)
259
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
260
+ # mask = mask / 255.
261
+ mask = mask.numpy()
262
+ # iou
263
+ inter = np.logical_and(pred, mask)
264
+ union = np.logical_or(pred, mask)
265
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
266
+ I_list.append(inter)
267
+ U_list.append(union)
268
+ iou_list.append(iou)
269
+
270
+ iou_list = np.stack(iou_list)
271
+ iou_list = torch.from_numpy(iou_list).to(imgs.device)
272
+ iou_list = concat_all_gather(iou_list)
273
+
274
+ I_list = np.stack(I_list)
275
+ I_list = torch.from_numpy(I_list).to(imgs.device)
276
+ I_list = concat_all_gather(I_list)
277
+
278
+ U_list = np.stack(U_list)
279
+ U_list = torch.from_numpy(U_list).to(imgs.device)
280
+ U_list = concat_all_gather(U_list)
281
+
282
+ overall_I = I_list.sum().item()
283
+ overall_U = U_list.sum().item()
284
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
285
+
286
+
287
+ prec_list = []
288
+ for thres in torch.arange(0.5, 1.0, 0.1):
289
+ tmp = (iou_list > thres).float().mean()
290
+ prec_list.append(tmp)
291
+ iou = iou_list.mean()
292
+ prec = {}
293
+ temp = ' '
294
+ for i, thres in enumerate(range(5, 10)):
295
+ key = 'Pr@{}'.format(thres * 10)
296
+ value = prec_list[i].item()
297
+ prec[key] = value
298
+ temp += "{}: {:.2f} ".format(key, 100. * value)
299
+ head = 'Evaluation: Epoch=[{}/{}] IoU={:.2f} OIoU={:.4f}'.format(
300
+ epoch, args.epochs, 100. * iou.item(), 100. * overall_IoU)
301
+ logger.info(head + temp)
302
+
303
+ # return three results : mIoU, oIoU and prec results
304
+ torch.cuda.empty_cache()
305
+ return iou.item(), overall_IoU, prec
306
+
307
+
308
+ @torch.no_grad()
309
+ def inference(test_loader, model, args):
310
+ iou_list = []
311
+ I_list = []
312
+ U_list = []
313
+
314
+ tbar = tqdm(test_loader, desc='Inference:', ncols=100)
315
+ model.eval()
316
+ time.sleep(2)
317
+ for img, mask, param in tbar:
318
+ # data
319
+ # img = img.cuda(non_blocking=True)
320
+ # mask = cv2.imread(param['mask_dir'][0], flags=cv2.IMREAD_GRAYSCALE)
321
+ img = img.cuda(non_blocking=True)
322
+ mask = mask[0].cpu().numpy()
323
+
324
+ # dump image & mask
325
+ if args.visualize:
326
+ seg_id = param['seg_id'][0].cpu().numpy()
327
+ img_name = '{}-img.jpg'.format(seg_id)
328
+ mask_name = '{}-mask.png'.format(seg_id)
329
+ cv2.imwrite(filename=os.path.join(args.vis_dir, img_name),
330
+ img=param['ori_img'][0].cpu().numpy())
331
+ cv2.imwrite(filename=os.path.join(args.vis_dir, mask_name),
332
+ img=mask)
333
+ # multiple sentences
334
+ for sent in param['sents']:
335
+ # mask = mask / 255.
336
+ text = tokenize(sent, args.word_len, True)
337
+ text = text.cuda(non_blocking=True)
338
+ # inference
339
+ pred = model(img, text)
340
+ pred = torch.sigmoid(pred)
341
+ if pred.shape[-2:] != img.shape[-2:]:
342
+ pred = F.interpolate(pred,
343
+ size=img.shape[-2:],
344
+ mode='bicubic',
345
+ align_corners=True).squeeze()
346
+ # process one sentence
347
+ # h, w = param['ori_size'].numpy()[0]
348
+ # mat = param['inverse'].numpy()[0]
349
+ pred = pred.cpu().numpy()
350
+ # pred = cv2.warpAffine(pred, mat, (w, h),
351
+ # flags=cv2.INTER_CUBIC,
352
+ # borderValue=0.)
353
+ pred = np.array(pred > 0.35)
354
+ # iou
355
+ inter = np.logical_and(pred, mask)
356
+ union = np.logical_or(pred, mask)
357
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
358
+ iou_list.append(iou)
359
+ I_list.append(inter)
360
+ U_list.append(union)
361
+ # dump prediction
362
+ if args.visualize:
363
+ pred = np.array(pred*255, dtype=np.uint8)
364
+ sent = "_".join(sent[0].split(" "))
365
+ pred_name = '{}-iou={:.2f}-{}.png'.format(seg_id, iou*100, sent)
366
+ cv2.imwrite(filename=os.path.join(args.vis_dir, pred_name),
367
+ img=pred)
368
+ logger.info('=> Metric Calculation <=')
369
+ iou_list = np.stack(iou_list)
370
+ iou_list = torch.from_numpy(iou_list).to(img.device)
371
+
372
+ I_list = np.stack(I_list)
373
+ I_list = torch.from_numpy(I_list).to(img.device)
374
+ U_list = np.stack(U_list)
375
+ U_list = torch.from_numpy(U_list).to(img.device)
376
+ overall_I = I_list.sum().item()
377
+ overall_U = U_list.sum().item()
378
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
379
+
380
+ prec_list = []
381
+ for thres in torch.arange(0.5, 1.0, 0.1):
382
+ tmp = (iou_list > thres).float().mean()
383
+ prec_list.append(tmp)
384
+ iou = iou_list.mean()
385
+ prec = {}
386
+ for i, thres in enumerate(range(5, 10)):
387
+ key = 'Pr@{}'.format(thres*10)
388
+ value = prec_list[i].item()
389
+ prec[key] = value
390
+ logger.info('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
391
+ for k, v in prec.items():
392
+ logger.info('{}: {:.2f}.'.format(k, 100.*v))
393
+
394
+ return iou.item(), overall_IoU, prec
engine/engine_cy.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from tqdm import tqdm
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import pdb
8
+ import torch.cuda.amp as amp
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+ import wandb
12
+ from loguru import logger
13
+ from utils.dataset import tokenize
14
+ from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather,
15
+ trainMetricGPU)
16
+
17
+ ## todo : add oIoU metric
18
+
19
+ def train(train_loader, model, optimizer, scheduler, scaler, epoch, args):
20
+ # torch.autograd.set_detect_anomaly(True)
21
+ batch_time = AverageMeter('Batch', ':2.2f')
22
+ data_time = AverageMeter('Data', ':2.2f')
23
+ lr = AverageMeter('Lr', ':1.6f')
24
+ loss_meter = AverageMeter('Loss', ':2.4f')
25
+ iou_meter = AverageMeter('IoU', ':2.2f')
26
+ pr_meter = AverageMeter('Prec@50', ':2.2f')
27
+ progress = ProgressMeter(
28
+ len(train_loader),
29
+ [batch_time, data_time, lr, loss_meter, iou_meter, pr_meter],
30
+ prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs))
31
+
32
+ model.train()
33
+ time.sleep(2)
34
+ end = time.time()
35
+
36
+ # size_list = [320, 352, 384, 416, 448, 480, 512]
37
+ # idx = np.random.choice(len(size_list))
38
+ # new_size = size_list[idx]
39
+
40
+ for i, (image, text, target) in enumerate(train_loader):
41
+ data_time.update(time.time() - end)
42
+
43
+ # data
44
+ image = image.cuda(non_blocking=True)
45
+ text = text.cuda(non_blocking=True)
46
+ target = target.cuda(non_blocking=True).unsqueeze(1)
47
+ # # multi-scale training
48
+ # image = F.interpolate(image, size=(new_size, new_size), mode='bilinear')
49
+
50
+ # forward
51
+ with amp.autocast():
52
+ pred, target, loss = model(image, text, target)
53
+
54
+ # backward
55
+ optimizer.zero_grad()
56
+ scaler.scale(loss).backward()
57
+ if args.max_norm:
58
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
59
+
60
+
61
+ # for name, param in model.named_parameters():
62
+ # if param.grad is not None:
63
+ # if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():
64
+ # print(f"Inf/NaN in gradients: {name}")
65
+
66
+ # if args.max_norm:
67
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
68
+ scaler.step(optimizer)
69
+ scaler.update()
70
+
71
+ # metric
72
+ iou, pr5 = trainMetricGPU(pred, target, 0.35, 0.5)
73
+ dist.all_reduce(loss.detach())
74
+ dist.all_reduce(iou)
75
+ dist.all_reduce(pr5)
76
+ loss = loss / dist.get_world_size()
77
+ iou = iou / dist.get_world_size()
78
+ pr5 = pr5 / dist.get_world_size()
79
+
80
+ loss_meter.update(loss.item(), image.size(0))
81
+ iou_meter.update(iou.item(), image.size(0))
82
+ pr_meter.update(pr5.item(), image.size(0))
83
+ lr.update(scheduler.get_last_lr()[-1])
84
+ batch_time.update(time.time() - end)
85
+ end = time.time()
86
+
87
+ if (i + 1) % args.print_freq == 0:
88
+ progress.display(i + 1)
89
+ if dist.get_rank() in [-1, 0]:
90
+ wandb.log(
91
+ {
92
+ "time/batch": batch_time.val,
93
+ "time/data": data_time.val,
94
+ "training/lr": lr.val,
95
+ "training/loss": loss_meter.val,
96
+ "training/iou": iou_meter.val,
97
+ "training/prec@50": pr_meter.val,
98
+ },
99
+ step=epoch * len(train_loader) + (i + 1))
100
+
101
+
102
+ @torch.no_grad()
103
+ def validate(val_loader, model, epoch, args):
104
+ iou_list = []
105
+ I_list = []
106
+ U_list = []
107
+ model.eval()
108
+ time.sleep(2)
109
+ for imgs, texts, masks, param in val_loader:
110
+ # data
111
+ imgs = imgs.cuda(non_blocking=True)
112
+ texts = texts.cuda(non_blocking=True)
113
+ # inference
114
+ preds = model(imgs, texts)
115
+ preds = torch.sigmoid(preds)
116
+ if preds.shape[-2:] != imgs.shape[-2:]:
117
+ preds = F.interpolate(preds,
118
+ size=imgs.shape[-2:],
119
+ mode='bicubic',
120
+ align_corners=True).squeeze(1)
121
+ # process one batch
122
+ # for pred, mask_dir, mat, ori_size in zip(preds, param['mask_dir'],
123
+ # param['inverse'],
124
+ # param['ori_size']):
125
+ # h, w = np.array(ori_size)
126
+ # mat = np.array(mat)
127
+ # pred = pred.cpu().numpy()
128
+ # pred = cv2.warpAffine(pred, mat, (w, h),
129
+ # flags=cv2.INTER_CUBIC,
130
+ # borderValue=0.)
131
+ # pred = np.array(pred > 0.35)
132
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
133
+ # mask = mask / 255.
134
+ # # iou
135
+ # inter = np.logical_and(pred, mask)
136
+ # union = np.logical_or(pred, mask)
137
+ # iou = np.sum(inter) / (np.sum(union) + 1e-6)
138
+ # iou_list.append(iou)
139
+ # I_list.append(inter)
140
+ # U_list.append(union)
141
+ for pred, mask in zip(preds, masks):
142
+ # h, w = np.array(ori_size)
143
+ # mat = np.array(mat)
144
+ pred = pred.cpu().numpy()
145
+ # pred = cv2.warpAffine(pred, mat, (w, h),
146
+ # flags=cv2.INTER_CUBIC,
147
+ # borderValue=0.)
148
+ pred = np.array(pred > 0.35)
149
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
150
+ # mask = mask / 255.
151
+ mask = mask.numpy()
152
+ # iou
153
+ inter = np.logical_and(pred, mask)
154
+ union = np.logical_or(pred, mask)
155
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
156
+ I_list.append(inter)
157
+ U_list.append(union)
158
+ iou_list.append(iou)
159
+
160
+ iou_list = np.stack(iou_list)
161
+ iou_list = torch.from_numpy(iou_list).to(imgs.device)
162
+ iou_list = concat_all_gather(iou_list)
163
+
164
+ I_list = np.stack(I_list)
165
+ I_list = torch.from_numpy(I_list).to(imgs.device)
166
+ I_list = concat_all_gather(I_list)
167
+
168
+ U_list = np.stack(U_list)
169
+ U_list = torch.from_numpy(U_list).to(imgs.device)
170
+ U_list = concat_all_gather(U_list)
171
+
172
+ overall_I = I_list.sum().item()
173
+ overall_U = U_list.sum().item()
174
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
175
+
176
+
177
+ prec_list = []
178
+ for thres in torch.arange(0.5, 1.0, 0.1):
179
+ tmp = (iou_list > thres).float().mean()
180
+ prec_list.append(tmp)
181
+ iou = iou_list.mean()
182
+ prec = {}
183
+ temp = ' '
184
+ for i, thres in enumerate(range(5, 10)):
185
+ key = 'Pr@{}'.format(thres * 10)
186
+ value = prec_list[i].item()
187
+ prec[key] = value
188
+ temp += "{}: {:.2f} ".format(key, 100. * value)
189
+ head = 'Evaluation: Epoch=[{}/{}] IoU={:.2f} OIoU={:.4f}'.format(
190
+ epoch, args.epochs, 100. * iou.item(), 100. * overall_IoU)
191
+ logger.info(head + temp)
192
+ # print(head)
193
+
194
+ # return three results : mIoU, oIoU and prec results
195
+ return iou.item(), overall_IoU, prec
196
+
197
+
198
+ @torch.no_grad()
199
+ def inference(test_loader, model, args):
200
+ iou_list = []
201
+ I_list = []
202
+ U_list = []
203
+
204
+ tbar = tqdm(test_loader, desc='Inference:', ncols=100)
205
+ model.eval()
206
+ time.sleep(2)
207
+ for img, mask, param in tbar:
208
+ # data
209
+ # img = img.cuda(non_blocking=True)
210
+ # mask = cv2.imread(param['mask_dir'][0], flags=cv2.IMREAD_GRAYSCALE)
211
+ img = img.cuda(non_blocking=True)
212
+ mask = mask[0].cpu().numpy()
213
+
214
+ # dump image & mask
215
+ if args.visualize:
216
+ seg_id = param['seg_id'][0].cpu().numpy()
217
+ img_name = '{}-img.jpg'.format(seg_id)
218
+ mask_name = '{}-mask.png'.format(seg_id)
219
+ cv2.imwrite(filename=os.path.join(args.vis_dir, img_name),
220
+ img=param['ori_img'][0].cpu().numpy())
221
+ cv2.imwrite(filename=os.path.join(args.vis_dir, mask_name),
222
+ img=mask)
223
+ # multiple sentences
224
+ for sent in param['sents']:
225
+ # mask = mask / 255.
226
+ text = tokenize(sent, args.word_len, True)
227
+ text = text.cuda(non_blocking=True)
228
+ # inference
229
+ pred = model(img, text)
230
+ pred = torch.sigmoid(pred)
231
+ if pred.shape[-2:] != img.shape[-2:]:
232
+ pred = F.interpolate(pred,
233
+ size=img.shape[-2:],
234
+ mode='bicubic',
235
+ align_corners=True).squeeze()
236
+ # process one sentence
237
+ # h, w = param['ori_size'].numpy()[0]
238
+ # mat = param['inverse'].numpy()[0]
239
+ pred = pred.cpu().numpy()
240
+ # pred = cv2.warpAffine(pred, mat, (w, h),
241
+ # flags=cv2.INTER_CUBIC,
242
+ # borderValue=0.)
243
+ pred = np.array(pred > 0.35)
244
+ # iou
245
+ inter = np.logical_and(pred, mask)
246
+ union = np.logical_or(pred, mask)
247
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
248
+ iou_list.append(iou)
249
+ I_list.append(inter)
250
+ U_list.append(union)
251
+ # dump prediction
252
+ if args.visualize:
253
+ pred = np.array(pred*255, dtype=np.uint8)
254
+ sent = "_".join(sent[0].split(" "))
255
+ pred_name = '{}-iou={:.2f}-{}.png'.format(seg_id, iou*100, sent)
256
+ cv2.imwrite(filename=os.path.join(args.vis_dir, pred_name),
257
+ img=pred)
258
+ logger.info('=> Metric Calculation <=')
259
+ iou_list = np.stack(iou_list)
260
+ iou_list = torch.from_numpy(iou_list).to(img.device)
261
+
262
+ I_list = np.stack(I_list)
263
+ I_list = torch.from_numpy(I_list).to(img.device)
264
+ U_list = np.stack(U_list)
265
+ U_list = torch.from_numpy(U_list).to(img.device)
266
+ overall_I = I_list.sum().item()
267
+ overall_U = U_list.sum().item()
268
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
269
+
270
+ prec_list = []
271
+ for thres in torch.arange(0.5, 1.0, 0.1):
272
+ tmp = (iou_list > thres).float().mean()
273
+ prec_list.append(tmp)
274
+ iou = iou_list.mean()
275
+ prec = {}
276
+ for i, thres in enumerate(range(5, 10)):
277
+ key = 'Pr@{}'.format(thres*10)
278
+ value = prec_list[i].item()
279
+ prec[key] = value
280
+ logger.info('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
281
+ print('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
282
+ for k, v in prec.items():
283
+ logger.info('{}: {:.2f}.'.format(k, 100.*v))
284
+
285
+ return iou.item(), overall_IoU, prec
engine/engine_verbonly.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from tqdm import tqdm
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import pdb
8
+ import torch.cuda.amp as amp
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+ import wandb
12
+ from loguru import logger
13
+ from utils.dataset_verbonly import tokenize
14
+ from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather,
15
+ trainMetricGPU)
16
+
17
+ ## todo : add oIoU metric
18
+
19
+ def train(train_loader, model, optimizer, scheduler, scaler, epoch, args):
20
+ # torch.autograd.set_detect_anomaly(True)
21
+ batch_time = AverageMeter('Batch', ':2.2f')
22
+ data_time = AverageMeter('Data', ':2.2f')
23
+ lr = AverageMeter('Lr', ':1.6f')
24
+ loss_meter = AverageMeter('Loss', ':2.4f')
25
+ iou_meter = AverageMeter('IoU', ':2.2f')
26
+ pr_meter = AverageMeter('Prec@50', ':2.2f')
27
+ progress = ProgressMeter(
28
+ len(train_loader),
29
+ [batch_time, data_time, lr, loss_meter, iou_meter, pr_meter],
30
+ prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs))
31
+
32
+ model.train()
33
+ time.sleep(2)
34
+ end = time.time()
35
+
36
+ # size_list = [320, 352, 384, 416, 448, 480, 512]
37
+ # idx = np.random.choice(len(size_list))
38
+ # new_size = size_list[idx]
39
+
40
+ for i, (image, text, target, hardpos) in enumerate(train_loader):
41
+ data_time.update(time.time() - end)
42
+
43
+ # data
44
+ image = image.cuda(non_blocking=True)
45
+ text = text.cuda(non_blocking=True)
46
+ target = target.cuda(non_blocking=True).unsqueeze(1)
47
+ hardpos = hardpos.cuda(non_blocking=True)
48
+ # # multi-scale training
49
+ # image = F.interpolate(image, size=(new_size, new_size), mode='bilinear')
50
+
51
+ # forward
52
+ with amp.autocast():
53
+ pred, target, loss = model(image, text, target, hardpos)
54
+
55
+ # backward
56
+ optimizer.zero_grad()
57
+ scaler.scale(loss).backward()
58
+ if args.max_norm:
59
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
60
+
61
+
62
+ # for name, param in model.named_parameters():
63
+ # if param.grad is not None:
64
+ # if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():
65
+ # print(f"Inf/NaN in gradients: {name}")
66
+
67
+ # if args.max_norm:
68
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
69
+ scaler.step(optimizer)
70
+ scaler.update()
71
+
72
+ # metric
73
+ iou, pr5 = trainMetricGPU(pred, target, 0.35, 0.5)
74
+ dist.all_reduce(loss.detach())
75
+ dist.all_reduce(iou)
76
+ dist.all_reduce(pr5)
77
+ loss = loss / dist.get_world_size()
78
+ iou = iou / dist.get_world_size()
79
+ pr5 = pr5 / dist.get_world_size()
80
+
81
+ loss_meter.update(loss.item(), image.size(0))
82
+ iou_meter.update(iou.item(), image.size(0))
83
+ pr_meter.update(pr5.item(), image.size(0))
84
+ lr.update(scheduler.get_last_lr()[-1])
85
+ batch_time.update(time.time() - end)
86
+ end = time.time()
87
+
88
+ if (i + 1) % args.print_freq == 0:
89
+ progress.display(i + 1)
90
+ if dist.get_rank() in [-1, 0]:
91
+ wandb.log(
92
+ {
93
+ "time/batch": batch_time.val,
94
+ "time/data": data_time.val,
95
+ "training/lr": lr.val,
96
+ "training/loss": loss_meter.val,
97
+ "training/iou": iou_meter.val,
98
+ "training/prec@50": pr_meter.val,
99
+ },
100
+ step=epoch * len(train_loader) + (i + 1))
101
+
102
+
103
+ @torch.no_grad()
104
+ def validate(val_loader, model, epoch, args):
105
+ iou_list = []
106
+ I_list = []
107
+ U_list = []
108
+ model.eval()
109
+ time.sleep(2)
110
+ for imgs, texts, masks, param in val_loader:
111
+ # data
112
+ imgs = imgs.cuda(non_blocking=True)
113
+ texts = texts.cuda(non_blocking=True)
114
+ # inference
115
+ preds = model(imgs, texts)
116
+ preds = torch.sigmoid(preds)
117
+ if preds.shape[-2:] != imgs.shape[-2:]:
118
+ preds = F.interpolate(preds,
119
+ size=imgs.shape[-2:],
120
+ mode='bicubic',
121
+ align_corners=True).squeeze(1)
122
+ # process one batch
123
+ # for pred, mask_dir, mat, ori_size in zip(preds, param['mask_dir'],
124
+ # param['inverse'],
125
+ # param['ori_size']):
126
+ # h, w = np.array(ori_size)
127
+ # mat = np.array(mat)
128
+ # pred = pred.cpu().numpy()
129
+ # pred = cv2.warpAffine(pred, mat, (w, h),
130
+ # flags=cv2.INTER_CUBIC,
131
+ # borderValue=0.)
132
+ # pred = np.array(pred > 0.35)
133
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
134
+ # mask = mask / 255.
135
+ # # iou
136
+ # inter = np.logical_and(pred, mask)
137
+ # union = np.logical_or(pred, mask)
138
+ # iou = np.sum(inter) / (np.sum(union) + 1e-6)
139
+ # iou_list.append(iou)
140
+ # I_list.append(inter)
141
+ # U_list.append(union)
142
+ for pred, mask in zip(preds, masks):
143
+ # h, w = np.array(ori_size)
144
+ # mat = np.array(mat)
145
+ pred = pred.cpu().numpy()
146
+ # pred = cv2.warpAffine(pred, mat, (w, h),
147
+ # flags=cv2.INTER_CUBIC,
148
+ # borderValue=0.)
149
+ pred = np.array(pred > 0.35)
150
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
151
+ # mask = mask / 255.
152
+ mask = mask.numpy()
153
+ # iou
154
+ inter = np.logical_and(pred, mask)
155
+ union = np.logical_or(pred, mask)
156
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
157
+ I_list.append(inter)
158
+ U_list.append(union)
159
+ iou_list.append(iou)
160
+
161
+ iou_list = np.stack(iou_list)
162
+ iou_list = torch.from_numpy(iou_list).to(imgs.device)
163
+ iou_list = concat_all_gather(iou_list)
164
+
165
+ I_list = np.stack(I_list)
166
+ I_list = torch.from_numpy(I_list).to(imgs.device)
167
+ I_list = concat_all_gather(I_list)
168
+
169
+ U_list = np.stack(U_list)
170
+ U_list = torch.from_numpy(U_list).to(imgs.device)
171
+ U_list = concat_all_gather(U_list)
172
+
173
+ overall_I = I_list.sum().item()
174
+ overall_U = U_list.sum().item()
175
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
176
+
177
+
178
+ prec_list = []
179
+ for thres in torch.arange(0.5, 1.0, 0.1):
180
+ tmp = (iou_list > thres).float().mean()
181
+ prec_list.append(tmp)
182
+ iou = iou_list.mean()
183
+ prec = {}
184
+ temp = ' '
185
+ for i, thres in enumerate(range(5, 10)):
186
+ key = 'Pr@{}'.format(thres * 10)
187
+ value = prec_list[i].item()
188
+ prec[key] = value
189
+ temp += "{}: {:.2f} ".format(key, 100. * value)
190
+ head = 'Evaluation: Epoch=[{}/{}] IoU={:.2f} OIoU={:.4f}'.format(
191
+ epoch, args.epochs, 100. * iou.item(), 100. * overall_IoU)
192
+ logger.info(head + temp)
193
+ # print(head)
194
+
195
+ # return three results : mIoU, oIoU and prec results
196
+ return iou.item(), overall_IoU, prec
197
+
198
+
199
+ @torch.no_grad()
200
+ def inference(test_loader, model, args):
201
+ iou_list = []
202
+ I_list = []
203
+ U_list = []
204
+
205
+ tbar = tqdm(test_loader, desc='Inference:', ncols=100)
206
+ model.eval()
207
+ time.sleep(2)
208
+ for img, mask, param in tbar:
209
+ # data
210
+ # img = img.cuda(non_blocking=True)
211
+ # mask = cv2.imread(param['mask_dir'][0], flags=cv2.IMREAD_GRAYSCALE)
212
+ img = img.cuda(non_blocking=True)
213
+ mask = mask[0].cpu().numpy()
214
+
215
+ # dump image & mask
216
+ if args.visualize:
217
+ seg_id = param['seg_id'][0].cpu().numpy()
218
+ img_name = '{}-img.jpg'.format(seg_id)
219
+ mask_name = '{}-mask.png'.format(seg_id)
220
+ cv2.imwrite(filename=os.path.join(args.vis_dir, img_name),
221
+ img=param['ori_img'][0].cpu().numpy())
222
+ cv2.imwrite(filename=os.path.join(args.vis_dir, mask_name),
223
+ img=mask)
224
+ # multiple sentences
225
+ for sent in param['sents']:
226
+ # mask = mask / 255.
227
+ text = tokenize(sent, args.word_len, True)
228
+ text = text.cuda(non_blocking=True)
229
+ # inference
230
+ pred = model(img, text)
231
+ pred = torch.sigmoid(pred)
232
+ if pred.shape[-2:] != img.shape[-2:]:
233
+ pred = F.interpolate(pred,
234
+ size=img.shape[-2:],
235
+ mode='bicubic',
236
+ align_corners=True).squeeze()
237
+ # process one sentence
238
+ # h, w = param['ori_size'].numpy()[0]
239
+ # mat = param['inverse'].numpy()[0]
240
+ pred = pred.cpu().numpy()
241
+ # pred = cv2.warpAffine(pred, mat, (w, h),
242
+ # flags=cv2.INTER_CUBIC,
243
+ # borderValue=0.)
244
+ pred = np.array(pred > 0.35)
245
+ # iou
246
+ inter = np.logical_and(pred, mask)
247
+ union = np.logical_or(pred, mask)
248
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
249
+ iou_list.append(iou)
250
+ I_list.append(inter)
251
+ U_list.append(union)
252
+ # dump prediction
253
+ if args.visualize:
254
+ pred = np.array(pred*255, dtype=np.uint8)
255
+ sent = "_".join(sent[0].split(" "))
256
+ pred_name = '{}-iou={:.2f}-{}.png'.format(seg_id, iou*100, sent)
257
+ cv2.imwrite(filename=os.path.join(args.vis_dir, pred_name),
258
+ img=pred)
259
+ logger.info('=> Metric Calculation <=')
260
+ iou_list = np.stack(iou_list)
261
+ iou_list = torch.from_numpy(iou_list).to(img.device)
262
+
263
+ I_list = np.stack(I_list)
264
+ I_list = torch.from_numpy(I_list).to(img.device)
265
+ U_list = np.stack(U_list)
266
+ U_list = torch.from_numpy(U_list).to(img.device)
267
+ overall_I = I_list.sum().item()
268
+ overall_U = U_list.sum().item()
269
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
270
+
271
+ prec_list = []
272
+ for thres in torch.arange(0.5, 1.0, 0.1):
273
+ tmp = (iou_list > thres).float().mean()
274
+ prec_list.append(tmp)
275
+ iou = iou_list.mean()
276
+ prec = {}
277
+ for i, thres in enumerate(range(5, 10)):
278
+ key = 'Pr@{}'.format(thres*10)
279
+ value = prec_list[i].item()
280
+ prec[key] = value
281
+ logger.info('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
282
+ print('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
283
+ for k, v in prec.items():
284
+ logger.info('{}: {:.2f}.'.format(k, 100.*v))
285
+
286
+ return iou.item(), overall_IoU, prec
engine/engine_verbonly_hardneg.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from tqdm import tqdm
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import pdb
8
+ import torch.cuda.amp as amp
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+ import wandb
12
+ from loguru import logger
13
+ from utils.dataset_verbonly import tokenize
14
+ from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather,
15
+ trainMetricGPU)
16
+
17
+ ## todo : add oIoU metric
18
+
19
+ def train(train_loader, model, optimizer, scheduler, scaler, epoch, args):
20
+ # torch.autograd.set_detect_anomaly(True)
21
+ batch_time = AverageMeter('Batch', ':2.2f')
22
+ data_time = AverageMeter('Data', ':2.2f')
23
+ lr = AverageMeter('Lr', ':1.6f')
24
+ loss_meter = AverageMeter('Loss', ':2.4f')
25
+ iou_meter = AverageMeter('IoU', ':2.2f')
26
+ pr_meter = AverageMeter('Prec@50', ':2.2f')
27
+ progress = ProgressMeter(
28
+ len(train_loader),
29
+ [batch_time, data_time, lr, loss_meter, iou_meter, pr_meter],
30
+ prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs))
31
+
32
+ model.train()
33
+ time.sleep(2)
34
+ end = time.time()
35
+
36
+ # size_list = [320, 352, 384, 416, 448, 480, 512]
37
+ # idx = np.random.choice(len(size_list))
38
+ # new_size = size_list[idx]
39
+
40
+ for i, (image, text, target, hardpos, hardneg) in enumerate(train_loader):
41
+ data_time.update(time.time() - end)
42
+
43
+ # data
44
+ image = image.cuda(non_blocking=True)
45
+ text = text.cuda(non_blocking=True)
46
+ hardpos = hardpos.cuda(non_blocking=True)
47
+ hardneg = hardneg.cuda(non_blocking=True)
48
+ target = target.cuda(non_blocking=True).unsqueeze(1)
49
+ # # multi-scale training
50
+ # image = F.interpolate(image, size=(new_size, new_size), mode='bilinear')
51
+
52
+ # forward
53
+ with amp.autocast():
54
+ pred, target, loss = model(image, text, target, hardpos, hardneg)
55
+
56
+ # backward
57
+ optimizer.zero_grad()
58
+ scaler.scale(loss).backward()
59
+ if args.max_norm:
60
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
61
+
62
+
63
+ # for name, param in model.named_parameters():
64
+ # if param.grad is not None:
65
+ # if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():
66
+ # print(f"Inf/NaN in gradients: {name}")
67
+
68
+ # if args.max_norm:
69
+ # torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
70
+ scaler.step(optimizer)
71
+ scaler.update()
72
+
73
+ # metric
74
+ iou, pr5 = trainMetricGPU(pred, target, 0.35, 0.5)
75
+ dist.all_reduce(loss.detach())
76
+ dist.all_reduce(iou)
77
+ dist.all_reduce(pr5)
78
+ loss = loss / dist.get_world_size()
79
+ iou = iou / dist.get_world_size()
80
+ pr5 = pr5 / dist.get_world_size()
81
+
82
+ loss_meter.update(loss.item(), image.size(0))
83
+ iou_meter.update(iou.item(), image.size(0))
84
+ pr_meter.update(pr5.item(), image.size(0))
85
+ lr.update(scheduler.get_last_lr()[-1])
86
+ batch_time.update(time.time() - end)
87
+ end = time.time()
88
+
89
+ if (i + 1) % args.print_freq == 0:
90
+ progress.display(i + 1)
91
+ if dist.get_rank() in [-1, 0]:
92
+ wandb.log(
93
+ {
94
+ "time/batch": batch_time.val,
95
+ "time/data": data_time.val,
96
+ "training/lr": lr.val,
97
+ "training/loss": loss_meter.val,
98
+ "training/iou": iou_meter.val,
99
+ "training/prec@50": pr_meter.val,
100
+ },
101
+ step=epoch * len(train_loader) + (i + 1))
102
+
103
+
104
+ @torch.no_grad()
105
+ def validate(val_loader, model, epoch, args):
106
+ iou_list = []
107
+ I_list = []
108
+ U_list = []
109
+ model.eval()
110
+ time.sleep(2)
111
+ for imgs, texts, masks, param in val_loader:
112
+ # data
113
+ imgs = imgs.cuda(non_blocking=True)
114
+ texts = texts.cuda(non_blocking=True)
115
+ # inference
116
+ preds = model(imgs, texts)
117
+ preds = torch.sigmoid(preds)
118
+ if preds.shape[-2:] != imgs.shape[-2:]:
119
+ preds = F.interpolate(preds,
120
+ size=imgs.shape[-2:],
121
+ mode='bicubic',
122
+ align_corners=True).squeeze(1)
123
+ # process one batch
124
+ # for pred, mask_dir, mat, ori_size in zip(preds, param['mask_dir'],
125
+ # param['inverse'],
126
+ # param['ori_size']):
127
+ # h, w = np.array(ori_size)
128
+ # mat = np.array(mat)
129
+ # pred = pred.cpu().numpy()
130
+ # pred = cv2.warpAffine(pred, mat, (w, h),
131
+ # flags=cv2.INTER_CUBIC,
132
+ # borderValue=0.)
133
+ # pred = np.array(pred > 0.35)
134
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
135
+ # mask = mask / 255.
136
+ # # iou
137
+ # inter = np.logical_and(pred, mask)
138
+ # union = np.logical_or(pred, mask)
139
+ # iou = np.sum(inter) / (np.sum(union) + 1e-6)
140
+ # iou_list.append(iou)
141
+ # I_list.append(inter)
142
+ # U_list.append(union)
143
+ for pred, mask in zip(preds, masks):
144
+ # h, w = np.array(ori_size)
145
+ # mat = np.array(mat)
146
+ pred = pred.cpu().numpy()
147
+ # pred = cv2.warpAffine(pred, mat, (w, h),
148
+ # flags=cv2.INTER_CUBIC,
149
+ # borderValue=0.)
150
+ pred = np.array(pred > 0.35)
151
+ # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE)
152
+ # mask = mask / 255.
153
+ mask = mask.numpy()
154
+ # iou
155
+ inter = np.logical_and(pred, mask)
156
+ union = np.logical_or(pred, mask)
157
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
158
+ I_list.append(inter)
159
+ U_list.append(union)
160
+ iou_list.append(iou)
161
+
162
+ iou_list = np.stack(iou_list)
163
+ iou_list = torch.from_numpy(iou_list).to(imgs.device)
164
+ iou_list = concat_all_gather(iou_list)
165
+
166
+ I_list = np.stack(I_list)
167
+ I_list = torch.from_numpy(I_list).to(imgs.device)
168
+ I_list = concat_all_gather(I_list)
169
+
170
+ U_list = np.stack(U_list)
171
+ U_list = torch.from_numpy(U_list).to(imgs.device)
172
+ U_list = concat_all_gather(U_list)
173
+
174
+ overall_I = I_list.sum().item()
175
+ overall_U = U_list.sum().item()
176
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
177
+
178
+
179
+ prec_list = []
180
+ for thres in torch.arange(0.5, 1.0, 0.1):
181
+ tmp = (iou_list > thres).float().mean()
182
+ prec_list.append(tmp)
183
+ iou = iou_list.mean()
184
+ prec = {}
185
+ temp = ' '
186
+ for i, thres in enumerate(range(5, 10)):
187
+ key = 'Pr@{}'.format(thres * 10)
188
+ value = prec_list[i].item()
189
+ prec[key] = value
190
+ temp += "{}: {:.2f} ".format(key, 100. * value)
191
+ head = 'Evaluation: Epoch=[{}/{}] IoU={:.2f} OIoU={:.4f}'.format(
192
+ epoch, args.epochs, 100. * iou.item(), 100. * overall_IoU)
193
+ logger.info(head + temp)
194
+ # print(head)
195
+
196
+ # return three results : mIoU, oIoU and prec results
197
+ return iou.item(), overall_IoU, prec
198
+
199
+
200
+ @torch.no_grad()
201
+ def inference(test_loader, model, args):
202
+ iou_list = []
203
+ I_list = []
204
+ U_list = []
205
+
206
+ tbar = tqdm(test_loader, desc='Inference:', ncols=100)
207
+ model.eval()
208
+ time.sleep(2)
209
+ for img, mask, param in tbar:
210
+ # data
211
+ # img = img.cuda(non_blocking=True)
212
+ # mask = cv2.imread(param['mask_dir'][0], flags=cv2.IMREAD_GRAYSCALE)
213
+ img = img.cuda(non_blocking=True)
214
+ mask = mask[0].cpu().numpy()
215
+
216
+ # dump image & mask
217
+ if args.visualize:
218
+ seg_id = param['seg_id'][0].cpu().numpy()
219
+ img_name = '{}-img.jpg'.format(seg_id)
220
+ mask_name = '{}-mask.png'.format(seg_id)
221
+ cv2.imwrite(filename=os.path.join(args.vis_dir, img_name),
222
+ img=param['ori_img'][0].cpu().numpy())
223
+ cv2.imwrite(filename=os.path.join(args.vis_dir, mask_name),
224
+ img=mask)
225
+ # multiple sentences
226
+ for sent in param['sents']:
227
+ # mask = mask / 255.
228
+ text = tokenize(sent, args.word_len, True)
229
+ text = text.cuda(non_blocking=True)
230
+ # inference
231
+ pred = model(img, text)
232
+ pred = torch.sigmoid(pred)
233
+ if pred.shape[-2:] != img.shape[-2:]:
234
+ pred = F.interpolate(pred,
235
+ size=img.shape[-2:],
236
+ mode='bicubic',
237
+ align_corners=True).squeeze()
238
+ # process one sentence
239
+ # h, w = param['ori_size'].numpy()[0]
240
+ # mat = param['inverse'].numpy()[0]
241
+ pred = pred.cpu().numpy()
242
+ # pred = cv2.warpAffine(pred, mat, (w, h),
243
+ # flags=cv2.INTER_CUBIC,
244
+ # borderValue=0.)
245
+ pred = np.array(pred > 0.35)
246
+ # iou
247
+ inter = np.logical_and(pred, mask)
248
+ union = np.logical_or(pred, mask)
249
+ iou = np.sum(inter) / (np.sum(union) + 1e-6)
250
+ iou_list.append(iou)
251
+ I_list.append(inter)
252
+ U_list.append(union)
253
+ # dump prediction
254
+ if args.visualize:
255
+ pred = np.array(pred*255, dtype=np.uint8)
256
+ sent = "_".join(sent[0].split(" "))
257
+ pred_name = '{}-iou={:.2f}-{}.png'.format(seg_id, iou*100, sent)
258
+ cv2.imwrite(filename=os.path.join(args.vis_dir, pred_name),
259
+ img=pred)
260
+ logger.info('=> Metric Calculation <=')
261
+ iou_list = np.stack(iou_list)
262
+ iou_list = torch.from_numpy(iou_list).to(img.device)
263
+
264
+ I_list = np.stack(I_list)
265
+ I_list = torch.from_numpy(I_list).to(img.device)
266
+ U_list = np.stack(U_list)
267
+ U_list = torch.from_numpy(U_list).to(img.device)
268
+ overall_I = I_list.sum().item()
269
+ overall_U = U_list.sum().item()
270
+ overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero
271
+
272
+ prec_list = []
273
+ for thres in torch.arange(0.5, 1.0, 0.1):
274
+ tmp = (iou_list > thres).float().mean()
275
+ prec_list.append(tmp)
276
+ iou = iou_list.mean()
277
+ prec = {}
278
+ for i, thres in enumerate(range(5, 10)):
279
+ key = 'Pr@{}'.format(thres*10)
280
+ value = prec_list[i].item()
281
+ prec[key] = value
282
+ logger.info('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
283
+ print('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU))
284
+ for k, v in prec.items():
285
+ logger.info('{}: {:.2f}.'.format(k, 100.*v))
286
+
287
+ return iou.item(), overall_IoU, prec