Spaces:
Sleeping
Sleeping
add log_per_sequence as option
Browse files- user-friendly-metrics.py +25 -20
user-friendly-metrics.py
CHANGED
@@ -115,6 +115,7 @@ class UserFriendlyMetrics(evaluate.Metric):
|
|
115 |
wandb_project="user_friendly_metrics",
|
116 |
log_plots: bool = True,
|
117 |
debug: bool = False,
|
|
|
118 |
):
|
119 |
"""
|
120 |
Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for overall metrics.
|
@@ -140,9 +141,12 @@ class UserFriendlyMetrics(evaluate.Metric):
|
|
140 |
self.wandb_run(result = result,
|
141 |
wandb_run_name = wandb_run_name,
|
142 |
wandb_project = wandb_project,
|
143 |
-
debug = debug
|
|
|
|
|
|
|
144 |
|
145 |
-
def wandb_run(self, result, wandb_run_name, wandb_project, debug, wandb_section = None, log_plots = True):
|
146 |
|
147 |
run = wandb.init(
|
148 |
project = wandb_project,
|
@@ -208,24 +212,25 @@ class UserFriendlyMetrics(evaluate.Metric):
|
|
208 |
}
|
209 |
)
|
210 |
|
211 |
-
if
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
for
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
229 |
|
230 |
if debug:
|
231 |
print("\nDebug Mode: Logging Summary and History")
|
|
|
115 |
wandb_project="user_friendly_metrics",
|
116 |
log_plots: bool = True,
|
117 |
debug: bool = False,
|
118 |
+
log_per_sequence = False
|
119 |
):
|
120 |
"""
|
121 |
Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for overall metrics.
|
|
|
141 |
self.wandb_run(result = result,
|
142 |
wandb_run_name = wandb_run_name,
|
143 |
wandb_project = wandb_project,
|
144 |
+
debug = debug
|
145 |
+
wandb_section = wandb_section,
|
146 |
+
log_plots = log_plots,
|
147 |
+
log_per_sequence = log_per_sequence)
|
148 |
|
149 |
+
def wandb_run(self, result, wandb_run_name, wandb_project, debug, wandb_section = None, log_plots = True, log_per_sequence = False):
|
150 |
|
151 |
run = wandb.init(
|
152 |
project = wandb_project,
|
|
|
212 |
}
|
213 |
)
|
214 |
|
215 |
+
if log_per_sequence:
|
216 |
+
if "per_sequence" in result:
|
217 |
+
sorted_sequences = sorted(
|
218 |
+
result["per_sequence"].items(),
|
219 |
+
key=lambda x: next(iter(x[1].values()), {}).get("all", {}).get("recall", 0),
|
220 |
+
reverse=True, # Set to True for descending order
|
221 |
+
)
|
222 |
+
|
223 |
+
for sequence_name, sequence_data in sorted_sequences:
|
224 |
+
for metric, value in sequence_data["all"].items():
|
225 |
+
log_key = (
|
226 |
+
f"{wandb_section}/per_sequence/{sequence_name}/{metric}"
|
227 |
+
if wandb_section
|
228 |
+
else f"per_sequence/{sequence_name}/{metric}"
|
229 |
+
)
|
230 |
+
run.log({log_key: value})
|
231 |
+
if debug:
|
232 |
+
print(f" {log_key} = {value}")
|
233 |
+
print("----------------------------------------------------")
|
234 |
|
235 |
if debug:
|
236 |
print("\nDebug Mode: Logging Summary and History")
|