Upload metrics.py with huggingface_hub
Browse files- metrics.py +1338 -199
metrics.py
CHANGED
|
@@ -1,19 +1,24 @@
|
|
| 1 |
import re
|
| 2 |
import string
|
| 3 |
import uuid
|
|
|
|
| 4 |
from abc import ABC, abstractmethod
|
| 5 |
from collections import Counter
|
|
|
|
| 6 |
from dataclasses import field
|
|
|
|
| 7 |
from typing import Any, Dict, Generator, List, Optional, Tuple
|
| 8 |
|
| 9 |
import evaluate
|
| 10 |
import numpy
|
| 11 |
import numpy as np
|
| 12 |
from scipy.stats import bootstrap
|
|
|
|
| 13 |
|
| 14 |
from .artifact import Artifact
|
| 15 |
from .dataclass import InternalField, OptionalField
|
| 16 |
from .logging_utils import get_logger
|
|
|
|
| 17 |
from .operator import (
|
| 18 |
MultiStreamOperator,
|
| 19 |
SingleStreamOperator,
|
|
@@ -22,14 +27,17 @@ from .operator import (
|
|
| 22 |
)
|
| 23 |
from .operators import CopyFields
|
| 24 |
from .random_utils import get_seed
|
|
|
|
| 25 |
from .stream import MultiStream, Stream
|
| 26 |
-
from .type_utils import isoftype
|
| 27 |
|
| 28 |
logger = get_logger()
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def abstract_factory():
|
|
@@ -40,6 +48,18 @@ def abstract_field():
|
|
| 40 |
return field(default_factory=abstract_factory)
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
class UpdateStream(StreamInstanceOperator):
|
| 44 |
update: dict
|
| 45 |
|
|
@@ -57,6 +77,48 @@ class Metric(Artifact):
|
|
| 57 |
def main_score(self):
|
| 58 |
pass
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
class MetricWithConfidenceInterval(Metric):
|
| 62 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
|
@@ -73,7 +135,12 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 73 |
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
| 74 |
|
| 75 |
def disable_confidence_interval_calculation(self):
|
|
|
|
| 76 |
self.n_resamples = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
def _can_compute_confidence_intervals(self, num_predictions):
|
| 79 |
return (
|
|
@@ -82,45 +149,117 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 82 |
and num_predictions > 1
|
| 83 |
)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
The instances for which the confidence intervals are computed.
|
| 92 |
"""
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
|
|
|
|
|
|
|
|
|
| 95 |
result = {}
|
| 96 |
|
| 97 |
if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
|
| 98 |
return result
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
| 104 |
for score_name in score_names:
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
ci = bootstrap(
|
| 109 |
-
(
|
| 110 |
-
statistic=
|
| 111 |
n_resamples=self.n_resamples,
|
| 112 |
confidence_level=self.confidence_level,
|
| 113 |
random_state=self.new_random_generator(),
|
| 114 |
).confidence_interval
|
| 115 |
-
|
| 116 |
-
result[f"{
|
|
|
|
| 117 |
if score_name == self.main_score:
|
| 118 |
result["score_ci_low"] = ci.low
|
| 119 |
result["score_ci_high"] = ci.high
|
| 120 |
return result
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
def compute_global_confidence_intervals(
|
| 123 |
-
self, references, predictions,
|
| 124 |
):
|
| 125 |
"""Computed confidence intervals for a set of references and predictions."""
|
| 126 |
random_gen = self.new_random_generator()
|
|
@@ -128,12 +267,12 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 128 |
def statistic(arr, axis):
|
| 129 |
# arr is a 2d array where each row is a resampling, so we
|
| 130 |
# iterate over the rows and compute the metric on each resampling
|
| 131 |
-
def metric(sample_refs, sample_preds,
|
| 132 |
try:
|
| 133 |
return self._compute(
|
| 134 |
references=sample_refs,
|
| 135 |
predictions=sample_preds,
|
| 136 |
-
|
| 137 |
)["score"]
|
| 138 |
except Exception as e:
|
| 139 |
# this happens in edge cases, for example, when the sampling creates a
|
|
@@ -141,40 +280,21 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 141 |
logger.info(f"Warning in {self.__class__.__name__}", e)
|
| 142 |
return np.nan
|
| 143 |
|
|
|
|
| 144 |
scores = numpy.apply_along_axis(
|
| 145 |
lambda x: metric(
|
| 146 |
sample_refs=[references[i] for i in x],
|
| 147 |
sample_preds=[predictions[i] for i in x],
|
| 148 |
-
|
| 149 |
),
|
| 150 |
axis=axis,
|
| 151 |
arr=arr,
|
| 152 |
)
|
| 153 |
|
| 154 |
-
#
|
| 155 |
-
#
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
# edge cases - for example, when the sample contains only empty strings.
|
| 159 |
-
# CI is about the distribution around the statistic (e.g. mean), it doesn't deal with
|
| 160 |
-
# cases in which the metric is not computable. Therefore, we ignore these edge cases
|
| 161 |
-
# as part of the computation of CI. The question is how to implement this policy.
|
| 162 |
-
# Options:
|
| 163 |
-
# 1. skip the errors and return a shorter array => this fails because Scipy demans
|
| 164 |
-
# this callback (i.e. the statistic() callback) to return an array of the same size
|
| 165 |
-
# as the number of resamples
|
| 166 |
-
# 2. Put np.nan for the errors => this fails because in such case the ci itself
|
| 167 |
-
# becomes np.nan. So one edge case can fail the whole CI computation.
|
| 168 |
-
# 3. Replace the errors with a sampling from the successful cases => this is what
|
| 169 |
-
# is implemented.
|
| 170 |
-
error_indices = numpy.isnan(scores)
|
| 171 |
-
n_errors = sum(error_indices)
|
| 172 |
-
if n_errors > 0:
|
| 173 |
-
new_scores = random_gen.choice(scores, n_errors, replace=True)
|
| 174 |
-
scores = scores[~error_indices]
|
| 175 |
-
scores = np.concatenate([scores, new_scores])
|
| 176 |
-
|
| 177 |
-
return scores
|
| 178 |
|
| 179 |
result = {}
|
| 180 |
num_predictions = len(predictions)
|
|
@@ -202,12 +322,15 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 202 |
need to be considered. Accuracy, on the other hand, is just an average of the accuracy of all the instances.
|
| 203 |
"""
|
| 204 |
|
| 205 |
-
n_resamples =
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 208 |
references = []
|
| 209 |
predictions = []
|
| 210 |
-
|
| 211 |
global_score = {}
|
| 212 |
|
| 213 |
instances = []
|
|
@@ -226,31 +349,40 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 226 |
predictions.append(instance_prediction)
|
| 227 |
instances.append(instance)
|
| 228 |
|
| 229 |
-
|
| 230 |
-
instance["
|
| 231 |
)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
if isinstance(self.main_score, str):
|
| 243 |
-
instance_score[self.main_score] =
|
| 244 |
|
| 245 |
instance["score"]["instance"].update(instance_score)
|
| 246 |
|
| 247 |
-
result = self._compute(references, predictions,
|
| 248 |
|
| 249 |
global_score.update(result)
|
| 250 |
|
| 251 |
score_name = global_score["score_name"]
|
| 252 |
confidence_interval = self.compute_global_confidence_intervals(
|
| 253 |
-
references, predictions,
|
| 254 |
)
|
| 255 |
global_score.update(confidence_interval)
|
| 256 |
|
|
@@ -262,9 +394,9 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 262 |
self,
|
| 263 |
references: List[List[str]],
|
| 264 |
predictions: List[str],
|
| 265 |
-
|
| 266 |
) -> dict:
|
| 267 |
-
result = self.compute(references, predictions,
|
| 268 |
result["score"] = result[self.main_score]
|
| 269 |
result["score_name"] = self.main_score
|
| 270 |
return result
|
|
@@ -274,13 +406,25 @@ class GlobalMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 274 |
self,
|
| 275 |
references: List[List[Any]],
|
| 276 |
predictions: List[Any],
|
| 277 |
-
|
| 278 |
) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
pass
|
| 280 |
|
| 281 |
|
| 282 |
class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
| 283 |
-
n_resamples =
|
|
|
|
|
|
|
| 284 |
main_score: str
|
| 285 |
reduction_map: Dict[str, List[str]]
|
| 286 |
|
|
@@ -301,8 +445,8 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 301 |
),
|
| 302 |
)
|
| 303 |
|
| 304 |
-
|
| 305 |
-
instance["
|
| 306 |
for instance in stream
|
| 307 |
]
|
| 308 |
|
|
@@ -310,7 +454,7 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 310 |
instance_scores = self.compute(
|
| 311 |
references=references,
|
| 312 |
predictions=predictions,
|
| 313 |
-
|
| 314 |
)
|
| 315 |
|
| 316 |
# add the score and score_name fields
|
|
@@ -334,8 +478,6 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 334 |
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
|
| 335 |
|
| 336 |
if reduction == "mean":
|
| 337 |
-
from statistics import mean
|
| 338 |
-
|
| 339 |
for field_name in fields:
|
| 340 |
global_score[field_name] = mean(
|
| 341 |
[
|
|
@@ -347,8 +489,13 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 347 |
global_score["score"] = global_score[field_name]
|
| 348 |
global_score["score_name"] = self.main_score
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
confidence_interval = self.score_based_confidence_interval(
|
| 351 |
-
instances=instances
|
| 352 |
)
|
| 353 |
global_score.update(confidence_interval)
|
| 354 |
|
|
@@ -360,33 +507,217 @@ class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 360 |
self,
|
| 361 |
references: List[List[Any]],
|
| 362 |
predictions: List[Any],
|
| 363 |
-
|
| 364 |
) -> List[Dict[str, Any]]:
|
| 365 |
pass
|
| 366 |
|
| 367 |
|
| 368 |
class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
@property
|
| 374 |
@abstractmethod
|
| 375 |
def reduction_map(self) -> dict:
|
| 376 |
pass
|
| 377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
global_score = {}
|
| 380 |
instances = []
|
| 381 |
|
| 382 |
for instance in stream:
|
| 383 |
refs, pred = instance["references"], instance["prediction"]
|
| 384 |
-
|
| 385 |
-
instance["additional_inputs"] if "additional_inputs" in instance else {}
|
| 386 |
-
)
|
| 387 |
|
| 388 |
instance_score = self.compute(
|
| 389 |
-
references=refs, prediction=pred,
|
| 390 |
)
|
| 391 |
instance_score["score"] = instance_score[self.main_score]
|
| 392 |
instance_score["score_name"] = self.main_score
|
|
@@ -399,36 +730,100 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 399 |
|
| 400 |
instances.append(instance)
|
| 401 |
|
| 402 |
-
|
| 403 |
-
assert (
|
| 404 |
-
reduction in self.implemented_reductions
|
| 405 |
-
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
|
| 406 |
|
| 407 |
-
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
global_score[field_name] = mean(scores)
|
| 416 |
-
if field_name == self.main_score:
|
| 417 |
-
global_score["score"] = global_score[field_name]
|
| 418 |
-
global_score["score_name"] = self.main_score
|
| 419 |
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
)
|
| 423 |
-
global_score.update(confidence_interval)
|
| 424 |
|
| 425 |
-
|
| 426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
@abstractmethod
|
| 429 |
-
def compute(
|
| 430 |
-
self, references: List[Any], prediction: Any, additional_inputs: Dict
|
| 431 |
-
) -> dict:
|
| 432 |
pass
|
| 433 |
|
| 434 |
|
|
@@ -445,7 +840,7 @@ class Squad(GlobalMetric):
|
|
| 445 |
self,
|
| 446 |
references: List[List[str]],
|
| 447 |
predictions: List[str],
|
| 448 |
-
|
| 449 |
) -> dict:
|
| 450 |
ids = [str(uuid.uuid4()).replace("-", "") for _ in range(len(predictions))]
|
| 451 |
formatted_predictions = [
|
|
@@ -466,9 +861,10 @@ class Squad(GlobalMetric):
|
|
| 466 |
class Accuracy(InstanceMetric):
|
| 467 |
reduction_map = {"mean": ["accuracy"]}
|
| 468 |
main_score = "accuracy"
|
|
|
|
| 469 |
|
| 470 |
def compute(
|
| 471 |
-
self, references: List[Any], prediction: Any,
|
| 472 |
) -> dict:
|
| 473 |
result = {
|
| 474 |
self.main_score: float(
|
|
@@ -483,13 +879,14 @@ class Accuracy(InstanceMetric):
|
|
| 483 |
class StringContainment(InstanceMetric):
|
| 484 |
reduction_map = {"mean": ["string_containment"]}
|
| 485 |
main_score = "string_containment"
|
|
|
|
| 486 |
|
| 487 |
def compute(
|
| 488 |
-
self, references: List[Any], prediction: Any,
|
| 489 |
) -> dict:
|
| 490 |
result = {
|
| 491 |
self.main_score: float(
|
| 492 |
-
any(str(reference) in prediction for reference in references)
|
| 493 |
)
|
| 494 |
}
|
| 495 |
result["score"] = result[self.main_score]
|
|
@@ -505,6 +902,13 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
| 505 |
)
|
| 506 |
metric: Metric = None
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
def verify(self):
|
| 509 |
assert self.main_score is not None, "main_score is not set"
|
| 510 |
|
|
@@ -569,37 +973,37 @@ class HuggingfaceMetric(GlobalMetric):
|
|
| 569 |
self,
|
| 570 |
references: List[List[Any]],
|
| 571 |
predictions: List[Any],
|
| 572 |
-
|
| 573 |
) -> dict:
|
| 574 |
-
|
| 575 |
for additional_input_field in self.hf_additional_input_fields:
|
| 576 |
assert (
|
| 577 |
-
additional_input_field in
|
| 578 |
-
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in
|
| 579 |
-
|
| 580 |
additional_input[additional_input_field]
|
| 581 |
-
for additional_input in
|
| 582 |
]
|
| 583 |
for additional_input_field in self.hf_additional_input_fields_pass_one_value:
|
| 584 |
assert (
|
| 585 |
-
additional_input_field in
|
| 586 |
-
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in
|
| 587 |
|
| 588 |
values = {
|
| 589 |
additional_input[additional_input_field]
|
| 590 |
-
for additional_input in
|
| 591 |
}
|
| 592 |
assert (
|
| 593 |
len(values) == 1
|
| 594 |
), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
|
| 595 |
|
| 596 |
-
|
| 597 |
|
| 598 |
-
# add check that all required fields in self.metrics are in
|
| 599 |
result = self.metric.compute(
|
| 600 |
predictions=predictions,
|
| 601 |
references=references,
|
| 602 |
-
**
|
| 603 |
**self.hf_compute_args,
|
| 604 |
)
|
| 605 |
if self.hf_main_score:
|
|
@@ -641,23 +1045,23 @@ class HuggingfaceBulkMetric(BulkInstanceMetric):
|
|
| 641 |
self,
|
| 642 |
references: List[List[str]],
|
| 643 |
predictions: List[str],
|
| 644 |
-
|
| 645 |
) -> List[Dict[str, Any]]:
|
| 646 |
-
|
| 647 |
for additional_input_field in self.hf_additional_input_fields:
|
| 648 |
assert (
|
| 649 |
-
additional_input_field in
|
| 650 |
-
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in
|
| 651 |
-
|
| 652 |
additional_input[additional_input_field]
|
| 653 |
-
for additional_input in
|
| 654 |
]
|
| 655 |
-
# add check that all required fields in self.metrics are in
|
| 656 |
|
| 657 |
scores = self.metric.compute(
|
| 658 |
predictions=predictions,
|
| 659 |
references=references,
|
| 660 |
-
**
|
| 661 |
**self.hf_compute_args,
|
| 662 |
)
|
| 663 |
|
|
@@ -692,7 +1096,7 @@ class F1(GlobalMetric):
|
|
| 692 |
self,
|
| 693 |
references: List[List[str]],
|
| 694 |
predictions: List[str],
|
| 695 |
-
|
| 696 |
) -> dict:
|
| 697 |
assert all(
|
| 698 |
len(reference) == 1 for reference in references
|
|
@@ -714,8 +1118,6 @@ class F1(GlobalMetric):
|
|
| 714 |
average=self.average,
|
| 715 |
)
|
| 716 |
if isinstance(result["f1"], numpy.ndarray):
|
| 717 |
-
from statistics import mean
|
| 718 |
-
|
| 719 |
final_result = {self.main_score: mean(result["f1"])}
|
| 720 |
for i, label in enumerate(labels):
|
| 721 |
final_result["f1_" + self.id_to_str[label]] = result["f1"][i]
|
|
@@ -742,7 +1144,6 @@ class F1MultiLabel(GlobalMetric):
|
|
| 742 |
_metric = None
|
| 743 |
main_score = "f1_macro"
|
| 744 |
average = None # Report per class then aggregate by mean
|
| 745 |
-
classes_to_ignore = ["none"]
|
| 746 |
metric = "f1"
|
| 747 |
|
| 748 |
def prepare(self):
|
|
@@ -767,7 +1168,7 @@ class F1MultiLabel(GlobalMetric):
|
|
| 767 |
self,
|
| 768 |
references: List[List[str]],
|
| 769 |
predictions: List[List[str]],
|
| 770 |
-
|
| 771 |
) -> dict:
|
| 772 |
self.str_to_id = {}
|
| 773 |
self.id_to_str = {}
|
|
@@ -775,13 +1176,9 @@ class F1MultiLabel(GlobalMetric):
|
|
| 775 |
self._validate_references_and_prediction(references, predictions)
|
| 776 |
references = [reference[0] for reference in references]
|
| 777 |
|
| 778 |
-
labels =
|
| 779 |
-
|
| 780 |
-
for lbl in {label for reference in references for label in reference}
|
| 781 |
-
if lbl not in self.classes_to_ignore
|
| 782 |
-
]
|
| 783 |
# if no classes are left then F1 is not defined
|
| 784 |
-
# (e.g. only "none" in references)
|
| 785 |
if len(labels) == 0:
|
| 786 |
return {self.main_score: float("nan")}
|
| 787 |
|
|
@@ -809,8 +1206,6 @@ class F1MultiLabel(GlobalMetric):
|
|
| 809 |
labels=labels_param,
|
| 810 |
)
|
| 811 |
if isinstance(result[self.metric], numpy.ndarray):
|
| 812 |
-
from statistics import mean
|
| 813 |
-
|
| 814 |
assert (
|
| 815 |
len(result[self.metric]) == len(labels)
|
| 816 |
), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
|
|
@@ -883,6 +1278,8 @@ class Rouge(HuggingfaceMetric):
|
|
| 883 |
|
| 884 |
sent_split_newline: bool = True
|
| 885 |
|
|
|
|
|
|
|
| 886 |
def prepare(self):
|
| 887 |
super().prepare()
|
| 888 |
|
|
@@ -895,7 +1292,7 @@ class Rouge(HuggingfaceMetric):
|
|
| 895 |
nltk.download("punkt")
|
| 896 |
self.sent_tokenize = nltk.sent_tokenize
|
| 897 |
|
| 898 |
-
def compute(self, references, predictions,
|
| 899 |
if self.sent_split_newline:
|
| 900 |
predictions = [
|
| 901 |
"\n".join(self.sent_tokenize(prediction.strip()))
|
|
@@ -905,13 +1302,16 @@ class Rouge(HuggingfaceMetric):
|
|
| 905 |
["\n".join(self.sent_tokenize(r.strip())) for r in reference]
|
| 906 |
for reference in references
|
| 907 |
]
|
| 908 |
-
return super().compute(references, predictions,
|
| 909 |
|
| 910 |
|
| 911 |
# Computes char edit distance, ignoring whitespace
|
| 912 |
class CharEditDistanceAccuracy(InstanceMetric):
|
| 913 |
reduction_map = {"mean": ["char_edit_dist_accuracy"]}
|
| 914 |
main_score = "char_edit_dist_accuracy"
|
|
|
|
|
|
|
|
|
|
| 915 |
|
| 916 |
def prepare(self):
|
| 917 |
super().prepare()
|
|
@@ -919,9 +1319,7 @@ class CharEditDistanceAccuracy(InstanceMetric):
|
|
| 919 |
|
| 920 |
self.eval = editdistance.eval
|
| 921 |
|
| 922 |
-
def compute(
|
| 923 |
-
self, references, prediction: str, additional_inputs: List[Dict]
|
| 924 |
-
) -> dict:
|
| 925 |
assert (
|
| 926 |
len(references) == 1
|
| 927 |
), f"Expected only one reference , but received: {references}"
|
|
@@ -939,11 +1337,13 @@ class Wer(HuggingfaceMetric):
|
|
| 939 |
hf_metric_name = "wer"
|
| 940 |
main_score = "wer"
|
| 941 |
|
|
|
|
|
|
|
| 942 |
def compute(
|
| 943 |
self,
|
| 944 |
references: List[List[str]],
|
| 945 |
predictions: List[str],
|
| 946 |
-
|
| 947 |
) -> dict:
|
| 948 |
assert all(
|
| 949 |
len(reference) == 1 for reference in references
|
|
@@ -955,6 +1355,43 @@ class Wer(HuggingfaceMetric):
|
|
| 955 |
return {self.main_score: result}
|
| 956 |
|
| 957 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 958 |
class MatthewsCorrelation(HuggingfaceMetric):
|
| 959 |
hf_metric_name = "matthews_correlation"
|
| 960 |
main_score = "matthews_correlation"
|
|
@@ -970,7 +1407,7 @@ class MatthewsCorrelation(HuggingfaceMetric):
|
|
| 970 |
self,
|
| 971 |
references: List[List[str]],
|
| 972 |
predictions: List[str],
|
| 973 |
-
|
| 974 |
) -> dict:
|
| 975 |
formatted_references = [
|
| 976 |
self.get_str_id(reference[0]) for reference in references
|
|
@@ -983,6 +1420,33 @@ class MatthewsCorrelation(HuggingfaceMetric):
|
|
| 983 |
)
|
| 984 |
|
| 985 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 986 |
class CustomF1(GlobalMetric):
|
| 987 |
main_score = "f1_micro"
|
| 988 |
groups = None
|
|
@@ -1036,9 +1500,9 @@ class CustomF1(GlobalMetric):
|
|
| 1036 |
except ZeroDivisionError:
|
| 1037 |
return self.zero_division
|
| 1038 |
|
| 1039 |
-
def get_groups(self, elements,
|
| 1040 |
groups = set()
|
| 1041 |
-
for sublist, additional_input in zip(elements,
|
| 1042 |
for e in sublist:
|
| 1043 |
if self.should_ignore_element(e, additional_input):
|
| 1044 |
continue
|
|
@@ -1049,7 +1513,7 @@ class CustomF1(GlobalMetric):
|
|
| 1049 |
self,
|
| 1050 |
references: List[List[Any]],
|
| 1051 |
predictions: List[Any],
|
| 1052 |
-
|
| 1053 |
) -> dict:
|
| 1054 |
# in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
|
| 1055 |
if (
|
|
@@ -1065,12 +1529,12 @@ class CustomF1(GlobalMetric):
|
|
| 1065 |
)
|
| 1066 |
|
| 1067 |
if self.groups is None:
|
| 1068 |
-
groups = self.get_groups(references,
|
| 1069 |
else:
|
| 1070 |
groups = self.groups
|
| 1071 |
groups_statistics = {}
|
| 1072 |
for references_batch, predictions_batch, additional_input in zip(
|
| 1073 |
-
references, predictions,
|
| 1074 |
):
|
| 1075 |
grouped_references = self.group_elements(references_batch, additional_input)
|
| 1076 |
grouped_predictions = self.group_elements(
|
|
@@ -1187,10 +1651,11 @@ class TokenOverlap(InstanceMetric):
|
|
| 1187 |
ci_scores = ["f1", "precision", "recall"]
|
| 1188 |
|
| 1189 |
def compute(
|
| 1190 |
-
self, references: List[Any], prediction: Any,
|
| 1191 |
) -> dict:
|
| 1192 |
results = [
|
| 1193 |
-
self._compute_single_ref(reference, prediction)
|
|
|
|
| 1194 |
]
|
| 1195 |
return {
|
| 1196 |
measure: max(r[i] for r in results)
|
|
@@ -1200,8 +1665,8 @@ class TokenOverlap(InstanceMetric):
|
|
| 1200 |
def _compute_single_ref(
|
| 1201 |
self, reference: Any, prediction: Any
|
| 1202 |
) -> Tuple[float, float, float]:
|
| 1203 |
-
prediction_tokens = normalize_answer(prediction).split()
|
| 1204 |
-
reference_tokens = normalize_answer(reference).split()
|
| 1205 |
common = Counter(prediction_tokens) & Counter(reference_tokens)
|
| 1206 |
num_same = sum(common.values())
|
| 1207 |
if num_same == 0:
|
|
@@ -1221,9 +1686,11 @@ class BertScore(HuggingfaceBulkMetric):
|
|
| 1221 |
ci_scores = ["f1", "precision", "recall"]
|
| 1222 |
model_name: str
|
| 1223 |
|
|
|
|
|
|
|
| 1224 |
def prepare(self):
|
| 1225 |
super().prepare()
|
| 1226 |
-
self.hf_compute_args = {"model_type": self.model_name}
|
| 1227 |
|
| 1228 |
|
| 1229 |
class SentenceBert(BulkInstanceMetric):
|
|
@@ -1233,19 +1700,23 @@ class SentenceBert(BulkInstanceMetric):
|
|
| 1233 |
|
| 1234 |
model_name: str
|
| 1235 |
|
|
|
|
|
|
|
| 1236 |
def prepare(self):
|
| 1237 |
super().prepare()
|
|
|
|
| 1238 |
from sentence_transformers import SentenceTransformer
|
| 1239 |
from sentence_transformers import util as sbert_util
|
| 1240 |
|
| 1241 |
-
self.
|
|
|
|
| 1242 |
self.util = sbert_util
|
| 1243 |
|
| 1244 |
def compute(
|
| 1245 |
self,
|
| 1246 |
references: List[List[Any]],
|
| 1247 |
predictions: List[Any],
|
| 1248 |
-
|
| 1249 |
) -> List[Dict[str, Any]]:
|
| 1250 |
scores = []
|
| 1251 |
|
|
@@ -1260,9 +1731,9 @@ class SentenceBert(BulkInstanceMetric):
|
|
| 1260 |
count += len(ref_group)
|
| 1261 |
|
| 1262 |
# compute s-bert embeddings
|
| 1263 |
-
preds_emb = self.model.encode(predictions)
|
| 1264 |
refs_emb = self.model.encode(
|
| 1265 |
-
[ref for ref_group in references for ref in ref_group]
|
| 1266 |
)
|
| 1267 |
|
| 1268 |
# for each candidate, pick the reference with the highest score
|
|
@@ -1280,17 +1751,23 @@ class Reward(BulkInstanceMetric):
|
|
| 1280 |
|
| 1281 |
model_name: str
|
| 1282 |
|
|
|
|
|
|
|
| 1283 |
def prepare(self):
|
| 1284 |
super().prepare()
|
|
|
|
| 1285 |
from transformers import pipeline
|
| 1286 |
|
| 1287 |
-
|
|
|
|
|
|
|
|
|
|
| 1288 |
|
| 1289 |
def compute(
|
| 1290 |
self,
|
| 1291 |
references: List[List[Any]],
|
| 1292 |
predictions: List[Any],
|
| 1293 |
-
|
| 1294 |
) -> List[Dict[str, Any]]:
|
| 1295 |
# treat the references as the questions and the predictions as answers
|
| 1296 |
# assume a single reference
|
|
@@ -1316,25 +1793,27 @@ class Perplexity(BulkInstanceMetric):
|
|
| 1316 |
batch_size: int = 32
|
| 1317 |
model_name: str
|
| 1318 |
|
|
|
|
|
|
|
| 1319 |
def compute(
|
| 1320 |
self,
|
| 1321 |
references: List[List[Any]],
|
| 1322 |
predictions: List[Any],
|
| 1323 |
-
|
| 1324 |
) -> List[Dict[str, Any]]:
|
| 1325 |
"""Computes the likelihood of generating text Y after text X - P(Y|X).
|
| 1326 |
|
| 1327 |
-
:param
|
| 1328 |
-
:param
|
| 1329 |
|
| 1330 |
-
:return: the likelihood of generating text Y_i after text
|
| 1331 |
"""
|
| 1332 |
sources = []
|
| 1333 |
targets = []
|
| 1334 |
for prediction, instance_references in zip(predictions, references):
|
| 1335 |
for instance_reference in instance_references:
|
| 1336 |
-
sources.append(f"{self.perplexity_prompt} {
|
| 1337 |
-
targets.append(
|
| 1338 |
|
| 1339 |
from transformers import AutoConfig
|
| 1340 |
|
|
@@ -1375,9 +1854,11 @@ class Perplexity(BulkInstanceMetric):
|
|
| 1375 |
from transformers import AutoTokenizer
|
| 1376 |
|
| 1377 |
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1378 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 1379 |
-
self.model = self.model_class().from_pretrained(self.model_name)
|
| 1380 |
-
self.is_cuda = torch.cuda.is_available()
|
| 1381 |
|
| 1382 |
def compute_lm(
|
| 1383 |
self, source: List[str], target: List[str], batch_size: int
|
|
@@ -1470,16 +1951,9 @@ class Perplexity(BulkInstanceMetric):
|
|
| 1470 |
return AutoModelForSeq2SeqLM
|
| 1471 |
|
| 1472 |
def compute_batch(self, tokens_source, tokens_target):
|
| 1473 |
-
tokens_docs_ids = tokens_source["input_ids"]
|
| 1474 |
-
attention = tokens_source["attention_mask"]
|
| 1475 |
-
labels = tokens_target["input_ids"]
|
| 1476 |
-
|
| 1477 |
-
if self.is_cuda:
|
| 1478 |
-
tokens_docs_ids, attention, labels = (
|
| 1479 |
-
tokens_docs_ids.cuda(),
|
| 1480 |
-
attention.cuda(),
|
| 1481 |
-
labels.cuda(),
|
| 1482 |
-
)
|
| 1483 |
|
| 1484 |
logits = self.model(
|
| 1485 |
input_ids=tokens_docs_ids.long(),
|
|
@@ -1519,12 +1993,9 @@ class Perplexity(BulkInstanceMetric):
|
|
| 1519 |
# replace the padding token in the labels by -100
|
| 1520 |
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 1521 |
|
| 1522 |
-
|
| 1523 |
-
|
| 1524 |
-
|
| 1525 |
-
attention.cuda(),
|
| 1526 |
-
labels.cuda(),
|
| 1527 |
-
)
|
| 1528 |
|
| 1529 |
# no need to pass labels as we calculate the loss below per document
|
| 1530 |
model_output = self.model(
|
|
@@ -1558,6 +2029,8 @@ class NDCG(GlobalMetric):
|
|
| 1558 |
|
| 1559 |
main_score = "nDCG"
|
| 1560 |
|
|
|
|
|
|
|
| 1561 |
def prepare(self):
|
| 1562 |
from sklearn.metrics import ndcg_score
|
| 1563 |
|
|
@@ -1568,15 +2041,12 @@ class NDCG(GlobalMetric):
|
|
| 1568 |
self,
|
| 1569 |
references: List[List[Any]],
|
| 1570 |
predictions: List[Any],
|
| 1571 |
-
|
| 1572 |
) -> dict:
|
| 1573 |
from collections import defaultdict
|
| 1574 |
-
from statistics import mean
|
| 1575 |
|
| 1576 |
query_to_predictions_and_references = defaultdict(lambda: [[], []])
|
| 1577 |
-
for reference, pred, inputs_dict in zip(
|
| 1578 |
-
references, predictions, additional_inputs
|
| 1579 |
-
):
|
| 1580 |
query = inputs_dict.get("query")
|
| 1581 |
query_to_predictions_and_references[query][0].append(pred)
|
| 1582 |
query_to_predictions_and_references[query][1].append(reference)
|
|
@@ -1606,9 +2076,7 @@ class NDCG(GlobalMetric):
|
|
| 1606 |
|
| 1607 |
|
| 1608 |
class RetrievalMetric(InstanceMetric):
|
| 1609 |
-
def compute(
|
| 1610 |
-
self, references: List[Any], prediction: Any, additional_inputs: Dict
|
| 1611 |
-
) -> dict:
|
| 1612 |
# digest input
|
| 1613 |
pred_ids: List[Any] = prediction
|
| 1614 |
ref_ids: List[Any] = list(dict.fromkeys(references))
|
|
@@ -1681,6 +2149,7 @@ class RetrievalMetric(InstanceMetric):
|
|
| 1681 |
class MRR(RetrievalMetric):
|
| 1682 |
reduction_map = {"mean": ["mrr"]}
|
| 1683 |
main_score = "mrr"
|
|
|
|
| 1684 |
|
| 1685 |
def _compute(
|
| 1686 |
self,
|
|
@@ -1697,6 +2166,7 @@ class MRR(RetrievalMetric):
|
|
| 1697 |
class MAP(RetrievalMetric):
|
| 1698 |
reduction_map = {"mean": ["map"]}
|
| 1699 |
main_score = "map"
|
|
|
|
| 1700 |
|
| 1701 |
def _compute(
|
| 1702 |
self,
|
|
@@ -1765,3 +2235,672 @@ class KPA(CustomF1):
|
|
| 1765 |
|
| 1766 |
def should_ignore_element(self, element, additional_input):
|
| 1767 |
return element == "none"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import re
|
| 2 |
import string
|
| 3 |
import uuid
|
| 4 |
+
import warnings
|
| 5 |
from abc import ABC, abstractmethod
|
| 6 |
from collections import Counter
|
| 7 |
+
from copy import deepcopy
|
| 8 |
from dataclasses import field
|
| 9 |
+
from statistics import mean
|
| 10 |
from typing import Any, Dict, Generator, List, Optional, Tuple
|
| 11 |
|
| 12 |
import evaluate
|
| 13 |
import numpy
|
| 14 |
import numpy as np
|
| 15 |
from scipy.stats import bootstrap
|
| 16 |
+
from scipy.stats._warnings_errors import DegenerateDataWarning
|
| 17 |
|
| 18 |
from .artifact import Artifact
|
| 19 |
from .dataclass import InternalField, OptionalField
|
| 20 |
from .logging_utils import get_logger
|
| 21 |
+
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
| 22 |
from .operator import (
|
| 23 |
MultiStreamOperator,
|
| 24 |
SingleStreamOperator,
|
|
|
|
| 27 |
)
|
| 28 |
from .operators import CopyFields
|
| 29 |
from .random_utils import get_seed
|
| 30 |
+
from .settings_utils import get_settings
|
| 31 |
from .stream import MultiStream, Stream
|
| 32 |
+
from .type_utils import isoftype, to_float_or_default
|
| 33 |
|
| 34 |
logger = get_logger()
|
| 35 |
+
settings = get_settings()
|
| 36 |
+
|
| 37 |
+
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
warnings.filterwarnings("ignore", category=DegenerateDataWarning)
|
| 41 |
|
| 42 |
|
| 43 |
def abstract_factory():
|
|
|
|
| 48 |
return field(default_factory=abstract_factory)
|
| 49 |
|
| 50 |
|
| 51 |
+
def nan_mean(x):
|
| 52 |
+
import warnings
|
| 53 |
+
|
| 54 |
+
with warnings.catch_warnings():
|
| 55 |
+
# final mean should be mean of scores, ignoring NaN, hence nanmean
|
| 56 |
+
# but if the group function values is NaN for ALL values, nanmean throws a
|
| 57 |
+
# RuntimeWarning that it is calculating the mean of an empty slice (with no non-Nans)
|
| 58 |
+
# this is the desired behavior, but we want to avoid the warning here
|
| 59 |
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
| 60 |
+
return np.nanmean(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
class UpdateStream(StreamInstanceOperator):
|
| 64 |
update: dict
|
| 65 |
|
|
|
|
| 77 |
def main_score(self):
|
| 78 |
pass
|
| 79 |
|
| 80 |
+
def consume_stream(self, stream: Stream):
|
| 81 |
+
references = []
|
| 82 |
+
predictions = []
|
| 83 |
+
additional_inputs = []
|
| 84 |
+
instances = []
|
| 85 |
+
for instance in stream:
|
| 86 |
+
references.append(instance["references"])
|
| 87 |
+
predictions.append(instance["prediction"])
|
| 88 |
+
additional_inputs.append(
|
| 89 |
+
instance["additional_inputs"] if "additional_inputs" in instance else {}
|
| 90 |
+
)
|
| 91 |
+
instances.append(instance)
|
| 92 |
+
return predictions, references, additional_inputs, instances
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def update_instance_scores(instances, instances_scores: List[Dict[str, Any]]):
|
| 96 |
+
for instance, new_scores in zip(instances, instances_scores):
|
| 97 |
+
if "score" not in instance:
|
| 98 |
+
instance["score"] = {}
|
| 99 |
+
scores = instance["score"]
|
| 100 |
+
if "instance" not in scores:
|
| 101 |
+
scores["instance"] = {}
|
| 102 |
+
scores["instance"].update(new_scores)
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def set_global_score(instances, global_score: Dict[str, Any]):
|
| 106 |
+
for instance in instances:
|
| 107 |
+
if "score" not in instance:
|
| 108 |
+
instance["score"] = {}
|
| 109 |
+
scores = instance["score"]
|
| 110 |
+
if "global" not in scores:
|
| 111 |
+
scores["global"] = {}
|
| 112 |
+
scores["global"] = global_score
|
| 113 |
+
|
| 114 |
+
@abstractmethod
|
| 115 |
+
def disable_confidence_interval_calculation(self):
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
@abstractmethod
|
| 119 |
+
def set_n_resamples(self, n_resample):
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
|
| 123 |
class MetricWithConfidenceInterval(Metric):
|
| 124 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
|
|
|
| 135 |
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
| 136 |
|
| 137 |
def disable_confidence_interval_calculation(self):
|
| 138 |
+
n = self.n_resamples
|
| 139 |
self.n_resamples = None
|
| 140 |
+
return n
|
| 141 |
+
|
| 142 |
+
def set_n_resamples(self, n_resamples):
|
| 143 |
+
self.n_resamples = n_resamples
|
| 144 |
|
| 145 |
def _can_compute_confidence_intervals(self, num_predictions):
|
| 146 |
return (
|
|
|
|
| 149 |
and num_predictions > 1
|
| 150 |
)
|
| 151 |
|
| 152 |
+
@staticmethod
|
| 153 |
+
def average_item_scores(instances: List[dict], score_name: str):
|
| 154 |
+
"""Calculate mean of a set of instance scores (given by score_name), omitting NaN values.
|
| 155 |
|
| 156 |
+
Args:
|
| 157 |
+
instances: list of dicts of each instance's instance scores.
|
| 158 |
+
score_name: score field names to compute the mean for.
|
|
|
|
| 159 |
"""
|
| 160 |
+
return nan_mean(
|
| 161 |
+
[instance["score"]["instance"][score_name] for instance in instances]
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def score_based_confidence_interval(
|
| 165 |
+
self,
|
| 166 |
+
instances: List[dict],
|
| 167 |
+
score_names: List[str],
|
| 168 |
+
aggregation_func=None,
|
| 169 |
+
ci_score_prefix="",
|
| 170 |
+
):
|
| 171 |
+
"""Compute confidence intervals based on existing scores, already computed on the input instances.
|
| 172 |
+
|
| 173 |
+
Unlike GlobalMetric, this is simply a function of the instance scores (possibly taking into account task_data field),
|
| 174 |
+
so they don't need to be recomputed after every bootstrap draw.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
instances: The instances for which the confidence intervals are computed; should already have the relevant instance scores calculated.
|
| 178 |
+
score_names: List of instance score field names to compute a confidence interval for.
|
| 179 |
+
aggregation_func: A function with arguments instances, field_name; is applied on list of instances (which may include task_data
|
| 180 |
+
field, as well as the prediction and references), and the field_name; default is simply to take the mean field_name from
|
| 181 |
+
instances after resampling, if argument is None.
|
| 182 |
+
ci_score_prefix: An optional string prefix to the score_name in the CI. Useful in cases where the
|
| 183 |
+
aggregation_func is something other than the mean
|
| 184 |
|
| 185 |
+
Returns:
|
| 186 |
+
Dict of confidence interval values
|
| 187 |
+
"""
|
| 188 |
result = {}
|
| 189 |
|
| 190 |
if not self._can_compute_confidence_intervals(num_predictions=len(instances)):
|
| 191 |
return result
|
| 192 |
|
| 193 |
+
ci_score_prefix = str(ci_score_prefix)
|
| 194 |
+
if aggregation_func is None:
|
| 195 |
+
# if aggregation_func is None, we simply take the mean of the resampled instance scores
|
| 196 |
+
# otherwise, the aggregation_func needs to be applied AFTER resampling the instances;
|
| 197 |
+
# that is, re-form the groups, calculate the function, and take the mean of the group scores
|
| 198 |
+
aggregation_func = self.average_item_scores
|
| 199 |
for score_name in score_names:
|
| 200 |
+
# need to redefine the statistic function within the loop because score_name is a loop variable
|
| 201 |
+
def statistic(arr, axis, score_name=score_name):
|
| 202 |
+
# arr is a 2d array where each row is a resampling, so we
|
| 203 |
+
# iterate over the rows and compute the metric on each resampling
|
| 204 |
+
scores = numpy.apply_along_axis(
|
| 205 |
+
lambda resampled_instances: aggregation_func(
|
| 206 |
+
resampled_instances, score_name
|
| 207 |
+
),
|
| 208 |
+
axis=axis,
|
| 209 |
+
arr=arr,
|
| 210 |
+
)
|
| 211 |
+
return self.resample_from_non_nan(scores)
|
| 212 |
+
|
| 213 |
+
# apply bootstrap only on the relevant field
|
| 214 |
ci = bootstrap(
|
| 215 |
+
(instances,),
|
| 216 |
+
statistic=statistic,
|
| 217 |
n_resamples=self.n_resamples,
|
| 218 |
confidence_level=self.confidence_level,
|
| 219 |
random_state=self.new_random_generator(),
|
| 220 |
).confidence_interval
|
| 221 |
+
full_score_name = ci_score_prefix + score_name
|
| 222 |
+
result[f"{full_score_name}_ci_low"] = ci.low
|
| 223 |
+
result[f"{full_score_name}_ci_high"] = ci.high
|
| 224 |
if score_name == self.main_score:
|
| 225 |
result["score_ci_low"] = ci.low
|
| 226 |
result["score_ci_high"] = ci.high
|
| 227 |
return result
|
| 228 |
|
| 229 |
+
def resample_from_non_nan(self, values):
|
| 230 |
+
"""Given an array values, will replace any NaN values with elements resampled with replacement from the non-NaN ones.
|
| 231 |
+
|
| 232 |
+
here we deal with samples on which the metric could not be computed. These are
|
| 233 |
+
edge cases - for example, when the sample contains only empty strings.
|
| 234 |
+
CI is about the distribution around the statistic (e.g. mean), it doesn't deal with
|
| 235 |
+
cases in which the metric is not computable. Therefore, we ignore these edge cases
|
| 236 |
+
as part of the computation of CI.
|
| 237 |
+
|
| 238 |
+
In theory there would be several ways to deal with this:
|
| 239 |
+
1. skip the errors and return a shorter array => this fails because Scipy requires
|
| 240 |
+
this callback (i.e. the statistic() callback) to return an array of the same size
|
| 241 |
+
as the number of resamples
|
| 242 |
+
2. Put np.nan for the errors => this fails because in such case the ci itself
|
| 243 |
+
becomes np.nan. So one edge case can fail the whole CI computation.
|
| 244 |
+
3. Replace the errors with a sampling from the successful cases => this is what is implemented.
|
| 245 |
+
|
| 246 |
+
This resampling makes it so that, if possible, the bca confidence interval returned by bootstrap will not be NaN, since
|
| 247 |
+
bootstrap does not ignore NaNs. However, if there are 0 or 1 non-NaN values, or all non-NaN values are equal,
|
| 248 |
+
the resulting distribution will be degenerate (only one unique value) so the CI will still be NaN since there is
|
| 249 |
+
no variability. In this case, the CI is essentially an interval of length 0 equaling the mean itself.
|
| 250 |
+
"""
|
| 251 |
+
if values.size > 1:
|
| 252 |
+
error_indices = numpy.isnan(values)
|
| 253 |
+
n_errors = sum(error_indices)
|
| 254 |
+
if 0 < n_errors < values.size:
|
| 255 |
+
# replace NaN aggregate scores with random draws from non-NaN scores, so that confidence interval isn't NaN itself
|
| 256 |
+
values[error_indices] = self.new_random_generator().choice(
|
| 257 |
+
values[~error_indices], n_errors, replace=True
|
| 258 |
+
)
|
| 259 |
+
return values
|
| 260 |
+
|
| 261 |
def compute_global_confidence_intervals(
|
| 262 |
+
self, references, predictions, task_data, score_name
|
| 263 |
):
|
| 264 |
"""Computed confidence intervals for a set of references and predictions."""
|
| 265 |
random_gen = self.new_random_generator()
|
|
|
|
| 267 |
def statistic(arr, axis):
|
| 268 |
# arr is a 2d array where each row is a resampling, so we
|
| 269 |
# iterate over the rows and compute the metric on each resampling
|
| 270 |
+
def metric(sample_refs, sample_preds, sample_task_data):
|
| 271 |
try:
|
| 272 |
return self._compute(
|
| 273 |
references=sample_refs,
|
| 274 |
predictions=sample_preds,
|
| 275 |
+
task_data=sample_task_data,
|
| 276 |
)["score"]
|
| 277 |
except Exception as e:
|
| 278 |
# this happens in edge cases, for example, when the sampling creates a
|
|
|
|
| 280 |
logger.info(f"Warning in {self.__class__.__name__}", e)
|
| 281 |
return np.nan
|
| 282 |
|
| 283 |
+
# resample the instance scores, and then return the global score each time
|
| 284 |
scores = numpy.apply_along_axis(
|
| 285 |
lambda x: metric(
|
| 286 |
sample_refs=[references[i] for i in x],
|
| 287 |
sample_preds=[predictions[i] for i in x],
|
| 288 |
+
sample_task_data=[task_data[i] for i in x],
|
| 289 |
),
|
| 290 |
axis=axis,
|
| 291 |
arr=arr,
|
| 292 |
)
|
| 293 |
|
| 294 |
+
# in some resamplings of instances, the global score may be NaN since it cannot be computed;
|
| 295 |
+
# in these cases, the bca confidence interval will be NaN because it does not ignore these values,
|
| 296 |
+
# so we replace any NaN values with those resampled from the non-NaN ones.
|
| 297 |
+
return self.resample_from_non_nan(scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
result = {}
|
| 300 |
num_predictions = len(predictions)
|
|
|
|
| 322 |
need to be considered. Accuracy, on the other hand, is just an average of the accuracy of all the instances.
|
| 323 |
"""
|
| 324 |
|
| 325 |
+
n_resamples: int = OptionalField(
|
| 326 |
+
default_factory=lambda: settings.num_resamples_for_global_metrics
|
| 327 |
+
)
|
| 328 |
+
process_single_instances = True
|
| 329 |
|
| 330 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 331 |
references = []
|
| 332 |
predictions = []
|
| 333 |
+
task_data = []
|
| 334 |
global_score = {}
|
| 335 |
|
| 336 |
instances = []
|
|
|
|
| 349 |
predictions.append(instance_prediction)
|
| 350 |
instances.append(instance)
|
| 351 |
|
| 352 |
+
instance_task_data = (
|
| 353 |
+
instance["task_data"] if "task_data" in instance else {}
|
| 354 |
)
|
| 355 |
+
task_data.append(instance_task_data)
|
| 356 |
+
instance_score = None
|
| 357 |
+
# for backward compatibility
|
| 358 |
+
no_score_value = np.nan
|
| 359 |
+
if self.process_single_instances:
|
| 360 |
+
try:
|
| 361 |
+
instance_score = self._compute(
|
| 362 |
+
[instance_references],
|
| 363 |
+
[instance_prediction],
|
| 364 |
+
[instance_task_data],
|
| 365 |
+
)
|
| 366 |
+
except:
|
| 367 |
+
no_score_value = None
|
| 368 |
+
if not instance_score:
|
| 369 |
+
instance_score = {
|
| 370 |
+
"score": no_score_value,
|
| 371 |
+
"score_name": self.main_score,
|
| 372 |
+
}
|
| 373 |
|
| 374 |
if isinstance(self.main_score, str):
|
| 375 |
+
instance_score[self.main_score] = no_score_value
|
| 376 |
|
| 377 |
instance["score"]["instance"].update(instance_score)
|
| 378 |
|
| 379 |
+
result = self._compute(references, predictions, task_data)
|
| 380 |
|
| 381 |
global_score.update(result)
|
| 382 |
|
| 383 |
score_name = global_score["score_name"]
|
| 384 |
confidence_interval = self.compute_global_confidence_intervals(
|
| 385 |
+
references, predictions, task_data, score_name
|
| 386 |
)
|
| 387 |
global_score.update(confidence_interval)
|
| 388 |
|
|
|
|
| 394 |
self,
|
| 395 |
references: List[List[str]],
|
| 396 |
predictions: List[str],
|
| 397 |
+
task_data: List[Any],
|
| 398 |
) -> dict:
|
| 399 |
+
result = self.compute(references, predictions, task_data)
|
| 400 |
result["score"] = result[self.main_score]
|
| 401 |
result["score_name"] = self.main_score
|
| 402 |
return result
|
|
|
|
| 406 |
self,
|
| 407 |
references: List[List[Any]],
|
| 408 |
predictions: List[Any],
|
| 409 |
+
task_data: List[Any],
|
| 410 |
) -> dict:
|
| 411 |
+
"""Computes a scores dictionary on a list of references, predictions and input.
|
| 412 |
+
|
| 413 |
+
This function is called once per instance, and then another time
|
| 414 |
+
over all data instances.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
a dictionary of scores that is set as:
|
| 418 |
+
the instance scores when called on a single data instance
|
| 419 |
+
the global score when called on the all data instances
|
| 420 |
+
"""
|
| 421 |
pass
|
| 422 |
|
| 423 |
|
| 424 |
class BulkInstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
| 425 |
+
n_resamples: int = OptionalField(
|
| 426 |
+
default_factory=lambda: settings.num_resamples_for_instance_metrics
|
| 427 |
+
)
|
| 428 |
main_score: str
|
| 429 |
reduction_map: Dict[str, List[str]]
|
| 430 |
|
|
|
|
| 445 |
),
|
| 446 |
)
|
| 447 |
|
| 448 |
+
task_data = [
|
| 449 |
+
instance["task_data"] if "task_data" in instance else {}
|
| 450 |
for instance in stream
|
| 451 |
]
|
| 452 |
|
|
|
|
| 454 |
instance_scores = self.compute(
|
| 455 |
references=references,
|
| 456 |
predictions=predictions,
|
| 457 |
+
task_data=task_data,
|
| 458 |
)
|
| 459 |
|
| 460 |
# add the score and score_name fields
|
|
|
|
| 478 |
), f"Reduction {reduction} is not implemented, use one of {self.implemented_reductions}"
|
| 479 |
|
| 480 |
if reduction == "mean":
|
|
|
|
|
|
|
| 481 |
for field_name in fields:
|
| 482 |
global_score[field_name] = mean(
|
| 483 |
[
|
|
|
|
| 489 |
global_score["score"] = global_score[field_name]
|
| 490 |
global_score["score_name"] = self.main_score
|
| 491 |
|
| 492 |
+
ci_fields = (
|
| 493 |
+
list(set(self.ci_scores))
|
| 494 |
+
if self.ci_scores is not None
|
| 495 |
+
else [self.main_score]
|
| 496 |
+
)
|
| 497 |
confidence_interval = self.score_based_confidence_interval(
|
| 498 |
+
instances=instances, score_names=ci_fields
|
| 499 |
)
|
| 500 |
global_score.update(confidence_interval)
|
| 501 |
|
|
|
|
| 507 |
self,
|
| 508 |
references: List[List[Any]],
|
| 509 |
predictions: List[Any],
|
| 510 |
+
task_data: List[Dict],
|
| 511 |
) -> List[Dict[str, Any]]:
|
| 512 |
pass
|
| 513 |
|
| 514 |
|
| 515 |
class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
| 516 |
+
"""Class for metrics for which a global score can be calculated by aggregating the instance scores (possibly with additional instance inputs).
|
| 517 |
+
|
| 518 |
+
InstanceMetric currently allows two reductions:
|
| 519 |
+
1. 'mean', which calculates the mean of instance scores,
|
| 520 |
+
2. 'group_mean', which first applies an aggregation function specified in the reduction_map
|
| 521 |
+
to instance scores grouped by the field grouping_field (which must not be None), and returns the mean
|
| 522 |
+
of the group scores; if grouping_field is None, grouping is disabled.
|
| 523 |
+
See _validate_group_mean_reduction for formatting instructions.
|
| 524 |
+
"""
|
| 525 |
|
| 526 |
+
n_resamples: int = OptionalField(
|
| 527 |
+
default_factory=lambda: settings.num_resamples_for_instance_metrics
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
# some group_mean aggregation functions (3rd element of "agg_func" list in the reduction)
|
| 531 |
+
# only require a list of instance scores (e.g., mean, median, etc.). Others aggregation functions
|
| 532 |
+
# require an additional column (e.g., a subgroup identifier) by which the instance scores will be grouped
|
| 533 |
+
# if subgroup_column is not None, a column by the specified name will be required in task_data
|
| 534 |
+
subgroup_column = None
|
| 535 |
+
implemented_reductions: List[str] = field(
|
| 536 |
+
default_factory=lambda: ["mean", "group_mean"]
|
| 537 |
+
)
|
| 538 |
|
| 539 |
@property
|
| 540 |
@abstractmethod
|
| 541 |
def reduction_map(self) -> dict:
|
| 542 |
pass
|
| 543 |
|
| 544 |
+
def _validate_group_mean_reduction(self, instances: List[dict]):
|
| 545 |
+
"""Ensure that group_mean reduction_map is properly formatted.
|
| 546 |
+
|
| 547 |
+
Example: Apply the variance (np.var) to group Accuracy instance scores. This class would be specified as follows:
|
| 548 |
+
|
| 549 |
+
class GroupVarianceAccuracy(Accuracy):
|
| 550 |
+
reduction_map = {'group_mean': {'agg_func': ['variance', np.var, True]}}
|
| 551 |
+
|
| 552 |
+
reduction_map must be a dict with values containing
|
| 553 |
+
- an 'agg_func' field with value being a 3-element list where
|
| 554 |
+
- 1st element is a string name of the aggregation function (used in naming the CI report)
|
| 555 |
+
- 2nd element is the callable aggregation function
|
| 556 |
+
- 3rd element is a Boolean indicator of whether, during boostrap CI calculation, the groups are to be sampled as single units.
|
| 557 |
+
If True, the group scores are calculated and then resampled. This treats the group units as the unit of
|
| 558 |
+
interest for which the CI is being compared.
|
| 559 |
+
If False, the instances are resampled individually, and the groups determined
|
| 560 |
+
(meaning the groups may be of slightly different size or composition from the original
|
| 561 |
+
depending on the resampling of the instances).
|
| 562 |
+
- Optional: 'score_fields' key with list value containing the string names of fields to apply the aggregation to
|
| 563 |
+
- If not present, the parent class main_score is used.
|
| 564 |
+
|
| 565 |
+
The aggregation function (2nd element of agg_func) can be one of two types:
|
| 566 |
+
1. simple: calculate a summary statistic from a single group of values (e.g. mean, median, etc.).
|
| 567 |
+
This is best suited for cases where the instances are independent of each other, other than belonging to the same group
|
| 568 |
+
2. comparison: requires subgroup_column to be specified. This function conducts
|
| 569 |
+
a comparison between scores for differing values of subgroup_column (e.g., 'original' vs 'paraphrase').
|
| 570 |
+
An example is where the original instance is a question, and the others are various paraphrases
|
| 571 |
+
or perturbations of this question. Here, the function would return, say, a comparison of the instance accuracies
|
| 572 |
+
rather than, say, the average instance accuracy.
|
| 573 |
+
In these cases, we recommend setting the 3rd parameter to be True so that the groups are resampled together.
|
| 574 |
+
|
| 575 |
+
Example:
|
| 576 |
+
class GroupVsBaselineDiffAccuracy(Accuracy):
|
| 577 |
+
subgroup_column = 'variant_type'
|
| 578 |
+
reduction_map = {'group_mean': {'agg_func': ['accuracy_diff', accuracy_diff, True],}}
|
| 579 |
+
|
| 580 |
+
# where the function is defined as
|
| 581 |
+
def accuracy_diff(subgroup_scores_dict, expected_subgroup_types=['original', 'paraphrase']):
|
| 582 |
+
validate_subgroup_types(subgroup_scores_dict, expected_subgroup_types)
|
| 583 |
+
from statistics import mean
|
| 584 |
+
return mean(subgroup_scores_dict['paraphrase']) - mean(subgroup_scores_dict['original'])
|
| 585 |
+
The input dataset should look like:
|
| 586 |
+
|
| 587 |
+
'group_id' 'question' 'variant_type'
|
| 588 |
+
1 'How do you fix a car engine?' 'original'
|
| 589 |
+
1 'What is the best way to fix an engine?' 'paraphrase'
|
| 590 |
+
1 'How do you repair a car engine?' 'paraphrase'
|
| 591 |
+
1 'How do I repair my engine?' 'paraphrase'
|
| 592 |
+
2 'Why are ants eating my food?' 'original'
|
| 593 |
+
"""
|
| 594 |
+
# instances need to all have task_data field with field group_id
|
| 595 |
+
assert all(
|
| 596 |
+
"task_data" in instance for instance in instances
|
| 597 |
+
), "each instance must have an task_data field"
|
| 598 |
+
assert all(
|
| 599 |
+
isinstance(instance["task_data"], dict) for instance in instances
|
| 600 |
+
), "each instance must have an task_data field that is a dict"
|
| 601 |
+
assert all(
|
| 602 |
+
"group_id" in instance["task_data"] for instance in instances
|
| 603 |
+
), "each instance task_data dict must have a key group_id"
|
| 604 |
+
|
| 605 |
+
# validate the reduction_map
|
| 606 |
+
assert (
|
| 607 |
+
"group_mean" in self.reduction_map
|
| 608 |
+
), "reduction_map must have a 'group_mean' key"
|
| 609 |
+
fields = self.reduction_map["group_mean"]
|
| 610 |
+
# for group_mean, expects a dict
|
| 611 |
+
assert isinstance(fields, dict)
|
| 612 |
+
assert (
|
| 613 |
+
"agg_func" in fields
|
| 614 |
+
), "fields should have a key 'agg_func' whose value is a 3-element list of a function name, function definition, and a boolean indicator"
|
| 615 |
+
assert isinstance(
|
| 616 |
+
fields["agg_func"], list
|
| 617 |
+
), "fields['agg_func'] should be a list"
|
| 618 |
+
assert (
|
| 619 |
+
len(fields["agg_func"]) == 3
|
| 620 |
+
), "fields['agg_func'] should be a 3-element list"
|
| 621 |
+
assert isinstance(
|
| 622 |
+
fields["agg_func"][0], str
|
| 623 |
+
), "first item in fields['agg_func'] should be a string name of a function"
|
| 624 |
+
assert callable(
|
| 625 |
+
fields["agg_func"][1]
|
| 626 |
+
), "second item in fields['agg_func'] should be a callable function"
|
| 627 |
+
assert isinstance(
|
| 628 |
+
fields["agg_func"][2], bool
|
| 629 |
+
), "third item in fields['agg_func'] should be a boolean value"
|
| 630 |
+
if "score_fields" in fields:
|
| 631 |
+
assert isinstance(fields["score_fields"], list)
|
| 632 |
+
|
| 633 |
+
# for aggregation functions that use the subgroup_column (expect a dict of lists), check that
|
| 634 |
+
# this field exists
|
| 635 |
+
if self.subgroup_column is not None:
|
| 636 |
+
assert all(
|
| 637 |
+
self.subgroup_column in instance["task_data"] for instance in instances
|
| 638 |
+
), f"each instance task_data dict must have a key {self.subgroup_column}"
|
| 639 |
+
|
| 640 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 641 |
+
instances, global_score = self.compute_instance_scores(stream)
|
| 642 |
+
|
| 643 |
+
for reduction_type, reduction_params in self.reduction_map.items():
|
| 644 |
+
assert (
|
| 645 |
+
reduction_type in self.implemented_reductions
|
| 646 |
+
), f"Reduction {reduction_type} is not implemented, use one of {self.implemented_reductions}"
|
| 647 |
+
|
| 648 |
+
field_name_full_prefix = ""
|
| 649 |
+
# used for passing to the bootstrapping, depends on whether the groups are fixed or not
|
| 650 |
+
aggregation_function = self.average_item_scores
|
| 651 |
+
if reduction_type == "mean":
|
| 652 |
+
reduction_fields = list(set(reduction_params))
|
| 653 |
+
# no group reduction, so resample instances individually
|
| 654 |
+
scores_to_resample = instances
|
| 655 |
+
elif reduction_type == "group_mean":
|
| 656 |
+
self._validate_group_mean_reduction(instances=instances)
|
| 657 |
+
reduction_fields = (
|
| 658 |
+
[self.main_score]
|
| 659 |
+
if "score_fields" not in reduction_params
|
| 660 |
+
else list(set(reduction_params["score_fields"]))
|
| 661 |
+
)
|
| 662 |
+
aggregation_function_name = str(reduction_params["agg_func"][0])
|
| 663 |
+
field_name_full_prefix = "group_" + aggregation_function_name + "_"
|
| 664 |
+
do_resample_as_group = reduction_params["agg_func"][2]
|
| 665 |
+
if do_resample_as_group:
|
| 666 |
+
# append fixed_ to name because resamples the groups as fixed units
|
| 667 |
+
field_name_full_prefix = "fixed_" + field_name_full_prefix
|
| 668 |
+
(
|
| 669 |
+
scores_to_resample,
|
| 670 |
+
aggregation_function,
|
| 671 |
+
) = self._set_up_group_mean_aggregation(
|
| 672 |
+
instances, reduction_params, reduction_fields
|
| 673 |
+
)
|
| 674 |
+
else:
|
| 675 |
+
raise ValueError(
|
| 676 |
+
f"Reduction {reduction_type} is not supported, please specify a valid reduction method in reduction_map {self.reduction_map}."
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# calculate global scores for each reduction field
|
| 680 |
+
for field_name in reduction_fields:
|
| 681 |
+
field_name_full = field_name_full_prefix + field_name
|
| 682 |
+
# if group resampling (3rd element of agg_func parameter) is True, then
|
| 683 |
+
# 1. scores_to_resample are the group scores, and
|
| 684 |
+
# 2. aggregation_function is to take the raw mean
|
| 685 |
+
# if no group resampling (3rd element of agg_func parameter) is False, then
|
| 686 |
+
# 1. scores_to_resample are the original instance scores, and
|
| 687 |
+
# 2. aggregation_function is to apply the group aggregation from the instance scores
|
| 688 |
+
# either way, the application of aggregation_function to scores_to_resample yields the global score
|
| 689 |
+
global_score[field_name_full] = aggregation_function(
|
| 690 |
+
scores_to_resample, field_name
|
| 691 |
+
)
|
| 692 |
+
if field_name == self.main_score:
|
| 693 |
+
global_score["score"] = global_score[field_name_full]
|
| 694 |
+
global_score["score_name"] = field_name_full
|
| 695 |
+
|
| 696 |
+
# need to specify which fields should have CIs calculated for them through ci_scores
|
| 697 |
+
# (will not automatically calculate CIs for fields in reduction map)
|
| 698 |
+
if self.ci_scores is not None:
|
| 699 |
+
confidence_interval = self.score_based_confidence_interval(
|
| 700 |
+
instances=scores_to_resample,
|
| 701 |
+
score_names=list(set(self.ci_scores)),
|
| 702 |
+
ci_score_prefix=field_name_full_prefix,
|
| 703 |
+
aggregation_func=aggregation_function,
|
| 704 |
+
)
|
| 705 |
+
global_score.update(confidence_interval)
|
| 706 |
+
|
| 707 |
+
yield from instances
|
| 708 |
+
|
| 709 |
+
def compute_instance_scores(
|
| 710 |
+
self, stream: Stream, stream_name: Optional[str] = None
|
| 711 |
+
):
|
| 712 |
global_score = {}
|
| 713 |
instances = []
|
| 714 |
|
| 715 |
for instance in stream:
|
| 716 |
refs, pred = instance["references"], instance["prediction"]
|
| 717 |
+
task_data = instance["task_data"] if "task_data" in instance else {}
|
|
|
|
|
|
|
| 718 |
|
| 719 |
instance_score = self.compute(
|
| 720 |
+
references=refs, prediction=pred, task_data=task_data
|
| 721 |
)
|
| 722 |
instance_score["score"] = instance_score[self.main_score]
|
| 723 |
instance_score["score_name"] = self.main_score
|
|
|
|
| 730 |
|
| 731 |
instances.append(instance)
|
| 732 |
|
| 733 |
+
return instances, global_score
|
|
|
|
|
|
|
|
|
|
| 734 |
|
| 735 |
+
def get_group_scores(
|
| 736 |
+
self, instances: List[dict], score_names: List[str], group_aggregation_func
|
| 737 |
+
):
|
| 738 |
+
"""Group scores by the group_id and subgroup_type fields of each instance, and compute group_aggregation_func by group.
|
| 739 |
+
|
| 740 |
+
Args:
|
| 741 |
+
instances: List of observation instances with instance-level scores (fields) computed.
|
| 742 |
+
score_names: List of instance score names in each instance to apply the aggregation function.
|
| 743 |
+
group_aggregation_func: Callable aggregation function accepting a list of numeric scores;
|
| 744 |
+
or, if self.subgroup_column is not None, a dict of subgroup types scores by subgroup_column value.
|
| 745 |
+
callable function returns a single score for the group
|
| 746 |
+
|
| 747 |
+
Returns:
|
| 748 |
+
List of dicts, each corresponding to a group of instances (defined by 'group_id'),
|
| 749 |
+
with an aggregate group score for each score_name
|
| 750 |
+
"""
|
| 751 |
+
from collections import defaultdict
|
| 752 |
|
| 753 |
+
# three-level defaultdict:
|
| 754 |
+
# first is the grouping, second is the field name, the third is the subgroup_type (by default 'default')
|
| 755 |
+
group_to_instance_scores = defaultdict(
|
| 756 |
+
lambda: defaultdict(lambda: defaultdict(list))
|
| 757 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
|
| 759 |
+
# check if function has fields for subgroup_column
|
| 760 |
+
uses_subgroups = self.subgroup_column is not None
|
| 761 |
+
default_subgroup_name = "default"
|
| 762 |
+
# loop through the instances and group the scores
|
| 763 |
+
for instance in instances:
|
| 764 |
+
task_data = instance["task_data"]
|
| 765 |
+
group_key = task_data["group_id"]
|
| 766 |
+
# for functions that do comparisons between subgroup_column groups
|
| 767 |
+
# if function doesn't use subgroup_column, or none is present, set "default" as default value, and pass all scores
|
| 768 |
+
subgroup_type = (
|
| 769 |
+
task_data[self.subgroup_column]
|
| 770 |
+
if uses_subgroups
|
| 771 |
+
else default_subgroup_name
|
| 772 |
+
)
|
| 773 |
+
for score_name in score_names:
|
| 774 |
+
group_to_instance_scores[group_key][score_name][subgroup_type].append(
|
| 775 |
+
instance["score"]["instance"][score_name]
|
| 776 |
)
|
|
|
|
| 777 |
|
| 778 |
+
# if group_aggregation_func expects a subgroup-types score dict, pass it; otherwise pass the default type list of scores
|
| 779 |
+
return [
|
| 780 |
+
{
|
| 781 |
+
"score": {
|
| 782 |
+
"instance": {
|
| 783 |
+
score_name: group_aggregation_func(
|
| 784 |
+
score_dict
|
| 785 |
+
if uses_subgroups
|
| 786 |
+
else score_dict[default_subgroup_name]
|
| 787 |
+
)
|
| 788 |
+
for score_name, score_dict in group_scores.items()
|
| 789 |
+
}
|
| 790 |
+
}
|
| 791 |
+
}
|
| 792 |
+
for group_scores in group_to_instance_scores.values()
|
| 793 |
+
]
|
| 794 |
+
|
| 795 |
+
def _set_up_group_mean_aggregation(
|
| 796 |
+
self, instances, reduction_params, reduction_fields
|
| 797 |
+
):
|
| 798 |
+
group_aggregation_func = reduction_params["agg_func"][1]
|
| 799 |
+
# if treat groups as units
|
| 800 |
+
do_resample_as_group = reduction_params["agg_func"][2]
|
| 801 |
+
if do_resample_as_group:
|
| 802 |
+
# pass the group aggregate---not instance---scores to resample as usual
|
| 803 |
+
aggregation_function = self.average_item_scores
|
| 804 |
+
scores_to_resample = self.get_group_scores(
|
| 805 |
+
instances, reduction_fields, group_aggregation_func
|
| 806 |
+
)
|
| 807 |
+
else:
|
| 808 |
+
# pass the instance scores to resample, and calculate the group aggregation on the resamplings
|
| 809 |
+
scores_to_resample = instances
|
| 810 |
+
|
| 811 |
+
def aggregation_function(
|
| 812 |
+
instances,
|
| 813 |
+
field_name,
|
| 814 |
+
group_aggregation_func=group_aggregation_func,
|
| 815 |
+
):
|
| 816 |
+
group_scores = self.get_group_scores(
|
| 817 |
+
instances, [field_name], group_aggregation_func
|
| 818 |
+
)
|
| 819 |
+
return nan_mean(
|
| 820 |
+
[group["score"]["instance"][field_name] for group in group_scores]
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
return scores_to_resample, aggregation_function
|
| 824 |
|
| 825 |
@abstractmethod
|
| 826 |
+
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
|
|
|
|
|
|
| 827 |
pass
|
| 828 |
|
| 829 |
|
|
|
|
| 840 |
self,
|
| 841 |
references: List[List[str]],
|
| 842 |
predictions: List[str],
|
| 843 |
+
task_data: List[Dict],
|
| 844 |
) -> dict:
|
| 845 |
ids = [str(uuid.uuid4()).replace("-", "") for _ in range(len(predictions))]
|
| 846 |
formatted_predictions = [
|
|
|
|
| 861 |
class Accuracy(InstanceMetric):
|
| 862 |
reduction_map = {"mean": ["accuracy"]}
|
| 863 |
main_score = "accuracy"
|
| 864 |
+
ci_scores = ["accuracy"]
|
| 865 |
|
| 866 |
def compute(
|
| 867 |
+
self, references: List[Any], prediction: Any, task_data: List[Dict]
|
| 868 |
) -> dict:
|
| 869 |
result = {
|
| 870 |
self.main_score: float(
|
|
|
|
| 879 |
class StringContainment(InstanceMetric):
|
| 880 |
reduction_map = {"mean": ["string_containment"]}
|
| 881 |
main_score = "string_containment"
|
| 882 |
+
ci_scores = ["string_containment"]
|
| 883 |
|
| 884 |
def compute(
|
| 885 |
+
self, references: List[Any], prediction: Any, task_data: List[Dict]
|
| 886 |
) -> dict:
|
| 887 |
result = {
|
| 888 |
self.main_score: float(
|
| 889 |
+
any(str(reference) in str(prediction) for reference in references)
|
| 890 |
)
|
| 891 |
}
|
| 892 |
result["score"] = result[self.main_score]
|
|
|
|
| 902 |
)
|
| 903 |
metric: Metric = None
|
| 904 |
|
| 905 |
+
def disable_confidence_interval_calculation(self):
|
| 906 |
+
return self.metric.disable_confidence_interval_calculation()
|
| 907 |
+
|
| 908 |
+
def set_n_resamples(self, n_resample):
|
| 909 |
+
if isinstance(self.metric, MetricWithConfidenceInterval):
|
| 910 |
+
self.metric.set_n_resamples(n_resample)
|
| 911 |
+
|
| 912 |
def verify(self):
|
| 913 |
assert self.main_score is not None, "main_score is not set"
|
| 914 |
|
|
|
|
| 973 |
self,
|
| 974 |
references: List[List[Any]],
|
| 975 |
predictions: List[Any],
|
| 976 |
+
task_data: List[Dict],
|
| 977 |
) -> dict:
|
| 978 |
+
passed_task_data = {}
|
| 979 |
for additional_input_field in self.hf_additional_input_fields:
|
| 980 |
assert (
|
| 981 |
+
additional_input_field in task_data[0]
|
| 982 |
+
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
|
| 983 |
+
passed_task_data[additional_input_field] = [
|
| 984 |
additional_input[additional_input_field]
|
| 985 |
+
for additional_input in task_data
|
| 986 |
]
|
| 987 |
for additional_input_field in self.hf_additional_input_fields_pass_one_value:
|
| 988 |
assert (
|
| 989 |
+
additional_input_field in task_data[0]
|
| 990 |
+
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
|
| 991 |
|
| 992 |
values = {
|
| 993 |
additional_input[additional_input_field]
|
| 994 |
+
for additional_input in task_data
|
| 995 |
}
|
| 996 |
assert (
|
| 997 |
len(values) == 1
|
| 998 |
), f"Values of '{additional_input_field}' field required by {__class__.__name__} should all be the same, but have multiple values {values}"
|
| 999 |
|
| 1000 |
+
passed_task_data[additional_input_field] = next(iter(values))
|
| 1001 |
|
| 1002 |
+
# add check that all required fields in self.metrics are in passed_task_data print(passed_task_data)
|
| 1003 |
result = self.metric.compute(
|
| 1004 |
predictions=predictions,
|
| 1005 |
references=references,
|
| 1006 |
+
**passed_task_data,
|
| 1007 |
**self.hf_compute_args,
|
| 1008 |
)
|
| 1009 |
if self.hf_main_score:
|
|
|
|
| 1045 |
self,
|
| 1046 |
references: List[List[str]],
|
| 1047 |
predictions: List[str],
|
| 1048 |
+
task_data: List[Any],
|
| 1049 |
) -> List[Dict[str, Any]]:
|
| 1050 |
+
passed_task_data = {}
|
| 1051 |
for additional_input_field in self.hf_additional_input_fields:
|
| 1052 |
assert (
|
| 1053 |
+
additional_input_field in task_data[0]
|
| 1054 |
+
), f"'{additional_input_field}' field required by {__class__.__name__} is not in passed in task_data: {task_data[0]}"
|
| 1055 |
+
passed_task_data[additional_input_field] = [
|
| 1056 |
additional_input[additional_input_field]
|
| 1057 |
+
for additional_input in task_data
|
| 1058 |
]
|
| 1059 |
+
# add check that all required fields in self.metrics are in passed_task_data
|
| 1060 |
|
| 1061 |
scores = self.metric.compute(
|
| 1062 |
predictions=predictions,
|
| 1063 |
references=references,
|
| 1064 |
+
**passed_task_data,
|
| 1065 |
**self.hf_compute_args,
|
| 1066 |
)
|
| 1067 |
|
|
|
|
| 1096 |
self,
|
| 1097 |
references: List[List[str]],
|
| 1098 |
predictions: List[str],
|
| 1099 |
+
task_data: List[Dict],
|
| 1100 |
) -> dict:
|
| 1101 |
assert all(
|
| 1102 |
len(reference) == 1 for reference in references
|
|
|
|
| 1118 |
average=self.average,
|
| 1119 |
)
|
| 1120 |
if isinstance(result["f1"], numpy.ndarray):
|
|
|
|
|
|
|
| 1121 |
final_result = {self.main_score: mean(result["f1"])}
|
| 1122 |
for i, label in enumerate(labels):
|
| 1123 |
final_result["f1_" + self.id_to_str[label]] = result["f1"][i]
|
|
|
|
| 1144 |
_metric = None
|
| 1145 |
main_score = "f1_macro"
|
| 1146 |
average = None # Report per class then aggregate by mean
|
|
|
|
| 1147 |
metric = "f1"
|
| 1148 |
|
| 1149 |
def prepare(self):
|
|
|
|
| 1168 |
self,
|
| 1169 |
references: List[List[str]],
|
| 1170 |
predictions: List[List[str]],
|
| 1171 |
+
task_data: List[Dict],
|
| 1172 |
) -> dict:
|
| 1173 |
self.str_to_id = {}
|
| 1174 |
self.id_to_str = {}
|
|
|
|
| 1176 |
self._validate_references_and_prediction(references, predictions)
|
| 1177 |
references = [reference[0] for reference in references]
|
| 1178 |
|
| 1179 |
+
labels = list({label for reference in references for label in reference})
|
| 1180 |
+
|
|
|
|
|
|
|
|
|
|
| 1181 |
# if no classes are left then F1 is not defined
|
|
|
|
| 1182 |
if len(labels) == 0:
|
| 1183 |
return {self.main_score: float("nan")}
|
| 1184 |
|
|
|
|
| 1206 |
labels=labels_param,
|
| 1207 |
)
|
| 1208 |
if isinstance(result[self.metric], numpy.ndarray):
|
|
|
|
|
|
|
| 1209 |
assert (
|
| 1210 |
len(result[self.metric]) == len(labels)
|
| 1211 |
), f"F1 result ({result[self.metric]}) has more entries than labels ({labels})"
|
|
|
|
| 1278 |
|
| 1279 |
sent_split_newline: bool = True
|
| 1280 |
|
| 1281 |
+
_requirements_list: List[str] = ["nltk", "rouge_score"]
|
| 1282 |
+
|
| 1283 |
def prepare(self):
|
| 1284 |
super().prepare()
|
| 1285 |
|
|
|
|
| 1292 |
nltk.download("punkt")
|
| 1293 |
self.sent_tokenize = nltk.sent_tokenize
|
| 1294 |
|
| 1295 |
+
def compute(self, references, predictions, task_data: List[Dict]):
|
| 1296 |
if self.sent_split_newline:
|
| 1297 |
predictions = [
|
| 1298 |
"\n".join(self.sent_tokenize(prediction.strip()))
|
|
|
|
| 1302 |
["\n".join(self.sent_tokenize(r.strip())) for r in reference]
|
| 1303 |
for reference in references
|
| 1304 |
]
|
| 1305 |
+
return super().compute(references, predictions, task_data)
|
| 1306 |
|
| 1307 |
|
| 1308 |
# Computes char edit distance, ignoring whitespace
|
| 1309 |
class CharEditDistanceAccuracy(InstanceMetric):
|
| 1310 |
reduction_map = {"mean": ["char_edit_dist_accuracy"]}
|
| 1311 |
main_score = "char_edit_dist_accuracy"
|
| 1312 |
+
ci_scores = ["char_edit_dist_accuracy"]
|
| 1313 |
+
|
| 1314 |
+
_requirements_list: List[str] = ["editdistance"]
|
| 1315 |
|
| 1316 |
def prepare(self):
|
| 1317 |
super().prepare()
|
|
|
|
| 1319 |
|
| 1320 |
self.eval = editdistance.eval
|
| 1321 |
|
| 1322 |
+
def compute(self, references, prediction: str, task_data: List[Dict]) -> dict:
|
|
|
|
|
|
|
| 1323 |
assert (
|
| 1324 |
len(references) == 1
|
| 1325 |
), f"Expected only one reference , but received: {references}"
|
|
|
|
| 1337 |
hf_metric_name = "wer"
|
| 1338 |
main_score = "wer"
|
| 1339 |
|
| 1340 |
+
_requirements_list: List[str] = ["jiwer"]
|
| 1341 |
+
|
| 1342 |
def compute(
|
| 1343 |
self,
|
| 1344 |
references: List[List[str]],
|
| 1345 |
predictions: List[str],
|
| 1346 |
+
task_data: List[Dict],
|
| 1347 |
) -> dict:
|
| 1348 |
assert all(
|
| 1349 |
len(reference) == 1 for reference in references
|
|
|
|
| 1355 |
return {self.main_score: result}
|
| 1356 |
|
| 1357 |
|
| 1358 |
+
class Spearmanr(HuggingfaceMetric):
|
| 1359 |
+
hf_metric_name = "spearmanr"
|
| 1360 |
+
main_score = "spearmanr"
|
| 1361 |
+
process_single_instances = False
|
| 1362 |
+
|
| 1363 |
+
|
| 1364 |
+
class KendallTauMetric(GlobalMetric):
|
| 1365 |
+
main_score = "kendalltau_b"
|
| 1366 |
+
variant = "b"
|
| 1367 |
+
process_single_instances = False
|
| 1368 |
+
|
| 1369 |
+
_requirements_list: List[str] = ["scipy"]
|
| 1370 |
+
|
| 1371 |
+
def prepare(self):
|
| 1372 |
+
from scipy.stats import kendalltau
|
| 1373 |
+
|
| 1374 |
+
self.kendalltau = kendalltau
|
| 1375 |
+
|
| 1376 |
+
def compute(
|
| 1377 |
+
self,
|
| 1378 |
+
references: List[List[str]],
|
| 1379 |
+
predictions: List[str],
|
| 1380 |
+
task_data: List[Dict],
|
| 1381 |
+
) -> dict:
|
| 1382 |
+
if isinstance(references[0], list):
|
| 1383 |
+
references = [reference[0] for reference in references]
|
| 1384 |
+
references = [to_float_or_default(r) for r in references]
|
| 1385 |
+
predictions = [to_float_or_default(p) for p in predictions]
|
| 1386 |
+
|
| 1387 |
+
kendall_results = self.kendalltau(references, predictions, variant=self.variant)
|
| 1388 |
+
corr = kendall_results.correlation
|
| 1389 |
+
return {
|
| 1390 |
+
self.main_score: corr,
|
| 1391 |
+
f"{self.main_score}_p_val": kendall_results.pvalue,
|
| 1392 |
+
}
|
| 1393 |
+
|
| 1394 |
+
|
| 1395 |
class MatthewsCorrelation(HuggingfaceMetric):
|
| 1396 |
hf_metric_name = "matthews_correlation"
|
| 1397 |
main_score = "matthews_correlation"
|
|
|
|
| 1407 |
self,
|
| 1408 |
references: List[List[str]],
|
| 1409 |
predictions: List[str],
|
| 1410 |
+
task_data: List[Dict],
|
| 1411 |
) -> dict:
|
| 1412 |
formatted_references = [
|
| 1413 |
self.get_str_id(reference[0]) for reference in references
|
|
|
|
| 1420 |
)
|
| 1421 |
|
| 1422 |
|
| 1423 |
+
class RocAuc(GlobalMetric):
|
| 1424 |
+
main_score = "roc_auc"
|
| 1425 |
+
process_single_instances = False
|
| 1426 |
+
_requirements_list: List[str] = ["sklearn"]
|
| 1427 |
+
|
| 1428 |
+
def prepare(self):
|
| 1429 |
+
from sklearn import metrics
|
| 1430 |
+
|
| 1431 |
+
self.roc_curve = metrics.roc_curve
|
| 1432 |
+
self.auc = metrics.auc
|
| 1433 |
+
|
| 1434 |
+
def compute(
|
| 1435 |
+
self,
|
| 1436 |
+
references: List[List[str]],
|
| 1437 |
+
predictions: List[str],
|
| 1438 |
+
task_data: List[Dict],
|
| 1439 |
+
) -> dict:
|
| 1440 |
+
if isinstance(references[0], list):
|
| 1441 |
+
references = [reference[0] for reference in references]
|
| 1442 |
+
references = [to_float_or_default(r) for r in references]
|
| 1443 |
+
predictions = [to_float_or_default(p) for p in predictions]
|
| 1444 |
+
|
| 1445 |
+
fpr, tpr, thrs = self.roc_curve(y_true=references, y_score=predictions)
|
| 1446 |
+
roc_auc = self.auc(fpr, tpr)
|
| 1447 |
+
return {self.main_score: roc_auc}
|
| 1448 |
+
|
| 1449 |
+
|
| 1450 |
class CustomF1(GlobalMetric):
|
| 1451 |
main_score = "f1_micro"
|
| 1452 |
groups = None
|
|
|
|
| 1500 |
except ZeroDivisionError:
|
| 1501 |
return self.zero_division
|
| 1502 |
|
| 1503 |
+
def get_groups(self, elements, task_data):
|
| 1504 |
groups = set()
|
| 1505 |
+
for sublist, additional_input in zip(elements, task_data):
|
| 1506 |
for e in sublist:
|
| 1507 |
if self.should_ignore_element(e, additional_input):
|
| 1508 |
continue
|
|
|
|
| 1513 |
self,
|
| 1514 |
references: List[List[Any]],
|
| 1515 |
predictions: List[Any],
|
| 1516 |
+
task_data: List[Dict],
|
| 1517 |
) -> dict:
|
| 1518 |
# in case reference are List[List[List[Any]]] and predictions are List[List[Any]]:
|
| 1519 |
if (
|
|
|
|
| 1529 |
)
|
| 1530 |
|
| 1531 |
if self.groups is None:
|
| 1532 |
+
groups = self.get_groups(references, task_data)
|
| 1533 |
else:
|
| 1534 |
groups = self.groups
|
| 1535 |
groups_statistics = {}
|
| 1536 |
for references_batch, predictions_batch, additional_input in zip(
|
| 1537 |
+
references, predictions, task_data
|
| 1538 |
):
|
| 1539 |
grouped_references = self.group_elements(references_batch, additional_input)
|
| 1540 |
grouped_predictions = self.group_elements(
|
|
|
|
| 1651 |
ci_scores = ["f1", "precision", "recall"]
|
| 1652 |
|
| 1653 |
def compute(
|
| 1654 |
+
self, references: List[Any], prediction: Any, task_data: List[Dict]
|
| 1655 |
) -> dict:
|
| 1656 |
results = [
|
| 1657 |
+
self._compute_single_ref(str(reference), str(prediction))
|
| 1658 |
+
for reference in references
|
| 1659 |
]
|
| 1660 |
return {
|
| 1661 |
measure: max(r[i] for r in results)
|
|
|
|
| 1665 |
def _compute_single_ref(
|
| 1666 |
self, reference: Any, prediction: Any
|
| 1667 |
) -> Tuple[float, float, float]:
|
| 1668 |
+
prediction_tokens = normalize_answer(str(prediction)).split()
|
| 1669 |
+
reference_tokens = normalize_answer(str(reference)).split()
|
| 1670 |
common = Counter(prediction_tokens) & Counter(reference_tokens)
|
| 1671 |
num_same = sum(common.values())
|
| 1672 |
if num_same == 0:
|
|
|
|
| 1686 |
ci_scores = ["f1", "precision", "recall"]
|
| 1687 |
model_name: str
|
| 1688 |
|
| 1689 |
+
_requirements_list: List[str] = ["bert_score"]
|
| 1690 |
+
|
| 1691 |
def prepare(self):
|
| 1692 |
super().prepare()
|
| 1693 |
+
self.hf_compute_args = {"model_type": self.model_name, "batch_size": 16}
|
| 1694 |
|
| 1695 |
|
| 1696 |
class SentenceBert(BulkInstanceMetric):
|
|
|
|
| 1700 |
|
| 1701 |
model_name: str
|
| 1702 |
|
| 1703 |
+
_requirements_list: List[str] = ["sentence_transformers"]
|
| 1704 |
+
|
| 1705 |
def prepare(self):
|
| 1706 |
super().prepare()
|
| 1707 |
+
import torch
|
| 1708 |
from sentence_transformers import SentenceTransformer
|
| 1709 |
from sentence_transformers import util as sbert_util
|
| 1710 |
|
| 1711 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 1712 |
+
self.model = SentenceTransformer(self.model_name, device=self.device)
|
| 1713 |
self.util = sbert_util
|
| 1714 |
|
| 1715 |
def compute(
|
| 1716 |
self,
|
| 1717 |
references: List[List[Any]],
|
| 1718 |
predictions: List[Any],
|
| 1719 |
+
task_data: List[Dict],
|
| 1720 |
) -> List[Dict[str, Any]]:
|
| 1721 |
scores = []
|
| 1722 |
|
|
|
|
| 1731 |
count += len(ref_group)
|
| 1732 |
|
| 1733 |
# compute s-bert embeddings
|
| 1734 |
+
preds_emb = self.model.encode(predictions, device=self.device)
|
| 1735 |
refs_emb = self.model.encode(
|
| 1736 |
+
[ref for ref_group in references for ref in ref_group], device=self.device
|
| 1737 |
)
|
| 1738 |
|
| 1739 |
# for each candidate, pick the reference with the highest score
|
|
|
|
| 1751 |
|
| 1752 |
model_name: str
|
| 1753 |
|
| 1754 |
+
_requirements_list: List[str] = ["transformers"]
|
| 1755 |
+
|
| 1756 |
def prepare(self):
|
| 1757 |
super().prepare()
|
| 1758 |
+
import torch
|
| 1759 |
from transformers import pipeline
|
| 1760 |
|
| 1761 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 1762 |
+
self.pipe = pipeline(
|
| 1763 |
+
"text-classification", model=self.model_name, device=device
|
| 1764 |
+
)
|
| 1765 |
|
| 1766 |
def compute(
|
| 1767 |
self,
|
| 1768 |
references: List[List[Any]],
|
| 1769 |
predictions: List[Any],
|
| 1770 |
+
task_data: List[Dict],
|
| 1771 |
) -> List[Dict[str, Any]]:
|
| 1772 |
# treat the references as the questions and the predictions as answers
|
| 1773 |
# assume a single reference
|
|
|
|
| 1793 |
batch_size: int = 32
|
| 1794 |
model_name: str
|
| 1795 |
|
| 1796 |
+
_requirements_list: List[str] = ["transformers"]
|
| 1797 |
+
|
| 1798 |
def compute(
|
| 1799 |
self,
|
| 1800 |
references: List[List[Any]],
|
| 1801 |
predictions: List[Any],
|
| 1802 |
+
task_data: List[Dict],
|
| 1803 |
) -> List[Dict[str, Any]]:
|
| 1804 |
"""Computes the likelihood of generating text Y after text X - P(Y|X).
|
| 1805 |
|
| 1806 |
+
:param predictions: the list of Y texts = the targets of the generation
|
| 1807 |
+
:param references: the list of list of X texts = the sources of the generation
|
| 1808 |
|
| 1809 |
+
:return: the likelihood of generating text Y_i after each text X_i_j = P(Y_i|X_i_1), ..., P(Y_i|X_i_n) for every i.
|
| 1810 |
"""
|
| 1811 |
sources = []
|
| 1812 |
targets = []
|
| 1813 |
for prediction, instance_references in zip(predictions, references):
|
| 1814 |
for instance_reference in instance_references:
|
| 1815 |
+
sources.append(f"{self.perplexity_prompt} {instance_reference}")
|
| 1816 |
+
targets.append(prediction)
|
| 1817 |
|
| 1818 |
from transformers import AutoConfig
|
| 1819 |
|
|
|
|
| 1854 |
from transformers import AutoTokenizer
|
| 1855 |
|
| 1856 |
self.model_name = model_name
|
| 1857 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 1858 |
+
self.model = (
|
| 1859 |
+
self.model_class().from_pretrained(self.model_name).to(self.device)
|
| 1860 |
+
)
|
| 1861 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
|
|
|
|
|
| 1862 |
|
| 1863 |
def compute_lm(
|
| 1864 |
self, source: List[str], target: List[str], batch_size: int
|
|
|
|
| 1951 |
return AutoModelForSeq2SeqLM
|
| 1952 |
|
| 1953 |
def compute_batch(self, tokens_source, tokens_target):
|
| 1954 |
+
tokens_docs_ids = tokens_source["input_ids"].to(self.device)
|
| 1955 |
+
attention = tokens_source["attention_mask"].to(self.device)
|
| 1956 |
+
labels = tokens_target["input_ids"].to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1957 |
|
| 1958 |
logits = self.model(
|
| 1959 |
input_ids=tokens_docs_ids.long(),
|
|
|
|
| 1993 |
# replace the padding token in the labels by -100
|
| 1994 |
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 1995 |
|
| 1996 |
+
tokens = tokens.to(self.device)
|
| 1997 |
+
attention = attention.to(self.device)
|
| 1998 |
+
labels = labels.to(self.device)
|
|
|
|
|
|
|
|
|
|
| 1999 |
|
| 2000 |
# no need to pass labels as we calculate the loss below per document
|
| 2001 |
model_output = self.model(
|
|
|
|
| 2029 |
|
| 2030 |
main_score = "nDCG"
|
| 2031 |
|
| 2032 |
+
_requirements_list: List[str] = ["sklearn"]
|
| 2033 |
+
|
| 2034 |
def prepare(self):
|
| 2035 |
from sklearn.metrics import ndcg_score
|
| 2036 |
|
|
|
|
| 2041 |
self,
|
| 2042 |
references: List[List[Any]],
|
| 2043 |
predictions: List[Any],
|
| 2044 |
+
task_data: List[Any],
|
| 2045 |
) -> dict:
|
| 2046 |
from collections import defaultdict
|
|
|
|
| 2047 |
|
| 2048 |
query_to_predictions_and_references = defaultdict(lambda: [[], []])
|
| 2049 |
+
for reference, pred, inputs_dict in zip(references, predictions, task_data):
|
|
|
|
|
|
|
| 2050 |
query = inputs_dict.get("query")
|
| 2051 |
query_to_predictions_and_references[query][0].append(pred)
|
| 2052 |
query_to_predictions_and_references[query][1].append(reference)
|
|
|
|
| 2076 |
|
| 2077 |
|
| 2078 |
class RetrievalMetric(InstanceMetric):
|
| 2079 |
+
def compute(self, references: List[Any], prediction: Any, task_data: Dict) -> dict:
|
|
|
|
|
|
|
| 2080 |
# digest input
|
| 2081 |
pred_ids: List[Any] = prediction
|
| 2082 |
ref_ids: List[Any] = list(dict.fromkeys(references))
|
|
|
|
| 2149 |
class MRR(RetrievalMetric):
|
| 2150 |
reduction_map = {"mean": ["mrr"]}
|
| 2151 |
main_score = "mrr"
|
| 2152 |
+
ci_scores = ["mrr"]
|
| 2153 |
|
| 2154 |
def _compute(
|
| 2155 |
self,
|
|
|
|
| 2166 |
class MAP(RetrievalMetric):
|
| 2167 |
reduction_map = {"mean": ["map"]}
|
| 2168 |
main_score = "map"
|
| 2169 |
+
ci_scores = ["map"]
|
| 2170 |
|
| 2171 |
def _compute(
|
| 2172 |
self,
|
|
|
|
| 2235 |
|
| 2236 |
def should_ignore_element(self, element, additional_input):
|
| 2237 |
return element == "none"
|
| 2238 |
+
|
| 2239 |
+
|
| 2240 |
+
class RemoteMetric(SingleStreamOperator, Metric):
|
| 2241 |
+
"""A metric that runs another metric remotely.
|
| 2242 |
+
|
| 2243 |
+
main_score: the score updated by this metric.
|
| 2244 |
+
endpoint: the remote host that supports the remote metric execution.
|
| 2245 |
+
metric_name: the name of the metric that is executed remotely.
|
| 2246 |
+
api_key: optional, passed to the remote metric with the input, allows secure authentication.
|
| 2247 |
+
"""
|
| 2248 |
+
|
| 2249 |
+
main_score: str = None
|
| 2250 |
+
endpoint: str
|
| 2251 |
+
metric_name: str
|
| 2252 |
+
api_key: str = None
|
| 2253 |
+
|
| 2254 |
+
@staticmethod
|
| 2255 |
+
def wrap_inner_metric_pipeline_metric(
|
| 2256 |
+
metric_pipeline: MetricPipeline, remote_metrics_endpoint: str
|
| 2257 |
+
) -> MetricPipeline:
|
| 2258 |
+
"""Wrap the inner metric in a MetricPipeline with a RemoteMetric.
|
| 2259 |
+
|
| 2260 |
+
When executing the returned MetricPipeline, the inner metric will be computed
|
| 2261 |
+
remotely (pre and post processing steps in the MetricPipeline will be computed locally).
|
| 2262 |
+
"""
|
| 2263 |
+
local_inner_metric = metric_pipeline.metric
|
| 2264 |
+
metric_pipeline = deepcopy(
|
| 2265 |
+
metric_pipeline
|
| 2266 |
+
) # To avoid unintentional changes to the catalog contents
|
| 2267 |
+
metric_pipeline.metric = RemoteMetric(
|
| 2268 |
+
main_score=local_inner_metric.main_score,
|
| 2269 |
+
metric_name=local_inner_metric.artifact_identifier,
|
| 2270 |
+
endpoint=remote_metrics_endpoint,
|
| 2271 |
+
)
|
| 2272 |
+
return metric_pipeline
|
| 2273 |
+
|
| 2274 |
+
def get_metric_url(self) -> str:
|
| 2275 |
+
return f"{self.endpoint}/{self.metric_name}"
|
| 2276 |
+
|
| 2277 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 2278 |
+
predictions, references, additional_inputs, instances = self.consume_stream(
|
| 2279 |
+
stream
|
| 2280 |
+
)
|
| 2281 |
+
metric_request = self.create_metric_request(
|
| 2282 |
+
predictions, references, additional_inputs
|
| 2283 |
+
)
|
| 2284 |
+
metric_response = self.get_metric_response(metric_request)
|
| 2285 |
+
self.update_instance_scores(instances, metric_response.instances_scores)
|
| 2286 |
+
self.set_global_score(instances, metric_response.global_score)
|
| 2287 |
+
yield from instances
|
| 2288 |
+
|
| 2289 |
+
@staticmethod
|
| 2290 |
+
def create_metric_request(predictions, references, additional_inputs):
|
| 2291 |
+
instance_inputs = [
|
| 2292 |
+
InstanceInput(
|
| 2293 |
+
prediction=prediction,
|
| 2294 |
+
references=reference,
|
| 2295 |
+
additional_inputs=additional_input,
|
| 2296 |
+
)
|
| 2297 |
+
for prediction, reference, additional_input in zip(
|
| 2298 |
+
predictions, references, additional_inputs
|
| 2299 |
+
)
|
| 2300 |
+
]
|
| 2301 |
+
return MetricRequest(instance_inputs=instance_inputs)
|
| 2302 |
+
|
| 2303 |
+
def get_metric_response(self, metric_request: MetricRequest) -> MetricResponse:
|
| 2304 |
+
import requests
|
| 2305 |
+
|
| 2306 |
+
response = requests.post(
|
| 2307 |
+
url=self.get_metric_url(),
|
| 2308 |
+
json=metric_request.to_dict(),
|
| 2309 |
+
headers={"Authorization": f"Bearer {self.api_key}"},
|
| 2310 |
+
)
|
| 2311 |
+
response.raise_for_status()
|
| 2312 |
+
response_json = response.json()
|
| 2313 |
+
return MetricResponse(**response_json)
|
| 2314 |
+
|
| 2315 |
+
def disable_confidence_interval_calculation(self):
|
| 2316 |
+
"""Confidence intervals are always disabled for RemoteMetric.
|
| 2317 |
+
|
| 2318 |
+
No need to do anything.
|
| 2319 |
+
"""
|
| 2320 |
+
pass
|
| 2321 |
+
|
| 2322 |
+
def set_n_resamples(self, n_resample):
|
| 2323 |
+
"""Since confidence intervals are always disabled for remote metrics, this is a no-op."""
|
| 2324 |
+
pass
|
| 2325 |
+
|
| 2326 |
+
|
| 2327 |
+
def validate_subgroup_types(
|
| 2328 |
+
subgroup_scores_dict: Dict[str, List],
|
| 2329 |
+
control_subgroup_types: List[str],
|
| 2330 |
+
comparison_subgroup_types: List[str],
|
| 2331 |
+
):
|
| 2332 |
+
"""Validate a dict of subgroup type instance score lists, and subgroup type lists.
|
| 2333 |
+
|
| 2334 |
+
Args:
|
| 2335 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
| 2336 |
+
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
| 2337 |
+
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
| 2338 |
+
to be compared to the control group.
|
| 2339 |
+
|
| 2340 |
+
Returns:
|
| 2341 |
+
dict with all NaN scores removed; control_subgroup_types and comparison_subgroup_types will have non-unique elements removed
|
| 2342 |
+
"""
|
| 2343 |
+
# note: subgroup_scores_dict is already a defaultdict of lists, so don't need to check that keys in control_ and comparison_subgroup_types exist in it
|
| 2344 |
+
# remove any NaNs
|
| 2345 |
+
subgroup_scores_dict.update(
|
| 2346 |
+
{
|
| 2347 |
+
subgroup_name: [score for score in score_list if not np.isnan(score)]
|
| 2348 |
+
for subgroup_name, score_list in subgroup_scores_dict.items()
|
| 2349 |
+
}
|
| 2350 |
+
)
|
| 2351 |
+
assert isinstance(
|
| 2352 |
+
control_subgroup_types, list
|
| 2353 |
+
), "control_subgroup_types must be a list"
|
| 2354 |
+
assert isinstance(
|
| 2355 |
+
comparison_subgroup_types, list
|
| 2356 |
+
), "comparison_subgroup_types must be a list"
|
| 2357 |
+
# make sure each list is unique, so that labels aren't double-counted
|
| 2358 |
+
control_subgroup_types = list(set(control_subgroup_types))
|
| 2359 |
+
comparison_subgroup_types = list(set(comparison_subgroup_types))
|
| 2360 |
+
|
| 2361 |
+
return subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
|
| 2362 |
+
|
| 2363 |
+
|
| 2364 |
+
def performance_drop_rate(
|
| 2365 |
+
subgroup_scores_dict: Dict[str, List],
|
| 2366 |
+
control_subgroup_types: List[str],
|
| 2367 |
+
comparison_subgroup_types: List[str],
|
| 2368 |
+
):
|
| 2369 |
+
"""Percentage decrease of mean performance on test elements relative to that on a baseline (control).
|
| 2370 |
+
|
| 2371 |
+
from https://arxiv.org/pdf/2306.04528.pdf.
|
| 2372 |
+
|
| 2373 |
+
Args:
|
| 2374 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
| 2375 |
+
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
| 2376 |
+
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
| 2377 |
+
to be compared to the control group.
|
| 2378 |
+
|
| 2379 |
+
Returns:
|
| 2380 |
+
numeric PDR metric.
|
| 2381 |
+
If only one element (no test set) or the first is 0 (percentage change is undefined) return NaN
|
| 2382 |
+
otherwise, calculate PDR
|
| 2383 |
+
"""
|
| 2384 |
+
(
|
| 2385 |
+
subgroup_scores_dict,
|
| 2386 |
+
control_subgroup_types,
|
| 2387 |
+
comparison_subgroup_types,
|
| 2388 |
+
) = validate_subgroup_types(
|
| 2389 |
+
subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
|
| 2390 |
+
)
|
| 2391 |
+
|
| 2392 |
+
# combine all scores from each label (if there are more than 1 in each group) into a list
|
| 2393 |
+
group_scores_list = [
|
| 2394 |
+
np.concatenate(
|
| 2395 |
+
[subgroup_scores_dict[subgroup_name] for subgroup_name in name_list]
|
| 2396 |
+
)
|
| 2397 |
+
for name_list in [control_subgroup_types, comparison_subgroup_types]
|
| 2398 |
+
]
|
| 2399 |
+
if any(len(scores) == 0 for scores in group_scores_list):
|
| 2400 |
+
# no comparison can be made since there is not at least one score per type
|
| 2401 |
+
return np.nan
|
| 2402 |
+
control_mean = mean(group_scores_list[0])
|
| 2403 |
+
comparison_mean = mean(group_scores_list[1])
|
| 2404 |
+
if control_mean == 0:
|
| 2405 |
+
# return 0 if comparison is also 0
|
| 2406 |
+
if comparison_mean == 0:
|
| 2407 |
+
return 0
|
| 2408 |
+
return np.nan
|
| 2409 |
+
# otherwise, take the percentage change (which may also be 0)
|
| 2410 |
+
return 1 - comparison_mean / control_mean
|
| 2411 |
+
|
| 2412 |
+
|
| 2413 |
+
def interpret_effect_size(x: float):
|
| 2414 |
+
"""Return a string rule-of-thumb interpretation of an effect size value, as defined by Cohen/Sawilowsky.
|
| 2415 |
+
|
| 2416 |
+
See https://en.wikipedia.org/wiki/Effect_size;
|
| 2417 |
+
Cohen, Jacob (1988). Statistical Power Analysis for the Behavioral Sciences; and
|
| 2418 |
+
Sawilowsky, S (2009). "New effect size rules of thumb". Journal of Modern Applied Statistical Methods. 8 (2): 467-474.
|
| 2419 |
+
|
| 2420 |
+
Value has interpretation of
|
| 2421 |
+
- essentially 0 if |x| < 0.01
|
| 2422 |
+
- very small if 0.01 <= |x| < 0.2
|
| 2423 |
+
- small difference if 0.2 <= |x| < 0.5
|
| 2424 |
+
- a medium difference if 0.5 <= |x| < 0.8
|
| 2425 |
+
- a large difference if 0.8 <= |x| < 1.2
|
| 2426 |
+
- a very large difference if 1.2 <= |x| < 2.0
|
| 2427 |
+
- a huge difference if 2.0 <= |x|
|
| 2428 |
+
|
| 2429 |
+
Args:
|
| 2430 |
+
x: float effect size value
|
| 2431 |
+
|
| 2432 |
+
Returns:
|
| 2433 |
+
string interpretation
|
| 2434 |
+
"""
|
| 2435 |
+
import pandas as pd
|
| 2436 |
+
|
| 2437 |
+
# assign a label according to threshold of the absolute value
|
| 2438 |
+
return pd.cut(
|
| 2439 |
+
x=[np.abs(x)],
|
| 2440 |
+
right=False,
|
| 2441 |
+
bins=[-1, 0.01, 0.2, 0.5, 0.8, 1.2, 2.0, np.Inf],
|
| 2442 |
+
labels=[
|
| 2443 |
+
"essentially zero",
|
| 2444 |
+
"very small",
|
| 2445 |
+
"small",
|
| 2446 |
+
"medium",
|
| 2447 |
+
"large",
|
| 2448 |
+
"very large",
|
| 2449 |
+
"huge",
|
| 2450 |
+
],
|
| 2451 |
+
)[0]
|
| 2452 |
+
|
| 2453 |
+
|
| 2454 |
+
def normalized_cohens_h(
|
| 2455 |
+
subgroup_scores_dict: Dict[str, List],
|
| 2456 |
+
control_subgroup_types: List[str],
|
| 2457 |
+
comparison_subgroup_types: List[str],
|
| 2458 |
+
interpret=False,
|
| 2459 |
+
):
|
| 2460 |
+
"""Cohen's h effect size between two proportions, normalized to interval [-1,1].
|
| 2461 |
+
|
| 2462 |
+
Allows for change-type metric when the baseline is 0 (percentage change, and thus PDR, is undefined)
|
| 2463 |
+
https://en.wikipedia.org/wiki/Cohen%27s_h
|
| 2464 |
+
|
| 2465 |
+
Cohen's h effect size metric between two proportions p2 and p1 is 2 * (arcsin(sqrt(p2)) - arcsin(sqrt(p1))).
|
| 2466 |
+
h in -pi, pi, with +/-pi representing the largest increase/decrease (p1=0, p2=1), or (p1=1, p2=0).
|
| 2467 |
+
h=0 is no change. Unlike percentage change, h is defined even if the baseline (p1) is 0.
|
| 2468 |
+
Assumes the scores are in [0,1], either continuous or binary; hence taking the average of a group of scores yields a proportion..
|
| 2469 |
+
Calculates the change in the average of the other_scores relative to the average of the baseline_scores. We rescale this to [-1,1] from [-pi,pi] for clarity, where +- 1 are the most extreme changes, and 0 is no change
|
| 2470 |
+
|
| 2471 |
+
Interpretation: the original unscaled Cohen's h can be interpreted according to function interpret_effect_size
|
| 2472 |
+
|
| 2473 |
+
Thus, the rule of interpreting the effect of the normalized value is to use the same thresholds divided by pi
|
| 2474 |
+
- essentially 0 if |norm h| < 0.0031831
|
| 2475 |
+
- very small if 0.0031831 <= |norm h| < 0.06366198
|
| 2476 |
+
- small difference if 0.06366198 <= |norm h| < 0.15915494
|
| 2477 |
+
- a medium difference if 0.15915494 <= |norm h| < 0.25464791
|
| 2478 |
+
- a large difference if 0.25464791 <= |norm h| < 0.38197186
|
| 2479 |
+
- a very large difference if 0.38197186 <= |norm h| < 0.63661977
|
| 2480 |
+
- a huge difference if 0.63661977 <= |norm h|
|
| 2481 |
+
Args:
|
| 2482 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
| 2483 |
+
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
| 2484 |
+
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
| 2485 |
+
to be compared to the control group.
|
| 2486 |
+
interpret: boolean, whether to interpret the significance of the score or not
|
| 2487 |
+
Returns:
|
| 2488 |
+
float score between -1 and 1, and a string interpretation if interpret=True
|
| 2489 |
+
"""
|
| 2490 |
+
(
|
| 2491 |
+
subgroup_scores_dict,
|
| 2492 |
+
control_subgroup_types,
|
| 2493 |
+
comparison_subgroup_types,
|
| 2494 |
+
) = validate_subgroup_types(
|
| 2495 |
+
subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
|
| 2496 |
+
)
|
| 2497 |
+
|
| 2498 |
+
# requires scores to be in [0,1]
|
| 2499 |
+
for subgroup_name, score_list in subgroup_scores_dict.items():
|
| 2500 |
+
assert all(
|
| 2501 |
+
0 <= score <= 1 for score in score_list
|
| 2502 |
+
), f"all {subgroup_name} scores must be in [0,1]"
|
| 2503 |
+
|
| 2504 |
+
# combine all scores from each label (if there are more than 1 in each group) into a list
|
| 2505 |
+
group_scores_list = [
|
| 2506 |
+
np.concatenate(
|
| 2507 |
+
[subgroup_scores_dict[subgroup_name] for subgroup_name in name_list]
|
| 2508 |
+
)
|
| 2509 |
+
for name_list in [control_subgroup_types, comparison_subgroup_types]
|
| 2510 |
+
]
|
| 2511 |
+
|
| 2512 |
+
if any(len(scores) == 0 for scores in group_scores_list):
|
| 2513 |
+
# no comparison can be made since there is not at least one score per type
|
| 2514 |
+
h, norm_h = np.nan, np.nan
|
| 2515 |
+
else:
|
| 2516 |
+
control_mean = mean(group_scores_list[0])
|
| 2517 |
+
comparison_mean = mean(group_scores_list[1])
|
| 2518 |
+
h = 2 * (np.arcsin(np.sqrt(comparison_mean)) - np.arcsin(np.sqrt(control_mean)))
|
| 2519 |
+
norm_h = np.clip(a=h / np.pi, a_min=-1, a_max=1)
|
| 2520 |
+
|
| 2521 |
+
if not interpret:
|
| 2522 |
+
return norm_h
|
| 2523 |
+
|
| 2524 |
+
return norm_h, interpret_effect_size(h)
|
| 2525 |
+
|
| 2526 |
+
|
| 2527 |
+
def normalized_hedges_g(
|
| 2528 |
+
subgroup_scores_dict: Dict[str, List[float]],
|
| 2529 |
+
control_subgroup_types: List[str],
|
| 2530 |
+
comparison_subgroup_types: List[str],
|
| 2531 |
+
interpret=False,
|
| 2532 |
+
):
|
| 2533 |
+
"""Hedge's g effect size between mean of two samples, normalized to interval [-1,1]. Better than Cohen's d for small sample sizes.
|
| 2534 |
+
|
| 2535 |
+
Takes into account the variances within the samples, not just the means.
|
| 2536 |
+
|
| 2537 |
+
Args:
|
| 2538 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
| 2539 |
+
control_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the control (baseline) group
|
| 2540 |
+
comparison_subgroup_types: list of subgroup types (potential keys of subgroup_scores_dict) that are the group
|
| 2541 |
+
to be compared to the control group.
|
| 2542 |
+
interpret: boolean, whether to interpret the significance of the score or not
|
| 2543 |
+
Returns:
|
| 2544 |
+
float score between -1 and 1, and a string interpretation if interpret=True
|
| 2545 |
+
"""
|
| 2546 |
+
(
|
| 2547 |
+
subgroup_scores_dict,
|
| 2548 |
+
control_subgroup_types,
|
| 2549 |
+
comparison_subgroup_types,
|
| 2550 |
+
) = validate_subgroup_types(
|
| 2551 |
+
subgroup_scores_dict, control_subgroup_types, comparison_subgroup_types
|
| 2552 |
+
)
|
| 2553 |
+
|
| 2554 |
+
# combine all scores from each label (if there are more than 1 in each group) into a list
|
| 2555 |
+
group_scores_list = [
|
| 2556 |
+
np.concatenate(
|
| 2557 |
+
[subgroup_scores_dict[subgroup_name] for subgroup_name in name_list]
|
| 2558 |
+
)
|
| 2559 |
+
for name_list in [control_subgroup_types, comparison_subgroup_types]
|
| 2560 |
+
]
|
| 2561 |
+
|
| 2562 |
+
group_n = [len(scores) for scores in group_scores_list]
|
| 2563 |
+
if any(nn == 0 for nn in group_n) or all(nn <= 1 for nn in group_n):
|
| 2564 |
+
# if at least one sample size is 0 for one type, no comparison can be made at all
|
| 2565 |
+
# if both sample sizes are 1, then the denominator is undefined since divide by n1 + n2 - 2
|
| 2566 |
+
# so require at least one sample to have > 1 observation, and both to have >= 1.
|
| 2567 |
+
g, norm_g = np.nan, np.nan
|
| 2568 |
+
else:
|
| 2569 |
+
# otherwise, calculate the variances
|
| 2570 |
+
group_mean = [mean(scores) for scores in group_scores_list]
|
| 2571 |
+
# sample variance with 1 degree of freedom (denominator n-1); if n=1, return 0 since otherwise throws an error
|
| 2572 |
+
group_var = [
|
| 2573 |
+
0.0 if nn == 1 else np.var(scores, ddof=1)
|
| 2574 |
+
for scores, nn in zip(group_scores_list, group_n)
|
| 2575 |
+
]
|
| 2576 |
+
var_total = sum([(nn - 1) * vv for vv, nn in zip(group_var, group_n)])
|
| 2577 |
+
pooled_sd = np.sqrt(var_total / (sum(group_n) - 2))
|
| 2578 |
+
|
| 2579 |
+
max_absolute_value = 5
|
| 2580 |
+
gmd = float(group_mean[1] - group_mean[0])
|
| 2581 |
+
|
| 2582 |
+
if gmd == 0:
|
| 2583 |
+
# if exactly the same, return 0
|
| 2584 |
+
g = 0.0
|
| 2585 |
+
else:
|
| 2586 |
+
try:
|
| 2587 |
+
g = gmd / pooled_sd
|
| 2588 |
+
except ZeroDivisionError:
|
| 2589 |
+
# return a large effect size to avoid explosion if there is zero variance
|
| 2590 |
+
g = np.sign(gmd) * max_absolute_value
|
| 2591 |
+
|
| 2592 |
+
n = sum(group_n)
|
| 2593 |
+
if 3 < n < 50:
|
| 2594 |
+
# small sample adjustment see https://www.itl.nist.gov/div898/software/dataplot/refman2/auxillar/hedgeg.htm
|
| 2595 |
+
# the multiplier is 0 if n <= 3
|
| 2596 |
+
g *= ((n - 3) / (n - 2.25)) * np.sqrt((n - 2) / n)
|
| 2597 |
+
# clip it at a very large value so it doesn't become infinite if the variance (denominator) is very small or 0
|
| 2598 |
+
g = float(np.clip(a=g, a_min=-1 * max_absolute_value, a_max=max_absolute_value))
|
| 2599 |
+
norm_g = g / max_absolute_value
|
| 2600 |
+
|
| 2601 |
+
if not interpret:
|
| 2602 |
+
return norm_g
|
| 2603 |
+
return norm_g, interpret_effect_size(g)
|
| 2604 |
+
|
| 2605 |
+
|
| 2606 |
+
def mean_subgroup_score(
|
| 2607 |
+
subgroup_scores_dict: Dict[str, List], subgroup_types: List[str]
|
| 2608 |
+
):
|
| 2609 |
+
"""Return the mean instance score for a subset (possibly a single type) of variants (not a comparison).
|
| 2610 |
+
|
| 2611 |
+
Args:
|
| 2612 |
+
subgroup_scores_dict: dict where keys are subgroup types and values are lists of instance scores.
|
| 2613 |
+
subgroup_types: the keys (subgroup types) for which the average will be computed.
|
| 2614 |
+
|
| 2615 |
+
Returns:
|
| 2616 |
+
float score
|
| 2617 |
+
"""
|
| 2618 |
+
subgroup_scores_dict, subgroup_types, _ = validate_subgroup_types(
|
| 2619 |
+
subgroup_scores_dict, subgroup_types, []
|
| 2620 |
+
)
|
| 2621 |
+
|
| 2622 |
+
# combine all desired subgroup scores
|
| 2623 |
+
score_list = np.concatenate(
|
| 2624 |
+
[subgroup_scores_dict[subgroup_name] for subgroup_name in subgroup_types]
|
| 2625 |
+
)
|
| 2626 |
+
if len(score_list) == 0:
|
| 2627 |
+
# no scores to use
|
| 2628 |
+
return np.nan
|
| 2629 |
+
return mean(score_list)
|
| 2630 |
+
|
| 2631 |
+
|
| 2632 |
+
# metrics using mean reduction
|
| 2633 |
+
class GroupMeanAccuracy(Accuracy):
|
| 2634 |
+
reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, False]}}
|
| 2635 |
+
|
| 2636 |
+
|
| 2637 |
+
class FixedGroupMeanAccuracy(Accuracy):
|
| 2638 |
+
# the same as GroupMeanAccuracy, except the groups are fixed and are resampled together
|
| 2639 |
+
reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, True]}}
|
| 2640 |
+
|
| 2641 |
+
|
| 2642 |
+
# same as above, now using StringContainment
|
| 2643 |
+
class GroupMeanStringContainment(StringContainment):
|
| 2644 |
+
reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, False]}}
|
| 2645 |
+
|
| 2646 |
+
|
| 2647 |
+
class FixedGroupMeanStringContainment(StringContainment):
|
| 2648 |
+
# the same as GroupMeanStringContainment, except the groups are fixed and are resampled together
|
| 2649 |
+
reduction_map = {"group_mean": {"agg_func": ["mean", nan_mean, True]}}
|
| 2650 |
+
|
| 2651 |
+
|
| 2652 |
+
# take only the (fixed) group mean of baseline or other (paraphrases) scores
|
| 2653 |
+
class FixedGroupMeanBaselineAccuracy(Accuracy):
|
| 2654 |
+
subgroup_column = "variant_type"
|
| 2655 |
+
# take mean of "original" variants only
|
| 2656 |
+
reduction_map = {
|
| 2657 |
+
"group_mean": {
|
| 2658 |
+
"agg_func": [
|
| 2659 |
+
"mean_baseline",
|
| 2660 |
+
lambda scd: mean_subgroup_score(
|
| 2661 |
+
subgroup_scores_dict=scd, subgroup_types=["original"]
|
| 2662 |
+
),
|
| 2663 |
+
True,
|
| 2664 |
+
],
|
| 2665 |
+
}
|
| 2666 |
+
}
|
| 2667 |
+
|
| 2668 |
+
|
| 2669 |
+
class FixedGroupMeanParaphraseAccuracy(Accuracy):
|
| 2670 |
+
subgroup_column = "variant_type"
|
| 2671 |
+
# take mean of "paraphrase" variants only
|
| 2672 |
+
reduction_map = {
|
| 2673 |
+
"group_mean": {
|
| 2674 |
+
"agg_func": [
|
| 2675 |
+
"mean_paraphrase",
|
| 2676 |
+
lambda scd: mean_subgroup_score(
|
| 2677 |
+
subgroup_scores_dict=scd, subgroup_types=["paraphrase"]
|
| 2678 |
+
),
|
| 2679 |
+
True,
|
| 2680 |
+
],
|
| 2681 |
+
}
|
| 2682 |
+
}
|
| 2683 |
+
|
| 2684 |
+
|
| 2685 |
+
# same as above but using StringContainment
|
| 2686 |
+
class FixedGroupMeanBaselineStringContainment(StringContainment):
|
| 2687 |
+
subgroup_column = "variant_type"
|
| 2688 |
+
# take mean of "original" variants only
|
| 2689 |
+
reduction_map = {
|
| 2690 |
+
"group_mean": {
|
| 2691 |
+
"agg_func": [
|
| 2692 |
+
"mean_baseline",
|
| 2693 |
+
lambda scd: mean_subgroup_score(
|
| 2694 |
+
subgroup_scores_dict=scd, subgroup_types=["original"]
|
| 2695 |
+
),
|
| 2696 |
+
True,
|
| 2697 |
+
],
|
| 2698 |
+
}
|
| 2699 |
+
}
|
| 2700 |
+
|
| 2701 |
+
|
| 2702 |
+
class FixedGroupMeanParaphraseStringContainment(StringContainment):
|
| 2703 |
+
subgroup_column = "variant_type"
|
| 2704 |
+
# take mean of "paraphrase" variants only
|
| 2705 |
+
reduction_map = {
|
| 2706 |
+
"group_mean": {
|
| 2707 |
+
"agg_func": [
|
| 2708 |
+
"mean_paraphrase",
|
| 2709 |
+
lambda scd: mean_subgroup_score(
|
| 2710 |
+
subgroup_scores_dict=scd, subgroup_types=["paraphrase"]
|
| 2711 |
+
),
|
| 2712 |
+
True,
|
| 2713 |
+
],
|
| 2714 |
+
}
|
| 2715 |
+
}
|
| 2716 |
+
|
| 2717 |
+
|
| 2718 |
+
# using PDR
|
| 2719 |
+
class FixedGroupPDRParaphraseAccuracy(Accuracy):
|
| 2720 |
+
subgroup_column = "variant_type"
|
| 2721 |
+
reduction_map = {
|
| 2722 |
+
"group_mean": {
|
| 2723 |
+
"agg_func": [
|
| 2724 |
+
"pdr_paraphrase",
|
| 2725 |
+
lambda scd: performance_drop_rate(
|
| 2726 |
+
subgroup_scores_dict=scd,
|
| 2727 |
+
control_subgroup_types=["original"],
|
| 2728 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2729 |
+
),
|
| 2730 |
+
True,
|
| 2731 |
+
],
|
| 2732 |
+
}
|
| 2733 |
+
}
|
| 2734 |
+
|
| 2735 |
+
|
| 2736 |
+
class FixedGroupPDRParaphraseStringContainment(StringContainment):
|
| 2737 |
+
subgroup_column = "variant_type"
|
| 2738 |
+
reduction_map = {
|
| 2739 |
+
"group_mean": {
|
| 2740 |
+
"agg_func": [
|
| 2741 |
+
"pdr_paraphrase",
|
| 2742 |
+
lambda scd: performance_drop_rate(
|
| 2743 |
+
subgroup_scores_dict=scd,
|
| 2744 |
+
control_subgroup_types=["original"],
|
| 2745 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2746 |
+
),
|
| 2747 |
+
True,
|
| 2748 |
+
],
|
| 2749 |
+
}
|
| 2750 |
+
}
|
| 2751 |
+
|
| 2752 |
+
|
| 2753 |
+
class GroupMeanTokenOverlap(TokenOverlap):
|
| 2754 |
+
reduction_map = {
|
| 2755 |
+
"group_mean": {
|
| 2756 |
+
"agg_func": ["mean", nan_mean, False],
|
| 2757 |
+
"score_fields": ["f1", "precision", "recall"],
|
| 2758 |
+
}
|
| 2759 |
+
}
|
| 2760 |
+
|
| 2761 |
+
|
| 2762 |
+
# using Cohens's h for proportions
|
| 2763 |
+
class FixedGroupNormCohensHParaphraseAccuracy(Accuracy):
|
| 2764 |
+
subgroup_column = "variant_type"
|
| 2765 |
+
reduction_map = {
|
| 2766 |
+
"group_mean": {
|
| 2767 |
+
"agg_func": [
|
| 2768 |
+
"norm_cohens_h_paraphrase",
|
| 2769 |
+
lambda scd: normalized_cohens_h(
|
| 2770 |
+
subgroup_scores_dict=scd,
|
| 2771 |
+
control_subgroup_types=["original"],
|
| 2772 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2773 |
+
),
|
| 2774 |
+
True,
|
| 2775 |
+
],
|
| 2776 |
+
}
|
| 2777 |
+
}
|
| 2778 |
+
|
| 2779 |
+
|
| 2780 |
+
class FixedGroupNormCohensHParaphraseStringContainment(StringContainment):
|
| 2781 |
+
subgroup_column = "variant_type"
|
| 2782 |
+
reduction_map = {
|
| 2783 |
+
"group_mean": {
|
| 2784 |
+
"agg_func": [
|
| 2785 |
+
"norm_cohens_h_paraphrase",
|
| 2786 |
+
lambda scd: normalized_cohens_h(
|
| 2787 |
+
subgroup_scores_dict=scd,
|
| 2788 |
+
control_subgroup_types=["original"],
|
| 2789 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2790 |
+
),
|
| 2791 |
+
True,
|
| 2792 |
+
],
|
| 2793 |
+
}
|
| 2794 |
+
}
|
| 2795 |
+
|
| 2796 |
+
|
| 2797 |
+
# using Hedges' g (takes into account internal variation in group scores)
|
| 2798 |
+
class FixedGroupNormHedgesGParaphraseAccuracy(Accuracy):
|
| 2799 |
+
subgroup_column = "variant_type"
|
| 2800 |
+
reduction_map = {
|
| 2801 |
+
"group_mean": {
|
| 2802 |
+
"agg_func": [
|
| 2803 |
+
"norm_hedges_g_paraphrase",
|
| 2804 |
+
lambda scd: normalized_hedges_g(
|
| 2805 |
+
subgroup_scores_dict=scd,
|
| 2806 |
+
control_subgroup_types=["original"],
|
| 2807 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2808 |
+
),
|
| 2809 |
+
True,
|
| 2810 |
+
],
|
| 2811 |
+
}
|
| 2812 |
+
}
|
| 2813 |
+
|
| 2814 |
+
|
| 2815 |
+
class FixedGroupNormHedgesGParaphraseStringContainment(StringContainment):
|
| 2816 |
+
subgroup_column = "variant_type"
|
| 2817 |
+
reduction_map = {
|
| 2818 |
+
"group_mean": {
|
| 2819 |
+
"agg_func": [
|
| 2820 |
+
"norm_hedges_g_paraphrase",
|
| 2821 |
+
lambda scd: normalized_hedges_g(
|
| 2822 |
+
subgroup_scores_dict=scd,
|
| 2823 |
+
control_subgroup_types=["original"],
|
| 2824 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2825 |
+
),
|
| 2826 |
+
True,
|
| 2827 |
+
],
|
| 2828 |
+
}
|
| 2829 |
+
}
|
| 2830 |
+
|
| 2831 |
+
|
| 2832 |
+
# for above metrics, take absolute value of group score first; this measures variation in either direction
|
| 2833 |
+
class FixedGroupAbsvalNormCohensHParaphraseAccuracy(Accuracy):
|
| 2834 |
+
subgroup_column = "variant_type"
|
| 2835 |
+
reduction_map = {
|
| 2836 |
+
"group_mean": {
|
| 2837 |
+
"agg_func": [
|
| 2838 |
+
"absval_norm_cohens_h_paraphrase",
|
| 2839 |
+
lambda scd: np.abs(
|
| 2840 |
+
normalized_cohens_h(
|
| 2841 |
+
subgroup_scores_dict=scd,
|
| 2842 |
+
control_subgroup_types=["original"],
|
| 2843 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2844 |
+
)
|
| 2845 |
+
),
|
| 2846 |
+
True,
|
| 2847 |
+
],
|
| 2848 |
+
}
|
| 2849 |
+
}
|
| 2850 |
+
|
| 2851 |
+
|
| 2852 |
+
class FixedGroupAbsvalNormCohensHParaphraseStringContainment(StringContainment):
|
| 2853 |
+
subgroup_column = "variant_type"
|
| 2854 |
+
reduction_map = {
|
| 2855 |
+
"group_mean": {
|
| 2856 |
+
"agg_func": [
|
| 2857 |
+
"absval_norm_cohens_h_paraphrase",
|
| 2858 |
+
lambda scd: np.abs(
|
| 2859 |
+
normalized_cohens_h(
|
| 2860 |
+
subgroup_scores_dict=scd,
|
| 2861 |
+
control_subgroup_types=["original"],
|
| 2862 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2863 |
+
)
|
| 2864 |
+
),
|
| 2865 |
+
True,
|
| 2866 |
+
],
|
| 2867 |
+
}
|
| 2868 |
+
}
|
| 2869 |
+
|
| 2870 |
+
|
| 2871 |
+
class FixedGroupAbsvalNormHedgesGParaphraseAccuracy(Accuracy):
|
| 2872 |
+
subgroup_column = "variant_type"
|
| 2873 |
+
reduction_map = {
|
| 2874 |
+
"group_mean": {
|
| 2875 |
+
"agg_func": [
|
| 2876 |
+
"absval_norm_hedges_g_paraphrase",
|
| 2877 |
+
lambda scd: np.abs(
|
| 2878 |
+
normalized_hedges_g(
|
| 2879 |
+
subgroup_scores_dict=scd,
|
| 2880 |
+
control_subgroup_types=["original"],
|
| 2881 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2882 |
+
)
|
| 2883 |
+
),
|
| 2884 |
+
True,
|
| 2885 |
+
],
|
| 2886 |
+
}
|
| 2887 |
+
}
|
| 2888 |
+
|
| 2889 |
+
|
| 2890 |
+
class FixedGroupAbsvalNormHedgesGParaphraseStringContainment(StringContainment):
|
| 2891 |
+
subgroup_column = "variant_type"
|
| 2892 |
+
reduction_map = {
|
| 2893 |
+
"group_mean": {
|
| 2894 |
+
"agg_func": [
|
| 2895 |
+
"absval_norm_hedges_g_paraphrase",
|
| 2896 |
+
lambda scd: np.abs(
|
| 2897 |
+
normalized_hedges_g(
|
| 2898 |
+
subgroup_scores_dict=scd,
|
| 2899 |
+
control_subgroup_types=["original"],
|
| 2900 |
+
comparison_subgroup_types=["paraphrase"],
|
| 2901 |
+
)
|
| 2902 |
+
),
|
| 2903 |
+
True,
|
| 2904 |
+
],
|
| 2905 |
+
}
|
| 2906 |
+
}
|