Gil-Simas commited on
Commit
5f71887
·
1 Parent(s): c62f523

refactor to reinit class

Browse files
Files changed (1) hide show
  1. user-friendly-metrics.py +64 -14
user-friendly-metrics.py CHANGED
@@ -60,6 +60,20 @@ Args:
60
  class UserFriendlyMetrics(evaluate.Metric):
61
  """TODO: Short description of my evaluation module."""
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def _info(self):
64
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
65
  return evaluate.MetricInfo(
@@ -89,23 +103,59 @@ class UserFriendlyMetrics(evaluate.Metric):
89
  # TODO: Download external resources if needed
90
  pass
91
 
92
- def compute_from_payload(self, payload, **kwargs):
93
- return self._compute(payload, **kwargs)
94
 
95
- def _compute(
96
- self,
97
- payload,
98
- max_iou: float = 0.5,
99
- filter={},
100
- recognition_thresholds=[0.3, 0.5, 0.8],
101
- debug: bool = False,
102
- ):
103
- """Returns the scores"""
104
- # TODO: Compute the different scores of the module
105
  return calculate_from_payload(
106
- payload, max_iou, filter, recognition_thresholds, debug
 
 
 
 
107
  )
108
- # return calculate(predictions, references, max_iou)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def wandb(
111
  self,
 
60
  class UserFriendlyMetrics(evaluate.Metric):
61
  """TODO: Short description of my evaluation module."""
62
 
63
+ def __init__(self,
64
+ iou_threshold = 1e-10,
65
+ filter = {"name": "area", "ranges": [("all", [0, 1e5**2])]},
66
+ payload: Payload = None,,
67
+ recognition_thresholds: list = [0.3, 0.5, 0.8],
68
+ **kwargs
69
+ ):
70
+ super().__init__(**kwargs)
71
+
72
+ self.iou_threshold = iou_threshold
73
+ self.filter = filter
74
+ self.payload = payload
75
+ self.recognition_thresholds = recognition_thresholds
76
+
77
  def _info(self):
78
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
79
  return evaluate.MetricInfo(
 
103
  # TODO: Download external resources if needed
104
  pass
105
 
106
+ # def compute_from_payload(self, payload, **kwargs):
107
+
108
 
109
+ def _compute(self, **kwargs):
110
+
 
 
 
 
 
 
 
 
111
  return calculate_from_payload(
112
+ self.payload,
113
+ self.iou_threshold,
114
+ self.filter,
115
+ self.recognition_thresholds,
116
+ **kwargs
117
  )
118
+
119
+ def compute_from_payload(self,
120
+ payload: Payload,
121
+ **kwargs):
122
+
123
+ results = {}
124
+
125
+ for model_name in payload.models:
126
+ results[model_name] = {"overall": {}, "per_sequence": {}}
127
+
128
+ # per-sequence loop
129
+ for seq_name, sequence in payload.sequences.items():
130
+ # create new payload only with specific sequence and model
131
+ sequence_payload = Payload(
132
+ dataset=payload.dataset,
133
+ gt_field_name=payload.gt_field_name,
134
+ models=[model_name],
135
+ sequences={seq_name: sequence}
136
+ )
137
+ module = UserFriendlyMetrics(
138
+ iou_threshold=self.iou_threshold,
139
+ filter=self.filter,
140
+ payload=sequence_payload
141
+ recognition_thresholds=self.recognition_thresholds
142
+ )
143
+ results[model_name]["per_sequence"][seq_name] = module.compute()[model_name]["metrics"]
144
+
145
+ # overall per-model loop
146
+ model_payload = Payload(
147
+ dataset=payload.dataset,
148
+ gt_field_name=payload.gt_field_name,
149
+ models=[model_name],
150
+ sequences=payload.sequences
151
+ )
152
+ module = UserFriendlyMetrics(
153
+ iou_threshold=self.iou_threshold,
154
+ filter=self.filter,
155
+ payload=model_payload
156
+ recognition_thresholds=self.recognition_thresholds
157
+ )
158
+ results[model_name]["overall"] = module.compute()[model_name]["metrics"]
159
 
160
  def wandb(
161
  self,