wordle-solver / a3c /train.py
santit96's picture
Fix code style with black and isort
c412087
import os
import random
import numpy as np
import torch
import torch.multiprocessing as mp
from .net import Net
from .shared_adam import SharedAdam
from .worker import Worker
def _set_seed(seed: int = 100) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)
def train(
env,
max_ep,
model_checkpoint_dir,
gamma=0.0,
seed=100,
pretrained_model_path=None,
save=False,
min_reward=9.9,
every_n_save=100,
):
os.environ["OMP_NUM_THREADS"] = "1"
if not os.path.exists(model_checkpoint_dir):
os.makedirs(model_checkpoint_dir)
n_s = env.observation_space.shape[0]
n_a = env.action_space.n
words_list = env.words
word_width = len(env.words[0])
# Set global seeds for randoms
_set_seed(seed)
gnet = Net(n_s, n_a, words_list, word_width) # global network
if pretrained_model_path:
gnet.load_state_dict(torch.load(pretrained_model_path))
gnet.share_memory() # share the global parameters in multiprocessing
opt = SharedAdam(
gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)
) # global optimizer
global_ep, global_ep_r, res_queue, win_ep = (
mp.Value("i", 0),
mp.Value("d", 0.0),
mp.Queue(),
mp.Value("i", 0),
)
# parallel training
workers = [
Worker(
max_ep,
gnet,
opt,
global_ep,
global_ep_r,
res_queue,
i,
env,
n_s,
n_a,
words_list,
word_width,
win_ep,
model_checkpoint_dir,
gamma,
pretrained_model_path,
save,
min_reward,
every_n_save,
)
for i in range(mp.cpu_count())
]
[w.start() for w in workers]
res = [] # record episode reward to plot
while True:
r = res_queue.get()
if r is not None:
res.append(r)
else:
break
[w.join() for w in workers]
if save:
torch.save(
gnet.state_dict(),
os.path.join(model_checkpoint_dir, f"model_{env.unwrapped.spec.id}.pth"),
)
return global_ep, win_ep, gnet, res