import os import json import random import time import torch import importlib import numpy as np from math import ceil from torch import nn from analysis.tests import evaluate_generator, evaluate_gen_log from src.drl.ac_agents import SAC from src.drl.ac_models import SoftActor, SoftDoubleClipCriticQ from src.drl.nets import GaussianMLP, ObsActMLP from src.drl.rep_mem import ReplayMem from src.drl.trainer import AsyncOffpolicyTrainer from src.env.environments import AsyncOlGenEnv from src.env.logger import GenResLogger, AsyncCsvLogger, AsyncStdLogger from src.gan.gankits import sample_latvec, get_decoder from src.gan.gans import nz from src.olgen.ol_generator import VecOnlineGenerator from src.olgen.olg_policy import EnsembleGenPolicy from src.smb.asyncsimlt import AsycSimltPool from src.utils.filesys import getpath, auto_dire from src.utils.misc import record_time ################## Borrowed from https://github.com/jjccero/DvD_TD3 ################## class TS: def __init__(self, arms=None, random_choice=6): self.arms = [0.0, 0.5] if arms is None else arms self.arm_num = len(self.arms) self.alpha = np.ones(self.arm_num, dtype=int) self.beta = np.ones(self.arm_num, dtype=int) self.arm = 0 self.choices = 0 self.random_choice = random_choice @property def value(self): return self.arms[self.arm] def update(self, reward): self.choices += 1 if reward: self.alpha[self.arm] += 1 else: self.beta[self.arm] += 1 def sample(self): if self.choices < self.random_choice: self.arm = np.random.choice(self.arm_num) else: self.arm = np.argmax(np.random.beta(self.alpha, self.beta)) return self.value def clear(self): self.alpha[:] = 1 self.beta[:] = 1 self.choices = 0 def l2rbf(m_actions): actions = torch.stack(m_actions) x1 = actions.unsqueeze(0).repeat_interleave(actions.shape[0], 0) x2 = actions.unsqueeze(1).repeat_interleave(actions.shape[0], 1) d2 = torch.square(x1 - x2) l2 = torch.var(actions, dim=0).detach() + 1e-8 return (d2 / (2 * l2)).mean(-1) class LogDet(nn.Module): def __init__(self, beta=0.99): super(LogDet, self).__init__() self.beta = beta def forward(self, embeddings): d = l2rbf(embeddings) K = (-d).exp() K_ = self.beta * K + (1 - self.beta) * torch.eye(len(embeddings), device=K.device) L = torch.linalg.cholesky(K_) log_det = 2 * torch.log(torch.diag(L)).sum() return log_det ####################################################################################### class PESACAgent: def __init__(self, subs, device='cpu'): self.subs = subs self.device = device self.to(device) self.i = 0 self.test = False def to(self, device): for sub in self.subs: sub.to(device) self.device = device def update(self, obs, acts, rews, ops): for sub in self.subs: sub.update(obs, acts, rews, ops) pass def next(self): self.i += 1 self.i %= len(self.subs) def make_decision(self, obs, **kwargs): if self.test: sub = self.subs[random.randrange(0, len(self.subs))] else: sub = self.subs[self.i] a, _ = sub.actor.forward( torch.tensor(obs, dtype=torch.float, device=self.device), grad=False, **kwargs ) return a.squeeze().cpu().numpy() class DvDAgent(PESACAgent): def __init__(self, subs, phi_batch=20, device='cpu'): super().__init__(subs, device) self.div_coe = 0.0 self.phi_batch = phi_batch self.log_det = LogDet() self.bandit = TS() def div_loss(self, obs): o = obs[:self.phi_batch] embeddings = [sub.actor.forward(o)[0].flatten() for sub in self.subs] return self.log_det(embeddings).exp() def update(self, obs, acts, rews, ops): for sub in self.subs: sub.actor.zero_grads() ldiv = self.div_coe * self.div_loss(obs) ldiv.backward() for sub in self.subs: sub.actor.backward_policy(sub.critic, obs, (1 - self.div_coe)) sub.actor.backward_alpha(obs) sub.actor.grad_step() sub.critic.zero_grads() sub.critic.backward_mse(sub.actor, obs, acts, rews, ops) sub.critic.grad_step() sub.critic.update_tarnet() def adapt_div_coe(self, delta): self.bandit.update(delta > 0) self.div_coe = self.bandit.sample() print('Lambda: %.4g' % self.div_coe) class DvDTrainer(AsyncOffpolicyTrainer): def __init__(self, rep_mem:ReplayMem=None, update_per=1, batch=256, eval_itv=20000, eval_num=50): super().__init__(rep_mem, update_per, batch) self.eval_logger = None self.mean_return = 0 self.env = None self.agent = None self.eval_itv = eval_itv if eval_itv >= 0 else eval_itv self.eval_num = eval_num if eval_num >= 0 else eval_num def train(self, env:AsyncOlGenEnv, agent: DvDAgent, budget, path, check_points=None): self._reset() self.env = env self.agent = agent o = self.env.reset() for logger in self.loggers: if logger.__class__.__name__ == 'GenResLogger': self.agent.test = True logger.on_episode(self.env, agent, 0) self.agent.test = False eval_horizon = self.eval_itv self.__eval() while self.steps < budget: a = agent.make_decision(o) o, done = env.step(a) self.steps += 1 if done: model_credits = ceil(1.25 * env.eplen / self.update_per) self._update(model_credits, env, agent) agent.next() if self.steps >= eval_horizon: self.__eval() eval_horizon += self.eval_itv self._update(0, env, agent, close=True) for i, sub in enumerate(self.agent.subs): torch.save(sub.actor.net, getpath(f'{path}/policy{i}.pth')) def __eval(self): if self.steps > 0: transitions, rewss = self.env.rollout(wait=True) self.rep_mem.add_transitions(transitions) self.num_trans += len(transitions) self.agent.test = True it = 0 o = self.env.reset() rewss = [] while it < self.eval_num: a = self.agent.make_decision(o) o, done = self.env.step(a) if done: rewss += self.env.rollout()[1] it +=1 rewss += self.env.rollout(wait=True)[1] rewss = np.array([[v for v in rews.values()] for rews in rewss]) mean_return = float(np.sum(rewss, axis=1).mean()) if self.steps > 0: self.agent.adapt_div_coe(mean_return - self.mean_return) self.mean_return = mean_return self.agent.test = False def _update(self, model_credits, env, agent, close=False): transitions, rewss = env.close() if close else env.rollout() self.rep_mem.add_transitions(transitions) self.num_trans += len(transitions) if len(self.rep_mem) > self.batch: t = 1 while self.num_trans >= (self.num_updates + 1) * self.update_per: agent.update(*self.rep_mem.sample(self.batch)) self.num_updates += 1 if not close and t == model_credits: break t += 1 for logger in self.loggers: loginfo = self._pack_loginfo(rewss) if logger.__class__.__name__ == 'GenResLogger': agent.test = True logger.on_episode(env, agent, self.steps) agent.test = False else: logger.on_episode(**loginfo, close=close) def _reset(self): self.mean_return = 0 self.start_time = time.time() self.steps = 0 self.num_trans = 0 self.num_updates = 0 self.env = None self.agent = None ####################### Comand Line Configuration ####################### def set_DvDSAC_parser(parser): parser.add_argument('--n_workers', type=int, default=20, help='Number of max_parallel processes in the environment.') parser.add_argument('--queuesize', type=int, default=25, help='Size of waiting queue of the environment.') parser.add_argument('--eplen', type=int, default=50, help='Episode length of the environment.') parser.add_argument('--budget', type=int, default=int(1e6), help='Total time steps of training.') parser.add_argument('--gamma', type=float, default=0.9, help='RL parameter') parser.add_argument('--tar_entropy', type=float, default=-nz, help='SAC parameter, taget entropy') parser.add_argument('--tau', type=float, default=0.02, help='SAC parameter, taget net smooth coefficient') parser.add_argument('--update_per', type=int, default=2, help='Do one update (with one batch) per how many collected transitions') parser.add_argument('--batch', type=int, default=256, help='Batch size for one update') parser.add_argument('--mem_size', type=int, default=int(1e6), help='Size of replay memory') parser.add_argument('--gpuid', type=int, default=0, help='ID of GPU to train the policy. CPU will be used if gpuid < 0') parser.add_argument('--rfunc', type=str, default='default', help='Name of the reward function in src/env/rfuncs.py') parser.add_argument('--path', type=str, default='', help='Path related to \'/training_data\'to save the training logs. If not specified, a new folder named SAC{id} will be created.') parser.add_argument('--actor_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of actor net') parser.add_argument('--critic_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of critic net') parser.add_argument('--gen_period', type=int, default=20000, help='Period of saving level generation results') parser.add_argument('--periodic_gen_num', type=int, default=200, help='Number of levels to be generated for each evaluation') parser.add_argument('--redirect', action='store_true', help='If add this, redirect STD log to log.txt') parser.add_argument( '--check_points', type=int, nargs='+', help='check points to save policy, specified by the number of time steps.' ) parser.add_argument('--name', type=str, default='DvDSAC', help='Name of this algorithm.') parser.add_argument('--m', type=int, default=5, help='Number of ensemble heads in the actor') parser.add_argument('--eval_itv', type=int, default=20000, help='Period of evaluating policy and adapt diversity loss coefficient') parser.add_argument('--eval_num', type=int, default=50, help='Number of evaluation times') pass def train_DvDSAC(args): def _construct_agent(_args, _path, _device, _obs_dim, _act_dim): subs = [] for i in range(_args.m): actor = SoftActor( lambda: GaussianMLP(_obs_dim, _act_dim, _args.actor_hiddens), tar_ent=_args.tar_entropy ) critic = SoftDoubleClipCriticQ( lambda: ObsActMLP(_obs_dim, _act_dim, _args.critic_hiddens), gamma=_args.gamma, tau=_args.tau ) subs.append(SAC(actor, critic, _device)) with open(f'{_path}/nn_architecture.txt', 'w') as f: f.writelines([ '-' * 24 + 'Actor' + '-' * 24 + '\n', subs[0].actor.get_nn_arch_str(), '-' * 24 + 'Critic-Q' + '-' * 24 + '\n', subs[0].critic.get_nn_arch_str() ]) return DvDAgent(subs, device=_device) if not args.path: path = auto_dire('training_data', args.name) else: path = getpath('training_data', args.path) os.makedirs(path, exist_ok=True) if os.path.exists(f'{path}/policy.pth'): print(f'Trainning at <{path}> is skipped as there has a finished trial already.') return device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}' evalpool = AsycSimltPool(args.n_workers, args.queuesize, args.rfunc, verbose=False) rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')() env = AsyncOlGenEnv(rfunc.get_n(), get_decoder('models/decoder.pth'), evalpool, args.eplen, device=device) loggers = [ AsyncCsvLogger(f'{path}/log.csv', rfunc), AsyncStdLogger(rfunc, 2000, f'{path}/log.txt' if args.redirect else '') ] if args.periodic_gen_num > 0: loggers.append(GenResLogger(path, args.periodic_gen_num, args.gen_period)) with open(path + '/run_configuration.txt', 'w') as f: f.write(time.strftime('%Y-%m-%d %H:%M') + '\n') f.write(f'---------{args.name}---------\n') args_strlines = [ f'{key}={val}\n' for key, val in vars(args).items() if key not in {'name', 'rfunc', 'path', 'entry'} ] f.writelines(args_strlines) f.write('-' * 50 + '\n') f.write(str(rfunc)) N = rfunc.get_n() with open(f'{path}/cfgs.json', 'w') as f: data = {'N': N, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc, 'm': args.m} json.dump(data, f) obs_dim, act_dim = env.histlen * nz, nz agent = _construct_agent(args, path, device, obs_dim, act_dim) agent.to(device) trainer = DvDTrainer( ReplayMem(args.mem_size, device=device), update_per=args.update_per, batch=args.batch, eval_itv=args.eval_itv, eval_num=args.eval_num ) trainer.set_loggers(*loggers) _, timecost = record_time(trainer.train)(env, agent, args.budget, path, check_points=args.check_points)