#%% import os from typing import List, Union import json import cv2 import lmdb import random import numpy as np import pyarrow as pa import torch from torch.utils.data import Dataset import itertools import albumentations as A from albumentations.pytorch import ToTensorV2 from .simple_tokenizer import SimpleTokenizer as _Tokenizer info = { 'refcoco': { 'train': 42404, 'val': 3811, 'val-test': 3811, 'testA': 1975, 'testB': 1810 }, 'refcoco+': { 'train': 42278, 'val': 3805, 'val-test': 3805, 'testA': 1975, 'testB': 1798 }, 'refcocog_u': { 'train': 42226, 'val': 2573, 'val-test': 2573, 'test': 5023, 'test_0-5_verb' : 572, 'test_0-5_static' : 1688, 'test_6-7_verb' : 949, 'test_6-7_static' : 1240, 'test_8-10_verb' : 1523, 'test_8-10_static' : 1194, 'test_11-20_verb' : 1768, 'test_11-20_static' : 584, 'test_abl_motion' : 267, 'test_abl_static' : 267 }, 'refcocog_g': { 'train': 44822, 'val': 5000, 'val-test': 5000 } } _tokenizer = _Tokenizer() #%% def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length truncate: bool Whether to truncate the text in case its encoding is longer than the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] """ if isinstance(texts, str): texts = [texts] sot_token = _tokenizer.encoder["<|startoftext|>"] eot_token = _tokenizer.encoder["<|endoftext|>"] all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: if truncate: tokens = tokens[:context_length] tokens[-1] = eot_token else: raise RuntimeError( f"Input {texts[i]} is too long for context length {context_length}" ) result[i, :len(tokens)] = torch.tensor(tokens) return result def loads_pyarrow(buf): """ Args: buf: the output of `dumps`. """ return pa.deserialize(buf) class RefDataset(Dataset): def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size, word_length, args): super(RefDataset, self).__init__() self.lmdb_dir = lmdb_dir self.mask_dir = mask_dir self.dataset = dataset self.split = split self.mode = mode self.input_size = (input_size, input_size) self.word_length = word_length self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).reshape(3, 1, 1) self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).reshape(3, 1, 1) self.length = info[dataset][split] self.env = None self.exclude_position = args.exclude_pos self.metric_learning = args.metric_learning self.hardpos_rigid = args.hardpos_rigid self.resize_bg1 = A.Compose([ A.Resize(input_size, input_size, always_apply=True)]) if self.metric_learning : if self.hardpos_rigid and self.exclude_position : multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj_nopos.txt' with open(multiobj_path, 'r') as f: self.multi_obj_ref_ids = [int(line.strip()) for line in f.readlines()] elif self.hardpos_rigid : multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj.txt' with open(multiobj_path, 'r') as f: self.multi_obj_ref_ids = [int(line.strip()) for line in f.readlines()] else : self.multi_obj_ref_ids = None path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/llama3-demo/llama3/hardpos_verbphrase_0906upd.json' with open(path, 'r', encoding='utf-8') as f: self.metadata = json.load(f) else : self.metadata = None def _init_db(self): self.env = lmdb.open(self.lmdb_dir, subdir=os.path.isdir(self.lmdb_dir), readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin(write=False) as txn: self.length = loads_pyarrow(txn.get(b'__len__')) self.keys = loads_pyarrow(txn.get(b'__keys__')) def __len__(self): return self.length def __getitem__(self, index): # Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129 if self.env is None: self._init_db() env = self.env with env.begin(write=False) as txn: byteflow = txn.get(self.keys[index]) ref = loads_pyarrow(byteflow) # img ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8), cv2.IMREAD_COLOR) img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB) # mask seg_id = ref['seg_id'] mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png') mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8), cv2.IMREAD_GRAYSCALE) mask = mask / 255. # image resizing resized = self.resize_bg1(image=img, mask=mask) imgs, masks = [resized['image']], [resized['mask']] img = imgs[0] mask = masks[0] mask = mask.astype(np.uint8) mask[mask>0] = 1 # image transform img_size = img.shape[:2] mat, mat_inv = self.getTransformMat(img_size, True) img = cv2.warpAffine( img, mat, self.input_size, flags=cv2.INTER_CUBIC, borderValue=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255]) # sentences sents = ref['sents'] n_sentences = ref['num_sents'] if self.mode == 'train': # mask transform mask = cv2.warpAffine(mask, mat, self.input_size, flags=cv2.INTER_LINEAR, borderValue=0.) # if metric learning, select 2 positive sentences if self.metric_learning: if self.hardpos_rigid and seg_id in self.multi_obj_ref_ids: if n_sentences > 1: idx = np.random.choice(ref['num_sents'], 2, replace=False) sent = [sents[i] for i in idx] else: sent = [sents[0], sents[0]] else: # Added processing hardpos data hardpos_dict = self.metadata[str(ref['seg_id'])] hardpos_list = list(itertools.chain(*hardpos_dict.values())) sent_id_list = list(hardpos_dict.keys()) if n_sentences > 1: if self.hardpos_rigid : idx = np.random.choice(ref['num_sents'], 1, replace=False)[0] cur_hardpos = hardpos_dict[sent_id_list[idx]] if len(cur_hardpos) == 0 : idx = np.random.choice(ref['num_sents'], 2, replace=False) sent = [sents[i] for i in idx] else : hardpos_choice = random.choice(cur_hardpos) sent = [sents[idx], hardpos_choice] random.shuffle(sent) else : if len(hardpos_list) == 0 : idx = np.random.choice(ref['num_sents'], 2, replace=False) sent = [sents[i] for i in idx] else : idx = np.random.choice(ref['num_sents'], 1, replace=False)[0] hardpos_choice = random.choice(hardpos_list) sent = [sents[idx], hardpos_choice] random.shuffle(sent) # if there's only one, duplicate it else: if len(hardpos_list) == 0 : sent = [sents[0], sents[0]] else : hardpos_choice = random.choice(hardpos_list) sent = [sents[0], hardpos_choice] random.shuffle(sent) # print(f"Generated sentences: {sent}") else: idx = np.random.choice(ref['num_sents'], 1, replace=False) sent = sents[idx] word_vec = tokenize(sent, self.word_length, True).squeeze(0) img, mask = self.convert(img, mask) # params = { # 'ori_img': ori_img, # 'seg_id': seg_id, # 'mask_dir': mask_dir, # 'inverse': mat_inv, # 'ori_size': np.array(img_size), # 'sents': sents # } return img, word_vec, mask elif self.mode == 'val': # sentence -> vector sent = sents[0] word_vec = tokenize(sent, self.word_length, True).squeeze(0) img = self.convert(img)[0] params = { 'mask_dir': mask_dir, 'inverse': mat_inv, 'ori_size': np.array(img_size) } return img, word_vec, mask, params else: # sentence -> vector img = self.convert(img)[0] params = { 'ori_img': ori_img, 'seg_id': seg_id, 'mask_dir': mask_dir, 'inverse': mat_inv, 'ori_size': np.array(img_size), 'sents': sents } return img, mask, params def getTransformMat(self, img_size, inverse=False): ori_h, ori_w = img_size inp_h, inp_w = self.input_size scale = min(inp_h / ori_h, inp_w / ori_w) new_h, new_w = ori_h * scale, ori_w * scale bias_x, bias_y = (inp_w - new_w) / 2., (inp_h - new_h) / 2. src = np.array([[0, 0], [ori_w, 0], [0, ori_h]], np.float32) dst = np.array([[bias_x, bias_y], [new_w + bias_x, bias_y], [bias_x, new_h + bias_y]], np.float32) mat = cv2.getAffineTransform(src, dst) if inverse: mat_inv = cv2.getAffineTransform(dst, src) return mat, mat_inv return mat, None def convert(self, img, mask=None): # Image ToTensor & Normalize img = torch.from_numpy(img.transpose((2, 0, 1))) if not isinstance(img, torch.FloatTensor): img = img.float() img.div_(255.).sub_(self.mean).div_(self.std) # Mask ToTensor if mask is not None: mask = torch.from_numpy(mask) if not isinstance(mask, torch.FloatTensor): mask = mask.float() return img, mask def __repr__(self): return self.__class__.__name__ + "(" + \ f"db_path={self.lmdb_dir}, " + \ f"dataset={self.dataset}, " + \ f"split={self.split}, " + \ f"mode={self.mode}, " + \ f"input_size={self.input_size}, " + \ f"word_length={self.word_length}" # def get_length(self): # return self.length # def get_sample(self, idx): # return self.__getitem__(idx)