wordle-solver / a3c /worker.py
santit96's picture
Fix code style with black and isort
c412087
"""
Worker class implementation of the a3c discrete algorithm
"""
import os
import numpy as np
import torch
import torch.multiprocessing as mp
from torch import nn
from .net import Net
from .utils import v_wrap
class Worker(mp.Process):
def __init__(
self,
max_ep,
gnet,
opt,
global_ep,
global_ep_r,
res_queue,
name,
env,
N_S,
N_A,
words_list,
word_width,
winning_ep,
model_checkpoint_dir,
gamma=0.0,
pretrained_model_path=None,
save=False,
min_reward=9.9,
every_n_save=100,
):
super(Worker, self).__init__()
self.max_ep = max_ep
self.name = "w%02i" % name
self.g_ep = global_ep
self.g_ep_r = global_ep_r
self.res_queue = res_queue
self.winning_ep = winning_ep
self.gnet, self.opt = gnet, opt
self.word_list = words_list
# local network
self.lnet = Net(N_S, N_A, words_list, word_width)
if pretrained_model_path:
self.lnet.load_state_dict(torch.load(pretrained_model_path))
self.env = env.unwrapped
self.gamma = gamma
self.model_checkpoint_dir = model_checkpoint_dir
self.save = save
self.min_reward = min_reward
self.every_n_save = every_n_save
def run(self):
while self.g_ep.value < self.max_ep:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.0
while True:
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if done: # update global and assign to local net
# sync
self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r)
goal_word = self.word_list[self.env.goal_word]
self.record(ep_r, goal_word, self.word_list[a], len(buffer_a))
self.save_model()
buffer_s, buffer_a, buffer_r = [], [], []
break
s = s_
self.res_queue.put(None)
def push_and_pull(self, done, s_, bs, ba, br):
if done:
v_s_ = 0.0 # terminal
else:
v_s_ = self.lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0]
buffer_v_target = []
for r in br[::-1]: # reverse buffer r
v_s_ = r + self.gamma * v_s_
buffer_v_target.append(v_s_)
buffer_v_target.reverse()
loss = self.lnet.loss_func(
v_wrap(np.vstack(bs)),
v_wrap(np.array(ba), dtype=np.int64)
if ba[0].dtype == np.int64
else v_wrap(np.vstack(ba)),
v_wrap(np.array(buffer_v_target)[:, None]),
)
# calculate local gradients and push local parameters to global
self.opt.zero_grad()
loss.backward()
for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
gp._grad = lp.grad
self.opt.step()
# pull global parameters
self.lnet.load_state_dict(self.gnet.state_dict())
def save_model(self):
if (
self.save
and self.g_ep_r.value >= self.min_reward
and self.g_ep.value % self.every_n_save == 0
):
torch.save(
self.gnet.state_dict(),
os.path.join(self.model_checkpoint_dir, f"model_{self.g_ep.value}.pth"),
)
def record(self, ep_r, goal_word, action, action_number):
with self.g_ep.get_lock():
self.g_ep.value += 1
with self.g_ep_r.get_lock():
if self.g_ep_r.value == 0.0:
self.g_ep_r.value = ep_r
else:
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
self.res_queue.put(self.g_ep_r.value)
if goal_word == action:
self.winning_ep.value += 1
if self.g_ep.value % 100 == 0:
print(
self.name,
"Ep:",
self.g_ep.value,
"| Ep_r: %.0f" % self.g_ep_r.value,
"| Goal :",
goal_word,
"| Action: ",
action,
"| Actions: ",
action_number,
)