Upload metrics.py with huggingface_hub
Browse files- metrics.py +164 -10
metrics.py
CHANGED
|
@@ -16,7 +16,7 @@ from scipy.stats import bootstrap
|
|
| 16 |
from scipy.stats._warnings_errors import DegenerateDataWarning
|
| 17 |
|
| 18 |
from .artifact import Artifact
|
| 19 |
-
from .dataclass import AbstractField, InternalField, OptionalField
|
| 20 |
from .logging_utils import get_logger
|
| 21 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
| 22 |
from .operator import (
|
|
@@ -648,6 +648,9 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 648 |
|
| 649 |
reduction_map: Dict[str, List[str]] = AbstractField()
|
| 650 |
|
|
|
|
|
|
|
|
|
|
| 651 |
def _validate_group_mean_reduction(self, instances: List[dict]):
|
| 652 |
"""Ensure that group_mean reduction_map is properly formatted.
|
| 653 |
|
|
@@ -827,10 +830,21 @@ class InstanceMetric(SingleStreamOperator, MetricWithConfidenceInterval):
|
|
| 827 |
instances = []
|
| 828 |
|
| 829 |
for instance in stream:
|
| 830 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
self._validate_prediction(pred)
|
| 832 |
self._validate_reference(refs)
|
| 833 |
-
task_data = instance["task_data"] if "task_data" in instance else {}
|
| 834 |
|
| 835 |
instance_score = self.compute(
|
| 836 |
references=refs, prediction=pred, task_data=task_data
|
|
@@ -1033,7 +1047,6 @@ class MetricPipeline(MultiStreamOperator, Metric):
|
|
| 1033 |
[f"score/instance/{self.main_score}", "score/instance/score"],
|
| 1034 |
[f"score/global/{self.main_score}", "score/global/score"],
|
| 1035 |
],
|
| 1036 |
-
use_query=True,
|
| 1037 |
)
|
| 1038 |
|
| 1039 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
|
@@ -1447,13 +1460,15 @@ class Rouge(HuggingfaceMetric):
|
|
| 1447 |
|
| 1448 |
|
| 1449 |
# Computes char edit distance, ignoring whitespace
|
| 1450 |
-
class
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
ci_scores = [
|
| 1454 |
prediction_type = "str"
|
| 1455 |
single_reference_per_prediction = True
|
| 1456 |
|
|
|
|
|
|
|
| 1457 |
_requirements_list: List[str] = ["editdistance"]
|
| 1458 |
|
| 1459 |
def prepare(self):
|
|
@@ -1467,9 +1482,21 @@ class CharEditDistanceAccuracy(InstanceMetric):
|
|
| 1467 |
formatted_reference = "".join(references[0].split())
|
| 1468 |
max_length = max(len(formatted_reference), len(formatted_prediction))
|
| 1469 |
if max_length == 0:
|
| 1470 |
-
return {
|
| 1471 |
edit_dist = self.eval(formatted_reference, formatted_prediction)
|
| 1472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1473 |
|
| 1474 |
|
| 1475 |
class Wer(HuggingfaceMetric):
|
|
@@ -1853,6 +1880,8 @@ class BertScore(HuggingfaceBulkMetric):
|
|
| 1853 |
ci_scores = ["f1", "precision", "recall"]
|
| 1854 |
model_name: str
|
| 1855 |
|
|
|
|
|
|
|
| 1856 |
_requirements_list: List[str] = ["bert_score"]
|
| 1857 |
|
| 1858 |
def prepare(self):
|
|
@@ -1949,6 +1978,38 @@ class Reward(BulkInstanceMetric):
|
|
| 1949 |
return self.pipe(inputs, batch_size=self.batch_size)
|
| 1950 |
|
| 1951 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1952 |
class LlamaIndexCorrectness(InstanceMetric):
|
| 1953 |
"""LlamaIndex based metric class for evaluating correctness."""
|
| 1954 |
|
|
@@ -3320,6 +3381,99 @@ class BinaryMaxAccuracy(GlobalMetric):
|
|
| 3320 |
return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
|
| 3321 |
|
| 3322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3323 |
KO_ERROR_MESSAGE = """
|
| 3324 |
|
| 3325 |
Additional dependencies required. To install them, run:
|
|
|
|
| 16 |
from scipy.stats._warnings_errors import DegenerateDataWarning
|
| 17 |
|
| 18 |
from .artifact import Artifact
|
| 19 |
+
from .dataclass import AbstractField, InternalField, NonPositionalField, OptionalField
|
| 20 |
from .logging_utils import get_logger
|
| 21 |
from .metric_utils import InstanceInput, MetricRequest, MetricResponse
|
| 22 |
from .operator import (
|
|
|
|
| 648 |
|
| 649 |
reduction_map: Dict[str, List[str]] = AbstractField()
|
| 650 |
|
| 651 |
+
reference_field: str = NonPositionalField(default="references")
|
| 652 |
+
prediction_field: str = NonPositionalField(default="prediction")
|
| 653 |
+
|
| 654 |
def _validate_group_mean_reduction(self, instances: List[dict]):
|
| 655 |
"""Ensure that group_mean reduction_map is properly formatted.
|
| 656 |
|
|
|
|
| 830 |
instances = []
|
| 831 |
|
| 832 |
for instance in stream:
|
| 833 |
+
task_data = instance["task_data"] if "task_data" in instance else {}
|
| 834 |
+
|
| 835 |
+
if self.reference_field == "references":
|
| 836 |
+
refs = instance["references"]
|
| 837 |
+
else:
|
| 838 |
+
refs = task_data[self.reference_field]
|
| 839 |
+
if not isinstance(refs, list):
|
| 840 |
+
refs = [refs]
|
| 841 |
+
if self.prediction_field == "prediction":
|
| 842 |
+
pred = instance["prediction"]
|
| 843 |
+
else:
|
| 844 |
+
pred = task_data[self.prediction_field]
|
| 845 |
+
|
| 846 |
self._validate_prediction(pred)
|
| 847 |
self._validate_reference(refs)
|
|
|
|
| 848 |
|
| 849 |
instance_score = self.compute(
|
| 850 |
references=refs, prediction=pred, task_data=task_data
|
|
|
|
| 1047 |
[f"score/instance/{self.main_score}", "score/instance/score"],
|
| 1048 |
[f"score/global/{self.main_score}", "score/global/score"],
|
| 1049 |
],
|
|
|
|
| 1050 |
)
|
| 1051 |
|
| 1052 |
def process(self, multi_stream: MultiStream) -> MultiStream:
|
|
|
|
| 1460 |
|
| 1461 |
|
| 1462 |
# Computes char edit distance, ignoring whitespace
|
| 1463 |
+
class CharEditDistance(InstanceMetric):
|
| 1464 |
+
main_score = "char_edit_distance"
|
| 1465 |
+
reduction_map = {"mean": [main_score]}
|
| 1466 |
+
ci_scores = [main_score]
|
| 1467 |
prediction_type = "str"
|
| 1468 |
single_reference_per_prediction = True
|
| 1469 |
|
| 1470 |
+
accuracy_metric = False
|
| 1471 |
+
|
| 1472 |
_requirements_list: List[str] = ["editdistance"]
|
| 1473 |
|
| 1474 |
def prepare(self):
|
|
|
|
| 1482 |
formatted_reference = "".join(references[0].split())
|
| 1483 |
max_length = max(len(formatted_reference), len(formatted_prediction))
|
| 1484 |
if max_length == 0:
|
| 1485 |
+
return {self.main_score: 0.0}
|
| 1486 |
edit_dist = self.eval(formatted_reference, formatted_prediction)
|
| 1487 |
+
if self.accuracy_metric:
|
| 1488 |
+
score = 1 - edit_dist / max_length
|
| 1489 |
+
else:
|
| 1490 |
+
score = edit_dist
|
| 1491 |
+
return {self.main_score: score}
|
| 1492 |
+
|
| 1493 |
+
|
| 1494 |
+
class CharEditDistanceAccuracy(CharEditDistance):
|
| 1495 |
+
main_score = "char_edit_dist_accuracy"
|
| 1496 |
+
reduction_map = {"mean": [main_score]}
|
| 1497 |
+
ci_scores = [main_score]
|
| 1498 |
+
|
| 1499 |
+
accuracy_metric = True
|
| 1500 |
|
| 1501 |
|
| 1502 |
class Wer(HuggingfaceMetric):
|
|
|
|
| 1880 |
ci_scores = ["f1", "precision", "recall"]
|
| 1881 |
model_name: str
|
| 1882 |
|
| 1883 |
+
prediction_type = "str"
|
| 1884 |
+
|
| 1885 |
_requirements_list: List[str] = ["bert_score"]
|
| 1886 |
|
| 1887 |
def prepare(self):
|
|
|
|
| 1978 |
return self.pipe(inputs, batch_size=self.batch_size)
|
| 1979 |
|
| 1980 |
|
| 1981 |
+
class Detector(BulkInstanceMetric):
|
| 1982 |
+
reduction_map = {"mean": ["score"]}
|
| 1983 |
+
main_score = "score"
|
| 1984 |
+
batch_size: int = 32
|
| 1985 |
+
|
| 1986 |
+
prediction_type = "str"
|
| 1987 |
+
|
| 1988 |
+
model_name: str
|
| 1989 |
+
|
| 1990 |
+
_requirements_list: List[str] = ["transformers", "torch"]
|
| 1991 |
+
|
| 1992 |
+
def prepare(self):
|
| 1993 |
+
super().prepare()
|
| 1994 |
+
import torch
|
| 1995 |
+
from transformers import pipeline
|
| 1996 |
+
|
| 1997 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 1998 |
+
self.pipe = pipeline(
|
| 1999 |
+
"text-classification", model=self.model_name, device=device
|
| 2000 |
+
)
|
| 2001 |
+
|
| 2002 |
+
def compute(
|
| 2003 |
+
self,
|
| 2004 |
+
references: List[List[Any]],
|
| 2005 |
+
predictions: List[Any],
|
| 2006 |
+
task_data: List[Dict],
|
| 2007 |
+
) -> List[Dict[str, Any]]:
|
| 2008 |
+
# compute the metric
|
| 2009 |
+
# add function_to_apply="none" to disable sigmoid
|
| 2010 |
+
return self.pipe(predictions, batch_size=self.batch_size)
|
| 2011 |
+
|
| 2012 |
+
|
| 2013 |
class LlamaIndexCorrectness(InstanceMetric):
|
| 2014 |
"""LlamaIndex based metric class for evaluating correctness."""
|
| 2015 |
|
|
|
|
| 3381 |
return {self.main_score: best_acc, "best_thr_max_acc": best_thr}
|
| 3382 |
|
| 3383 |
|
| 3384 |
+
######################
|
| 3385 |
+
# RerankRecallMetric #
|
| 3386 |
+
|
| 3387 |
+
|
| 3388 |
+
def pytrec_eval_at_k(results, qrels, at_k, metric_name):
|
| 3389 |
+
import pandas as pd
|
| 3390 |
+
import pytrec_eval
|
| 3391 |
+
|
| 3392 |
+
metric = {}
|
| 3393 |
+
|
| 3394 |
+
for k in at_k:
|
| 3395 |
+
metric[f"{metric_name}@{k}"] = 0.0
|
| 3396 |
+
|
| 3397 |
+
metric_string = f"{metric_name}." + ",".join([str(k) for k in at_k])
|
| 3398 |
+
# print('metric_string = ', metric_string)
|
| 3399 |
+
evaluator = pytrec_eval.RelevanceEvaluator(
|
| 3400 |
+
qrels, {"ndcg", metric_string}
|
| 3401 |
+
) # {map_string, ndcg_string, recall_string, precision_string})
|
| 3402 |
+
scores = evaluator.evaluate(results)
|
| 3403 |
+
scores = pd.DataFrame(scores).transpose()
|
| 3404 |
+
|
| 3405 |
+
keys = []
|
| 3406 |
+
column_map = {}
|
| 3407 |
+
for k in at_k:
|
| 3408 |
+
keys.append(f"{metric_name}_{k}")
|
| 3409 |
+
column_map[f"{metric_name}_{k}"] = k
|
| 3410 |
+
scores[keys].rename(columns=column_map)
|
| 3411 |
+
|
| 3412 |
+
return scores
|
| 3413 |
+
|
| 3414 |
+
|
| 3415 |
+
class RerankRecall(GlobalMetric):
|
| 3416 |
+
"""RerankRecall: measures the quality of reranking with respect to ground truth ranking scores.
|
| 3417 |
+
|
| 3418 |
+
This metric measures ranking performance across a dataset. The
|
| 3419 |
+
references for a query will have a score of 1 for the gold passage
|
| 3420 |
+
and 0 for all other passages. The model returns scores in [0,1]
|
| 3421 |
+
for each passage,query pair. This metric measures recall at k by
|
| 3422 |
+
testing that the predicted score for the gold passage,query pair
|
| 3423 |
+
is at least the k'th highest for all passages for that query. A
|
| 3424 |
+
query receives 1 if so, and 0 if not. The 1's and 0's are
|
| 3425 |
+
averaged across the dataset.
|
| 3426 |
+
|
| 3427 |
+
query_id_field selects the field containing the query id for an instance.
|
| 3428 |
+
passage_id_field selects the field containing the passage id for an instance.
|
| 3429 |
+
at_k selects the value of k used to compute recall.
|
| 3430 |
+
|
| 3431 |
+
"""
|
| 3432 |
+
|
| 3433 |
+
main_score = "recall_at_5"
|
| 3434 |
+
query_id_field: str = "query_id"
|
| 3435 |
+
passage_id_field: str = "passage_id"
|
| 3436 |
+
at_k: List[int] = [1, 2, 5]
|
| 3437 |
+
|
| 3438 |
+
# This doesn't seem to make sense
|
| 3439 |
+
n_resamples = None
|
| 3440 |
+
|
| 3441 |
+
_requirements_list: List[str] = ["pandas", "pytrec_eval"]
|
| 3442 |
+
|
| 3443 |
+
def compute(
|
| 3444 |
+
self,
|
| 3445 |
+
references: List[List[str]],
|
| 3446 |
+
predictions: List[str],
|
| 3447 |
+
task_data: List[Dict],
|
| 3448 |
+
):
|
| 3449 |
+
# Collect relevance score and ref per query/passage pair
|
| 3450 |
+
results = {}
|
| 3451 |
+
qrels = {}
|
| 3452 |
+
for ref, pred, data in zip(references, predictions, task_data):
|
| 3453 |
+
qid = data[self.query_id_field]
|
| 3454 |
+
pid = data[self.passage_id_field]
|
| 3455 |
+
if qid not in results:
|
| 3456 |
+
results[qid] = {}
|
| 3457 |
+
qrels[qid] = {}
|
| 3458 |
+
# Convert string-wrapped float to regular float
|
| 3459 |
+
try:
|
| 3460 |
+
results[qid][pid] = float(pred)
|
| 3461 |
+
except ValueError:
|
| 3462 |
+
# Card testing feeds nonnumeric values in, so catch that.
|
| 3463 |
+
results[qid][pid] = np.nan
|
| 3464 |
+
|
| 3465 |
+
# There's always a single reference per pid/qid pair
|
| 3466 |
+
qrels[qid][pid] = int(ref[0])
|
| 3467 |
+
|
| 3468 |
+
# Compute recall @ 5
|
| 3469 |
+
scores = pytrec_eval_at_k(results, qrels, self.at_k, "recall")
|
| 3470 |
+
# print(scores.describe())
|
| 3471 |
+
# pytrec returns numpy float32
|
| 3472 |
+
return {
|
| 3473 |
+
f"recall_at_{i}": float(scores[f"recall_{i}"].mean()) for i in self.at_k
|
| 3474 |
+
}
|
| 3475 |
+
|
| 3476 |
+
|
| 3477 |
KO_ERROR_MESSAGE = """
|
| 3478 |
|
| 3479 |
Additional dependencies required. To install them, run:
|