Spaces:
Sleeping
Sleeping
File size: 4,569 Bytes
254d61f c412087 254d61f c412087 254d61f c412087 254d61f fa34b1d c412087 fa34b1d 254d61f c412087 c10a05f 254d61f 18a7031 254d61f f899dd3 254d61f fa34b1d 254d61f c412087 254d61f c412087 254d61f c412087 254d61f f899dd3 254d61f c412087 254d61f c412087 254d61f c412087 f899dd3 254d61f c412087 c10a05f 254d61f c412087 254d61f c412087 254d61f c412087 254d61f c412087 254d61f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""
Worker class implementation of the a3c discrete algorithm
"""
import os
import numpy as np
import torch
import torch.multiprocessing as mp
from torch import nn
from .net import Net
from .utils import v_wrap
class Worker(mp.Process):
def __init__(
self,
max_ep,
gnet,
opt,
global_ep,
global_ep_r,
res_queue,
name,
env,
N_S,
N_A,
words_list,
word_width,
winning_ep,
model_checkpoint_dir,
gamma=0.0,
pretrained_model_path=None,
save=False,
min_reward=9.9,
every_n_save=100,
):
super(Worker, self).__init__()
self.max_ep = max_ep
self.name = "w%02i" % name
self.g_ep = global_ep
self.g_ep_r = global_ep_r
self.res_queue = res_queue
self.winning_ep = winning_ep
self.gnet, self.opt = gnet, opt
self.word_list = words_list
# local network
self.lnet = Net(N_S, N_A, words_list, word_width)
if pretrained_model_path:
self.lnet.load_state_dict(torch.load(pretrained_model_path))
self.env = env.unwrapped
self.gamma = gamma
self.model_checkpoint_dir = model_checkpoint_dir
self.save = save
self.min_reward = min_reward
self.every_n_save = every_n_save
def run(self):
while self.g_ep.value < self.max_ep:
s = self.env.reset()
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.0
while True:
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
ep_r += r
buffer_a.append(a)
buffer_s.append(s)
buffer_r.append(r)
if done: # update global and assign to local net
# sync
self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r)
goal_word = self.word_list[self.env.goal_word]
self.record(ep_r, goal_word, self.word_list[a], len(buffer_a))
self.save_model()
buffer_s, buffer_a, buffer_r = [], [], []
break
s = s_
self.res_queue.put(None)
def push_and_pull(self, done, s_, bs, ba, br):
if done:
v_s_ = 0.0 # terminal
else:
v_s_ = self.lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0]
buffer_v_target = []
for r in br[::-1]: # reverse buffer r
v_s_ = r + self.gamma * v_s_
buffer_v_target.append(v_s_)
buffer_v_target.reverse()
loss = self.lnet.loss_func(
v_wrap(np.vstack(bs)),
v_wrap(np.array(ba), dtype=np.int64)
if ba[0].dtype == np.int64
else v_wrap(np.vstack(ba)),
v_wrap(np.array(buffer_v_target)[:, None]),
)
# calculate local gradients and push local parameters to global
self.opt.zero_grad()
loss.backward()
for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
gp._grad = lp.grad
self.opt.step()
# pull global parameters
self.lnet.load_state_dict(self.gnet.state_dict())
def save_model(self):
if (
self.save
and self.g_ep_r.value >= self.min_reward
and self.g_ep.value % self.every_n_save == 0
):
torch.save(
self.gnet.state_dict(),
os.path.join(self.model_checkpoint_dir, f"model_{self.g_ep.value}.pth"),
)
def record(self, ep_r, goal_word, action, action_number):
with self.g_ep.get_lock():
self.g_ep.value += 1
with self.g_ep_r.get_lock():
if self.g_ep_r.value == 0.0:
self.g_ep_r.value = ep_r
else:
self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
self.res_queue.put(self.g_ep_r.value)
if goal_word == action:
self.winning_ep.value += 1
if self.g_ep.value % 100 == 0:
print(
self.name,
"Ep:",
self.g_ep.value,
"| Ep_r: %.0f" % self.g_ep_r.value,
"| Goal :",
goal_word,
"| Action: ",
action,
"| Actions: ",
action_number,
)
|