|
import argparse |
|
import json |
|
import os |
|
|
|
import cv2 |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
from refer import REFER |
|
|
|
parser = argparse.ArgumentParser(description='Data preparation') |
|
parser.add_argument('--data_root', type=str) |
|
parser.add_argument('--output_dir', type=str) |
|
parser.add_argument('--dataset', |
|
type=str, |
|
choices=['refcoco', 'refcoco+', 'refcocog', 'refclef'], |
|
default='refcoco') |
|
parser.add_argument('--split', type=str, default='umd') |
|
parser.add_argument('--generate_mask', action='store_true') |
|
args = parser.parse_args() |
|
img_path = os.path.join(args.data_root, 'images', 'train2014') |
|
|
|
h, w = (416, 416) |
|
|
|
refer = REFER(args.data_root, args.dataset, args.split) |
|
|
|
print('dataset [%s_%s] contains: ' % (args.dataset, args.split)) |
|
ref_ids = refer.getRefIds() |
|
image_ids = refer.getImgIds() |
|
print('%s expressions for %s refs in %s images.' % |
|
(len(refer.Sents), len(ref_ids), len(image_ids))) |
|
|
|
print('\nAmong them:') |
|
if args.dataset == 'refclef': |
|
if args.split == 'unc': |
|
splits = ['train', 'val', 'testA', 'testB', 'testC'] |
|
else: |
|
splits = ['train', 'val', 'test'] |
|
elif args.dataset == 'refcoco': |
|
splits = ['train', 'val', 'testA', 'testB'] |
|
elif args.dataset == 'refcoco+': |
|
splits = ['train', 'val', 'testA', 'testB'] |
|
elif args.dataset == 'refcocog': |
|
splits = ['train', 'val', |
|
'test'] |
|
|
|
for split in splits: |
|
ref_ids = refer.getRefIds(split=split) |
|
print('%s refs are in split [%s].' % (len(ref_ids), split)) |
|
|
|
|
|
def cat_process(cat): |
|
if cat >= 1 and cat <= 11: |
|
cat = cat - 1 |
|
elif cat >= 13 and cat <= 25: |
|
cat = cat - 2 |
|
elif cat >= 27 and cat <= 28: |
|
cat = cat - 3 |
|
elif cat >= 31 and cat <= 44: |
|
cat = cat - 5 |
|
elif cat >= 46 and cat <= 65: |
|
cat = cat - 6 |
|
elif cat == 67: |
|
cat = cat - 7 |
|
elif cat == 70: |
|
cat = cat - 9 |
|
elif cat >= 72 and cat <= 82: |
|
cat = cat - 10 |
|
elif cat >= 84 and cat <= 90: |
|
cat = cat - 11 |
|
return cat |
|
|
|
|
|
def bbox_process(bbox): |
|
x_min = int(bbox[0]) |
|
y_min = int(bbox[1]) |
|
x_max = x_min + int(bbox[2]) |
|
y_max = y_min + int(bbox[3]) |
|
return list(map(int, [x_min, y_min, x_max, y_max])) |
|
|
|
|
|
def prepare_dataset(dataset, splits, output_dir, generate_mask=False): |
|
ann_path = os.path.join(output_dir, 'anns', dataset) |
|
mask_path = os.path.join(output_dir, 'masks', dataset) |
|
if not os.path.exists(ann_path): |
|
os.makedirs(ann_path) |
|
if not os.path.exists(mask_path): |
|
os.makedirs(mask_path) |
|
|
|
for split in splits: |
|
dataset_array = [] |
|
ref_ids = refer.getRefIds(split=split) |
|
print('Processing split:{} - Len: {}'.format(split, len(ref_ids))) |
|
for i in tqdm(ref_ids): |
|
ref_dict = {} |
|
|
|
refs = refer.Refs[i] |
|
bboxs = refer.getRefBox(i) |
|
sentences = refs['sentences'] |
|
image_urls = refer.loadImgs(image_ids=refs['image_id'])[0] |
|
cat = cat_process(refs['category_id']) |
|
image_urls = image_urls['file_name'] |
|
if dataset == 'refclef' and image_urls in [ |
|
'19579.jpg', '17975.jpg', '19575.jpg' |
|
]: |
|
continue |
|
box_info = bbox_process(bboxs) |
|
|
|
ref_dict['bbox'] = box_info |
|
ref_dict['cat'] = cat |
|
ref_dict['segment_id'] = i |
|
ref_dict['img_name'] = image_urls |
|
|
|
if generate_mask: |
|
cv2.imwrite(os.path.join(mask_path, |
|
str(i) + '.png'), |
|
refer.getMask(refs)['mask'] * 255) |
|
|
|
sent_dict = [] |
|
for i, sent in enumerate(sentences): |
|
sent_dict.append({ |
|
'idx': i, |
|
'sent_id': sent['sent_id'], |
|
'sent': sent['sent'].strip() |
|
}) |
|
|
|
ref_dict['sentences'] = sent_dict |
|
ref_dict['sentences_num'] = len(sent_dict) |
|
|
|
dataset_array.append(ref_dict) |
|
print('Dumping json file...') |
|
with open(os.path.join(output_dir, 'anns', dataset, split + '.json'), |
|
'w') as f: |
|
json.dump(dataset_array, f) |
|
|
|
|
|
prepare_dataset(args.dataset, splits, args.output_dir, args.generate_mask) |
|
|