VRIS_vip / LAVT-RIS /data /dataset_refer_bert_mostat.py
dianecy's picture
Upload folder using huggingface_hub
8d82201 verified
raw
history blame
4.9 kB
import os
import sys
import torch.utils.data as data
import torch
from torchvision import transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torchvision.transforms.functional as TF
import random
from bert.tokenization_bert import BertTokenizer
import h5py
from refer.refer import REFER
from args import get_parser
# Dataset configuration initialization
parser = get_parser()
args = parser.parse_args()
class ReferDataset(data.Dataset):
def __init__(self,
args,
image_transforms=None,
target_transforms=None,
split='train',
eval_mode=False):
self.classes = []
self.image_transforms = image_transforms
self.target_transform = target_transforms
self.split = split
self.dataset = args.dataset
self.args = args
if args.dataset == 'refcocog' and args.split in ['motion', 'static']:
import json
print(f"Easy & Hard Example Experiments - dataset : {args.dataset}, split : {args.split}")
if args.split == 'motion' :
meta_fp = '/data2/projects/chaeyun/LAVT-RIS/test_ablation_motion.json'
else :
meta_fp = '/data2/projects/chaeyun/LAVT-RIS/test_ablation_static.json'
with open(meta_fp, 'r') as f :
ref_metas = json.load(f)
self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy)
self.max_tokens = 20
# motion, static split binning
self.input_ids = []
self.attention_masks = []
self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
self.ref_ids = []
self.eval_mode = eval_mode
self.refer_ctmz = {}
for ref in ref_metas :
sentences_for_ref = []
attentions_for_ref = []
sent_lens_for_ref = []
for i, sents in enumerate(ref['sentences']) :
sentence_raw = sents['sent']
attention_mask = [0] * self.max_tokens
padded_input_ids = [0] * self.max_tokens
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True)
input_ids = input_ids[:self.max_tokens]
padded_input_ids[:len(input_ids)] = input_ids
attention_mask[:len(input_ids)] = [1]*len(input_ids)
sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0))
attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0))
self.input_ids.append(sentences_for_ref)
self.attention_masks.append(attentions_for_ref)
self.ref_ids.append(ref['segment_id'])
if ref['segment_id'] not in self.refer_ctmz :
self.refer_ctmz[ref['segment_id']] = ref
img_ids = self.refer.getImgIds(self.ref_ids)
all_imgs = self.refer.Imgs
self.imgs = list(all_imgs[i] for i in img_ids)
def get_classes(self):
return self.classes
def __len__(self):
return len(self.ref_ids)
def __getitem__(self, index):
this_ref_id = self.ref_ids[index]
this_img_id = self.refer.getImgIds(this_ref_id)
this_img = self.refer.Imgs[this_img_id[0]]
IMAGE_DIR = '/data2/dataset/COCO2014/train2014/'
img = Image.open(os.path.join(IMAGE_DIR, this_img['file_name'])).convert("RGB")
ref_orig = self.refer.loadRefs(this_ref_id)
ref = self.refer_ctmz[this_ref_id]
ref_mask = np.array(self.refer.getMask(ref_orig[0])['mask'])
annot = np.zeros(ref_mask.shape)
annot[ref_mask == 1] = 1
annot = Image.fromarray(annot.astype(np.uint8), mode="P")
if self.image_transforms is not None:
# resize, from PIL to tensor, and mean and std normalization
img, target = self.image_transforms(img, annot)
if self.eval_mode:
embedding = []
att = []
for s in range(len(self.input_ids[index])):
e = self.input_ids[index][s]
a = self.attention_masks[index][s]
embedding.append(e.unsqueeze(-1))
att.append(a.unsqueeze(-1))
tensor_embeddings = torch.cat(embedding, dim=-1)
attention_mask = torch.cat(att, dim=-1)
return img, target, tensor_embeddings, attention_mask
else:
choice_sent = np.random.choice(len(self.input_ids[index]))
tensor_embeddings = self.input_ids[index][choice_sent]
attention_mask = self.attention_masks[index][choice_sent]
return img, target, tensor_embeddings, attention_mask