| import contextlib | |
| import random as python_random | |
| import string | |
| import threading | |
| __default_seed__ = 42 | |
| _thread_local = threading.local() | |
| _thread_local.seed = __default_seed__ | |
| _thread_local.random = python_random.Random() | |
| random = _thread_local.random | |
| def set_seed(seed): | |
| _thread_local.random.seed(seed) | |
| _thread_local.seed = seed | |
| def get_seed(): | |
| return _thread_local.seed | |
| def get_random_string(length): | |
| letters = string.ascii_letters | |
| result_str = "".join(random.choice(letters) for _ in range(length)) | |
| return result_str | |
| def nested_seed(sub_seed=None): | |
| state = _thread_local.random.getstate() | |
| old_global_seed = get_seed() | |
| sub_seed = sub_seed or get_random_string(10) | |
| new_global_seed = str(old_global_seed) + "/" + sub_seed | |
| set_seed(new_global_seed) | |
| try: | |
| yield _thread_local.random | |
| finally: | |
| set_seed(old_global_seed) | |
| _thread_local.random.setstate(state) | |