import os import sys import json import torch.utils.data as data import torch import itertools from torchvision import transforms from torch.autograd import Variable import numpy as np from PIL import Image import torchvision.transforms.functional as TF import random from bert.tokenization_bert import BertTokenizer import h5py from refer.refer import REFER from args import get_parser # Dataset configuration initialization # parser = get_parser() # args = parser.parse_args() class ReferDataset(data.Dataset): def __init__(self, args, image_transforms=None, target_transforms=None, split='train', eval_mode=False): self.classes = [] self.image_transforms = image_transforms self.target_transform = target_transforms self.split = split self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) self.max_tokens = 20 ref_ids = self.refer.getRefIds(split=self.split) img_ids = self.refer.getImgIds(ref_ids) all_imgs = self.refer.Imgs self.imgs = list(all_imgs[i] for i in img_ids) self.ref_ids = ref_ids self.input_ids = [] self.attention_masks = [] self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) # for metric learning self.ROOT = '/data2/projects/seunghoon/VerbRIS/VerbCentric_CY/datasets/VRIS' self.metric_learning = args.metric_learning self.exclude_multiobj = args.exclude_multiobj self.metric_mode = args.metric_mode self.exclude_position = False self.hp_selection = args.hp_selection if self.metric_learning and eval_mode == False: self.hardneg_prob = args.hn_prob self.multi_obj_ref_ids = self._load_multi_obj_ref_ids() self.hardpos_meta, self.hardneg_meta = self._load_metadata() else: self.hardneg_prob = 0.0 self.multi_obj_ref_ids = None self.hardpos_meta, self.hardneg_meta = None, None self.eval_mode = eval_mode # if we are testing on a dataset, test all sentences of an object; # o/w, we are validating during training, randomly sample one sentence for efficiency for r in ref_ids: ref = self.refer.Refs[r] sentences_for_ref = [] attentions_for_ref = [] for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): sentence_raw = el['raw'] attention_mask = [0] * self.max_tokens padded_input_ids = [0] * self.max_tokens input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) # truncation of tokens input_ids = input_ids[:self.max_tokens] padded_input_ids[:len(input_ids)] = input_ids attention_mask[:len(input_ids)] = [1]*len(input_ids) sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) self.input_ids.append(sentences_for_ref) self.attention_masks.append(attentions_for_ref) def _tokenize(self, sentence): attention_mask = [0] * self.max_tokens padded_input_ids = [0] * self.max_tokens input_ids = self.tokenizer.encode(text=sentence, add_special_tokens=True) # truncation of tokens input_ids = input_ids[:self.max_tokens] padded_input_ids[:len(input_ids)] = input_ids attention_mask[:len(input_ids)] = [1]*len(input_ids) # match shape as (1, max_tokens) return torch.tensor(padded_input_ids).unsqueeze(0), torch.tensor(attention_mask).unsqueeze(0) def _load_multi_obj_ref_ids(self): # Load multi-object reference IDs based on configurations if not self.exclude_multiobj and not self.exclude_position : return None elif self.exclude_position: multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt') elif self.exclude_multiobj : multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt') with open(multiobj_path, 'r') as f: return [int(line.strip()) for line in f.readlines()] def _load_metadata(self): # Load metadata for hard positive verb phrases, hard negative queries if 'refined' in self.metric_mode or 'hardneg' in self.metric_mode : hardpos_path = os.path.join(self.ROOT, 'hardpos_verdict_gref_v4.json') else : hardpos_path = os.path.join(self.ROOT, 'hardpos_verbphrase_0906upd.json') # do not use hardneg_path hardneg_path = os.path.join(self.ROOT, 'hardneg_verb.json') with open(hardpos_path, 'r', encoding='utf-8') as f: hardpos_json = json.load(f) if "hardpos_only" in self.metric_mode : hardneg_json = None else : with open(hardneg_path, 'r', encoding='utf-8') as q: hardneg_json = json.load(q) return hardpos_json, hardneg_json def _get_hardpos_verb(self, ref, seg_id, sent_idx) : if seg_id in self.multi_obj_ref_ids: return '' # Extract metadata for hard positives if present hardpos_dict = self.hardpos_meta.get(str(seg_id), {}) if self.hp_selection == 'strict' : sent_id_list = list(hardpos_dict.keys()) cur_hardpos = hardpos_dict.get(sent_id_list[sent_idx], {}).get('phrases', []) else : cur_hardpos = list(itertools.chain.from_iterable(hardpos_dict[sid]['phrases'] for sid in hardpos_dict)) if cur_hardpos: # Assign a hard positive verb phrase if available raw_verb = random.choice(cur_hardpos) return raw_verb return '' def get_classes(self): return self.classes def __len__(self): return len(self.ref_ids) def __getitem__(self, index): this_ref_id = self.ref_ids[index] this_img_id = self.refer.getImgIds(this_ref_id) this_img = self.refer.Imgs[this_img_id[0]] IMAGE_DIR = '/data2/dataset/COCO2014/trainval2014/' img = Image.open(os.path.join(IMAGE_DIR, this_img['file_name'])).convert("RGB") ref = self.refer.loadRefs(this_ref_id) ref_mask = np.array(self.refer.getMask(ref[0])['mask']) annot = np.zeros(ref_mask.shape) annot[ref_mask == 1] = 1 annot = Image.fromarray(annot.astype(np.uint8), mode="P") if self.image_transforms is not None: # resize, from PIL to tensor, and mean and std normalization img, target = self.image_transforms(img, annot) if self.eval_mode: embedding = [] att = [] for s in range(len(self.input_ids[index])): e = self.input_ids[index][s] a = self.attention_masks[index][s] embedding.append(e.unsqueeze(-1)) att.append(a.unsqueeze(-1)) tensor_embeddings = torch.cat(embedding, dim=-1) attention_mask = torch.cat(att, dim=-1) return img, target, tensor_embeddings, attention_mask else: # train phase choice_sent = np.random.choice(len(self.input_ids[index])) tensor_embeddings = self.input_ids[index][choice_sent] attention_mask = self.attention_masks[index][choice_sent] if self.metric_learning: pos_sent = torch.zeros_like(tensor_embeddings) pos_attn_mask = torch.zeros_like(attention_mask) if 'hardpos_' in self.metric_mode or self.hardneg_prob == 0.0: if 'refined' in self.metric_mode : pos_sent_picked = self._get_hardpos_verb(ref, this_ref_id, choice_sent) else : pos_sents = self.hardpos_meta[str(this_ref_id)].values() # drop elements with none pos_sents = [s for s in pos_sents if s is not None] pos_sent_picked = random.choice(list(pos_sents)) if pos_sent_picked: pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked) return img, target, tensor_embeddings, attention_mask, pos_sent, pos_attn_mask else: neg_sent = torch.zeros_like(tensor_embeddings) neg_attn_mask = torch.zeros_like(attention_mask) pos_sents = self.hardpos_meta[str(this_ref_id)].values() # drop elements with none pos_sents = [s for s in pos_sents if s is not None] pos_sent_picked = random.choice(list(pos_sents)) if pos_sent_picked: pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked) if random.random() < self.hardneg_prob: neg_sents = self.hardneg_meta[str(this_ref_id)].values() neg_sents = [s for s in neg_sents if s is not None] neg_sent_picked = random.choice(list(neg_sents)) #print("neg_sent: ", neg_sent) if neg_sent_picked: neg_sent, neg_attn_mask = self._tokenize(neg_sent_picked) return img, target, tensor_embeddings, attention_mask, pos_sent, pos_attn_mask, neg_sent, neg_attn_mask