|
|
|
import os |
|
from typing import List, Union |
|
import json |
|
import cv2 |
|
import lmdb |
|
import random |
|
import numpy as np |
|
import pyarrow as pa |
|
import torch |
|
from torch.utils.data import Dataset |
|
import itertools |
|
import albumentations as A |
|
from albumentations.pytorch import ToTensorV2 |
|
from .simple_tokenizer import SimpleTokenizer as _Tokenizer |
|
|
|
info = { |
|
'refcoco': { |
|
'train': 42404, |
|
'val': 3811, |
|
'val-test': 3811, |
|
'testA': 1975, |
|
'testB': 1810 |
|
}, |
|
'refcoco+': { |
|
'train': 42278, |
|
'val': 3805, |
|
'val-test': 3805, |
|
'testA': 1975, |
|
'testB': 1798 |
|
}, |
|
'refcocog_u': { |
|
'train': 42226, |
|
'val': 2573, |
|
'val-test': 2573, |
|
'test': 5023, |
|
'test_0-5_verb' : 572, |
|
'test_0-5_static' : 1688, |
|
'test_6-7_verb' : 949, |
|
'test_6-7_static' : 1240, |
|
'test_8-10_verb' : 1523, |
|
'test_8-10_static' : 1194, |
|
'test_11-20_verb' : 1768, |
|
'test_11-20_static' : 584, |
|
'test_abl_motion' : 267, |
|
'test_abl_static' : 267 |
|
}, |
|
'refcocog_g': { |
|
'train': 44822, |
|
'val': 5000, |
|
'val-test': 5000 |
|
} |
|
} |
|
_tokenizer = _Tokenizer() |
|
|
|
|
|
def tokenize(texts: Union[str, List[str]], |
|
context_length: int = 77, |
|
truncate: bool = False) -> torch.LongTensor: |
|
""" |
|
Returns the tokenized representation of given input string(s) |
|
|
|
Parameters |
|
---------- |
|
texts : Union[str, List[str]] |
|
An input string or a list of input strings to tokenize |
|
|
|
context_length : int |
|
The context length to use; all CLIP models use 77 as the context length |
|
|
|
truncate: bool |
|
Whether to truncate the text in case its encoding is longer than the context length |
|
|
|
Returns |
|
------- |
|
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] |
|
""" |
|
if isinstance(texts, str): |
|
texts = [texts] |
|
|
|
sot_token = _tokenizer.encoder["<|startoftext|>"] |
|
eot_token = _tokenizer.encoder["<|endoftext|>"] |
|
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] |
|
for text in texts] |
|
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
|
|
|
for i, tokens in enumerate(all_tokens): |
|
if len(tokens) > context_length: |
|
if truncate: |
|
tokens = tokens[:context_length] |
|
tokens[-1] = eot_token |
|
else: |
|
raise RuntimeError( |
|
f"Input {texts[i]} is too long for context length {context_length}" |
|
) |
|
result[i, :len(tokens)] = torch.tensor(tokens) |
|
|
|
return result |
|
|
|
|
|
def loads_pyarrow(buf): |
|
""" |
|
Args: |
|
buf: the output of `dumps`. |
|
""" |
|
return pa.deserialize(buf) |
|
|
|
|
|
class RefDataset(Dataset): |
|
def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size, |
|
word_length, args): |
|
super(RefDataset, self).__init__() |
|
self.lmdb_dir = lmdb_dir |
|
self.mask_dir = mask_dir |
|
self.dataset = dataset |
|
self.split = split |
|
self.mode = mode |
|
self.input_size = (input_size, input_size) |
|
self.word_length = word_length |
|
self.mean = torch.tensor([0.48145466, 0.4578275, |
|
0.40821073]).reshape(3, 1, 1) |
|
self.std = torch.tensor([0.26862954, 0.26130258, |
|
0.27577711]).reshape(3, 1, 1) |
|
self.length = info[dataset][split] |
|
self.env = None |
|
self.exclude_position = args.exclude_pos |
|
self.metric_learning = args.metric_learning |
|
self.hardpos_rigid = args.hardpos_rigid |
|
self.resize_bg1 = A.Compose([ |
|
A.Resize(input_size, input_size, always_apply=True)]) |
|
if self.metric_learning : |
|
if self.hardpos_rigid and self.exclude_position : |
|
multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj_nopos.txt' |
|
with open(multiobj_path, 'r') as f: |
|
self.multi_obj_ref_ids = [int(line.strip()) for line in f.readlines()] |
|
elif self.hardpos_rigid : |
|
multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj.txt' |
|
with open(multiobj_path, 'r') as f: |
|
self.multi_obj_ref_ids = [int(line.strip()) for line in f.readlines()] |
|
else : |
|
self.multi_obj_ref_ids = None |
|
|
|
path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/llama3-demo/llama3/hardpos_verbphrase_0906upd.json' |
|
with open(path, 'r', encoding='utf-8') as f: |
|
self.metadata = json.load(f) |
|
else : |
|
self.metadata = None |
|
|
|
def _init_db(self): |
|
self.env = lmdb.open(self.lmdb_dir, |
|
subdir=os.path.isdir(self.lmdb_dir), |
|
readonly=True, |
|
lock=False, |
|
readahead=False, |
|
meminit=False) |
|
with self.env.begin(write=False) as txn: |
|
self.length = loads_pyarrow(txn.get(b'__len__')) |
|
self.keys = loads_pyarrow(txn.get(b'__keys__')) |
|
|
|
def __len__(self): |
|
return self.length |
|
|
|
def __getitem__(self, index): |
|
|
|
if self.env is None: |
|
self._init_db() |
|
env = self.env |
|
with env.begin(write=False) as txn: |
|
byteflow = txn.get(self.keys[index]) |
|
ref = loads_pyarrow(byteflow) |
|
|
|
ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8), |
|
cv2.IMREAD_COLOR) |
|
img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
seg_id = ref['seg_id'] |
|
mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png') |
|
|
|
mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8), |
|
cv2.IMREAD_GRAYSCALE) |
|
mask = mask / 255. |
|
|
|
|
|
|
|
resized = self.resize_bg1(image=img, mask=mask) |
|
imgs, masks = [resized['image']], [resized['mask']] |
|
img = imgs[0] |
|
mask = masks[0] |
|
mask = mask.astype(np.uint8) |
|
mask[mask>0] = 1 |
|
|
|
|
|
img_size = img.shape[:2] |
|
mat, mat_inv = self.getTransformMat(img_size, True) |
|
img = cv2.warpAffine( |
|
img, |
|
mat, |
|
self.input_size, |
|
flags=cv2.INTER_CUBIC, |
|
borderValue=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255]) |
|
|
|
|
|
sents = ref['sents'] |
|
n_sentences = ref['num_sents'] |
|
|
|
if self.mode == 'train': |
|
|
|
mask = cv2.warpAffine(mask, |
|
mat, |
|
self.input_size, |
|
flags=cv2.INTER_LINEAR, |
|
borderValue=0.) |
|
|
|
|
|
if self.metric_learning: |
|
if self.hardpos_rigid and seg_id in self.multi_obj_ref_ids: |
|
if n_sentences > 1: |
|
idx = np.random.choice(ref['num_sents'], 2, replace=False) |
|
sent = [sents[i] for i in idx] |
|
else: |
|
sent = [sents[0], sents[0]] |
|
else: |
|
|
|
hardpos_dict = self.metadata[str(ref['seg_id'])] |
|
hardpos_list = list(itertools.chain(*hardpos_dict.values())) |
|
sent_id_list = list(hardpos_dict.keys()) |
|
|
|
if n_sentences > 1: |
|
if self.hardpos_rigid : |
|
idx = np.random.choice(ref['num_sents'], 1, replace=False)[0] |
|
cur_hardpos = hardpos_dict[sent_id_list[idx]] |
|
if len(cur_hardpos) == 0 : |
|
idx = np.random.choice(ref['num_sents'], 2, replace=False) |
|
sent = [sents[i] for i in idx] |
|
else : |
|
hardpos_choice = random.choice(cur_hardpos) |
|
sent = [sents[idx], hardpos_choice] |
|
random.shuffle(sent) |
|
else : |
|
if len(hardpos_list) == 0 : |
|
idx = np.random.choice(ref['num_sents'], 2, replace=False) |
|
sent = [sents[i] for i in idx] |
|
else : |
|
idx = np.random.choice(ref['num_sents'], 1, replace=False)[0] |
|
hardpos_choice = random.choice(hardpos_list) |
|
sent = [sents[idx], hardpos_choice] |
|
random.shuffle(sent) |
|
|
|
else: |
|
if len(hardpos_list) == 0 : |
|
sent = [sents[0], sents[0]] |
|
else : |
|
hardpos_choice = random.choice(hardpos_list) |
|
sent = [sents[0], hardpos_choice] |
|
random.shuffle(sent) |
|
|
|
else: |
|
idx = np.random.choice(ref['num_sents'], 1, replace=False) |
|
sent = sents[idx] |
|
word_vec = tokenize(sent, self.word_length, True).squeeze(0) |
|
img, mask = self.convert(img, mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return img, word_vec, mask |
|
|
|
elif self.mode == 'val': |
|
|
|
sent = sents[0] |
|
word_vec = tokenize(sent, self.word_length, True).squeeze(0) |
|
img = self.convert(img)[0] |
|
params = { |
|
'mask_dir': mask_dir, |
|
'inverse': mat_inv, |
|
'ori_size': np.array(img_size) |
|
} |
|
return img, word_vec, mask, params |
|
else: |
|
|
|
img = self.convert(img)[0] |
|
params = { |
|
'ori_img': ori_img, |
|
'seg_id': seg_id, |
|
'mask_dir': mask_dir, |
|
'inverse': mat_inv, |
|
'ori_size': np.array(img_size), |
|
'sents': sents |
|
} |
|
return img, mask, params |
|
|
|
def getTransformMat(self, img_size, inverse=False): |
|
ori_h, ori_w = img_size |
|
inp_h, inp_w = self.input_size |
|
scale = min(inp_h / ori_h, inp_w / ori_w) |
|
new_h, new_w = ori_h * scale, ori_w * scale |
|
bias_x, bias_y = (inp_w - new_w) / 2., (inp_h - new_h) / 2. |
|
|
|
src = np.array([[0, 0], [ori_w, 0], [0, ori_h]], np.float32) |
|
dst = np.array([[bias_x, bias_y], [new_w + bias_x, bias_y], |
|
[bias_x, new_h + bias_y]], np.float32) |
|
|
|
mat = cv2.getAffineTransform(src, dst) |
|
if inverse: |
|
mat_inv = cv2.getAffineTransform(dst, src) |
|
return mat, mat_inv |
|
return mat, None |
|
|
|
def convert(self, img, mask=None): |
|
|
|
img = torch.from_numpy(img.transpose((2, 0, 1))) |
|
if not isinstance(img, torch.FloatTensor): |
|
img = img.float() |
|
img.div_(255.).sub_(self.mean).div_(self.std) |
|
|
|
if mask is not None: |
|
mask = torch.from_numpy(mask) |
|
if not isinstance(mask, torch.FloatTensor): |
|
mask = mask.float() |
|
return img, mask |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + "(" + \ |
|
f"db_path={self.lmdb_dir}, " + \ |
|
f"dataset={self.dataset}, " + \ |
|
f"split={self.split}, " + \ |
|
f"mode={self.mode}, " + \ |
|
f"input_size={self.input_size}, " + \ |
|
f"word_length={self.word_length}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|