|
import contextlib |
|
import io |
|
import logging |
|
import numpy as np |
|
import os |
|
import random |
|
import copy |
|
import pycocotools.mask as mask_util |
|
from fvcore.common.timer import Timer |
|
from PIL import Image |
|
import torch.utils.data as data |
|
|
|
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes |
|
from detectron2.utils.file_io import PathManager |
|
|
|
import time |
|
import copy |
|
import logging |
|
import torch |
|
import math |
|
|
|
from detectron2.config import configurable |
|
from detectron2.data import detection_utils as utils |
|
from detectron2.data import transforms as T |
|
|
|
|
|
from bert.tokenization_bert import BertTokenizer |
|
from pycocotools import mask as coco_mask |
|
from data.utils import convert_coco_poly_to_mask, build_transform_train, build_transform_test |
|
from .utils import cosine_annealing |
|
""" |
|
This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format". |
|
""" |
|
|
|
|
|
import albumentations as A |
|
from albumentations.pytorch import ToTensorV2 |
|
from PIL import Image, ImageDraw, ImageFilter |
|
import lmdb |
|
import pyarrow as pa |
|
|
|
|
|
|
|
def loads_pyarrow(buf): |
|
return pa.deserialize(buf) |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
__all__ = ["load_refcoco_json"] |
|
|
|
class GReferDataset(data.Dataset): |
|
def __init__(self, |
|
args, refer_root, dataset_name, splitby, split, image_root, |
|
img_format="RGB", merge=True, |
|
extra_annotation_keys=None, extra_refer_keys=None): |
|
|
|
self.refer_root = refer_root |
|
self.dataset_name = dataset_name |
|
self.splitby = splitby |
|
self.split = split |
|
self.image_root = image_root |
|
self.extra_annotation_keys = extra_annotation_keys |
|
self.extra_refer_keys = extra_refer_keys |
|
self.img_format = img_format |
|
self.merge = merge |
|
self.max_tokens = 20 |
|
self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) |
|
|
|
if split == "train": |
|
self.tfm_gens = build_transform_train(args) |
|
elif split in ["val", "test", "testA", "testB"]: |
|
self.tfm_gens = build_transform_test(args) |
|
|
|
if self.dataset_name == 'refcocop': |
|
self.dataset_name = 'refcoco+' |
|
if self.dataset_name == 'refcoco' or self.dataset_name == 'refcoco+': |
|
self.splitby == 'unc' |
|
if self.dataset_name == 'refcocog': |
|
assert self.splitby == 'umd' or self.splitby == 'google' |
|
|
|
dataset_id = '_'.join([self.dataset_name, self.splitby, self.split]) |
|
|
|
from refer.grefer import G_REFER |
|
logger.info('Loading dataset {} ({}-{}) ...'.format(self.dataset_name, self.splitby, self.split)) |
|
logger.info('Refcoco root: {}'.format(self.refer_root)) |
|
timer = Timer() |
|
self.refer_root = PathManager.get_local_path(self.refer_root) |
|
with contextlib.redirect_stdout(io.StringIO()): |
|
self.refer = G_REFER(data_root=self.refer_root, |
|
dataset=self.dataset_name, |
|
splitBy=self.splitby) |
|
if timer.seconds() > 1: |
|
logger.info("Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds())) |
|
|
|
self.ref_ids = self.refer.getRefIds(split=self.split) |
|
self.img_ids = self.refer.getImgIds(self.ref_ids) |
|
self.refs = self.refer.loadRefs(self.ref_ids) |
|
imgs = [self.refer.loadImgs(ref['image_id'])[0] for ref in self.refs] |
|
anns = [self.refer.loadAnns(ref['ann_id']) for ref in self.refs] |
|
self.imgs_refs_anns = list(zip(imgs, self.refs, anns)) |
|
|
|
logger.info("Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format(len(self.img_ids), len(self.ref_ids), dataset_id)) |
|
|
|
|
|
self.dataset_dicts = [] |
|
|
|
ann_keys = ["iscrowd", "bbox", "category_id"] + (self.extra_annotation_keys or []) |
|
ref_keys = ["raw", "sent_id"] + (self.extra_refer_keys or []) |
|
|
|
ann_lib = {} |
|
|
|
NT_count = 0 |
|
MT_count = 0 |
|
|
|
for idx, (img_dict, ref_dict, anno_dicts) in enumerate(self.imgs_refs_anns): |
|
record = {} |
|
record['id'] = idx |
|
record["source"] = 'grefcoco' |
|
record["file_name"] = os.path.join(self.image_root, img_dict["file_name"]) |
|
record["height"] = img_dict["height"] |
|
record["width"] = img_dict["width"] |
|
image_id = record["image_id"] = img_dict["id"] |
|
|
|
|
|
|
|
assert ref_dict['image_id'] == image_id |
|
assert ref_dict['split'] == self.split |
|
if not isinstance(ref_dict['ann_id'], list): |
|
ref_dict['ann_id'] = [ref_dict['ann_id']] |
|
|
|
|
|
if None in anno_dicts: |
|
assert anno_dicts == [None] |
|
assert ref_dict['ann_id'] == [-1] |
|
record['empty'] = True |
|
obj = {key: None for key in ann_keys if key in ann_keys} |
|
obj["bbox_mode"] = BoxMode.XYWH_ABS |
|
obj["empty"] = True |
|
obj = [obj] |
|
|
|
|
|
else: |
|
record['empty'] = False |
|
obj = [] |
|
for anno_dict in anno_dicts: |
|
ann_id = anno_dict['id'] |
|
if anno_dict['iscrowd']: |
|
continue |
|
assert anno_dict["image_id"] == image_id |
|
assert ann_id in ref_dict['ann_id'] |
|
|
|
if ann_id in ann_lib: |
|
ann = ann_lib[ann_id] |
|
else: |
|
ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict} |
|
ann["bbox_mode"] = BoxMode.XYWH_ABS |
|
ann["empty"] = False |
|
|
|
segm = anno_dict.get("segmentation", None) |
|
assert segm |
|
if isinstance(segm, dict): |
|
if isinstance(segm["counts"], list): |
|
|
|
segm = mask_util.frPyObjects(segm, *segm["size"]) |
|
else: |
|
|
|
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] |
|
if len(segm) == 0: |
|
num_instances_without_valid_segmentation += 1 |
|
continue |
|
ann["segmentation"] = segm |
|
ann_lib[ann_id] = ann |
|
|
|
obj.append(ann) |
|
|
|
record["annotations"] = obj |
|
|
|
|
|
sents = ref_dict['sentences'] |
|
for sent in sents: |
|
ref_record = record.copy() |
|
ref = {key: sent[key] for key in ref_keys if key in sent} |
|
ref["ref_id"] = ref_dict["ref_id"] |
|
ref_record["sentence"] = ref |
|
self.dataset_dicts.append(ref_record) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.classes = [] |
|
self.aug = args.aug |
|
self.bert_type = args.bert_tokenizer |
|
|
|
self.img_sz = args.img_size |
|
|
|
each_img_sz = int(args.img_size/math.sqrt(self.aug.num_bgs)) |
|
mean = (0.485, 0.456, 0.406) |
|
std = (0.229, 0.224, 0.225) |
|
|
|
self.resize_bg1 = A.Compose([ |
|
A.Resize(args.img_size, args.img_size, always_apply=True)]) |
|
|
|
self.resize_bg4 = A.Compose([ |
|
A.Resize(each_img_sz, each_img_sz, always_apply=True)], |
|
additional_targets={'image1': 'image', 'image2': 'image', 'image3': 'image', |
|
'mask1': 'mask', 'mask2': 'mask', 'mask3': 'mask',}) |
|
|
|
self.transforms = A.Compose([ |
|
A.Normalize(mean=mean, std=std), |
|
ToTensorV2 (), |
|
]) |
|
|
|
|
|
|
|
all_imgs = self.refer.Imgs |
|
self.imgs = list(all_imgs[i] for i in self.img_ids) |
|
|
|
self.ref_id2idx = dict(zip(self.ref_ids, range(len(self.ref_ids)))) |
|
self.ref_idx2id = dict(zip(range(len(self.ref_ids)), self.ref_ids)) |
|
self.img2refs = self.refer.imgToRefs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.max_tokens = 20 |
|
self.is_train = True if split == "train" else False |
|
self.input_ids = [] |
|
self.attention_masks = [] |
|
for i, r in enumerate(self.ref_ids): |
|
ref = self.refer.Refs[r] |
|
|
|
sentences_for_ref = [] |
|
attentions_for_ref = [] |
|
for j, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): |
|
sentence_raw = el['raw'] |
|
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True, max_length=self.max_tokens, truncation=True) |
|
|
|
padded_input_ids = [0] * self.max_tokens |
|
padded_input_ids[:len(input_ids)] = input_ids |
|
attention_mask = [0] * self.max_tokens |
|
attention_mask[:len(input_ids)] = [1]*len(input_ids) |
|
sentences_for_ref.append(padded_input_ids) |
|
attentions_for_ref.append(attention_mask) |
|
|
|
self.input_ids.append(sentences_for_ref) |
|
self.attention_masks.append(attentions_for_ref) |
|
|
|
if self.aug.blur: |
|
self.blur = ImageFilter.GaussianBlur(100) |
|
|
|
|
|
lmdb_path = f'/data2/dataset/RefCOCO/logit_db/grefcoco/grefcoco.lmdb' |
|
|
|
self.lmdb_env = lmdb.open( |
|
lmdb_path, subdir=False, max_readers=32, |
|
readonly=True, lock=False, |
|
readahead=False, meminit=False |
|
) |
|
with self.lmdb_env.begin(write=False) as txn: |
|
self.length = loads_pyarrow(txn.get(b'__len__')) |
|
self.keys = loads_pyarrow(txn.get(b'__keys__')) |
|
|
|
self.epoch = 0 |
|
np.random.seed() |
|
|
|
|
|
|
|
@staticmethod |
|
def _merge_masks(x): |
|
return x.sum(dim=0, keepdim=True).clamp(max=1) |
|
|
|
def __len__(self): |
|
return len(self.dataset_dicts) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
refid = self.ref_idx2id[index] |
|
with self.lmdb_env.begin(write=False) as txn: |
|
byteflow = txn.get(self.keys[refid]) |
|
lmdb_dict = loads_pyarrow(byteflow) |
|
|
|
|
|
dataset_dict = copy.deepcopy(self.dataset_dicts[index]) |
|
img_id = dataset_dict["image_id"] |
|
index = dataset_dict["id"] |
|
|
|
|
|
if self.split=='train': |
|
if self.aug.num_bgs==4: |
|
aug_prob = self.aug.aug_prob |
|
if self.epoch < self.aug.retrieval_epoch : |
|
num_bgs = np.random.choice([1, 4], p=[1-aug_prob, aug_prob]) |
|
else : |
|
rand_prob = cosine_annealing(epoch=self.epoch-self.aug.retrieval_epoch, \ |
|
n_epochs=self.args.epochs-self.aug.retrieval_epoch, |
|
n_cycles=1, lrate_max=aug_prob) |
|
retr_prob = aug_prob-rand_prob |
|
choice = np.random.choice(['one', 'random', 'retrieval'], p=[1-aug_prob, rand_prob, retr_prob]) |
|
if choice == 'one': |
|
num_bgs = 1 |
|
else : |
|
num_bgs = 4 |
|
else: |
|
num_bgs = 1 |
|
else: |
|
num_bgs = 1 |
|
|
|
|
|
target_sent_idx = np.random.choice(len(self.input_ids[index])) |
|
ref_id = self.ref_idx2id[index] |
|
insert_idx = np.random.choice(range(num_bgs)) |
|
|
|
if num_bgs==1: |
|
ref_ids = [] |
|
sent_idxs = [] |
|
sents = np.array([], dtype='str') |
|
img_ids = [self.refer.Refs[ref_id]['image_id']] |
|
|
|
else: |
|
if self.epoch >= self.aug.retrieval_epoch : |
|
sent_id = list(lmdb_dict.keys())[target_sent_idx] |
|
img_ids = list(np.random.choice(lmdb_dict[sent_id], size=num_bgs-1, replace=True)) |
|
img_ids = np.insert(img_ids, insert_idx, self.refer.Refs[ref_id]['image_id']) |
|
|
|
ref_ids = list(np.random.choice(self.ref_ids, size=num_bgs-1, replace=False)) |
|
sent_idxs = [np.random.choice(len(self.refer.Refs[r]['sentences'])) for r in ref_ids] |
|
sents = np.array([self.refer.Refs[r]['sentences'][sent_idxs[i]]['raw'] for i, r in enumerate(ref_ids)], dtype='str') |
|
|
|
ref_ids = np.insert(ref_ids, insert_idx, self.ref_idx2id[index]).astype(int) |
|
sents = np.insert(sents, insert_idx, |
|
self.refer.Refs[ref_ids[insert_idx]]['sentences'][target_sent_idx]['raw']) |
|
sent_idxs = np.insert(sent_idxs, insert_idx, target_sent_idx).astype(int) |
|
|
|
|
|
if self.aug.tgt_selection == 'random': |
|
target_idx = np.random.choice(range(num_bgs)) |
|
target_ref_idx = self.ref_id2idx[ref_ids[target_idx]] |
|
target_sent_idx = int(np.random.choice(len(self.input_ids[target_ref_idx]))) |
|
elif self.aug.tgt_selection == 'longest': |
|
target_idx = np.argmax(list(map(len, sents))) |
|
target_sent_idx = sent_idxs[target_idx] |
|
elif self.aug.tgt_selection == 'fixed': |
|
target_idx = insert_idx |
|
|
|
target_ref_id = self.ref_idx2id[index] |
|
|
|
|
|
imgs, masks = [], [] |
|
if self.epoch >= self.aug.retrieval_epoch : |
|
for img_id in img_ids: |
|
img_info = self.refer.Imgs[img_id] |
|
img_path = os.path.join(self.refer.IMAGE_DIR, img_info['file_name']) |
|
img = Image.open(img_path).convert("RGB") |
|
imgs.append(np.array(img)) |
|
ref = self.refer.imgToRefs[img_id][0] |
|
|
|
|
|
|
|
mask = self.refer.getMaskByRef(ref, ref['ref_id'], self.merge)['mask'] |
|
masks.append(mask) |
|
|
|
else : |
|
for ref_id in ref_ids: |
|
img_id = self.refer.getImgIds([ref_id])[0] |
|
img_info = self.refer.Imgs[img_id] |
|
img_path = os.path.join(self.refer.IMAGE_DIR, img_info['file_name']) |
|
img = Image.open(img_path).convert("RGB") |
|
imgs.append(np.array(img)) |
|
ref = self.refer.loadRefs(ref_ids=[ref_id]) |
|
|
|
|
|
|
|
mask = self.refer.getMaskByRef(ref[0], ref_id, self.merge)['mask'] |
|
masks.append(mask) |
|
|
|
|
|
if num_bgs==1: |
|
resized = self.resize_bg1(image=imgs[0], mask=masks[0]) |
|
imgs, masks = [resized['image']], [resized['mask']] |
|
img = imgs[0] |
|
else: |
|
if self.aug.move_crs_pnt: |
|
crs_y = np.random.randint(0, self.img_sz+1) |
|
crs_x = np.random.randint(0, self.img_sz+1) |
|
else: |
|
crs_y = 480//2 |
|
crs_x = 480//2 |
|
|
|
if crs_y==0 or crs_x==0: |
|
img1 = np.zeros([0,crs_x,3]) if crs_y==0 else np.zeros([crs_y,0,3]) |
|
mask1 = np.zeros([0,crs_x]) if crs_y==0 else np.zeros([crs_y,0]) |
|
else: |
|
resize_bg1 = A.Compose([A.Resize(crs_y, crs_x, always_apply=True)]) |
|
temp = resize_bg1(image=imgs[0], mask=masks[0]) |
|
img1 = temp['image'] |
|
mask1 = temp['mask'] |
|
|
|
if crs_y==0 or crs_x==self.img_sz: |
|
img2 = np.zeros([0,self.img_sz-crs_x,3]) if crs_y==0 \ |
|
else np.zeros([crs_y,0,3]) |
|
mask2 = np.zeros([0,self.img_sz-crs_x]) if crs_y==0 \ |
|
else np.zeros([crs_y,0]) |
|
else: |
|
resize_bg2 = A.Compose([ |
|
A.Resize(crs_y, self.img_sz-crs_x, always_apply=True)]) |
|
temp = resize_bg2(image=imgs[1], mask=masks[1]) |
|
img2 = temp['image'] |
|
mask2 = temp['mask'] |
|
|
|
if crs_y==self.img_sz or crs_x==0: |
|
img3 = np.zeros([0,crs_x,3]) if crs_y==self.img_sz \ |
|
else np.zeros([self.img_sz-crs_y,0,3]) |
|
mask3 = np.zeros([0,crs_x]) if crs_y==self.img_sz \ |
|
else np.zeros([self.img_sz-crs_y,0]) |
|
else: |
|
resize_bg3 = A.Compose([ |
|
A.Resize(self.img_sz-crs_y, crs_x, always_apply=True)]) |
|
temp = resize_bg3(image=imgs[2], mask=masks[2]) |
|
img3 = temp['image'] |
|
mask3 = temp['mask'] |
|
|
|
if crs_y==self.img_sz or crs_x==self.img_sz: |
|
img4 = np.zeros([0,self.img_sz-crs_x,3]) if crs_y==self.img_sz \ |
|
else np.zeros([self.img_sz-crs_y,0,3]) |
|
mask4 = np.zeros([0,self.img_sz-crs_x]) if crs_y==self.img_sz \ |
|
else np.zeros([self.img_sz-crs_y,0]) |
|
else: |
|
resize_bg4 = A.Compose([ |
|
A.Resize(self.img_sz-crs_y, |
|
self.img_sz-crs_x, always_apply=True)]) |
|
temp = resize_bg4(image=imgs[3], mask=masks[3]) |
|
img4 = temp['image'] |
|
mask4 = temp['mask'] |
|
|
|
imgs = [img1, img2, img3, img4] |
|
masks = [mask1, mask2, mask3, mask4] |
|
|
|
|
|
if self.aug.blur: |
|
imgs = [np.asarray(Image.fromarray(x).filter(self.blur)) if i!=insert_idx else x for i, x in enumerate(imgs)] |
|
|
|
num_rows = num_cols = int(math.sqrt(num_bgs)) |
|
idxs = [(i*num_cols,i*num_cols+num_cols) for i in range(num_rows)] |
|
img = [np.concatenate(imgs[_from:_to], axis=1) for (_from, _to) in idxs] |
|
img = np.concatenate(img, axis=0).astype(np.uint8) |
|
|
|
masks_arr = [] |
|
for bg_idx in range(num_bgs): |
|
mask = masks[bg_idx] |
|
temp = [mask if idx==bg_idx else np.zeros_like(masks[idx]) for idx in range(num_bgs)] |
|
mask = [np.concatenate(temp[_from:_to], axis=1) for (_from, _to) in idxs] |
|
mask = np.concatenate(mask, axis=0).astype(np.int32) |
|
masks_arr.append(mask) |
|
masks = masks_arr |
|
|
|
mask = masks[target_idx] |
|
mask = mask.astype(np.uint8) |
|
mask[mask>0] = 1 |
|
|
|
item = self.transforms(image=img, mask=mask) |
|
img_tensor = item['image'] |
|
target = item['mask'].long() |
|
|
|
target_ref_idx = self.ref_id2idx[target_ref_id] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
padded_input_ids = self.input_ids[target_ref_idx][target_sent_idx] |
|
tensor_embeddings = torch.tensor(padded_input_ids).unsqueeze(0) |
|
attention_mask = self.attention_masks[target_ref_idx][target_sent_idx] |
|
attention_mask = torch.tensor(attention_mask).unsqueeze(0) |
|
|
|
empty = dataset_dict.get("empty", False) |
|
dataset_dict["empty"] = empty |
|
dataset_dict['image'] = img_tensor |
|
dataset_dict['gt_masks'] = target.unsqueeze(0) |
|
dataset_dict['lang_tokens'] = tensor_embeddings |
|
dataset_dict['lang_mask'] = attention_mask |
|
|
|
dataset_dict["gt_mask_merged"] = target.unsqueeze(0) |
|
|
|
|
|
item = { |
|
'image': img_tensor.float(), |
|
'seg_target': target.long(), |
|
'sentence': tensor_embeddings, |
|
'attn_mask': attention_mask, |
|
} |
|
return item |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
Test the COCO json dataset loader. |
|
|
|
Usage: |
|
python -m detectron2.data.datasets.coco \ |
|
path/to/json path/to/image_root dataset_name |
|
|
|
"dataset_name" can be "coco_2014_minival_100", or other |
|
pre-registered ones |
|
""" |
|
from detectron2.utils.logger import setup_logger |
|
from detectron2.utils.visualizer import Visualizer |
|
import detectron2.data.datasets |
|
import sys |
|
|
|
REFCOCO_PATH = '/data2/projects/donghwa/RIS/ReLA/datasets' |
|
COCO_TRAIN_2014_IMAGE_ROOT = '/data2/projects/donghwa/RIS/ReLA/datasets/images' |
|
REFCOCO_DATASET = 'grefcoco' |
|
REFCOCO_SPLITBY = 'unc' |
|
REFCOCO_SPLIT = 'train' |
|
|
|
logger = setup_logger(name=__name__) |
|
dicts = load_grefcoco_json(REFCOCO_PATH, REFCOCO_DATASET, REFCOCO_SPLITBY, REFCOCO_SPLIT, COCO_TRAIN_2014_IMAGE_ROOT) |
|
logger.info("Done loading {} samples.".format(len(dicts))) |
|
|