Gil-Simas commited on
Commit
538cc7b
·
verified ·
1 Parent(s): f5c606e

add log_per_sequence as option

Browse files
Files changed (1) hide show
  1. user-friendly-metrics.py +25 -20
user-friendly-metrics.py CHANGED
@@ -115,6 +115,7 @@ class UserFriendlyMetrics(evaluate.Metric):
115
  wandb_project="user_friendly_metrics",
116
  log_plots: bool = True,
117
  debug: bool = False,
 
118
  ):
119
  """
120
  Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for overall metrics.
@@ -140,9 +141,12 @@ class UserFriendlyMetrics(evaluate.Metric):
140
  self.wandb_run(result = result,
141
  wandb_run_name = wandb_run_name,
142
  wandb_project = wandb_project,
143
- debug = debug)
 
 
 
144
 
145
- def wandb_run(self, result, wandb_run_name, wandb_project, debug, wandb_section = None, log_plots = True):
146
 
147
  run = wandb.init(
148
  project = wandb_project,
@@ -208,24 +212,25 @@ class UserFriendlyMetrics(evaluate.Metric):
208
  }
209
  )
210
 
211
- if "per_sequence" in result:
212
- sorted_sequences = sorted(
213
- result["per_sequence"].items(),
214
- key=lambda x: next(iter(x[1].values()), {}).get("all", {}).get("recall", 0),
215
- reverse=True, # Set to True for descending order
216
- )
217
-
218
- for sequence_name, sequence_data in sorted_sequences:
219
- for metric, value in sequence_data["all"].items():
220
- log_key = (
221
- f"{wandb_section}/per_sequence/{sequence_name}/{metric}"
222
- if wandb_section
223
- else f"per_sequence/{sequence_name}/{metric}"
224
- )
225
- run.log({log_key: value})
226
- if debug:
227
- print(f" {log_key} = {value}")
228
- print("----------------------------------------------------")
 
229
 
230
  if debug:
231
  print("\nDebug Mode: Logging Summary and History")
 
115
  wandb_project="user_friendly_metrics",
116
  log_plots: bool = True,
117
  debug: bool = False,
118
+ log_per_sequence = False
119
  ):
120
  """
121
  Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for overall metrics.
 
141
  self.wandb_run(result = result,
142
  wandb_run_name = wandb_run_name,
143
  wandb_project = wandb_project,
144
+ debug = debug
145
+ wandb_section = wandb_section,
146
+ log_plots = log_plots,
147
+ log_per_sequence = log_per_sequence)
148
 
149
+ def wandb_run(self, result, wandb_run_name, wandb_project, debug, wandb_section = None, log_plots = True, log_per_sequence = False):
150
 
151
  run = wandb.init(
152
  project = wandb_project,
 
212
  }
213
  )
214
 
215
+ if log_per_sequence:
216
+ if "per_sequence" in result:
217
+ sorted_sequences = sorted(
218
+ result["per_sequence"].items(),
219
+ key=lambda x: next(iter(x[1].values()), {}).get("all", {}).get("recall", 0),
220
+ reverse=True, # Set to True for descending order
221
+ )
222
+
223
+ for sequence_name, sequence_data in sorted_sequences:
224
+ for metric, value in sequence_data["all"].items():
225
+ log_key = (
226
+ f"{wandb_section}/per_sequence/{sequence_name}/{metric}"
227
+ if wandb_section
228
+ else f"per_sequence/{sequence_name}/{metric}"
229
+ )
230
+ run.log({log_key: value})
231
+ if debug:
232
+ print(f" {log_key} = {value}")
233
+ print("----------------------------------------------------")
234
 
235
  if debug:
236
  print("\nDebug Mode: Logging Summary and History")