File size: 4,569 Bytes
254d61f
 
 
 
c412087
254d61f
c412087
254d61f
 
c412087
254d61f
 
 
 
 
fa34b1d
c412087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa34b1d
254d61f
 
c412087
c10a05f
 
 
 
254d61f
 
 
 
18a7031
 
254d61f
f899dd3
254d61f
fa34b1d
 
 
254d61f
 
 
 
 
c412087
254d61f
 
 
 
 
 
 
 
 
 
c412087
254d61f
c412087
254d61f
 
 
 
 
 
f899dd3
254d61f
c412087
254d61f
c412087
254d61f
 
c412087
f899dd3
254d61f
 
 
 
 
c412087
 
 
 
c10a05f
254d61f
 
 
 
 
 
 
 
 
 
 
 
c412087
 
 
 
 
 
 
 
 
254d61f
 
 
 
 
c412087
254d61f
 
 
 
 
 
 
 
 
c412087
 
254d61f
c412087
 
 
 
 
 
254d61f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Worker class implementation of the a3c discrete algorithm
"""
import os

import numpy as np
import torch
import torch.multiprocessing as mp
from torch import nn

from .net import Net
from .utils import v_wrap


class Worker(mp.Process):
    def __init__(
        self,
        max_ep,
        gnet,
        opt,
        global_ep,
        global_ep_r,
        res_queue,
        name,
        env,
        N_S,
        N_A,
        words_list,
        word_width,
        winning_ep,
        model_checkpoint_dir,
        gamma=0.0,
        pretrained_model_path=None,
        save=False,
        min_reward=9.9,
        every_n_save=100,
    ):
        super(Worker, self).__init__()
        self.max_ep = max_ep
        self.name = "w%02i" % name
        self.g_ep = global_ep
        self.g_ep_r = global_ep_r
        self.res_queue = res_queue
        self.winning_ep = winning_ep
        self.gnet, self.opt = gnet, opt
        self.word_list = words_list
        # local network
        self.lnet = Net(N_S, N_A, words_list, word_width)
        if pretrained_model_path:
            self.lnet.load_state_dict(torch.load(pretrained_model_path))
        self.env = env.unwrapped
        self.gamma = gamma
        self.model_checkpoint_dir = model_checkpoint_dir
        self.save = save
        self.min_reward = min_reward
        self.every_n_save = every_n_save

    def run(self):
        while self.g_ep.value < self.max_ep:
            s = self.env.reset()
            buffer_s, buffer_a, buffer_r = [], [], []
            ep_r = 0.0
            while True:
                a = self.lnet.choose_action(v_wrap(s[None, :]))
                s_, r, done, _ = self.env.step(a)
                ep_r += r
                buffer_a.append(a)
                buffer_s.append(s)
                buffer_r.append(r)

                if done:  # update global and assign to local net
                    # sync
                    self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r)
                    goal_word = self.word_list[self.env.goal_word]
                    self.record(ep_r, goal_word, self.word_list[a], len(buffer_a))
                    self.save_model()
                    buffer_s, buffer_a, buffer_r = [], [], []
                    break
                s = s_
        self.res_queue.put(None)

    def push_and_pull(self, done, s_, bs, ba, br):
        if done:
            v_s_ = 0.0  # terminal
        else:
            v_s_ = self.lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0]

        buffer_v_target = []
        for r in br[::-1]:  # reverse buffer r
            v_s_ = r + self.gamma * v_s_
            buffer_v_target.append(v_s_)
        buffer_v_target.reverse()

        loss = self.lnet.loss_func(
            v_wrap(np.vstack(bs)),
            v_wrap(np.array(ba), dtype=np.int64)
            if ba[0].dtype == np.int64
            else v_wrap(np.vstack(ba)),
            v_wrap(np.array(buffer_v_target)[:, None]),
        )

        # calculate local gradients and push local parameters to global
        self.opt.zero_grad()
        loss.backward()
        for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()):
            gp._grad = lp.grad
        self.opt.step()

        # pull global parameters
        self.lnet.load_state_dict(self.gnet.state_dict())

    def save_model(self):
        if (
            self.save
            and self.g_ep_r.value >= self.min_reward
            and self.g_ep.value % self.every_n_save == 0
        ):
            torch.save(
                self.gnet.state_dict(),
                os.path.join(self.model_checkpoint_dir, f"model_{self.g_ep.value}.pth"),
            )

    def record(self, ep_r, goal_word, action, action_number):
        with self.g_ep.get_lock():
            self.g_ep.value += 1
        with self.g_ep_r.get_lock():
            if self.g_ep_r.value == 0.0:
                self.g_ep_r.value = ep_r
            else:
                self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01
        self.res_queue.put(self.g_ep_r.value)
        if goal_word == action:
            self.winning_ep.value += 1
            if self.g_ep.value % 100 == 0:
                print(
                    self.name,
                    "Ep:",
                    self.g_ep.value,
                    "| Ep_r: %.0f" % self.g_ep_r.value,
                    "| Goal :",
                    goal_word,
                    "| Action: ",
                    action,
                    "| Actions: ",
                    action_number,
                )