File size: 11,811 Bytes
8d82201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
import contextlib
import io
import logging
import numpy as np
import os
import random
import copy
import pycocotools.mask as mask_util
from fvcore.common.timer import Timer
from PIL import Image
import torch.utils.data as data
from detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
from detectron2.utils.file_io import PathManager
import time
import copy
import logging
import torch
from detectron2.config import configurable
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
# from transformers import BertTokenizer
from bert.tokenization_bert import BertTokenizer
from pycocotools import mask as coco_mask
from data.utils import convert_coco_poly_to_mask, build_transform_train, build_transform_test
"""
This file contains functions to parse RefCOCO-format annotations into dicts in "Detectron2 format".
"""
logger = logging.getLogger(__name__)
__all__ = ["load_refcoco_json"]
class GReferDataset(data.Dataset):
def __init__(self,
args, refer_root, dataset_name, splitby, split, image_root,
img_format="RGB", merge=True,
extra_annotation_keys=None, extra_refer_keys=None):
self.refer_root = refer_root
self.dataset_name = dataset_name
self.splitby = splitby
self.split = split
self.image_root = image_root
self.extra_annotation_keys = extra_annotation_keys
self.extra_refer_keys = extra_refer_keys
self.img_format = img_format
self.merge = merge
self.max_tokens = 20
self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
if split == "train":
self.tfm_gens = build_transform_train(args)
elif split in ["val", "test", "testA", "testB"]:
self.tfm_gens = build_transform_test(args)
if self.dataset_name == 'refcocop':
self.dataset_name = 'refcoco+'
if self.dataset_name == 'refcoco' or self.dataset_name == 'refcoco+':
self.splitby == 'unc'
if self.dataset_name == 'refcocog':
assert self.splitby == 'umd' or self.splitby == 'google'
dataset_id = '_'.join([self.dataset_name, self.splitby, self.split])
from refer.grefer import G_REFER
logger.info('Loading dataset {} ({}-{}) ...'.format(self.dataset_name, self.splitby, self.split))
logger.info('Refcoco root: {}'.format(self.refer_root))
timer = Timer()
self.refer_root = PathManager.get_local_path(self.refer_root)
with contextlib.redirect_stdout(io.StringIO()):
refer_api = G_REFER(data_root=self.refer_root,
dataset=self.dataset_name,
splitBy=self.splitby)
if timer.seconds() > 1:
logger.info("Loading {} takes {:.2f} seconds.".format(dataset_id, timer.seconds()))
self.ref_ids = refer_api.getRefIds(split=self.split)
self.img_ids = refer_api.getImgIds(self.ref_ids)
self.refs = refer_api.loadRefs(self.ref_ids)
imgs = [refer_api.loadImgs(ref['image_id'])[0] for ref in self.refs]
anns = [refer_api.loadAnns(ref['ann_id']) for ref in self.refs]
self.imgs_refs_anns = list(zip(imgs, self.refs, anns))
logger.info("Loaded {} images, {} referring object sets in G_RefCOCO format from {}".format(len(self.img_ids), len(self.ref_ids), dataset_id))
self.dataset_dicts = []
ann_keys = ["iscrowd", "bbox", "category_id"] + (self.extra_annotation_keys or [])
ref_keys = ["raw", "sent_id"] + (self.extra_refer_keys or [])
ann_lib = {}
NT_count = 0
MT_count = 0
for idx, (img_dict, ref_dict, anno_dicts) in enumerate(self.imgs_refs_anns):
record = {}
record['id'] = idx
record["source"] = 'grefcoco'
record["file_name"] = os.path.join(self.image_root, img_dict["file_name"])
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
image_id = record["image_id"] = img_dict["id"]
# Check that information of image, ann and ref match each other
# This fails only when the data parsing logic or the annotation file is buggy.
assert ref_dict['image_id'] == image_id
assert ref_dict['split'] == self.split
if not isinstance(ref_dict['ann_id'], list):
ref_dict['ann_id'] = [ref_dict['ann_id']]
# No target samples
if None in anno_dicts:
assert anno_dicts == [None]
assert ref_dict['ann_id'] == [-1]
record['empty'] = True
obj = {key: None for key in ann_keys if key in ann_keys}
obj["bbox_mode"] = BoxMode.XYWH_ABS
obj["empty"] = True
obj = [obj]
# Multi target samples
else:
record['empty'] = False
obj = []
for anno_dict in anno_dicts:
ann_id = anno_dict['id']
if anno_dict['iscrowd']:
continue
assert anno_dict["image_id"] == image_id
assert ann_id in ref_dict['ann_id']
if ann_id in ann_lib:
ann = ann_lib[ann_id]
else:
ann = {key: anno_dict[key] for key in ann_keys if key in anno_dict}
ann["bbox_mode"] = BoxMode.XYWH_ABS
ann["empty"] = False
segm = anno_dict.get("segmentation", None)
assert segm # either list[list[float]] or dict(RLE)
if isinstance(segm, dict):
if isinstance(segm["counts"], list):
# convert to compressed RLE
segm = mask_util.frPyObjects(segm, *segm["size"])
else:
# filter out invalid polygons (< 3 points)
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue # ignore this instance
ann["segmentation"] = segm
ann_lib[ann_id] = ann
obj.append(ann)
record["annotations"] = obj
# Process referring expressions
sents = ref_dict['sentences']
for sent in sents:
ref_record = record.copy()
ref = {key: sent[key] for key in ref_keys if key in sent}
ref["ref_id"] = ref_dict["ref_id"]
ref_record["sentence"] = ref
self.dataset_dicts.append(ref_record)
# if ref_record['empty']:
# NT_count += 1
# else:
# MT_count += 1
# logger.info("NT samples: %d, MT samples: %d", NT_count, MT_count)
# Debug mode
# return self.dataset_dicts[:100]
@staticmethod
def _merge_masks(x):
return x.sum(dim=0, keepdim=True).clamp(max=1)
def __getitem__(self, index):
dataset_dict = copy.deepcopy(self.dataset_dicts[index])
# dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
utils.check_image_size(dataset_dict, image)
# TODO: get padding mask
# by feeding a "segmentation mask" to the same transforms
padding_mask = np.ones(image.shape[:2])
image, transforms = T.apply_transform_gens(self.tfm_gens, image)
# the crop transformation has default padding value 0 for segmentation
padding_mask = transforms.apply_segmentation(padding_mask)
padding_mask = ~padding_mask.astype(bool)
image_shape = image.shape[:2] # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask))
# USER: Implement additional transformations if you have other types of data
annos = [
utils.transform_instance_annotations(obj, transforms, image_shape)
for obj in dataset_dict.pop("annotations")
if (obj.get("iscrowd", 0) == 0) and (obj.get("empty", False) == False)
]
instances = utils.annotations_to_instances(annos, image_shape)
empty = dataset_dict.get("empty", False)
if len(instances) > 0:
assert (not empty)
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
# Generate masks from polygon
h, w = instances.image_size
assert hasattr(instances, 'gt_masks')
gt_masks = instances.gt_masks
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
instances.gt_masks = gt_masks
else:
assert empty
gt_masks = torch.zeros((0, image_shape[0], image_shape[1]), dtype=torch.uint8)
instances.gt_masks = gt_masks
if self.split == "train" :
dataset_dict["instances"] = instances
else:
dataset_dict["gt_mask"] = gt_masks
dataset_dict["empty"] = empty
dataset_dict["gt_mask_merged"] = self._merge_masks(gt_masks) if self.merge else None
# dataset_dict["gt_mask_merged"] = dataset_dict["gt_mask_merged"].float()
# Language data
sentence_raw = dataset_dict['sentence']['raw']
attention_mask = [0] * self.max_tokens
padded_input_ids = [0] * self.max_tokens
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True)
input_ids = input_ids[:self.max_tokens]
padded_input_ids[:len(input_ids)] = input_ids
attention_mask[:len(input_ids)] = [1] * len(input_ids)
dataset_dict['lang_tokens'] = torch.tensor(padded_input_ids).unsqueeze(0)
dataset_dict['lang_mask'] = torch.tensor(attention_mask).unsqueeze(0)
return dataset_dict["image"].float(), dataset_dict["gt_mask_merged"].squeeze(0).long(), dataset_dict['lang_tokens'], dataset_dict['lang_mask']
def __len__(self):
return len(self.dataset_dicts)
if __name__ == "__main__":
"""
Test the COCO json dataset loader.
Usage:
python -m detectron2.data.datasets.coco \
path/to/json path/to/image_root dataset_name
"dataset_name" can be "coco_2014_minival_100", or other
pre-registered ones
"""
from detectron2.utils.logger import setup_logger
from detectron2.utils.visualizer import Visualizer
import detectron2.data.datasets # noqa # add pre-defined metadata
import sys
REFCOCO_PATH = '/data2/projects/donghwa/RIS/ReLA/datasets'
COCO_TRAIN_2014_IMAGE_ROOT = '/data2/projects/donghwa/RIS/ReLA/datasets/images'
REFCOCO_DATASET = 'grefcoco'
REFCOCO_SPLITBY = 'unc'
REFCOCO_SPLIT = 'train'
logger = setup_logger(name=__name__)
dicts = load_grefcoco_json(REFCOCO_PATH, REFCOCO_DATASET, REFCOCO_SPLITBY, REFCOCO_SPLIT, COCO_TRAIN_2014_IMAGE_ROOT)
logger.info("Done loading {} samples.".format(len(dicts)))
|