Spaces:
Sleeping
Sleeping
File size: 9,348 Bytes
eaf2e33 3582c8a eaf2e33 3582c8a eaf2e33 3582c8a eaf2e33 3582c8a eaf2e33 3582c8a eaf2e33 3582c8a eaf2e33 3582c8a eaf2e33 3582c8a eaf2e33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import importlib
from src.env.logger import *
from src.drl.ac_agents import *
from src.drl.rep_mem import ReplayMem
from src.utils.misc import record_time
from src.utils.filesys import auto_dire
from src.env.environments import AsyncOlGenEnv
from src.drl.trainer import AsyncOffpolicyTrainer
from src.drl.pmoe import PMOESoftActor
from analysis.tests import *
from src.gan.gankits import get_decoder
def set_common_args(parser):
parser.add_argument('--n_workers', type=int, default=20, help='Number of max_parallel processes in the environment.')
parser.add_argument('--queuesize', type=int, default=25, help='Size of waiting queue of the environment.')
parser.add_argument('--eplen', type=int, default=50, help='Episode length of the environment.')
parser.add_argument('--budget', type=int, default=int(1e6), help='Total time steps of training.')
parser.add_argument('--gamma', type=float, default=0.9, help='RL parameter')
parser.add_argument('--tar_entropy', type=float, default=-nz, help='SAC parameter, taget entropy')
parser.add_argument('--tau', type=float, default=0.02, help='SAC parameter, taget net smooth coefficient')
parser.add_argument('--update_per', type=int, default=2, help='Do one update (with one batch) per how many collected transitions')
parser.add_argument('--batch', type=int, default=256, help='Batch size for one update')
parser.add_argument('--mem_size', type=int, default=int(1e6), help='Size of replay memory')
parser.add_argument('--gpuid', type=int, default=0, help='ID of GPU to train the policy. CPU will be used if gpuid < 0')
parser.add_argument('--rfunc', type=str, default='default', help='Name of the reward function in src/env/rfuncs.py')
parser.add_argument('--path', type=str, default='', help='Path related to \'/training_data\'to save the training logs. If not specified, a new folder named SAC{id} will be created.')
parser.add_argument('--actor_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of actor net')
parser.add_argument('--critic_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of critic net')
parser.add_argument('--gen_period', type=int, default=20000, help='Period of saving level generation results')
parser.add_argument('--periodic_gen_num', type=int, default=200, help='Number of levels to be generated for each evaluation')
parser.add_argument('--redirect', action='store_true', help='If add this, redirect STD log to log.txt')
parser.add_argument(
'--check_points', type=int, nargs='+',
help='check points to save policy, specified by the number of time steps.'
)
def drl_train(foo):
"""
DRL Train, foo是被调用的函数, 如train_AsyncSAC.
"""
def __inner(args):
if not args.path:
path = auto_dire('training_data', args.name)
else:
path = getpath('training_data', args.path)
os.makedirs(path, exist_ok=True)
if os.path.exists(f'{path}/policy.pth'):
print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
return
device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
evalpool = AsycSimltPool(args.n_workers, args.queuesize, args.rfunc, verbose=False)
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
env = AsyncOlGenEnv(rfunc.get_n(), get_decoder('models/decoder.pth'), evalpool, args.eplen, device=device)
loggers = [
AsyncCsvLogger(f'{path}/log.csv', rfunc),
AsyncStdLogger(rfunc, 2000, f'{path}/log.txt' if args.redirect else '')
]
if args.periodic_gen_num > 0:
loggers.append(GenResLogger(path, args.periodic_gen_num, args.gen_period))
with open(path + '/run_configuration.txt', 'w') as f:
f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
f.write(f'---------{args.name}---------\n')
args_strlines = [
f'{key}={val}\n' for key, val in vars(args).items()
if key not in {'name', 'rfunc', 'path', 'entry'}
]
f.writelines(args_strlines)
f.write('-' * 50 + '\n')
f.write(str(rfunc))
N = rfunc.get_n()
with open(f'{path}/cfgs.json', 'w') as f:
data = {'N': N, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc}
if args.name == 'MESAC':
data.update({'m': args.m, 'lambda': args.lbd, 'me_type': args.me_type})
json.dump(data, f)
obs_dim, act_dim = env.histlen * nz, nz
# 根据foo的不同返回agent, 返回的类型是ActCrtAgent
agent = foo(args, path, device, obs_dim, act_dim)
agent.to(device)
trainer = AsyncOffpolicyTrainer(
ReplayMem(args.mem_size, device=device), update_per=args.update_per, batch=args.batch
)
trainer.set_loggers(*loggers)
_, timecost = record_time(trainer.train)(env, agent, args.budget, path, check_points=args.check_points)
return __inner
############### AsyncSAC ###############
def set_AsyncSAC_parser(parser):
set_common_args(parser)
parser.add_argument('--name', type=str, default='AsyncSAC', help='Name of this algorithm.')
#同样的sac训练,但是多了异步
@drl_train
def train_AsyncSAC(args, path, device, obs_dim, act_dim):
actor = SoftActor(
lambda: GaussianMLP(obs_dim, act_dim, args.actor_hiddens), tar_ent=args.tar_entropy
)
critic = SoftDoubleClipCriticQ(
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens), gamma=args.gamma, tau=args.tau
)
with open(f'{path}/nn_architecture.txt', 'w') as f:
f.writelines([
'-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
'-' * 24 + 'Critic-Q' + '-' * 24 + '\n', critic.get_nn_arch_str()
])
return SAC(actor, critic, device)
############## NCESAC ##############
def set_NCESAC_parser(parser):
set_common_args(parser)
parser.add_argument('--name', type=str, default='NCESAC', help='Name of this algorithm.')
parser.add_argument('--lbd', type=float, default=0.2, help='Weight of mutual exlusion regularisation')
parser.add_argument('--m', type=int, default=2, help='Number of ensemble heads in the actor')
parser.add_argument('--me_type', type=str, default='clip', choices=['log', 'clip', 'logclip'], help='Type of mutual exclusion regularisation')
parser.add_argument('--actor_net_type', type=str, default='mlp', choices=['mlp', 'conv'], help='Type of actor\'s NN')
@drl_train
def train_NCESAC(args, path, device, obs_dim, act_dim):
me_reg, actor_nn_constructor = None, None
# 初始化不同的正则化器
if args.me_type == 'log':
me_reg = LogWassersteinExclusion(args.lbd)
elif args.me_type == 'clip':
me_reg = ClipExclusion(args.lbd)
elif args.me_type == 'logclip':
me_reg = LogClipExclusion(args.lbd)
# 初始化不同的 网络构造器
if args.actor_net_type == 'conv':
actor_nn_constructor = lambda: EsmbGaussianConv(
obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
)
elif args.actor_net_type == 'mlp':
actor_nn_constructor = lambda: EsmbGaussianMLP(
obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
)
# 初始化Actor
actor = MERegMixSoftActor(actor_nn_constructor, me_reg, tar_ent=args.tar_entropy)
# 初始化Critic
critic = MERegSoftDoubleClipCriticQ(
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
gamma=args.gamma, tau=args.tau
)
critic_U = MERegDoubleClipCriticW(
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
gamma=args.gamma, tau=args.tau
)
# 保存神经网络架构
with open(f'{path}/nn_architecture.txt', 'w') as f:
f.writelines([
'-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
'-' * 24 + 'Critic-Q' + '-' * 24 + '\n', critic.get_nn_arch_str(),
'-' * 24 + 'Critic-U' + '-' * 24 + '\n', critic_U.get_nn_arch_str()
])
return MESAC(actor, critic, critic_U, device)
############## PMOESAC ##############
def set_PMOESAC_parser(parser):
set_common_args(parser)
parser.add_argument('--name', type=str, default='PMOESAC', help='Name of this algorithm.')
parser.add_argument('--m', type=int, default=5, help='Number of ensemble heads in the actor')
@drl_train
def train_PMOESAC(args, path, device, obs_dim, act_dim):
actor = PMOESoftActor(
lambda: EsmbGaussianMLP(obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m),
tar_ent=args.tar_entropy
)
critic = SoftDoubleClipCriticQ(
lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
gamma=args.gamma, tau=args.tau
)
with open(f'{path}/nn_architecture.txt', 'w') as f:
f.writelines([
'-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
'-' * 24 + 'Critic-Q' + '-' * 24 + '\n', critic.get_nn_arch_str()
])
return SAC(actor, critic, device)
|