Hasanmog commited on
Commit
76047a1
·
verified ·
1 Parent(s): 9cbeebc
Files changed (1) hide show
  1. util/utils.py +222 -96
util/utils.py CHANGED
@@ -1,59 +1,154 @@
1
- from collections import OrderedDict
2
- from copy import deepcopy
3
  import json
4
  import warnings
 
 
 
5
 
6
- import torch
7
  import numpy as np
 
 
 
 
8
 
9
- def slprint(x, name='x'):
 
10
  if isinstance(x, (torch.Tensor, np.ndarray)):
11
- print(f'{name}.shape:', x.shape)
12
  elif isinstance(x, (tuple, list)):
13
- print('type x:', type(x))
14
  for i in range(min(10, len(x))):
15
- slprint(x[i], f'{name}[{i}]')
16
  elif isinstance(x, dict):
17
- for k,v in x.items():
18
- slprint(v, f'{name}[{k}]')
19
  else:
20
- print(f'{name}.type:', type(x))
 
21
 
22
  def clean_state_dict(state_dict):
23
  new_state_dict = OrderedDict()
24
  for k, v in state_dict.items():
25
- if k[:7] == 'module.':
26
  k = k[7:] # remove `module.`
27
  new_state_dict[k] = v
28
  return new_state_dict
29
 
30
- def renorm(img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) \
31
- -> torch.FloatTensor:
 
 
32
  # img: tensor(3,H,W) or tensor(B,3,H,W)
33
  # return: same as img
34
- assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
35
  if img.dim() == 3:
36
- assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (img.size(0), str(img.size()))
37
- img_perm = img.permute(1,2,0)
 
 
 
38
  mean = torch.Tensor(mean)
39
  std = torch.Tensor(std)
40
  img_res = img_perm * std + mean
41
- return img_res.permute(2,0,1)
42
- else: # img.dim() == 4
43
- assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (img.size(1), str(img.size()))
44
- img_perm = img.permute(0,2,3,1)
 
 
 
45
  mean = torch.Tensor(mean)
46
  std = torch.Tensor(std)
47
  img_res = img_perm * std + mean
48
- return img_res.permute(0,3,1,2)
49
-
50
 
51
 
52
- class CocoClassMapper():
53
  def __init__(self) -> None:
