|
import os |
|
import sys |
|
import json |
|
import torch.utils.data as data |
|
import torch |
|
import itertools |
|
import numpy as np |
|
from PIL import Image |
|
import pdb |
|
import copy |
|
from random import choice |
|
from bert.tokenization_bert import BertTokenizer |
|
|
|
from refer.refer_zom import ZREFER |
|
import copy |
|
import random |
|
import torch |
|
from collections import defaultdict |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
from args import get_parser |
|
import random |
|
|
|
parser = get_parser() |
|
args = parser.parse_args() |
|
|
|
|
|
class Referzom_Dataset(data.Dataset): |
|
|
|
def __init__(self, |
|
args, |
|
image_transforms=None, |
|
target_transforms=None, |
|
split='train', |
|
eval_mode=False): |
|
|
|
self.classes = [] |
|
self.image_transforms = image_transforms |
|
self.target_transform = target_transforms |
|
self.split = split |
|
self.refer = ZREFER(args.refer_data_root, args.dataset, args.splitBy) |
|
self.dataset_type = args.dataset |
|
self.max_tokens = 20 |
|
ref_ids = self.refer.getRefIds(split=self.split) |
|
self.img_ids = self.refer.getImgIds(ref_ids) |
|
|
|
all_imgs = self.refer.Imgs |
|
self.imgs = list(all_imgs[i] for i in self.img_ids) |
|
self.ref_ids = ref_ids |
|
|
|
self.input_ids = [] |
|
self.attention_masks = [] |
|
self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer) |
|
|
|
self.ROOT = '/data2/dataset/RefCOCO/VRIS' |
|
self.metric_learning = args.metric_learning |
|
self.exclude_multiobj = args.exclude_multiobj |
|
self.metric_mode = args.metric_mode |
|
self.exclude_position = False |
|
|
|
if self.metric_learning and eval_mode == False: |
|
self.hardneg_prob = args.hn_prob |
|
self.multi_obj_ref_ids = self._load_multi_obj_ref_ids() |
|
self.hardpos_meta, self.hardneg_meta = self._load_metadata() |
|
else: |
|
self.hardneg_prob = 0.0 |
|
self.multi_obj_ref_ids = None |
|
self.hardpos_meta, self.hardneg_meta = None, None |
|
|
|
self.eval_mode = eval_mode |
|
|
|
self.zero_sent_id_list = [] |
|
self.one_sent_id_list = [] |
|
self.all_sent_id_list = [] |
|
self.sent_2_refid = {} |
|
|
|
|
|
for r in ref_ids: |
|
ref = self.refer.loadRefs(r) |
|
source_type = ref[0]['source'] |
|
|
|
for sent_dict in ref[0]['sentences']: |
|
sent_id = sent_dict['sent_id'] |
|
|
|
self.sent_2_refid[sent_id] = r |
|
self.all_sent_id_list.append(sent_id) |
|
if source_type=='zero': |
|
self.zero_sent_id_list.append(sent_id) |
|
else: |
|
self.one_sent_id_list.append(sent_id) |
|
|
|
for r in ref_ids: |
|
ref = self.refer.Refs[r] |
|
|
|
sentences_for_ref = [] |
|
attentions_for_ref = [] |
|
|
|
for i, el in enumerate(ref['sentences']): |
|
sentence_raw = el['raw'] |
|
attention_mask = [0] * self.max_tokens |
|
padded_input_ids = [0] * self.max_tokens |
|
|
|
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) |
|
|
|
|
|
input_ids = input_ids[:self.max_tokens] |
|
|
|
padded_input_ids[:len(input_ids)] = input_ids |
|
attention_mask[:len(input_ids)] = [1]*len(input_ids) |
|
|
|
sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) |
|
attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) |
|
|
|
self.input_ids.extend(sentences_for_ref) |
|
self.attention_masks.extend(attentions_for_ref) |
|
|
|
|
|
def get_classes(self): |
|
return self.classes |
|
|
|
|
|
def _tokenize(self, sentence): |
|
attention_mask = [0] * self.max_tokens |
|
padded_input_ids = [0] * self.max_tokens |
|
|
|
input_ids = self.tokenizer.encode(text=sentence, add_special_tokens=True) |
|
|
|
input_ids = input_ids[:self.max_tokens] |
|
padded_input_ids[:len(input_ids)] = input_ids |
|
attention_mask[:len(input_ids)] = [1]*len(input_ids) |
|
|
|
|
|
return torch.tensor(padded_input_ids).unsqueeze(0), torch.tensor(attention_mask).unsqueeze(0) |
|
|
|
def _load_multi_obj_ref_ids(self): |
|
|
|
if not self.exclude_multiobj and not self.exclude_position : |
|
return None |
|
elif self.exclude_position: |
|
multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt') |
|
elif self.exclude_multiobj : |
|
multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt') |
|
with open(multiobj_path, 'r') as f: |
|
return [int(line.strip()) for line in f.readlines()] |
|
|
|
def _load_metadata(self): |
|
hardpos_path = os.path.join(self.ROOT, 'verb_ext_text_example_refzom.json') |
|
with open(hardpos_path, 'r', encoding='utf-8') as f: |
|
hardpos_json = json.load(f) |
|
if "hardpos_only" in self.metric_mode : |
|
hardneg_json = None |
|
|
|
|
|
|
|
|
|
return hardpos_json, hardneg_json |
|
|
|
|
|
def _get_hardpos_verb(self, ref, seg_id, sent_idx) : |
|
if seg_id in self.multi_obj_ref_ids: |
|
return '' |
|
|
|
|
|
hardpos_dict = self.hardpos_meta.get(str(seg_id), {}) |
|
if self.hp_selection == 'strict' : |
|
sent_id_list = list(hardpos_dict.keys()) |
|
cur_hardpos = hardpos_dict.get(sent_id_list[sent_idx], {}).get('phrases', []) |
|
else : |
|
cur_hardpos = list(itertools.chain.from_iterable(hardpos_dict[sid]['phrases'] for sid in hardpos_dict)) |
|
|
|
if cur_hardpos: |
|
|
|
raw_verb = random.choice(cur_hardpos) |
|
return raw_verb |
|
|
|
return '' |
|
|
|
def __len__(self): |
|
return len(self.all_sent_id_list) |
|
|
|
def __getitem__(self, index): |
|
|
|
sent_id = self.all_sent_id_list[index] |
|
this_ref_id = self.sent_2_refid[sent_id] |
|
|
|
this_img_id = self.refer.getImgIds(this_ref_id) |
|
this_img = self.refer.Imgs[this_img_id[0]] |
|
|
|
IMAGE_DIR = '/data2/dataset/COCO2014/trainval2014/' |
|
img = Image.open(os.path.join(IMAGE_DIR, this_img['file_name'])).convert("RGB") |
|
|
|
ref = self.refer.loadRefs(this_ref_id) |
|
if self.dataset_type == 'ref-zom': |
|
source_type = ref[0]['source'] |
|
else: |
|
source_type = 'not_zero' |
|
|
|
ref_mask = np.array(self.refer.getMask(ref[0])['mask']) |
|
annot = np.zeros(ref_mask.shape) |
|
annot[ref_mask == 1] = 1 |
|
annot = Image.fromarray(annot.astype(np.uint8), mode="P") |
|
|
|
|
|
if self.image_transforms is not None: |
|
img, target = self.image_transforms(img, annot) |
|
|
|
if self.eval_mode: |
|
embedding = [] |
|
att = [] |
|
for s in range(len(self.input_ids[index])): |
|
padded_input_ids = self.input_ids[index][s] |
|
attention_mask = self.attention_masks[index][s] |
|
|
|
embedding.append(padded_input_ids.unsqueeze(-1)) |
|
att.append(attention_mask.unsqueeze(-1)) |
|
|
|
tensor_embeddings = torch.cat(embedding, dim=-1) |
|
attention_mask = torch.cat(att, dim=-1) |
|
return img, target, source_type, tensor_embeddings, attention_mask |
|
|
|
else: |
|
choice_sent = np.random.choice(len(self.input_ids[index])) |
|
tensor_embeddings = self.input_ids[index][choice_sent] |
|
attention_mask = self.attention_masks[index][choice_sent] |
|
|
|
if self.metric_learning : |
|
pos_sent = torch.zeros_like(tensor_embeddings) |
|
pos_attn_mask = torch.zeros_like(attention_mask) |
|
|
|
|
|
if 'hardpos_' in self.metric_mode or self.hardneg_prob == 0.0: |
|
pos_type = 'zero' |
|
if 'refined' in self.metric_mode : |
|
pos_sent_picked = self._get_hardpos_verb(ref, this_ref_id, choice_sent) |
|
else : |
|
pos_sents = self.hardpos_meta[str(this_ref_id)].values() |
|
|
|
pos_sents = [s for s in pos_sents if s is not None] |
|
pos_sent_picked = random.choice(list(pos_sents)) |
|
if pos_sent_picked : |
|
pos_type = 'hardpos' |
|
pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked) |
|
pos_sent = pos_sent.squeeze(0) if pos_sent.dim() == 2 and pos_sent.size(0) == 1 else pos_sent |
|
pos_attn_mask = pos_attn_mask.squeeze(0) if pos_attn_mask.size(0) == 1 else pos_attn_mask |
|
|
|
return img, target, source_type, tensor_embeddings, attention_mask, pos_sent, pos_attn_mask, pos_type |
|
|
|
return img, target, source_type, tensor_embeddings, attention_mask |
|
|
|
|
|
|
|
|
|
class Refzom_DistributedSampler(DistributedSampler): |
|
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): |
|
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) |
|
self.one_id_list = dataset.one_sent_id_list |
|
|
|
self.zero_id_list = dataset.zero_sent_id_list |
|
self.sent_ids_list = dataset.all_sent_id_list |
|
if self.shuffle==True: |
|
random.shuffle(self.one_id_list) |
|
random.shuffle(self.zero_id_list) |
|
|
|
self.sent_id = self.insert_evenly(self.zero_id_list,self.one_id_list) |
|
self.indices = self.get_positions(self.sent_ids_list, self.sent_id) |
|
|
|
def get_positions(self, list_a, list_b): |
|
position_dict = {value: index for index, value in enumerate(list_a)} |
|
positions = [position_dict[item] for item in list_b] |
|
|
|
return positions |
|
|
|
def insert_evenly(self, list_a, list_b): |
|
len_a = len(list_a) |
|
len_b = len(list_b) |
|
block_size = len_b // len_a |
|
|
|
result = [] |
|
for i in range(len_a): |
|
start = i * block_size |
|
end = (i + 1) * block_size |
|
result.extend(list_b[start:end]) |
|
result.append(list_a[i]) |
|
|
|
remaining = list_b[(len_a * block_size):] |
|
result.extend(remaining) |
|
|
|
return result |
|
|
|
def __iter__(self): |
|
|
|
indices_per_process = self.indices[self.rank::self.num_replicas] |
|
return iter(indices_per_process) |