VRIS_vip / LAVT-RIS /donghwa /test_mosaic_retrieval.py
dianecy's picture
Upload folder using huggingface_hub
8d82201 verified
raw
history blame
8.06 kB
import datetime
import os
import time
import torch
import torch.utils.data
from torch import nn
from functools import reduce
import operator
from bert.modeling_bert import BertModel
import torchvision
from lib import segmentation
import transforms as T
import utils
import numpy as np
import torch.nn.functional as F
import gc
from collections import OrderedDict
from data.utils import MosaicVisualization, COCOVisualization
import albumentations as A
from albumentations.pytorch import ToTensorV2
def get_dataset(image_set, transform, args):
if args.dataset == "grefcoco":
# from data.dataset_grefer import GReferDataset
from data.dataset_grefer_mosaic_retrieval import GReferDataset
ds = GReferDataset(args=args,
refer_root=args.refer_data_root,
dataset_name=args.dataset,
splitby=args.splitBy,
split=image_set,
image_root=os.path.join(args.refer_data_root, 'images/train2014')
)
fpath = os.path.join('coco-data-vis-retrieval', args.model_id, image_set)
MosaicVisualization(ds, fpath)
else :
# from data.dataset_refer_bert import ReferDataset
from data.dataset_refer_bert_mosaic_retrieval import ReferDataset
ds = ReferDataset(args,
split=image_set
)
fpath = os.path.join('coco-data-vis-retrieval', args.model_id, image_set)
MosaicVisualization(ds, fpath)
num_classes = 2
return ds, num_classes
# IoU calculation for validation
def IoU(pred, gt):
pred = pred.argmax(1)
intersection = torch.sum(torch.mul(pred, gt))
union = torch.sum(torch.add(pred, gt)) - intersection
if intersection == 0 or union == 0:
iou = 0
else:
iou = float(intersection) / float(union)
return iou, intersection, union
def get_transform(args):
transforms = [T.Resize(args.img_size, args.img_size),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
return T.Compose(transforms)
def criterion(input, target):
weight = torch.FloatTensor([0.9, 1.1]).cuda()
return nn.functional.cross_entropy(input, target, weight=weight)
def evaluate(model, data_loader, bert_model=None):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
total_its = 0
acc_ious = 0
# evaluation variables
cum_I, cum_U = 0, 0
eval_seg_iou_list = [.5, .6, .7, .8, .9]
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32)
seg_total = 0
mean_IoU = []
with torch.no_grad():
for data in metric_logger.log_every(data_loader, 100, header):
total_its += 1
image, target, sentences, attentions = data['image'], data['seg_target'], data['sentence'], data['attn_mask']
image, target, sentences, attentions = image.cuda(non_blocking=True),\
target.cuda(non_blocking=True),\
sentences.cuda(non_blocking=True),\
attentions.cuda(non_blocking=True)
sentences = sentences.squeeze(1)
attentions = attentions.squeeze(1)
if bert_model is not None:
last_hidden_states = bert_model(sentences, attention_mask=attentions)[0]
embedding = last_hidden_states.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy
attentions = attentions.unsqueeze(dim=-1) # (B, N_l, 1)
output = model(image, embedding, l_mask=attentions)
else:
output = model(image, sentences, l_mask=attentions)
iou, I, U = IoU(output, target)
acc_ious += iou
mean_IoU.append(iou)
cum_I += I
cum_U += U
for n_eval_iou in range(len(eval_seg_iou_list)):
eval_seg_iou = eval_seg_iou_list[n_eval_iou]
seg_correct[n_eval_iou] += (iou >= eval_seg_iou)
seg_total += 1
# for GRES
iou = acc_ious / total_its
mean_IoU = np.array(mean_IoU)
mIoU = np.mean(mean_IoU)
print('Final results:')
print('Mean IoU is %.2f\n' % (mIoU * 100.))
results_str = ''
for n_eval_iou in range(len(eval_seg_iou_list)):
results_str += ' precision@%s = %.2f\n' % \
(str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total)
results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U)
print(results_str)
return 100 * iou, 100 * cum_I / cum_U
# def get_transform(args):
# transforms = T.Compose([
# T.Resize(args.img_size, args.img_size),
# T.ToTensor(),
# T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# return transforms
def get_transform(args):
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transforms = A.Compose([
A.Resize(args.img_size, args.img_size, always_apply=True),
A.Normalize(mean=mean, std=std),
ToTensorV2 (),
]) #,
return transforms
def computeIoU(pred_seg, gd_seg):
I = np.sum(np.logical_and(pred_seg, gd_seg))
U = np.sum(np.logical_or(pred_seg, gd_seg))
return I, U
def main(args):
device = 'cuda'
dataset_test,_ = get_dataset(args.split, get_transform(args=args), args=args)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers)
# model initialization
print(args.model)
model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights,
args=args)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# model.cuda()
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
# single_model = model.module
# make sure embedding layer size matches
model.text_encoder.resize_token_embeddings(len(dataset_test.tokenizer))
model.cuda()
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=True)
# evaluate(model, data_loader_test, device=device)
iou, overallIoU = evaluate(model, data_loader_test)
print('Average object IoU {}'.format(iou))
print('Overall IoU {}'.format(overallIoU))
def parse_args():
parser = argparse.ArgumentParser(description='RefCOCO Test')
# parser.add_argument("--local_rank",
# type=int,
# help='local rank for DistributedDataParallel')
# parser.add_argument('--config',
# default='path to xxx.yaml',
# type=str,
# help='config file')
# parser.add_argument('--opts',
# default=None,
# nargs=argparse.REMAINDER,
# help='override some settings in the config.')
args = parser.parse_args()
assert args.config is not None
cfg = OmegaConf.load(args.config)
cfg['local_rank'] = args.local_rank
return cfg
if __name__ == "__main__":
from args import get_parser
parser = get_parser()
args = parser.parse_args()
if args.config is not None :
from config.utils import CfgNode
cn = CfgNode(CfgNode.load_yaml_with_base(args.config))
for k,v in cn.items():
if not hasattr(args, k):
print('Warning: key %s not in args' %k)
setattr(args, k, v)
args = parser.parse_args(namespace=args)
print(args)
print(f'Image size: {args.img_size}')
main(args)