Upload operators.py with huggingface_hub
Browse files- operators.py +625 -235
operators.py
CHANGED
|
@@ -1,11 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import collections
|
| 2 |
import importlib
|
|
|
|
|
|
|
| 3 |
import uuid
|
| 4 |
from abc import abstractmethod
|
| 5 |
from collections import Counter
|
| 6 |
from copy import deepcopy
|
| 7 |
from dataclasses import field
|
| 8 |
from itertools import zip_longest
|
|
|
|
| 9 |
from typing import (
|
| 10 |
Any,
|
| 11 |
Callable,
|
|
@@ -32,17 +68,20 @@ from .operator import (
|
|
| 32 |
StreamInstanceOperator,
|
| 33 |
StreamSource,
|
| 34 |
)
|
| 35 |
-
from .random_utils import
|
| 36 |
from .stream import Stream
|
| 37 |
from .text_utils import nested_tuple_to_string
|
|
|
|
| 38 |
from .utils import flatten_dict
|
| 39 |
|
| 40 |
|
| 41 |
class FromIterables(StreamInitializerOperator):
|
| 42 |
-
"""Creates a MultiStream from iterables.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
Args:
|
| 45 |
-
iterables (Dict[str, Iterable]): A dictionary where each key-value pair represents a stream name and its corresponding iterable.
|
| 46 |
"""
|
| 47 |
|
| 48 |
def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
|
|
@@ -50,6 +89,19 @@ class FromIterables(StreamInitializerOperator):
|
|
| 50 |
|
| 51 |
|
| 52 |
class IterableSource(StreamSource):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
iterables: Dict[str, Iterable]
|
| 54 |
|
| 55 |
def __call__(self) -> MultiStream:
|
|
@@ -57,7 +109,7 @@ class IterableSource(StreamSource):
|
|
| 57 |
|
| 58 |
|
| 59 |
class MapInstanceValues(StreamInstanceOperator):
|
| 60 |
-
"""A class used to map instance values into
|
| 61 |
|
| 62 |
This class is a type of StreamInstanceOperator,
|
| 63 |
it maps values of instances in a stream using predefined mappers.
|
|
@@ -87,6 +139,11 @@ class MapInstanceValues(StreamInstanceOperator):
|
|
| 87 |
To ensure that all values of field 'a' are mapped in every instance, use strict=True.
|
| 88 |
Input instance {"a":"3", "b": 2} will raise an exception per the above call,
|
| 89 |
because "3" is not a key in the mapper of "a".
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
"""
|
| 91 |
|
| 92 |
mappers: Dict[str, Dict[str, str]]
|
|
@@ -115,34 +172,31 @@ class MapInstanceValues(StreamInstanceOperator):
|
|
| 115 |
raise ValueError(
|
| 116 |
f"'process_every_field' == True is allowed only when all fields which have mappers, i.e., {list(self.mappers.keys())} are lists. Instace = {instance}"
|
| 117 |
)
|
| 118 |
-
if isinstance(value, list):
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
dict_set(instance, key, value, use_dpath=self.use_query)
|
| 130 |
-
else: # field is a list, and process_every_value == False
|
| 131 |
-
if self.strict: # whole lists can not be mapped by a string-to-something mapper
|
| 132 |
-
raise KeyError(
|
| 133 |
-
f"A whole list ({value}) in the instance can not be mapped by a field mapper."
|
| 134 |
-
)
|
| 135 |
-
else: # value is not a list, implying process_every_value == False
|
| 136 |
-
value = str(value) # make sure the value is a string
|
| 137 |
-
if self.strict and (value not in mapper):
|
| 138 |
-
raise KeyError(
|
| 139 |
-
f"value '{value}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
| 140 |
-
)
|
| 141 |
-
if value in mapper:
|
| 142 |
-
dict_set(instance, key, mapper[value], use_dpath=self.use_query)
|
| 143 |
|
| 144 |
return instance
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
class FlattenInstances(StreamInstanceOperator):
|
| 148 |
"""Flattens each instance in a stream, making nested dictionary entries into top-level entries.
|
|
@@ -182,6 +236,7 @@ class AddFields(StreamInstanceOperator):
|
|
| 182 |
# Add a 'classes' field on a given list, prevent modification of original list
|
| 183 |
# from changing the instance.
|
| 184 |
AddFields(fields={"classes": alist}), use_deepcopy=True)
|
|
|
|
| 185 |
"""
|
| 186 |
|
| 187 |
fields: Dict[str, object]
|
|
@@ -204,7 +259,7 @@ class AddFields(StreamInstanceOperator):
|
|
| 204 |
|
| 205 |
|
| 206 |
class RemoveFields(StreamInstanceOperator):
|
| 207 |
-
"""Remove specified fields
|
| 208 |
|
| 209 |
Args:
|
| 210 |
fields (List[str]): The fields to remove from each instance.
|
|
@@ -221,19 +276,32 @@ class RemoveFields(StreamInstanceOperator):
|
|
| 221 |
|
| 222 |
|
| 223 |
class FieldOperator(StreamInstanceOperator):
|
| 224 |
-
"""A general stream that processes the values of a field (or multiple ones.
|
| 225 |
|
| 226 |
Args:
|
| 227 |
-
field (Optional[str]): The field to process, if only a single one is passed Defaults to None
|
| 228 |
-
to_field (Optional[str]): Field name to save, if only one field is
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
|
| 231 |
use_query (bool): Whether to use dpath style queries. Defaults to False.
|
|
|
|
|
|
|
|
|
|
| 232 |
"""
|
| 233 |
|
| 234 |
field: Optional[str] = None
|
| 235 |
to_field: Optional[str] = None
|
| 236 |
-
field_to_field: Optional[Union[List[
|
| 237 |
process_every_value: bool = False
|
| 238 |
use_query: bool = False
|
| 239 |
get_default: Any = None
|
|
@@ -250,25 +318,67 @@ class FieldOperator(StreamInstanceOperator):
|
|
| 250 |
), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
|
| 251 |
assert (
|
| 252 |
self.field is None or self.field_to_field is None
|
| 253 |
-
), f"Can not apply operator both on {self.field} and on the
|
|
|
|
| 254 |
assert (
|
| 255 |
-
self._field_to_field
|
| 256 |
-
), f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
@abstractmethod
|
| 259 |
def process_value(self, value: Any) -> Any:
|
| 260 |
pass
|
| 261 |
|
| 262 |
def prepare(self):
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
if self.field_to_field is None:
|
| 266 |
-
self._field_to_field = [
|
|
|
|
|
|
|
| 267 |
else:
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
|
|
|
| 272 |
|
| 273 |
def process(
|
| 274 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
|
@@ -295,7 +405,7 @@ class FieldOperator(StreamInstanceOperator):
|
|
| 295 |
raise ValueError(
|
| 296 |
f"Failed to process '{from_field}' from {instance} due to : {e}"
|
| 297 |
) from e
|
| 298 |
-
if
|
| 299 |
dict_delete(instance, from_field)
|
| 300 |
dict_set(
|
| 301 |
instance,
|
|
@@ -308,7 +418,25 @@ class FieldOperator(StreamInstanceOperator):
|
|
| 308 |
|
| 309 |
|
| 310 |
class RenameFields(FieldOperator):
|
| 311 |
-
"""Renames fields.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
def process_value(self, value: Any) -> Any:
|
| 314 |
return value
|
|
@@ -317,20 +445,31 @@ class RenameFields(FieldOperator):
|
|
| 317 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 318 |
) -> Dict[str, Any]:
|
| 319 |
res = super().process(instance=instance, stream_name=stream_name)
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
return res
|
| 327 |
|
| 328 |
|
| 329 |
class AddConstant(FieldOperator):
|
| 330 |
-
"""Adds a
|
| 331 |
|
| 332 |
Args:
|
| 333 |
-
add:
|
| 334 |
"""
|
| 335 |
|
| 336 |
add: Any
|
|
@@ -396,19 +535,15 @@ class Augmentor(StreamInstanceOperator):
|
|
| 396 |
default="",
|
| 397 |
not_exist_ok=False,
|
| 398 |
)
|
| 399 |
-
except
|
| 400 |
raise TypeError(f"Failed to get {field_name} from {instance}") from e
|
| 401 |
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
except Exception as e:
|
| 409 |
-
raise RuntimeError(
|
| 410 |
-
f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
|
| 411 |
-
) from e
|
| 412 |
dict_set(instance, field_name, new_value, use_dpath=True, not_exist_ok=True)
|
| 413 |
return instance
|
| 414 |
|
|
@@ -433,90 +568,146 @@ class AugmentWhitespace(Augmentor):
|
|
| 433 |
words = re.split(r"(\s+)", value)
|
| 434 |
new_value = ""
|
| 435 |
|
|
|
|
| 436 |
for word in words:
|
| 437 |
if word.isspace():
|
| 438 |
-
new_value +=
|
| 439 |
["\n", "\t", " "]
|
| 440 |
-
) *
|
| 441 |
else:
|
| 442 |
new_value += word
|
| 443 |
return new_value
|
| 444 |
|
| 445 |
|
| 446 |
-
class
|
| 447 |
-
r"""Augments the input by appending to it a randomly selected (typically, whitespace)
|
| 448 |
|
| 449 |
Args:
|
| 450 |
-
suffixes : the potential (typically, whitespace) patterns to select from.
|
| 451 |
The dictionary version allows to specify relative weights of the different patterns.
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
Examples:
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
|
|
|
| 462 |
|
| 463 |
"""
|
| 464 |
|
| 465 |
-
|
| 466 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
def verify(self):
|
| 469 |
assert (
|
| 470 |
-
|
| 471 |
-
),
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
), f"suffixes should be a list of strings, whereas member {k!s} is of type {type(k)}"
|
| 486 |
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
|
|
|
|
|
|
|
|
|
| 491 |
)
|
| 492 |
total_weight = (
|
| 493 |
-
len(
|
| 494 |
-
if isinstance(
|
| 495 |
-
else sum([v for k, v in
|
| 496 |
)
|
| 497 |
-
|
| 498 |
-
[1.0 / total_weight] * len(
|
| 499 |
-
if isinstance(
|
| 500 |
-
else [float(
|
| 501 |
)
|
| 502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
def process_value(self, value: Any) -> Any:
|
| 505 |
assert value is not None, "input value should not be None"
|
| 506 |
new_value = str(value)
|
| 507 |
-
if self.
|
| 508 |
-
new_value = new_value.
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
|
| 514 |
class ShuffleFieldValues(FieldOperator):
|
| 515 |
-
"""Shuffles
|
| 516 |
|
| 517 |
def process_value(self, value: Any) -> Any:
|
| 518 |
res = list(value)
|
| 519 |
-
|
|
|
|
| 520 |
return res
|
| 521 |
|
| 522 |
|
|
@@ -621,9 +812,18 @@ class ListFieldValues(StreamInstanceOperator):
|
|
| 621 |
|
| 622 |
|
| 623 |
class ZipFieldValues(StreamInstanceOperator):
|
| 624 |
-
"""Zips values of multiple fields similar to list(zip(*fields)).
|
|
|
|
|
|
|
|
|
|
| 625 |
|
| 626 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
to_field: str
|
| 628 |
longest: bool = False
|
| 629 |
use_query: bool = False
|
|
@@ -643,7 +843,7 @@ class ZipFieldValues(StreamInstanceOperator):
|
|
| 643 |
|
| 644 |
|
| 645 |
class IndexOf(StreamInstanceOperator):
|
| 646 |
-
"""
|
| 647 |
|
| 648 |
search_in: str
|
| 649 |
index_of: str
|
|
@@ -660,7 +860,7 @@ class IndexOf(StreamInstanceOperator):
|
|
| 660 |
|
| 661 |
|
| 662 |
class TakeByField(StreamInstanceOperator):
|
| 663 |
-
"""
|
| 664 |
|
| 665 |
field: str
|
| 666 |
index: str
|
|
@@ -681,11 +881,24 @@ class TakeByField(StreamInstanceOperator):
|
|
| 681 |
|
| 682 |
|
| 683 |
class CopyFields(FieldOperator):
|
| 684 |
-
"""Copies
|
| 685 |
|
| 686 |
-
Args:
|
| 687 |
field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
|
| 688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
"""
|
| 690 |
|
| 691 |
def process_value(self, value: Any) -> Any:
|
|
@@ -693,6 +906,8 @@ class CopyFields(FieldOperator):
|
|
| 693 |
|
| 694 |
|
| 695 |
class AddID(StreamInstanceOperator):
|
|
|
|
|
|
|
| 696 |
id_field_name: str = "id"
|
| 697 |
|
| 698 |
def process(
|
|
@@ -706,22 +921,31 @@ class CastFields(StreamInstanceOperator):
|
|
| 706 |
"""Casts specified fields to specified types.
|
| 707 |
|
| 708 |
Args:
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
defaults (Dict[str, object]): A dictionary mapping
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
"""
|
| 714 |
|
| 715 |
-
types = {
|
| 716 |
-
"int": int,
|
| 717 |
-
"float": float,
|
| 718 |
-
"str": str,
|
| 719 |
-
"bool": bool,
|
| 720 |
-
}
|
| 721 |
fields: Dict[str, str] = field(default_factory=dict)
|
| 722 |
failure_defaults: Dict[str, object] = field(default_factory=dict)
|
| 723 |
use_nested_query: bool = False
|
| 724 |
-
|
|
|
|
|
|
|
|
|
|
| 725 |
|
| 726 |
def _cast_single(self, value, type, field):
|
| 727 |
try:
|
|
@@ -734,14 +958,17 @@ class CastFields(StreamInstanceOperator):
|
|
| 734 |
return self.failure_defaults[field]
|
| 735 |
|
| 736 |
def _cast_multiple(self, values, type, field):
|
| 737 |
-
|
| 738 |
|
| 739 |
def process(
|
| 740 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 741 |
) -> Dict[str, Any]:
|
| 742 |
for field_name, type in self.fields.items():
|
| 743 |
value = dict_get(instance, field_name, use_dpath=self.use_nested_query)
|
| 744 |
-
if self.
|
|
|
|
|
|
|
|
|
|
| 745 |
casted_value = self._cast_multiple(value, type, field_name)
|
| 746 |
else:
|
| 747 |
casted_value = self._cast_single(value, type, field_name)
|
|
@@ -751,29 +978,46 @@ class CastFields(StreamInstanceOperator):
|
|
| 751 |
return instance
|
| 752 |
|
| 753 |
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
for key, value in instance.items():
|
| 757 |
-
instance[key] = recursive_divide(value, divisor, strict=strict)
|
| 758 |
-
elif isinstance(instance, list):
|
| 759 |
-
for i, value in enumerate(instance):
|
| 760 |
-
instance[i] = recursive_divide(value, divisor, strict=strict)
|
| 761 |
-
elif isinstance(instance, float):
|
| 762 |
-
instance /= divisor
|
| 763 |
-
elif strict:
|
| 764 |
-
raise ValueError(f"Cannot divide instance of type {type(instance)}")
|
| 765 |
-
return instance
|
| 766 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 767 |
|
| 768 |
-
class DivideAllFieldsBy(StreamInstanceOperator):
|
| 769 |
divisor: float = 1.0
|
| 770 |
strict: bool = False
|
| 771 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 772 |
|
| 773 |
def process(
|
| 774 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 775 |
) -> Dict[str, Any]:
|
| 776 |
-
return
|
| 777 |
|
| 778 |
|
| 779 |
class ArtifactFetcherMixin:
|
|
@@ -797,13 +1041,21 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
|
| 797 |
"""Applies value operators to each instance in a stream based on specified fields.
|
| 798 |
|
| 799 |
Args:
|
| 800 |
-
|
| 801 |
-
|
|
|
|
|
|
|
|
|
|
| 802 |
default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
|
| 803 |
-
"""
|
| 804 |
|
| 805 |
-
|
|
|
|
|
|
|
|
|
|
| 806 |
|
|
|
|
|
|
|
|
|
|
| 807 |
operators_field: str
|
| 808 |
default_operators: List[str] = None
|
| 809 |
fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
|
|
@@ -815,7 +1067,7 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
|
| 815 |
if operator_names is None:
|
| 816 |
assert (
|
| 817 |
self.default_operators is not None
|
| 818 |
-
), f"No operators found in {self.
|
| 819 |
operator_names = self.default_operators
|
| 820 |
|
| 821 |
if isinstance(operator_names, str):
|
|
@@ -828,35 +1080,155 @@ class ApplyOperatorsField(StreamInstanceOperator, ArtifactFetcherMixin):
|
|
| 828 |
if field_name in self.fields_to_treat_as_list:
|
| 829 |
instance[field_name] = [operator.process(v) for v in value]
|
| 830 |
else:
|
| 831 |
-
instance[field_name] = operator.process(
|
| 832 |
|
| 833 |
return instance
|
| 834 |
|
| 835 |
|
| 836 |
-
class
|
| 837 |
-
"""Filters a stream, yielding only instances
|
|
|
|
|
|
|
| 838 |
|
| 839 |
Args:
|
| 840 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
"""
|
| 842 |
|
| 843 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 844 |
|
| 845 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
|
|
|
| 846 |
for instance in stream:
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 850 |
raise ValueError(
|
| 851 |
-
f"
|
| 852 |
)
|
| 853 |
-
if instance[key]
|
| 854 |
-
|
| 855 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
yield instance
|
| 857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
|
| 859 |
-
class
|
| 860 |
field: str
|
| 861 |
stream_name: str
|
| 862 |
overall_top_frequency_percent: Optional[int] = 100
|
|
@@ -877,21 +1249,21 @@ class ExtractFieldValues(MultiStreamOperator):
|
|
| 877 |
|
| 878 |
Examples:
|
| 879 |
|
| 880 |
-
|
| 881 |
field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
|
| 882 |
every instance in all streams.
|
| 883 |
|
| 884 |
-
|
| 885 |
in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
|
| 886 |
value members in these lists, and report the most frequent values.
|
| 887 |
if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
|
| 888 |
'to_field' of each instance of all streams.
|
| 889 |
|
| 890 |
-
|
| 891 |
extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
|
| 892 |
and stores them in field 'classes' of each instance of all streams.
|
| 893 |
|
| 894 |
-
|
| 895 |
extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
|
| 896 |
Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
|
| 897 |
"""
|
|
@@ -952,41 +1324,18 @@ class ExtractFieldValues(MultiStreamOperator):
|
|
| 952 |
[*ele[0]] if isinstance(ele[0], tuple) else ele[0]
|
| 953 |
for ele in values_and_counts
|
| 954 |
]
|
| 955 |
-
for name in multi_stream:
|
| 956 |
-
for instance in multi_stream[name]:
|
| 957 |
-
instance[self.to_field] = values_to_keep
|
| 958 |
-
return multi_stream
|
| 959 |
|
|
|
|
|
|
|
| 960 |
|
| 961 |
-
class FilterByListsOfValues(SingleStreamOperator):
|
| 962 |
-
"""Filters a stream, yielding only instances that whose field values are included in the specified value lists.
|
| 963 |
-
|
| 964 |
-
Args:
|
| 965 |
-
required_values (Dict[str, List]): For each field, the list of values that instances should match to be included in the output.
|
| 966 |
-
"""
|
| 967 |
-
|
| 968 |
-
required_values: Dict[str, List]
|
| 969 |
|
|
|
|
| 970 |
def verify(self):
|
| 971 |
super().verify()
|
| 972 |
-
for key, value in self.required_values.items():
|
| 973 |
-
if not isinstance(value, list):
|
| 974 |
-
raise ValueError(
|
| 975 |
-
f"The filter for key ('{key}') in FilterByListsOfValues is not a list but '{value}'"
|
| 976 |
-
)
|
| 977 |
|
| 978 |
-
def
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
for key, value in self.required_values.items():
|
| 982 |
-
if key not in instance:
|
| 983 |
-
raise ValueError(
|
| 984 |
-
f"Required filter field ('{key}') in FilterByListsOfValues is not found in {instance}"
|
| 985 |
-
)
|
| 986 |
-
if instance[key] not in value:
|
| 987 |
-
filter = True
|
| 988 |
-
if not filter:
|
| 989 |
-
yield instance
|
| 990 |
|
| 991 |
|
| 992 |
class Intersect(FieldOperator):
|
|
@@ -1011,6 +1360,7 @@ class Intersect(FieldOperator):
|
|
| 1011 |
)
|
| 1012 |
|
| 1013 |
def process_value(self, value: Any) -> Any:
|
|
|
|
| 1014 |
if not isinstance(value, list):
|
| 1015 |
raise ValueError(f"The value in field is not a list but '{value}'")
|
| 1016 |
return [e for e in value if e in self.allowed_values]
|
|
@@ -1020,7 +1370,7 @@ class RemoveValues(FieldOperator):
|
|
| 1020 |
"""Removes elements in a field, which must be a list, using a given list of unallowed.
|
| 1021 |
|
| 1022 |
Args:
|
| 1023 |
-
unallowed_values (list) -
|
| 1024 |
"""
|
| 1025 |
|
| 1026 |
unallowed_values: List[Any]
|
|
@@ -1089,8 +1439,8 @@ class SplitByValue(MultiStreamOperator):
|
|
| 1089 |
stream_unique_values = uniques[stream_name]
|
| 1090 |
for unique_values in stream_unique_values:
|
| 1091 |
filtering_values = dict(zip(self.fields, unique_values))
|
| 1092 |
-
filtered_streams =
|
| 1093 |
-
|
| 1094 |
)._process_single_stream(stream)
|
| 1095 |
filtered_stream_name = (
|
| 1096 |
stream_name + "_" + nested_tuple_to_string(unique_values)
|
|
@@ -1112,7 +1462,7 @@ class ApplyStreamOperatorsField(SingleStreamOperator, ArtifactFetcherMixin):
|
|
| 1112 |
reversed: bool = False
|
| 1113 |
|
| 1114 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1115 |
-
first_instance = stream.
|
| 1116 |
|
| 1117 |
operators = first_instance.get(self.field, [])
|
| 1118 |
if isinstance(operators, str):
|
|
@@ -1146,7 +1496,7 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
|
|
| 1146 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1147 |
from .metrics import Metric, MetricPipeline, MetricWithConfidenceInterval
|
| 1148 |
|
| 1149 |
-
first_instance = stream.
|
| 1150 |
|
| 1151 |
metric_names = first_instance.get(self.metric_field, [])
|
| 1152 |
if not metric_names:
|
|
@@ -1182,27 +1532,6 @@ class ApplyMetric(SingleStreamOperator, ArtifactFetcherMixin):
|
|
| 1182 |
yield from stream
|
| 1183 |
|
| 1184 |
|
| 1185 |
-
class AddFieldNamePrefix(StreamInstanceOperator):
|
| 1186 |
-
"""Adds a prefix to each field name in each instance of a stream.
|
| 1187 |
-
|
| 1188 |
-
Args:
|
| 1189 |
-
prefix_dict (Dict[str, str]): A dictionary mapping stream names to prefixes.
|
| 1190 |
-
"""
|
| 1191 |
-
|
| 1192 |
-
prefix_dict: Dict[str, str]
|
| 1193 |
-
|
| 1194 |
-
def prepare(self):
|
| 1195 |
-
return super().prepare()
|
| 1196 |
-
|
| 1197 |
-
def process(
|
| 1198 |
-
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 1199 |
-
) -> Dict[str, Any]:
|
| 1200 |
-
return {
|
| 1201 |
-
self.prefix_dict[stream_name] + key: value
|
| 1202 |
-
for key, value in instance.items()
|
| 1203 |
-
}
|
| 1204 |
-
|
| 1205 |
-
|
| 1206 |
class MergeStreams(MultiStreamOperator):
|
| 1207 |
"""Merges multiple streams into a single stream.
|
| 1208 |
|
|
@@ -1238,20 +1567,39 @@ class MergeStreams(MultiStreamOperator):
|
|
| 1238 |
class Shuffle(PagedStreamOperator):
|
| 1239 |
"""Shuffles the order of instances in each page of a stream.
|
| 1240 |
|
| 1241 |
-
Args:
|
| 1242 |
page_size (int): The size of each page in the stream. Defaults to 1000.
|
| 1243 |
"""
|
| 1244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1245 |
def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
|
| 1246 |
-
|
| 1247 |
yield from page
|
| 1248 |
|
| 1249 |
|
| 1250 |
class EncodeLabels(StreamInstanceOperator):
|
| 1251 |
-
"""Encode
|
|
|
|
|
|
|
|
|
|
| 1252 |
|
| 1253 |
Args:
|
| 1254 |
fields (List[str]): The fields to encode together.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1255 |
"""
|
| 1256 |
|
| 1257 |
fields: List[str]
|
|
@@ -1279,7 +1627,23 @@ class EncodeLabels(StreamInstanceOperator):
|
|
| 1279 |
|
| 1280 |
|
| 1281 |
class StreamRefiner(SingleStreamOperator):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
max_instances: int = None
|
|
|
|
| 1283 |
|
| 1284 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1285 |
if self.max_instances is not None:
|
|
@@ -1291,13 +1655,23 @@ class StreamRefiner(SingleStreamOperator):
|
|
| 1291 |
class DeterministicBalancer(StreamRefiner):
|
| 1292 |
"""A class used to balance streams deterministically.
|
| 1293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1294 |
Attributes:
|
| 1295 |
-
fields (List[str]): A list of field names to be used in
|
| 1296 |
-
|
| 1297 |
|
| 1298 |
Usage:
|
| 1299 |
-
balancer = DeterministicBalancer(fields=["field1", "field2"],
|
| 1300 |
balanced_stream = balancer.process(stream)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1301 |
"""
|
| 1302 |
|
| 1303 |
fields: List[str]
|
|
@@ -1334,7 +1708,23 @@ class DeterministicBalancer(StreamRefiner):
|
|
| 1334 |
|
| 1335 |
|
| 1336 |
class LengthBalancer(DeterministicBalancer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1337 |
segments_boundaries: List[int]
|
|
|
|
| 1338 |
|
| 1339 |
def signature(self, instance):
|
| 1340 |
total_len = 0
|
|
|
|
| 1 |
+
"""This section describes unitxt operators.
|
| 2 |
+
|
| 3 |
+
Operators: Building Blocks of Unitxt Processing Pipelines
|
| 4 |
+
==============================================================
|
| 5 |
+
|
| 6 |
+
Within the Unitxt framework, operators serve as the foundational elements used to assemble processing pipelines.
|
| 7 |
+
Each operator is designed to perform specific manipulations on dictionary structures within a stream.
|
| 8 |
+
These operators are callable entities that receive a MultiStream as input.
|
| 9 |
+
The output is a MultiStream, augmented with the operator's manipulations, which are then systematically applied to each instance in the stream when pulled.
|
| 10 |
+
|
| 11 |
+
Creating Custom Operators
|
| 12 |
+
-------------------------------
|
| 13 |
+
To enhance the functionality of Unitxt, users are encouraged to develop custom operators.
|
| 14 |
+
This can be achieved by inheriting from any of the existing operators listed below or from one of the fundamental :class:`base operators<unitxt.operator>`.
|
| 15 |
+
The primary task in any operator development is to implement the `process` function, which defines the unique manipulations the operator will perform.
|
| 16 |
+
|
| 17 |
+
General or Specelized Operators
|
| 18 |
+
--------------------------------
|
| 19 |
+
Some operators are specielized in specific task such as:
|
| 20 |
+
|
| 21 |
+
- :class:`loaders<unitxt.loaders>` for loading data.
|
| 22 |
+
- :class:`splitters<unitxt.splitters>` for fixing data splits.
|
| 23 |
+
|
| 24 |
+
Other specelized operators are used by unitxt internally:
|
| 25 |
+
|
| 26 |
+
- :class:`templates<unitxt.templates>` for verbalizing data examples.
|
| 27 |
+
- :class:`formats<unitxt.formats>` for preparing data for models.
|
| 28 |
+
|
| 29 |
+
The rest of this section is dedicated for general operators.
|
| 30 |
+
|
| 31 |
+
General Operaotrs List:
|
| 32 |
+
------------------------
|
| 33 |
+
"""
|
| 34 |
import collections
|
| 35 |
import importlib
|
| 36 |
+
import operator
|
| 37 |
+
import os
|
| 38 |
import uuid
|
| 39 |
from abc import abstractmethod
|
| 40 |
from collections import Counter
|
| 41 |
from copy import deepcopy
|
| 42 |
from dataclasses import field
|
| 43 |
from itertools import zip_longest
|
| 44 |
+
from random import Random
|
| 45 |
from typing import (
|
| 46 |
Any,
|
| 47 |
Callable,
|
|
|
|
| 68 |
StreamInstanceOperator,
|
| 69 |
StreamSource,
|
| 70 |
)
|
| 71 |
+
from .random_utils import new_random_generator
|
| 72 |
from .stream import Stream
|
| 73 |
from .text_utils import nested_tuple_to_string
|
| 74 |
+
from .type_utils import isoftype
|
| 75 |
from .utils import flatten_dict
|
| 76 |
|
| 77 |
|
| 78 |
class FromIterables(StreamInitializerOperator):
|
| 79 |
+
"""Creates a MultiStream from a dict of named iterables.
|
| 80 |
+
|
| 81 |
+
Example:
|
| 82 |
+
operator = FromIterables()
|
| 83 |
+
ms = operator.process(iterables)
|
| 84 |
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
|
| 87 |
def process(self, iterables: Dict[str, Iterable]) -> MultiStream:
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
class IterableSource(StreamSource):
|
| 92 |
+
"""Creates a MultiStream from a dict of named iterables.
|
| 93 |
+
|
| 94 |
+
It is a callable.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
iterables (Dict[str, Iterable]): A dictionary mapping stream names to iterables.
|
| 98 |
+
|
| 99 |
+
Example:
|
| 100 |
+
operator = IterableSource(input_dict)
|
| 101 |
+
ms = operator()
|
| 102 |
+
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
iterables: Dict[str, Iterable]
|
| 106 |
|
| 107 |
def __call__(self) -> MultiStream:
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
class MapInstanceValues(StreamInstanceOperator):
|
| 112 |
+
"""A class used to map instance values into other values.
|
| 113 |
|
| 114 |
This class is a type of StreamInstanceOperator,
|
| 115 |
it maps values of instances in a stream using predefined mappers.
|
|
|
|
| 139 |
To ensure that all values of field 'a' are mapped in every instance, use strict=True.
|
| 140 |
Input instance {"a":"3", "b": 2} will raise an exception per the above call,
|
| 141 |
because "3" is not a key in the mapper of "a".
|
| 142 |
+
|
| 143 |
+
MapInstanceValues(mappers={"a": {str([1,2,3,4]): 'All', str([]): 'None'}}, strict=True)
|
| 144 |
+
replaces a list [1,2,3,4] with the string 'All' and an empty list by string 'None'.
|
| 145 |
+
Note that mapped values are defined by their string representation, so mapped values
|
| 146 |
+
must be converted to strings.
|
| 147 |
"""
|
| 148 |
|
| 149 |
mappers: Dict[str, Dict[str, str]]
|
|
|
|
| 172 |
raise ValueError(
|
| 173 |
f"'process_every_field' == True is allowed only when all fields which have mappers, i.e., {list(self.mappers.keys())} are lists. Instace = {instance}"
|
| 174 |
)
|
| 175 |
+
if isinstance(value, list) and self.process_every_value:
|
| 176 |
+
for i, val in enumerate(value):
|
| 177 |
+
value[i] = self.get_mapped_value(instance, key, mapper, val)
|
| 178 |
+
else:
|
| 179 |
+
value = self.get_mapped_value(instance, key, mapper, value)
|
| 180 |
+
dict_set(
|
| 181 |
+
instance,
|
| 182 |
+
key,
|
| 183 |
+
value,
|
| 184 |
+
use_dpath=self.use_query,
|
| 185 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
return instance
|
| 188 |
|
| 189 |
+
def get_mapped_value(self, instance, key, mapper, val):
|
| 190 |
+
val_as_str = str(val) # make sure the value is a string
|
| 191 |
+
if self.strict and (val_as_str not in mapper):
|
| 192 |
+
raise KeyError(
|
| 193 |
+
f"value '{val}' in instance '{instance}' is not found in mapper '{mapper}', associated with field '{key}'."
|
| 194 |
+
)
|
| 195 |
+
# By default deep copy the value in mapper to avoid shared modifications
|
| 196 |
+
if val_as_str in mapper:
|
| 197 |
+
return deepcopy(mapper[val_as_str])
|
| 198 |
+
return val
|
| 199 |
+
|
| 200 |
|
| 201 |
class FlattenInstances(StreamInstanceOperator):
|
| 202 |
"""Flattens each instance in a stream, making nested dictionary entries into top-level entries.
|
|
|
|
| 236 |
# Add a 'classes' field on a given list, prevent modification of original list
|
| 237 |
# from changing the instance.
|
| 238 |
AddFields(fields={"classes": alist}), use_deepcopy=True)
|
| 239 |
+
# if now alist is modified, still the instances remain intact.
|
| 240 |
"""
|
| 241 |
|
| 242 |
fields: Dict[str, object]
|
|
|
|
| 259 |
|
| 260 |
|
| 261 |
class RemoveFields(StreamInstanceOperator):
|
| 262 |
+
"""Remove specified fields from each instance in a stream.
|
| 263 |
|
| 264 |
Args:
|
| 265 |
fields (List[str]): The fields to remove from each instance.
|
|
|
|
| 276 |
|
| 277 |
|
| 278 |
class FieldOperator(StreamInstanceOperator):
|
| 279 |
+
"""A general stream instance operator that processes the values of a field (or multiple ones).
|
| 280 |
|
| 281 |
Args:
|
| 282 |
+
field (Optional[str]): The field to process, if only a single one is passed. Defaults to None
|
| 283 |
+
to_field (Optional[str]): Field name to save result into, if only one field is processed, if None is passed the
|
| 284 |
+
operation would happen in-place and its result would replace the value of "field". Defaults to None
|
| 285 |
+
field_to_field (Optional[Union[List[List[str]], Dict[str, str]]]): Mapping from names of fields to process,
|
| 286 |
+
to names of fields to save the results into. Inner List, if used, should be of length 2.
|
| 287 |
+
A field is processed by feeding its value into method 'process_value' and storing the result in to_field that
|
| 288 |
+
is mapped to the field.
|
| 289 |
+
When the type of argument 'field_to_field' is List, the order by which the fields are processed is their order
|
| 290 |
+
in the (outer) List. But when the type of argument 'field_to_field' is Dict, there is no uniquely determined
|
| 291 |
+
order. The end result might depend on that order if either (1) two different fields are mapped to the same
|
| 292 |
+
to_field, or (2) a field shows both as a key and as a value in different mappings.
|
| 293 |
+
The operator throws an AssertionError in either of these cases.
|
| 294 |
+
field_to_field defaults to None
|
| 295 |
process_every_value (bool): Processes the values in a list instead of the list as a value, similar to *var. Defaults to False
|
| 296 |
use_query (bool): Whether to use dpath style queries. Defaults to False.
|
| 297 |
+
|
| 298 |
+
Note: if 'field' and 'to_field' (or both members of a pair in 'field_to_field') are equal (or share a common
|
| 299 |
+
prefix if 'use_query'=True), then the result of the operation is saved within 'field'
|
| 300 |
"""
|
| 301 |
|
| 302 |
field: Optional[str] = None
|
| 303 |
to_field: Optional[str] = None
|
| 304 |
+
field_to_field: Optional[Union[List[List[str]], Dict[str, str]]] = None
|
| 305 |
process_every_value: bool = False
|
| 306 |
use_query: bool = False
|
| 307 |
get_default: Any = None
|
|
|
|
| 318 |
), f"Can not apply operator to create both on {self.to_field} and on the mapping from fields to fields {self.field_to_field}"
|
| 319 |
assert (
|
| 320 |
self.field is None or self.field_to_field is None
|
| 321 |
+
), f"Can not apply operator both on {self.field} and on the from fields in the mapping {self.field_to_field}"
|
| 322 |
+
assert self._field_to_field, f"the from and to fields must be defined or implied from the other inputs got: {self._field_to_field}"
|
| 323 |
assert (
|
| 324 |
+
len(self._field_to_field) > 0
|
| 325 |
+
), f"'input argument 'field_to_field' should convey at least one field to process. Got {self.field_to_field}"
|
| 326 |
+
# self._field_to_field is built explicitly by pairs, or copied from argument 'field_to_field'
|
| 327 |
+
if self.field_to_field is None:
|
| 328 |
+
return
|
| 329 |
+
# for backward compatibility also allow list of tupples of two strings
|
| 330 |
+
if isoftype(self.field_to_field, List[List[str]]) or isoftype(
|
| 331 |
+
self.field_to_field, List[Tuple[str, str]]
|
| 332 |
+
):
|
| 333 |
+
for pair in self._field_to_field:
|
| 334 |
+
assert (
|
| 335 |
+
len(pair) == 2
|
| 336 |
+
), f"when 'field_to_field' is defined as a list of lists, the inner lists should all be of length 2. {self.field_to_field}"
|
| 337 |
+
# order of field processing is uniquely determined by the input field_to_field when a list
|
| 338 |
+
return
|
| 339 |
+
if isoftype(self.field_to_field, Dict[str, str]):
|
| 340 |
+
if len(self.field_to_field) < 2:
|
| 341 |
+
return
|
| 342 |
+
for ff, tt in self.field_to_field.items():
|
| 343 |
+
for f, t in self.field_to_field.items():
|
| 344 |
+
if f == ff:
|
| 345 |
+
continue
|
| 346 |
+
assert (
|
| 347 |
+
t != ff
|
| 348 |
+
), f"In input argument 'field_to_field': {self.field_to_field}, field {f} is mapped to field {t}, while the latter is mapped to {tt}. Whether {f} or {t} is processed first might impact end result."
|
| 349 |
+
assert (
|
| 350 |
+
tt != t
|
| 351 |
+
), f"In input argument 'field_to_field': {self.field_to_field}, two different fields: {ff} and {f} are mapped to field {tt}. Whether {ff} or {f} is processed last might impact end result."
|
| 352 |
+
return
|
| 353 |
+
raise ValueError(
|
| 354 |
+
"Input argument 'field_to_field': {self.field_to_field} is neither of type List{List[str]] nor of type Dict[str, str]."
|
| 355 |
+
)
|
| 356 |
|
| 357 |
@abstractmethod
|
| 358 |
def process_value(self, value: Any) -> Any:
|
| 359 |
pass
|
| 360 |
|
| 361 |
def prepare(self):
|
| 362 |
+
super().prepare()
|
| 363 |
+
|
| 364 |
+
# prepare is invoked before verify, hence must make some checks here, before the changes done here
|
| 365 |
+
assert (
|
| 366 |
+
(self.field is None) != (self.field_to_field is None)
|
| 367 |
+
), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
|
| 368 |
+
assert (
|
| 369 |
+
self.to_field is None or self.field_to_field is None
|
| 370 |
+
), f"Can not apply operator to create both {self.to_field} and the to fields in the mapping {self.field_to_field}"
|
| 371 |
+
|
| 372 |
if self.field_to_field is None:
|
| 373 |
+
self._field_to_field = [
|
| 374 |
+
(self.field, self.to_field if self.to_field is not None else self.field)
|
| 375 |
+
]
|
| 376 |
else:
|
| 377 |
+
self._field_to_field = (
|
| 378 |
+
list(self.field_to_field.items())
|
| 379 |
+
if isinstance(self.field_to_field, dict)
|
| 380 |
+
else self.field_to_field
|
| 381 |
+
)
|
| 382 |
|
| 383 |
def process(
|
| 384 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
|
|
|
| 405 |
raise ValueError(
|
| 406 |
f"Failed to process '{from_field}' from {instance} due to : {e}"
|
| 407 |
) from e
|
| 408 |
+
if is_subpath(from_field, to_field) or is_subpath(to_field, from_field):
|
| 409 |
dict_delete(instance, from_field)
|
| 410 |
dict_set(
|
| 411 |
instance,
|
|
|
|
| 418 |
|
| 419 |
|
| 420 |
class RenameFields(FieldOperator):
|
| 421 |
+
"""Renames fields.
|
| 422 |
+
|
| 423 |
+
Move value from one field to another, potentially, if 'use_query'=True, from one branch into another.
|
| 424 |
+
Remove the from field, potentially part of it in case of use_query.
|
| 425 |
+
|
| 426 |
+
Examples:
|
| 427 |
+
RenameFields(field_to_field={"b": "c"})
|
| 428 |
+
will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": 2}, {"a": 2, "c": 3}]
|
| 429 |
+
|
| 430 |
+
RenameFields(field_to_field={"b": "c/d"}, use_query=True)
|
| 431 |
+
will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "c": {"d": 2}}, {"a": 2, "c": {"d": 3}}]
|
| 432 |
+
|
| 433 |
+
RenameFields(field_to_field={"b": "b/d"}, use_query=True)
|
| 434 |
+
will change inputs [{"a": 1, "b": 2}, {"a": 2, "b": 3}] to [{"a": 1, "b": {"d": 2}}, {"a": 2, "b": {"d": 3}}]
|
| 435 |
+
|
| 436 |
+
RenameFields(field_to_field={"b/c/e": "b/d"}, use_query=True)
|
| 437 |
+
will change inputs [{"a": 1, "b": {"c": {"e": 2, "f": 20}}}] to [{"a": 1, "b": {"c": {"f": 20}, "d": 2}}]
|
| 438 |
+
|
| 439 |
+
"""
|
| 440 |
|
| 441 |
def process_value(self, value: Any) -> Any:
|
| 442 |
return value
|
|
|
|
| 445 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 446 |
) -> Dict[str, Any]:
|
| 447 |
res = super().process(instance=instance, stream_name=stream_name)
|
| 448 |
+
for from_field, to_field in self._field_to_field:
|
| 449 |
+
if (not is_subpath(from_field, to_field)) and (
|
| 450 |
+
not is_subpath(to_field, from_field)
|
| 451 |
+
):
|
| 452 |
+
dict_delete(res, from_field)
|
| 453 |
+
if self.use_query:
|
| 454 |
+
from_field_components = list(
|
| 455 |
+
os.path.normpath(from_field).split(os.path.sep)
|
| 456 |
+
)
|
| 457 |
+
while len(from_field_components) > 1:
|
| 458 |
+
from_field_components.pop()
|
| 459 |
+
parent = dict_get(res, os.path.sep.join(from_field_components))
|
| 460 |
+
if isinstance(parent, dict) and not parent:
|
| 461 |
+
dict_delete(res, os.path.sep.join(from_field_components))
|
| 462 |
+
else:
|
| 463 |
+
break
|
| 464 |
+
|
| 465 |
return res
|
| 466 |
|
| 467 |
|
| 468 |
class AddConstant(FieldOperator):
|
| 469 |
+
"""Adds a constant, being argument 'add', to the processed value.
|
| 470 |
|
| 471 |
Args:
|
| 472 |
+
add: the constant to add.
|
| 473 |
"""
|
| 474 |
|
| 475 |
add: Any
|
|
|
|
| 535 |
default="",
|
| 536 |
not_exist_ok=False,
|
| 537 |
)
|
| 538 |
+
except ValueError as e:
|
| 539 |
raise TypeError(f"Failed to get {field_name} from {instance}") from e
|
| 540 |
|
| 541 |
+
try:
|
| 542 |
+
new_value = self.process_value(old_value)
|
| 543 |
+
except Exception as e:
|
| 544 |
+
raise RuntimeError(
|
| 545 |
+
f"Error augmenting value '{old_value}' from '{field_name}' in instance: {instance}"
|
| 546 |
+
) from e
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
dict_set(instance, field_name, new_value, use_dpath=True, not_exist_ok=True)
|
| 548 |
return instance
|
| 549 |
|
|
|
|
| 568 |
words = re.split(r"(\s+)", value)
|
| 569 |
new_value = ""
|
| 570 |
|
| 571 |
+
random_generator = new_random_generator(sub_seed=value)
|
| 572 |
for word in words:
|
| 573 |
if word.isspace():
|
| 574 |
+
new_value += random_generator.choice(
|
| 575 |
["\n", "\t", " "]
|
| 576 |
+
) * random_generator.randint(1, 3)
|
| 577 |
else:
|
| 578 |
new_value += word
|
| 579 |
return new_value
|
| 580 |
|
| 581 |
|
| 582 |
+
class AugmentPrefixSuffix(Augmentor):
|
| 583 |
+
r"""Augments the input by prepending and appending to it a randomly selected (typically, whitespace) patterns.
|
| 584 |
|
| 585 |
Args:
|
| 586 |
+
prefixes, suffixes (list or dict) : the potential (typically, whitespace) patterns to select from.
|
| 587 |
The dictionary version allows to specify relative weights of the different patterns.
|
| 588 |
+
prefix_len, suffix_len (positive int) : The added prefix or suffix will be of length
|
| 589 |
+
prefix_len of suffix_len, respectively, repetitions of the randomly selected patterns.
|
| 590 |
+
remove_existing_whitespaces : allows to first clean any existing leading and trailing whitespaces.
|
| 591 |
+
The strings made of repetitions of the selected pattern(s) are then prepended and/or appended to the potentially
|
| 592 |
+
trimmed input.
|
| 593 |
+
If only one of prefixes/suffixes is needed, set the other to None.
|
| 594 |
|
| 595 |
Examples:
|
| 596 |
+
To prepend the input with a prefix made of 4 '\n'-s or '\t'-s, employ
|
| 597 |
+
AugmentPrefixSuffix(augment_model_input=True, prefixes=['\n','\t'], prefix_len=4, suffixes = None)
|
| 598 |
+
To append the input with a suffix made of 3 '\n'-s or '\t'-s, with triple '\n' suffixes
|
| 599 |
+
being preferred over triple '\t', at 2:1 ratio, employ
|
| 600 |
+
AugmentPrefixSuffix(augment_model_input=True, suffixes={'\n':2,'\t':1}, suffix_len=3, prefixes = None)
|
| 601 |
+
which will append '\n'-s twice as often as '\t'-s.
|
| 602 |
|
| 603 |
"""
|
| 604 |
|
| 605 |
+
prefixes: Optional[Union[List[str], Dict[str, int]]] = {
|
| 606 |
+
" ": 20,
|
| 607 |
+
"\\t": 10,
|
| 608 |
+
"\\n": 40,
|
| 609 |
+
"": 30,
|
| 610 |
+
}
|
| 611 |
+
prefix_len: Optional[int] = 3
|
| 612 |
+
suffixes: Optional[Union[List[str], Dict[str, int]]] = {
|
| 613 |
+
" ": 20,
|
| 614 |
+
"\\t": 10,
|
| 615 |
+
"\\n": 40,
|
| 616 |
+
"": 30,
|
| 617 |
+
}
|
| 618 |
+
suffix_len: Optional[int] = 3
|
| 619 |
+
remove_existing_whitespaces: Optional[bool] = False
|
| 620 |
|
| 621 |
def verify(self):
|
| 622 |
assert (
|
| 623 |
+
self.prefixes or self.suffixes
|
| 624 |
+
), "At least one of prefixes/suffixes should be not None."
|
| 625 |
+
for arg, arg_name in zip(
|
| 626 |
+
[self.prefixes, self.suffixes], ["prefixes", "suffixes"]
|
| 627 |
+
):
|
| 628 |
+
assert (
|
| 629 |
+
arg is None or isoftype(arg, List[str]) or isoftype(arg, Dict[str, int])
|
| 630 |
+
), f"Argument {arg_name} should be either None or a list of strings or a dictionary str->int. {arg} is none of the above."
|
| 631 |
+
assert (
|
| 632 |
+
self.prefix_len > 0
|
| 633 |
+
), f"prefix_len must be positive, got {self.prefix_len}"
|
| 634 |
+
assert (
|
| 635 |
+
self.suffix_len > 0
|
| 636 |
+
), f"suffix_len must be positive, got {self.suffix_len}"
|
| 637 |
+
super().verify()
|
|
|
|
| 638 |
|
| 639 |
+
def _calculate_distributions(self, prefs_or_suffs):
|
| 640 |
+
if prefs_or_suffs is None:
|
| 641 |
+
return None, None
|
| 642 |
+
patterns = (
|
| 643 |
+
prefs_or_suffs
|
| 644 |
+
if isinstance(prefs_or_suffs, list)
|
| 645 |
+
else [k for k, v in prefs_or_suffs.items()]
|
| 646 |
)
|
| 647 |
total_weight = (
|
| 648 |
+
len(patterns)
|
| 649 |
+
if isinstance(prefs_or_suffs, list)
|
| 650 |
+
else sum([v for k, v in prefs_or_suffs.items()])
|
| 651 |
)
|
| 652 |
+
weights = (
|
| 653 |
+
[1.0 / total_weight] * len(patterns)
|
| 654 |
+
if isinstance(prefs_or_suffs, list)
|
| 655 |
+
else [float(prefs_or_suffs[p]) / total_weight for p in patterns]
|
| 656 |
)
|
| 657 |
+
return patterns, weights
|
| 658 |
+
|
| 659 |
+
def prepare(self):
|
| 660 |
+
# Being an artifact, prepare is invoked before verify. Here we need verify before the actions
|
| 661 |
+
self.verify()
|
| 662 |
+
self._prefix_pattern_distribution = {"length": self.prefix_len}
|
| 663 |
+
self._suffix_pattern_distribution = {"length": self.suffix_len}
|
| 664 |
+
|
| 665 |
+
(
|
| 666 |
+
self._prefix_pattern_distribution["patterns"],
|
| 667 |
+
self._prefix_pattern_distribution["weights"],
|
| 668 |
+
) = self._calculate_distributions(self.prefixes)
|
| 669 |
+
(
|
| 670 |
+
self._suffix_pattern_distribution["patterns"],
|
| 671 |
+
self._suffix_pattern_distribution["weights"],
|
| 672 |
+
) = self._calculate_distributions(self.suffixes)
|
| 673 |
+
super().prepare()
|
| 674 |
+
|
| 675 |
+
def _get_random_pattern(
|
| 676 |
+
self, pattern_distribution, random_generator: Random
|
| 677 |
+
) -> str:
|
| 678 |
+
string_to_add = ""
|
| 679 |
+
if pattern_distribution["patterns"]:
|
| 680 |
+
string_to_add = "".join(
|
| 681 |
+
random_generator.choices(
|
| 682 |
+
pattern_distribution["patterns"],
|
| 683 |
+
pattern_distribution["weights"],
|
| 684 |
+
k=pattern_distribution["length"],
|
| 685 |
+
)
|
| 686 |
+
)
|
| 687 |
+
return string_to_add
|
| 688 |
|
| 689 |
def process_value(self, value: Any) -> Any:
|
| 690 |
assert value is not None, "input value should not be None"
|
| 691 |
new_value = str(value)
|
| 692 |
+
if self.remove_existing_whitespaces:
|
| 693 |
+
new_value = new_value.strip()
|
| 694 |
+
random_generator = new_random_generator(sub_seed=value)
|
| 695 |
+
prefix = self._get_random_pattern(
|
| 696 |
+
self._prefix_pattern_distribution, random_generator
|
| 697 |
+
)
|
| 698 |
+
suffix = self._get_random_pattern(
|
| 699 |
+
self._suffix_pattern_distribution, random_generator
|
| 700 |
+
)
|
| 701 |
+
return prefix + new_value + suffix
|
| 702 |
|
| 703 |
|
| 704 |
class ShuffleFieldValues(FieldOperator):
|
| 705 |
+
"""Shuffles a list of values found in a field."""
|
| 706 |
|
| 707 |
def process_value(self, value: Any) -> Any:
|
| 708 |
res = list(value)
|
| 709 |
+
random_generator = new_random_generator(sub_seed=res)
|
| 710 |
+
random_generator.shuffle(res)
|
| 711 |
return res
|
| 712 |
|
| 713 |
|
|
|
|
| 812 |
|
| 813 |
|
| 814 |
class ZipFieldValues(StreamInstanceOperator):
|
| 815 |
+
"""Zips values of multiple fields in a given instance, similar to list(zip(*fields)).
|
| 816 |
+
|
| 817 |
+
The value in each of the specified 'fields' is assumed to be a list. The lists from all 'fields'
|
| 818 |
+
are zipped, and stored into 'to_field'.
|
| 819 |
|
| 820 |
+
If 'longest'=False, the length of the zipped result is determined by the shortest input value.
|
| 821 |
+
If 'longest'=False, the length of the zipped result is determined by the longest input, padding shorter
|
| 822 |
+
inputs with None -s.
|
| 823 |
+
|
| 824 |
+
"""
|
| 825 |
+
|
| 826 |
+
fields: List[str]
|
| 827 |
to_field: str
|
| 828 |
longest: bool = False
|
| 829 |
use_query: bool = False
|
|
|
|
| 843 |
|
| 844 |
|
| 845 |
class IndexOf(StreamInstanceOperator):
|
| 846 |
+
"""For a given instance, finds the offset of value of field 'index_of', within the value of field 'search_in'."""
|
| 847 |
|
| 848 |
search_in: str
|
| 849 |
index_of: str
|
|
|
|
| 860 |
|
| 861 |
|
| 862 |
class TakeByField(StreamInstanceOperator):
|
| 863 |
+
"""From field 'field' of a given instance, select the member indexed by field 'index', and store to field 'to_field'."""
|
| 864 |
|
| 865 |
field: str
|
| 866 |
index: str
|
|
|
|
| 881 |
|
| 882 |
|
| 883 |
class CopyFields(FieldOperator):
|
| 884 |
+
"""Copies values from specified fields to specified fields.
|
| 885 |
|
| 886 |
+
Args (of parent class):
|
| 887 |
field_to_field (Union[List[List], Dict[str, str]]): A list of lists, where each sublist contains the source field and the destination field, or a dictionary mapping source fields to destination fields.
|
| 888 |
+
use_query (bool): Whether to use dpath for accessing fields. Defaults to False.
|
| 889 |
+
|
| 890 |
+
Examples:
|
| 891 |
+
An input instance {"a": 2, "b": 3}, when processed by
|
| 892 |
+
CopyField(field_to_field={"a": "b"}
|
| 893 |
+
would yield {"a": 2, "b": 2}, and when processed by
|
| 894 |
+
CopyField(field_to_field={"a": "c"} would yield
|
| 895 |
+
{"a": 2, "b": 3, "c": 2}
|
| 896 |
+
|
| 897 |
+
with use_query=True, we can also copy inside the field:
|
| 898 |
+
CopyFields(field_to_field={"a/0": "a"}, use_query=True)
|
| 899 |
+
would process instance {"a": [1, 3]} into {"a": 1}
|
| 900 |
+
|
| 901 |
+
|
| 902 |
"""
|
| 903 |
|
| 904 |
def process_value(self, value: Any) -> Any:
|
|
|
|
| 906 |
|
| 907 |
|
| 908 |
class AddID(StreamInstanceOperator):
|
| 909 |
+
"""Stores a unique id value in the designated 'id_field_name' field of the given instance."""
|
| 910 |
+
|
| 911 |
id_field_name: str = "id"
|
| 912 |
|
| 913 |
def process(
|
|
|
|
| 921 |
"""Casts specified fields to specified types.
|
| 922 |
|
| 923 |
Args:
|
| 924 |
+
use_nested_query (bool): Whether to cast nested fields, expressed in dpath. Defaults to False.
|
| 925 |
+
fields (Dict[str, str]): A dictionary mapping field names to the names of the types to cast the fields to.
|
| 926 |
+
e.g: "int", "str", "float", "bool". Basic names of types
|
| 927 |
+
defaults (Dict[str, object]): A dictionary mapping field names to default values for cases of casting failure.
|
| 928 |
+
process_every_value (bool): If true, all fields involved must contain lists, and each value in the list is then casted. Defaults to False.
|
| 929 |
+
|
| 930 |
+
Examples:
|
| 931 |
+
CastFields(
|
| 932 |
+
fields={"a/d": "float", "b": "int"},
|
| 933 |
+
failure_defaults={"a/d": 0.0, "b": 0},
|
| 934 |
+
process_every_value=True,
|
| 935 |
+
use_nested_query=True
|
| 936 |
+
)
|
| 937 |
+
would process the input instance: {"a": {"d": ["half", "0.6", 1, 12]}, "b": ["2"]}
|
| 938 |
+
into {"a": {"d": [0.0, 0.6, 1.0, 12.0]}, "b": [2]}
|
| 939 |
+
|
| 940 |
"""
|
| 941 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 942 |
fields: Dict[str, str] = field(default_factory=dict)
|
| 943 |
failure_defaults: Dict[str, object] = field(default_factory=dict)
|
| 944 |
use_nested_query: bool = False
|
| 945 |
+
process_every_value: bool = False
|
| 946 |
+
|
| 947 |
+
def prepare(self):
|
| 948 |
+
self.types = {"int": int, "float": float, "str": str, "bool": bool}
|
| 949 |
|
| 950 |
def _cast_single(self, value, type, field):
|
| 951 |
try:
|
|
|
|
| 958 |
return self.failure_defaults[field]
|
| 959 |
|
| 960 |
def _cast_multiple(self, values, type, field):
|
| 961 |
+
return [self._cast_single(value, type, field) for value in values]
|
| 962 |
|
| 963 |
def process(
|
| 964 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 965 |
) -> Dict[str, Any]:
|
| 966 |
for field_name, type in self.fields.items():
|
| 967 |
value = dict_get(instance, field_name, use_dpath=self.use_nested_query)
|
| 968 |
+
if self.process_every_value:
|
| 969 |
+
assert isinstance(
|
| 970 |
+
value, list
|
| 971 |
+
), f"'process_every_value' can be set to True only for fields that contain lists, whereas in instance {instance}, the contents of field '{field_name}' is of type '{type(value)}'"
|
| 972 |
casted_value = self._cast_multiple(value, type, field_name)
|
| 973 |
else:
|
| 974 |
casted_value = self._cast_single(value, type, field_name)
|
|
|
|
| 978 |
return instance
|
| 979 |
|
| 980 |
|
| 981 |
+
class DivideAllFieldsBy(StreamInstanceOperator):
|
| 982 |
+
"""Recursively reach down to all fields that are float, and divide each by 'divisor'.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
|
| 984 |
+
The given instance is viewed as a tree whose internal nodes are dictionaries and lists, and
|
| 985 |
+
the leaves are either 'float' and then divided, or other basic type, in which case, a ValueError is raised
|
| 986 |
+
if input flag 'strict' is True, or -- left alone, if 'strict' is False.
|
| 987 |
+
|
| 988 |
+
Args:
|
| 989 |
+
divisor (float) the value to divide by
|
| 990 |
+
strict (bool) whether to raise an error upon visiting a leaf that is not float. Defaults to False.
|
| 991 |
+
|
| 992 |
+
Example:
|
| 993 |
+
when instance {"a": 10.0, "b": [2.0, 4.0, 7.0], "c": 5} is processed by operator:
|
| 994 |
+
operator = DivideAllFieldsBy(divisor=2.0)
|
| 995 |
+
the output is: {"a": 5.0, "b": [1.0, 2.0, 3.5], "c": 5}
|
| 996 |
+
If the operator were defined with strict=True, through:
|
| 997 |
+
operator = DivideAllFieldsBy(divisor=2.0, strict=True),
|
| 998 |
+
the processing of the above instance would raise a ValueError, for the integer at "c".
|
| 999 |
+
"""
|
| 1000 |
|
|
|
|
| 1001 |
divisor: float = 1.0
|
| 1002 |
strict: bool = False
|
| 1003 |
+
|
| 1004 |
+
def _recursive_divide(self, instance, divisor):
|
| 1005 |
+
if isinstance(instance, dict):
|
| 1006 |
+
for key, value in instance.items():
|
| 1007 |
+
instance[key] = self._recursive_divide(value, divisor)
|
| 1008 |
+
elif isinstance(instance, list):
|
| 1009 |
+
for i, value in enumerate(instance):
|
| 1010 |
+
instance[i] = self._recursive_divide(value, divisor)
|
| 1011 |
+
elif isinstance(instance, float):
|
| 1012 |
+
instance /= divisor
|
| 1013 |
+
elif self.strict:
|
| 1014 |
+
raise ValueError(f"Cannot divide instance of type {type(instance)}")
|
| 1015 |
+
return instance
|
| 1016 |
|
| 1017 |
def process(
|
| 1018 |
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 1019 |
) -> Dict[str, Any]:
|
| 1020 |
+
return self._recursive_divide(instance, self.divisor)
|
| 1021 |
|
| 1022 |
|
| 1023 |
class ArtifactFetcherMixin:
|
|
|
|
| 1041 |
"""Applies value operators to each instance in a stream based on specified fields.
|
| 1042 |
|
| 1043 |
Args:
|
| 1044 |
+
inputs_fields (List[str]): list of field names, the values in which are to be processed
|
| 1045 |
+
fields_to_treat_as_list (List[str]): sublist of input_fields, each member of this sublist is supposed to contain
|
| 1046 |
+
a list of values, each of which is to be processed.
|
| 1047 |
+
operators_field (str): name of the field that contains the list of names of the operators to be applied,
|
| 1048 |
+
one after the other, for the processing.
|
| 1049 |
default_operators (List[str]): A list of default operators to be used if no operators are found in the instance.
|
|
|
|
| 1050 |
|
| 1051 |
+
Example:
|
| 1052 |
+
when instance {"a": 111, "b": 2, "c": ["processors.to_string", "processors.first_character"]} is processed by operator:
|
| 1053 |
+
operator = ApplyOperatorsField(inputs_fields=["a"], operators_field="c", default_operators=["add"]),
|
| 1054 |
+
the resulting instance is: {"a": "1", "b": 2, "c": ["processors.to_string", "processors.first_character"]}
|
| 1055 |
|
| 1056 |
+
"""
|
| 1057 |
+
|
| 1058 |
+
inputs_fields: List[str]
|
| 1059 |
operators_field: str
|
| 1060 |
default_operators: List[str] = None
|
| 1061 |
fields_to_treat_as_list: List[str] = NonPositionalField(default_factory=list)
|
|
|
|
| 1067 |
if operator_names is None:
|
| 1068 |
assert (
|
| 1069 |
self.default_operators is not None
|
| 1070 |
+
), f"No operators found in field '{self.operators_field}', and no default operators provided."
|
| 1071 |
operator_names = self.default_operators
|
| 1072 |
|
| 1073 |
if isinstance(operator_names, str):
|
|
|
|
| 1080 |
if field_name in self.fields_to_treat_as_list:
|
| 1081 |
instance[field_name] = [operator.process(v) for v in value]
|
| 1082 |
else:
|
| 1083 |
+
instance[field_name] = operator.process(value)
|
| 1084 |
|
| 1085 |
return instance
|
| 1086 |
|
| 1087 |
|
| 1088 |
+
class FilterByCondition(SingleStreamOperator):
|
| 1089 |
+
"""Filters a stream, yielding only instances for which the required values follows the required condition operator.
|
| 1090 |
+
|
| 1091 |
+
Raises an error if a required key is missing.
|
| 1092 |
|
| 1093 |
Args:
|
| 1094 |
+
values (Dict[str, Any]): Values that instances must match using the condition to be included in the output.
|
| 1095 |
+
condition: the name of the desired condition operator between the key and the value in values ("gt", "ge", "lt", "le", "ne", "eq")
|
| 1096 |
+
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
|
| 1097 |
+
|
| 1098 |
+
Examples:
|
| 1099 |
+
FilterByCondition(values = {"a":4}, condition = "gt") will yield only instances where "a">4
|
| 1100 |
+
FilterByCondition(values = {"a":4}, condition = "le") will yield only instances where "a"<=4
|
| 1101 |
+
FilterByCondition(values = {"a":[4,8]}, condition = "in") will yield only instances where "a" is 4 or 8
|
| 1102 |
+
FilterByCondition(values = {"a":[4,8]}, condition = "not in") will yield only instances where "a" different from 4 or 8
|
| 1103 |
+
|
| 1104 |
"""
|
| 1105 |
|
| 1106 |
+
values: Dict[str, Any]
|
| 1107 |
+
condition: str
|
| 1108 |
+
condition_to_func = {
|
| 1109 |
+
"gt": operator.gt,
|
| 1110 |
+
"ge": operator.ge,
|
| 1111 |
+
"lt": operator.lt,
|
| 1112 |
+
"le": operator.le,
|
| 1113 |
+
"eq": operator.eq,
|
| 1114 |
+
"ne": operator.ne,
|
| 1115 |
+
"in": None, # Handled as special case
|
| 1116 |
+
"not in": None, # Handled as special case
|
| 1117 |
+
}
|
| 1118 |
+
error_on_filtered_all: bool = True
|
| 1119 |
|
| 1120 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1121 |
+
yielded = False
|
| 1122 |
for instance in stream:
|
| 1123 |
+
if self._is_required(instance):
|
| 1124 |
+
yielded = True
|
| 1125 |
+
yield instance
|
| 1126 |
+
|
| 1127 |
+
if not yielded and self.error_on_filtered_all:
|
| 1128 |
+
raise RuntimeError(
|
| 1129 |
+
f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
def verify(self):
|
| 1133 |
+
if self.condition not in self.condition_to_func:
|
| 1134 |
+
raise ValueError(
|
| 1135 |
+
f"Unsupported condition operator '{self.condition}', supported {list(self.condition_to_func.keys())}"
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
for key, value in self.values.items():
|
| 1139 |
+
if self.condition in ["in", "not it"] and not isinstance(value, list):
|
| 1140 |
+
raise ValueError(
|
| 1141 |
+
f"The filter for key ('{key}') in FilterByCondition with condition '{self.condition}' must be list but is not : '{value}'"
|
| 1142 |
+
)
|
| 1143 |
+
return super().verify()
|
| 1144 |
+
|
| 1145 |
+
def _is_required(self, instance: dict) -> bool:
|
| 1146 |
+
for key, value in self.values.items():
|
| 1147 |
+
if key not in instance:
|
| 1148 |
+
raise ValueError(
|
| 1149 |
+
f"Required filter field ('{key}') in FilterByCondition is not found in {instance}"
|
| 1150 |
+
)
|
| 1151 |
+
if self.condition == "in":
|
| 1152 |
+
if instance[key] not in value:
|
| 1153 |
+
return False
|
| 1154 |
+
elif self.condition == "not in":
|
| 1155 |
+
if instance[key] in value:
|
| 1156 |
+
return False
|
| 1157 |
+
else:
|
| 1158 |
+
func = self.condition_to_func[self.condition]
|
| 1159 |
+
if func is None:
|
| 1160 |
raise ValueError(
|
| 1161 |
+
f"Function not defined for condition '{self.condition}'"
|
| 1162 |
)
|
| 1163 |
+
if not func(instance[key], value):
|
| 1164 |
+
return False
|
| 1165 |
+
return True
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
class FilterByQuery(SingleStreamOperator):
|
| 1169 |
+
"""Filters a stream, yielding only instances which fulfil a condition specified as a string to be python's eval-uated.
|
| 1170 |
+
|
| 1171 |
+
Raises an error if a field participating in the specified condition is missing from the instance
|
| 1172 |
+
|
| 1173 |
+
Args:
|
| 1174 |
+
query (str): a condition over fields of the instance, to be processed by python's eval()
|
| 1175 |
+
error_on_filtered_all (bool, optional): If True, raises an error if all instances are filtered out. Defaults to True.
|
| 1176 |
+
|
| 1177 |
+
Examples:
|
| 1178 |
+
FilterByQuery(query = "a > 4") will yield only instances where "a">4
|
| 1179 |
+
FilterByQuery(query = "a <= 4 and b > 5") will yield only instances where the value of field "a" is not exceeding 4 and in field "b" -- greater than 5
|
| 1180 |
+
FilterByQuery(query = "a in [4, 8]") will yield only instances where "a" is 4 or 8
|
| 1181 |
+
FilterByQuery(query = "a not in [4, 8]") will yield only instances where "a" is neither 4 nor 8
|
| 1182 |
+
|
| 1183 |
+
"""
|
| 1184 |
+
|
| 1185 |
+
query: str
|
| 1186 |
+
error_on_filtered_all: bool = True
|
| 1187 |
+
|
| 1188 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1189 |
+
yielded = False
|
| 1190 |
+
for instance in stream:
|
| 1191 |
+
if eval(self.query, None, instance):
|
| 1192 |
+
yielded = True
|
| 1193 |
yield instance
|
| 1194 |
|
| 1195 |
+
if not yielded and self.error_on_filtered_all:
|
| 1196 |
+
raise RuntimeError(
|
| 1197 |
+
f"{self.__class__.__name__} filtered out every instance in stream '{stream_name}'. If this is intended set error_on_filtered_all=False"
|
| 1198 |
+
)
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
class ExecuteQuery(StreamInstanceOperator):
|
| 1202 |
+
"""Compute an expression (query), expressed as a string to be eval-uated, over the instance's fields, and store the result in field to_field.
|
| 1203 |
+
|
| 1204 |
+
Raises an error if a field mentioned in the query is missing from the instance.
|
| 1205 |
+
|
| 1206 |
+
Args:
|
| 1207 |
+
query (str): an expression to be evaluated over the fields of the instance
|
| 1208 |
+
to_field (str): the field where the result is to be stored into
|
| 1209 |
+
|
| 1210 |
+
Examples:
|
| 1211 |
+
When instance {"a": 2, "b": 3} is process-ed by operator
|
| 1212 |
+
ExecuteQuery(query="a+b", to_field = "c")
|
| 1213 |
+
the result is {"a": 2, "b": 3, "c": 5}
|
| 1214 |
+
|
| 1215 |
+
When instance {"a": "hello", "b": "world"} is process-ed by operator
|
| 1216 |
+
ExecuteQuery(query = "a+' '+b", to_field = "c")
|
| 1217 |
+
the result is {"a": "hello", "b": "world", "c": "hello world"}
|
| 1218 |
+
|
| 1219 |
+
"""
|
| 1220 |
+
|
| 1221 |
+
query: str
|
| 1222 |
+
to_field: str
|
| 1223 |
+
|
| 1224 |
+
def process(
|
| 1225 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
| 1226 |
+
) -> Dict[str, Any]:
|
| 1227 |
+
instance[self.to_field] = eval(self.query, None, instance)
|
| 1228 |
+
return instance
|
| 1229 |
+
|
| 1230 |
|
| 1231 |
+
class ExtractMostCommonFieldValues(MultiStreamOperator):
|
| 1232 |
field: str
|
| 1233 |
stream_name: str
|
| 1234 |
overall_top_frequency_percent: Optional[int] = 100
|
|
|
|
| 1249 |
|
| 1250 |
Examples:
|
| 1251 |
|
| 1252 |
+
ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes") - extracts all the unique values of
|
| 1253 |
field 'label', sorts them by decreasing frequency, and stores the resulting list in field 'classes' of each and
|
| 1254 |
every instance in all streams.
|
| 1255 |
|
| 1256 |
+
ExtractMostCommonFieldValues(stream_name="train", field="labels", to_field="classes", process_every_value=True) -
|
| 1257 |
in case that field 'labels' contains a list of values (and not a single value) - track the occurrences of all the possible
|
| 1258 |
value members in these lists, and report the most frequent values.
|
| 1259 |
if process_every_value=False, track the most frequent whole lists, and report those (as a list of lists) in field
|
| 1260 |
'to_field' of each instance of all streams.
|
| 1261 |
|
| 1262 |
+
ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",overall_top_frequency_percent=80) -
|
| 1263 |
extracts the most frequent possible values of field 'label' that together cover at least 80% of the instances of stream_name,
|
| 1264 |
and stores them in field 'classes' of each instance of all streams.
|
| 1265 |
|
| 1266 |
+
ExtractMostCommonFieldValues(stream_name="train", field="label", to_field="classes",min_frequency_percent=5) -
|
| 1267 |
extracts all possible values of field 'label' that cover, each, at least 5% of the instances.
|
| 1268 |
Stores these values, sorted by decreasing order of frequency, in field 'classes' of each instance in all streams.
|
| 1269 |
"""
|
|
|
|
| 1324 |
[*ele[0]] if isinstance(ele[0], tuple) else ele[0]
|
| 1325 |
for ele in values_and_counts
|
| 1326 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1327 |
|
| 1328 |
+
addmostcommons = AddFields(fields={self.to_field: values_to_keep})
|
| 1329 |
+
return addmostcommons(multi_stream)
|
| 1330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1331 |
|
| 1332 |
+
class ExtractFieldValues(ExtractMostCommonFieldValues):
|
| 1333 |
def verify(self):
|
| 1334 |
super().verify()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1335 |
|
| 1336 |
+
def prepare(self):
|
| 1337 |
+
self.overall_top_frequency_percent = 100
|
| 1338 |
+
self.min_frequency_percent = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1339 |
|
| 1340 |
|
| 1341 |
class Intersect(FieldOperator):
|
|
|
|
| 1360 |
)
|
| 1361 |
|
| 1362 |
def process_value(self, value: Any) -> Any:
|
| 1363 |
+
super().process_value(value)
|
| 1364 |
if not isinstance(value, list):
|
| 1365 |
raise ValueError(f"The value in field is not a list but '{value}'")
|
| 1366 |
return [e for e in value if e in self.allowed_values]
|
|
|
|
| 1370 |
"""Removes elements in a field, which must be a list, using a given list of unallowed.
|
| 1371 |
|
| 1372 |
Args:
|
| 1373 |
+
unallowed_values (list) - values to be removed.
|
| 1374 |
"""
|
| 1375 |
|
| 1376 |
unallowed_values: List[Any]
|
|
|
|
| 1439 |
stream_unique_values = uniques[stream_name]
|
| 1440 |
for unique_values in stream_unique_values:
|
| 1441 |
filtering_values = dict(zip(self.fields, unique_values))
|
| 1442 |
+
filtered_streams = FilterByCondition(
|
| 1443 |
+
values=filtering_values, condition="eq"
|
| 1444 |
)._process_single_stream(stream)
|
| 1445 |
filtered_stream_name = (
|
| 1446 |
stream_name + "_" + nested_tuple_to_string(unique_values)
|
|
|
|
| 1462 |
reversed: bool = False
|
| 1463 |
|
| 1464 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1465 |
+
first_instance = stream.peek()
|
| 1466 |
|
| 1467 |
operators = first_instance.get(self.field, [])
|
| 1468 |
if isinstance(operators, str):
|
|
|
|
| 1496 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1497 |
from .metrics import Metric, MetricPipeline, MetricWithConfidenceInterval
|
| 1498 |
|
| 1499 |
+
first_instance = stream.peek()
|
| 1500 |
|
| 1501 |
metric_names = first_instance.get(self.metric_field, [])
|
| 1502 |
if not metric_names:
|
|
|
|
| 1532 |
yield from stream
|
| 1533 |
|
| 1534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1535 |
class MergeStreams(MultiStreamOperator):
|
| 1536 |
"""Merges multiple streams into a single stream.
|
| 1537 |
|
|
|
|
| 1567 |
class Shuffle(PagedStreamOperator):
|
| 1568 |
"""Shuffles the order of instances in each page of a stream.
|
| 1569 |
|
| 1570 |
+
Args (of superclass):
|
| 1571 |
page_size (int): The size of each page in the stream. Defaults to 1000.
|
| 1572 |
"""
|
| 1573 |
|
| 1574 |
+
random_generator: Random = None
|
| 1575 |
+
|
| 1576 |
+
def before_process_multi_stream(self):
|
| 1577 |
+
super().before_process_multi_stream()
|
| 1578 |
+
self.random_generator = new_random_generator(sub_seed="shuffle")
|
| 1579 |
+
|
| 1580 |
def process(self, page: List[Dict], stream_name: Optional[str] = None) -> Generator:
|
| 1581 |
+
self.random_generator.shuffle(page)
|
| 1582 |
yield from page
|
| 1583 |
|
| 1584 |
|
| 1585 |
class EncodeLabels(StreamInstanceOperator):
|
| 1586 |
+
"""Encode each value encountered in any field in 'fields' into the integers 0,1,...
|
| 1587 |
+
|
| 1588 |
+
Encoding is determined by a str->int map that is built on the go, as different values are
|
| 1589 |
+
first encountered in the stream, either as list members or as values in single-value fields.
|
| 1590 |
|
| 1591 |
Args:
|
| 1592 |
fields (List[str]): The fields to encode together.
|
| 1593 |
+
|
| 1594 |
+
Example: applying
|
| 1595 |
+
EncodeLabels(fields = ["a", "b/*"])
|
| 1596 |
+
on input stream = [{"a": "red", "b": ["red", "blue"], "c":"bread"},
|
| 1597 |
+
{"a": "blue", "b": ["green"], "c":"water"}] will yield the
|
| 1598 |
+
output stream = [{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]
|
| 1599 |
+
|
| 1600 |
+
Note: dpath is applied here, and hence, fields that are lists, should be included in
|
| 1601 |
+
input 'fields' with the appendix "/*" as in the above example.
|
| 1602 |
+
|
| 1603 |
"""
|
| 1604 |
|
| 1605 |
fields: List[str]
|
|
|
|
| 1627 |
|
| 1628 |
|
| 1629 |
class StreamRefiner(SingleStreamOperator):
|
| 1630 |
+
"""Discard from the input stream all instances beyond the leading 'max_instances' instances.
|
| 1631 |
+
|
| 1632 |
+
Thereby, if the input stream consists of no more than 'max_instances' instances, the resulting stream is the whole of the
|
| 1633 |
+
input stream. And if the input stream consists of more than 'max_instances' instances, the resulting stream only consists
|
| 1634 |
+
of the leading 'max_instances' of the input stream.
|
| 1635 |
+
|
| 1636 |
+
Args: max_instances (int)
|
| 1637 |
+
apply_to_streams (optional, list(str)): names of streams to refine.
|
| 1638 |
+
|
| 1639 |
+
Examples:
|
| 1640 |
+
when input = [{"a": 1},{"a": 2},{"a": 3},{"a": 4},{"a": 5},{"a": 6}] is fed into
|
| 1641 |
+
StreamRefiner(max_instances=4)
|
| 1642 |
+
the resulting stream is [{"a": 1},{"a": 2},{"a": 3},{"a": 4}]
|
| 1643 |
+
"""
|
| 1644 |
+
|
| 1645 |
max_instances: int = None
|
| 1646 |
+
apply_to_streams: Optional[List[str]] = None
|
| 1647 |
|
| 1648 |
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1649 |
if self.max_instances is not None:
|
|
|
|
| 1655 |
class DeterministicBalancer(StreamRefiner):
|
| 1656 |
"""A class used to balance streams deterministically.
|
| 1657 |
|
| 1658 |
+
For each instance, a signature is constructed from the values of the instance in specified input 'fields'.
|
| 1659 |
+
By discarding instances from the input stream, DeterministicBalancer maintains equal number of instances for all signatures.
|
| 1660 |
+
When also input 'max_instances' is specified, DeterministicBalancer maintains a total instance count not exceeding
|
| 1661 |
+
'max_instances'. The total number of discarded instances is as few as possible.
|
| 1662 |
+
|
| 1663 |
Attributes:
|
| 1664 |
+
fields (List[str]): A list of field names to be used in producing the instance's signature.
|
| 1665 |
+
max_instances (Optional, int)
|
| 1666 |
|
| 1667 |
Usage:
|
| 1668 |
+
balancer = DeterministicBalancer(fields=["field1", "field2"], max_instances=200)
|
| 1669 |
balanced_stream = balancer.process(stream)
|
| 1670 |
+
|
| 1671 |
+
Example:
|
| 1672 |
+
When input [{"a": 1, "b": 1},{"a": 1, "b": 2},{"a": 2},{"a": 3},{"a": 4}] is fed into
|
| 1673 |
+
DeterministicBalancer(fields=["a"])
|
| 1674 |
+
the resulting stream will be: [{"a": 1, "b": 1},{"a": 2},{"a": 3},{"a": 4}]
|
| 1675 |
"""
|
| 1676 |
|
| 1677 |
fields: List[str]
|
|
|
|
| 1708 |
|
| 1709 |
|
| 1710 |
class LengthBalancer(DeterministicBalancer):
|
| 1711 |
+
"""Balances by a signature that reflects the total length of the fields' values, quantized into integer segments.
|
| 1712 |
+
|
| 1713 |
+
Args:
|
| 1714 |
+
segments_boundaries (List[int]): distinct integers sorted in increasing order, that maps a given total length
|
| 1715 |
+
into the index of the least of them that exceeds the total length. (If none exceeds -- into one index
|
| 1716 |
+
beyond, namely, the length of segments_boudaries)
|
| 1717 |
+
|
| 1718 |
+
fields (Optional, List[str])
|
| 1719 |
+
|
| 1720 |
+
Example:
|
| 1721 |
+
when input [{"a": [1, 3], "b": 0, "id": 0}, {"a": [1, 3], "b": 0, "id": 1}, {"a": [], "b": "a", "id": 2}] is fed into
|
| 1722 |
+
LengthBalancer(fields=["a"], segments_boundaries=[1])
|
| 1723 |
+
input instances will be counted and balanced against two categories: empty total length (less than 1), and non-empty.
|
| 1724 |
+
"""
|
| 1725 |
+
|
| 1726 |
segments_boundaries: List[int]
|
| 1727 |
+
fields: Optional[List[str]]
|
| 1728 |
|
| 1729 |
def signature(self, instance):
|
| 1730 |
total_len = 0
|