|
import contextlib |
|
import io |
|
import logging |
|
import numpy as np |
|
import os |
|
import random |
|
import copy |
|
import pycocotools.mask as mask_util |
|
from fvcore.common.timer import Timer |
|
from PIL import Image |
|
import torch.utils.data as data |
|
|
|
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes |
|
from detectron2.utils.file_io import PathManager |
|
|
|
import time |
|
import copy |
|
import logging |
|
import torch |
|
|
|
from detectron2.config import configurable |
|
from detectron2.data import detection_utils as utils |
|
from detectron2.data import transforms as T |
|
|
|
|
|
from bert.tokenization_bert import BertTokenizer |
|
from pycocotools import mask as coco_mask |
|
from data.utils import convert_coco_poly_to_mask, build_transform_train, build_transform_test |
|
""" |
|
This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format". |
|
""" |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
__all__ = ["load_refcoco_json"] |
|
|
|
class GReferDataset(data.Dataset): |
|
def __init__(self, |
|
args, refer_root, dataset_name, splitby, split, image_root, |
|
img_format="RGB", merge=True, |
|
extra_annotation_keys=None, extra_refer_keys=None): |
|
|
|
self.refer_root = refer_root |
|
self.dataset_name = dataset_name |
|
self.splitby = splitby |
|
self.split = split |
|
self.image_root = image_root |
|
self.extra_annotation_keys = extra_annotation_keys |
|
self.extra_refer_keys = extra_refer_keys |
|
self.img_format = img_format |
|
self.merge = merge |
|
self.max_tokens = 20 |
|
self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) |
|
|
|
if split == "train": |
|
self.tfm_gens = build_transform_train(args) |
|
elif split in ["val", "test", "testA", "testB"]: |
|
self.tfm_gens = build_transform_test(args) |
|
|
|
if self.dataset_name == 'refcocop': |
|
self.dataset_name = 'refcoco+' |
|
if self.dataset_name == 'refcoco' or self.dataset_name == 'refcoco+': |
|
self.splitby == 'unc' |
|
if self.dataset_name == 'refcocog': |
|
assert self.splitby == 'umd' or self.splitby == 'google' |
|
|
|
dataset_id = '_'.join([self.dataset_name, self.splitby, self.split]) |
|
|
|
from refer.grefer import G_REFER |
|
logger.info('Loading dataset {} ({}-{}) ...'.format(self.dataset_name, self.splitby, self.split)) |
|
logger.info('Refcoco root: {}'.format(self.refer_root)) |
|
timer = Timer() |
|
self.refer_root = PathManager.get_local_path(self.refer_root) |
|
with contextlib.redirect_stdout(io.StringIO()): |
|
refer_api = G_REFER(data_root=self.refer_root, |
|
dataset=self.dataset_name, |
|
splitBy=self.splitby) |
|
if timer.seconds() > 1: |
|
logger.info("Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds())) |
|
|
|
self.ref_ids = refer_api.getRefIds(split=self.split) |
|
self.img_ids = refer_api.getImgIds(self.ref_ids) |
|
self.refs = refer_api.loadRefs(self.ref_ids) |
|
imgs = [refer_api.loadImgs(ref['image_id'])[0] for ref in self.refs] |
|
anns = [refer_api.loadAnns(ref['ann_id']) for ref in self.refs] |
|
self.imgs_refs_anns = list(zip(imgs, self.refs, anns)) |
|
|
|
logger.info("Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format(len(self.img_ids), len(self.ref_ids), dataset_id)) |
|
|
|
|
|
self.dataset_dicts = [] |
|
|
|
ann_keys = ["iscrowd", "bbox", "category_id"] + (self.extra_annotation_keys or []) |
|
ref_keys = ["raw", "sent_id"] + (self.extra_refer_keys or []) |
|
|
|
ann_lib = {} |
|
|
|
NT_count = 0 |
|
MT_count = 0 |
|
|
|
for idx, (img_dict, ref_dict, anno_dicts) in enumerate(self.imgs_refs_anns): |
|
record = {} |
|
record['id'] = idx |
|
record["source"] = 'grefcoco' |
|
record["file_name"] = os.path.join(self.image_root, img_dict["file_name"]) |
|
record["height"] = img_dict["height"] |
|
record["width"] = img_dict["width"] |
|
image_id = record["image_id"] = img_dict["id"] |
|
|
|
|
|
|
|
assert ref_dict['image_id'] == image_id |
|
assert ref_dict['split'] == self.split |
|
if not isinstance(ref_dict['ann_id'], list): |
|
ref_dict['ann_id'] = [ref_dict['ann_id']] |
|
|
|
|
|
if None in anno_dicts: |
|
assert anno_dicts == [None] |
|
assert ref_dict['ann_id'] == [-1] |
|
record['empty'] = True |
|
obj = {key: None for key in ann_keys if key in ann_keys} |
|
obj["bbox_mode"] = BoxMode.XYWH_ABS |
|
obj["empty"] = True |
|
obj = [obj] |
|
|
|
|
|
else: |
|
record['empty'] = False |
|
obj = [] |
|
for anno_dict in anno_dicts: |
|
ann_id = anno_dict['id'] |
|
if anno_dict['iscrowd']: |
|
continue |
|
assert anno_dict["image_id"] == image_id |
|
assert ann_id in ref_dict['ann_id'] |
|
|
|
if ann_id in ann_lib: |
|
ann = ann_lib[ann_id] |
|
else: |
|
ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict} |
|
ann["bbox_mode"] = BoxMode.XYWH_ABS |
|
ann["empty"] = False |
|
|
|
segm = anno_dict.get("segmentation", None) |
|
assert segm |
|
if isinstance(segm, dict): |
|
if isinstance(segm["counts"], list): |
|
|
|
segm = mask_util.frPyObjects(segm, *segm["size"]) |
|
else: |
|
|
|
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] |
|
if len(segm) == 0: |
|
num_instances_without_valid_segmentation += 1 |
|
continue |
|
ann["segmentation"] = segm |
|
ann_lib[ann_id] = ann |
|
|
|
obj.append(ann) |
|
|
|
record["annotations"] = obj |
|
|
|
|
|
sents = ref_dict['sentences'] |
|
for sent in sents: |
|
ref_record = record.copy() |
|
ref = {key: sent[key] for key in ref_keys if key in sent} |
|
ref["ref_id"] = ref_dict["ref_id"] |
|
ref_record["sentence"] = ref |
|
self.dataset_dicts.append(ref_record) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def _merge_masks(x): |
|
return x.sum(dim=0, keepdim=True).clamp(max=1) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
dataset_dict = copy.deepcopy(self.dataset_dicts[index]) |
|
|
|
image = utils.read_image(dataset_dict["file_name"], format=self.img_format) |
|
utils.check_image_size(dataset_dict, image) |
|
|
|
|
|
|
|
padding_mask = np.ones(image.shape[:2]) |
|
image, transforms = T.apply_transform_gens(self.tfm_gens, image) |
|
|
|
padding_mask = transforms.apply_segmentation(padding_mask) |
|
padding_mask = ~padding_mask.astype(bool) |
|
|
|
image_shape = image.shape[:2] |
|
|
|
|
|
|
|
|
|
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) |
|
dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) |
|
|
|
|
|
annos = [ |
|
utils.transform_instance_annotations(obj, transforms, image_shape) |
|
for obj in dataset_dict.pop("annotations") |
|
if (obj.get("iscrowd", 0) == 0) and (obj.get("empty", False) == False) |
|
] |
|
instances = utils.annotations_to_instances(annos, image_shape) |
|
|
|
empty = dataset_dict.get("empty", False) |
|
|
|
if len(instances) > 0: |
|
assert (not empty) |
|
instances.gt_boxes = instances.gt_masks.get_bounding_boxes() |
|
|
|
h, w = instances.image_size |
|
assert hasattr(instances, 'gt_masks') |
|
gt_masks = instances.gt_masks |
|
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) |
|
instances.gt_masks = gt_masks |
|
else: |
|
assert empty |
|
gt_masks = torch.zeros((0, image_shape[0], image_shape[1]), dtype=torch.uint8) |
|
instances.gt_masks = gt_masks |
|
|
|
if self.split == "train" : |
|
dataset_dict["instances"] = instances |
|
else: |
|
dataset_dict["gt_mask"] = gt_masks |
|
|
|
dataset_dict["empty"] = empty |
|
dataset_dict["gt_mask_merged"] = self._merge_masks(gt_masks) if self.merge else None |
|
|
|
|
|
|
|
sentence_raw = dataset_dict['sentence']['raw'] |
|
attention_mask = [0] * self.max_tokens |
|
padded_input_ids = [0] * self.max_tokens |
|
|
|
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) |
|
|
|
input_ids = input_ids[:self.max_tokens] |
|
padded_input_ids[:len(input_ids)] = input_ids |
|
|
|
attention_mask[:len(input_ids)] = [1] * len(input_ids) |
|
|
|
dataset_dict['lang_tokens'] = torch.tensor(padded_input_ids).unsqueeze(0) |
|
dataset_dict['lang_mask'] = torch.tensor(attention_mask).unsqueeze(0) |
|
|
|
return dataset_dict["image"].float(), dataset_dict["gt_mask_merged"].squeeze(0).long(), dataset_dict['lang_tokens'], dataset_dict['lang_mask'] |
|
|
|
|
|
def __len__(self): |
|
return len(self.dataset_dicts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
Test the COCO json dataset loader. |
|
|
|
Usage: |
|
python -m detectron2.data.datasets.coco \ |
|
path/to/json path/to/image_root dataset_name |
|
|
|
"dataset_name" can be "coco_2014_minival_100", or other |
|
pre-registered ones |
|
""" |
|
from detectron2.utils.logger import setup_logger |
|
from detectron2.utils.visualizer import Visualizer |
|
import detectron2.data.datasets |
|
import sys |
|
|
|
REFCOCO_PATH = '/data2/projects/donghwa/RIS/ReLA/datasets' |
|
COCO_TRAIN_2014_IMAGE_ROOT = '/data2/projects/donghwa/RIS/ReLA/datasets/images' |
|
REFCOCO_DATASET = 'grefcoco' |
|
REFCOCO_SPLITBY = 'unc' |
|
REFCOCO_SPLIT = 'train' |
|
|
|
logger = setup_logger(name=__name__) |
|
dicts = load_grefcoco_json(REFCOCO_PATH, REFCOCO_DATASET, REFCOCO_SPLITBY, REFCOCO_SPLIT, COCO_TRAIN_2014_IMAGE_ROOT) |
|
logger.info("Done loading {} samples.".format(len(dicts))) |
|
|