import os import sys import json import torch.utils.data as data import torch from torchvision import transforms import numpy as np from PIL import Image import torchvision.transforms.functional as TF import random from bert.tokenization_bert import BertTokenizer from refer.refer import REFER from args import get_parser # Dataset configuration initialization # parser = get_parser() # args = parser.parse_args() class ReferDataset(data.Dataset): def __init__(self, args, image_transforms=None, target_transforms=None, split='train', eval_mode=False): self.classes = [] self.image_transforms = image_transforms self.target_transform = target_transforms self.split = split self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) self.max_tokens = 20 ref_ids = self.refer.getRefIds(split=self.split) img_ids = self.refer.getImgIds(ref_ids) all_imgs = self.refer.Imgs self.imgs = list(all_imgs[i] for i in img_ids) self.ref_ids = ref_ids self.input_ids = [] self.attention_masks = [] self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) # for metric learning self.ROOT = '/data2/dataset/RefCOCO/VRIS' self.metric_learning = args.metric_learning self.exclude_multiobj = args.exclude_multiobj self.metric_mode = args.metric_mode self.exclude_position = False if self.metric_learning and eval_mode == False: self.hardneg_prob = args.hn_prob self.multi_obj_ref_ids = self._load_multi_obj_ref_ids() self.hardpos_meta, self.hardneg_meta = self._load_metadata() else: self.hardneg_prob = 0.0 self.multi_obj_ref_ids = None self.hardpos_meta, self.hardneg_meta = None, None self.eval_mode = eval_mode # if we are testing on a dataset, test all sentences of an object; # o/w, we are validating during training, randomly sample one sentence for efficiency for r in ref_ids: ref = self.refer.Refs[r] sentences_for_ref = [] attentions_for_ref = [] for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): sentence_raw = el['raw'] attention_mask = [0] * self.max_tokens padded_input_ids = [0] * self.max_tokens input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) # truncation of tokens input_ids = input_ids[:self.max_tokens] padded_input_ids[:len(input_ids)] = input_ids attention_mask[:len(input_ids)] = [1]*len(input_ids) sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) self.input_ids.append(sentences_for_ref) self.attention_masks.append(attentions_for_ref) def _tokenize(self, sentence): attention_mask = [0] * self.max_tokens padded_input_ids = [0] * self.max_tokens input_ids = self.tokenizer.encode(text=sentence, add_special_tokens=True) # truncation of tokens input_ids = input_ids[:self.max_tokens] padded_input_ids[:len(input_ids)] = input_ids attention_mask[:len(input_ids)] = [1]*len(input_ids) # match shape as (1, max_tokens) return torch.tensor(padded_input_ids).unsqueeze(0), torch.tensor(attention_mask).unsqueeze(0) def _load_multi_obj_ref_ids(self): # Load multi-object reference IDs based on configurations if not self.exclude_multiobj and not self.exclude_position : return None elif self.exclude_position: multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt') elif self.exclude_multiobj : multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt') with open(multiobj_path, 'r') as f: return [int(line.strip()) for line in f.readlines()] def _load_metadata(self): # Load metadata for hard positive verb phrases, hard negative queries if 'refined' in self.metric_mode or 'hardneg' in self.metric_mode : hardpos_path = os.path.join(self.ROOT, 'hardpos_verdict_gref_v4.json') else : hardpos_path = os.path.join(self.ROOT, 'hardpos_verbphrase_0906upd.json') # do not use hardneg_path hardneg_path = os.path.join(self.ROOT, 'hardneg_verb.json') with open(hardpos_path, 'r', encoding='utf-8') as f: hardpos_json = json.load(f) if "hardpos_only" in self.metric_mode : hardneg_json = None else : with open(hardneg_path, 'r', encoding='utf-8') as q: hardneg_json = json.load(q) return hardpos_json, hardneg_json def get_classes(self): return self.classes def __len__(self): return len(self.ref_ids) def __getitem__(self, index): this_ref_id = self.ref_ids[index] this_img_id = self.refer.getImgIds(this_ref_id) this_img = self.refer.Imgs[this_img_id[0]] IMAGE_DIR = '/data2/dataset/COCO2014/trainval2014/' img = Image.open(os.path.join(IMAGE_DIR, this_img['file_name'])).convert("RGB") ref = self.refer.loadRefs(this_ref_id) #print(ref) ref_mask = np.array(self.refer.getMask(ref[0])['mask']) annot = np.zeros(ref_mask.shape) annot[ref_mask == 1] = 1 annot = Image.fromarray(annot.astype(np.uint8), mode="P") if self.image_transforms is not None: # resize, from PIL to tensor, and mean and std normalization img, target = self.image_transforms(img, annot) if self.eval_mode: embedding = [] att = [] for s in range(len(self.input_ids[index])): e = self.input_ids[index][s] a = self.attention_masks[index][s] embedding.append(e.unsqueeze(-1)) att.append(a.unsqueeze(-1)) tensor_embeddings = torch.cat(embedding, dim=-1) attention_mask = torch.cat(att, dim=-1) return img, target, tensor_embeddings, attention_mask else: # train phase choice_sent = np.random.choice(len(self.input_ids[index])) tensor_embeddings = self.input_ids[index][choice_sent] attention_mask = self.attention_masks[index][choice_sent] pos_sent = torch.zeros_like(tensor_embeddings) neg_sent = torch.zeros_like(tensor_embeddings) pos_attn_mask = torch.zeros_like(attention_mask) neg_attn_mask = torch.zeros_like(attention_mask) if self.metric_learning: if 'hardpos_' in self.metric_mode or self.hardneg_prob == 0.0: pos_sents = self.hardpos_meta[str(this_ref_id)].values() # drop elements with none pos_sents = [s for s in pos_sents if s is not None] pos_sent_picked = random.choice(list(pos_sents)) if pos_sent_picked: pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked) else: pos_sents = self.hardpos_meta[str(this_ref_id)].values() # drop elements with none pos_sents = [s for s in pos_sents if s is not None] pos_sent_picked = random.choice(list(pos_sents)) if pos_sent_picked: pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked) if random.random() < self.hardneg_prob: neg_sents = self.hardneg_meta[str(this_ref_id)].values() neg_sents = [s for s in neg_sents if s is not None] neg_sent_picked = random.choice(list(neg_sents)) #print("neg_sent: ", neg_sent) if neg_sent_picked: neg_sent, neg_attn_mask = self._tokenize(neg_sent_picked) # print("index: ", self.input_ids[index]) # print("choice_sent: ", choice_sent) # print("tensor_embeddings: ", tensor_embeddings) # print("original sentence: ", self.tokenizer.decode(tensor_embeddings.squeeze(0).tolist())) # print("pos_sent: ", pos_sent) # print("neg_sent: ", neg_sent) # print("pos_attn_mask: ", pos_attn_mask) # print("neg_attn_mask: ", neg_attn_mask) #exit() return img, target, tensor_embeddings, attention_mask, pos_sent, pos_attn_mask, neg_sent, neg_attn_mask