Upload stream.py with huggingface_hub
Browse files
    	
        stream.py
    CHANGED
    
    | @@ -1,4 +1,3 @@ | |
| 1 | 
            -
            from copy import deepcopy
         | 
| 2 | 
             
            from typing import Dict, Iterable
         | 
| 3 |  | 
| 4 | 
             
            from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
         | 
| @@ -31,11 +30,11 @@ class Stream(Dataclass): | |
| 31 | 
             
                    """
         | 
| 32 | 
             
                    if self.caching:
         | 
| 33 | 
             
                        return Dataset.from_generator
         | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 |  | 
| 40 | 
             
                def _get_stream(self):
         | 
| 41 | 
             
                    """Private method to get the stream based on the initiator function.
         | 
| @@ -102,12 +101,20 @@ class MultiStream(dict): | |
| 102 |  | 
| 103 | 
             
                def to_dataset(self) -> DatasetDict:
         | 
| 104 | 
             
                    return DatasetDict(
         | 
| 105 | 
            -
                        { | 
|  | |
|  | |
|  | |
| 106 | 
             
                    )
         | 
| 107 |  | 
| 108 | 
             
                def to_iterable_dataset(self) -> IterableDatasetDict:
         | 
| 109 | 
             
                    return IterableDatasetDict(
         | 
| 110 | 
            -
                        { | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 111 | 
             
                    )
         | 
| 112 |  | 
| 113 | 
             
                def __setitem__(self, key, value):
         | 
| @@ -116,17 +123,19 @@ class MultiStream(dict): | |
| 116 | 
             
                    super().__setitem__(key, value)
         | 
| 117 |  | 
| 118 | 
             
                @classmethod
         | 
| 119 | 
            -
                def from_generators( | 
|  | |
|  | |
| 120 | 
             
                    """Creates a MultiStream from a dictionary of ReusableGenerators.
         | 
| 121 |  | 
| 122 | 
             
                    Args:
         | 
| 123 | 
             
                        generators (Dict[str, ReusableGenerator]): A dictionary of ReusableGenerators.
         | 
| 124 | 
             
                        caching (bool, optional): Whether the data should be cached or not. Defaults to False.
         | 
|  | |
| 125 |  | 
| 126 | 
             
                    Returns:
         | 
| 127 | 
             
                        MultiStream: A MultiStream object.
         | 
| 128 | 
             
                    """
         | 
| 129 | 
            -
             | 
| 130 | 
             
                    assert all(isinstance(v, ReusableGenerator) for v in generators.values())
         | 
| 131 | 
             
                    return cls(
         | 
| 132 | 
             
                        {
         | 
| @@ -141,17 +150,19 @@ class MultiStream(dict): | |
| 141 | 
             
                    )
         | 
| 142 |  | 
| 143 | 
             
                @classmethod
         | 
| 144 | 
            -
                def from_iterables( | 
|  | |
|  | |
| 145 | 
             
                    """Creates a MultiStream from a dictionary of iterables.
         | 
| 146 |  | 
| 147 | 
             
                    Args:
         | 
| 148 | 
             
                        iterables (Dict[str, Iterable]): A dictionary of iterables.
         | 
| 149 | 
             
                        caching (bool, optional): Whether the data should be cached or not. Defaults to False.
         | 
|  | |
| 150 |  | 
| 151 | 
             
                    Returns:
         | 
| 152 | 
             
                        MultiStream: A MultiStream object.
         | 
| 153 | 
             
                    """
         | 
| 154 | 
            -
             | 
| 155 | 
             
                    return cls(
         | 
| 156 | 
             
                        {
         | 
| 157 | 
             
                            key: Stream(
         | 
|  | |
|  | |
| 1 | 
             
            from typing import Dict, Iterable
         | 
| 2 |  | 
| 3 | 
             
            from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
         | 
|  | |
| 30 | 
             
                    """
         | 
| 31 | 
             
                    if self.caching:
         | 
| 32 | 
             
                        return Dataset.from_generator
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if self.copying:
         | 
| 35 | 
            +
                        return CopyingReusableGenerator
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    return ReusableGenerator
         | 
| 38 |  | 
| 39 | 
             
                def _get_stream(self):
         | 
| 40 | 
             
                    """Private method to get the stream based on the initiator function.
         | 
|  | |
| 101 |  | 
| 102 | 
             
                def to_dataset(self) -> DatasetDict:
         | 
| 103 | 
             
                    return DatasetDict(
         | 
| 104 | 
            +
                        {
         | 
| 105 | 
            +
                            key: Dataset.from_generator(self.get_generator, gen_kwargs={"key": key})
         | 
| 106 | 
            +
                            for key in self.keys()
         | 
| 107 | 
            +
                        }
         | 
| 108 | 
             
                    )
         | 
| 109 |  | 
| 110 | 
             
                def to_iterable_dataset(self) -> IterableDatasetDict:
         | 
| 111 | 
             
                    return IterableDatasetDict(
         | 
| 112 | 
            +
                        {
         | 
| 113 | 
            +
                            key: IterableDataset.from_generator(
         | 
| 114 | 
            +
                                self.get_generator, gen_kwargs={"key": key}
         | 
| 115 | 
            +
                            )
         | 
| 116 | 
            +
                            for key in self.keys()
         | 
| 117 | 
            +
                        }
         | 
| 118 | 
             
                    )
         | 
| 119 |  | 
| 120 | 
             
                def __setitem__(self, key, value):
         | 
|  | |
| 123 | 
             
                    super().__setitem__(key, value)
         | 
| 124 |  | 
| 125 | 
             
                @classmethod
         | 
| 126 | 
            +
                def from_generators(
         | 
| 127 | 
            +
                    cls, generators: Dict[str, ReusableGenerator], caching=False, copying=False
         | 
| 128 | 
            +
                ):
         | 
| 129 | 
             
                    """Creates a MultiStream from a dictionary of ReusableGenerators.
         | 
| 130 |  | 
| 131 | 
             
                    Args:
         | 
| 132 | 
             
                        generators (Dict[str, ReusableGenerator]): A dictionary of ReusableGenerators.
         | 
| 133 | 
             
                        caching (bool, optional): Whether the data should be cached or not. Defaults to False.
         | 
| 134 | 
            +
                        copying (bool, optional): Whether the data should be copyied or not. Defaults to False.
         | 
| 135 |  | 
| 136 | 
             
                    Returns:
         | 
| 137 | 
             
                        MultiStream: A MultiStream object.
         | 
| 138 | 
             
                    """
         | 
|  | |
| 139 | 
             
                    assert all(isinstance(v, ReusableGenerator) for v in generators.values())
         | 
| 140 | 
             
                    return cls(
         | 
| 141 | 
             
                        {
         | 
|  | |
| 150 | 
             
                    )
         | 
| 151 |  | 
| 152 | 
             
                @classmethod
         | 
| 153 | 
            +
                def from_iterables(
         | 
| 154 | 
            +
                    cls, iterables: Dict[str, Iterable], caching=False, copying=False
         | 
| 155 | 
            +
                ):
         | 
| 156 | 
             
                    """Creates a MultiStream from a dictionary of iterables.
         | 
| 157 |  | 
| 158 | 
             
                    Args:
         | 
| 159 | 
             
                        iterables (Dict[str, Iterable]): A dictionary of iterables.
         | 
| 160 | 
             
                        caching (bool, optional): Whether the data should be cached or not. Defaults to False.
         | 
| 161 | 
            +
                        copying (bool, optional): Whether the data should be copyied or not. Defaults to False.
         | 
| 162 |  | 
| 163 | 
             
                    Returns:
         | 
| 164 | 
             
                        MultiStream: A MultiStream object.
         | 
| 165 | 
             
                    """
         | 
|  | |
| 166 | 
             
                    return cls(
         | 
| 167 | 
             
                        {
         | 
| 168 | 
             
                            key: Stream(
         | 

