VRIS_vip / LAVT-RIS /donghwa /dataset_grefer.py
dianecy's picture
Upload folder using huggingface_hub
8d82201 verified
raw
history blame
11.8 kB
import contextlib
import io
import logging
import numpy as np
import os
import random
import copy
import pycocotools.mask as mask_util
from fvcore.common.timer import Timer
from PIL import Image
import torch.utils.data as data
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
from detectron2.utils.file_io import PathManager
import time
import copy
import logging
import torch
from detectron2.config import configurable
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
# from transformers import BertTokenizer
from bert.tokenization_bert import BertTokenizer
from pycocotools import mask as coco_mask
from data.utils import convert_coco_poly_to_mask, build_transform_train, build_transform_test
"""
This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format".
"""
logger = logging.getLogger(__name__)
__all__ = ["load_refcoco_json"]
class GReferDataset(data.Dataset):
def __init__(self,
args, refer_root, dataset_name, splitby, split, image_root,
img_format="RGB", merge=True,
extra_annotation_keys=None, extra_refer_keys=None):
self.refer_root = refer_root
self.dataset_name = dataset_name
self.splitby = splitby
self.split = split
self.image_root = image_root
self.extra_annotation_keys = extra_annotation_keys
self.extra_refer_keys = extra_refer_keys
self.img_format = img_format
self.merge = merge
self.max_tokens = 20
self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
if split == "train":
self.tfm_gens = build_transform_train(args)
elif split in ["val", "test", "testA", "testB"]:
self.tfm_gens = build_transform_test(args)
if self.dataset_name == 'refcocop':
self.dataset_name = 'refcoco+'
if self.dataset_name == 'refcoco' or self.dataset_name == 'refcoco+':
self.splitby == 'unc'
if self.dataset_name == 'refcocog':
assert self.splitby == 'umd' or self.splitby == 'google'
dataset_id = '_'.join([self.dataset_name, self.splitby, self.split])
from refer.grefer import G_REFER
logger.info('Loading dataset {} ({}-{}) ...'.format(self.dataset_name, self.splitby, self.split))
logger.info('Refcoco root: {}'.format(self.refer_root))
timer = Timer()
self.refer_root = PathManager.get_local_path(self.refer_root)
with contextlib.redirect_stdout(io.StringIO()):
refer_api = G_REFER(data_root=self.refer_root,
dataset=self.dataset_name,
splitBy=self.splitby)
if timer.seconds() > 1:
logger.info("Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds()))
self.ref_ids = refer_api.getRefIds(split=self.split)
self.img_ids = refer_api.getImgIds(self.ref_ids)
self.refs = refer_api.loadRefs(self.ref_ids)
imgs = [refer_api.loadImgs(ref['image_id'])[0] for ref in self.refs]
anns = [refer_api.loadAnns(ref['ann_id']) for ref in self.refs]
self.imgs_refs_anns = list(zip(imgs, self.refs, anns))
logger.info("Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format(len(self.img_ids), len(self.ref_ids), dataset_id))
self.dataset_dicts = []
ann_keys = ["iscrowd", "bbox", "category_id"] + (self.extra_annotation_keys or [])
ref_keys = ["raw", "sent_id"] + (self.extra_refer_keys or [])
ann_lib = {}
NT_count = 0
MT_count = 0
for idx, (img_dict, ref_dict, anno_dicts) in enumerate(self.imgs_refs_anns):
record = {}
record['id'] = idx
record["source"] = 'grefcoco'
record["file_name"] = os.path.join(self.image_root, img_dict["file_name"])
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
image_id = record["image_id"] = img_dict["id"]
# Check that information of image, ann and ref match each other
# This fails only when the data parsing logic or the annotation file is buggy.
assert ref_dict['image_id'] == image_id
assert ref_dict['split'] == self.split
if not isinstance(ref_dict['ann_id'], list):
ref_dict['ann_id'] = [ref_dict['ann_id']]
# No target samples
if None in anno_dicts:
assert anno_dicts == [None]
assert ref_dict['ann_id'] == [-1]
record['empty'] = True
obj = {key: None for key in ann_keys if key in ann_keys}
obj["bbox_mode"] = BoxMode.XYWH_ABS
obj["empty"] = True
obj = [obj]
# Multi target samples
else:
record['empty'] = False
obj = []
for anno_dict in anno_dicts:
ann_id = anno_dict['id']
if anno_dict['iscrowd']:
continue
assert anno_dict["image_id"] == image_id
assert ann_id in ref_dict['ann_id']
if ann_id in ann_lib:
ann = ann_lib[ann_id]
else:
ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict}
ann["bbox_mode"] = BoxMode.XYWH_ABS
ann["empty"] = False
segm = anno_dict.get("segmentation", None)
assert segm # either list[list[float]] or dict(RLE)
if isinstance(segm, dict):
if isinstance(segm["counts"], list):
# convert to compressed RLE
segm = mask_util.frPyObjects(segm, *segm["size"])
else:
# filter out invalid polygons (< 3 points)
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue # ignore this instance
ann["segmentation"] = segm
ann_lib[ann_id] = ann
obj.append(ann)
record["annotations"] = obj
# Process referring expressions
sents = ref_dict['sentences']
for sent in sents:
ref_record = record.copy()
ref = {key: sent[key] for key in ref_keys if key in sent}
ref["ref_id"] = ref_dict["ref_id"]
ref_record["sentence"] = ref
self.dataset_dicts.append(ref_record)
# if ref_record['empty']:
# NT_count += 1
# else:
# MT_count += 1
# logger.info("NT samples: %d, MT samples: %d", NT_count, MT_count)
# Debug mode
# return self.dataset_dicts[:100]
@staticmethod
def _merge_masks(x):
return x.sum(dim=0, keepdim=True).clamp(max=1)
def __getitem__(self, index):
dataset_dict = copy.deepcopy(self.dataset_dicts[index])
# dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
utils.check_image_size(dataset_dict, image)
# TODO: get padding mask
# by feeding a "segmentation mask" to the same transforms
padding_mask = np.ones(image.shape[:2])
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
# the crop transformation has default padding value 0 for segmentation
padding_mask = transforms.apply_segmentation(padding_mask)
padding_mask = ~padding_mask.astype(bool)
image_shape = image.shape[:2] # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask))
# USER: Implement additional transformations if you have other types of data
annos = [
utils.transform_instance_annotations(obj, transforms, image_shape)
for obj in dataset_dict.pop("annotations")
if (obj.get("iscrowd", 0) == 0) and (obj.get("empty", False) == False)
]
instances = utils.annotations_to_instances(annos, image_shape)
empty = dataset_dict.get("empty", False)
if len(instances) > 0:
assert (not empty)
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
# Generate masks from polygon
h, w = instances.image_size
assert hasattr(instances, 'gt_masks')
gt_masks = instances.gt_masks
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
instances.gt_masks = gt_masks
else:
assert empty
gt_masks = torch.zeros((0, image_shape[0], image_shape[1]), dtype=torch.uint8)
instances.gt_masks = gt_masks
if self.split == "train" :
dataset_dict["instances"] = instances
else:
dataset_dict["gt_mask"] = gt_masks
dataset_dict["empty"] = empty
dataset_dict["gt_mask_merged"] = self._merge_masks(gt_masks) if self.merge else None
# dataset_dict["gt_mask_merged"] = dataset_dict["gt_mask_merged"].float()
# Language data
sentence_raw = dataset_dict['sentence']['raw']
attention_mask = [0] * self.max_tokens
padded_input_ids = [0] * self.max_tokens
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True)
input_ids = input_ids[:self.max_tokens]
padded_input_ids[:len(input_ids)] = input_ids
attention_mask[:len(input_ids)] = [1] * len(input_ids)
dataset_dict['lang_tokens'] = torch.tensor(padded_input_ids).unsqueeze(0)
dataset_dict['lang_mask'] = torch.tensor(attention_mask).unsqueeze(0)
return dataset_dict["image"].float(), dataset_dict["gt_mask_merged"].squeeze(0).long(), dataset_dict['lang_tokens'], dataset_dict['lang_mask']
def __len__(self):
return len(self.dataset_dicts)
if __name__ == "__main__":
"""
Test the COCO json dataset loader.
Usage:
python -m detectron2.data.datasets.coco \
path/to/json path/to/image_root dataset_name
"dataset_name" can be "coco_2014_minival_100", or other
pre-registered ones
"""
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
import detectron2.data.datasets # noqa # add pre-defined metadata
import sys
REFCOCO_PATH = '/data2/projects/donghwa/RIS/ReLA/datasets'
COCO_TRAIN_2014_IMAGE_ROOT = '/data2/projects/donghwa/RIS/ReLA/datasets/images'
REFCOCO_DATASET = 'grefcoco'
REFCOCO_SPLITBY = 'unc'
REFCOCO_SPLIT = 'train'
logger = setup_logger(name=__name__)
dicts = load_grefcoco_json(REFCOCO_PATH, REFCOCO_DATASET, REFCOCO_SPLITBY, REFCOCO_SPLIT, COCO_TRAIN_2014_IMAGE_ROOT)
logger.info("Done loading {} samples.".format(len(dicts)))