VRIS_vip / LAVT-RIS /donghwa /dataset_grefer_mosaic_retrieval.py
dianecy's picture
Upload folder using huggingface_hub
8d82201 verified
raw
history blame
22.9 kB
import contextlib
import io
import logging
import numpy as np
import os
import random
import copy
import pycocotools.mask as mask_util
from fvcore.common.timer import Timer
from PIL import Image
import torch.utils.data as data
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
from detectron2.utils.file_io import PathManager
import time
import copy
import logging
import torch
import math
from detectron2.config import configurable
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
# from transformers import BertTokenizer
from bert.tokenization_bert import BertTokenizer
from pycocotools import mask as coco_mask
from data.utils import convert_coco_poly_to_mask, build_transform_train, build_transform_test
from .utils import cosine_annealing
"""
This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format".
"""
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image, ImageDraw, ImageFilter
import lmdb
import pyarrow as pa
def loads_pyarrow(buf):
return pa.deserialize(buf)
logger = logging.getLogger(__name__)
__all__ = ["load_refcoco_json"]
class GReferDataset(data.Dataset):
def __init__(self,
args, refer_root, dataset_name, splitby, split, image_root,
img_format="RGB", merge=True,
extra_annotation_keys=None, extra_refer_keys=None):
self.refer_root = refer_root
self.dataset_name = dataset_name
self.splitby = splitby
self.split = split
self.image_root = image_root
self.extra_annotation_keys = extra_annotation_keys
self.extra_refer_keys = extra_refer_keys
self.img_format = img_format
self.merge = merge
self.max_tokens = 20
self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
if split == "train":
self.tfm_gens = build_transform_train(args)
elif split in ["val", "test", "testA", "testB"]:
self.tfm_gens = build_transform_test(args)
if self.dataset_name == 'refcocop':
self.dataset_name = 'refcoco+'
if self.dataset_name == 'refcoco' or self.dataset_name == 'refcoco+':
self.splitby == 'unc'
if self.dataset_name == 'refcocog':
assert self.splitby == 'umd' or self.splitby == 'google'
dataset_id = '_'.join([self.dataset_name, self.splitby, self.split])
from refer.grefer import G_REFER
logger.info('Loading dataset {} ({}-{}) ...'.format(self.dataset_name, self.splitby, self.split))
logger.info('Refcoco root: {}'.format(self.refer_root))
timer = Timer()
self.refer_root = PathManager.get_local_path(self.refer_root)
with contextlib.redirect_stdout(io.StringIO()):
self.refer = G_REFER(data_root=self.refer_root,
dataset=self.dataset_name,
splitBy=self.splitby)
if timer.seconds() > 1:
logger.info("Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds()))
self.ref_ids = self.refer.getRefIds(split=self.split)
self.img_ids = self.refer.getImgIds(self.ref_ids)
self.refs = self.refer.loadRefs(self.ref_ids)
imgs = [self.refer.loadImgs(ref['image_id'])[0] for ref in self.refs]
anns = [self.refer.loadAnns(ref['ann_id']) for ref in self.refs]
self.imgs_refs_anns = list(zip(imgs, self.refs, anns))
logger.info("Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format(len(self.img_ids), len(self.ref_ids), dataset_id))
self.dataset_dicts = []
ann_keys = ["iscrowd", "bbox", "category_id"] + (self.extra_annotation_keys or [])
ref_keys = ["raw", "sent_id"] + (self.extra_refer_keys or [])
ann_lib = {}
NT_count = 0
MT_count = 0
for idx, (img_dict, ref_dict, anno_dicts) in enumerate(self.imgs_refs_anns):
record = {}
record['id'] = idx
record["source"] = 'grefcoco'
record["file_name"] = os.path.join(self.image_root, img_dict["file_name"])
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
image_id = record["image_id"] = img_dict["id"]
# Check that information of image, ann and ref match each other
# This fails only when the data parsing logic or the annotation file is buggy.
assert ref_dict['image_id'] == image_id
assert ref_dict['split'] == self.split
if not isinstance(ref_dict['ann_id'], list):
ref_dict['ann_id'] = [ref_dict['ann_id']]
# No target samples
if None in anno_dicts:
assert anno_dicts == [None]
assert ref_dict['ann_id'] == [-1]
record['empty'] = True
obj = {key: None for key in ann_keys if key in ann_keys}
obj["bbox_mode"] = BoxMode.XYWH_ABS
obj["empty"] = True
obj = [obj]
# Multi target samples
else:
record['empty'] = False
obj = []
for anno_dict in anno_dicts:
ann_id = anno_dict['id']
if anno_dict['iscrowd']:
continue
assert anno_dict["image_id"] == image_id
assert ann_id in ref_dict['ann_id']
if ann_id in ann_lib:
ann = ann_lib[ann_id]
else:
ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict}
ann["bbox_mode"] = BoxMode.XYWH_ABS
ann["empty"] = False
segm = anno_dict.get("segmentation", None)
assert segm # either list[list[float]] or dict(RLE)
if isinstance(segm, dict):
if isinstance(segm["counts"], list):
# convert to compressed RLE
segm = mask_util.frPyObjects(segm, *segm["size"])
else:
# filter out invalid polygons (< 3 points)
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue # ignore this instance
ann["segmentation"] = segm
ann_lib[ann_id] = ann
obj.append(ann)
record["annotations"] = obj
# Process referring expressions
sents = ref_dict['sentences']
for sent in sents:
ref_record = record.copy()
ref = {key: sent[key] for key in ref_keys if key in sent}
ref["ref_id"] = ref_dict["ref_id"]
ref_record["sentence"] = ref
self.dataset_dicts.append(ref_record)
# if ref_record['empty']:
# NT_count += 1
# else:
# MT_count += 1
# logger.info("NT samples: %d, MT samples: %d", NT_count, MT_count)
# Debug mode
# return self.dataset_dicts[:100]
# grefcoco
self.classes = []
self.aug = args.aug
self.bert_type = args.bert_tokenizer
self.img_sz = args.img_size
each_img_sz = int(args.img_size/math.sqrt(self.aug.num_bgs))
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
self.resize_bg1 = A.Compose([
A.Resize(args.img_size, args.img_size, always_apply=True)])
self.resize_bg4 = A.Compose([
A.Resize(each_img_sz, each_img_sz, always_apply=True)],
additional_targets={'image1': 'image', 'image2': 'image', 'image3': 'image',
'mask1': 'mask', 'mask2': 'mask', 'mask3': 'mask',})
self.transforms = A.Compose([
A.Normalize(mean=mean, std=std),
ToTensorV2 (),
])
# ref_ids = self.refer.getRefIds(split=self.split)
# img_ids = self.refer.getImgIds(ref_ids)
all_imgs = self.refer.Imgs
self.imgs = list(all_imgs[i] for i in self.img_ids)
# self.ref_ids = ref_ids#[:500]
self.ref_id2idx = dict(zip(self.ref_ids, range(len(self.ref_ids))))
self.ref_idx2id = dict(zip(range(len(self.ref_ids)), self.ref_ids))
self.img2refs = self.refer.imgToRefs
# self.tokenizer.add_special_tokens({'additional_special_tokens': task_tokens})
# self.tokenizer.add_tokens(position_tokens)
# if we are testing on a dataset, test all sentences of an object;
# o/w, we are validating during training, randomly sample one sentence for efficiency
self.max_tokens = 20
self.is_train = True if split == "train" else False
self.input_ids = []
self.attention_masks = []
for i, r in enumerate(self.ref_ids):
ref = self.refer.Refs[r]
sentences_for_ref = []
attentions_for_ref = []
for j, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])):
sentence_raw = el['raw']
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True, max_length=self.max_tokens, truncation=True)
#input_ids = input_ids[:self.max_tokens]
padded_input_ids = [0] * self.max_tokens
padded_input_ids[:len(input_ids)] = input_ids
attention_mask = [0] * self.max_tokens
attention_mask[:len(input_ids)] = [1]*len(input_ids)
sentences_for_ref.append(padded_input_ids)
attentions_for_ref.append(attention_mask)
self.input_ids.append(sentences_for_ref)
self.attention_masks.append(attentions_for_ref)
if self.aug.blur:
self.blur = ImageFilter.GaussianBlur(100)
# Load mldb data / temptative
lmdb_path = f'/data2/dataset/RefCOCO/logit_db/grefcoco/grefcoco.lmdb'
# self.lmdb_dict = read_from_lmdb(lmdb_dir)
self.lmdb_env = lmdb.open(
lmdb_path, subdir=False, max_readers=32,
readonly=True, lock=False,
readahead=False, meminit=False
)
with self.lmdb_env.begin(write=False) as txn:
self.length = loads_pyarrow(txn.get(b'__len__'))
self.keys = loads_pyarrow(txn.get(b'__keys__'))
self.epoch = 0
np.random.seed()
@staticmethod
def _merge_masks(x):
return x.sum(dim=0, keepdim=True).clamp(max=1)
def __len__(self):
return len(self.dataset_dicts)
def __getitem__(self, index):
refid = self.ref_idx2id[index]
with self.lmdb_env.begin(write=False) as txn:
byteflow = txn.get(self.keys[refid])
lmdb_dict = loads_pyarrow(byteflow)
dataset_dict = copy.deepcopy(self.dataset_dicts[index])
img_id = dataset_dict["image_id"]
index = dataset_dict["id"]
# decide mosaic size
if self.split=='train':
if self.aug.num_bgs==4:
aug_prob = self.aug.aug_prob # 0.6
if self.epoch < self.aug.retrieval_epoch :
num_bgs = np.random.choice([1, 4], p=[1-aug_prob, aug_prob])
else :
rand_prob = cosine_annealing(epoch=self.epoch-self.aug.retrieval_epoch, \
n_epochs=self.args.epochs-self.aug.retrieval_epoch,
n_cycles=1, lrate_max=aug_prob)
retr_prob = aug_prob-rand_prob
choice = np.random.choice(['one', 'random', 'retrieval'], p=[1-aug_prob, rand_prob, retr_prob])
if choice == 'one':
num_bgs = 1
else :
num_bgs = 4
else:
num_bgs = 1
else:
num_bgs = 1
target_sent_idx = np.random.choice(len(self.input_ids[index]))
ref_id = self.ref_idx2id[index]
insert_idx = np.random.choice(range(num_bgs))
if num_bgs==1:
ref_ids = []
sent_idxs = []
sents = np.array([], dtype='str')
img_ids = [self.refer.Refs[ref_id]['image_id']]
else:
if self.epoch >= self.aug.retrieval_epoch :
sent_id = list(lmdb_dict.keys())[target_sent_idx]
img_ids = list(np.random.choice(lmdb_dict[sent_id], size=num_bgs-1, replace=True))
img_ids = np.insert(img_ids, insert_idx, self.refer.Refs[ref_id]['image_id'])
ref_ids = list(np.random.choice(self.ref_ids, size=num_bgs-1, replace=False))
sent_idxs = [np.random.choice(len(self.refer.Refs[r]['sentences'])) for r in ref_ids]
sents = np.array([self.refer.Refs[r]['sentences'][sent_idxs[i]]['raw'] for i, r in enumerate(ref_ids)], dtype='str')
ref_ids = np.insert(ref_ids, insert_idx, self.ref_idx2id[index]).astype(int)
sents = np.insert(sents, insert_idx,
self.refer.Refs[ref_ids[insert_idx]]['sentences'][target_sent_idx]['raw'])
sent_idxs = np.insert(sent_idxs, insert_idx, target_sent_idx).astype(int)
# pick a target origin
if self.aug.tgt_selection == 'random':
target_idx = np.random.choice(range(num_bgs))
target_ref_idx = self.ref_id2idx[ref_ids[target_idx]]
target_sent_idx = int(np.random.choice(len(self.input_ids[target_ref_idx])))
elif self.aug.tgt_selection == 'longest':
target_idx = np.argmax(list(map(len, sents)))
target_sent_idx = sent_idxs[target_idx]
elif self.aug.tgt_selection == 'fixed':
target_idx = insert_idx
# target_ref_id = ref_ids[target_idx]
target_ref_id = self.ref_idx2id[index]
# load items
imgs, masks = [], []
if self.epoch >= self.aug.retrieval_epoch :
for img_id in img_ids:
img_info = self.refer.Imgs[img_id]
img_path = os.path.join(self.refer.IMAGE_DIR, img_info['file_name'])
img = Image.open(img_path).convert("RGB")
imgs.append(np.array(img))
ref = self.refer.imgToRefs[img_id][0]
# if self.dataset_name in ['refcoco', 'refcoco+', 'refcocog']:
# mask = np.array(self.refer.getMask(ref)['mask'])
# elif self.dataset_name in ['grefcoco'] :
mask = self.refer.getMaskByRef(ref, ref['ref_id'], self.merge)['mask']
masks.append(mask)
else :
for ref_id in ref_ids:
img_id = self.refer.getImgIds([ref_id])[0]
img_info = self.refer.Imgs[img_id]
img_path = os.path.join(self.refer.IMAGE_DIR, img_info['file_name'])
img = Image.open(img_path).convert("RGB")
imgs.append(np.array(img))
ref = self.refer.loadRefs(ref_ids=[ref_id])
# if self.dataset_name in ['refcoco', 'refcoco+', 'refcocog']:
# mask = np.array(self.refer.getMask(ref[0])['mask'])
# elif self.dataset_name in ['grefcoco'] :
mask = self.refer.getMaskByRef(ref[0], ref_id, self.merge)['mask']
masks.append(mask)
# image resize and apply 4in1 augmentation
if num_bgs==1:
resized = self.resize_bg1(image=imgs[0], mask=masks[0])
imgs, masks = [resized['image']], [resized['mask']]
img = imgs[0]
else:
if self.aug.move_crs_pnt:
crs_y = np.random.randint(0, self.img_sz+1)
crs_x = np.random.randint(0, self.img_sz+1)
else:
crs_y = 480//2 #
crs_x = 480//2 #
if crs_y==0 or crs_x==0:
img1 = np.zeros([0,crs_x,3]) if crs_y==0 else np.zeros([crs_y,0,3])
mask1 = np.zeros([0,crs_x]) if crs_y==0 else np.zeros([crs_y,0])
else:
resize_bg1 = A.Compose([A.Resize(crs_y, crs_x, always_apply=True)])
temp = resize_bg1(image=imgs[0], mask=masks[0])
img1 = temp['image']
mask1 = temp['mask']
if crs_y==0 or crs_x==self.img_sz:
img2 = np.zeros([0,self.img_sz-crs_x,3]) if crs_y==0 \
else np.zeros([crs_y,0,3])
mask2 = np.zeros([0,self.img_sz-crs_x]) if crs_y==0 \
else np.zeros([crs_y,0])
else:
resize_bg2 = A.Compose([
A.Resize(crs_y, self.img_sz-crs_x, always_apply=True)])
temp = resize_bg2(image=imgs[1], mask=masks[1])
img2 = temp['image']
mask2 = temp['mask']
if crs_y==self.img_sz or crs_x==0:
img3 = np.zeros([0,crs_x,3]) if crs_y==self.img_sz \
else np.zeros([self.img_sz-crs_y,0,3])
mask3 = np.zeros([0,crs_x]) if crs_y==self.img_sz \
else np.zeros([self.img_sz-crs_y,0])
else:
resize_bg3 = A.Compose([
A.Resize(self.img_sz-crs_y, crs_x, always_apply=True)])
temp = resize_bg3(image=imgs[2], mask=masks[2])
img3 = temp['image']
mask3 = temp['mask']
if crs_y==self.img_sz or crs_x==self.img_sz:
img4 = np.zeros([0,self.img_sz-crs_x,3]) if crs_y==self.img_sz \
else np.zeros([self.img_sz-crs_y,0,3])
mask4 = np.zeros([0,self.img_sz-crs_x]) if crs_y==self.img_sz \
else np.zeros([self.img_sz-crs_y,0])
else:
resize_bg4 = A.Compose([
A.Resize(self.img_sz-crs_y,
self.img_sz-crs_x, always_apply=True)])
temp = resize_bg4(image=imgs[3], mask=masks[3])
img4 = temp['image']
mask4 = temp['mask']
imgs = [img1, img2, img3, img4]
masks = [mask1, mask2, mask3, mask4]
# scale effect ablation
if self.aug.blur:
imgs = [np.asarray(Image.fromarray(x).filter(self.blur)) if i!=insert_idx else x for i, x in enumerate(imgs)]
num_rows = num_cols = int(math.sqrt(num_bgs))
idxs = [(i*num_cols,i*num_cols+num_cols) for i in range(num_rows)]
img = [np.concatenate(imgs[_from:_to], axis=1) for (_from, _to) in idxs]
img = np.concatenate(img, axis=0).astype(np.uint8)
masks_arr = []
for bg_idx in range(num_bgs):
mask = masks[bg_idx]
temp = [mask if idx==bg_idx else np.zeros_like(masks[idx]) for idx in range(num_bgs)]
mask = [np.concatenate(temp[_from:_to], axis=1) for (_from, _to) in idxs]
mask = np.concatenate(mask, axis=0).astype(np.int32)
masks_arr.append(mask)
masks = masks_arr
mask = masks[target_idx]
mask = mask.astype(np.uint8)
mask[mask>0] = 1
item = self.transforms(image=img, mask=mask)
img_tensor = item['image']
target = item['mask'].long()
target_ref_idx = self.ref_id2idx[target_ref_id]
# if self.is_train:
# embedding = []
# att = []
# for s in range(len(self.input_ids[target_ref_idx])):
# padded_input_ids = self.input_ids[target_ref_idx][s]
# tensor_embeddings = torch.tensor(padded_input_ids).unsqueeze(0)
# attention_mask = self.attention_masks[target_ref_idx][s]
# attention_mask = torch.tensor(attention_mask).unsqueeze(0)
# embedding.append(tensor_embeddings.unsqueeze(-1))
# att.append(attention_mask.unsqueeze(-1))
# tensor_embeddings = torch.cat(embedding, dim=-1)
# attention_mask = torch.cat(att, dim=-1)
# else:
padded_input_ids = self.input_ids[target_ref_idx][target_sent_idx]
tensor_embeddings = torch.tensor(padded_input_ids).unsqueeze(0)
attention_mask = self.attention_masks[target_ref_idx][target_sent_idx]
attention_mask = torch.tensor(attention_mask).unsqueeze(0)
empty = dataset_dict.get("empty", False)
dataset_dict["empty"] = empty
dataset_dict['image'] = img_tensor
dataset_dict['gt_masks'] = target.unsqueeze(0)
dataset_dict['lang_tokens'] = tensor_embeddings
dataset_dict['lang_mask'] = attention_mask
# dataset_dict["gt_mask_merged"] = self._merge_masks(target) if self.merge else None
dataset_dict["gt_mask_merged"] = target.unsqueeze(0)
item = {
'image': img_tensor.float(),
'seg_target': target.long(),
'sentence': tensor_embeddings,
'attn_mask': attention_mask,
}
return item
if __name__ == "__main__":
"""
Test the COCO json dataset loader.
Usage:
python -m detectron2.data.datasets.coco \
path/to/json path/to/image_root dataset_name
"dataset_name" can be "coco_2014_minival_100", or other
pre-registered ones
"""
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
import detectron2.data.datasets # noqa # add pre-defined metadata
import sys
REFCOCO_PATH = '/data2/projects/donghwa/RIS/ReLA/datasets'
COCO_TRAIN_2014_IMAGE_ROOT = '/data2/projects/donghwa/RIS/ReLA/datasets/images'
REFCOCO_DATASET = 'grefcoco'
REFCOCO_SPLITBY = 'unc'
REFCOCO_SPLIT = 'train'
logger = setup_logger(name=__name__)
dicts = load_grefcoco_json(REFCOCO_PATH, REFCOCO_DATASET, REFCOCO_SPLITBY, REFCOCO_SPLIT, COCO_TRAIN_2014_IMAGE_ROOT)
logger.info("Done loading {} samples.".format(len(dicts)))