import datetime import os import time import torch import torch.utils.data from torch import nn from functools import reduce import operator from bert.modeling_bert import BertModel import torchvision from lib import segmentation import transforms as T import utils import numpy as np import torch.nn.functional as F from data.dataset_refer_zom import Referzom_Dataset, Refzom_DistributedSampler from data.dataset_refer_bert_rev import ReferDataset import gc from collections import OrderedDict from torch.utils.tensorboard import SummaryWriter def get_dataset(image_set, transform, args, eval_mode=False): if args.dataset == 'ref-zom': ds = Referzom_Dataset(args, split=image_set, image_transforms=transform, target_transforms=None, eval_mode=eval_mode) else : ds = ReferDataset(args, split=image_set, image_transforms=transform, target_transforms=None, eval_mode=image_set == 'val' ) num_classes = 2 return ds, num_classes # IoU calculation for validation # def IoU(pred, gt): # pred = pred.argmax(1) # intersection = torch.sum(torch.mul(pred, gt)) # union = torch.sum(torch.add(pred, gt)) - intersection # if intersection == 0 or union == 0: # iou = 0 # else: # iou = float(intersection) / float(union) # return iou, intersection, union def IoU(pred, gt): pred = pred.argmax(1) intersection = torch.sum(torch.mul(pred, gt)) union = torch.sum(torch.add(pred, gt)) - intersection if intersection == 0 or union == 0: iou = 0 else: iou = float(intersection) / float(union) return iou, intersection, union def get_transform(args): transforms = [T.Resize(args.img_size, args.img_size), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] return T.Compose(transforms) def criterion(input, target): weight = torch.FloatTensor([0.9, 1.1]).cuda() return nn.functional.cross_entropy(input, target, weight=weight) def return_mask(emb_distance, verb_mask=None): B_, B_ = emb_distance.shape positive_mask = torch.zeros_like(emb_distance) positive_mask.fill_diagonal_(1) # Set diagonal elements to 1 for all cases if B_ < len(verb_mask): # If B_ equals to 2*K (double the number of verb phrase) for i in range(B_ // 2): positive_mask[2 * i, 2 * i + 1] = 1 positive_mask[2 * i + 1, 2 * i] = 1 else: # Process the case where we have a mix of sentences with and without verbs i = 0 while i < B_: if verb_mask[i] == 1: positive_mask[i, i + 1] = 1 positive_mask[i + 1, i] = 1 i += 2 else: i += 1 negative_mask = torch.ones_like(emb_distance) - positive_mask return positive_mask, negative_mask def UniAngularContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): _, C, H, W = total_fq.shape if verbonly : B = total_fq[verb_mask].shape[0] emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C) assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." else : emb = torch.mean(total_fq, dim=-1) B_ = emb.shape[0] emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C) emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C) sim = nn.CosineSimilarity(dim=-1, eps=1e-6) sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_) sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) positive_mask, negative_mask = return_mask(sim_matrix, verb_mask) if len(positive_mask) > 0 : sim_matrix_with_margin = sim_matrix.clone() sim_matrix_with_margin[positive_mask.bool()] = torch.cos(torch.acos(sim_matrix[positive_mask.bool()]) + m / 57.2958) logits = sim_matrix_with_margin / tau exp_logits = torch.exp(logits) pos_exp_logits = exp_logits * positive_mask.long() pos_exp_logits = pos_exp_logits.sum(dim=-1) # print("pos_exp_logits: ", pos_exp_logits.shape) total_exp_logits = exp_logits.sum(dim=-1) positive_loss = -torch.log(pos_exp_logits / total_exp_logits) angular_loss = positive_loss.mean() return angular_loss else : return torch.tensor(0.0, device=total_fq.device) def UniAngularLogitContrastLoss(total_fq, verb_mask, alpha=0.5, verbonly=True, m=0.5, tau=0.05, args=None): epsilon = 1e-10 # Stability term for numerical issues _, C, H, W = total_fq.shape # Calculate embeddings if verbonly : B = total_fq[verb_mask].shape[0] emb = torch.mean(total_fq[verb_mask], dim=(-1, -2)).reshape(B, C) assert emb.shape[0] % 2 == 0, f"Embedding count {emb.shape[0]} is not divisible by 2." else : emb = torch.mean(total_fq, dim=-1) B_ = emb.shape[0] emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (B_, B_, C) emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (B_, B_, C) sim = nn.CosineSimilarity(dim=-1, eps=1e-6) sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (B_, B_) sim_matrix = torch.clamp(sim_matrix, min=-0.9999, max=0.9999) margin_in_radians = m / 57.2958 # Convert degrees to radians theta_matrix = (torch.pi / 2) - torch.acos(sim_matrix) positive_mask, negative_mask = return_mask(sim_matrix, verb_mask) theta_with_margin = theta_matrix.clone() theta_with_margin[positive_mask.bool()] -= margin_in_radians # Subtract margin directly for positives logits = theta_with_margin / tau # Scale with temperature # Compute exponential logits for softmax exp_logits = torch.exp(logits) # pos_exp_logits = (exp_logits * positive_mask).sum(dim=-1) # Positive term pos_exp_logits = exp_logits * positive_mask pos_exp_logits = pos_exp_logits.sum(dim=-1) # neg_exp_logits = (exp_logits * negative_mask).sum(dim=-1) # Negative term # total_exp_logits = pos_exp_logits + neg_exp_logits total_exp_logits = exp_logits.sum(dim=-1) # pos_exp_logits = pos_exp_logits + epsilon # total_exp_logits = total_exp_logits + epsilon # Compute angular loss loss = -torch.log(pos_exp_logits / total_exp_logits) angular_loss = loss.mean() return angular_loss def evaluate(model, data_loader, bert_model): #print("current model : ", model) model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' total_its = 0 acc_ious = 0 # evaluation variables cum_I, cum_U = 0, 0 eval_seg_iou_list = [.5, .6, .7, .8, .9] seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) seg_total = 0 mean_IoU = [] mean_acc = [] with torch.no_grad(): for data in metric_logger.log_every(data_loader, 100, header): total_its += 1 # Unpack data image, target, source_type, sentences, attentions = data image, target, sentences, attentions = ( image.cuda(non_blocking=True), target.cuda(non_blocking=True), sentences.cuda(non_blocking=True), attentions.cuda(non_blocking=True) ) # Squeeze unnecessary dimensions sentences = sentences.squeeze(-1) attentions = attentions.squeeze(-1) # Model inference if bert_model is not None: last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] embedding = last_hidden_states.permute(0, 2, 1) # [B, N, 768] -> [B, 768, N] attentions = attentions.unsqueeze(-1) # [B, N] -> [B, N, 1] output = model(image, embedding, l_mask=attentions) else: output = model(image, sentences, l_mask=attentions, is_train=False) # Zero target case if source_type[0] == 'zero': pred = output.argmax(1) incorrect_num = torch.sum(pred).item() # Count non-zero predictions acc = 1 if incorrect_num == 0 else 0 mean_acc.append(acc) else: # Non-zero target case this_iou, I, U = IoU(output, target) # Use the provided IoU function mean_IoU.append(this_iou) cum_I += I cum_U += U for n_eval_iou in range(len(eval_seg_iou_list)): eval_seg_iou = eval_seg_iou_list[n_eval_iou] seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) seg_total += 1 mIoU = np.mean(mean_IoU) mean_acc = np.mean(mean_acc) print('Final results:') print('Mean IoU is %.2f\n' % (mIoU * 100.)) results_str = '' precs = [] for n_eval_iou in range(len(eval_seg_iou_list)): results_str += ' precision@%s = %.2f\n' % \ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) precs.append(seg_correct[n_eval_iou] * 100. / seg_total) results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) results_str += ' mean IoU = %.2f\n' % (mIoU * 100.) print(results_str) if args.dataset == 'ref-zom': print('Mean accuracy for one-to-zero sample is %.2f\n' % (mean_acc*100)) return mIoU, 100 * cum_I / cum_U, precs def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, iterations, bert_model, metric_learning=False, args=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) header = 'Epoch: [{}]'.format(epoch) train_loss = 0 total_its = 0 mlw = args.metric_loss_weight metric_mode = args.metric_mode if not metric_learning: mlw = 0 for data in metric_logger.log_every(data_loader, print_freq, header): #print("data : ", data) total_its += 1 # Ref-Zom Repro image, target, source_type, sentences, attentions, pos_sent, pos_attn_mask, pos_type = data source_type = np.array(source_type) target_flag = torch.tensor(np.where(source_type == 'zero', 0, 1)) if args.addzero : hardpos_flag = torch.tensor(np.where(pos_type == 'hardpos', 1, 0)) else : # default option for training : only include one, many targets! hardpos_flag = torch.tensor(np.where((source_type != 'zero') & (pos_type == 'hardpos'), 1, 0)) sentences = sentences.squeeze(1) attentions = attentions.squeeze(1) pos_sent = pos_sent.squeeze(1) pos_attn_mask = pos_attn_mask.squeeze(1) ## ver 1 : hardpos flag outside the model verb_masks = [] cl_masks = [] images = [] targets = [] sentences_ = [] attentions_ = [] for idx in range(len(image)) : # Append original data sentences_.append(sentences[idx]) images.append(image[idx]) targets.append(target[idx]) attentions_.append(attentions[idx]) if hardpos_flag[idx] : verb_masks.extend([1, 1]) cl_masks.extend([1, 0]) sentences_.append(pos_sent[idx]) images.append(image[idx]) targets.append(target[idx]) attentions_.append(pos_attn_mask[idx]) else: verb_masks.append(0) cl_masks.append(1) image, target, sentences, attentions, verb_masks, cl_masks = \ torch.stack(images).cuda(non_blocking=True),\ torch.stack(targets).cuda(non_blocking=True),\ torch.stack(sentences_).cuda(non_blocking=True),\ torch.stack(attentions_).cuda(non_blocking=True),\ torch.tensor(verb_masks, dtype=torch.bool, device='cuda'),\ torch.tensor(cl_masks, dtype=torch.bool, device='cuda') loss = 0 metric_loss = 0 if bert_model is not None: last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] # (6, 10, 768) embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy attentions = attentions.unsqueeze(dim=-1) # (batch, N_l, 1) output = model(image, embedding, l_mask=attentions) else: output, metric_tensors = model(image, sentences, l_mask=attentions) ce_loss = criterion(output[cl_masks], target[cl_masks]) if metric_learning: hardpos_count = sum(hardpos_flag) divn = 1 if hardpos_count >= 3: metric_loss = UniAngularLogitContrastLoss(metric_tensors, verb_masks, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) divn+=mlw # (1+mlw) else: metric_loss = 0 else: metric_loss = 0 divn = 1 # if metric_learning and sum(hardpos_flag) > 0 : # metric_loss = UniAngularLogitContrastLoss(metric_tensors, verb_masks, m=args.margin_value, tau=args.temperature, verbonly=True, args=args) loss = (ce_loss + metric_loss * mlw) / divn optimizer.zero_grad() # set_to_none=True is only available in pytorch 1.6+ loss.backward() optimizer.step() lr_scheduler.step() torch.cuda.synchronize() train_loss += loss.item() iterations += 1 metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) del image, target, sentences, attentions, loss, output, data if bert_model is not None: del last_hidden_states, embedding gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() loss_log = { 'loss': metric_logger.meters['loss'].global_avg } return iterations, loss_log def main(args): writer = SummaryWriter('./experiments/{}/{}'.format("_".join([args.dataset, args.splitBy]), args.model_id)) dataset, num_classes = get_dataset("train", get_transform(args=args), args=args, eval_mode=False) dataset_test, _ = get_dataset(args.split, get_transform(args=args), args=args, eval_mode=True) # batch sampler print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") num_tasks = utils.get_world_size() global_rank = utils.get_rank() if args.dataset == 'ref-zom': train_sampler = Refzom_DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True) else: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True) test_sampler = torch.utils.data.SequentialSampler(dataset_test) # data loader data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=args.pin_mem, drop_last=True) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers) # model initialization print(args.model) model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights, args=args) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model.cuda() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) single_model = model.module if args.model != 'lavt_one': model_class = BertModel bert_model = model_class.from_pretrained(args.ck_bert) bert_model.pooler = None # a work-around for a bug in Transformers = 3.0.2 that appears for DistributedDataParallel bert_model.cuda() bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) single_bert_model = bert_model.module else: bert_model = None single_bert_model = None # resume training if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') single_model.load_state_dict(checkpoint['model']) if args.model != 'lavt_one': single_bert_model.load_state_dict(checkpoint['bert_model']) # parameters to optimize backbone_no_decay = list() backbone_decay = list() for name, m in single_model.backbone.named_parameters(): if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: backbone_no_decay.append(m) else: backbone_decay.append(m) if args.model != 'lavt_one': params_to_optimize = [ {'params': backbone_no_decay, 'weight_decay': 0.0}, {'params': backbone_decay}, {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, # the following are the parameters of bert {"params": reduce(operator.concat, [[p for p in single_bert_model.encoder.layer[i].parameters() if p.requires_grad] for i in range(10)])}, ] else: params_to_optimize = [ {'params': backbone_no_decay, 'weight_decay': 0.0}, {'params': backbone_decay}, {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, # the following are the parameters of bert {"params": reduce(operator.concat, [[p for p in single_model.text_encoder.encoder.layer[i].parameters() if p.requires_grad] for i in range(10)])}, ] # params_to_optimize = [ # {'params': backbone_no_decay, 'weight_decay': 0.0}, # {'params': backbone_decay}, # {"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, # # the following are the parameters of bert # {"params": reduce(operator.concat, # [[p for p in single_model.text_encoder.encoder.layer[i].parameters() # if p.requires_grad] for i in range(10)]), 'lr': args.lr/10}, # ] # optimizer optimizer = torch.optim.AdamW(params_to_optimize, lr=args.lr, weight_decay=args.weight_decay, amsgrad=args.amsgrad ) # learning rate scheduler lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) # housekeeping start_time = time.time() iterations = 0 best_oIoU = -0.1 # resume training (optimizer, lr scheduler, and the epoch) if args.resume: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) resume_epoch = checkpoint['epoch'] else: resume_epoch = -999 # training loops for epoch in range(max(0, resume_epoch+1), args.epochs): data_loader.sampler.set_epoch(epoch) itrs_temp, loss_log = train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, iterations, bert_model, metric_learning=args.metric_learning, args=args) iou, overallIoU, precs = evaluate(model, data_loader_test, bert_model) print('Average object IoU {}'.format(iou)) print('Overall IoU {}'.format(overallIoU)) save_checkpoint = (best_oIoU < overallIoU) if save_checkpoint: print('Better epoch: {}\n'.format(epoch)) if single_bert_model is not None: dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 'lr_scheduler': lr_scheduler.state_dict()} else: dict_to_save = {'model': single_model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 'lr_scheduler': lr_scheduler.state_dict()} utils.save_on_master(dict_to_save, os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id))) best_oIoU = overallIoU if utils.is_main_process(): writer.add_scalar('val/mIoU', iou, epoch) writer.add_scalar('val/oIoU', overallIoU, epoch) writer.add_scalar('val/Prec/50', precs[0], epoch) writer.add_scalar('val/Prec/60', precs[1], epoch) writer.add_scalar('val/Prec/70', precs[2], epoch) writer.add_scalar('val/Prec/80', precs[3], epoch) writer.add_scalar('val/Prec/90', precs[4], epoch) writer.add_scalar('train/loss', loss_log['loss'], epoch) writer.flush() # summarize total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == "__main__": from args import get_parser parser = get_parser() args = parser.parse_args() # set up distributed learning if "LOCAL_RANK" in os.environ: local_rank = int(os.environ["LOCAL_RANK"]) else: local_rank = 0 # Default value for non-distributed mode print(f"Local Rank: {local_rank}, World Size: {os.environ.get('WORLD_SIZE', '1')}") utils.init_distributed_mode(args) print('Image size: {}'.format(str(args.img_size))) print('Metric Learning Ops') print('metric learning flag : ', args.metric_learning) print('metric loss weight : ', args.metric_loss_weight) print('metric mode and hardpos selection : ', args.metric_mode, args.hp_selection) print('margin value : ', args.margin_value) print('temperature : ', args.temperature) print('add zero in ACE loss : ', args.addzero) print(args) main(args)