import os import sys import torch.utils.data as data import torch from torchvision import transforms from torch.autograd import Variable import numpy as np from PIL import Image import torchvision.transforms.functional as TF import random from bert.tokenization_bert import BertTokenizer import h5py from refer.refer import REFER from args import get_parser # Dataset configuration initialization parser = get_parser() args = parser.parse_args() class ReferDataset(data.Dataset): def __init__(self, args, image_transforms=None, target_transforms=None, split='train', eval_mode=False): self.classes = [] self.image_transforms = image_transforms self.target_transform = target_transforms self.split = split self.dataset = args.dataset self.args = args if args.dataset == 'refcocog' and args.split in ['motion', 'static']: import json print(f"Easy & Hard Example Experiments - dataset : {args.dataset}, split : {args.split}") if args.split == 'motion' : meta_fp = '/data2/projects/chaeyun/LAVT-RIS/test_ablation_motion.json' else : meta_fp = '/data2/projects/chaeyun/LAVT-RIS/test_ablation_static.json' with open(meta_fp, 'r') as f : ref_metas = json.load(f) self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) self.max_tokens = 20 # motion, static split binning self.input_ids = [] self.attention_masks = [] self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) self.ref_ids = [] self.eval_mode = eval_mode self.refer_ctmz = {} for ref in ref_metas : sentences_for_ref = [] attentions_for_ref = [] sent_lens_for_ref = [] for i, sents in enumerate(ref['sentences']) : sentence_raw = sents['sent'] attention_mask = [0] * self.max_tokens padded_input_ids = [0] * self.max_tokens input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) input_ids = input_ids[:self.max_tokens] padded_input_ids[:len(input_ids)] = input_ids attention_mask[:len(input_ids)] = [1]*len(input_ids) sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) self.input_ids.append(sentences_for_ref) self.attention_masks.append(attentions_for_ref) self.ref_ids.append(ref['segment_id']) if ref['segment_id'] not in self.refer_ctmz : self.refer_ctmz[ref['segment_id']] = ref img_ids = self.refer.getImgIds(self.ref_ids) all_imgs = self.refer.Imgs self.imgs = list(all_imgs[i] for i in img_ids) def get_classes(self): return self.classes def __len__(self): return len(self.ref_ids) def __getitem__(self, index): this_ref_id = self.ref_ids[index] this_img_id = self.refer.getImgIds(this_ref_id) this_img = self.refer.Imgs[this_img_id[0]] IMAGE_DIR = '/data2/dataset/COCO2014/train2014/' img = Image.open(os.path.join(IMAGE_DIR, this_img['file_name'])).convert("RGB") ref_orig = self.refer.loadRefs(this_ref_id) ref = self.refer_ctmz[this_ref_id] ref_mask = np.array(self.refer.getMask(ref_orig[0])['mask']) annot = np.zeros(ref_mask.shape) annot[ref_mask == 1] = 1 annot = Image.fromarray(annot.astype(np.uint8), mode="P") if self.image_transforms is not None: # resize, from PIL to tensor, and mean and std normalization img, target = self.image_transforms(img, annot) if self.eval_mode: embedding = [] att = [] for s in range(len(self.input_ids[index])): e = self.input_ids[index][s] a = self.attention_masks[index][s] embedding.append(e.unsqueeze(-1)) att.append(a.unsqueeze(-1)) tensor_embeddings = torch.cat(embedding, dim=-1) attention_mask = torch.cat(att, dim=-1) return img, target, tensor_embeddings, attention_mask else: choice_sent = np.random.choice(len(self.input_ids[index])) tensor_embeddings = self.input_ids[index][choice_sent] attention_mask = self.attention_masks[index][choice_sent] return img, target, tensor_embeddings, attention_mask