hichem-abdellali commited on
Commit
d70cd24
·
verified ·
1 Parent(s): fcbfc13

dummy _compute() output

Browse files
Files changed (1) hide show
  1. ref-metric.py +146 -119
ref-metric.py CHANGED
@@ -99,126 +99,153 @@ class UserFriendlyMetrics(evaluate.Metric):
99
  ):
100
  """Returns the scores"""
101
  # TODO: Compute the different scores of the module
102
- return calculate_from_payload(
103
- payload, max_iou, filters, recognition_thresholds, debug
104
- )
105
  # return calculate(predictions, references, max_iou)
106
 
107
- def wandb(
108
- self,
109
- results,
110
- wandb_section: str = None,
111
- wandb_project="user_friendly_metrics",
112
- log_plots: bool = True,
113
- debug: bool = False,
114
- ):
115
- """
116
- Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for global metrics.
117
-
118
- Args:
119
- results (dict): Results dictionary with 'global' and 'per_sequence' keys.
120
- wandb_section (str, optional): W&B section for metric grouping. Defaults to None.
121
- wandb_project (str, optional): The name of the wandb project. Defaults to 'user_friendly_metrics'.
122
- log_plots (bool, optional): Generates categorized bar charts for global metrics. Defaults to True.
123
- debug (bool, optional): Logs detailed summaries and histories to the terminal console. Defaults to False.
124
- """
125
-
126
- current_datetime = datetime.datetime.now()
127
- formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
128
- wandb.login(key=os.getenv("WANDB_API_KEY"))
129
-
130
- run = wandb.init(
131
- project=wandb_project,
132
- name=f"evaluation-{formatted_datetime}",
133
- reinit=True,
134
- settings=wandb.Settings(silent=not debug),
135
- )
136
-
137
- categories = {
138
- "user_friendly_metrics": {
139
- "mostly_tracked_score_0.3",
140
- "mostly_tracked_score_0.5",
141
- "mostly_tracked_score_0.8",
142
- },
143
- "evaluation_metrics_dev": {
144
- "f1",
145
- "recall",
146
- "precision",
147
- },
148
- "user_friendly_metrics_dev": {
149
- "mostly_tracked_count_0.3",
150
- "mostly_tracked_count_0.5",
151
- "mostly_tracked_count_0.8",
152
- "unique_obj_count",
153
- },
154
- "predictions_summary": {
155
- "fp",
156
- "tp",
157
- "fn",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  },
159
- }
160
-
161
- chart_data = {key: [] for key in categories.keys()}
162
-
163
- # Log global metrics
164
- if "global" in results:
165
- for global_key, global_metrics in results["global"].items():
166
- for metric, value in global_metrics["all"].items():
167
- log_key = (
168
- f"{wandb_section}/global/{global_key}/{metric}"
169
- if wandb_section
170
- else f"global/{global_key}/{metric}"
171
- )
172
- run.log({log_key: value})
173
-
174
- if debug:
175
- print(f" {log_key} = {value}")
176
-
177
- for category, metrics in categories.items():
178
- if metric in metrics:
179
- chart_data[category].append([metric, value])
180
- print("----------------------------------------------------")
181
-
182
- if log_plots:
183
- for category, data in chart_data.items():
184
- if data:
185
- table_data = [[label, value] for label, value in data]
186
- table = wandb.Table(data=table_data, columns=["metrics", "value"])
187
- run.log(
188
- {
189
- f"{category}_bar_chart": wandb.plot.bar(
190
- table,
191
- "metrics",
192
- "value",
193
- title=f"{category.replace('_', ' ').title()}",
194
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  }
196
- )
197
-
198
- if "per_sequence" in results:
199
- sorted_sequences = sorted(
200
- results["per_sequence"].items(),
201
- key=lambda x: next(iter(x[1].values()), {}).get("all", {}).get("f1", 0),
202
- reverse=True, # Set to True for descending order
203
- )
204
-
205
- for sequence_name, sequence_data in sorted_sequences:
206
- for seq_key, seq_metrics in sequence_data.items():
207
- for metric, value in seq_metrics["all"].items():
208
- log_key = (
209
- f"{wandb_section}/per_sequence/{sequence_name}/{seq_key}/{metric}"
210
- if wandb_section
211
- else f"per_sequence/{sequence_name}/{seq_key}/{metric}"
212
- )
213
- run.log({log_key: value})
214
- if debug:
215
- print(f" {log_key} = {value}")
216
- print("----------------------------------------------------")
217
-
218
- if debug:
219
- print("\nDebug Mode: Logging Summary and History")
220
- print(f"Results Summary:\n{results}")
221
- print(f"WandB Settings:\n{run.settings}")
222
- print("All metrics have been logged.")
223
-
224
- run.finish()
 
99
  ):
