File size: 13,989 Bytes
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import os
import json
import random
import time
import torch
import importlib
import numpy as np
from math import ceil
from torch import nn
from analysis.tests import evaluate_generator, evaluate_gen_log
from src.drl.ac_agents import SAC
from src.drl.ac_models import SoftActor, SoftDoubleClipCriticQ
from src.drl.nets import GaussianMLP, ObsActMLP
from src.drl.rep_mem import ReplayMem
from src.drl.trainer import AsyncOffpolicyTrainer
from src.env.environments import AsyncOlGenEnv
from src.env.logger import GenResLogger, AsyncCsvLogger, AsyncStdLogger
from src.gan.gankits import sample_latvec, get_decoder
from src.gan.gans import nz
from src.olgen.ol_generator import VecOnlineGenerator
from src.olgen.olg_policy import EnsembleGenPolicy
from src.smb.asyncsimlt import AsycSimltPool
from src.utils.filesys import getpath, auto_dire
from src.utils.misc import record_time


################## Borrowed from https://github.com/jjccero/DvD_TD3 ##################
class TS:
    def __init__(self, arms=None, random_choice=6):
        self.arms = [0.0, 0.5] if arms is None else arms
        self.arm_num = len(self.arms)
        self.alpha = np.ones(self.arm_num, dtype=int)
        self.beta = np.ones(self.arm_num, dtype=int)
        self.arm = 0
        self.choices = 0
        self.random_choice = random_choice

    @property
    def value(self):
        return self.arms[self.arm]

    def update(self, reward):
        self.choices += 1
        if reward:
            self.alpha[self.arm] += 1
        else:
            self.beta[self.arm] += 1

    def sample(self):
        if self.choices < self.random_choice:
            self.arm = np.random.choice(self.arm_num)
        else:
            self.arm = np.argmax(np.random.beta(self.alpha, self.beta))
        return self.value

    def clear(self):
        self.alpha[:] = 1
        self.beta[:] = 1
        self.choices = 0

def l2rbf(m_actions):
    actions = torch.stack(m_actions)
    x1 = actions.unsqueeze(0).repeat_interleave(actions.shape[0], 0)
    x2 = actions.unsqueeze(1).repeat_interleave(actions.shape[0], 1)
    d2 = torch.square(x1 - x2)
    l2 = torch.var(actions, dim=0).detach() + 1e-8
    return (d2 / (2 * l2)).mean(-1)

class LogDet(nn.Module):
    def __init__(self, beta=0.99):
        super(LogDet, self).__init__()
        self.beta = beta

    def forward(self, embeddings):
        d = l2rbf(embeddings)
        K = (-d).exp()
        K_ = self.beta * K + (1 - self.beta) * torch.eye(len(embeddings), device=K.device)
        L = torch.linalg.cholesky(K_)
        log_det = 2 * torch.log(torch.diag(L)).sum()
        return log_det
#######################################################################################


class PESACAgent:
    def __init__(self, subs, device='cpu'):
        self.subs = subs
        self.device = device
        self.to(device)
        self.i = 0
        self.test = False

    def to(self, device):
        for sub in self.subs:
            sub.to(device)
        self.device = device

    def update(self, obs, acts, rews, ops):
        for sub in self.subs:
            sub.update(obs, acts, rews, ops)
        pass

    def next(self):
        self.i += 1
        self.i %= len(self.subs)

    def make_decision(self, obs, **kwargs):
        if self.test:
            sub = self.subs[random.randrange(0, len(self.subs))]
        else:
            sub = self.subs[self.i]
        a, _ = sub.actor.forward(
            torch.tensor(obs, dtype=torch.float, device=self.device),
            grad=False, **kwargs
        )
        return a.squeeze().cpu().numpy()


class DvDAgent(PESACAgent):
    def __init__(self, subs, phi_batch=20, device='cpu'):
        super().__init__(subs, device)
        self.div_coe = 0.0
        self.phi_batch = phi_batch
        self.log_det = LogDet()
        self.bandit = TS()

    def div_loss(self, obs):
        o = obs[:self.phi_batch]
        embeddings = [sub.actor.forward(o)[0].flatten() for sub in self.subs]
        return self.log_det(embeddings).exp()

    def update(self, obs, acts, rews, ops):
        for sub in self.subs:
            sub.actor.zero_grads()
        ldiv = self.div_coe * self.div_loss(obs)
        ldiv.backward()
        for sub in self.subs:
            sub.actor.backward_policy(sub.critic, obs, (1 - self.div_coe))
            sub.actor.backward_alpha(obs)
            sub.actor.grad_step()
            sub.critic.zero_grads()
            sub.critic.backward_mse(sub.actor, obs, acts, rews, ops)
            sub.critic.grad_step()
            sub.critic.update_tarnet()

    def adapt_div_coe(self, delta):
        self.bandit.update(delta > 0)
        self.div_coe = self.bandit.sample()
        print('Lambda: %.4g' % self.div_coe)


