Upload operators.py with huggingface_hub
Browse files- operators.py +47 -21
operators.py
CHANGED
|
@@ -35,7 +35,6 @@ General Operaotrs List:
|
|
| 35 |
import collections
|
| 36 |
import copy
|
| 37 |
import operator
|
| 38 |
-
import os
|
| 39 |
import uuid
|
| 40 |
import zipfile
|
| 41 |
from abc import abstractmethod
|
|
@@ -418,8 +417,6 @@ class InstanceFieldOperator(StreamInstanceOperator):
|
|
| 418 |
raise ValueError(
|
| 419 |
f"Failed to process '{from_field}' from {instance} due to : {e}"
|
| 420 |
) from e
|
| 421 |
-
if is_subpath(from_field, to_field) or is_subpath(to_field, from_field):
|
| 422 |
-
dict_delete(instance, from_field)
|
| 423 |
dict_set(
|
| 424 |
instance,
|
| 425 |
to_field,
|
|
@@ -471,18 +468,7 @@ class RenameFields(FieldOperator):
|
|
| 471 |
if (not is_subpath(from_field, to_field)) and (
|
| 472 |
not is_subpath(to_field, from_field)
|
| 473 |
):
|
| 474 |
-
dict_delete(res, from_field)
|
| 475 |
-
if self.use_query:
|
| 476 |
-
from_field_components = list(
|
| 477 |
-
os.path.normpath(from_field).split(os.path.sep)
|
| 478 |
-
)
|
| 479 |
-
while len(from_field_components) > 1:
|
| 480 |
-
from_field_components.pop()
|
| 481 |
-
parent = dict_get(res, os.path.sep.join(from_field_components))
|
| 482 |
-
if isinstance(parent, dict) and not parent:
|
| 483 |
-
dict_delete(res, os.path.sep.join(from_field_components))
|
| 484 |
-
else:
|
| 485 |
-
break
|
| 486 |
|
| 487 |
return res
|
| 488 |
|
|
@@ -1480,10 +1466,6 @@ class RemoveValues(FieldOperator):
|
|
| 1480 |
|
| 1481 |
def verify(self):
|
| 1482 |
super().verify()
|
| 1483 |
-
if self.process_every_value:
|
| 1484 |
-
raise ValueError(
|
| 1485 |
-
"'process_every_value=True' is not supported in RemoveValues operator"
|
| 1486 |
-
)
|
| 1487 |
|
| 1488 |
if not isinstance(self.unallowed_values, list):
|
| 1489 |
raise ValueError(
|
|
@@ -1712,7 +1694,7 @@ class EncodeLabels(StreamInstanceOperator):
|
|
| 1712 |
{"a": "blue", "b": ["green"], "c":"water"}] will yield the
|
| 1713 |
output stream = [{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]
|
| 1714 |
|
| 1715 |
-
Note:
|
| 1716 |
input 'fields' with the appendix "/*" as in the above example.
|
| 1717 |
|
| 1718 |
"""
|
|
@@ -1728,14 +1710,21 @@ class EncodeLabels(StreamInstanceOperator):
|
|
| 1728 |
) -> Dict[str, Any]:
|
| 1729 |
for field_name in self.fields:
|
| 1730 |
values = dict_get(instance, field_name, use_dpath=True)
|
|
|
|
| 1731 |
if not isinstance(values, list):
|
| 1732 |
values = [values]
|
| 1733 |
for value in values:
|
| 1734 |
if value not in self.encoder:
|
| 1735 |
self.encoder[value] = len(self.encoder)
|
| 1736 |
new_values = [self.encoder[value] for value in values]
|
|
|
|
|
|
|
| 1737 |
dict_set(
|
| 1738 |
-
instance,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1739 |
)
|
| 1740 |
|
| 1741 |
return instance
|
|
@@ -1904,3 +1893,40 @@ class ExtractZipFile(SideEffectOperator):
|
|
| 1904 |
def process(self):
|
| 1905 |
with zipfile.ZipFile(self.zip_file) as zf:
|
| 1906 |
zf.extractall(self.target_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
import collections
|
| 36 |
import copy
|
| 37 |
import operator
|
|
|
|
| 38 |
import uuid
|
| 39 |
import zipfile
|
| 40 |
from abc import abstractmethod
|
|
|
|
| 417 |
raise ValueError(
|
| 418 |
f"Failed to process '{from_field}' from {instance} due to : {e}"
|
| 419 |
) from e
|
|
|
|
|
|
|
| 420 |
dict_set(
|
| 421 |
instance,
|
| 422 |
to_field,
|
|
|
|
| 468 |
if (not is_subpath(from_field, to_field)) and (
|
| 469 |
not is_subpath(to_field, from_field)
|
| 470 |
):
|
| 471 |
+
dict_delete(res, from_field, remove_empty_ancestors=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
return res
|
| 474 |
|
|
|
|
| 1466 |
|
| 1467 |
def verify(self):
|
| 1468 |
super().verify()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1469 |
|
| 1470 |
if not isinstance(self.unallowed_values, list):
|
| 1471 |
raise ValueError(
|
|
|
|
| 1694 |
{"a": "blue", "b": ["green"], "c":"water"}] will yield the
|
| 1695 |
output stream = [{'a': 0, 'b': [0, 1], 'c': 'bread'}, {'a': 1, 'b': [2], 'c': 'water'}]
|
| 1696 |
|
| 1697 |
+
Note: qpath is applied here, and hence, fields that are lists, should be included in
|
| 1698 |
input 'fields' with the appendix "/*" as in the above example.
|
| 1699 |
|
| 1700 |
"""
|
|
|
|
| 1710 |
) -> Dict[str, Any]:
|
| 1711 |
for field_name in self.fields:
|
| 1712 |
values = dict_get(instance, field_name, use_dpath=True)
|
| 1713 |
+
values_was_a_list = isinstance(values, list)
|
| 1714 |
if not isinstance(values, list):
|
| 1715 |
values = [values]
|
| 1716 |
for value in values:
|
| 1717 |
if value not in self.encoder:
|
| 1718 |
self.encoder[value] = len(self.encoder)
|
| 1719 |
new_values = [self.encoder[value] for value in values]
|
| 1720 |
+
if not values_was_a_list:
|
| 1721 |
+
new_values = new_values[0]
|
| 1722 |
dict_set(
|
| 1723 |
+
instance,
|
| 1724 |
+
field_name,
|
| 1725 |
+
new_values,
|
| 1726 |
+
use_dpath=True,
|
| 1727 |
+
set_multiple="*" in field_name,
|
| 1728 |
)
|
| 1729 |
|
| 1730 |
return instance
|
|
|
|
| 1893 |
def process(self):
|
| 1894 |
with zipfile.ZipFile(self.zip_file) as zf:
|
| 1895 |
zf.extractall(self.target_dir)
|
| 1896 |
+
|
| 1897 |
+
|
| 1898 |
+
class DuplicateInstances(SingleStreamOperator):
|
| 1899 |
+
"""Operator which duplicates each instance in stream a given number of times.
|
| 1900 |
+
|
| 1901 |
+
Attributes:
|
| 1902 |
+
num_duplications (int): How many times each instance should be duplicated (1 means no duplication).
|
| 1903 |
+
duplication_index_field (Optional[str]):
|
| 1904 |
+
If given, then additional field with specified name is added to each duplicated instance,
|
| 1905 |
+
which contains id of a given duplication. Defaults to None, so no field is added.
|
| 1906 |
+
"""
|
| 1907 |
+
|
| 1908 |
+
num_duplications: int
|
| 1909 |
+
duplication_index_field: Optional[str] = None
|
| 1910 |
+
|
| 1911 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
| 1912 |
+
for instance in stream:
|
| 1913 |
+
for idx in range(self.num_duplications):
|
| 1914 |
+
duplicate = deepcopy(instance)
|
| 1915 |
+
if self.duplication_index_field:
|
| 1916 |
+
duplicate.update({self.duplication_index_field: idx})
|
| 1917 |
+
yield duplicate
|
| 1918 |
+
|
| 1919 |
+
def verify(self):
|
| 1920 |
+
if not isinstance(self.num_duplications, int) or self.num_duplications < 1:
|
| 1921 |
+
raise ValueError(
|
| 1922 |
+
f"num_duplications must be an integer equal to or greater than 1. "
|
| 1923 |
+
f"Got: {self.num_duplications}."
|
| 1924 |
+
)
|
| 1925 |
+
|
| 1926 |
+
if self.duplication_index_field is not None and not isinstance(
|
| 1927 |
+
self.duplication_index_field, str
|
| 1928 |
+
):
|
| 1929 |
+
raise ValueError(
|
| 1930 |
+
f"If given, duplication_index_field must be a string. "
|
| 1931 |
+
f"Got: {self.duplication_index_field}"
|
| 1932 |
+
)
|