54
- self.category_map_str = {"1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "10": 10, "11": 11, "13": 12, "14": 13, "15": 14, "16": 15, "17": 16, "18": 17, "19": 18, "20": 19, "21": 20, "22": 21, "23": 22, "24": 23, "25": 24, "27": 25, "28": 26, "31": 27, "32": 28, "33": 29, "34": 30, "35": 31, "36": 32, "37": 33, "38": 34, "39": 35, "40": 36, "41": 37, "42": 38, "43": 39, "44": 40, "46": 41, "47": 42, "48": 43, "49": 44, "50": 45, "51": 46, "52": 47, "53": 48, "54": 49, "55": 50, "56": 51, "57": 52, "58": 53, "59": 54, "60": 55, "61": 56, "62": 57, "63": 58, "64": 59, "65": 60, "67": 61, "70": 62, "72": 63, "73": 64, "74": 65, "75": 66, "76": 67, "77": 68, "78": 69, "79": 70, "80": 71, "81": 72, "82": 73, "84": 74, "85": 75, "86": 76, "87": 77, "88": 78, "89": 79, "90": 80}
55
- self.origin2compact_mapper = {int(k):v-1 for k,v in self.category_map_str.items()}
56
- self.compact2origin_mapper = {int(v-1):int(k) for k,v in self.category_map_str.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def origin2compact(self, idx):
59
  return self.origin2compact_mapper[int(idx)]
@@ -61,19 +156,21 @@ class CocoClassMapper():
61
  def compact2origin(self, idx):
62
  return self.compact2origin_mapper[int(idx)]
63
 
 
64
  def to_device(item, device):
65
  if isinstance(item, torch.Tensor):
66
  return item.to(device)
67
  elif isinstance(item, list):
68
  return [to_device(i, device) for i in item]
69
  elif isinstance(item, dict):
70
- return {k: to_device(v, device) for k,v in item.items()}
71
  else:
72
- raise NotImplementedError("Call Shilong if you use other containers! type: {}".format(type(item)))
73
-
 
74
 
75
 
76
- #
77
  def get_gaussian_mean(x, axis, other_axis, softmax=True):
78
  """
79
 
@@ -99,6 +196,7 @@ def get_gaussian_mean(x, axis, other_axis, softmax=True):
99
  mean_position = torch.sum(index * u, dim=2)
100
  return mean_position
101
 
 
102
  def get_expected_points_from_map(hm, softmax=True):
103
  """get_gaussian_map_from_points
104
  B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
@@ -107,71 +205,74 @@ def get_expected_points_from_map(hm, softmax=True):
107
  Args:
108
  hm (float): Input images(BxCxHxW)
109
 
110
- Returns:
111
  weighted index for axis, BxCx2. float between 0 and 1.
112
 
113
  """
114
  # hm = 10*hm
115
- B,C,H,W = hm.shape
116
- y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
117
- x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
118
  # return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
119
  return torch.stack([x_mean, y_mean], dim=2)
120
 
 
121
  # Positional encoding (section 5.1)
122
  # borrow from nerf
123
  class Embedder:
124
  def __init__(self, **kwargs):
125
  self.kwargs = kwargs
126
  self.create_embedding_fn()
127
-
128
  def create_embedding_fn(self):
129
  embed_fns = []
130
- d = self.kwargs['input_dims']
131
  out_dim = 0
132
- if self.kwargs['include_input']:
133
- embed_fns.append(lambda x : x)
134
  out_dim += d
135
-
136
- max_freq = self.kwargs['max_freq_log2']
137
- N_freqs = self.kwargs['num_freqs']
138
-
139
- if self.kwargs['log_sampling']:
140
- freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
141
  else:
142
- freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
143
-
144
  for freq in freq_bands:
145
- for p_fn in self.kwargs['periodic_fns']:
146
- embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
147
  out_dim += d
148
-
149
  self.embed_fns = embed_fns
150
  self.out_dim = out_dim
151
-
152
  def embed(self, inputs):
153
  return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
154
 
155
 
156
  def get_embedder(multires, i=0):
157
  import torch.nn as nn
 
158
  if i == -1:
159
  return nn.Identity(), 3
160
-
161
  embed_kwargs = {
162
- 'include_input' : True,
163
- 'input_dims' : 3,
164
- 'max_freq_log2' : multires-1,
165
- 'num_freqs' : multires,
166
- 'log_sampling' : True,
167
- 'periodic_fns' : [torch.sin, torch.cos],
168
  }
169
-
170
  embedder_obj = Embedder(**embed_kwargs)
171
- embed = lambda x, eo=embedder_obj : eo.embed(x)
172
  return embed, embedder_obj.out_dim
173
 
174
- class APOPMeter():
 
175
  def __init__(self) -> None:
176
  self.tp = 0
177
  self.fp = 0
@@ -195,24 +296,24 @@ class APOPMeter():
195
  self.tn += tn
196
  self.tn += fn
197
 
 
198
  def inverse_sigmoid(x, eps=1e-5):
199
  x = x.clamp(min=0, max=1)
200
  x1 = x.clamp(min=eps)
201
  x2 = (1 - x).clamp(min=eps)
202
- return torch.log(x1/x2)
 
203
 
204
- import argparse
205
- from util.slconfig import SLConfig
206
  def get_raw_dict(args):
207
  """
208
  return the dicf contained in args.
209
-
210
  e.g:
211
  >>> with open(path, 'w') as f:
212
  json.dump(get_raw_dict(args), f, indent=2)
213
  """
214
- if isinstance(args, argparse.Namespace):
215
- return vars(args)
216
  elif isinstance(args, dict):
217
  return args
218
  elif isinstance(args, SLConfig):
@@ -227,12 +328,12 @@ def stat_tensors(tensor):
227
  entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
228
 
229
  return {
230
- 'max': tensor.max(),
231
- 'min': tensor.min(),
232
- 'mean': tensor.mean(),
233
- 'var': tensor.var(),
234
- 'std': tensor.var() ** 0.5,
235
- 'entropy': entropy
236
  }
237
 
238
 
@@ -272,21 +373,20 @@ class NiceRepr:
272
 
273
  def __nice__(self):
274
  """str: a "nice" summary string describing this module"""
275
- if hasattr(self, '__len__'):
276
  # It is a common pattern for objects to use __len__ in __nice__
277
  # As a convenience we define a default __nice__ for these objects
278
  return str(len(self))
279
  else:
280
  # In all other cases force the subclass to overload __nice__
281
- raise NotImplementedError(
282
- f'Define the __nice__ method for {self.__class__!r}')
283
 
284
  def __repr__(self):
285
  """str: the string of the module"""
286
  try:
287
  nice = self.__nice__()
288
  classname = self.__class__.__name__
289
- return f'<{classname}({nice}) at {hex(id(self))}>'
290
  except NotImplementedError as ex:
291
  warnings.warn(str(ex), category=RuntimeWarning)
292
  return object.__repr__(self)
@@ -296,13 +396,12 @@ class NiceRepr:
296
  try:
297
  classname = self.__class__.__name__
298
  nice = self.__nice__()
299
- return f'<{classname}({nice})>'
300
  except NotImplementedError as ex:
301
  warnings.warn(str(ex), category=RuntimeWarning)
302
  return object.__repr__(self)
303
 
304
 
305
-
306
  def ensure_rng(rng=None):
307
  """Coerces input into a random number generator.
308
 
@@ -333,6 +432,7 @@ def ensure_rng(rng=None):
333
  rng = rng
334
  return rng
335
 
 
336
  def random_boxes(num=1, scale=1, rng=None):
337
  """Simple version of ``kwimage.Boxes.random``
338
 
@@ -377,6 +477,8 @@ class ModelEma(torch.nn.Module):
377
  self.module = deepcopy(model)
378
  self.module.eval()
379
 
 
 
380
  self.decay = decay
381
  self.device = device # perform ema on different device from model if set
382
  if self.device is not None:
@@ -384,30 +486,33 @@ class ModelEma(torch.nn.Module):
384
 
385
  def _update(self, model, update_fn):
386
  with torch.no_grad():
387
- for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
 
 
388
  if self.device is not None:
389
  model_v = model_v.to(device=self.device)
390
  ema_v.copy_(update_fn(ema_v, model_v))
391
 
392
  def update(self, model):
393
- self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
394
 
395
  def set(self, model):
396
  self._update(model, update_fn=lambda e, m: m)
397
 
398
- class BestMetricSingle():
399
- def __init__(self, init_res=0.0, better='large') -> None:
 
400
  self.init_res = init_res
401
  self.best_res = init_res
402
  self.best_ep = -1
403
 
404
  self.better = better
405
- assert better in ['large', 'small']
406
 
407
  def isbetter(self, new_res, old_res):
408
- if self.better == 'large':
409
  return new_res > old_res
410
- if self.better == 'small':
411
  return new_res < old_res
412
 
413
  def update(self, new_res, ep):
@@ -425,19 +530,18 @@ class BestMetricSingle():
425
 
426
  def summary(self) -> dict:
427
  return {
428
- 'best_res': self.best_res,
429
- 'best_ep': self.best_ep,
430
  }
431
 
432
 
433
- class BestMetricHolder():
434
- def __init__(self, init_res=0.0, better='large', use_ema=False) -> None:
435
  self.best_all = BestMetricSingle(init_res, better)
436
  self.use_ema = use_ema
437
  if use_ema:
438
  self.best_ema = BestMetricSingle(init_res, better)
439
  self.best_regular = BestMetricSingle(init_res, better)
440
-
441
 
442
  def update(self, new_res, epoch, is_ema=False):
443
  """
@@ -458,9 +562,9 @@ class BestMetricHolder():
458
  return self.best_all.summary()
459
 
460
  res = {}
461
- res.update({f'all_{k}':v for k,v in self.best_all.summary().items()})
462
- res.update({f'regular_{k}':v for k,v in self.best_regular.summary().items()})
463
- res.update({f'ema_{k}':v for k,v in self.best_ema.summary().items()})
464
  return res
465
 
466
  def __repr__(self) -> str:
@@ -469,8 +573,31 @@ class BestMetricHolder():
469
  def __str__(self) -> str:
470
  return self.__repr__()
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  def get_phrases_from_posmap(
473
- posmap: torch.BoolTensor, tokenized: dict, tokenizer: AutoTokenizer, left_idx: int = 0, right_idx: int = 255
474
  ):
475
  assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
476
  if posmap.dim() == 1:
@@ -480,5 +607,4 @@ def get_phrases_from_posmap(
480
  token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
481
  return tokenizer.decode(token_ids)
482
  else:
483
- raise NotImplementedError("posmap must be 1-dim")
484
-
 
1
+ import argparse
 
2
  import json
3
  import warnings
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from typing import Any, Dict, List
7
 
 
8
  import numpy as np
9
+ import torch
10
+ from transformers import AutoTokenizer
11
+
12
+ from groundingdino.util.slconfig import SLConfig
13
 
14
+
15
+ def slprint(x, name="x"):
16
  if isinstance(x, (torch.Tensor, np.ndarray)):
17
+ print(f"{name}.shape:", x.shape)
18
  elif isinstance(x, (tuple, list)):
19
+ print("type x:", type(x))
20
  for i in range(min(10, len(x))):
21
+ slprint(x[i], f"{name}[{i}]")
22
  elif isinstance(x, dict):
23
+ for k, v in x.items():
24
+ slprint(v, f"{name}[{k}]")
25
  else:
26
+ print(f"{name}.type:", type(x))
27
+
28
 
29
  def clean_state_dict(state_dict):
30
  new_state_dict = OrderedDict()
31
  for k, v in state_dict.items():
32
+ if k[:7] == "module.":
33
  k = k[7:] # remove `module.`
34
  new_state_dict[k] = v
35
  return new_state_dict
36
 
37
+
38
+ def renorm(
39
+ img: torch.FloatTensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
40
+ ) -> torch.FloatTensor:
41
  # img: tensor(3,H,W) or tensor(B,3,H,W)
42
  # return: same as img
43
+ assert img.dim() == 3 or img.dim() == 4, "img.dim() should be 3 or 4 but %d" % img.dim()
44
  if img.dim() == 3:
45
+ assert img.size(0) == 3, 'img.size(0) shoule be 3 but "%d". (%s)' % (
46
+ img.size(0),
47
+ str(img.size()),
48
+ )
49
+ img_perm = img.permute(1, 2, 0)
50
  mean = torch.Tensor(mean)
51
  std = torch.Tensor(std)
52
  img_res = img_perm * std + mean
53
+ return img_res.permute(2, 0, 1)
54
+ else: # img.dim() == 4
55
+ assert img.size(1) == 3, 'img.size(1) shoule be 3 but "%d". (%s)' % (
56
+ img.size(1),
57
+ str(img.size()),
58
+ )
59
+ img_perm = img.permute(0, 2, 3, 1)
60
  mean = torch.Tensor(mean)
61
  std = torch.Tensor(std)
62
  img_res = img_perm * std + mean
63
+ return img_res.permute(0, 3, 1, 2)
 
64
 
65
 
66
+ class CocoClassMapper:
67
  def __init__(self) -> None:
68
+ self.category_map_str = {
69
+ "1": 1,
70
+ "2": 2,
71
+ "3": 3,
72
+ "4": 4,
73
+ "5": 5,
74
+ "6": 6,
75
+ "7": 7,
76
+ "8": 8,
77
+ "9": 9,
78
+ "10": 10,
79
+ "11": 11,
80
+ "13": 12,
81
+ "14": 13,
82
+ "15": 14,
83
+ "16": 15,
84
+ "17": 16,
85
+ "18": 17,
86
+ "19": 18,
87
+ "20": 19,
88
+ "21": 20,
89
+ "22": 21,
90
+ "23": 22,
91
+ "24": 23,
92
+ "25": 24,
93
+ "27": 25,
94
+ "28": 26,
95
+ "31": 27,
96
+ "32": 28,
97
+ "33": 29,
98
+ "34": 30,
99
+ "35": 31,
100
+ "36": 32,
101
+ "37": 33,
102
+ "38": 34,
103
+ "39": 35,
104
+ "40": 36,
105
+ "41": 37,
106
+ "42": 38,
107
+ "43": 39,
108
+ "44": 40,
109
+ "46": 41,
110
+ "47": 42,
111
+ "48": 43,
112
+ "49": 44,
113
+ "50": 45,
114
+ "51": 46,
115
+ "52": 47,
116
+ "53": 48,
117
+ "54": 49,
118
+ "55": 50,
119
+ "56": 51,
120
+ "57": 52,
121
+ "58": 53,
122
+ "59": 54,
123
+ "60": 55,
124
+ "61": 56,
125
+ "62": 57,
126
+ "63": 58,
127
+ "64": 59,
128
+ "65": 60,
129
+ "67": 61,
130
+ "70": 62,
131
+ "72": 63,
132
+ "73": 64,
133
+ "74": 65,
134
+ "75": 66,
135
+ "76": 67,
136
+ "77": 68,
137
+ "78": 69,
138
+ "79": 70,
139
+ "80": 71,
140
+ "81": 72,
141
+ "82": 73,
142
+ "84": 74,
143
+ "85": 75,
144
+ "86": 76,
145
+ "87": 77,
146
+ "88": 78,
147
+ "89": 79,
148
+ "90": 80,
149
+ }
150
+ self.origin2compact_mapper = {int(k): v - 1 for k, v in self.category_map_str.items()}
151
+ self.compact2origin_mapper = {int(v - 1): int(k) for k, v in self.category_map_str.items()}
152
 
153
  def origin2compact(self, idx):
154
  return self.origin2compact_mapper[int(idx)]
 
156
  def compact2origin(self, idx):
157
  return self.compact2origin_mapper[int(idx)]
158
 
159
+
160
  def to_device(item, device):
161
  if isinstance(item, torch.Tensor):
162
  return item.to(device)
163
  elif isinstance(item, list):
164
  return [to_device(i, device) for i in item]
165
  elif isinstance(item, dict):
166
+ return {k: to_device(v, device) for k, v in item.items()}
167
  else:
168
+ raise NotImplementedError(
169
+ "Call Shilong if you use other containers! type: {}".format(type(item))
170
+ )
171
 
172
 
173
+ #
174
  def get_gaussian_mean(x, axis, other_axis, softmax=True):
175
  """
176
 
 
196
  mean_position = torch.sum(index * u, dim=2)
197
  return mean_position
198
 
199
+
200
  def get_expected_points_from_map(hm, softmax=True):
201
  """get_gaussian_map_from_points
202
  B,C,H,W -> B,N,2 float(0, 1) float(0, 1)
 
205
  Args:
206
  hm (float): Input images(BxCxHxW)
207
 
208
+ Returns:
209
  weighted index for axis, BxCx2. float between 0 and 1.
210
 
211
  """
212
  # hm = 10*hm
213
+ B, C, H, W = hm.shape
214
+ y_mean = get_gaussian_mean(hm, 2, 3, softmax=softmax) # B,C
215
+ x_mean = get_gaussian_mean(hm, 3, 2, softmax=softmax) # B,C
216
  # return torch.cat((x_mean.unsqueeze(-1), y_mean.unsqueeze(-1)), 2)
217
  return torch.stack([x_mean, y_mean], dim=2)
218
 
219
+
220
  # Positional encoding (section 5.1)
221
  # borrow from nerf
222
  class Embedder:
223
  def __init__(self, **kwargs):
224
  self.kwargs = kwargs
225
  self.create_embedding_fn()
226
+
227
  def create_embedding_fn(self):
228
  embed_fns = []
229
+ d = self.kwargs["input_dims"]
230
  out_dim = 0
231
+ if self.kwargs["include_input"]:
232
+ embed_fns.append(lambda x: x)
233
  out_dim += d
234
+
235
+ max_freq = self.kwargs["max_freq_log2"]
236
+ N_freqs = self.kwargs["num_freqs"]
237
+
238
+ if self.kwargs["log_sampling"]:
239
+ freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)
240
  else:
241
+ freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)
242
+
243
  for freq in freq_bands:
244
+ for p_fn in self.kwargs["periodic_fns"]:
245
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
246
  out_dim += d
247
+
248
  self.embed_fns = embed_fns
249
  self.out_dim = out_dim
250
+
251
  def embed(self, inputs):
252
  return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
253
 
254
 
255
  def get_embedder(multires, i=0):
256
  import torch.nn as nn
257
+
258
  if i == -1:
259
  return nn.Identity(), 3
260
+
261
  embed_kwargs = {
262
+ "include_input": True,
263
+ "input_dims": 3,
264
+ "max_freq_log2": multires - 1,
265
+ "num_freqs": multires,
266
+ "log_sampling": True,
267
+ "periodic_fns": [torch.sin, torch.cos],
268
  }
269
+
270
  embedder_obj = Embedder(**embed_kwargs)
271
+ embed = lambda x, eo=embedder_obj: eo.embed(x)
272
  return embed, embedder_obj.out_dim
273
 
274
+
275
+ class APOPMeter:
276
  def __init__(self) -> None:
277
  self.tp = 0
278
  self.fp = 0
 
296
  self.tn += tn
297
  self.tn += fn
298
 
299
+
300
  def inverse_sigmoid(x, eps=1e-5):
301
  x = x.clamp(min=0, max=1)
302
  x1 = x.clamp(min=eps)
303
  x2 = (1 - x).clamp(min=eps)
304
+ return torch.log(x1 / x2)
305
+
306
 
 
 
307
  def get_raw_dict(args):
308
  """
309
  return the dicf contained in args.
310
+
311
  e.g:
312
  >>> with open(path, 'w') as f:
313
  json.dump(get_raw_dict(args), f, indent=2)
314
  """
315
+ if isinstance(args, argparse.Namespace):
316
+ return vars(args)
317
  elif isinstance(args, dict):
318
  return args
319
  elif isinstance(args, SLConfig):
 
328
  entropy = (tensor_sm * torch.log(tensor_sm + 1e-9)).sum()
329
 
330
  return {
331
+ "max": tensor.max(),
332
+ "min": tensor.min(),
333
+ "mean": tensor.mean(),
334
+ "var": tensor.var(),
335
+ "std": tensor.var() ** 0.5,
336
+ "entropy": entropy,
337
  }
338
 
339
 
 
373
 
374
  def __nice__(self):
375
  """str: a "nice" summary string describing this module"""
376
+ if hasattr(self, "__len__"):
377
  # It is a common pattern for objects to use __len__ in __nice__
378
  # As a convenience we define a default __nice__ for these objects
379
  return str(len(self))
380
  else:
381
  # In all other cases force the subclass to overload __nice__
382
+ raise NotImplementedError(f"Define the __nice__ method for {self.__class__!r}")
 
383
 
384
  def __repr__(self):
385
  """str: the string of the module"""
386
  try:
387
  nice = self.__nice__()
388
  classname = self.__class__.__name__
389
+ return f"<{classname}({nice}) at {hex(id(self))}>"
390
  except NotImplementedError as ex:
391
  warnings.warn(str(ex), category=RuntimeWarning)
392
  return object.__repr__(self)
 
396
  try:
397
  classname = self.__class__.__name__
398
  nice = self.__nice__()
399
+ return f"<{classname}({nice})>"
400
  except NotImplementedError as ex:
401
  warnings.warn(str(ex), category=RuntimeWarning)
402
  return object.__repr__(self)
403
 
404
 
 
405
  def ensure_rng(rng=None):
406
  """Coerces input into a random number generator.
407
 
 
432
  rng = rng
433
  return rng
434
 
435
+
436
  def random_boxes(num=1, scale=1, rng=None):
437
  """Simple version of ``kwimage.Boxes.random``
438
 
 
477
  self.module = deepcopy(model)
478
  self.module.eval()
479
 
480
+ # import ipdb; ipdb.set_trace()
481
+
482
  self.decay = decay
483
  self.device = device # perform ema on different device from model if set
484
  if self.device is not None:
 
486
 
487
  def _update(self, model, update_fn):
488
  with torch.no_grad():
489
+ for ema_v, model_v in zip(
490
+ self.module.state_dict().values(), model.state_dict().values()
491
+ ):
492
  if self.device is not None:
493
  model_v = model_v.to(device=self.device)
494
  ema_v.copy_(update_fn(ema_v, model_v))
495
 
496
  def update(self, model):
497
+ self._update(model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m)
498
 
499
  def set(self, model):
500
  self._update(model, update_fn=lambda e, m: m)
501
 
502
+
503
+ class BestMetricSingle:
504
+ def __init__(self, init_res=0.0, better="large") -> None:
505
  self.init_res = init_res
506
  self.best_res = init_res
507
  self.best_ep = -1
508
 
509
  self.better = better
510
+ assert better in ["large", "small"]
511
 
512
  def isbetter(self, new_res, old_res):
513
+ if self.better == "large":
514
  return new_res > old_res
515
+ if self.better == "small":
516
  return new_res < old_res
517
 
518
  def update(self, new_res, ep):
 
530
 
531
  def summary(self) -> dict:
532
  return {
533
+ "best_res": self.best_res,
534
+ "best_ep": self.best_ep,
535
  }
536
 
537
 
538
+ class BestMetricHolder:
539
+ def __init__(self, init_res=0.0, better="large", use_ema=False) -> None:
540
  self.best_all = BestMetricSingle(init_res, better)
541
  self.use_ema = use_ema
542
  if use_ema:
543
  self.best_ema = BestMetricSingle(init_res, better)
544
  self.best_regular = BestMetricSingle(init_res, better)
 
545
 
546
  def update(self, new_res, epoch, is_ema=False):
547
  """
 
562
  return self.best_all.summary()
563
 
564
  res = {}
565
+ res.update({f"all_{k}": v for k, v in self.best_all.summary().items()})
566
+ res.update({f"regular_{k}": v for k, v in self.best_regular.summary().items()})
567
+ res.update({f"ema_{k}": v for k, v in self.best_ema.summary().items()})
568
  return res
569
 
570
  def __repr__(self) -> str:
 
573
  def __str__(self) -> str:
574
  return self.__repr__()
575
 
576
+
577
+ def targets_to(targets: List[Dict[str, Any]], device):
578
+ """Moves the target dicts to the given device."""
579
+ excluded_keys = [
580
+ "questionId",
581
+ "tokens_positive",
582
+ "strings_positive",
583
+ "tokens",
584
+ "dataset_name",
585
+ "sentence_id",
586
+ "original_img_id",
587
+ "nb_eval",
588
+ "task_id",
589
+ "original_id",
590
+ "token_span",
591
+ "caption",
592
+ "dataset_type",
593
+ ]
594
+ return [
595
+ {k: v.to(device) if k not in excluded_keys else v for k, v in t.items()} for t in targets
596
+ ]
597
+
598
+
599
  def get_phrases_from_posmap(
600
+ posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer, left_idx: int = 0, right_idx: int = 255
601
  ):
602
  assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
603
  if posmap.dim() == 1:
 
607
  token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
608
  return tokenizer.decode(token_ids)
609
  else:
610
+ raise NotImplementedError("posmap must be 1-dim")