dianecy commited on
Commit
fce6bfe
verified
1 Parent(s): 10842b4

Upload folder using huggingface_hub

Browse files
utils/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .simple_tokenizer import SimpleTokenizer
2
+ from .config import *
3
+ from .dataset import *
4
+ from .misc import *
utils/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
utils/config.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------
2
+ # Functions for parsing args
3
+ # -----------------------------------------------------------------------------
4
+ import copy
5
+ import os
6
+ from ast import literal_eval
7
+
8
+ import yaml
9
+
10
+
11
+ class CfgNode(dict):
12
+ """
13
+ CfgNode represents an internal node in the configuration tree. It's a simple
14
+ dict-like container that allows for attribute-based access to keys.
15
+ """
16
+ def __init__(self, init_dict=None, key_list=None, new_allowed=False):
17
+ # Recursively convert nested dictionaries in init_dict into CfgNodes
18
+ init_dict = {} if init_dict is None else init_dict
19
+ key_list = [] if key_list is None else key_list
20
+ for k, v in init_dict.items():
21
+ if type(v) is dict:
22
+ # Convert dict to CfgNode
23
+ init_dict[k] = CfgNode(v, key_list=key_list + [k])
24
+ super(CfgNode, self).__init__(init_dict)
25
+
26
+ def __getattr__(self, name):
27
+ if name in self:
28
+ return self[name]
29
+ else:
30
+ raise AttributeError(name)
31
+
32
+ def __setattr__(self, name, value):
33
+ self[name] = value
34
+
35
+ def __str__(self):
36
+ def _indent(s_, num_spaces):
37
+ s = s_.split("\n")
38
+ if len(s) == 1:
39
+ return s_
40
+ first = s.pop(0)
41
+ s = [(num_spaces * " ") + line for line in s]
42
+ s = "\n".join(s)
43
+ s = first + "\n" + s
44
+ return s
45
+
46
+ r = ""
47
+ s = []
48
+ for k, v in sorted(self.items()):
49
+ seperator = "\n" if isinstance(v, CfgNode) else " "
50
+ attr_str = "{}:{}{}".format(str(k), seperator, str(v))
51
+ attr_str = _indent(attr_str, 2)
52
+ s.append(attr_str)
53
+ r += "\n".join(s)
54
+ return r
55
+
56
+ def __repr__(self):
57
+ return "{}({})".format(self.__class__.__name__,
58
+ super(CfgNode, self).__repr__())
59
+
60
+
61
+ def load_cfg_from_cfg_file(file):
62
+ cfg = {}
63
+ assert os.path.isfile(file) and file.endswith('.yaml'), \
64
+ '{} is not a yaml file'.format(file)
65
+
66
+ with open(file, 'r') as f:
67
+ cfg_from_file = yaml.safe_load(f)
68
+
69
+ for key in cfg_from_file:
70
+ for k, v in cfg_from_file[key].items():
71
+ cfg[k] = v
72
+
73
+ cfg = CfgNode(cfg)
74
+ return cfg
75
+
76
+
77
+ def merge_cfg_from_list(cfg, cfg_list):
78
+ new_cfg = copy.deepcopy(cfg)
79
+ assert len(cfg_list) % 2 == 0
80
+ for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
81
+ subkey = full_key.split('.')[-1]
82
+ assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
83
+ value = _decode_cfg_value(v)
84
+ value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
85
+ full_key)
86
+ setattr(new_cfg, subkey, value)
87
+
88
+ return new_cfg
89
+
90
+
91
+ def _decode_cfg_value(v):
92
+ """Decodes a raw config value (e.g., from a yaml config files or command
93
+ line argument) into a Python object.
94
+ """
95
+ # All remaining processing is only applied to strings
96
+ if not isinstance(v, str):
97
+ return v
98
+ # Try to interpret `v` as a:
99
+ # string, number, tuple, list, dict, boolean, or None
100
+ try:
101
+ v = literal_eval(v)
102
+ # The following two excepts allow v to pass through when it represents a
103
+ # string.
104
+ #
105
+ # Longer explanation:
106
+ # The type of v is always a string (before calling literal_eval), but
107
+ # sometimes it *represents* a string and other times a data structure, like
108
+ # a list. In the case that v represents a string, what we got back from the
109
+ # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
110
+ # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
111
+ # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
112
+ # will raise a SyntaxError.
113
+ except ValueError:
114
+ pass
115
+ except SyntaxError:
116
+ pass
117
+ return v
118
+
119
+
120
+ def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
121
+ """Checks that `replacement`, which is intended to replace `original` is of
122
+ the right type. The type is correct if it matches exactly or is one of a few
123
+ cases in which the type can be easily coerced.
124
+ """
125
+ original_type = type(original)
126
+ replacement_type = type(replacement)
127
+
128
+ # The types must match (with some exceptions)
129
+ if replacement_type == original_type:
130
+ return replacement
131
+
132
+ # Cast replacement from from_type to to_type if the replacement and original
133
+ # types match from_type and to_type
134
+ def conditional_cast(from_type, to_type):
135
+ if replacement_type == from_type and original_type == to_type:
136
+ return True, to_type(replacement)
137
+ else:
138
+ return False, None
139
+
140
+ # Conditionally casts
141
+ # list <-> tuple
142
+ casts = [(tuple, list), (list, tuple)]
143
+ # For py2: allow converting from str (bytes) to a unicode string
144
+ try:
145
+ casts.append((str, unicode)) # noqa: F821
146
+ except Exception:
147
+ pass
148
+
149
+ for (from_type, to_type) in casts:
150
+ converted, converted_value = conditional_cast(from_type, to_type)
151
+ if converted:
152
+ return converted_value
153
+
154
+ raise ValueError(
155
+ "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
156
+ "key: {}".format(original_type, replacement_type, original,
157
+ replacement, full_key))
utils/dataset.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ from typing import List, Union
4
+ import json
5
+ import cv2
6
+ import lmdb
7
+ import random
8
+ import numpy as np
9
+ import pyarrow as pa
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+ import itertools
13
+ import albumentations as A
14
+ from albumentations.pytorch import ToTensorV2
15
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
16
+
17
+ info = {
18
+ 'refcoco': {
19
+ 'train': 42404,
20
+ 'val': 3811,
21
+ 'val-test': 3811,
22
+ 'testA': 1975,
23
+ 'testB': 1810
24
+ },
25
+ 'refcoco+': {
26
+ 'train': 42278,
27
+ 'val': 3805,
28
+ 'val-test': 3805,
29
+ 'testA': 1975,
30
+ 'testB': 1798
31
+ },
32
+ 'refcocog_u': {
33
+ 'train': 42226,
34
+ 'val': 2573,
35
+ 'val-test': 2573,
36
+ 'test': 5023,
37
+ 'test_0-5_verb' : 572,
38
+ 'test_0-5_static' : 1688,
39
+ 'test_6-7_verb' : 949,
40
+ 'test_6-7_static' : 1240,
41
+ 'test_8-10_verb' : 1523,
42
+ 'test_8-10_static' : 1194,
43
+ 'test_11-20_verb' : 1768,
44
+ 'test_11-20_static' : 584,
45
+ 'test_abl_motion' : 267,
46
+ 'test_abl_static' : 267
47
+ },
48
+ 'refcocog_g': {
49
+ 'train': 44822,
50
+ 'val': 5000,
51
+ 'val-test': 5000
52
+ }
53
+ }
54
+ _tokenizer = _Tokenizer()
55
+
56
+ #%%
57
+ def tokenize(texts: Union[str, List[str]],
58
+ context_length: int = 77,
59
+ truncate: bool = False) -> torch.LongTensor:
60
+ """
61
+ Returns the tokenized representation of given input string(s)
62
+
63
+ Parameters
64
+ ----------
65
+ texts : Union[str, List[str]]
66
+ An input string or a list of input strings to tokenize
67
+
68
+ context_length : int
69
+ The context length to use; all CLIP models use 77 as the context length
70
+
71
+ truncate: bool
72
+ Whether to truncate the text in case its encoding is longer than the context length
73
+
74
+ Returns
75
+ -------
76
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
77
+ """
78
+ if isinstance(texts, str):
79
+ texts = [texts]
80
+
81
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
82
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
83
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
84
+ for text in texts]
85
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
86
+
87
+ for i, tokens in enumerate(all_tokens):
88
+ if len(tokens) > context_length:
89
+ if truncate:
90
+ tokens = tokens[:context_length]
91
+ tokens[-1] = eot_token
92
+ else:
93
+ raise RuntimeError(
94
+ f"Input {texts[i]} is too long for context length {context_length}"
95
+ )
96
+ result[i, :len(tokens)] = torch.tensor(tokens)
97
+
98
+ return result
99
+
100
+
101
+ def loads_pyarrow(buf):
102
+ """
103
+ Args:
104
+ buf: the output of `dumps`.
105
+ """
106
+ return pa.deserialize(buf)
107
+
108
+
109
+ class RefDataset(Dataset):
110
+ def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size,
111
+ word_length, args):
112
+ super(RefDataset, self).__init__()
113
+ self.lmdb_dir = lmdb_dir
114
+ self.mask_dir = mask_dir
115
+ self.dataset = dataset
116
+ self.split = split
117
+ self.mode = mode
118
+ self.input_size = (input_size, input_size)
119
+ self.word_length = word_length
120
+ self.mean = torch.tensor([0.48145466, 0.4578275,
121
+ 0.40821073]).reshape(3, 1, 1)
122
+ self.std = torch.tensor([0.26862954, 0.26130258,
123
+ 0.27577711]).reshape(3, 1, 1)
124
+ self.length = info[dataset][split]
125
+ self.env = None
126
+ self.exclude_position = args.exclude_pos
127
+ self.metric_learning = args.metric_learning
128
+ self.hardpos_rigid = args.hardpos_rigid
129
+ self.resize_bg1 = A.Compose([
130
+ A.Resize(input_size, input_size, always_apply=True)])
131
+ if self.metric_learning :
132
+ if self.hardpos_rigid and self.exclude_position :
133
+ multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj_nopos.txt'
134
+ with open(multiobj_path, 'r') as f:
135
+ self.multi_obj_ref_ids = [int(line.strip()) for line in f.readlines()]
136
+ elif self.hardpos_rigid :
137
+ multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj.txt'
138
+ with open(multiobj_path, 'r') as f:
139
+ self.multi_obj_ref_ids = [int(line.strip()) for line in f.readlines()]
140
+ else :
141
+ self.multi_obj_ref_ids = None
142
+
143
+ path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/llama3-demo/llama3/hardpos_verbphrase_0906upd.json'
144
+ with open(path, 'r', encoding='utf-8') as f:
145
+ self.metadata = json.load(f)
146
+ else :
147
+ self.metadata = None
148
+
149
+ def _init_db(self):
150
+ self.env = lmdb.open(self.lmdb_dir,
151
+ subdir=os.path.isdir(self.lmdb_dir),
152
+ readonly=True,
153
+ lock=False,
154
+ readahead=False,
155
+ meminit=False)
156
+ with self.env.begin(write=False) as txn:
157
+ self.length = loads_pyarrow(txn.get(b'__len__'))
158
+ self.keys = loads_pyarrow(txn.get(b'__keys__'))
159
+
160
+ def __len__(self):
161
+ return self.length
162
+
163
+ def __getitem__(self, index):
164
+ # Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129
165
+ if self.env is None:
166
+ self._init_db()
167
+ env = self.env
168
+ with env.begin(write=False) as txn:
169
+ byteflow = txn.get(self.keys[index])
170
+ ref = loads_pyarrow(byteflow)
171
+ # img
172
+ ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8),
173
+ cv2.IMREAD_COLOR)
174
+ img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
175
+
176
+ # mask
177
+ seg_id = ref['seg_id']
178
+ mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png')
179
+
180
+ mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8),
181
+ cv2.IMREAD_GRAYSCALE)
182
+ mask = mask / 255.
183
+
184
+
185
+ # image resizing
186
+ resized = self.resize_bg1(image=img, mask=mask)
187
+ imgs, masks = [resized['image']], [resized['mask']]
188
+ img = imgs[0]
189
+ mask = masks[0]
190
+ mask = mask.astype(np.uint8)
191
+ mask[mask>0] = 1
192
+
193
+ # image transform
194
+ img_size = img.shape[:2]
195
+ mat, mat_inv = self.getTransformMat(img_size, True)
196
+ img = cv2.warpAffine(
197
+ img,
198
+ mat,
199
+ self.input_size,
200
+ flags=cv2.INTER_CUBIC,
201
+ borderValue=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255])
202
+
203
+ # sentences
204
+ sents = ref['sents']
205
+ n_sentences = ref['num_sents']
206
+
207
+ if self.mode == 'train':
208
+ # mask transform
209
+ mask = cv2.warpAffine(mask,
210
+ mat,
211
+ self.input_size,
212
+ flags=cv2.INTER_LINEAR,
213
+ borderValue=0.)
214
+
215
+ # if metric learning, select 2 positive sentences
216
+ if self.metric_learning:
217
+ if self.hardpos_rigid and seg_id in self.multi_obj_ref_ids:
218
+ if n_sentences > 1:
219
+ idx = np.random.choice(ref['num_sents'], 2, replace=False)
220
+ sent = [sents[i] for i in idx]
221
+ else:
222
+ sent = [sents[0], sents[0]]
223
+ else:
224
+ # Added processing hardpos data
225
+ hardpos_dict = self.metadata[str(ref['seg_id'])]
226
+ hardpos_list = list(itertools.chain(*hardpos_dict.values()))
227
+ sent_id_list = list(hardpos_dict.keys())
228
+
229
+ if n_sentences > 1:
230
+ if self.hardpos_rigid :
231
+ idx = np.random.choice(ref['num_sents'], 1, replace=False)[0]
232
+ cur_hardpos = hardpos_dict[sent_id_list[idx]]
233
+ if len(cur_hardpos) == 0 :
234
+ idx = np.random.choice(ref['num_sents'], 2, replace=False)
235
+ sent = [sents[i] for i in idx]
236
+ else :
237
+ hardpos_choice = random.choice(cur_hardpos)
238
+ sent = [sents[idx], hardpos_choice]
239
+ random.shuffle(sent)
240
+ else :
241
+ if len(hardpos_list) == 0 :
242
+ idx = np.random.choice(ref['num_sents'], 2, replace=False)
243
+ sent = [sents[i] for i in idx]
244
+ else :
245
+ idx = np.random.choice(ref['num_sents'], 1, replace=False)[0]
246
+ hardpos_choice = random.choice(hardpos_list)
247
+ sent = [sents[idx], hardpos_choice]
248
+ random.shuffle(sent)
249
+ # if there's only one, duplicate it
250
+ else:
251
+ if len(hardpos_list) == 0 :
252
+ sent = [sents[0], sents[0]]
253
+ else :
254
+ hardpos_choice = random.choice(hardpos_list)
255
+ sent = [sents[0], hardpos_choice]
256
+ random.shuffle(sent)
257
+ # print(f"Generated sentences: {sent}")
258
+ else:
259
+ idx = np.random.choice(ref['num_sents'], 1, replace=False)
260
+ sent = sents[idx]
261
+ word_vec = tokenize(sent, self.word_length, True).squeeze(0)
262
+ img, mask = self.convert(img, mask)
263
+
264
+ # params = {
265
+ # 'ori_img': ori_img,
266
+ # 'seg_id': seg_id,
267
+ # 'mask_dir': mask_dir,
268
+ # 'inverse': mat_inv,
269
+ # 'ori_size': np.array(img_size),
270
+ # 'sents': sents
271
+ # }
272
+ return img, word_vec, mask
273
+
274
+ elif self.mode == 'val':
275
+ # sentence -> vector
276
+ sent = sents[0]
277
+ word_vec = tokenize(sent, self.word_length, True).squeeze(0)
278
+ img = self.convert(img)[0]
279
+ params = {
280
+ 'mask_dir': mask_dir,
281
+ 'inverse': mat_inv,
282
+ 'ori_size': np.array(img_size)
283
+ }
284
+ return img, word_vec, mask, params
285
+ else:
286
+ # sentence -> vector
287
+ img = self.convert(img)[0]
288
+ params = {
289
+ 'ori_img': ori_img,
290
+ 'seg_id': seg_id,
291
+ 'mask_dir': mask_dir,
292
+ 'inverse': mat_inv,
293
+ 'ori_size': np.array(img_size),
294
+ 'sents': sents
295
+ }
296
+ return img, mask, params
297
+
298
+ def getTransformMat(self, img_size, inverse=False):
299
+ ori_h, ori_w = img_size
300
+ inp_h, inp_w = self.input_size
301
+ scale = min(inp_h / ori_h, inp_w / ori_w)
302
+ new_h, new_w = ori_h * scale, ori_w * scale
303
+ bias_x, bias_y = (inp_w - new_w) / 2., (inp_h - new_h) / 2.
304
+
305
+ src = np.array([[0, 0], [ori_w, 0], [0, ori_h]], np.float32)
306
+ dst = np.array([[bias_x, bias_y], [new_w + bias_x, bias_y],
307
+ [bias_x, new_h + bias_y]], np.float32)
308
+
309
+ mat = cv2.getAffineTransform(src, dst)
310
+ if inverse:
311
+ mat_inv = cv2.getAffineTransform(dst, src)
312
+ return mat, mat_inv
313
+ return mat, None
314
+
315
+ def convert(self, img, mask=None):
316
+ # Image ToTensor & Normalize
317
+ img = torch.from_numpy(img.transpose((2, 0, 1)))
318
+ if not isinstance(img, torch.FloatTensor):
319
+ img = img.float()
320
+ img.div_(255.).sub_(self.mean).div_(self.std)
321
+ # Mask ToTensor
322
+ if mask is not None:
323
+ mask = torch.from_numpy(mask)
324
+ if not isinstance(mask, torch.FloatTensor):
325
+ mask = mask.float()
326
+ return img, mask
327
+
328
+ def __repr__(self):
329
+ return self.__class__.__name__ + "(" + \
330
+ f"db_path={self.lmdb_dir}, " + \
331
+ f"dataset={self.dataset}, " + \
332
+ f"split={self.split}, " + \
333
+ f"mode={self.mode}, " + \
334
+ f"input_size={self.input_size}, " + \
335
+ f"word_length={self.word_length}"
336
+
337
+ # def get_length(self):
338
+ # return self.length
339
+
340
+ # def get_sample(self, idx):
341
+ # return self.__getitem__(idx)
utils/dataset_verbonly.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import os
3
+ from typing import List, Union
4
+ import json
5
+ import cv2
6
+ import lmdb
7
+ import random
8
+ import numpy as np
9
+ import pyarrow as pa
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+ import itertools
13
+ import albumentations as A
14
+ from albumentations.pytorch import ToTensorV2
15
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
16
+
17
+ info = {
18
+ 'refcoco': {
19
+ 'train': 42404,
20
+ 'val': 3811,
21
+ 'val-test': 3811,
22
+ 'testA': 1975,
23
+ 'testB': 1810
24
+ },
25
+ 'refcoco+': {
26
+ 'train': 42278,
27
+ 'val': 3805,
28
+ 'val-test': 3805,
29
+ 'testA': 1975,
30
+ 'testB': 1798
31
+ },
32
+ 'refcocog_u': {
33
+ 'train': 42226,
34
+ 'val': 2573,
35
+ 'val-test': 2573,
36
+ 'test': 5023,
37
+ },
38
+ 'refcocog_g': {
39
+ 'train': 44822,
40
+ 'val': 5000,
41
+ 'val-test': 5000
42
+ }
43
+ }
44
+ _tokenizer = _Tokenizer()
45
+
46
+ #%%
47
+ def tokenize(texts: Union[str, List[str]],
48
+ context_length: int = 77,
49
+ truncate: bool = False) -> torch.LongTensor:
50
+ """
51
+ Returns the tokenized representation of given input string(s)
52
+
53
+ Parameters
54
+ ----------
55
+ texts : Union[str, List[str]]
56
+ An input string or a list of input strings to tokenize
57
+
58
+ context_length : int
59
+ The context length to use; all CLIP models use 77 as the context length
60
+
61
+ truncate: bool
62
+ Whether to truncate the text in case its encoding is longer than the context length
63
+
64
+ Returns
65
+ -------
66
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
67
+ """
68
+ if isinstance(texts, str):
69
+ texts = [texts]
70
+
71
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
72
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
73
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
74
+ for text in texts]
75
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
76
+
77
+ for i, tokens in enumerate(all_tokens):
78
+ if len(tokens) > context_length:
79
+ if truncate:
80
+ tokens = tokens[:context_length]
81
+ tokens[-1] = eot_token
82
+ else:
83
+ raise RuntimeError(
84
+ f"Input {texts[i]} is too long for context length {context_length}"
85
+ )
86
+ result[i, :len(tokens)] = torch.tensor(tokens)
87
+
88
+ return result
89
+
90
+
91
+ def loads_pyarrow(buf):
92
+ """
93
+ Args:
94
+ buf: the output of `dumps`.
95
+ """
96
+ return pa.deserialize(buf)
97
+
98
+
99
+ class RefDataset(Dataset):
100
+ def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size,
101
+ word_length, args):
102
+ super(RefDataset, self).__init__()
103
+ self.lmdb_dir = lmdb_dir
104
+ self.mask_dir = mask_dir
105
+ self.dataset = dataset
106
+ self.split = split
107
+ self.mode = mode
108
+ self.input_size = (input_size, input_size)
109
+ self.word_length = word_length
110
+ self.mean = torch.tensor([0.48145466, 0.4578275,
111
+ 0.40821073]).reshape(3, 1, 1)
112
+ self.std = torch.tensor([0.26862954, 0.26130258,
113
+ 0.27577711]).reshape(3, 1, 1)
114
+ self.length = info[dataset][split]
115
+ self.env = None
116
+
117
+ self.exclude_position = args.exclude_pos
118
+ self.metric_learning = args.metric_learning
119
+ self.exclude_multiobj = args.exclude_multiobj
120
+ self.metric_mode = args.metric_mode
121
+
122
+ self.resize_bg1 = A.Compose([
123
+ A.Resize(input_size, input_size, always_apply=True)])
124
+ if self.metric_learning:
125
+ self.hardneg_prob = args.hn_prob # Hard negative probability 锟竭帮拷
126
+ self.multi_obj_ref_ids = self._load_multi_obj_ref_ids()
127
+ self.hardpos_meta, self.hardneg_meta = self._load_metadata()
128
+ else:
129
+ self.hardneg_prob = 0.0
130
+ self.multi_obj_ref_ids = None
131
+ self.hardpos_meta, self.hardneg_meta = None, None
132
+
133
+ def _load_multi_obj_ref_ids(self):
134
+ # Load multi-object reference IDs based on configurations
135
+ if not self.exclude_multiobj and not self.exclude_position :
136
+ return None
137
+ elif self.exclude_position:
138
+ multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj_ov2_nopos.txt'
139
+ elif self.exclude_multiobj :
140
+ multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj_ov3.txt'
141
+ with open(multiobj_path, 'r') as f:
142
+ return [int(line.strip()) for line in f.readlines()]
143
+
144
+ def _load_metadata(self):
145
+ # Load metadata for hard positive verb phrases, hard negative queries
146
+ hardpos_path = '/data2/projects/chaeyun/VerbCentric_RIS/hardpos_verbphrase_0906upd.json'
147
+ hardneg_path = '/data2/projects/chaeyun/VerbCentric_RIS/hardneg_verb.json'
148
+
149
+ with open(hardpos_path, 'r', encoding='utf-8') as f:
150
+ hardpos_json = json.load(f)
151
+ if self.metric_mode == "hardpos_only" :
152
+ hardneg_json = None
153
+ else :
154
+ with open(hardneg_path, 'r', encoding='utf-8') as q:
155
+ hardneg_json = json.load(q)
156
+ return hardpos_json, hardneg_json
157
+
158
+
159
+ def _init_db(self):
160
+ self.env = lmdb.open(self.lmdb_dir,
161
+ subdir=os.path.isdir(self.lmdb_dir),
162
+ readonly=True,
163
+ lock=False,
164
+ readahead=False,
165
+ meminit=False)
166
+ with self.env.begin(write=False) as txn:
167
+ self.length = loads_pyarrow(txn.get(b'__len__'))
168
+ self.keys = loads_pyarrow(txn.get(b'__keys__'))
169
+
170
+ def __len__(self):
171
+ return self.length
172
+
173
+ def __getitem__(self, index):
174
+ # Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129
175
+ if self.env is None:
176
+ self._init_db()
177
+ env = self.env
178
+ with env.begin(write=False) as txn:
179
+ byteflow = txn.get(self.keys[index])
180
+ ref = loads_pyarrow(byteflow)
181
+ # img
182
+ ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8),
183
+ cv2.IMREAD_COLOR)
184
+ img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
185
+
186
+ # mask
187
+ seg_id = ref['seg_id']
188
+ mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png')
189
+
190
+ mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8),
191
+ cv2.IMREAD_GRAYSCALE)
192
+ mask = mask / 255.
193
+
194
+
195
+ # image resizing
196
+ resized = self.resize_bg1(image=img, mask=mask)
197
+ imgs, masks = [resized['image']], [resized['mask']]
198
+ img = imgs[0]
199
+ mask = masks[0]
200
+ mask = mask.astype(np.uint8)
201
+ mask[mask>0] = 1
202
+
203
+ # image transform
204
+ img_size = img.shape[:2]
205
+ mat, mat_inv = self.getTransformMat(img_size, True)
206
+ img = cv2.warpAffine(
207
+ img,
208
+ mat,
209
+ self.input_size,
210
+ flags=cv2.INTER_CUBIC,
211
+ borderValue=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255])
212
+
213
+ # sentences
214
+ sents = ref['sents']
215
+ n_sentences = ref['num_sents']
216
+
217
+ if self.mode == 'train':
218
+ # mask transform
219
+ mask = cv2.warpAffine(mask,
220
+ mat,
221
+ self.input_size,
222
+ flags=cv2.INTER_LINEAR,
223
+ borderValue=0.)
224
+
225
+ # if metric learning, assign hard positive verb phrase if applicable
226
+ idx = np.random.choice(n_sentences, 1, replace=False)[0]
227
+ sent = sents[idx]
228
+ raw_hardpos, hardpos = self._get_hardpos_verb(ref, seg_id, idx)
229
+ img, mask = self.convert(img, mask)
230
+ word_vec = tokenize(sent, self.word_length, True).squeeze(0)
231
+
232
+ if self.metric_mode == "hardpos_only" :
233
+ return img, word_vec, mask, hardpos
234
+
235
+ else :
236
+ choice = np.random.choice(['hn', 'no_hn'], p=[self.hardneg_prob, 1 - self.hardneg_prob])
237
+ if choice == 'hn' and raw_hardpos :
238
+ raw_hardneg, hardneg = self._get_hardneg_verb(ref, seg_id, idx)
239
+ else :
240
+ hardneg = torch.zeros(self.word_length, dtype=torch.long)
241
+ return img, word_vec, mask, hardpos, hardneg
242
+
243
+ elif self.mode == 'val':
244
+ # sentence -> vector
245
+ sent = sents[0]
246
+ word_vec = tokenize(sent, self.word_length, True).squeeze(0)
247
+ img = self.convert(img)[0]
248
+ params = {
249
+ 'mask_dir': mask_dir,
250
+ 'inverse': mat_inv,
251
+ 'ori_size': np.array(img_size)
252
+ }
253
+ return img, word_vec, mask, params
254
+ else:
255
+ # sentence -> vector
256
+ img = self.convert(img)[0]
257
+ params = {
258
+ 'ori_img': ori_img,
259
+ 'seg_id': seg_id,
260
+ 'mask_dir': mask_dir,
261
+ 'inverse': mat_inv,
262
+ 'ori_size': np.array(img_size),
263
+ 'sents': sents
264
+ }
265
+ return img, mask, params
266
+
267
+
268
+ def _get_hardneg_verb(self, ref, seg_id, sent_idx):
269
+ """
270
+ Handle the logic for selecting hard positive verb phrases during metric learning.
271
+ Returns the sentence, raw_verb, and tokenized verb if applicable.
272
+ """
273
+
274
+ # Extract metadata for hard positives if present
275
+ hardneg_dict = self.hardneg_meta.get(str(seg_id), {})
276
+ sent_id_list = list(hardneg_dict.keys())
277
+
278
+ cur_hardneg = hardpos_dict.get(sent_id_list[sent_idx], [])
279
+ if cur_hardneg:
280
+ # Assign a hard positive verb phrase if available
281
+ raw_verb_hardneg = random.choice(cur_hardneg)
282
+ verb_hardneg = tokenize(raw_verb_hardneg, self.word_length, True).squeeze(0)
283
+ return raw_verb_hardneg, verb_hardneg
284
+
285
+ verb_hardneg = torch.zeros(self.word_length, dtype=torch.long)
286
+ return '', verb_hardneg
287
+
288
+
289
+
290
+ def _get_hardpos_verb(self, ref, seg_id, sent_idx):
291
+ """
292
+ Handle the logic for selecting hard positive verb phrases during metric learning.
293
+ Returns the sentence, raw_verb, and tokenized verb if applicable.
294
+ """
295
+ # If the object appears multiple times, no hard positive is used
296
+ if seg_id in self.multi_obj_ref_ids:
297
+ verb_hardpos = torch.zeros(self.word_length, dtype=torch.long)
298
+ return '', verb_hardpos
299
+
300
+ # Extract metadata for hard positives if present
301
+ hardpos_dict = self.hardpos_meta.get(str(seg_id), {})
302
+ sent_id_list = list(hardpos_dict.keys())
303
+ # cur_hardpos = hardpos_dict.get(sent_id_list[sent_idx], [])
304
+ cur_hardpos = list(itertools.chain(*hardpos_dict.values()))
305
+ if cur_hardpos:
306
+ # Assign a hard positive verb phrase if available
307
+ raw_verb = random.choice(cur_hardpos)
308
+ verb_hardpos = tokenize(raw_verb, self.word_length, True).squeeze(0)
309
+ return raw_verb, verb_hardpos
310
+
311
+ verb_hardpos = torch.zeros(self.word_length, dtype=torch.long)
312
+ return '', verb_hardpos
313
+
314
+
315
+ def getTransformMat(self, img_size, inverse=False):
316
+ ori_h, ori_w = img_size
317
+ inp_h, inp_w = self.input_size
318
+ scale = min(inp_h / ori_h, inp_w / ori_w)
319
+ new_h, new_w = ori_h * scale, ori_w * scale
320
+ bias_x, bias_y = (inp_w - new_w) / 2., (inp_h - new_h) / 2.
321
+
322
+ src = np.array([[0, 0], [ori_w, 0], [0, ori_h]], np.float32)
323
+ dst = np.array([[bias_x, bias_y], [new_w + bias_x, bias_y],
324
+ [bias_x, new_h + bias_y]], np.float32)
325
+
326
+ mat = cv2.getAffineTransform(src, dst)
327
+ if inverse:
328
+ mat_inv = cv2.getAffineTransform(dst, src)
329
+ return mat, mat_inv
330
+ return mat, None
331
+
332
+ def convert(self, img, mask=None):
333
+ # Image ToTensor & Normalize
334
+ img = torch.from_numpy(img.transpose((2, 0, 1)))
335
+ if not isinstance(img, torch.FloatTensor):
336
+ img = img.float()
337
+ img.div_(255.).sub_(self.mean).div_(self.std)
338
+ # Mask ToTensor
339
+ if mask is not None:
340
+ mask = torch.from_numpy(mask)
341
+ if not isinstance(mask, torch.FloatTensor):
342
+ mask = mask.float()
343
+ return img, mask
344
+
345
+ def __repr__(self):
346
+ return self.__class__.__name__ + "(" + \
347
+ f"db_path={self.lmdb_dir}, " + \
348
+ f"dataset={self.dataset}, " + \
349
+ f"split={self.split}, " + \
350
+ f"mode={self.mode}, " + \
351
+ f"input_size={self.input_size}, " + \
352
+ f"word_length={self.word_length}"
353
+
354
+ # def get_length(self):
355
+ # return self.length
356
+
357
+ # def get_sample(self, idx):
358
+ # return self.__getitem__(idx)
utils/misc.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+ from loguru import logger
6
+ import sys
7
+ import inspect
8
+
9
+ import torch
10
+ from torch import nn
11
+ import torch.distributed as dist
12
+
13
+
14
+ def init_random_seed(seed=None, device='cuda', rank=0, world_size=1):
15
+ """Initialize random seed."""
16
+ if seed is not None:
17
+ return seed
18
+
19
+ # Make sure all ranks share the same random seed to prevent
20
+ # some potential bugs. Please refer to
21
+ # https://github.com/open-mmlab/mmdetection/issues/6339
22
+ seed = np.random.randint(2**31)
23
+ if world_size == 1:
24
+ return seed
25
+
26
+ if rank == 0:
27
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
28
+ else:
29
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
30
+ dist.broadcast(random_num, src=0)
31
+ return random_num.item()
32
+
33
+
34
+ def set_random_seed(seed, deterministic=False):
35
+ """Set random seed."""
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ torch.cuda.manual_seed_all(seed)
40
+ if deterministic:
41
+ torch.backends.cudnn.deterministic = True
42
+ torch.backends.cudnn.benchmark = False
43
+
44
+
45
+ @torch.no_grad()
46
+ def concat_all_gather(tensor):
47
+ """
48
+ Performs all_gather operation on the provided tensors.
49
+ *** Warning ***: torch.distributed.all_gather has no gradient.
50
+ """
51
+ tensor = tensor.contiguous()
52
+ tensors_gather = [
53
+ torch.ones_like(tensor)
54
+ for _ in range(torch.distributed.get_world_size())
55
+ ]
56
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
57
+
58
+ output = torch.cat(tensors_gather, dim=0)
59
+ return output
60
+
61
+
62
+ def worker_init_fn(worker_id, num_workers, rank, seed):
63
+ # The seed of each worker equals to
64
+ # num_worker * rank + worker_id + user_seed
65
+ worker_seed = num_workers * rank + worker_id + seed
66
+ np.random.seed(worker_seed)
67
+ random.seed(worker_seed)
68
+
69
+
70
+ class AverageMeter(object):
71
+ """Computes and stores the average and current value"""
72
+
73
+ def __init__(self, name, fmt=":f"):
74
+ self.name = name
75
+ self.fmt = fmt
76
+ self.reset()
77
+
78
+ def reset(self):
79
+ self.val = 0
80
+ self.avg = 0
81
+ self.sum = 0
82
+ self.count = 0
83
+
84
+ def update(self, val, n=1):
85
+ self.val = val
86
+ self.sum += val * n
87
+ self.count += n
88
+ self.avg = self.sum / self.count
89
+
90
+ def __str__(self):
91
+ if self.name == "Lr":
92
+ fmtstr = "{name}={val" + self.fmt + "}"
93
+ else:
94
+ fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})"
95
+ return fmtstr.format(**self.__dict__)
96
+
97
+
98
+ class ProgressMeter(object):
99
+ def __init__(self, num_batches, meters, prefix=""):
100
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
101
+ self.meters = meters
102
+ self.prefix = prefix
103
+
104
+ def display(self, batch):
105
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
106
+ entries += [str(meter) for meter in self.meters]
107
+ logger.info(" ".join(entries))
108
+
109
+ def _get_batch_fmtstr(self, num_batches):
110
+ num_digits = len(str(num_batches // 1))
111
+ fmt = "{:" + str(num_digits) + "d}"
112
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
113
+
114
+
115
+ def trainMetricGPU(output, target, threshold=0.35, pr_iou=0.5):
116
+ assert (output.dim() in [2, 3, 4])
117
+ assert output.shape == target.shape
118
+ output = output.flatten(1)
119
+ target = target.flatten(1)
120
+ output = torch.sigmoid(output)
121
+ output[output < threshold] = 0.
122
+ output[output >= threshold] = 1.
123
+ # inter & union
124
+ inter = (output.bool() & target.bool()).sum(dim=1) # b
125
+ union = (output.bool() | target.bool()).sum(dim=1) # b
126
+ ious = inter / (union + 1e-6) # 0 ~ 1
127
+ # iou & pr@5
128
+ iou = ious.mean()
129
+ prec = (ious > pr_iou).float().mean()
130
+ return 100. * iou, 100. * prec
131
+
132
+
133
+ def ValMetricGPU(output, target, threshold=0.35):
134
+ assert output.size(0) == 1
135
+ output = output.flatten(1)
136
+ target = target.flatten(1)
137
+ output = torch.sigmoid(output)
138
+ output[output < threshold] = 0.
139
+ output[output >= threshold] = 1.
140
+ # inter & union
141
+ inter = (output.bool() & target.bool()).sum(dim=1) # b
142
+ union = (output.bool() | target.bool()).sum(dim=1) # b
143
+ ious = inter / (union + 1e-6) # 0 ~ 1
144
+ return ious
145
+
146
+
147
+ def intersectionAndUnionGPU(output, target, K, threshold=0.5):
148
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
149
+ assert (output.dim() in [1, 2, 3])
150
+ assert output.shape == target.shape
151
+ output = output.view(-1)
152
+ target = target.view(-1)
153
+
154
+ output = torch.sigmoid(output)
155
+ output[output < threshold] = 0.
156
+ output[output >= threshold] = 1.
157
+
158
+ intersection = output[output == target]
159
+ area_intersection = torch.histc(intersection.float(),
160
+ bins=K,
161
+ min=0,
162
+ max=K - 1)
163
+ area_output = torch.histc(output.float(), bins=K, min=0, max=K - 1)
164
+ area_target = torch.histc(target.float(), bins=K, min=0, max=K - 1)
165
+ area_union = area_output + area_target - area_intersection
166
+ return area_intersection[1], area_union[1]
167
+
168
+
169
+ def group_weight(weight_group, module, lr):
170
+ group_decay = []
171
+ group_no_decay = []
172
+ for m in module.modules():
173
+ if isinstance(m, nn.Linear):
174
+ group_decay.append(m.weight)
175
+ if m.bias is not None:
176
+ group_no_decay.append(m.bias)
177
+ elif isinstance(m, nn.modules.conv._ConvNd):
178
+ group_decay.append(m.weight)
179
+ if m.bias is not None:
180
+ group_no_decay.append(m.bias)
181
+ elif isinstance(m, nn.modules.batchnorm._BatchNorm):
182
+ if m.weight is not None:
183
+ group_no_decay.append(m.weight)
184
+ if m.bias is not None:
185
+ group_no_decay.append(m.bias)
186
+ assert len(list(
187
+ module.parameters())) == len(group_decay) + len(group_no_decay)
188
+ weight_group.append(dict(params=group_decay, lr=lr))
189
+ weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
190
+ return weight_group
191
+
192
+
193
+ def colorize(gray, palette):
194
+ # gray: numpy array of the label and 1*3N size list palette
195
+ color = Image.fromarray(gray.astype(np.uint8)).convert('P')
196
+ color.putpalette(palette)
197
+ return color
198
+
199
+
200
+ def find_free_port():
201
+ import socket
202
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
203
+ # Binding to port 0 will cause the OS to find an available port for us
204
+ sock.bind(("", 0))
205
+ port = sock.getsockname()[1]
206
+ sock.close()
207
+ # NOTE: there is still a chance the port could be taken by other processes.
208
+ return port
209
+
210
+
211
+ def get_caller_name(depth=0):
212
+ """
213
+ Args:
214
+ depth (int): Depth of caller conext, use 0 for caller depth.
215
+ Default value: 0.
216
+
217
+ Returns:
218
+ str: module name of the caller
219
+ """
220
+ # the following logic is a little bit faster than inspect.stack() logic
221
+ frame = inspect.currentframe().f_back
222
+ for _ in range(depth):
223
+ frame = frame.f_back
224
+
225
+ return frame.f_globals["__name__"]
226
+
227
+
228
+ class StreamToLoguru:
229
+ """
230
+ stream object that redirects writes to a logger instance.
231
+ """
232
+ def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
233
+ """
234
+ Args:
235
+ level(str): log level string of loguru. Default value: "INFO".
236
+ caller_names(tuple): caller names of redirected module.
237
+ Default value: (apex, pycocotools).
238
+ """
239
+ self.level = level
240
+ self.linebuf = ""
241
+ self.caller_names = caller_names
242
+
243
+ def write(self, buf):
244
+ full_name = get_caller_name(depth=1)
245
+ module_name = full_name.rsplit(".", maxsplit=-1)[0]
246
+ if module_name in self.caller_names:
247
+ for line in buf.rstrip().splitlines():
248
+ # use caller level log
249
+ logger.opt(depth=2).log(self.level, line.rstrip())
250
+ else:
251
+ sys.__stdout__.write(buf)
252
+
253
+ def flush(self):
254
+ pass
255
+
256
+
257
+ def redirect_sys_output(log_level="INFO"):
258
+ redirect_logger = StreamToLoguru(log_level)
259
+ sys.stderr = redirect_logger
260
+ sys.stdout = redirect_logger
261
+
262
+
263
+ def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
264
+ """setup logger for training and testing.
265
+ Args:
266
+ save_dir(str): location to save log file
267
+ distributed_rank(int): device rank when multi-gpu environment
268
+ filename (string): log save name.
269
+ mode(str): log file write mode, `append` or `override`. default is `a`.
270
+
271
+ Return:
272
+ logger instance.
273
+ """
274
+ loguru_format = (
275
+ "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
276
+ "<level>{level: <8}</level> | "
277
+ "<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
278
+
279
+ logger.remove()
280
+ save_file = os.path.join(save_dir, filename)
281
+ if mode == "o" and os.path.exists(save_file):
282
+ os.remove(save_file)
283
+ # only keep logger in rank0 process
284
+ if distributed_rank == 0:
285
+ logger.add(
286
+ sys.stderr,
287
+ format=loguru_format,
288
+ level="INFO",
289
+ enqueue=True,
290
+ )
291
+ logger.add(save_file)
292
+
293
+ # redirect stdout/stderr to loguru
294
+ redirect_sys_output("INFO")
utils/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("隆"), ord("卢")+1))+list(range(ord("庐"), ord("每")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text