Gil-Simas commited on
Commit
4e6e22b
·
1 Parent(s): ad46376

refactor compute and compute_to_payload

Browse files
Files changed (1) hide show
  1. user-friendly-metrics.py +72 -35
user-friendly-metrics.py CHANGED
@@ -17,7 +17,8 @@ import os
17
 
18
  import datasets
19
  import evaluate
20
- from seametrics.user_friendly.utils import calculate_from_payload
 
21
 
22
  import wandb
23
 
@@ -55,10 +56,22 @@ Args:
55
  Default is 0.5.
56
  """
57
 
58
-
59
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
60
  class UserFriendlyMetrics(evaluate.Metric):
61
  """TODO: Short description of my evaluation module."""
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def _info(self):
64
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
@@ -89,30 +102,41 @@ class UserFriendlyMetrics(evaluate.Metric):
89
  # TODO: Download external resources if needed
90
  pass
91
 
92
- # def compute_from_payload(self, payload, **kwargs):
 
 
 
 
 
 
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- def _compute(self,
96
- payload: Payload,
97
- iou_threshold: float = 1e-10,
98
- filter={"name": "area", "ranges": [("all", [0, 1e5**2])]},
99
- recognition_thresholds=[0.3, 0.5, 0.8],
100
- **kwargs):
101
 
102
- return calculate_from_payload(
103
- payload,
104
- iou_threshold,
105
- filter,
106
- recognition_thresholds,
107
- **kwargs
108
- )
109
 
110
  def compute_from_payload(self,
111
  payload: Payload,
112
- iou_threshold: float = 1e-10,
113
- filter={"name": "area", "ranges": [("all", [0, 1e5**2])]},
114
- recognition_thresholds=[0.3, 0.5, 0.8],
115
- **kwargs):
116
 
117
  results = {}
118
 
@@ -128,28 +152,21 @@ class UserFriendlyMetrics(evaluate.Metric):
128
  models=[model_name],
129
  sequences={seq_name: sequence}
130
  )
131
- module = UserFriendlyMetrics(
132
- iou_threshold=iou_threshold,
133
- filter=filter,
134
- payload=sequence_payload
135
- recognition_thresholds=recognition_thresholds
136
- )
137
- results[model_name]["per_sequence"][seq_name] = module.compute()[model_name]["metrics"]
138
 
139
- # overall per-model loop
 
 
140
  model_payload = Payload(
141
  dataset=payload.dataset,
142
  gt_field_name=payload.gt_field_name,
143
  models=[model_name],
144
  sequences=payload.sequences
145
  )
146
- module = UserFriendlyMetrics(
147
- iou_threshold=iou_threshold,
148
- filter=filter,
149
- payload=model_payload
150
- recognition_thresholds=recognition_thresholds
151
- )
152
- results[model_name]["overall"] = module.compute()[model_name]["metrics"]
153
 
154
  return results
155
 
@@ -285,3 +302,23 @@ class UserFriendlyMetrics(evaluate.Metric):
285
  print("All metrics have been logged.")
286
 
287
  run.finish()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  import datasets
19
  import evaluate
20
+ from seametrics.user_friendly.utils import payload_to_uf_metrics, UFM
21
+ from seametrics.payload import Payload
22
 
23
  import wandb
24
 
 
56
  Default is 0.5.
57
  """
58
 
 
59
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
60
  class UserFriendlyMetrics(evaluate.Metric):
61
  """TODO: Short description of my evaluation module."""
62
+ def __init__(
63
+ self,
64
+ iou_threshold: float = 1e-10,
65
+ recognition_thresholds=[0.3, 0.5, 0.8],
66
+ filter_dict={"name": "area", "ranges": [("all", [0, 1e5**2])]},
67
+ **kwargs):
68
+ super().__init__(**kwargs)
69
+
70
+ # save parameters for later
71
+ self.iou_threshold = iou_threshold
72
+ self.filter_dict = filter_dict
73
+ self.recognition_thresholds = recognition_thresholds
74
+
75
 
76
  def _info(self):
77
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
 
102
  # TODO: Download external resources if needed
103
  pass
104
 
105
+ def _compute(
106
+ self,
107
+ predictions,
108
+ references,
109
+ ):
110
+
111
+ results = {}
112
+ filter_ranges = self.filter_dict["ranges"]
113
 
114
+ for filter_range in filter_ranges:
115
+
116
+ filter_range_name = filter_range[0]
117
+ range_results = {}
118
+
119
+ for sequence_predictions, sequence_references in zip(predictions, references):
120
+
121
+ ufm = UFM(
122
+ iou_threshold=self.iou_threshold,
123
+ recognition_thresholds=self.recognition_thresholds
124
+ )
125
+
126
+ sequence_range_results = ufm.calculate(
127
+ sequence_predictions,
128
+ sequence_references[filter_range_name],
129
+ )
130
 
131
+ range_results = sum_dicts(range_results, sequence_range_results)
132
+
133
+ results[filter_range_name] = ufm.realize_metrics(range_results, self.recognition_thresholds)
 
 
 
134
 
135
+ return results
 
 
 
 
 
 
136
 
137
  def compute_from_payload(self,
138
  payload: Payload,
139
+ ):
 
 
 
140
 
141
  results = {}
142
 
 
152
  models=[model_name],
153
  sequences={seq_name: sequence}
154
  )
155
+
156
+ predictions, references = payload_to_uf_metrics(payload, model_name=model_name, filter_dict=self.filter_dict)
 
 
 
 
 
157
 
158
+ results[model_name]["per_sequence"][seq_name] = self._compute(predictions, references)
159
+
160
+ # overall
161
  model_payload = Payload(
162
  dataset=payload.dataset,
163
  gt_field_name=payload.gt_field_name,
164
  models=[model_name],
165
  sequences=payload.sequences
166
  )
167
+ predictions, references = payload_to_uf_metrics(payload, model_name=model_name, filter_dict=self.filter_dict)
168
+
169
+ results[model_name]["overall"] = self._compute(predictions, references)
 
 
 
 
170
 
171
  return results
172
 
 
302
  print("All metrics have been logged.")
303
 
304
  run.finish()
305
+
306
+ def sum_dicts(*dicts):
307
+ """
308
+ Sums multiple dictionaries with depth one. If keys overlap, their values are summed.
309
+ If keys are unique, they are simply included in the result.
310
+
311
+ Args:
312
+ *dicts: Any number of dictionaries to be summed.
313
+
314
+ Returns:
315
+ A single dictionary with the summed values.
316
+ """
317
+ result = {}
318
+ for d in dicts:
319
+ for key, value in d.items():
320
+ if key in result:
321
+ result[key] += value
322
+ else:
323
+ result[key] = value
324
+ return result