100
  """Returns the scores"""
101
  # TODO: Compute the different scores of the module
102
+ return dummy_values()
 
 
103
  # return calculate(predictions, references, max_iou)
104
 
105
+ def dummy_values():
106
+ return {
107
+ "model_1": {
108
+ "overall": {
109
+ "all": {
110
+ "tp": 50,
111
+ "fp": 20,
112
+ "fn": 10,
113
+ "precision": 0.71,
114
+ "recall": 0.83,
115
+ "f1": 0.76
116
+ },
117
+ "small": {
118
+ "tp": 15,
119
+ "fp": 5,
120
+ "fn": 2,
121
+ "precision": 0.75,
122
+ "recall": 0.88,
123
+ "f1": 0.81
124
+ },
125
+ "medium": {
126
+ "tp": 25,
127
+ "fp": 10,
128
+ "fn": 5,
129
+ "precision": 0.71,
130
+ "recall": 0.83,
131
+ "f1": 0.76
132
+ },
133
+ "large": {
134
+ "tp": 10,
135
+ "fp": 5,
136
+ "fn": 3,
137
+ "precision": 0.67,
138
+ "recall": 0.77,
139
+ "f1": 0.71
140
+ }
141
+ },
142
+ "per_sequence": {
143
+ "sequence_1": {
144
+ "all": {
145
+ "tp": 30,
146
+ "fp": 15,
147
+ "fn": 7,
148
+ "precision": 0.67,
149
+ "recall": 0.81,
150
+ "f1": 0.73
151
+ },
152
+ "small": {
153
+ "tp": 10,
154
+ "fp": 3,
155
+ "fn": 1,
156
+ "precision": 0.77,
157
+ "recall": 0.91,
158
+ "f1": 0.83
159
+ },
160
+ "medium": {
161
+ "tp": 15,
162
+ "fp": 7,
163
+ "fn": 2,
164
+ "precision": 0.68,
165
+ "recall": 0.88,
166
+ "f1": 0.77
167
+ },
168
+ "large": {
169
+ "tp": 5,
170
+ "fp": 2,
171
+ "fn": 1,
172
+ "precision": 0.71,
173
+ "recall": 0.83,
174
+ "f1": 0.76
175
+ }
176
+ }
177
+ }
178
  },
179
+ "model_2": {
180
+ "overall": {
181
+ "all": {
182
+ "tp": 60,
183
+ "fp": 25,
184
+ "fn": 15,
185
+ "precision": 0.71,
186
+ "recall": 0.80,
187
+ "f1": 0.75
188
+ },
189
+ "small": {
190
+ "tp": 20,
191
+ "fp": 6,
192
+ "fn": 3,
193
+ "precision": 0.77,
194
+ "recall": 0.87,
195
+ "f1": 0.82
196
+ },
197
+ "medium": {
198
+ "tp": 30,
199
+ "fp": 12,
200
+ "fn": 5,
201
+ "precision": 0.71,
202
+ "recall": 0.86,
203
+ "f1": 0.78
204
+ },
205
+ "large": {
206
+ "tp": 10,
207
+ "fp": 7,
208
+ "fn": 5,
209
+ "precision": 0.59,
210
+ "recall": 0.67,
211
+ "f1": 0.63
212
+ }
213
+ },
214
+ "per_sequence": {
215
+ "sequence_1": {
216
+ "all": {
217
+ "tp": 40,
218
+ "fp": 18,
219
+ "fn": 8,
220
+ "precision": 0.69,
221
+ "recall": 0.83,
222
+ "f1": 0.75
223
+ },
224
+ "small": {
225
+ "tp": 12,
226
+ "fp": 4,
227
+ "fn": 2,
228
+ "precision": 0.75,
229
+ "recall": 0.86,
230
+ "f1": 0.80
231
+ },
232
+ "medium": {
233
+ "tp": 20,
234
+ "fp": 8,
235
+ "fn": 3,
236
+ "precision": 0.71,
237
+ "recall": 0.87,
238
+ "f1": 0.78
239
+ },
240
+ "large": {
241
+ "tp": 8,
242
+ "fp": 6,
243
+ "fn": 3,
244
+ "precision": 0.57,
245
+ "recall": 0.73,
246
+ "f1": 0.64
247
  }
248
+ }
249
+ }
250
+ }
251
+ }