Upload folder using huggingface_hub
Browse files- utils/.gitignore +1 -0
- utils/__init__.py +4 -0
- utils/bpe_simple_vocab_16e6.txt.gz +3 -0
- utils/config.py +157 -0
- utils/dataset.py +341 -0
- utils/dataset_verbonly.py +358 -0
- utils/misc.py +294 -0
- utils/simple_tokenizer.py +132 -0
utils/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .simple_tokenizer import SimpleTokenizer
|
2 |
+
from .config import *
|
3 |
+
from .dataset import *
|
4 |
+
from .misc import *
|
utils/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
utils/config.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------
|
2 |
+
# Functions for parsing args
|
3 |
+
# -----------------------------------------------------------------------------
|
4 |
+
import copy
|
5 |
+
import os
|
6 |
+
from ast import literal_eval
|
7 |
+
|
8 |
+
import yaml
|
9 |
+
|
10 |
+
|
11 |
+
class CfgNode(dict):
|
12 |
+
"""
|
13 |
+
CfgNode represents an internal node in the configuration tree. It's a simple
|
14 |
+
dict-like container that allows for attribute-based access to keys.
|
15 |
+
"""
|
16 |
+
def __init__(self, init_dict=None, key_list=None, new_allowed=False):
|
17 |
+
# Recursively convert nested dictionaries in init_dict into CfgNodes
|
18 |
+
init_dict = {} if init_dict is None else init_dict
|
19 |
+
key_list = [] if key_list is None else key_list
|
20 |
+
for k, v in init_dict.items():
|
21 |
+
if type(v) is dict:
|
22 |
+
# Convert dict to CfgNode
|
23 |
+
init_dict[k] = CfgNode(v, key_list=key_list + [k])
|
24 |
+
super(CfgNode, self).__init__(init_dict)
|
25 |
+
|
26 |
+
def __getattr__(self, name):
|
27 |
+
if name in self:
|
28 |
+
return self[name]
|
29 |
+
else:
|
30 |
+
raise AttributeError(name)
|
31 |
+
|
32 |
+
def __setattr__(self, name, value):
|
33 |
+
self[name] = value
|
34 |
+
|
35 |
+
def __str__(self):
|
36 |
+
def _indent(s_, num_spaces):
|
37 |
+
s = s_.split("\n")
|
38 |
+
if len(s) == 1:
|
39 |
+
return s_
|
40 |
+
first = s.pop(0)
|
41 |
+
s = [(num_spaces * " ") + line for line in s]
|
42 |
+
s = "\n".join(s)
|
43 |
+
s = first + "\n" + s
|
44 |
+
return s
|
45 |
+
|
46 |
+
r = ""
|
47 |
+
s = []
|
48 |
+
for k, v in sorted(self.items()):
|
49 |
+
seperator = "\n" if isinstance(v, CfgNode) else " "
|
50 |
+
attr_str = "{}:{}{}".format(str(k), seperator, str(v))
|
51 |
+
attr_str = _indent(attr_str, 2)
|
52 |
+
s.append(attr_str)
|
53 |
+
r += "\n".join(s)
|
54 |
+
return r
|
55 |
+
|
56 |
+
def __repr__(self):
|
57 |
+
return "{}({})".format(self.__class__.__name__,
|
58 |
+
super(CfgNode, self).__repr__())
|
59 |
+
|
60 |
+
|
61 |
+
def load_cfg_from_cfg_file(file):
|
62 |
+
cfg = {}
|
63 |
+
assert os.path.isfile(file) and file.endswith('.yaml'), \
|
64 |
+
'{} is not a yaml file'.format(file)
|
65 |
+
|
66 |
+
with open(file, 'r') as f:
|
67 |
+
cfg_from_file = yaml.safe_load(f)
|
68 |
+
|
69 |
+
for key in cfg_from_file:
|
70 |
+
for k, v in cfg_from_file[key].items():
|
71 |
+
cfg[k] = v
|
72 |
+
|
73 |
+
cfg = CfgNode(cfg)
|
74 |
+
return cfg
|
75 |
+
|
76 |
+
|
77 |
+
def merge_cfg_from_list(cfg, cfg_list):
|
78 |
+
new_cfg = copy.deepcopy(cfg)
|
79 |
+
assert len(cfg_list) % 2 == 0
|
80 |
+
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
|
81 |
+
subkey = full_key.split('.')[-1]
|
82 |
+
assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
|
83 |
+
value = _decode_cfg_value(v)
|
84 |
+
value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
|
85 |
+
full_key)
|
86 |
+
setattr(new_cfg, subkey, value)
|
87 |
+
|
88 |
+
return new_cfg
|
89 |
+
|
90 |
+
|
91 |
+
def _decode_cfg_value(v):
|
92 |
+
"""Decodes a raw config value (e.g., from a yaml config files or command
|
93 |
+
line argument) into a Python object.
|
94 |
+
"""
|
95 |
+
# All remaining processing is only applied to strings
|
96 |
+
if not isinstance(v, str):
|
97 |
+
return v
|
98 |
+
# Try to interpret `v` as a:
|
99 |
+
# string, number, tuple, list, dict, boolean, or None
|
100 |
+
try:
|
101 |
+
v = literal_eval(v)
|
102 |
+
# The following two excepts allow v to pass through when it represents a
|
103 |
+
# string.
|
104 |
+
#
|
105 |
+
# Longer explanation:
|
106 |
+
# The type of v is always a string (before calling literal_eval), but
|
107 |
+
# sometimes it *represents* a string and other times a data structure, like
|
108 |
+
# a list. In the case that v represents a string, what we got back from the
|
109 |
+
# yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
|
110 |
+
# ok with '"foo"', but will raise a ValueError if given 'foo'. In other
|
111 |
+
# cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
|
112 |
+
# will raise a SyntaxError.
|
113 |
+
except ValueError:
|
114 |
+
pass
|
115 |
+
except SyntaxError:
|
116 |
+
pass
|
117 |
+
return v
|
118 |
+
|
119 |
+
|
120 |
+
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
|
121 |
+
"""Checks that `replacement`, which is intended to replace `original` is of
|
122 |
+
the right type. The type is correct if it matches exactly or is one of a few
|
123 |
+
cases in which the type can be easily coerced.
|
124 |
+
"""
|
125 |
+
original_type = type(original)
|
126 |
+
replacement_type = type(replacement)
|
127 |
+
|
128 |
+
# The types must match (with some exceptions)
|
129 |
+
if replacement_type == original_type:
|
130 |
+
return replacement
|
131 |
+
|
132 |
+
# Cast replacement from from_type to to_type if the replacement and original
|
133 |
+
# types match from_type and to_type
|
134 |
+
def conditional_cast(from_type, to_type):
|
135 |
+
if replacement_type == from_type and original_type == to_type:
|
136 |
+
return True, to_type(replacement)
|
137 |
+
else:
|
138 |
+
return False, None
|
139 |
+
|
140 |
+
# Conditionally casts
|
141 |
+
# list <-> tuple
|
142 |
+
casts = [(tuple, list), (list, tuple)]
|
143 |
+
# For py2: allow converting from str (bytes) to a unicode string
|
144 |
+
try:
|
145 |
+
casts.append((str, unicode)) # noqa: F821
|
146 |
+
except Exception:
|
147 |
+
pass
|
148 |
+
|
149 |
+
for (from_type, to_type) in casts:
|
150 |
+
converted, converted_value = conditional_cast(from_type, to_type)
|
151 |
+
if converted:
|
152 |
+
return converted_value
|
153 |
+
|
154 |
+
raise ValueError(
|
155 |
+
"Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
|
156 |
+
"key: {}".format(original_type, replacement_type, original,
|
157 |
+
replacement, full_key))
|
utils/dataset.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
from typing import List, Union
|
4 |
+
import json
|
5 |
+
import cv2
|
6 |
+
import lmdb
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import pyarrow as pa
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
import itertools
|
13 |
+
import albumentations as A
|
14 |
+
from albumentations.pytorch import ToTensorV2
|
15 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
16 |
+
|
17 |
+
info = {
|
18 |
+
'refcoco': {
|
19 |
+
'train': 42404,
|
20 |
+
'val': 3811,
|
21 |
+
'val-test': 3811,
|
22 |
+
'testA': 1975,
|
23 |
+
'testB': 1810
|
24 |
+
},
|
25 |
+
'refcoco+': {
|
26 |
+
'train': 42278,
|
27 |
+
'val': 3805,
|
28 |
+
'val-test': 3805,
|
29 |
+
'testA': 1975,
|
30 |
+
'testB': 1798
|
31 |
+
},
|
32 |
+
'refcocog_u': {
|
33 |
+
'train': 42226,
|
34 |
+
'val': 2573,
|
35 |
+
'val-test': 2573,
|
36 |
+
'test': 5023,
|
37 |
+
'test_0-5_verb' : 572,
|
38 |
+
'test_0-5_static' : 1688,
|
39 |
+
'test_6-7_verb' : 949,
|
40 |
+
'test_6-7_static' : 1240,
|
41 |
+
'test_8-10_verb' : 1523,
|
42 |
+
'test_8-10_static' : 1194,
|
43 |
+
'test_11-20_verb' : 1768,
|
44 |
+
'test_11-20_static' : 584,
|
45 |
+
'test_abl_motion' : 267,
|
46 |
+
'test_abl_static' : 267
|
47 |
+
},
|
48 |
+
'refcocog_g': {
|
49 |
+
'train': 44822,
|
50 |
+
'val': 5000,
|
51 |
+
'val-test': 5000
|
52 |
+
}
|
53 |
+
}
|
54 |
+
_tokenizer = _Tokenizer()
|
55 |
+
|
56 |
+
#%%
|
57 |
+
def tokenize(texts: Union[str, List[str]],
|
58 |
+
context_length: int = 77,
|
59 |
+
truncate: bool = False) -> torch.LongTensor:
|
60 |
+
"""
|
61 |
+
Returns the tokenized representation of given input string(s)
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
texts : Union[str, List[str]]
|
66 |
+
An input string or a list of input strings to tokenize
|
67 |
+
|
68 |
+
context_length : int
|
69 |
+
The context length to use; all CLIP models use 77 as the context length
|
70 |
+
|
71 |
+
truncate: bool
|
72 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
73 |
+
|
74 |
+
Returns
|
75 |
+
-------
|
76 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
77 |
+
"""
|
78 |
+
if isinstance(texts, str):
|
79 |
+
texts = [texts]
|
80 |
+
|
81 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
82 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
83 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
|
84 |
+
for text in texts]
|
85 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
86 |
+
|
87 |
+
for i, tokens in enumerate(all_tokens):
|
88 |
+
if len(tokens) > context_length:
|
89 |
+
if truncate:
|
90 |
+
tokens = tokens[:context_length]
|
91 |
+
tokens[-1] = eot_token
|
92 |
+
else:
|
93 |
+
raise RuntimeError(
|
94 |
+
f"Input {texts[i]} is too long for context length {context_length}"
|
95 |
+
)
|
96 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
97 |
+
|
98 |
+
return result
|
99 |
+
|
100 |
+
|
101 |
+
def loads_pyarrow(buf):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
buf: the output of `dumps`.
|
105 |
+
"""
|
106 |
+
return pa.deserialize(buf)
|
107 |
+
|
108 |
+
|
109 |
+
class RefDataset(Dataset):
|
110 |
+
def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size,
|
111 |
+
word_length, args):
|
112 |
+
super(RefDataset, self).__init__()
|
113 |
+
self.lmdb_dir = lmdb_dir
|
114 |
+
self.mask_dir = mask_dir
|
115 |
+
self.dataset = dataset
|
116 |
+
self.split = split
|
117 |
+
self.mode = mode
|
118 |
+
self.input_size = (input_size, input_size)
|
119 |
+
self.word_length = word_length
|
120 |
+
self.mean = torch.tensor([0.48145466, 0.4578275,
|
121 |
+
0.40821073]).reshape(3, 1, 1)
|
122 |
+
self.std = torch.tensor([0.26862954, 0.26130258,
|
123 |
+
0.27577711]).reshape(3, 1, 1)
|
124 |
+
self.length = info[dataset][split]
|
125 |
+
self.env = None
|
126 |
+
self.exclude_position = args.exclude_pos
|
127 |
+
self.metric_learning = args.metric_learning
|
128 |
+
self.hardpos_rigid = args.hardpos_rigid
|
129 |
+
self.resize_bg1 = A.Compose([
|
130 |
+
A.Resize(input_size, input_size, always_apply=True)])
|
131 |
+
if self.metric_learning :
|
132 |
+
if self.hardpos_rigid and self.exclude_position :
|
133 |
+
multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj_nopos.txt'
|
134 |
+
with open(multiobj_path, 'r') as f:
|
135 |
+
self.multi_obj_ref_ids = [int(line.strip()) for line in f.readlines()]
|
136 |
+
elif self.hardpos_rigid :
|
137 |
+
multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj.txt'
|
138 |
+
with open(multiobj_path, 'r') as f:
|
139 |
+
self.multi_obj_ref_ids = [int(line.strip()) for line in f.readlines()]
|
140 |
+
else :
|
141 |
+
self.multi_obj_ref_ids = None
|
142 |
+
|
143 |
+
path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/llama3-demo/llama3/hardpos_verbphrase_0906upd.json'
|
144 |
+
with open(path, 'r', encoding='utf-8') as f:
|
145 |
+
self.metadata = json.load(f)
|
146 |
+
else :
|
147 |
+
self.metadata = None
|
148 |
+
|
149 |
+
def _init_db(self):
|
150 |
+
self.env = lmdb.open(self.lmdb_dir,
|
151 |
+
subdir=os.path.isdir(self.lmdb_dir),
|
152 |
+
readonly=True,
|
153 |
+
lock=False,
|
154 |
+
readahead=False,
|
155 |
+
meminit=False)
|
156 |
+
with self.env.begin(write=False) as txn:
|
157 |
+
self.length = loads_pyarrow(txn.get(b'__len__'))
|
158 |
+
self.keys = loads_pyarrow(txn.get(b'__keys__'))
|
159 |
+
|
160 |
+
def __len__(self):
|
161 |
+
return self.length
|
162 |
+
|
163 |
+
def __getitem__(self, index):
|
164 |
+
# Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129
|
165 |
+
if self.env is None:
|
166 |
+
self._init_db()
|
167 |
+
env = self.env
|
168 |
+
with env.begin(write=False) as txn:
|
169 |
+
byteflow = txn.get(self.keys[index])
|
170 |
+
ref = loads_pyarrow(byteflow)
|
171 |
+
# img
|
172 |
+
ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8),
|
173 |
+
cv2.IMREAD_COLOR)
|
174 |
+
img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
|
175 |
+
|
176 |
+
# mask
|
177 |
+
seg_id = ref['seg_id']
|
178 |
+
mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png')
|
179 |
+
|
180 |
+
mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8),
|
181 |
+
cv2.IMREAD_GRAYSCALE)
|
182 |
+
mask = mask / 255.
|
183 |
+
|
184 |
+
|
185 |
+
# image resizing
|
186 |
+
resized = self.resize_bg1(image=img, mask=mask)
|
187 |
+
imgs, masks = [resized['image']], [resized['mask']]
|
188 |
+
img = imgs[0]
|
189 |
+
mask = masks[0]
|
190 |
+
mask = mask.astype(np.uint8)
|
191 |
+
mask[mask>0] = 1
|
192 |
+
|
193 |
+
# image transform
|
194 |
+
img_size = img.shape[:2]
|
195 |
+
mat, mat_inv = self.getTransformMat(img_size, True)
|
196 |
+
img = cv2.warpAffine(
|
197 |
+
img,
|
198 |
+
mat,
|
199 |
+
self.input_size,
|
200 |
+
flags=cv2.INTER_CUBIC,
|
201 |
+
borderValue=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255])
|
202 |
+
|
203 |
+
# sentences
|
204 |
+
sents = ref['sents']
|
205 |
+
n_sentences = ref['num_sents']
|
206 |
+
|
207 |
+
if self.mode == 'train':
|
208 |
+
# mask transform
|
209 |
+
mask = cv2.warpAffine(mask,
|
210 |
+
mat,
|
211 |
+
self.input_size,
|
212 |
+
flags=cv2.INTER_LINEAR,
|
213 |
+
borderValue=0.)
|
214 |
+
|
215 |
+
# if metric learning, select 2 positive sentences
|
216 |
+
if self.metric_learning:
|
217 |
+
if self.hardpos_rigid and seg_id in self.multi_obj_ref_ids:
|
218 |
+
if n_sentences > 1:
|
219 |
+
idx = np.random.choice(ref['num_sents'], 2, replace=False)
|
220 |
+
sent = [sents[i] for i in idx]
|
221 |
+
else:
|
222 |
+
sent = [sents[0], sents[0]]
|
223 |
+
else:
|
224 |
+
# Added processing hardpos data
|
225 |
+
hardpos_dict = self.metadata[str(ref['seg_id'])]
|
226 |
+
hardpos_list = list(itertools.chain(*hardpos_dict.values()))
|
227 |
+
sent_id_list = list(hardpos_dict.keys())
|
228 |
+
|
229 |
+
if n_sentences > 1:
|
230 |
+
if self.hardpos_rigid :
|
231 |
+
idx = np.random.choice(ref['num_sents'], 1, replace=False)[0]
|
232 |
+
cur_hardpos = hardpos_dict[sent_id_list[idx]]
|
233 |
+
if len(cur_hardpos) == 0 :
|
234 |
+
idx = np.random.choice(ref['num_sents'], 2, replace=False)
|
235 |
+
sent = [sents[i] for i in idx]
|
236 |
+
else :
|
237 |
+
hardpos_choice = random.choice(cur_hardpos)
|
238 |
+
sent = [sents[idx], hardpos_choice]
|
239 |
+
random.shuffle(sent)
|
240 |
+
else :
|
241 |
+
if len(hardpos_list) == 0 :
|
242 |
+
idx = np.random.choice(ref['num_sents'], 2, replace=False)
|
243 |
+
sent = [sents[i] for i in idx]
|
244 |
+
else :
|
245 |
+
idx = np.random.choice(ref['num_sents'], 1, replace=False)[0]
|
246 |
+
hardpos_choice = random.choice(hardpos_list)
|
247 |
+
sent = [sents[idx], hardpos_choice]
|
248 |
+
random.shuffle(sent)
|
249 |
+
# if there's only one, duplicate it
|
250 |
+
else:
|
251 |
+
if len(hardpos_list) == 0 :
|
252 |
+
sent = [sents[0], sents[0]]
|
253 |
+
else :
|
254 |
+
hardpos_choice = random.choice(hardpos_list)
|
255 |
+
sent = [sents[0], hardpos_choice]
|
256 |
+
random.shuffle(sent)
|
257 |
+
# print(f"Generated sentences: {sent}")
|
258 |
+
else:
|
259 |
+
idx = np.random.choice(ref['num_sents'], 1, replace=False)
|
260 |
+
sent = sents[idx]
|
261 |
+
word_vec = tokenize(sent, self.word_length, True).squeeze(0)
|
262 |
+
img, mask = self.convert(img, mask)
|
263 |
+
|
264 |
+
# params = {
|
265 |
+
# 'ori_img': ori_img,
|
266 |
+
# 'seg_id': seg_id,
|
267 |
+
# 'mask_dir': mask_dir,
|
268 |
+
# 'inverse': mat_inv,
|
269 |
+
# 'ori_size': np.array(img_size),
|
270 |
+
# 'sents': sents
|
271 |
+
# }
|
272 |
+
return img, word_vec, mask
|
273 |
+
|
274 |
+
elif self.mode == 'val':
|
275 |
+
# sentence -> vector
|
276 |
+
sent = sents[0]
|
277 |
+
word_vec = tokenize(sent, self.word_length, True).squeeze(0)
|
278 |
+
img = self.convert(img)[0]
|
279 |
+
params = {
|
280 |
+
'mask_dir': mask_dir,
|
281 |
+
'inverse': mat_inv,
|
282 |
+
'ori_size': np.array(img_size)
|
283 |
+
}
|
284 |
+
return img, word_vec, mask, params
|
285 |
+
else:
|
286 |
+
# sentence -> vector
|
287 |
+
img = self.convert(img)[0]
|
288 |
+
params = {
|
289 |
+
'ori_img': ori_img,
|
290 |
+
'seg_id': seg_id,
|
291 |
+
'mask_dir': mask_dir,
|
292 |
+
'inverse': mat_inv,
|
293 |
+
'ori_size': np.array(img_size),
|
294 |
+
'sents': sents
|
295 |
+
}
|
296 |
+
return img, mask, params
|
297 |
+
|
298 |
+
def getTransformMat(self, img_size, inverse=False):
|
299 |
+
ori_h, ori_w = img_size
|
300 |
+
inp_h, inp_w = self.input_size
|
301 |
+
scale = min(inp_h / ori_h, inp_w / ori_w)
|
302 |
+
new_h, new_w = ori_h * scale, ori_w * scale
|
303 |
+
bias_x, bias_y = (inp_w - new_w) / 2., (inp_h - new_h) / 2.
|
304 |
+
|
305 |
+
src = np.array([[0, 0], [ori_w, 0], [0, ori_h]], np.float32)
|
306 |
+
dst = np.array([[bias_x, bias_y], [new_w + bias_x, bias_y],
|
307 |
+
[bias_x, new_h + bias_y]], np.float32)
|
308 |
+
|
309 |
+
mat = cv2.getAffineTransform(src, dst)
|
310 |
+
if inverse:
|
311 |
+
mat_inv = cv2.getAffineTransform(dst, src)
|
312 |
+
return mat, mat_inv
|
313 |
+
return mat, None
|
314 |
+
|
315 |
+
def convert(self, img, mask=None):
|
316 |
+
# Image ToTensor & Normalize
|
317 |
+
img = torch.from_numpy(img.transpose((2, 0, 1)))
|
318 |
+
if not isinstance(img, torch.FloatTensor):
|
319 |
+
img = img.float()
|
320 |
+
img.div_(255.).sub_(self.mean).div_(self.std)
|
321 |
+
# Mask ToTensor
|
322 |
+
if mask is not None:
|
323 |
+
mask = torch.from_numpy(mask)
|
324 |
+
if not isinstance(mask, torch.FloatTensor):
|
325 |
+
mask = mask.float()
|
326 |
+
return img, mask
|
327 |
+
|
328 |
+
def __repr__(self):
|
329 |
+
return self.__class__.__name__ + "(" + \
|
330 |
+
f"db_path={self.lmdb_dir}, " + \
|
331 |
+
f"dataset={self.dataset}, " + \
|
332 |
+
f"split={self.split}, " + \
|
333 |
+
f"mode={self.mode}, " + \
|
334 |
+
f"input_size={self.input_size}, " + \
|
335 |
+
f"word_length={self.word_length}"
|
336 |
+
|
337 |
+
# def get_length(self):
|
338 |
+
# return self.length
|
339 |
+
|
340 |
+
# def get_sample(self, idx):
|
341 |
+
# return self.__getitem__(idx)
|
utils/dataset_verbonly.py
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os
|
3 |
+
from typing import List, Union
|
4 |
+
import json
|
5 |
+
import cv2
|
6 |
+
import lmdb
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import pyarrow as pa
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
import itertools
|
13 |
+
import albumentations as A
|
14 |
+
from albumentations.pytorch import ToTensorV2
|
15 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
16 |
+
|
17 |
+
info = {
|
18 |
+
'refcoco': {
|
19 |
+
'train': 42404,
|
20 |
+
'val': 3811,
|
21 |
+
'val-test': 3811,
|
22 |
+
'testA': 1975,
|
23 |
+
'testB': 1810
|
24 |
+
},
|
25 |
+
'refcoco+': {
|
26 |
+
'train': 42278,
|
27 |
+
'val': 3805,
|
28 |
+
'val-test': 3805,
|
29 |
+
'testA': 1975,
|
30 |
+
'testB': 1798
|
31 |
+
},
|
32 |
+
'refcocog_u': {
|
33 |
+
'train': 42226,
|
34 |
+
'val': 2573,
|
35 |
+
'val-test': 2573,
|
36 |
+
'test': 5023,
|
37 |
+
},
|
38 |
+
'refcocog_g': {
|
39 |
+
'train': 44822,
|
40 |
+
'val': 5000,
|
41 |
+
'val-test': 5000
|
42 |
+
}
|
43 |
+
}
|
44 |
+
_tokenizer = _Tokenizer()
|
45 |
+
|
46 |
+
#%%
|
47 |
+
def tokenize(texts: Union[str, List[str]],
|
48 |
+
context_length: int = 77,
|
49 |
+
truncate: bool = False) -> torch.LongTensor:
|
50 |
+
"""
|
51 |
+
Returns the tokenized representation of given input string(s)
|
52 |
+
|
53 |
+
Parameters
|
54 |
+
----------
|
55 |
+
texts : Union[str, List[str]]
|
56 |
+
An input string or a list of input strings to tokenize
|
57 |
+
|
58 |
+
context_length : int
|
59 |
+
The context length to use; all CLIP models use 77 as the context length
|
60 |
+
|
61 |
+
truncate: bool
|
62 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
63 |
+
|
64 |
+
Returns
|
65 |
+
-------
|
66 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
67 |
+
"""
|
68 |
+
if isinstance(texts, str):
|
69 |
+
texts = [texts]
|
70 |
+
|
71 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
72 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
73 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token]
|
74 |
+
for text in texts]
|
75 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
76 |
+
|
77 |
+
for i, tokens in enumerate(all_tokens):
|
78 |
+
if len(tokens) > context_length:
|
79 |
+
if truncate:
|
80 |
+
tokens = tokens[:context_length]
|
81 |
+
tokens[-1] = eot_token
|
82 |
+
else:
|
83 |
+
raise RuntimeError(
|
84 |
+
f"Input {texts[i]} is too long for context length {context_length}"
|
85 |
+
)
|
86 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
87 |
+
|
88 |
+
return result
|
89 |
+
|
90 |
+
|
91 |
+
def loads_pyarrow(buf):
|
92 |
+
"""
|
93 |
+
Args:
|
94 |
+
buf: the output of `dumps`.
|
95 |
+
"""
|
96 |
+
return pa.deserialize(buf)
|
97 |
+
|
98 |
+
|
99 |
+
class RefDataset(Dataset):
|
100 |
+
def __init__(self, lmdb_dir, mask_dir, dataset, split, mode, input_size,
|
101 |
+
word_length, args):
|
102 |
+
super(RefDataset, self).__init__()
|
103 |
+
self.lmdb_dir = lmdb_dir
|
104 |
+
self.mask_dir = mask_dir
|
105 |
+
self.dataset = dataset
|
106 |
+
self.split = split
|
107 |
+
self.mode = mode
|
108 |
+
self.input_size = (input_size, input_size)
|
109 |
+
self.word_length = word_length
|
110 |
+
self.mean = torch.tensor([0.48145466, 0.4578275,
|
111 |
+
0.40821073]).reshape(3, 1, 1)
|
112 |
+
self.std = torch.tensor([0.26862954, 0.26130258,
|
113 |
+
0.27577711]).reshape(3, 1, 1)
|
114 |
+
self.length = info[dataset][split]
|
115 |
+
self.env = None
|
116 |
+
|
117 |
+
self.exclude_position = args.exclude_pos
|
118 |
+
self.metric_learning = args.metric_learning
|
119 |
+
self.exclude_multiobj = args.exclude_multiobj
|
120 |
+
self.metric_mode = args.metric_mode
|
121 |
+
|
122 |
+
self.resize_bg1 = A.Compose([
|
123 |
+
A.Resize(input_size, input_size, always_apply=True)])
|
124 |
+
if self.metric_learning:
|
125 |
+
self.hardneg_prob = args.hn_prob # Hard negative probability 锟竭帮拷
|
126 |
+
self.multi_obj_ref_ids = self._load_multi_obj_ref_ids()
|
127 |
+
self.hardpos_meta, self.hardneg_meta = self._load_metadata()
|
128 |
+
else:
|
129 |
+
self.hardneg_prob = 0.0
|
130 |
+
self.multi_obj_ref_ids = None
|
131 |
+
self.hardpos_meta, self.hardneg_meta = None, None
|
132 |
+
|
133 |
+
def _load_multi_obj_ref_ids(self):
|
134 |
+
# Load multi-object reference IDs based on configurations
|
135 |
+
if not self.exclude_multiobj and not self.exclude_position :
|
136 |
+
return None
|
137 |
+
elif self.exclude_position:
|
138 |
+
multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj_ov2_nopos.txt'
|
139 |
+
elif self.exclude_multiobj :
|
140 |
+
multiobj_path = '/home/chaeyun/data/projects/chaeyun/RIS/CRIS.pytorch/multiobj_ov3.txt'
|
141 |
+
with open(multiobj_path, 'r') as f:
|
142 |
+
return [int(line.strip()) for line in f.readlines()]
|
143 |
+
|
144 |
+
def _load_metadata(self):
|
145 |
+
# Load metadata for hard positive verb phrases, hard negative queries
|
146 |
+
hardpos_path = '/data2/projects/chaeyun/VerbCentric_RIS/hardpos_verbphrase_0906upd.json'
|
147 |
+
hardneg_path = '/data2/projects/chaeyun/VerbCentric_RIS/hardneg_verb.json'
|
148 |
+
|
149 |
+
with open(hardpos_path, 'r', encoding='utf-8') as f:
|
150 |
+
hardpos_json = json.load(f)
|
151 |
+
if self.metric_mode == "hardpos_only" :
|
152 |
+
hardneg_json = None
|
153 |
+
else :
|
154 |
+
with open(hardneg_path, 'r', encoding='utf-8') as q:
|
155 |
+
hardneg_json = json.load(q)
|
156 |
+
return hardpos_json, hardneg_json
|
157 |
+
|
158 |
+
|
159 |
+
def _init_db(self):
|
160 |
+
self.env = lmdb.open(self.lmdb_dir,
|
161 |
+
subdir=os.path.isdir(self.lmdb_dir),
|
162 |
+
readonly=True,
|
163 |
+
lock=False,
|
164 |
+
readahead=False,
|
165 |
+
meminit=False)
|
166 |
+
with self.env.begin(write=False) as txn:
|
167 |
+
self.length = loads_pyarrow(txn.get(b'__len__'))
|
168 |
+
self.keys = loads_pyarrow(txn.get(b'__keys__'))
|
169 |
+
|
170 |
+
def __len__(self):
|
171 |
+
return self.length
|
172 |
+
|
173 |
+
def __getitem__(self, index):
|
174 |
+
# Delay loading LMDB data until after initialization: https://github.com/chainer/chainermn/issues/129
|
175 |
+
if self.env is None:
|
176 |
+
self._init_db()
|
177 |
+
env = self.env
|
178 |
+
with env.begin(write=False) as txn:
|
179 |
+
byteflow = txn.get(self.keys[index])
|
180 |
+
ref = loads_pyarrow(byteflow)
|
181 |
+
# img
|
182 |
+
ori_img = cv2.imdecode(np.frombuffer(ref['img'], np.uint8),
|
183 |
+
cv2.IMREAD_COLOR)
|
184 |
+
img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
|
185 |
+
|
186 |
+
# mask
|
187 |
+
seg_id = ref['seg_id']
|
188 |
+
mask_dir = os.path.join(self.mask_dir, str(seg_id) + '.png')
|
189 |
+
|
190 |
+
mask = cv2.imdecode(np.frombuffer(ref['mask'], np.uint8),
|
191 |
+
cv2.IMREAD_GRAYSCALE)
|
192 |
+
mask = mask / 255.
|
193 |
+
|
194 |
+
|
195 |
+
# image resizing
|
196 |
+
resized = self.resize_bg1(image=img, mask=mask)
|
197 |
+
imgs, masks = [resized['image']], [resized['mask']]
|
198 |
+
img = imgs[0]
|
199 |
+
mask = masks[0]
|
200 |
+
mask = mask.astype(np.uint8)
|
201 |
+
mask[mask>0] = 1
|
202 |
+
|
203 |
+
# image transform
|
204 |
+
img_size = img.shape[:2]
|
205 |
+
mat, mat_inv = self.getTransformMat(img_size, True)
|
206 |
+
img = cv2.warpAffine(
|
207 |
+
img,
|
208 |
+
mat,
|
209 |
+
self.input_size,
|
210 |
+
flags=cv2.INTER_CUBIC,
|
211 |
+
borderValue=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255])
|
212 |
+
|
213 |
+
# sentences
|
214 |
+
sents = ref['sents']
|
215 |
+
n_sentences = ref['num_sents']
|
216 |
+
|
217 |
+
if self.mode == 'train':
|
218 |
+
# mask transform
|
219 |
+
mask = cv2.warpAffine(mask,
|
220 |
+
mat,
|
221 |
+
self.input_size,
|
222 |
+
flags=cv2.INTER_LINEAR,
|
223 |
+
borderValue=0.)
|
224 |
+
|
225 |
+
# if metric learning, assign hard positive verb phrase if applicable
|
226 |
+
idx = np.random.choice(n_sentences, 1, replace=False)[0]
|
227 |
+
sent = sents[idx]
|
228 |
+
raw_hardpos, hardpos = self._get_hardpos_verb(ref, seg_id, idx)
|
229 |
+
img, mask = self.convert(img, mask)
|
230 |
+
word_vec = tokenize(sent, self.word_length, True).squeeze(0)
|
231 |
+
|
232 |
+
if self.metric_mode == "hardpos_only" :
|
233 |
+
return img, word_vec, mask, hardpos
|
234 |
+
|
235 |
+
else :
|
236 |
+
choice = np.random.choice(['hn', 'no_hn'], p=[self.hardneg_prob, 1 - self.hardneg_prob])
|
237 |
+
if choice == 'hn' and raw_hardpos :
|
238 |
+
raw_hardneg, hardneg = self._get_hardneg_verb(ref, seg_id, idx)
|
239 |
+
else :
|
240 |
+
hardneg = torch.zeros(self.word_length, dtype=torch.long)
|
241 |
+
return img, word_vec, mask, hardpos, hardneg
|
242 |
+
|
243 |
+
elif self.mode == 'val':
|
244 |
+
# sentence -> vector
|
245 |
+
sent = sents[0]
|
246 |
+
word_vec = tokenize(sent, self.word_length, True).squeeze(0)
|
247 |
+
img = self.convert(img)[0]
|
248 |
+
params = {
|
249 |
+
'mask_dir': mask_dir,
|
250 |
+
'inverse': mat_inv,
|
251 |
+
'ori_size': np.array(img_size)
|
252 |
+
}
|
253 |
+
return img, word_vec, mask, params
|
254 |
+
else:
|
255 |
+
# sentence -> vector
|
256 |
+
img = self.convert(img)[0]
|
257 |
+
params = {
|
258 |
+
'ori_img': ori_img,
|
259 |
+
'seg_id': seg_id,
|
260 |
+
'mask_dir': mask_dir,
|
261 |
+
'inverse': mat_inv,
|
262 |
+
'ori_size': np.array(img_size),
|
263 |
+
'sents': sents
|
264 |
+
}
|
265 |
+
return img, mask, params
|
266 |
+
|
267 |
+
|
268 |
+
def _get_hardneg_verb(self, ref, seg_id, sent_idx):
|
269 |
+
"""
|
270 |
+
Handle the logic for selecting hard positive verb phrases during metric learning.
|
271 |
+
Returns the sentence, raw_verb, and tokenized verb if applicable.
|
272 |
+
"""
|
273 |
+
|
274 |
+
# Extract metadata for hard positives if present
|
275 |
+
hardneg_dict = self.hardneg_meta.get(str(seg_id), {})
|
276 |
+
sent_id_list = list(hardneg_dict.keys())
|
277 |
+
|
278 |
+
cur_hardneg = hardpos_dict.get(sent_id_list[sent_idx], [])
|
279 |
+
if cur_hardneg:
|
280 |
+
# Assign a hard positive verb phrase if available
|
281 |
+
raw_verb_hardneg = random.choice(cur_hardneg)
|
282 |
+
verb_hardneg = tokenize(raw_verb_hardneg, self.word_length, True).squeeze(0)
|
283 |
+
return raw_verb_hardneg, verb_hardneg
|
284 |
+
|
285 |
+
verb_hardneg = torch.zeros(self.word_length, dtype=torch.long)
|
286 |
+
return '', verb_hardneg
|
287 |
+
|
288 |
+
|
289 |
+
|
290 |
+
def _get_hardpos_verb(self, ref, seg_id, sent_idx):
|
291 |
+
"""
|
292 |
+
Handle the logic for selecting hard positive verb phrases during metric learning.
|
293 |
+
Returns the sentence, raw_verb, and tokenized verb if applicable.
|
294 |
+
"""
|
295 |
+
# If the object appears multiple times, no hard positive is used
|
296 |
+
if seg_id in self.multi_obj_ref_ids:
|
297 |
+
verb_hardpos = torch.zeros(self.word_length, dtype=torch.long)
|
298 |
+
return '', verb_hardpos
|
299 |
+
|
300 |
+
# Extract metadata for hard positives if present
|
301 |
+
hardpos_dict = self.hardpos_meta.get(str(seg_id), {})
|
302 |
+
sent_id_list = list(hardpos_dict.keys())
|
303 |
+
# cur_hardpos = hardpos_dict.get(sent_id_list[sent_idx], [])
|
304 |
+
cur_hardpos = list(itertools.chain(*hardpos_dict.values()))
|
305 |
+
if cur_hardpos:
|
306 |
+
# Assign a hard positive verb phrase if available
|
307 |
+
raw_verb = random.choice(cur_hardpos)
|
308 |
+
verb_hardpos = tokenize(raw_verb, self.word_length, True).squeeze(0)
|
309 |
+
return raw_verb, verb_hardpos
|
310 |
+
|
311 |
+
verb_hardpos = torch.zeros(self.word_length, dtype=torch.long)
|
312 |
+
return '', verb_hardpos
|
313 |
+
|
314 |
+
|
315 |
+
def getTransformMat(self, img_size, inverse=False):
|
316 |
+
ori_h, ori_w = img_size
|
317 |
+
inp_h, inp_w = self.input_size
|
318 |
+
scale = min(inp_h / ori_h, inp_w / ori_w)
|
319 |
+
new_h, new_w = ori_h * scale, ori_w * scale
|
320 |
+
bias_x, bias_y = (inp_w - new_w) / 2., (inp_h - new_h) / 2.
|
321 |
+
|
322 |
+
src = np.array([[0, 0], [ori_w, 0], [0, ori_h]], np.float32)
|
323 |
+
dst = np.array([[bias_x, bias_y], [new_w + bias_x, bias_y],
|
324 |
+
[bias_x, new_h + bias_y]], np.float32)
|
325 |
+
|
326 |
+
mat = cv2.getAffineTransform(src, dst)
|
327 |
+
if inverse:
|
328 |
+
mat_inv = cv2.getAffineTransform(dst, src)
|
329 |
+
return mat, mat_inv
|
330 |
+
return mat, None
|
331 |
+
|
332 |
+
def convert(self, img, mask=None):
|
333 |
+
# Image ToTensor & Normalize
|
334 |
+
img = torch.from_numpy(img.transpose((2, 0, 1)))
|
335 |
+
if not isinstance(img, torch.FloatTensor):
|
336 |
+
img = img.float()
|
337 |
+
img.div_(255.).sub_(self.mean).div_(self.std)
|
338 |
+
# Mask ToTensor
|
339 |
+
if mask is not None:
|
340 |
+
mask = torch.from_numpy(mask)
|
341 |
+
if not isinstance(mask, torch.FloatTensor):
|
342 |
+
mask = mask.float()
|
343 |
+
return img, mask
|
344 |
+
|
345 |
+
def __repr__(self):
|
346 |
+
return self.__class__.__name__ + "(" + \
|
347 |
+
f"db_path={self.lmdb_dir}, " + \
|
348 |
+
f"dataset={self.dataset}, " + \
|
349 |
+
f"split={self.split}, " + \
|
350 |
+
f"mode={self.mode}, " + \
|
351 |
+
f"input_size={self.input_size}, " + \
|
352 |
+
f"word_length={self.word_length}"
|
353 |
+
|
354 |
+
# def get_length(self):
|
355 |
+
# return self.length
|
356 |
+
|
357 |
+
# def get_sample(self, idx):
|
358 |
+
# return self.__getitem__(idx)
|
utils/misc.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from loguru import logger
|
6 |
+
import sys
|
7 |
+
import inspect
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
import torch.distributed as dist
|
12 |
+
|
13 |
+
|
14 |
+
def init_random_seed(seed=None, device='cuda', rank=0, world_size=1):
|
15 |
+
"""Initialize random seed."""
|
16 |
+
if seed is not None:
|
17 |
+
return seed
|
18 |
+
|
19 |
+
# Make sure all ranks share the same random seed to prevent
|
20 |
+
# some potential bugs. Please refer to
|
21 |
+
# https://github.com/open-mmlab/mmdetection/issues/6339
|
22 |
+
seed = np.random.randint(2**31)
|
23 |
+
if world_size == 1:
|
24 |
+
return seed
|
25 |
+
|
26 |
+
if rank == 0:
|
27 |
+
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
28 |
+
else:
|
29 |
+
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
30 |
+
dist.broadcast(random_num, src=0)
|
31 |
+
return random_num.item()
|
32 |
+
|
33 |
+
|
34 |
+
def set_random_seed(seed, deterministic=False):
|
35 |
+
"""Set random seed."""
|
36 |
+
random.seed(seed)
|
37 |
+
np.random.seed(seed)
|
38 |
+
torch.manual_seed(seed)
|
39 |
+
torch.cuda.manual_seed_all(seed)
|
40 |
+
if deterministic:
|
41 |
+
torch.backends.cudnn.deterministic = True
|
42 |
+
torch.backends.cudnn.benchmark = False
|
43 |
+
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def concat_all_gather(tensor):
|
47 |
+
"""
|
48 |
+
Performs all_gather operation on the provided tensors.
|
49 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
50 |
+
"""
|
51 |
+
tensor = tensor.contiguous()
|
52 |
+
tensors_gather = [
|
53 |
+
torch.ones_like(tensor)
|
54 |
+
for _ in range(torch.distributed.get_world_size())
|
55 |
+
]
|
56 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
57 |
+
|
58 |
+
output = torch.cat(tensors_gather, dim=0)
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
63 |
+
# The seed of each worker equals to
|
64 |
+
# num_worker * rank + worker_id + user_seed
|
65 |
+
worker_seed = num_workers * rank + worker_id + seed
|
66 |
+
np.random.seed(worker_seed)
|
67 |
+
random.seed(worker_seed)
|
68 |
+
|
69 |
+
|
70 |
+
class AverageMeter(object):
|
71 |
+
"""Computes and stores the average and current value"""
|
72 |
+
|
73 |
+
def __init__(self, name, fmt=":f"):
|
74 |
+
self.name = name
|
75 |
+
self.fmt = fmt
|
76 |
+
self.reset()
|
77 |
+
|
78 |
+
def reset(self):
|
79 |
+
self.val = 0
|
80 |
+
self.avg = 0
|
81 |
+
self.sum = 0
|
82 |
+
self.count = 0
|
83 |
+
|
84 |
+
def update(self, val, n=1):
|
85 |
+
self.val = val
|
86 |
+
self.sum += val * n
|
87 |
+
self.count += n
|
88 |
+
self.avg = self.sum / self.count
|
89 |
+
|
90 |
+
def __str__(self):
|
91 |
+
if self.name == "Lr":
|
92 |
+
fmtstr = "{name}={val" + self.fmt + "}"
|
93 |
+
else:
|
94 |
+
fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})"
|
95 |
+
return fmtstr.format(**self.__dict__)
|
96 |
+
|
97 |
+
|
98 |
+
class ProgressMeter(object):
|
99 |
+
def __init__(self, num_batches, meters, prefix=""):
|
100 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
101 |
+
self.meters = meters
|
102 |
+
self.prefix = prefix
|
103 |
+
|
104 |
+
def display(self, batch):
|
105 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
106 |
+
entries += [str(meter) for meter in self.meters]
|
107 |
+
logger.info(" ".join(entries))
|
108 |
+
|
109 |
+
def _get_batch_fmtstr(self, num_batches):
|
110 |
+
num_digits = len(str(num_batches // 1))
|
111 |
+
fmt = "{:" + str(num_digits) + "d}"
|
112 |
+
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
113 |
+
|
114 |
+
|
115 |
+
def trainMetricGPU(output, target, threshold=0.35, pr_iou=0.5):
|
116 |
+
assert (output.dim() in [2, 3, 4])
|
117 |
+
assert output.shape == target.shape
|
118 |
+
output = output.flatten(1)
|
119 |
+
target = target.flatten(1)
|
120 |
+
output = torch.sigmoid(output)
|
121 |
+
output[output < threshold] = 0.
|
122 |
+
output[output >= threshold] = 1.
|
123 |
+
# inter & union
|
124 |
+
inter = (output.bool() & target.bool()).sum(dim=1) # b
|
125 |
+
union = (output.bool() | target.bool()).sum(dim=1) # b
|
126 |
+
ious = inter / (union + 1e-6) # 0 ~ 1
|
127 |
+
# iou & pr@5
|
128 |
+
iou = ious.mean()
|
129 |
+
prec = (ious > pr_iou).float().mean()
|
130 |
+
return 100. * iou, 100. * prec
|
131 |
+
|
132 |
+
|
133 |
+
def ValMetricGPU(output, target, threshold=0.35):
|
134 |
+
assert output.size(0) == 1
|
135 |
+
output = output.flatten(1)
|
136 |
+
target = target.flatten(1)
|
137 |
+
output = torch.sigmoid(output)
|
138 |
+
output[output < threshold] = 0.
|
139 |
+
output[output >= threshold] = 1.
|
140 |
+
# inter & union
|
141 |
+
inter = (output.bool() & target.bool()).sum(dim=1) # b
|
142 |
+
union = (output.bool() | target.bool()).sum(dim=1) # b
|
143 |
+
ious = inter / (union + 1e-6) # 0 ~ 1
|
144 |
+
return ious
|
145 |
+
|
146 |
+
|
147 |
+
def intersectionAndUnionGPU(output, target, K, threshold=0.5):
|
148 |
+
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
|
149 |
+
assert (output.dim() in [1, 2, 3])
|
150 |
+
assert output.shape == target.shape
|
151 |
+
output = output.view(-1)
|
152 |
+
target = target.view(-1)
|
153 |
+
|
154 |
+
output = torch.sigmoid(output)
|
155 |
+
output[output < threshold] = 0.
|
156 |
+
output[output >= threshold] = 1.
|
157 |
+
|
158 |
+
intersection = output[output == target]
|
159 |
+
area_intersection = torch.histc(intersection.float(),
|
160 |
+
bins=K,
|
161 |
+
min=0,
|
162 |
+
max=K - 1)
|
163 |
+
area_output = torch.histc(output.float(), bins=K, min=0, max=K - 1)
|
164 |
+
area_target = torch.histc(target.float(), bins=K, min=0, max=K - 1)
|
165 |
+
area_union = area_output + area_target - area_intersection
|
166 |
+
return area_intersection[1], area_union[1]
|
167 |
+
|
168 |
+
|
169 |
+
def group_weight(weight_group, module, lr):
|
170 |
+
group_decay = []
|
171 |
+
group_no_decay = []
|
172 |
+
for m in module.modules():
|
173 |
+
if isinstance(m, nn.Linear):
|
174 |
+
group_decay.append(m.weight)
|
175 |
+
if m.bias is not None:
|
176 |
+
group_no_decay.append(m.bias)
|
177 |
+
elif isinstance(m, nn.modules.conv._ConvNd):
|
178 |
+
group_decay.append(m.weight)
|
179 |
+
if m.bias is not None:
|
180 |
+
group_no_decay.append(m.bias)
|
181 |
+
elif isinstance(m, nn.modules.batchnorm._BatchNorm):
|
182 |
+
if m.weight is not None:
|
183 |
+
group_no_decay.append(m.weight)
|
184 |
+
if m.bias is not None:
|
185 |
+
group_no_decay.append(m.bias)
|
186 |
+
assert len(list(
|
187 |
+
module.parameters())) == len(group_decay) + len(group_no_decay)
|
188 |
+
weight_group.append(dict(params=group_decay, lr=lr))
|
189 |
+
weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
|
190 |
+
return weight_group
|
191 |
+
|
192 |
+
|
193 |
+
def colorize(gray, palette):
|
194 |
+
# gray: numpy array of the label and 1*3N size list palette
|
195 |
+
color = Image.fromarray(gray.astype(np.uint8)).convert('P')
|
196 |
+
color.putpalette(palette)
|
197 |
+
return color
|
198 |
+
|
199 |
+
|
200 |
+
def find_free_port():
|
201 |
+
import socket
|
202 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
203 |
+
# Binding to port 0 will cause the OS to find an available port for us
|
204 |
+
sock.bind(("", 0))
|
205 |
+
port = sock.getsockname()[1]
|
206 |
+
sock.close()
|
207 |
+
# NOTE: there is still a chance the port could be taken by other processes.
|
208 |
+
return port
|
209 |
+
|
210 |
+
|
211 |
+
def get_caller_name(depth=0):
|
212 |
+
"""
|
213 |
+
Args:
|
214 |
+
depth (int): Depth of caller conext, use 0 for caller depth.
|
215 |
+
Default value: 0.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
str: module name of the caller
|
219 |
+
"""
|
220 |
+
# the following logic is a little bit faster than inspect.stack() logic
|
221 |
+
frame = inspect.currentframe().f_back
|
222 |
+
for _ in range(depth):
|
223 |
+
frame = frame.f_back
|
224 |
+
|
225 |
+
return frame.f_globals["__name__"]
|
226 |
+
|
227 |
+
|
228 |
+
class StreamToLoguru:
|
229 |
+
"""
|
230 |
+
stream object that redirects writes to a logger instance.
|
231 |
+
"""
|
232 |
+
def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
|
233 |
+
"""
|
234 |
+
Args:
|
235 |
+
level(str): log level string of loguru. Default value: "INFO".
|
236 |
+
caller_names(tuple): caller names of redirected module.
|
237 |
+
Default value: (apex, pycocotools).
|
238 |
+
"""
|
239 |
+
self.level = level
|
240 |
+
self.linebuf = ""
|
241 |
+
self.caller_names = caller_names
|
242 |
+
|
243 |
+
def write(self, buf):
|
244 |
+
full_name = get_caller_name(depth=1)
|
245 |
+
module_name = full_name.rsplit(".", maxsplit=-1)[0]
|
246 |
+
if module_name in self.caller_names:
|
247 |
+
for line in buf.rstrip().splitlines():
|
248 |
+
# use caller level log
|
249 |
+
logger.opt(depth=2).log(self.level, line.rstrip())
|
250 |
+
else:
|
251 |
+
sys.__stdout__.write(buf)
|
252 |
+
|
253 |
+
def flush(self):
|
254 |
+
pass
|
255 |
+
|
256 |
+
|
257 |
+
def redirect_sys_output(log_level="INFO"):
|
258 |
+
redirect_logger = StreamToLoguru(log_level)
|
259 |
+
sys.stderr = redirect_logger
|
260 |
+
sys.stdout = redirect_logger
|
261 |
+
|
262 |
+
|
263 |
+
def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
|
264 |
+
"""setup logger for training and testing.
|
265 |
+
Args:
|
266 |
+
save_dir(str): location to save log file
|
267 |
+
distributed_rank(int): device rank when multi-gpu environment
|
268 |
+
filename (string): log save name.
|
269 |
+
mode(str): log file write mode, `append` or `override`. default is `a`.
|
270 |
+
|
271 |
+
Return:
|
272 |
+
logger instance.
|
273 |
+
"""
|
274 |
+
loguru_format = (
|
275 |
+
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
|
276 |
+
"<level>{level: <8}</level> | "
|
277 |
+
"<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>")
|
278 |
+
|
279 |
+
logger.remove()
|
280 |
+
save_file = os.path.join(save_dir, filename)
|
281 |
+
if mode == "o" and os.path.exists(save_file):
|
282 |
+
os.remove(save_file)
|
283 |
+
# only keep logger in rank0 process
|
284 |
+
if distributed_rank == 0:
|
285 |
+
logger.add(
|
286 |
+
sys.stderr,
|
287 |
+
format=loguru_format,
|
288 |
+
level="INFO",
|
289 |
+
enqueue=True,
|
290 |
+
)
|
291 |
+
logger.add(save_file)
|
292 |
+
|
293 |
+
# redirect stdout/stderr to loguru
|
294 |
+
redirect_sys_output("INFO")
|
utils/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("隆"), ord("卢")+1))+list(range(ord("庐"), ord("每")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|