File size: 862 Bytes
9aa735a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
import numpy as np
import random
def worker_init_fn(wid):
seed_sequence = np.random.SeedSequence(
[torch.initial_seed(), wid]
)
to_seed = spawn_get(seed_sequence, 2, dtype=int)
torch.random.manual_seed(to_seed)
np_seed = spawn_get(seed_sequence, 2, dtype=np.ndarray)
np.random.seed(np_seed)
py_seed = spawn_get(seed_sequence, 2, dtype=int)
random.seed(py_seed)
def spawn_get(seedseq, n_entropy, dtype):
child = seedseq.spawn(1)[0]
state = child.generate_state(n_entropy, dtype=np.uint32)
if dtype == np.ndarray:
return state
elif dtype == int:
state_as_int = 0
for shift, s in enumerate(state):
state_as_int = state_as_int + int((2 ** (32 * shift) * s))
return state_as_int
else:
raise ValueError(f'not a valid dtype "{dtype}"')
|