import os import time import math from tqdm import tqdm import cv2 import numpy as np import torch import torch.nn as nn import torch.cuda.amp as amp import torch.distributed as dist import torch.nn.functional as F import wandb from loguru import logger from utils.dataset import tokenize from utils.misc import (AverageMeter, ProgressMeter, concat_all_gather, trainMetricGPU) def return_mask(emb_distance): B_, B_ = emb_distance.shape positive_mask = torch.zeros_like(emb_distance) for i in range(B_//2): positive_mask[2*i, 2*i+1] = 1 positive_mask[2*i+1, 2*i] = 1 positive_mask.fill_diagonal_(1) negative_mask = torch.ones_like(emb_distance) - positive_mask return positive_mask, negative_mask def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None): # embeddings: ((2*B), C, (H*W)) # n_pos : chunk size of positive pairs # args: args # returns: loss metric_loss = 0 # flatten embeddings B_, C, HW = embeddings.shape emb = torch.mean(embeddings, dim=-1) # (2*B, C) emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C) emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C) emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B) assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \ "Diagonals are not zero. please check the permutation on the batch" # print("distance metrix : ", emb_distance) positive_mask, negative_mask = return_mask(emb_distance) positive_loss = torch.sum(emb_distance * positive_mask) / B_**2 #B_ # negative pairs and loss # negative_mask = torch.ones_like(emb_distance) - positive_mask negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_**2) #(B_**2 - 2*B_)) # print(positive_mask, negative_mask) metric_loss = alpha * positive_loss + (1-alpha) * negative_loss return metric_loss def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None, mask = None): # embeddings: ((2*B), C, (H*W)) # n_pos : chunk size of positive pairs # args: args # returns: loss geometric_loss = 0 # flatten embeddings B_, C, HW = embeddings.shape emb = torch.mean(embeddings, dim=-1) # (2*B, C) emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C) emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C) sim = nn.CosineSimilarity(dim=-1, eps=1e-6) sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (2*B , 2*B) sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) #print("similarity metrix : ", sim_matrix) phi = torch.acos(sim_matrix) # (2*B, 2*B) #print("phi metrix : ", phi) #print(args.batch_size, B_) assert (B_ == args.batch_size * 2 * args.ngpus_per_node), \ "B_ must be 2x batch_size. please check the inputs." # positive pairs and loss positive_mask, negative_mask = return_mask(sim_matrix) # positive_mask = torch.zeros_like(sim_matrix) # for i in range(B_//2): # positive_mask[2*i, 2*i+1] = 1 # positive_mask[2*i+1, 2*i] = 1 # positive_mask.fill_diagonal_(1) positive_loss = torch.sum((phi**2) * positive_mask) / B_**2 # negative pairs and loss # negative_mask = torch.ones_like(sim_matrix) - positive_mask phi_mask = phi < args.phi_threshold negative_loss = (args.phi_threshold - phi)**2 #print(negative_mask * phi_mask) negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / B_**2 #print("pos loss, neg loss : ", positive_loss, negative_loss) geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss return geometric_loss def train(train_loader, model, optimizer, scheduler, scaler, epoch, args): batch_time = AverageMeter('Batch', ':2.2f') data_time = AverageMeter('Data', ':2.2f') lr = AverageMeter('Lr', ':1.6f') loss_meter = AverageMeter('Loss', ':2.4f') iou_meter = AverageMeter('IoU', ':2.2f') pr_meter = AverageMeter('Prec@50', ':2.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, lr, loss_meter, iou_meter, pr_meter], prefix="Training: Epoch=[{}/{}] ".format(epoch, args.epochs)) metric_learning = args.metric_learning mix_distance_angular = args.mix_distance_angular positive_strength = args.positive_strength angular_loss_weight = args.metric_loss_weight * math.exp(-3.0 * (1-epoch/args.epochs)**2) #print("epoch : ", epoch, ", angular loss weight : ", angular_loss_weight) distance_loss_weight = args.distance_loss_weight model.train() time.sleep(2) end = time.time() # size_list = [320, 352, 384, 416, 448, 480, 512] # idx = np.random.choice(len(size_list)) # new_size = size_list[idx] for i, (image, text, target) in enumerate(train_loader): data_time.update(time.time() - end) # data image = image.cuda(non_blocking=True) text = text.cuda(non_blocking=True) target = target.cuda(non_blocking=True).unsqueeze(1) if i == 1 : print("Original input shape : ", image.shape, text.shape, target.shape) # # multi-scale training # image = F.interpolate(image, size=(new_size, new_size), mode='bilinear') # masking when params exists #mask_tensor = torch.tensor([True if params[i] else False for i in range(len(params))], dtype=torch.bool) # forward with amp.autocast(): # pred, target, loss = model(image, text, target) pred, target, CE_loss, metric_tensor = model(image, text, target) # gather tensors metric_tensor = concat_all_gather(metric_tensor) # get metric loss #print("gathered tensor shape : ", metric_tensor.shape) metric_loss = 0 if metric_learning: metric_loss += \ angular_loss_weight * AngularMetricLoss(metric_tensor, 2, alpha=positive_strength, args = args) #, mask=mask_tensor) if mix_distance_angular: metric_loss += \ distance_loss_weight * MetricLoss(metric_tensor, 2, alpha=positive_strength, args = args) #, mask=mask_tensor) loss = (CE_loss + metric_loss) / \ (1 + angular_loss_weight*metric_learning + \ distance_loss_weight*metric_learning*mix_distance_angular) # backward optimizer.zero_grad() scaler.scale(loss).backward() #loss.backward() if args.max_norm: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) #optimizer.step() scaler.step(optimizer) scaler.update() #dist.barrier() # metric iou, pr5 = trainMetricGPU(pred, target, 0.35, 0.5) dist.all_reduce(loss.detach()) dist.all_reduce(iou) dist.all_reduce(pr5) loss = loss / dist.get_world_size() iou = iou / dist.get_world_size() pr5 = pr5 / dist.get_world_size() loss_meter.update(loss.item(), image.size(0)) iou_meter.update(iou.item(), image.size(0)) pr_meter.update(pr5.item(), image.size(0)) lr.update(scheduler.get_last_lr()[-1]) batch_time.update(time.time() - end) end = time.time() if (i + 1) % args.print_freq == 0: progress.display(i + 1) if dist.get_rank() in [-1, 0]: wandb.log( { "time/batch": batch_time.val, "time/data": data_time.val, "training/lr": lr.val, "training/loss": loss_meter.val, "training/iou": iou_meter.val, "training/prec@50": pr_meter.val, }, step=epoch * len(train_loader) + (i + 1)) torch.cuda.empty_cache() @torch.no_grad() def validate(val_loader, model, epoch, args): iou_list = [] I_list = [] U_list = [] model.eval() time.sleep(16) for imgs, texts, masks, param in val_loader: # data imgs = imgs.cuda(non_blocking=True) texts = texts.cuda(non_blocking=True) # inference preds = model(imgs, texts) preds = torch.sigmoid(preds) if preds.shape[-2:] != imgs.shape[-2:]: preds = F.interpolate(preds, size=imgs.shape[-2:], mode='bicubic', align_corners=True).squeeze(1) # process one batch # for pred, mask_dir, mat, ori_size in zip(preds, param['mask_dir'], # param['inverse'], # param['ori_size']): # h, w = np.array(ori_size) # mat = np.array(mat) # pred = pred.cpu().numpy() # pred = cv2.warpAffine(pred, mat, (w, h), # flags=cv2.INTER_CUBIC, # borderValue=0.) # pred = np.array(pred > 0.35) # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE) # mask = mask / 255. # # iou # inter = np.logical_and(pred, mask) # union = np.logical_or(pred, mask) # iou = np.sum(inter) / (np.sum(union) + 1e-6) # iou_list.append(iou) # I_list.append(inter) # U_list.append(union) for pred, mask in zip(preds, masks): # h, w = np.array(ori_size) # mat = np.array(mat) pred = pred.cpu().numpy() # pred = cv2.warpAffine(pred, mat, (w, h), # flags=cv2.INTER_CUBIC, # borderValue=0.) pred = np.array(pred > 0.35) # mask = cv2.imread(mask_dir, flags=cv2.IMREAD_GRAYSCALE) # mask = mask / 255. mask = mask.numpy() # iou inter = np.logical_and(pred, mask) union = np.logical_or(pred, mask) iou = np.sum(inter) / (np.sum(union) + 1e-6) I_list.append(inter) U_list.append(union) iou_list.append(iou) iou_list = np.stack(iou_list) iou_list = torch.from_numpy(iou_list).to(imgs.device) iou_list = concat_all_gather(iou_list) I_list = np.stack(I_list) I_list = torch.from_numpy(I_list).to(imgs.device) I_list = concat_all_gather(I_list) U_list = np.stack(U_list) U_list = torch.from_numpy(U_list).to(imgs.device) U_list = concat_all_gather(U_list) overall_I = I_list.sum().item() overall_U = U_list.sum().item() overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero prec_list = [] for thres in torch.arange(0.5, 1.0, 0.1): tmp = (iou_list > thres).float().mean() prec_list.append(tmp) iou = iou_list.mean() prec = {} temp = ' ' for i, thres in enumerate(range(5, 10)): key = 'Pr@{}'.format(thres * 10) value = prec_list[i].item() prec[key] = value temp += "{}: {:.2f} ".format(key, 100. * value) head = 'Evaluation: Epoch=[{}/{}] IoU={:.2f} OIoU={:.4f}'.format( epoch, args.epochs, 100. * iou.item(), 100. * overall_IoU) logger.info(head + temp) # return three results : mIoU, oIoU and prec results torch.cuda.empty_cache() return iou.item(), overall_IoU, prec @torch.no_grad() def inference(test_loader, model, args): iou_list = [] I_list = [] U_list = [] tbar = tqdm(test_loader, desc='Inference:', ncols=100) model.eval() time.sleep(2) for img, mask, param in tbar: # data # img = img.cuda(non_blocking=True) # mask = cv2.imread(param['mask_dir'][0], flags=cv2.IMREAD_GRAYSCALE) img = img.cuda(non_blocking=True) mask = mask[0].cpu().numpy() # dump image & mask if args.visualize: seg_id = param['seg_id'][0].cpu().numpy() img_name = '{}-img.jpg'.format(seg_id) mask_name = '{}-mask.png'.format(seg_id) cv2.imwrite(filename=os.path.join(args.vis_dir, img_name), img=param['ori_img'][0].cpu().numpy()) cv2.imwrite(filename=os.path.join(args.vis_dir, mask_name), img=mask) # multiple sentences for sent in param['sents']: # mask = mask / 255. text = tokenize(sent, args.word_len, True) text = text.cuda(non_blocking=True) # inference pred = model(img, text) pred = torch.sigmoid(pred) if pred.shape[-2:] != img.shape[-2:]: pred = F.interpolate(pred, size=img.shape[-2:], mode='bicubic', align_corners=True).squeeze() # process one sentence # h, w = param['ori_size'].numpy()[0] # mat = param['inverse'].numpy()[0] pred = pred.cpu().numpy() # pred = cv2.warpAffine(pred, mat, (w, h), # flags=cv2.INTER_CUBIC, # borderValue=0.) pred = np.array(pred > 0.35) # iou inter = np.logical_and(pred, mask) union = np.logical_or(pred, mask) iou = np.sum(inter) / (np.sum(union) + 1e-6) iou_list.append(iou) I_list.append(inter) U_list.append(union) # dump prediction if args.visualize: pred = np.array(pred*255, dtype=np.uint8) sent = "_".join(sent[0].split(" ")) pred_name = '{}-iou={:.2f}-{}.png'.format(seg_id, iou*100, sent) cv2.imwrite(filename=os.path.join(args.vis_dir, pred_name), img=pred) logger.info('=> Metric Calculation <=') iou_list = np.stack(iou_list) iou_list = torch.from_numpy(iou_list).to(img.device) I_list = np.stack(I_list) I_list = torch.from_numpy(I_list).to(img.device) U_list = np.stack(U_list) U_list = torch.from_numpy(U_list).to(img.device) overall_I = I_list.sum().item() overall_U = U_list.sum().item() overall_IoU = overall_I / (overall_U + 1e-6) # to avoid division by zero prec_list = [] for thres in torch.arange(0.5, 1.0, 0.1): tmp = (iou_list > thres).float().mean() prec_list.append(tmp) iou = iou_list.mean() prec = {} for i, thres in enumerate(range(5, 10)): key = 'Pr@{}'.format(thres*10) value = prec_list[i].item() prec[key] = value logger.info('IoU={:.2f} OIoU={:.4f}'.format(100.*iou.item(), 100. * overall_IoU)) for k, v in prec.items(): logger.info('{}: {:.2f}.'.format(k, 100.*v)) return iou.item(), overall_IoU, prec