class DvDTrainer(AsyncOffpolicyTrainer):
    def __init__(self, rep_mem:ReplayMem=None, update_per=1, batch=256, eval_itv=20000, eval_num=50):
        super().__init__(rep_mem, update_per, batch)
        self.eval_logger = None
        self.mean_return = 0
        self.env = None
        self.agent = None
        self.eval_itv = eval_itv if eval_itv >= 0 else eval_itv
        self.eval_num = eval_num if eval_num >= 0 else eval_num

    def train(self, env:AsyncOlGenEnv, agent: DvDAgent, budget, path, check_points=None):
        self._reset()
        self.env = env
        self.agent = agent

        o = self.env.reset()
        for logger in self.loggers:
            if logger.__class__.__name__ == 'GenResLogger':
                self.agent.test = True
                logger.on_episode(self.env, agent, 0)
                self.agent.test = False
        eval_horizon = self.eval_itv
        self.__eval()
        while self.steps < budget:
            a = agent.make_decision(o)
            o, done = env.step(a)
            self.steps += 1
            if done:
                model_credits = ceil(1.25 * env.eplen / self.update_per)
                self._update(model_credits, env, agent)
                agent.next()
            if self.steps >= eval_horizon:
                self.__eval()
                eval_horizon += self.eval_itv

        self._update(0, env, agent, close=True)
        for i, sub in enumerate(self.agent.subs):
            torch.save(sub.actor.net, getpath(f'{path}/policy{i}.pth'))

    def __eval(self):
        if self.steps > 0:
            transitions, rewss = self.env.rollout(wait=True)
            self.rep_mem.add_transitions(transitions)
            self.num_trans += len(transitions)
        self.agent.test = True
        it = 0
        o = self.env.reset()
        rewss = []
        while it < self.eval_num:
            a = self.agent.make_decision(o)
            o, done = self.env.step(a)
            if done:
                rewss += self.env.rollout()[1]
                it +=1
        rewss += self.env.rollout(wait=True)[1]
        rewss = np.array([[v for v in rews.values()] for rews in rewss])
        mean_return = float(np.sum(rewss, axis=1).mean())
        if self.steps > 0:
            self.agent.adapt_div_coe(mean_return - self.mean_return)
        self.mean_return = mean_return
        self.agent.test = False

    def _update(self, model_credits, env, agent, close=False):
        transitions, rewss = env.close() if close else env.rollout()
        self.rep_mem.add_transitions(transitions)
        self.num_trans += len(transitions)
        if len(self.rep_mem) > self.batch:
            t = 1
            while self.num_trans >= (self.num_updates + 1) * self.update_per:
                agent.update(*self.rep_mem.sample(self.batch))
                self.num_updates += 1
                if not close and t == model_credits:
                    break
                t += 1
        for logger in self.loggers:
            loginfo = self._pack_loginfo(rewss)
            if logger.__class__.__name__ == 'GenResLogger':
                agent.test = True
                logger.on_episode(env, agent, self.steps)
                agent.test = False
            else:
                logger.on_episode(**loginfo, close=close)

    def _reset(self):
        self.mean_return = 0
        self.start_time = time.time()
        self.steps = 0
        self.num_trans = 0
        self.num_updates = 0
        self.env = None
        self.agent = None


####################### Comand Line Configuration #######################


