import os import sys import torch.utils.data as data import torch from torchvision import transforms from torch.autograd import Variable import numpy as np from PIL import Image import torchvision.transforms.functional as TF import random from bert.tokenization_bert import BertTokenizer import h5py from refer.refer import REFER from args import get_parser # Dataset configuration initialization parser = get_parser() args = parser.parse_args() class ReferDataset(data.Dataset): def __init__(self, args, image_transforms=None, target_transforms=None, split='train', eval_mode=False): self.classes = [] self.image_transforms = image_transforms self.target_transform = target_transforms self.split = split self.dataset = args.dataset self.args = args if args.dataset == 'refcocog' and args.split in ['testA', 'testB']: print(f"Easy & Hard Example Experiments - dataset : {args.dataset}, split : {args.split}") from refer.refer_test import REFER self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) else : from refer.refer import REFER self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) self.max_tokens = 20 ref_ids = self.refer.getRefIds(split=self.split) img_ids = self.refer.getImgIds(ref_ids) # sent_len experiments import copy new_ref_ids = [] new_refs = {} for ref_id in ref_ids: ref = copy.deepcopy(self.refer.Refs[ref_id]) new_sents = [] for sent in ref['sentences']: words = sent['raw'].split(' ') if len(words)>= args.sent_len.min and len(words) <= args.sent_len.max: # 여기는 수정 new_sents.append(sent) if len(new_sents)>0: ref['sentences'] = new_sents new_refs[ref_id] = ref new_ref_ids.append(ref_id) ref_ids = new_ref_ids self.refer.Refs = new_refs all_imgs = self.refer.Imgs self.imgs = list(all_imgs[i] for i in img_ids) self.ref_ids = ref_ids self.input_ids = [] self.attention_masks = [] self.sent_lens = [] self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) self.eval_mode = eval_mode # if we are testing on a dataset, test all sentences of an object; # o/w, we are validating during training, randomly sample one sentence for efficiency for r in ref_ids: ref = self.refer.Refs[r] sentences_for_ref = [] attentions_for_ref = [] sent_lens_for_ref = [] for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): sentence_raw = el['raw'] # sent_len = len(sentence_raw.split()) # if sent_len <= self.args.seq_len.min or sent_len >= self.args.seq_len.max : # continue attention_mask = [0] * self.max_tokens padded_input_ids = [0] * self.max_tokens input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) # truncation of tokens input_ids = input_ids[:self.max_tokens] padded_input_ids[:len(input_ids)] = input_ids attention_mask[:len(input_ids)] = [1]*len(input_ids) sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) # sent_lens_for_ref.append(torch.tensor(sent_len)) self.input_ids.append(sentences_for_ref) self.attention_masks.append(attentions_for_ref) self.sent_lens.append(sent_lens_for_ref) def get_classes(self): return self.classes def __len__(self): return len(self.ref_ids) def __getitem__(self, index): this_ref_id = self.ref_ids[index] this_img_id = self.refer.getImgIds(this_ref_id) this_img = self.refer.Imgs[this_img_id[0]] img = Image.open(os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])).convert("RGB") ref = self.refer.loadRefs(this_ref_id) ref_mask = np.array(self.refer.getMask(ref[0])['mask']) annot = np.zeros(ref_mask.shape) annot[ref_mask == 1] = 1 annot = Image.fromarray(annot.astype(np.uint8), mode="P") if self.image_transforms is not None: # resize, from PIL to tensor, and mean and std normalization img, target = self.image_transforms(img, annot) if self.eval_mode: embedding = [] att = [] # sent_lens = [] for s in range(len(self.input_ids[index])): e = self.input_ids[index][s] a = self.attention_masks[index][s] # ss = self.sent_lens[index][s] embedding.append(e.unsqueeze(-1)) att.append(a.unsqueeze(-1)) # sent_lens.append(ss) tensor_embeddings = torch.cat(embedding, dim=-1) attention_mask = torch.cat(att, dim=-1) # sent_lens = torch.tensor(sent_lens) else: choice_sent = np.random.choice(len(self.input_ids[index])) tensor_embeddings = self.input_ids[index][choice_sent] attention_mask = self.attention_masks[index][choice_sent] # sent_lens = self.sent_lens[index][choice_sent] return img, target, tensor_embeddings, attention_mask #, sent_lens