import contextlib import io import logging import numpy as np import os import random import copy import pycocotools.mask as mask_util from fvcore.common.timer import Timer from PIL import Image import torch.utils.data as data from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes from detectron2.utils.file_io import PathManager import time import copy import logging import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T # from transformers import BertTokenizer from bert.tokenization_bert import BertTokenizer from pycocotools import mask as coco_mask from data.utils import convert_coco_poly_to_mask, build_transform_train, build_transform_test """ This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format". """ logger = logging.getLogger(__name__) __all__ = ["load_refcoco_json"] class GReferDataset(data.Dataset): def __init__(self, args, refer_root, dataset_name, splitby, split, image_root, img_format="RGB", merge=True, extra_annotation_keys=None, extra_refer_keys=None): self.refer_root = refer_root self.dataset_name = dataset_name self.splitby = splitby self.split = split self.image_root = image_root self.extra_annotation_keys = extra_annotation_keys self.extra_refer_keys = extra_refer_keys self.img_format = img_format self.merge = merge self.max_tokens = 20 self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) if split == "train": self.tfm_gens = build_transform_train(args) elif split in ["val", "test", "testA", "testB"]: self.tfm_gens = build_transform_test(args) if self.dataset_name == 'refcocop': self.dataset_name = 'refcoco+' if self.dataset_name == 'refcoco' or self.dataset_name == 'refcoco+': self.splitby == 'unc' if self.dataset_name == 'refcocog': assert self.splitby == 'umd' or self.splitby == 'google' dataset_id = '_'.join([self.dataset_name, self.splitby, self.split]) from refer.grefer import G_REFER logger.info('Loading dataset {} ({}-{}) ...'.format(self.dataset_name, self.splitby, self.split)) logger.info('Refcoco root: {}'.format(self.refer_root)) timer = Timer() self.refer_root = PathManager.get_local_path(self.refer_root) with contextlib.redirect_stdout(io.StringIO()): refer_api = G_REFER(data_root=self.refer_root, dataset=self.dataset_name, splitBy=self.splitby) if timer.seconds() > 1: logger.info("Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds())) self.ref_ids = refer_api.getRefIds(split=self.split) self.img_ids = refer_api.getImgIds(self.ref_ids) self.refs = refer_api.loadRefs(self.ref_ids) imgs = [refer_api.loadImgs(ref['image_id'])[0] for ref in self.refs] anns = [refer_api.loadAnns(ref['ann_id']) for ref in self.refs] self.imgs_refs_anns = list(zip(imgs, self.refs, anns)) logger.info("Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format(len(self.img_ids), len(self.ref_ids), dataset_id)) self.dataset_dicts = [] ann_keys = ["iscrowd", "bbox", "category_id"] + (self.extra_annotation_keys or []) ref_keys = ["raw", "sent_id"] + (self.extra_refer_keys or []) ann_lib = {} NT_count = 0 MT_count = 0 for idx, (img_dict, ref_dict, anno_dicts) in enumerate(self.imgs_refs_anns): record = {} record['id'] = idx record["source"] = 'grefcoco' record["file_name"] = os.path.join(self.image_root, img_dict["file_name"]) record["height"] = img_dict["height"] record["width"] = img_dict["width"] image_id = record["image_id"] = img_dict["id"] # Check that information of image, ann and ref match each other # This fails only when the data parsing logic or the annotation file is buggy. assert ref_dict['image_id'] == image_id assert ref_dict['split'] == self.split if not isinstance(ref_dict['ann_id'], list): ref_dict['ann_id'] = [ref_dict['ann_id']] # No target samples if None in anno_dicts: assert anno_dicts == [None] assert ref_dict['ann_id'] == [-1] record['empty'] = True obj = {key: None for key in ann_keys if key in ann_keys} obj["bbox_mode"] = BoxMode.XYWH_ABS obj["empty"] = True obj = [obj] # Multi target samples else: record['empty'] = False obj = [] for anno_dict in anno_dicts: ann_id = anno_dict['id'] if anno_dict['iscrowd']: continue assert anno_dict["image_id"] == image_id assert ann_id in ref_dict['ann_id'] if ann_id in ann_lib: ann = ann_lib[ann_id] else: ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict} ann["bbox_mode"] = BoxMode.XYWH_ABS ann["empty"] = False segm = anno_dict.get("segmentation", None) assert segm # either list[list[float]] or dict(RLE) if isinstance(segm, dict): if isinstance(segm["counts"], list): # convert to compressed RLE segm = mask_util.frPyObjects(segm, *segm["size"]) else: # filter out invalid polygons (< 3 points) segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] if len(segm) == 0: num_instances_without_valid_segmentation += 1 continue # ignore this instance ann["segmentation"] = segm ann_lib[ann_id] = ann obj.append(ann) record["annotations"] = obj # Process referring expressions sents = ref_dict['sentences'] for sent in sents: ref_record = record.copy() ref = {key: sent[key] for key in ref_keys if key in sent} ref["ref_id"] = ref_dict["ref_id"] ref_record["sentence"] = ref self.dataset_dicts.append(ref_record) # if ref_record['empty']: # NT_count += 1 # else: # MT_count += 1 # logger.info("NT samples: %d, MT samples: %d", NT_count, MT_count) # Debug mode # return self.dataset_dicts[:100] @staticmethod def _merge_masks(x): return x.sum(dim=0, keepdim=True).clamp(max=1) def __getitem__(self, index): dataset_dict = copy.deepcopy(self.dataset_dicts[index]) # dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) # TODO: get padding mask # by feeding a "segmentation mask" to the same transforms padding_mask = np.ones(image.shape[:2]) image, transforms = T.apply_transform_gens(self.tfm_gens, image) # the crop transformation has default padding value 0 for segmentation padding_mask = transforms.apply_segmentation(padding_mask) padding_mask = ~padding_mask.astype(bool) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations(obj, transforms, image_shape) for obj in dataset_dict.pop("annotations") if (obj.get("iscrowd", 0) == 0) and (obj.get("empty", False) == False) ] instances = utils.annotations_to_instances(annos, image_shape) empty = dataset_dict.get("empty", False) if len(instances) > 0: assert (not empty) instances.gt_boxes = instances.gt_masks.get_bounding_boxes() # Generate masks from polygon h, w = instances.image_size assert hasattr(instances, 'gt_masks') gt_masks = instances.gt_masks gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) instances.gt_masks = gt_masks else: assert empty gt_masks = torch.zeros((0, image_shape[0], image_shape[1]), dtype=torch.uint8) instances.gt_masks = gt_masks if self.split == "train" : dataset_dict["instances"] = instances else: dataset_dict["gt_mask"] = gt_masks dataset_dict["empty"] = empty dataset_dict["gt_mask_merged"] = self._merge_masks(gt_masks) if self.merge else None # dataset_dict["gt_mask_merged"] = dataset_dict["gt_mask_merged"].float() # Language data sentence_raw = dataset_dict['sentence']['raw'] attention_mask = [0] * self.max_tokens padded_input_ids = [0] * self.max_tokens input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) input_ids = input_ids[:self.max_tokens] padded_input_ids[:len(input_ids)] = input_ids attention_mask[:len(input_ids)] = [1] * len(input_ids) dataset_dict['lang_tokens'] = torch.tensor(padded_input_ids).unsqueeze(0) dataset_dict['lang_mask'] = torch.tensor(attention_mask).unsqueeze(0) return dataset_dict["image"].float(), dataset_dict["gt_mask_merged"].squeeze(0).long(), dataset_dict['lang_tokens'], dataset_dict['lang_mask'] def __len__(self): return len(self.dataset_dicts) if __name__ == "__main__": """ Test the COCO json dataset loader. Usage: python -m detectron2.data.datasets.coco \ path/to/json path/to/image_root dataset_name "dataset_name" can be "coco_2014_minival_100", or other pre-registered ones """ from detectron2.utils.logger import setup_logger from detectron2.utils.visualizer import Visualizer import detectron2.data.datasets # noqa # add pre-defined metadata import sys REFCOCO_PATH = '/data2/projects/donghwa/RIS/ReLA/datasets' COCO_TRAIN_2014_IMAGE_ROOT = '/data2/projects/donghwa/RIS/ReLA/datasets/images' REFCOCO_DATASET = 'grefcoco' REFCOCO_SPLITBY = 'unc' REFCOCO_SPLIT = 'train' logger = setup_logger(name=__name__) dicts = load_grefcoco_json(REFCOCO_PATH, REFCOCO_DATASET, REFCOCO_SPLITBY, REFCOCO_SPLIT, COCO_TRAIN_2014_IMAGE_ROOT) logger.info("Done loading {} samples.".format(len(dicts)))