import os import sys import cv2 import math import glob import json import random import pickle import numpy as np import pandas as pd from PIL import Image, ImageDraw, ImageFilter from bert.tokenization_bert import BertTokenizer import albumentations as A from albumentations.pytorch import ToTensorV2 import torch, gc import torch.utils.data as data import lmdb import pyarrow as pa import warnings from .utils import get_warmup_value warnings.simplefilter(action='ignore', category=FutureWarning) gc.collect() torch.cuda.empty_cache() def loads_pyarrow(buf): return pa.deserialize(buf) class ReferDataset(data.Dataset): def __init__(self, args, split='train', eval_mode=False): self.classes = [] self.args = args self.split = split self.aug = args.aug self.img_sz = args.img_size each_img_sz = int(args.img_size/math.sqrt(self.aug.num_bgs)) mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) self.resize_bg1 = A.Compose([ A.Resize(args.img_size, args.img_size, always_apply=True)]) self.resize_bg4 = A.Compose([ A.Resize(each_img_sz, each_img_sz, always_apply=True)], additional_targets={'image1': 'image', 'image2': 'image', 'image3': 'image', 'mask1': 'mask', 'mask2': 'mask', 'mask3': 'mask',}) self.transforms = A.Compose([ A.Normalize(mean=mean, std=std), ToTensorV2 (), ]) # load annotations if args.dataset == 'refcocog' and args.split in ['testA', 'testB']: print(f"Easy & Hard Example Experiments - dataset : {args.dataset}, split : {args.split}") from refer.refer_test import REFER self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) else : from refer.refer import REFER self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) ref_ids = self.refer.getRefIds(split=self.split) img_ids = self.refer.getImgIds(ref_ids) all_imgs = self.refer.Imgs self.imgs = list(all_imgs[i] for i in img_ids) self.ref_ids = ref_ids#[:500] self.ref_id2idx = dict(zip(ref_ids, range(len(ref_ids)))) self.ref_idx2id = dict(zip(range(len(ref_ids)), ref_ids)) # tokenizer setting # if args.text_encoder.model=='bert': self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # elif args.text_encoder.model=='roberta': # self.tokenizer = RobertaTokenizerFast.from_pretrained(args.text_encoder.tokenizer) # elif args.text_encoder.model=='clip': # self.tokenizer = AutoTokenizer.from_pretrained(args.text_encoder.tokenizer) # elif args.text_encoder.model=='llama': # self.tokenizer = AutoTokenizer.from_pretrained(args.text_encoder.tokenizer) # self.tokenizer.add_special_tokens({'additional_special_tokens': task_tokens}) # self.tokenizer.add_tokens(position_tokens) # if we are testing on a dataset, test all sentences of an object; # o/w, we are validating during training, randomly sample one sentence for efficiency self.max_tokens = 20 self.eval_mode = eval_mode self.input_ids = [] self.attention_masks = [] for i, r in enumerate(ref_ids): ref = self.refer.Refs[r] sentences_for_ref = [] attentions_for_ref = [] for j, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): sentence_raw = el['raw'] input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True, max_length=self.max_tokens, truncation=True) #input_ids = input_ids[:self.max_tokens] padded_input_ids = [0] * self.max_tokens padded_input_ids[:len(input_ids)] = input_ids attention_mask = [0] * self.max_tokens attention_mask[:len(input_ids)] = [1]*len(input_ids) sentences_for_ref.append(padded_input_ids) attentions_for_ref.append(attention_mask) self.input_ids.append(sentences_for_ref) self.attention_masks.append(attentions_for_ref) if self.aug.blur: self.blur = ImageFilter.GaussianBlur(100) # Load mldb data if args.dataset == 'refcoco': lmdb_path = f'/data2/dataset/RefCOCO/logit_db/refcoco/refcoco_logit.lmdb' elif args.dataset == 'refcoco+': lmdb_path = f'/data2/dataset/RefCOCO/logit_db/refcoco+/refcocop_logit.lmdb' elif args.dataset == 'refcocog' and args.splitBy == 'umd': lmdb_path = f'/data2/dataset/RefCOCO/logit_db/refcocog_u/refcocog_u_logit.lmdb' # elif args.datset == 'refcocog' and args.splitBy == 'google': # lmdb_path = f'/data2/projects/donghwa/RIS/LAVT-RIS/data/logit_db/{args.dataset}/refcocog_g_logit.lmdb' self.lmdb_env = lmdb.open( lmdb_path, subdir=False, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) with self.lmdb_env.begin(write=False) as txn: self.length = loads_pyarrow(txn.get(b'__len__')) self.keys = loads_pyarrow(txn.get(b'__keys__')) self.epoch = 0 np.random.seed() def get_classes(self): return self.classes def __len__(self): return len(self.ref_ids) def __getitem__(self, index): refid = self.ref_idx2id[index] with self.lmdb_env.begin(write=False) as txn: byteflow = txn.get(self.keys[refid]) lmdb_dict = loads_pyarrow(byteflow) # decide mosaic size if self.split=='train': if self.aug.num_bgs==4: aug_prob = self.aug.aug_prob # Before retrieval_iter: One Image or Random Mosaic if self.epoch < self.aug.retrieval_epoch: # num_bgs = np.random.choice([1, 4], p=[1-aug_prob, aug_prob]) choice = np.random.choice(['one', 'random'], p=[1-aug_prob, aug_prob]) else: rand_prob = self.aug.rand_prob retr_prob = self.aug.retr_prob # After retrieval_iter: Decide between One Image, Random Mosaic, or Retrieval Based Mosaic choice = np.random.choice(['one', 'random', 'retrieval'], p=[1-(rand_prob + retr_prob), rand_prob, retr_prob]) if choice == 'one': num_bgs = 1 else: num_bgs = 4 else: num_bgs = 1 choice = 'one' else: # test, val num_bgs = 1 choice = 'one' target_sent_idx = np.random.choice(len(self.input_ids[index])) ref_id = self.ref_idx2id[index] insert_idx = np.random.choice(range(num_bgs)) if num_bgs==1: ref_ids = [] sent_idxs = [] sents = np.array([], dtype='str') img_ids = [self.refer.Refs[ref_id]['image_id']] else: if choice == 'retrieval': sent_id = list(lmdb_dict.keys())[target_sent_idx] img_ids = list(np.random.choice(lmdb_dict[sent_id][:self.aug.top_k], size=num_bgs-1, replace=True)) img_ids = np.insert(img_ids, insert_idx, self.refer.Refs[ref_id]['image_id']) ref_ids = list(np.random.choice(self.ref_ids, size=num_bgs-1, replace=False)) sent_idxs = [np.random.choice(len(self.refer.Refs[r]['sentences'])) for r in ref_ids] sents = np.array([self.refer.Refs[r]['sentences'][sent_idxs[i]]['raw'] for i, r in enumerate(ref_ids)], dtype='str') ref_ids = np.insert(ref_ids, insert_idx, self.ref_idx2id[index]).astype(int) sents = np.insert(sents, insert_idx, self.refer.Refs[ref_ids[insert_idx]]['sentences'][target_sent_idx]['raw']) sent_idxs = np.insert(sent_idxs, insert_idx, target_sent_idx).astype(int) # pick a target origin if self.aug.tgt_selection == 'random': target_idx = np.random.choice(range(num_bgs)) target_ref_idx = self.ref_id2idx[ref_ids[target_idx]] target_sent_idx = int(np.random.choice(len(self.input_ids[target_ref_idx]))) elif self.aug.tgt_selection == 'longest': target_idx = np.argmax(list(map(len, sents))) target_sent_idx = sent_idxs[target_idx] elif self.aug.tgt_selection == 'fixed': target_idx = insert_idx # target_ref_id = ref_ids[target_idx] target_ref_id = self.ref_idx2id[index] # load items imgs, masks = [], [] if choice == 'retrieval': # for ref_id in ref_ids: for img_id in img_ids: # img_id = self.refer.getImgIds([ref_id])[0] img_info = self.refer.Imgs[img_id] img_path = os.path.join(self.refer.IMAGE_DIR, img_info['file_name']) img = Image.open(img_path).convert("RGB") imgs.append(np.array(img)) ref = self.refer.imgToRefs[img_id][0] mask = np.array(self.refer.getMask(ref)['mask']) masks.append(mask) else : for ref_id in ref_ids: # for img_id in img_ids: img_id = self.refer.getImgIds([ref_id])[0] img_info = self.refer.Imgs[img_id] img_path = os.path.join(self.refer.IMAGE_DIR, img_info['file_name']) img = Image.open(img_path).convert("RGB") imgs.append(np.array(img)) ref = self.refer.loadRefs(ref_ids=[ref_id]) mask = np.array(self.refer.getMask(ref[0])['mask']) masks.append(mask) # image resize and apply 4in1 augmentation if num_bgs==1: resized = self.resize_bg1(image=imgs[0], mask=masks[0]) imgs, masks = [resized['image']], [resized['mask']] img = imgs[0] else: if self.aug.move_crs_pnt: crs_y = np.random.randint(0, self.img_sz+1) crs_x = np.random.randint(0, self.img_sz+1) else: crs_y = 480//2 # crs_x = 480//2 # if crs_y==0 or crs_x==0: img1 = np.zeros([0,crs_x,3]) if crs_y==0 else np.zeros([crs_y,0,3]) mask1 = np.zeros([0,crs_x]) if crs_y==0 else np.zeros([crs_y,0]) else: resize_bg1 = A.Compose([A.Resize(crs_y, crs_x, always_apply=True)]) temp = resize_bg1(image=imgs[0], mask=masks[0]) img1 = temp['image'] mask1 = temp['mask'] if crs_y==0 or crs_x==self.img_sz: img2 = np.zeros([0,self.img_sz-crs_x,3]) if crs_y==0 \ else np.zeros([crs_y,0,3]) mask2 = np.zeros([0,self.img_sz-crs_x]) if crs_y==0 \ else np.zeros([crs_y,0]) else: resize_bg2 = A.Compose([ A.Resize(crs_y, self.img_sz-crs_x, always_apply=True)]) temp = resize_bg2(image=imgs[1], mask=masks[1]) img2 = temp['image'] mask2 = temp['mask'] if crs_y==self.img_sz or crs_x==0: img3 = np.zeros([0,crs_x,3]) if crs_y==self.img_sz \ else np.zeros([self.img_sz-crs_y,0,3]) mask3 = np.zeros([0,crs_x]) if crs_y==self.img_sz \ else np.zeros([self.img_sz-crs_y,0]) else: resize_bg3 = A.Compose([ A.Resize(self.img_sz-crs_y, crs_x, always_apply=True)]) temp = resize_bg3(image=imgs[2], mask=masks[2]) img3 = temp['image'] mask3 = temp['mask'] if crs_y==self.img_sz or crs_x==self.img_sz: img4 = np.zeros([0,self.img_sz-crs_x,3]) if crs_y==self.img_sz \ else np.zeros([self.img_sz-crs_y,0,3]) mask4 = np.zeros([0,self.img_sz-crs_x]) if crs_y==self.img_sz \ else np.zeros([self.img_sz-crs_y,0]) else: resize_bg4 = A.Compose([ A.Resize(self.img_sz-crs_y, self.img_sz-crs_x, always_apply=True)]) temp = resize_bg4(image=imgs[3], mask=masks[3]) img4 = temp['image'] mask4 = temp['mask'] imgs = [img1, img2, img3, img4] masks = [mask1, mask2, mask3, mask4] # imgs = [bg1['image'], bg2['image'], bg3['image'], bg4['image']] # masks = [bg1['mask'], bg2['mask'], bg3['mask'], bg4['mask']] # resized = self.resize_bg4(image=imgs[0], mask=masks[0], # image1=imgs[1], mask1=masks[1], # image2=imgs[2], mask2=masks[2], # image3=imgs[3], mask3=masks[3]) # imgs = [resized['image'], resized['image1'], resized['image2'], resized['image3']] # masks = [resized['mask'], resized['mask1'], resized['mask2'], resized['mask3']] # scale effect ablation if self.aug.blur: imgs = [np.asarray(Image.fromarray(x).filter(self.blur)) if i!=insert_idx else x for i, x in enumerate(imgs)] num_rows = num_cols = int(math.sqrt(num_bgs)) idxs = [(i*num_cols,i*num_cols+num_cols) for i in range(num_rows)] img = [np.concatenate(imgs[_from:_to], axis=1) for (_from, _to) in idxs] img = np.concatenate(img, axis=0).astype(np.uint8) masks_arr = [] for bg_idx in range(num_bgs): mask = masks[bg_idx] temp = [mask if idx==bg_idx else np.zeros_like(masks[idx]) for idx in range(num_bgs)] mask = [np.concatenate(temp[_from:_to], axis=1) for (_from, _to) in idxs] mask = np.concatenate(mask, axis=0).astype(np.int32) masks_arr.append(mask) masks = masks_arr mask = masks[target_idx] mask = mask.astype(np.uint8) mask[mask>0] = 1 item = self.transforms(image=img, mask=mask) img_tensor = item['image'] target = item['mask'].long() target_ref_idx = self.ref_id2idx[target_ref_id] if self.eval_mode: embedding = [] att = [] for s in range(len(self.input_ids[target_ref_idx])): padded_input_ids = self.input_ids[target_ref_idx][s] #padded_input_ids = task_id + pos_id + padded_input_ids tensor_embeddings = torch.tensor(padded_input_ids).unsqueeze(0) attention_mask = self.attention_masks[target_ref_idx][s] attention_mask = torch.tensor(attention_mask).unsqueeze(0) embedding.append(tensor_embeddings.unsqueeze(-1)) att.append(attention_mask.unsqueeze(-1)) tensor_embeddings = torch.cat(embedding, dim=-1) attention_mask = torch.cat(att, dim=-1) else: padded_input_ids = self.input_ids[target_ref_idx][target_sent_idx] #padded_input_ids = task_id + pos_id + padded_input_ids tensor_embeddings = torch.tensor(padded_input_ids).unsqueeze(0) attention_mask = self.attention_masks[target_ref_idx][target_sent_idx] attention_mask = torch.tensor(attention_mask).unsqueeze(0) item = { 'image': img_tensor, 'seg_target': target, 'sentence': tensor_embeddings, 'attn_mask': attention_mask } return item