Spaces:
Sleeping
Sleeping
""" | |
Worker class implementation of the a3c discrete algorithm | |
""" | |
import os | |
import numpy as np | |
import torch | |
import torch.multiprocessing as mp | |
from torch import nn | |
from .net import Net | |
from .utils import v_wrap | |
class Worker(mp.Process): | |
def __init__( | |
self, | |
max_ep, | |
gnet, | |
opt, | |
global_ep, | |
global_ep_r, | |
res_queue, | |
name, | |
env, | |
N_S, | |
N_A, | |
words_list, | |
word_width, | |
winning_ep, | |
model_checkpoint_dir, | |
gamma=0.0, | |
pretrained_model_path=None, | |
save=False, | |
min_reward=9.9, | |
every_n_save=100, | |
): | |
super(Worker, self).__init__() | |
self.max_ep = max_ep | |
self.name = "w%02i" % name | |
self.g_ep = global_ep | |
self.g_ep_r = global_ep_r | |
self.res_queue = res_queue | |
self.winning_ep = winning_ep | |
self.gnet, self.opt = gnet, opt | |
self.word_list = words_list | |
# local network | |
self.lnet = Net(N_S, N_A, words_list, word_width) | |
if pretrained_model_path: | |
self.lnet.load_state_dict(torch.load(pretrained_model_path)) | |
self.env = env.unwrapped | |
self.gamma = gamma | |
self.model_checkpoint_dir = model_checkpoint_dir | |
self.save = save | |
self.min_reward = min_reward | |
self.every_n_save = every_n_save | |
def run(self): | |
while self.g_ep.value < self.max_ep: | |
s = self.env.reset() | |
buffer_s, buffer_a, buffer_r = [], [], [] | |
ep_r = 0.0 | |
while True: | |
a = self.lnet.choose_action(v_wrap(s[None, :])) | |
s_, r, done, _ = self.env.step(a) | |
ep_r += r | |
buffer_a.append(a) | |
buffer_s.append(s) | |
buffer_r.append(r) | |
if done: # update global and assign to local net | |
# sync | |
self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r) | |
goal_word = self.word_list[self.env.goal_word] | |
self.record(ep_r, goal_word, self.word_list[a], len(buffer_a)) | |
self.save_model() | |
buffer_s, buffer_a, buffer_r = [], [], [] | |
break | |
s = s_ | |
self.res_queue.put(None) | |
def push_and_pull(self, done, s_, bs, ba, br): | |
if done: | |
v_s_ = 0.0 # terminal | |
else: | |
v_s_ = self.lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0] | |
buffer_v_target = [] | |
for r in br[::-1]: # reverse buffer r | |
v_s_ = r + self.gamma * v_s_ | |
buffer_v_target.append(v_s_) | |
buffer_v_target.reverse() | |
loss = self.lnet.loss_func( | |
v_wrap(np.vstack(bs)), | |
v_wrap(np.array(ba), dtype=np.int64) | |
if ba[0].dtype == np.int64 | |
else v_wrap(np.vstack(ba)), | |
v_wrap(np.array(buffer_v_target)[:, None]), | |
) | |
# calculate local gradients and push local parameters to global | |
self.opt.zero_grad() | |
loss.backward() | |
for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()): | |
gp._grad = lp.grad | |
self.opt.step() | |
# pull global parameters | |
self.lnet.load_state_dict(self.gnet.state_dict()) | |
def save_model(self): | |
if ( | |
self.save | |
and self.g_ep_r.value >= self.min_reward | |
and self.g_ep.value % self.every_n_save == 0 | |
): | |
torch.save( | |
self.gnet.state_dict(), | |
os.path.join(self.model_checkpoint_dir, f"model_{self.g_ep.value}.pth"), | |
) | |
def record(self, ep_r, goal_word, action, action_number): | |
with self.g_ep.get_lock(): | |
self.g_ep.value += 1 | |
with self.g_ep_r.get_lock(): | |
if self.g_ep_r.value == 0.0: | |
self.g_ep_r.value = ep_r | |
else: | |
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01 | |
self.res_queue.put(self.g_ep_r.value) | |
if goal_word == action: | |
self.winning_ep.value += 1 | |
if self.g_ep.value % 100 == 0: | |
print( | |
self.name, | |
"Ep:", | |
self.g_ep.value, | |
"| Ep_r: %.0f" % self.g_ep_r.value, | |
"| Goal :", | |
goal_word, | |
"| Action: ", | |
action, | |
"| Actions: ", | |
action_number, | |
) | |