def set_DvDSAC_parser(parser):
    parser.add_argument('--n_workers', type=int, default=20, help='Number of max_parallel processes in the environment.')
    parser.add_argument('--queuesize', type=int, default=25, help='Size of waiting queue of the environment.')
    parser.add_argument('--eplen', type=int, default=50, help='Episode length of the environment.')
    parser.add_argument('--budget', type=int, default=int(1e6), help='Total time steps of training.')
    parser.add_argument('--gamma', type=float, default=0.9, help='RL parameter')
    parser.add_argument('--tar_entropy', type=float, default=-nz, help='SAC parameter, taget entropy')
    parser.add_argument('--tau', type=float, default=0.02, help='SAC parameter, taget net smooth coefficient')
    parser.add_argument('--update_per', type=int, default=2, help='Do one update (with one batch) per how many collected transitions')
    parser.add_argument('--batch', type=int, default=256, help='Batch size for one update')
    parser.add_argument('--mem_size', type=int, default=int(1e6), help='Size of replay memory')
    parser.add_argument('--gpuid', type=int, default=0, help='ID of GPU to train the policy. CPU will be used if gpuid < 0')
    parser.add_argument('--rfunc', type=str, default='default', help='Name of the reward function in src/env/rfuncs.py')
    parser.add_argument('--path', type=str, default='', help='Path related to \'/training_data\'to save the training logs. If not specified, a new folder named SAC{id} will be created.')
    parser.add_argument('--actor_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of actor net')
    parser.add_argument('--critic_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of critic net')
    parser.add_argument('--gen_period', type=int, default=20000, help='Period of saving level generation results')
    parser.add_argument('--periodic_gen_num', type=int, default=200, help='Number of levels to be generated for each evaluation')
    parser.add_argument('--redirect', action='store_true', help='If add this, redirect STD log to log.txt')
    parser.add_argument(
        '--check_points', type=int, nargs='+',
        help='check points to save policy, specified by the number of time steps.'
    )
    parser.add_argument('--name', type=str, default='DvDSAC', help='Name of this algorithm.')
    parser.add_argument('--m', type=int, default=5, help='Number of ensemble heads in the actor')
    parser.add_argument('--eval_itv', type=int, default=20000, help='Period of evaluating policy and adapt diversity loss coefficient')
    parser.add_argument('--eval_num', type=int, default=50, help='Number of evaluation times')
    pass


def train_DvDSAC(args):
    def _construct_agent(_args, _path, _device, _obs_dim, _act_dim):
        subs = []
        for i in range(_args.m):
            actor = SoftActor(
                lambda: GaussianMLP(_obs_dim, _act_dim, _args.actor_hiddens), tar_ent=_args.tar_entropy
            )
            critic = SoftDoubleClipCriticQ(
                lambda: ObsActMLP(_obs_dim, _act_dim, _args.critic_hiddens), gamma=_args.gamma, tau=_args.tau
            )
            subs.append(SAC(actor, critic, _device))
        with open(f'{_path}/nn_architecture.txt', 'w') as f:
            f.writelines([
                '-' * 24 + 'Actor' + '-' * 24 + '\n', subs[0].actor.get_nn_arch_str(),
                '-' * 24 + 'Critic-Q' + '-' * 24 + '\n', subs[0].critic.get_nn_arch_str()
            ])
        return DvDAgent(subs, device=_device)

    if not args.path:
        path = auto_dire('training_data', args.name)
    else:
        path = getpath('training_data', args.path)
        os.makedirs(path, exist_ok=True)
    if os.path.exists(f'{path}/policy.pth'):
        print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
        return
    device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'

    evalpool = AsycSimltPool(args.n_workers, args.queuesize, args.rfunc, verbose=False)
    rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
    env = AsyncOlGenEnv(rfunc.get_n(), get_decoder('models/decoder.pth'), evalpool, args.eplen, device=device)
    loggers = [
        AsyncCsvLogger(f'{path}/log.csv', rfunc),
        AsyncStdLogger(rfunc, 2000, f'{path}/log.txt' if args.redirect else '')
    ]
    if args.periodic_gen_num > 0:
        loggers.append(GenResLogger(path, args.periodic_gen_num, args.gen_period))
    with open(path + '/run_configuration.txt', 'w') as f:
        f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
        f.write(f'---------{args.name}---------\n')
        args_strlines = [
            f'{key}={val}\n' for key, val in vars(args).items()
            if key not in {'name', 'rfunc', 'path', 'entry'}
        ]
        f.writelines(args_strlines)
        f.write('-' * 50 + '\n')
        f.write(str(rfunc))
    N = rfunc.get_n()
    with open(f'{path}/cfgs.json', 'w') as f:
        data = {'N': N, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc, 'm': args.m}
        json.dump(data, f)
    obs_dim, act_dim = env.histlen * nz, nz

    agent = _construct_agent(args, path, device, obs_dim, act_dim)

    agent.to(device)
    trainer = DvDTrainer(
        ReplayMem(args.mem_size, device=device), update_per=args.update_per, batch=args.batch,
        eval_itv=args.eval_itv, eval_num=args.eval_num
    )
    trainer.set_loggers(*loggers)
    _, timecost = record_time(trainer.train)(env, agent, args.budget, path, check_points=args.check_points)