Upload folder using huggingface_hub
Browse files- api.py +9 -2
- metric_utils.py +5 -1
- metrics.py +435 -19
- operators.py +34 -18
- serializers.py +20 -1
- task.py +6 -1
- templates.py +91 -10
- types.py +9 -0
- version.py +1 -1
api.py
CHANGED
|
@@ -7,6 +7,7 @@ from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
|
| 7 |
from .artifact import fetch_artifact
|
| 8 |
from .card import TaskCard
|
| 9 |
from .dataset_utils import get_dataset_artifact
|
|
|
|
| 10 |
from .inference import (
|
| 11 |
InferenceEngine,
|
| 12 |
LogProbInferenceEngine,
|
|
@@ -198,8 +199,14 @@ def load_dataset(
|
|
| 198 |
).with_transform(loads_instance)
|
| 199 |
|
| 200 |
|
| 201 |
-
def evaluate(
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
def post_process(predictions, data) -> List[Dict[str, Any]]:
|
|
|
|
| 7 |
from .artifact import fetch_artifact
|
| 8 |
from .card import TaskCard
|
| 9 |
from .dataset_utils import get_dataset_artifact
|
| 10 |
+
from .error_utils import UnitxtError
|
| 11 |
from .inference import (
|
| 12 |
InferenceEngine,
|
| 13 |
LogProbInferenceEngine,
|
|
|
|
| 199 |
).with_transform(loads_instance)
|
| 200 |
|
| 201 |
|
| 202 |
+
def evaluate(
|
| 203 |
+
predictions, dataset: Union[Dataset, IterableDataset] = None, data=None
|
| 204 |
+
) -> EvaluationResults:
|
| 205 |
+
if dataset is None and data is None:
|
| 206 |
+
raise UnitxtError(message="Specify 'dataset' in evaluate")
|
| 207 |
+
if data is not None:
|
| 208 |
+
dataset = data # for backward compatibility
|
| 209 |
+
return _compute(predictions=predictions, references=dataset)
|
| 210 |
|
| 211 |
|
| 212 |
def post_process(predictions, data) -> List[Dict[str, Any]]:
|
metric_utils.py
CHANGED
|
@@ -38,7 +38,11 @@ constants = get_constants()
|
|
| 38 |
|
| 39 |
|
| 40 |
def nan_mean(scores):
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
def nan_mean(scores):
|
| 41 |
+
result = mean(score for score in scores if score == score)
|
| 42 |
+
try:
|
| 43 |
+
return float(result)
|
| 44 |
+
except:
|
| 45 |
+
return result
|
| 46 |
|
| 47 |
|
| 48 |
class FromPredictionsAndOriginalData(StreamInitializerOperator):
|
metrics.py
CHANGED
|
@@ -7,10 +7,10 @@ import string
|
|
| 7 |
import uuid
|
| 8 |
import warnings
|
| 9 |
from abc import ABC, abstractmethod
|
| 10 |
-
from collections import Counter, defaultdict
|
| 11 |
from dataclasses import field
|
| 12 |
from functools import lru_cache
|
| 13 |
-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
| 14 |
|
| 15 |
import numpy
|
| 16 |
import numpy as np
|
|
@@ -317,6 +317,398 @@ class Metric(Artifact):
|
|
| 317 |
instance["score"]["global"].pop(score_ci)
|
| 318 |
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
class MetricWithConfidenceInterval(Metric):
|
| 321 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
| 322 |
# Use None to disable confidence interval computation.
|
|
@@ -539,10 +931,10 @@ class MetricWithConfidenceInterval(Metric):
|
|
| 539 |
confidence_level=self.confidence_level,
|
| 540 |
random_state=random_gen,
|
| 541 |
).confidence_interval
|
| 542 |
-
result["score_ci_low"] = ci.low
|
| 543 |
-
result["score_ci_high"] = ci.high
|
| 544 |
-
result[f"{score_name}_ci_low"] = ci.low
|
| 545 |
-
result[f"{score_name}_ci_high"] = ci.high
|
| 546 |
return result
|
| 547 |
|
| 548 |
|
|
@@ -1732,7 +2124,7 @@ class HuggingfaceMetric(GlobalMetric):
|
|
| 1732 |
**self.hf_compute_args,
|
| 1733 |
)
|
| 1734 |
if self.hf_main_score:
|
| 1735 |
-
result[self.main_score] = result[self.hf_main_score]
|
| 1736 |
del result[self.hf_main_score]
|
| 1737 |
if self.scale != 1.0:
|
| 1738 |
assert (
|
|
@@ -1752,6 +2144,8 @@ class HuggingfaceMetric(GlobalMetric):
|
|
| 1752 |
result[key], float
|
| 1753 |
), "Scaled field '{key}' is not float: {result[key]}"
|
| 1754 |
result[key] /= self.scale
|
|
|
|
|
|
|
| 1755 |
return result
|
| 1756 |
|
| 1757 |
|
|
@@ -1837,17 +2231,49 @@ class HuggingfaceInstanceMetric(InstanceMetric):
|
|
| 1837 |
return score
|
| 1838 |
|
| 1839 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1840 |
class Meteor(InstanceMetric):
|
| 1841 |
main_score = "meteor"
|
| 1842 |
ci_scores = ["meteor"]
|
| 1843 |
reduction_map = {"mean": ["meteor"]}
|
| 1844 |
prediction_type = str
|
| 1845 |
|
| 1846 |
-
_requirements_list: List[str] = ["nltk"]
|
| 1847 |
alpha: float = 0.9
|
| 1848 |
beta: int = 3
|
| 1849 |
gamma: float = 0.5
|
| 1850 |
-
# unitxt uses nltk version >= 3.8
|
| 1851 |
|
| 1852 |
def prepare(self):
|
| 1853 |
super().prepare()
|
|
@@ -1861,16 +2287,6 @@ class Meteor(InstanceMetric):
|
|
| 1861 |
self.word_tokenize = word_tokenize
|
| 1862 |
self.meteor_score = meteor_score
|
| 1863 |
|
| 1864 |
-
def verify(self):
|
| 1865 |
-
import importlib.metadata as importlib_metadata
|
| 1866 |
-
|
| 1867 |
-
from datasets.config import version
|
| 1868 |
-
|
| 1869 |
-
nltk_version = version.parse(importlib_metadata.version("nltk"))
|
| 1870 |
-
assert nltk_version >= version.Version(
|
| 1871 |
-
"3.6.6"
|
| 1872 |
-
), "nltk version must be at least 3.6.6"
|
| 1873 |
-
|
| 1874 |
def compute(self, references, prediction, task_data):
|
| 1875 |
score = self.meteor_score.meteor_score(
|
| 1876 |
[self.word_tokenize(ref) for ref in references],
|
|
|
|
| 7 |
import uuid
|
| 8 |
import warnings
|
| 9 |
from abc import ABC, abstractmethod
|
| 10 |
+
from collections import Counter, defaultdict, namedtuple
|
| 11 |
from dataclasses import field
|
| 12 |
from functools import lru_cache
|
| 13 |
+
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union
|
| 14 |
|
| 15 |
import numpy
|
| 16 |
import numpy as np
|
|
|
|
| 317 |
instance["score"]["global"].pop(score_ci)
|
| 318 |
|
| 319 |
|
| 320 |
+
def new_random_generator():
|
| 321 |
+
# The np.random.default_rng expects a 32-bit int, while hash(..) can return a 64-bit integer.
|
| 322 |
+
# So use '& MAX_32BIT' to get a 32-bit seed.
|
| 323 |
+
_max_32bit = 2**32 - 1
|
| 324 |
+
return np.random.default_rng(hash(get_seed()) & _max_32bit)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class ConfidenceIntervalMixin(Artifact):
|
| 328 |
+
n_resamples: int = 1000
|
| 329 |
+
confidence_level: float = 0.95
|
| 330 |
+
ci_score_names: List[str] = None
|
| 331 |
+
|
| 332 |
+
@abstractmethod
|
| 333 |
+
def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
|
| 334 |
+
pass
|
| 335 |
+
|
| 336 |
+
def get_statistic(self, data: List[Any], score_names: List[str]):
|
| 337 |
+
def statistic_function(indices, axis=0):
|
| 338 |
+
# indices might be a 1D or 2D array, depending on bootstrap internals
|
| 339 |
+
# For simplicity, ensure we handle them as 1D.
|
| 340 |
+
indices = np.atleast_1d(indices).astype(int)
|
| 341 |
+
|
| 342 |
+
# Gather the subset
|
| 343 |
+
sample = [data[i] for i in indices]
|
| 344 |
+
|
| 345 |
+
# Compute metrics on this sample
|
| 346 |
+
scores = self._sample_to_scores(sample)
|
| 347 |
+
|
| 348 |
+
# Return them in consistent order
|
| 349 |
+
return np.array([scores[m] for m in score_names])
|
| 350 |
+
|
| 351 |
+
return statistic_function
|
| 352 |
+
|
| 353 |
+
def bootstrap(self, data: List[Any], score_names: List[str]):
|
| 354 |
+
if self.ci_score_names is not None:
|
| 355 |
+
score_names = self.ci_score_names
|
| 356 |
+
|
| 357 |
+
intervals = bootstrap(
|
| 358 |
+
(np.arange(len(data)),),
|
| 359 |
+
statistic=self.get_statistic(data, score_names),
|
| 360 |
+
n_resamples=self.n_resamples,
|
| 361 |
+
confidence_level=self.confidence_level,
|
| 362 |
+
random_state=new_random_generator(),
|
| 363 |
+
paired=False,
|
| 364 |
+
vectorized=False, # set to True if your statistic function is vectorized
|
| 365 |
+
method="BCa",
|
| 366 |
+
).confidence_interval
|
| 367 |
+
|
| 368 |
+
result = {}
|
| 369 |
+
for i, metric in enumerate(score_names):
|
| 370 |
+
result[f"{metric}_ci_low"] = float(intervals.low[i])
|
| 371 |
+
result[f"{metric}_ci_high"] = float(intervals.high[i])
|
| 372 |
+
|
| 373 |
+
return result
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
from typing import Generic, TypeVar, NamedTuple
|
| 377 |
+
from dataclasses import dataclass
|
| 378 |
+
|
| 379 |
+
IntermediateType = TypeVar("IntermediateType")
|
| 380 |
+
PredictionType = TypeVar("PredictionType")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class EvaluationInput(tuple, Generic[PredictionType]):
|
| 384 |
+
def __new__(
|
| 385 |
+
cls,
|
| 386 |
+
prediction: PredictionType,
|
| 387 |
+
references: List[PredictionType],
|
| 388 |
+
task_data: Dict[str, Any],
|
| 389 |
+
) -> "EvaluationInput[PredictionType]":
|
| 390 |
+
return super().__new__(cls, (prediction, references, task_data))
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def is_original_key(key):
|
| 394 |
+
if (
|
| 395 |
+
key.endswith("_ci_low")
|
| 396 |
+
or key.endswith("_ci_high")
|
| 397 |
+
or key == "score"
|
| 398 |
+
or key == "num_of_instances"
|
| 399 |
+
or key == "score_name"
|
| 400 |
+
):
|
| 401 |
+
return False
|
| 402 |
+
return True
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class MapReduceMetric(
|
| 406 |
+
StreamOperator,
|
| 407 |
+
Metric,
|
| 408 |
+
ConfidenceIntervalMixin,
|
| 409 |
+
Generic[PredictionType, IntermediateType],
|
| 410 |
+
):
|
| 411 |
+
score_prefix = ""
|
| 412 |
+
reference_field: str = NonPositionalField(default="references")
|
| 413 |
+
prediction_field: str = NonPositionalField(default="prediction")
|
| 414 |
+
|
| 415 |
+
def map(
|
| 416 |
+
self,
|
| 417 |
+
prediction: PredictionType,
|
| 418 |
+
references: List[PredictionType],
|
| 419 |
+
task_data: Dict[str, Any],
|
| 420 |
+
) -> IntermediateType:
|
| 421 |
+
raise NotImplementedError()
|
| 422 |
+
|
| 423 |
+
def reduce_one(self, intermidate: IntermediateType):
|
| 424 |
+
return self.reduce([intermidate])
|
| 425 |
+
|
| 426 |
+
@abstractmethod
|
| 427 |
+
def reduce(self, intermediates: List[IntermediateType]) -> Dict[str, Any]:
|
| 428 |
+
return {}
|
| 429 |
+
|
| 430 |
+
def disable_confidence_interval_calculation(self):
|
| 431 |
+
self.n_resamples = None
|
| 432 |
+
|
| 433 |
+
def annotate_scores(self, scores):
|
| 434 |
+
scores = {
|
| 435 |
+
**{self.score_prefix + key: val for key, val in scores.items()},
|
| 436 |
+
"score_name": self.score_prefix + self.main_score,
|
| 437 |
+
"score": scores[self.main_score],
|
| 438 |
+
}
|
| 439 |
+
for level in ["high", "low"]:
|
| 440 |
+
if f"{self.main_score}_ci_{level}" in scores:
|
| 441 |
+
scores[f"score_ci_{level}"] = scores[f"{self.main_score}_ci_{level}"]
|
| 442 |
+
return scores
|
| 443 |
+
|
| 444 |
+
def _sample_to_scores(self, sample: List[Any]) -> Dict[str, Any]:
|
| 445 |
+
return self.reduce(sample)
|
| 446 |
+
|
| 447 |
+
def reduce_and_bootstrap(
|
| 448 |
+
self, intermediates: List[IntermediateType]
|
| 449 |
+
) -> Dict[str, Any]:
|
| 450 |
+
scores = self.reduce(intermediates)
|
| 451 |
+
score_names = [k for k, v in scores.items() if isinstance(v, float)]
|
| 452 |
+
if self.n_resamples is None:
|
| 453 |
+
return scores
|
| 454 |
+
intervals = self.bootstrap(intermediates, score_names)
|
| 455 |
+
return {**scores, **intervals}
|
| 456 |
+
|
| 457 |
+
def _instance_to_evaluation_input(
|
| 458 |
+
self, instance: Dict[str, Any]
|
| 459 |
+
) -> EvaluationInput[PredictionType]:
|
| 460 |
+
instance = self.verify_instance(instance)
|
| 461 |
+
|
| 462 |
+
task_data = instance.get("task_data", {})
|
| 463 |
+
|
| 464 |
+
if self.reference_field == "references":
|
| 465 |
+
references = instance["references"]
|
| 466 |
+
else:
|
| 467 |
+
references = task_data[self.reference_field]
|
| 468 |
+
if not isinstance(references, list):
|
| 469 |
+
references = [references]
|
| 470 |
+
if self.prediction_field == "prediction":
|
| 471 |
+
prediction = instance["prediction"]
|
| 472 |
+
else:
|
| 473 |
+
prediction = task_data[self.prediction_field]
|
| 474 |
+
|
| 475 |
+
self._validate_prediction(prediction)
|
| 476 |
+
self._validate_reference(references)
|
| 477 |
+
|
| 478 |
+
return EvaluationInput[PredictionType](
|
| 479 |
+
prediction=prediction, references=references, task_data=task_data
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
def _instances_stream_to_evaluation_inputs(
|
| 483 |
+
self, stream: Stream
|
| 484 |
+
) -> Generator[EvaluationInput[PredictionType], None, None]:
|
| 485 |
+
for instance in stream:
|
| 486 |
+
yield self._instance_to_evaluation_input(instance)
|
| 487 |
+
|
| 488 |
+
def map_stream(
|
| 489 |
+
self,
|
| 490 |
+
evaluation_inputs_stream: Generator[
|
| 491 |
+
EvaluationInput[PredictionType], None, None
|
| 492 |
+
],
|
| 493 |
+
):
|
| 494 |
+
intermediates = []
|
| 495 |
+
for prediction, references, task_data in evaluation_inputs_stream:
|
| 496 |
+
intermediate = self.map(
|
| 497 |
+
prediction=prediction, references=references, task_data=task_data
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
intermediates.append(intermediate)
|
| 501 |
+
return intermediates
|
| 502 |
+
|
| 503 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None):
|
| 504 |
+
instances_scores, global_scores = self.compute(stream, stream_name)
|
| 505 |
+
for i, (instance, instance_scores) in enumerate(zip(stream, instances_scores)):
|
| 506 |
+
previous_score = instance.get("score", {"global": {}, "instance": {}})
|
| 507 |
+
|
| 508 |
+
if i == 0:
|
| 509 |
+
for key in global_scores:
|
| 510 |
+
if is_original_key(key) and key in previous_score["global"]:
|
| 511 |
+
UnitxtWarning(
|
| 512 |
+
message=f"Metric '{key}' that has just been evaluated with value {global_scores[key]}, is already recorded "
|
| 513 |
+
f"to have value {previous_score['global'][key]} by a previous metric evaluation on this instance or stream. "
|
| 514 |
+
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
|
| 515 |
+
f"which will yield, in this case, a score named: 'my_second_{key}')",
|
| 516 |
+
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
global_scores = {**previous_score["global"], **global_scores}
|
| 520 |
+
instance_scores = {**previous_score["instance"], **instance_scores}
|
| 521 |
+
|
| 522 |
+
yield {
|
| 523 |
+
**instance,
|
| 524 |
+
"score": {"global": global_scores, "instance": instance_scores},
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
def compute(self, stream: Stream, stream_name: Optional[str] = None):
|
| 528 |
+
evaluation_inputs_stream = self._instances_stream_to_evaluation_inputs(stream)
|
| 529 |
+
intermediates_list = self.map_stream(evaluation_inputs_stream)
|
| 530 |
+
|
| 531 |
+
instances_scores = []
|
| 532 |
+
for intermediate in intermediates_list:
|
| 533 |
+
instance_score = self.reduce_one(intermediate)
|
| 534 |
+
instance_score = self.annotate_scores(instance_score)
|
| 535 |
+
instances_scores.append(instance_score)
|
| 536 |
+
|
| 537 |
+
global_scores = self.reduce_and_bootstrap(intermediates_list)
|
| 538 |
+
global_scores = self.annotate_scores(global_scores)
|
| 539 |
+
|
| 540 |
+
global_scores["num_of_instances"] = len(intermediates_list)
|
| 541 |
+
|
| 542 |
+
return instances_scores, global_scores
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def get_index_or_default(lst, item, default=-1):
|
| 546 |
+
try:
|
| 547 |
+
return lst.index(item)
|
| 548 |
+
except ValueError:
|
| 549 |
+
return default
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
class AggregationReduction(Artifact, Generic[IntermediateType]):
|
| 553 |
+
def reduce(self, intermidates: List[IntermediateType]) -> Dict[str, Any]:
|
| 554 |
+
pass
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class DictReduction(AggregationReduction[Dict[str, float]]):
|
| 558 |
+
def reduce_list(self, lst: List[float]):
|
| 559 |
+
pass
|
| 560 |
+
|
| 561 |
+
def reduce(self, intermidates: List[Dict[str, float]]):
|
| 562 |
+
lists = {}
|
| 563 |
+
for intermidate in intermidates:
|
| 564 |
+
for key, val in intermidate.items():
|
| 565 |
+
if key not in lists:
|
| 566 |
+
lists[key] = []
|
| 567 |
+
lists[key].append(val)
|
| 568 |
+
|
| 569 |
+
result = {}
|
| 570 |
+
for key, val_list in lists.items():
|
| 571 |
+
result[key] = self.reduce_list(val_list)
|
| 572 |
+
return result
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
class MeanReduction(DictReduction):
|
| 576 |
+
def reduce_list(self, lst: List[float]):
|
| 577 |
+
return nan_mean(lst)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
class MaxReduction(DictReduction):
|
| 581 |
+
def reduce_list(self, lst: List[float]):
|
| 582 |
+
return float(nan_max(lst))
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
class ReductionInstanceMetric(
|
| 586 |
+
MapReduceMetric[PredictionType, IntermediateType],
|
| 587 |
+
Generic[PredictionType, IntermediateType],
|
| 588 |
+
):
|
| 589 |
+
reduction: AggregationReduction[IntermediateType]
|
| 590 |
+
|
| 591 |
+
def reduce(self, intermediates: List[IntermediateType]) -> Dict[str, Any]:
|
| 592 |
+
return self.reduction.reduce(intermediates)
|
| 593 |
+
|
| 594 |
+
def reduce_one(self, intermidate: IntermediateType):
|
| 595 |
+
return recursive_copy(intermidate)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
class AccuracyFast(ReductionInstanceMetric[str, Dict[str, float]]):
|
| 599 |
+
main_score = "accuracy"
|
| 600 |
+
reduction = MeanReduction()
|
| 601 |
+
|
| 602 |
+
def map(
|
| 603 |
+
self, prediction: str, references: List[str], task_data: Dict[str, Any]
|
| 604 |
+
) -> Dict[str, float]:
|
| 605 |
+
return {
|
| 606 |
+
self.main_score: float(
|
| 607 |
+
str(prediction) in [str(reference) for reference in references]
|
| 608 |
+
)
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class F1Fast(MapReduceMetric[str, Tuple[int, int]]):
|
| 613 |
+
main_score = "f1"
|
| 614 |
+
averages: List[Literal["f1", "macro", "micro", "per_class"]] = [
|
| 615 |
+
"f1",
|
| 616 |
+
"micro",
|
| 617 |
+
"macro",
|
| 618 |
+
"per_class",
|
| 619 |
+
]
|
| 620 |
+
ignore_punc: bool = True
|
| 621 |
+
ignore_case: bool = True
|
| 622 |
+
_requirements_list = ["scikit-learn", "regex"]
|
| 623 |
+
|
| 624 |
+
def prepare(self):
|
| 625 |
+
super().prepare()
|
| 626 |
+
from sklearn.metrics import f1_score
|
| 627 |
+
|
| 628 |
+
self._metric = f1_score
|
| 629 |
+
import regex
|
| 630 |
+
from functools import partial
|
| 631 |
+
|
| 632 |
+
self.remove_punc = partial(regex.compile(r"\p{P}+").sub, "")
|
| 633 |
+
|
| 634 |
+
def get_str_id(self, str):
|
| 635 |
+
if str not in self.str_to_id:
|
| 636 |
+
id = len(self.str_to_id)
|
| 637 |
+
self.str_to_id[str] = id
|
| 638 |
+
self.id_to_str[id] = str
|
| 639 |
+
return self.str_to_id[str]
|
| 640 |
+
|
| 641 |
+
def map_stream(
|
| 642 |
+
self, evaluation_inputs_stream: Generator[EvaluationInput[str], None, None]
|
| 643 |
+
):
|
| 644 |
+
self.str_to_id = {}
|
| 645 |
+
self.id_to_str = {}
|
| 646 |
+
return super().map_stream(evaluation_inputs_stream)
|
| 647 |
+
|
| 648 |
+
def map(
|
| 649 |
+
self, prediction: str, references: List[str], task_data: Dict[str, Any]
|
| 650 |
+
) -> Tuple[int, int]:
|
| 651 |
+
reference_index = self.get_str_id(references[0])
|
| 652 |
+
prediction_index = self.get_str_id(prediction)
|
| 653 |
+
|
| 654 |
+
return prediction_index, reference_index
|
| 655 |
+
|
| 656 |
+
def reduce(self, intermediates: List[Tuple[int, int]]) -> Dict[str, Any]:
|
| 657 |
+
y_true = []
|
| 658 |
+
y_pred = []
|
| 659 |
+
labels = set()
|
| 660 |
+
for pred_idx, ref_idx in intermediates:
|
| 661 |
+
y_pred.append(pred_idx)
|
| 662 |
+
y_true.append(ref_idx)
|
| 663 |
+
labels.add(ref_idx)
|
| 664 |
+
|
| 665 |
+
labels = list(labels)
|
| 666 |
+
result = {}
|
| 667 |
+
|
| 668 |
+
if "f1" in self.averages:
|
| 669 |
+
result["f1"] = float(
|
| 670 |
+
self._metric(
|
| 671 |
+
y_true,
|
| 672 |
+
y_pred,
|
| 673 |
+
average="macro",
|
| 674 |
+
labels=labels,
|
| 675 |
+
zero_division=0,
|
| 676 |
+
)
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
if "micro" in self.averages:
|
| 680 |
+
result["f1_micro"] = float(
|
| 681 |
+
self._metric(
|
| 682 |
+
y_true,
|
| 683 |
+
y_pred,
|
| 684 |
+
average="micro",
|
| 685 |
+
labels=labels,
|
| 686 |
+
zero_division=0,
|
| 687 |
+
)
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
if "macro" in self.averages:
|
| 691 |
+
result["f1_macro"] = float(
|
| 692 |
+
self._metric(
|
| 693 |
+
y_true,
|
| 694 |
+
y_pred,
|
| 695 |
+
average="macro",
|
| 696 |
+
labels=labels,
|
| 697 |
+
zero_division=0,
|
| 698 |
+
)
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
if "per_class" in self.averages:
|
| 702 |
+
f1_per_class = self._metric(
|
| 703 |
+
y_true, y_pred, average=None, labels=list(labels), zero_division=0
|
| 704 |
+
)
|
| 705 |
+
for label, score in zip(labels, f1_per_class):
|
| 706 |
+
class_name = self.id_to_str[label]
|
| 707 |
+
result[f"f1_{class_name}"] = float(score)
|
| 708 |
+
|
| 709 |
+
return result
|
| 710 |
+
|
| 711 |
+
|
| 712 |
class MetricWithConfidenceInterval(Metric):
|
| 713 |
# The number of resamples used to estimate the confidence intervals of this metric.
|
| 714 |
# Use None to disable confidence interval computation.
|
|
|
|
| 931 |
confidence_level=self.confidence_level,
|
| 932 |
random_state=random_gen,
|
| 933 |
).confidence_interval
|
| 934 |
+
result["score_ci_low"] = float(ci.low)
|
| 935 |
+
result["score_ci_high"] = float(ci.high)
|
| 936 |
+
result[f"{score_name}_ci_low"] = float(ci.low)
|
| 937 |
+
result[f"{score_name}_ci_high"] = float(ci.high)
|
| 938 |
return result
|
| 939 |
|
| 940 |
|
|
|
|
| 2124 |
**self.hf_compute_args,
|
| 2125 |
)
|
| 2126 |
if self.hf_main_score:
|
| 2127 |
+
result[self.main_score] = float(result[self.hf_main_score])
|
| 2128 |
del result[self.hf_main_score]
|
| 2129 |
if self.scale != 1.0:
|
| 2130 |
assert (
|
|
|
|
| 2144 |
result[key], float
|
| 2145 |
), "Scaled field '{key}' is not float: {result[key]}"
|
| 2146 |
result[key] /= self.scale
|
| 2147 |
+
if self.main_score in result:
|
| 2148 |
+
result[self.main_score] = float(result[self.main_score])
|
| 2149 |
return result
|
| 2150 |
|
| 2151 |
|
|
|
|
| 2231 |
return score
|
| 2232 |
|
| 2233 |
|
| 2234 |
+
class MeteorFast(ReductionInstanceMetric[str, Dict[str, float]]):
|
| 2235 |
+
main_score = "meteor"
|
| 2236 |
+
reduction = MeanReduction()
|
| 2237 |
+
_requirements_list: List[str] = ["nltk>=3.6.6"]
|
| 2238 |
+
alpha: float = 0.9
|
| 2239 |
+
beta: int = 3
|
| 2240 |
+
gamma: float = 0.5
|
| 2241 |
+
|
| 2242 |
+
def prepare(self):
|
| 2243 |
+
super().prepare()
|
| 2244 |
+
import nltk
|
| 2245 |
+
|
| 2246 |
+
nltk.download("wordnet", quiet=True)
|
| 2247 |
+
nltk.download("omw-1.4", quiet=True)
|
| 2248 |
+
from nltk import word_tokenize
|
| 2249 |
+
from nltk.translate import meteor_score
|
| 2250 |
+
|
| 2251 |
+
self.word_tokenize = word_tokenize
|
| 2252 |
+
self.meteor_score = meteor_score
|
| 2253 |
+
|
| 2254 |
+
def map(
|
| 2255 |
+
self, prediction: str, references: List[str], task_data: Dict[str, Any]
|
| 2256 |
+
) -> Dict[str, float]:
|
| 2257 |
+
score = self.meteor_score.meteor_score(
|
| 2258 |
+
[self.word_tokenize(ref) for ref in references],
|
| 2259 |
+
self.word_tokenize(prediction),
|
| 2260 |
+
alpha=self.alpha,
|
| 2261 |
+
beta=self.beta,
|
| 2262 |
+
gamma=self.gamma,
|
| 2263 |
+
)
|
| 2264 |
+
return {self.main_score: score}
|
| 2265 |
+
|
| 2266 |
+
|
| 2267 |
class Meteor(InstanceMetric):
|
| 2268 |
main_score = "meteor"
|
| 2269 |
ci_scores = ["meteor"]
|
| 2270 |
reduction_map = {"mean": ["meteor"]}
|
| 2271 |
prediction_type = str
|
| 2272 |
|
| 2273 |
+
_requirements_list: List[str] = ["nltk>=3.6.6"]
|
| 2274 |
alpha: float = 0.9
|
| 2275 |
beta: int = 3
|
| 2276 |
gamma: float = 0.5
|
|
|
|
| 2277 |
|
| 2278 |
def prepare(self):
|
| 2279 |
super().prepare()
|
|
|
|
| 2287 |
self.word_tokenize = word_tokenize
|
| 2288 |
self.meteor_score = meteor_score
|
| 2289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2290 |
def compute(self, references, prediction, task_data):
|
| 2291 |
score = self.meteor_score.meteor_score(
|
| 2292 |
[self.word_tokenize(ref) for ref in references],
|
operators.py
CHANGED
|
@@ -55,6 +55,7 @@ from typing import (
|
|
| 55 |
Generator,
|
| 56 |
Iterable,
|
| 57 |
List,
|
|
|
|
| 58 |
Optional,
|
| 59 |
Tuple,
|
| 60 |
Union,
|
|
@@ -1633,6 +1634,12 @@ class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
|
|
| 1633 |
yield from stream
|
| 1634 |
|
| 1635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1636 |
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
| 1637 |
"""Applies metric operators to a stream based on a metric field specified in each instance.
|
| 1638 |
|
|
@@ -1647,13 +1654,6 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
| 1647 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1648 |
from .metrics import Metric, MetricsList
|
| 1649 |
|
| 1650 |
-
def update_scores_of_stream_instances(
|
| 1651 |
-
stream: Stream, scores: List[dict]
|
| 1652 |
-
) -> Generator:
|
| 1653 |
-
for instance, score in zip(stream, scores):
|
| 1654 |
-
instance["score"] = recursive_copy(score)
|
| 1655 |
-
yield instance
|
| 1656 |
-
|
| 1657 |
# to be populated only when two or more metrics
|
| 1658 |
accumulated_scores = []
|
| 1659 |
|
|
@@ -1680,29 +1680,28 @@ class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
|
| 1680 |
f"Operator {metric_name} must be a Metric or MetricsList"
|
| 1681 |
)
|
| 1682 |
|
|
|
|
|
|
|
|
|
|
| 1683 |
# Each metric operator computes its score and then sets the main score, overwriting
|
| 1684 |
# the previous main score value (if any). So, we need to reverse the order of the listed metrics.
|
| 1685 |
# This will cause the first listed metric to run last, and the main score will be set
|
| 1686 |
# by the first listed metric (as desired).
|
| 1687 |
metrics_list = list(reversed(metrics_list))
|
| 1688 |
|
| 1689 |
-
for
|
| 1690 |
-
if
|
| 1691 |
-
|
| 1692 |
-
|
| 1693 |
-
if metric_no > 0:
|
| 1694 |
-
# update input stream with accumulated scores
|
| 1695 |
reusable_generator = ReusableGenerator(
|
| 1696 |
generator=update_scores_of_stream_instances,
|
| 1697 |
gen_kwargs={"stream": stream, "scores": accumulated_scores},
|
| 1698 |
)
|
| 1699 |
multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
|
| 1700 |
-
|
| 1701 |
-
multi_stream = MultiStream.from_iterables({"tmp": stream})
|
| 1702 |
multi_stream = metric(multi_stream)
|
| 1703 |
-
|
| 1704 |
-
|
| 1705 |
-
# updating accumulated_scores
|
| 1706 |
accumulated_scores = []
|
| 1707 |
for inst in multi_stream["tmp"]:
|
| 1708 |
accumulated_scores.append(recursive_copy(inst["score"]))
|
|
@@ -2214,3 +2213,20 @@ class CollateInstances(StreamOperator):
|
|
| 2214 |
f"batch_size must be an integer equal to or greater than 1. "
|
| 2215 |
f"Got: {self.batch_size}."
|
| 2216 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
Generator,
|
| 56 |
Iterable,
|
| 57 |
List,
|
| 58 |
+
Literal,
|
| 59 |
Optional,
|
| 60 |
Tuple,
|
| 61 |
Union,
|
|
|
|
| 1634 |
yield from stream
|
| 1635 |
|
| 1636 |
|
| 1637 |
+
def update_scores_of_stream_instances(stream: Stream, scores: List[dict]) -> Generator:
|
| 1638 |
+
for instance, score in zip(stream, scores):
|
| 1639 |
+
instance["score"] = recursive_copy(score)
|
| 1640 |
+
yield instance
|
| 1641 |
+
|
| 1642 |
+
|
| 1643 |
class ApplyMetric(StreamOperator, ArtifactFetcherMixin):
|
| 1644 |
"""Applies metric operators to a stream based on a metric field specified in each instance.
|
| 1645 |
|
|
|
|
| 1654 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1655 |
from .metrics import Metric, MetricsList
|
| 1656 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1657 |
# to be populated only when two or more metrics
|
| 1658 |
accumulated_scores = []
|
| 1659 |
|
|
|
|
| 1680 |
f"Operator {metric_name} must be a Metric or MetricsList"
|
| 1681 |
)
|
| 1682 |
|
| 1683 |
+
for metric in metrics_list:
|
| 1684 |
+
if not self.calc_confidence_intervals:
|
| 1685 |
+
metric.disable_confidence_interval_calculation()
|
| 1686 |
# Each metric operator computes its score and then sets the main score, overwriting
|
| 1687 |
# the previous main score value (if any). So, we need to reverse the order of the listed metrics.
|
| 1688 |
# This will cause the first listed metric to run last, and the main score will be set
|
| 1689 |
# by the first listed metric (as desired).
|
| 1690 |
metrics_list = list(reversed(metrics_list))
|
| 1691 |
|
| 1692 |
+
for i, metric in enumerate(metrics_list):
|
| 1693 |
+
if i == 0: # first metric
|
| 1694 |
+
multi_stream = MultiStream({"tmp": stream})
|
| 1695 |
+
else: # metrics with previous scores
|
|
|
|
|
|
|
| 1696 |
reusable_generator = ReusableGenerator(
|
| 1697 |
generator=update_scores_of_stream_instances,
|
| 1698 |
gen_kwargs={"stream": stream, "scores": accumulated_scores},
|
| 1699 |
)
|
| 1700 |
multi_stream = MultiStream.from_generators({"tmp": reusable_generator})
|
| 1701 |
+
|
|
|
|
| 1702 |
multi_stream = metric(multi_stream)
|
| 1703 |
+
|
| 1704 |
+
if i < len(metrics_list) - 1: # last metric
|
|
|
|
| 1705 |
accumulated_scores = []
|
| 1706 |
for inst in multi_stream["tmp"]:
|
| 1707 |
accumulated_scores.append(recursive_copy(inst["score"]))
|
|
|
|
| 2213 |
f"batch_size must be an integer equal to or greater than 1. "
|
| 2214 |
f"Got: {self.batch_size}."
|
| 2215 |
)
|
| 2216 |
+
|
| 2217 |
+
|
| 2218 |
+
class WikipediaFetcher(FieldOperator):
|
| 2219 |
+
mode: Literal["summary", "text"] = "text"
|
| 2220 |
+
_requirements_list = ["Wikipedia-API"]
|
| 2221 |
+
|
| 2222 |
+
def prepare(self):
|
| 2223 |
+
super().prepare()
|
| 2224 |
+
import wikipediaapi
|
| 2225 |
+
|
| 2226 |
+
self.wikipedia = wikipediaapi.Wikipedia("Unitxt")
|
| 2227 |
+
|
| 2228 |
+
def process_value(self, value: Any) -> Any:
|
| 2229 |
+
title = value.split("/")[-1]
|
| 2230 |
+
page = self.wikipedia.page(title)
|
| 2231 |
+
|
| 2232 |
+
return {"title": page.title, "body": getattr(page, self.mode)}
|
serializers.py
CHANGED
|
@@ -7,7 +7,7 @@ from .dataclass import AbstractField, Field
|
|
| 7 |
from .operators import InstanceFieldOperator
|
| 8 |
from .settings_utils import get_constants
|
| 9 |
from .type_utils import isoftype, to_type_string
|
| 10 |
-
from .types import Dialog, Image, Number, Table, Video
|
| 11 |
|
| 12 |
constants = get_constants()
|
| 13 |
|
|
@@ -127,9 +127,28 @@ class VideoSerializer(ImageSerializer):
|
|
| 127 |
return "".join(serialized_images)
|
| 128 |
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
class MultiTypeSerializer(Serializer):
|
| 131 |
serializers: List[SingleTypeSerializer] = Field(
|
| 132 |
default_factory=lambda: [
|
|
|
|
|
|
|
| 133 |
ImageSerializer(),
|
| 134 |
VideoSerializer(),
|
| 135 |
TableSerializer(),
|
|
|
|
| 7 |
from .operators import InstanceFieldOperator
|
| 8 |
from .settings_utils import get_constants
|
| 9 |
from .type_utils import isoftype, to_type_string
|
| 10 |
+
from .types import Dialog, Document, Image, MultiDocument, Number, Table, Video
|
| 11 |
|
| 12 |
constants = get_constants()
|
| 13 |
|
|
|
|
| 127 |
return "".join(serialized_images)
|
| 128 |
|
| 129 |
|
| 130 |
+
class DocumentSerializer(SingleTypeSerializer):
|
| 131 |
+
serialized_type = Document
|
| 132 |
+
|
| 133 |
+
def serialize(self, value: Document, instance: Dict[str, Any]) -> str:
|
| 134 |
+
return f"# {value['title']}\n\n{value['body']}"
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class MultiDocumentSerializer(DocumentSerializer):
|
| 138 |
+
serialized_type = MultiDocument
|
| 139 |
+
|
| 140 |
+
def serialize(self, value: MultiDocument, instance: Dict[str, Any]) -> str:
|
| 141 |
+
documents = []
|
| 142 |
+
for document in value:
|
| 143 |
+
documents.append(super().serialize(document, instance))
|
| 144 |
+
return "\n\n".join(documents)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
class MultiTypeSerializer(Serializer):
|
| 148 |
serializers: List[SingleTypeSerializer] = Field(
|
| 149 |
default_factory=lambda: [
|
| 150 |
+
DocumentSerializer(),
|
| 151 |
+
MultiDocumentSerializer(),
|
| 152 |
ImageSerializer(),
|
| 153 |
VideoSerializer(),
|
| 154 |
TableSerializer(),
|
task.py
CHANGED
|
@@ -116,13 +116,18 @@ class Task(InstanceOperator, ArtifactFetcherMixin):
|
|
| 116 |
self.prediction_type
|
| 117 |
)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
def task_deprecations(self):
|
| 120 |
if hasattr(self, "inputs") and self.inputs is not None:
|
| 121 |
depr_message = (
|
| 122 |
"The 'inputs' field is deprecated. Please use 'input_fields' instead."
|
| 123 |
)
|
| 124 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
| 125 |
-
|
| 126 |
if hasattr(self, "outputs") and self.outputs is not None:
|
| 127 |
depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
|
| 128 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
|
|
|
| 116 |
self.prediction_type
|
| 117 |
)
|
| 118 |
|
| 119 |
+
if hasattr(self, "inputs") and self.inputs is not None:
|
| 120 |
+
self.inputs = self.input_fields
|
| 121 |
+
|
| 122 |
+
if hasattr(self, "outputs") and self.outputs is not None:
|
| 123 |
+
self.outputs = self.reference_fields
|
| 124 |
+
|
| 125 |
def task_deprecations(self):
|
| 126 |
if hasattr(self, "inputs") and self.inputs is not None:
|
| 127 |
depr_message = (
|
| 128 |
"The 'inputs' field is deprecated. Please use 'input_fields' instead."
|
| 129 |
)
|
| 130 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
|
|
|
| 131 |
if hasattr(self, "outputs") and self.outputs is not None:
|
| 132 |
depr_message = "The 'outputs' field is deprecated. Please use 'reference_fields' instead."
|
| 133 |
warnings.warn(depr_message, DeprecationWarning, stacklevel=2)
|
templates.py
CHANGED
|
@@ -495,7 +495,31 @@ class PairwiseComparativeRatingTemplate(InputOutputTemplate):
|
|
| 495 |
|
| 496 |
|
| 497 |
class MultipleChoiceTemplate(InputFormatTemplate):
|
| 498 |
-
"""Formats the input
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
target_prefix: str = ""
|
| 501 |
choices_field: str = "choices"
|
|
@@ -504,7 +528,13 @@ class MultipleChoiceTemplate(InputFormatTemplate):
|
|
| 504 |
source_choice_format: str = "{choice_numeral}. {choice_text}"
|
| 505 |
target_choice_format: str = "{choice_numeral}"
|
| 506 |
enumerator: str = "capitals"
|
|
|
|
| 507 |
shuffle_choices: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
def prepare(self):
|
| 510 |
super().prepare()
|
|
@@ -538,6 +568,31 @@ class MultipleChoiceTemplate(InputFormatTemplate):
|
|
| 538 |
"XX",
|
| 539 |
]
|
| 540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str:
|
| 542 |
choices = data[self.choices_field]
|
| 543 |
enumrated_choices = []
|
|
@@ -612,18 +667,44 @@ class MultipleChoiceTemplate(InputFormatTemplate):
|
|
| 612 |
def preprocess_input_and_reference_fields(
|
| 613 |
self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
|
| 614 |
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 615 |
-
if
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
|
| 621 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 622 |
random_generator.shuffle(choices)
|
| 623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
|
| 625 |
-
|
| 626 |
-
|
|
|
|
|
|
|
| 627 |
|
| 628 |
return input_fields, reference_fields
|
| 629 |
|
|
|
|
| 495 |
|
| 496 |
|
| 497 |
class MultipleChoiceTemplate(InputFormatTemplate):
|
| 498 |
+
"""Formats the input that specifies a multiple-choice question, with a list of possible answers to choose from, and identifies the correct answer.
|
| 499 |
+
|
| 500 |
+
Args:
|
| 501 |
+
target_prefix (str): Optional prefix that can be added before the target label in
|
| 502 |
+
generated prompts or outputs.
|
| 503 |
+
choices_field (str): The key under which the multiple choices are stored in the
|
| 504 |
+
input and reference dictionaries.
|
| 505 |
+
target_field (str): The key under which the correct choice is stored in the
|
| 506 |
+
reference dictionary (can be integer index or textual label).
|
| 507 |
+
choices_separator (str): A string used to join formatted choices (e.g. ", ").
|
| 508 |
+
source_choice_format (str): A Python format string used for displaying each choice
|
| 509 |
+
in the input fields (e.g. "{choice_numeral}. {choice_text}").
|
| 510 |
+
target_choice_format (str): A Python format string used for displaying each choice
|
| 511 |
+
in the target or final output (e.g. "{choice_numeral}").
|
| 512 |
+
enumerator (str): Determines how choice numerals are enumerated. Possible values
|
| 513 |
+
include "capitals", "lowercase", "numbers", or "roman".
|
| 514 |
+
shuffle_choices (bool): If True, shuffle the choices. The shuffling seed can be
|
| 515 |
+
set with `shuffle_choices_seed`.
|
| 516 |
+
shuffle_choices_seed (int, optional): If provided, the choices are shuffled with
|
| 517 |
+
this fixed integer seed for reproducibility.
|
| 518 |
+
sort_choices_by_length (bool): If True, sorts choices by their length (ascending).
|
| 519 |
+
sort_choices_alphabetically (bool): If True, sorts choices in alphabetical order.
|
| 520 |
+
reverse_choices (bool): If True, reverses the order of the choices after any
|
| 521 |
+
sorting has been applied. Defaults to False to preserve backward compatibility.
|
| 522 |
+
"""
|
| 523 |
|
| 524 |
target_prefix: str = ""
|
| 525 |
choices_field: str = "choices"
|
|
|
|
| 528 |
source_choice_format: str = "{choice_numeral}. {choice_text}"
|
| 529 |
target_choice_format: str = "{choice_numeral}"
|
| 530 |
enumerator: str = "capitals"
|
| 531 |
+
|
| 532 |
shuffle_choices: bool = False
|
| 533 |
+
shuffle_choices_seed: int = None
|
| 534 |
+
sort_choices_by_length: bool = False
|
| 535 |
+
sort_choices_alphabetically: bool = False
|
| 536 |
+
reverse_choices: bool = False # False by default for backward-compat
|
| 537 |
+
place_correct_choice_position: int = None
|
| 538 |
|
| 539 |
def prepare(self):
|
| 540 |
super().prepare()
|
|
|
|
| 568 |
"XX",
|
| 569 |
]
|
| 570 |
|
| 571 |
+
def verify(self):
|
| 572 |
+
super().verify()
|
| 573 |
+
if self.shuffle_choices and (
|
| 574 |
+
self.sort_choices_by_length
|
| 575 |
+
or self.sort_choices_alphabetically
|
| 576 |
+
or self.reverse_choices
|
| 577 |
+
or self.place_correct_choice_position is not None
|
| 578 |
+
):
|
| 579 |
+
raise UnitxtError(
|
| 580 |
+
"You cannot combine shuffle_choices with sorting or reversing flags."
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
if self.sort_choices_by_length and self.sort_choices_alphabetically:
|
| 584 |
+
raise UnitxtError(
|
| 585 |
+
"You cannot combine both sort_choices_by_length and sort_choices_alphabetically simultaneously."
|
| 586 |
+
)
|
| 587 |
+
if self.place_correct_choice_position is not None and (
|
| 588 |
+
self.sort_choices_by_length
|
| 589 |
+
or self.sort_choices_alphabetically
|
| 590 |
+
or self.reverse_choices
|
| 591 |
+
):
|
| 592 |
+
raise UnitxtError(
|
| 593 |
+
"You cannot combine place_correct_choice_position with sorting or reversing flags."
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str:
|
| 597 |
choices = data[self.choices_field]
|
| 598 |
enumrated_choices = []
|
|
|
|
| 667 |
def preprocess_input_and_reference_fields(
|
| 668 |
self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
|
| 669 |
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 670 |
+
if (
|
| 671 |
+
not self.shuffle_choices
|
| 672 |
+
and not self.sort_choices_by_length
|
| 673 |
+
and not self.sort_choices_alphabetically
|
| 674 |
+
and not self.reverse_choices
|
| 675 |
+
and self.place_correct_choice_position is None
|
| 676 |
+
):
|
| 677 |
+
return input_fields, reference_fields
|
| 678 |
+
|
| 679 |
+
choices = input_fields[self.choices_field]
|
| 680 |
+
target_index = self.outputs_to_target_index(reference_fields)
|
| 681 |
+
original_label_choice = reference_fields[self.choices_field][target_index]
|
| 682 |
|
| 683 |
+
if self.sort_choices_by_length:
|
| 684 |
+
choices.sort(key=len)
|
| 685 |
+
if self.sort_choices_alphabetically:
|
| 686 |
+
choices.sort()
|
| 687 |
+
if self.reverse_choices:
|
| 688 |
+
choices.reverse()
|
| 689 |
+
if self.shuffle_choices:
|
| 690 |
+
random_generator = new_random_generator(
|
| 691 |
+
self.shuffle_choices_seed
|
| 692 |
+
if self.shuffle_choices_seed is not None
|
| 693 |
+
else {**input_fields}
|
| 694 |
+
)
|
| 695 |
random_generator.shuffle(choices)
|
| 696 |
+
if self.place_correct_choice_position is not None:
|
| 697 |
+
if not 0 <= self.place_correct_choice_position < len(choices):
|
| 698 |
+
raise ValueError(
|
| 699 |
+
f"fix_correct_choice_position={self.place_correct_choice_position} out of range (0..{len(choices) - 1})."
|
| 700 |
+
)
|
| 701 |
+
choices.remove(original_label_choice)
|
| 702 |
+
choices.insert(self.place_correct_choice_position, original_label_choice)
|
| 703 |
|
| 704 |
+
# Update both input_fields and reference_fields once at the end
|
| 705 |
+
input_fields[self.choices_field] = choices
|
| 706 |
+
reference_fields[self.choices_field] = choices
|
| 707 |
+
reference_fields[self.target_field] = choices.index(original_label_choice)
|
| 708 |
|
| 709 |
return input_fields, reference_fields
|
| 710 |
|
types.py
CHANGED
|
@@ -26,6 +26,13 @@ class Image(TypedDict):
|
|
| 26 |
format: str
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
Video = NewType("Video", List[Image])
|
| 30 |
|
| 31 |
|
|
@@ -46,4 +53,6 @@ register_type(Table)
|
|
| 46 |
register_type(Audio)
|
| 47 |
register_type(Image)
|
| 48 |
register_type(Video)
|
|
|
|
|
|
|
| 49 |
register_type(RagResponse)
|
|
|
|
| 26 |
format: str
|
| 27 |
|
| 28 |
|
| 29 |
+
class Document(TypedDict):
|
| 30 |
+
title: str
|
| 31 |
+
body: str
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
MultiDocument = NewType("MultiDocument", List[Document])
|
| 35 |
+
|
| 36 |
Video = NewType("Video", List[Image])
|
| 37 |
|
| 38 |
|
|
|
|
| 53 |
register_type(Audio)
|
| 54 |
register_type(Image)
|
| 55 |
register_type(Video)
|
| 56 |
+
register_type(Document)
|
| 57 |
+
register_type(MultiDocument)
|
| 58 |
register_type(RagResponse)
|
version.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
version = "1.16.
|
|
|
|
| 1 |
+
version = "1.16.2"
|