Upload split_utils.py with huggingface_hub
Browse files- split_utils.py +27 -26
split_utils.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
import itertools
|
| 2 |
-
import logging
|
| 3 |
import re
|
| 4 |
from typing import Dict
|
| 5 |
|
| 6 |
from .generator_utils import ReusableGenerator
|
| 7 |
-
from .
|
|
|
|
| 8 |
from .stream import Stream
|
| 9 |
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def parse_random_mix_string(input_str):
|
| 12 |
"""Parses a string of format "source1[percentage1%]+source2[value2]+..." and returns a dictionary.
|
|
@@ -179,7 +181,7 @@ def build_stream_routing(mapping):
|
|
| 179 |
}
|
| 180 |
}
|
| 181 |
stream_mapping = build_stream_mapping(mapping)
|
| 182 |
-
|
| 183 |
# Output: {'my_old_stream1': (['my_new_stream', 'my_new_stream2'], [0.6, 0.4]),
|
| 184 |
# 'my_old_stream2': (['my_new_stream', 'my_new_stream2'], [0.2, 0.8])}
|
| 185 |
"""
|
|
@@ -230,14 +232,14 @@ def random_mix_generator(
|
|
| 230 |
):
|
| 231 |
for old_stream_name in new_stream_sources:
|
| 232 |
optinal_streams, weights = stream_routing[old_stream_name]
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
|
| 242 |
|
| 243 |
def random_mix_streams(input_streams, mapping):
|
|
@@ -276,30 +278,29 @@ def random_mix_streams(input_streams, mapping):
|
|
| 276 |
}
|
| 277 |
new_streams = create_streams(input_streams, mapping)
|
| 278 |
for new_stream_name, new_stream in new_streams.items():
|
| 279 |
-
|
| 280 |
for _, item in zip(range(10), new_stream):
|
| 281 |
-
|
| 282 |
"""
|
| 283 |
new_streams = {}
|
| 284 |
|
| 285 |
# Build stream routing
|
| 286 |
stream_routing = build_stream_routing(mapping)
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
)
|
| 300 |
|
| 301 |
return new_streams
|
| 302 |
|
| 303 |
|
| 304 |
if __name__ == "__main__":
|
| 305 |
-
|
|
|
|
| 1 |
import itertools
|
|
|
|
| 2 |
import re
|
| 3 |
from typing import Dict
|
| 4 |
|
| 5 |
from .generator_utils import ReusableGenerator
|
| 6 |
+
from .logging_utils import get_logger
|
| 7 |
+
from .random_utils import new_random_generator
|
| 8 |
from .stream import Stream
|
| 9 |
|
| 10 |
+
logger = get_logger()
|
| 11 |
+
|
| 12 |
|
| 13 |
def parse_random_mix_string(input_str):
|
| 14 |
"""Parses a string of format "source1[percentage1%]+source2[value2]+..." and returns a dictionary.
|
|
|
|
| 181 |
}
|
| 182 |
}
|
| 183 |
stream_mapping = build_stream_mapping(mapping)
|
| 184 |
+
logger.info(stream_mapping)
|
| 185 |
# Output: {'my_old_stream1': (['my_new_stream', 'my_new_stream2'], [0.6, 0.4]),
|
| 186 |
# 'my_old_stream2': (['my_new_stream', 'my_new_stream2'], [0.2, 0.8])}
|
| 187 |
"""
|
|
|
|
| 232 |
):
|
| 233 |
for old_stream_name in new_stream_sources:
|
| 234 |
optinal_streams, weights = stream_routing[old_stream_name]
|
| 235 |
+
random_generator = new_random_generator(sub_seed=old_stream_name)
|
| 236 |
+
assert (
|
| 237 |
+
old_stream_name in input_streams
|
| 238 |
+
), f"'{old_stream_name}' split not found. Possibles options: {input_streams.keys()}"
|
| 239 |
+
for item in input_streams[old_stream_name]:
|
| 240 |
+
choice = random_generator.choices(optinal_streams, weights=weights, k=1)[0]
|
| 241 |
+
if choice == new_stream_name:
|
| 242 |
+
yield item
|
| 243 |
|
| 244 |
|
| 245 |
def random_mix_streams(input_streams, mapping):
|
|
|
|
| 278 |
}
|
| 279 |
new_streams = create_streams(input_streams, mapping)
|
| 280 |
for new_stream_name, new_stream in new_streams.items():
|
| 281 |
+
logger.info(f"{new_stream_name}:")
|
| 282 |
for _, item in zip(range(10), new_stream):
|
| 283 |
+
logger.info(item)
|
| 284 |
"""
|
| 285 |
new_streams = {}
|
| 286 |
|
| 287 |
# Build stream routing
|
| 288 |
stream_routing = build_stream_routing(mapping)
|
| 289 |
|
| 290 |
+
# Create new stream generators
|
| 291 |
+
for new_stream_name, new_stream_sources in mapping.items():
|
| 292 |
+
new_streams[new_stream_name] = ReusableGenerator(
|
| 293 |
+
random_mix_generator,
|
| 294 |
+
gen_kwargs={
|
| 295 |
+
"new_stream_name": new_stream_name,
|
| 296 |
+
"new_stream_sources": new_stream_sources,
|
| 297 |
+
"stream_routing": stream_routing,
|
| 298 |
+
"input_streams": input_streams,
|
| 299 |
+
},
|
| 300 |
+
)
|
|
|
|
| 301 |
|
| 302 |
return new_streams
|
| 303 |
|
| 304 |
|
| 305 |
if __name__ == "__main__":
|
| 306 |
+
logger.info(parse_random_mix_string("dale[90%]+oren[0.7]+mike"))
|