import datetime import os import time import torch import torch.utils.data from torch import nn from functools import reduce import operator from bert.modeling_bert import BertModel import torchvision from lib import segmentation import transforms as T import utils import numpy as np import torch.nn.functional as F import gc from collections import OrderedDict from data.utils import MosaicVisualization, COCOVisualization import albumentations as A from albumentations.pytorch import ToTensorV2 def get_dataset(image_set, transform, args): if args.dataset == "grefcoco": # from data.dataset_grefer import GReferDataset from data.dataset_grefer_mosaic_retrieval import GReferDataset ds = GReferDataset(args=args, refer_root=args.refer_data_root, dataset_name=args.dataset, splitby=args.splitBy, split=image_set, image_root=os.path.join(args.refer_data_root, 'images/train2014') ) fpath = os.path.join('coco-data-vis-retrieval', args.model_id, image_set) MosaicVisualization(ds, fpath) else : # from data.dataset_refer_bert import ReferDataset from data.dataset_refer_bert_mosaic_retrieval import ReferDataset ds = ReferDataset(args, split=image_set ) fpath = os.path.join('coco-data-vis-retrieval', args.model_id, image_set) MosaicVisualization(ds, fpath) num_classes = 2 return ds, num_classes # IoU calculation for validation def IoU(pred, gt): pred = pred.argmax(1) intersection = torch.sum(torch.mul(pred, gt)) union = torch.sum(torch.add(pred, gt)) - intersection if intersection == 0 or union == 0: iou = 0 else: iou = float(intersection) / float(union) return iou, intersection, union def get_transform(args): transforms = [T.Resize(args.img_size, args.img_size), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] return T.Compose(transforms) def criterion(input, target): weight = torch.FloatTensor([0.9, 1.1]).cuda() return nn.functional.cross_entropy(input, target, weight=weight) def evaluate(model, data_loader, bert_model=None): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' total_its = 0 acc_ious = 0 # evaluation variables cum_I, cum_U = 0, 0 eval_seg_iou_list = [.5, .6, .7, .8, .9] seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) seg_total = 0 mean_IoU = [] with torch.no_grad(): for data in metric_logger.log_every(data_loader, 100, header): total_its += 1 image, target, sentences, attentions = data['image'], data['seg_target'], data['sentence'], data['attn_mask'] image, target, sentences, attentions = image.cuda(non_blocking=True),\ target.cuda(non_blocking=True),\ sentences.cuda(non_blocking=True),\ attentions.cuda(non_blocking=True) sentences = sentences.squeeze(1) attentions = attentions.squeeze(1) if bert_model is not None: last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy attentions = attentions.unsqueeze(dim=-1) # (B, N_l, 1) output = model(image, embedding, l_mask=attentions) else: output = model(image, sentences, l_mask=attentions) iou, I, U = IoU(output, target) acc_ious += iou mean_IoU.append(iou) cum_I += I cum_U += U for n_eval_iou in range(len(eval_seg_iou_list)): eval_seg_iou = eval_seg_iou_list[n_eval_iou] seg_correct[n_eval_iou] += (iou >= eval_seg_iou) seg_total += 1 # for GRES iou = acc_ious / total_its mean_IoU = np.array(mean_IoU) mIoU = np.mean(mean_IoU) print('Final results:') print('Mean IoU is %.2f\n' % (mIoU * 100.)) results_str = '' for n_eval_iou in range(len(eval_seg_iou_list)): results_str += ' precision@%s = %.2f\n' % \ (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) print(results_str) return 100 * iou, 100 * cum_I / cum_U # def get_transform(args): # transforms = T.Compose([ # T.Resize(args.img_size, args.img_size), # T.ToTensor(), # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # return transforms def get_transform(args): mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) transforms = A.Compose([ A.Resize(args.img_size, args.img_size, always_apply=True), A.Normalize(mean=mean, std=std), ToTensorV2 (), ]) #, return transforms def computeIoU(pred_seg, gd_seg): I = np.sum(np.logical_and(pred_seg, gd_seg)) U = np.sum(np.logical_or(pred_seg, gd_seg)) return I, U def main(args): device = 'cuda' dataset_test,_ = get_dataset(args.split, get_transform(args=args), args=args) test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers) # model initialization print(args.model) model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights, args=args) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # model.cuda() # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) # single_model = model.module # make sure embedding layer size matches model.text_encoder.resize_token_embeddings(len(dataset_test.tokenizer)) model.cuda() checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model'], strict=True) # evaluate(model, data_loader_test, device=device) iou, overallIoU = evaluate(model, data_loader_test) print('Average object IoU {}'.format(iou)) print('Overall IoU {}'.format(overallIoU)) def parse_args(): parser = argparse.ArgumentParser(description='RefCOCO Test') # parser.add_argument("--local_rank", # type=int, # help='local rank for DistributedDataParallel') # parser.add_argument('--config', # default='path to xxx.yaml', # type=str, # help='config file') # parser.add_argument('--opts', # default=None, # nargs=argparse.REMAINDER, # help='override some settings in the config.') args = parser.parse_args() assert args.config is not None cfg = OmegaConf.load(args.config) cfg['local_rank'] = args.local_rank return cfg if __name__ == "__main__": from args import get_parser parser = get_parser() args = parser.parse_args() if args.config is not None : from config.utils import CfgNode cn = CfgNode(CfgNode.load_yaml_with_base(args.config)) for k,v in cn.items(): if not hasattr(args, k): print('Warning: key %s not in args' %k) setattr(args, k, v) args = parser.parse_args(namespace=args) print(args) print(f'Image size: {args.img_size}') main(args)