File size: 9,348 Bytes
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3582c8a
 
 
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3582c8a
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3582c8a
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3582c8a
 
eaf2e33
 
 
 
 
 
3582c8a
 
eaf2e33
 
 
 
 
 
 
 
3582c8a
 
eaf2e33
3582c8a
 
eaf2e33
 
 
 
 
 
 
 
3582c8a
 
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import importlib
from src.env.logger import *
from src.drl.ac_agents import *
from src.drl.rep_mem import ReplayMem
from src.utils.misc import record_time
from src.utils.filesys import auto_dire
from src.env.environments import AsyncOlGenEnv
from src.drl.trainer import AsyncOffpolicyTrainer
from src.drl.pmoe import PMOESoftActor
from analysis.tests import *
from src.gan.gankits import get_decoder


def set_common_args(parser):
    parser.add_argument('--n_workers', type=int, default=20, help='Number of max_parallel processes in the environment.')
    parser.add_argument('--queuesize', type=int, default=25, help='Size of waiting queue of the environment.')
    parser.add_argument('--eplen', type=int, default=50, help='Episode length of the environment.')
    parser.add_argument('--budget', type=int, default=int(1e6), help='Total time steps of training.')
    parser.add_argument('--gamma', type=float, default=0.9, help='RL parameter')
    parser.add_argument('--tar_entropy', type=float, default=-nz, help='SAC parameter, taget entropy')
    parser.add_argument('--tau', type=float, default=0.02, help='SAC parameter, taget net smooth coefficient')
    parser.add_argument('--update_per', type=int, default=2, help='Do one update (with one batch) per how many collected transitions')
    parser.add_argument('--batch', type=int, default=256, help='Batch size for one update')
    parser.add_argument('--mem_size', type=int, default=int(1e6), help='Size of replay memory')
    parser.add_argument('--gpuid', type=int, default=0, help='ID of GPU to train the policy. CPU will be used if gpuid < 0')
    parser.add_argument('--rfunc', type=str, default='default', help='Name of the reward function in src/env/rfuncs.py')
    parser.add_argument('--path', type=str, default='', help='Path related to \'/training_data\'to save the training logs. If not specified, a new folder named SAC{id} will be created.')
    parser.add_argument('--actor_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of actor net')
    parser.add_argument('--critic_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of critic net')
    parser.add_argument('--gen_period', type=int, default=20000, help='Period of saving level generation results')
    parser.add_argument('--periodic_gen_num', type=int, default=200, help='Number of levels to be generated for each evaluation')
    parser.add_argument('--redirect', action='store_true', help='If add this, redirect STD log to log.txt')
    parser.add_argument(
        '--check_points', type=int, nargs='+',
        help='check points to save policy, specified by the number of time steps.'
    )

def drl_train(foo):
    """
    DRL Train, foo是被调用的函数, 如train_AsyncSAC.
    """
    def __inner(args):
        if not args.path:
            path = auto_dire('training_data', args.name)
        else:
            path = getpath('training_data', args.path)
            os.makedirs(path, exist_ok=True)
        if os.path.exists(f'{path}/policy.pth'):
            print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
            return
        device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'

        evalpool = AsycSimltPool(args.n_workers, args.queuesize, args.rfunc, verbose=False)
        rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
        env = AsyncOlGenEnv(rfunc.get_n(), get_decoder('models/decoder.pth'), evalpool, args.eplen, device=device)
        loggers = [
            AsyncCsvLogger(f'{path}/log.csv', rfunc),
            AsyncStdLogger(rfunc, 2000, f'{path}/log.txt' if args.redirect else '')
        ]
        if args.periodic_gen_num > 0:
            loggers.append(GenResLogger(path, args.periodic_gen_num, args.gen_period))
        with open(path + '/run_configuration.txt', 'w') as f:
            f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
            f.write(f'---------{args.name}---------\n')
            args_strlines = [
                f'{key}={val}\n' for key, val in vars(args).items()
                if key not in {'name', 'rfunc', 'path', 'entry'}
            ]
            f.writelines(args_strlines)
            f.write('-' * 50 + '\n')
            f.write(str(rfunc))
        N = rfunc.get_n()
        with open(f'{path}/cfgs.json', 'w') as f:
            data = {'N': N, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc}
            if args.name == 'MESAC':
                data.update({'m': args.m, 'lambda': args.lbd, 'me_type': args.me_type})
            json.dump(data, f)
        obs_dim, act_dim = env.histlen * nz, nz

        # 根据foo的不同返回agent, 返回的类型是ActCrtAgent
        agent = foo(args, path, device, obs_dim, act_dim)

        agent.to(device)
        trainer = AsyncOffpolicyTrainer(
            ReplayMem(args.mem_size, device=device), update_per=args.update_per, batch=args.batch
        )
        trainer.set_loggers(*loggers)
        _, timecost = record_time(trainer.train)(env, agent, args.budget, path, check_points=args.check_points)
    return __inner

############### AsyncSAC ###############
def set_AsyncSAC_parser(parser):
    set_common_args(parser)
    parser.add_argument('--name', type=str, default='AsyncSAC', help='Name of this algorithm.')

#同样的sac训练,但是多了异步
@drl_train
def train_AsyncSAC(args, path, device, obs_dim, act_dim):
    actor = SoftActor(
        lambda: GaussianMLP(obs_dim, act_dim, args.actor_hiddens), tar_ent=args.tar_entropy
    )
    critic = SoftDoubleClipCriticQ(
        lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens), gamma=args.gamma, tau=args.tau
    )
    with open(f'{path}/nn_architecture.txt', 'w') as f:
        f.writelines([
            '-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
            '-' * 24 + 'Critic-Q' + '-' * 24 + '\n', critic.get_nn_arch_str()
        ])
    return SAC(actor, critic, device)

############## NCESAC ##############
def set_NCESAC_parser(parser):
    set_common_args(parser)
    parser.add_argument('--name', type=str, default='NCESAC', help='Name of this algorithm.')
    parser.add_argument('--lbd', type=float, default=0.2, help='Weight of mutual exlusion regularisation')
    parser.add_argument('--m', type=int, default=2, help='Number of ensemble heads in the actor')
    parser.add_argument('--me_type', type=str, default='clip', choices=['log', 'clip', 'logclip'], help='Type of mutual exclusion regularisation')
    parser.add_argument('--actor_net_type', type=str, default='mlp', choices=['mlp', 'conv'], help='Type of actor\'s NN')

@drl_train
def train_NCESAC(args, path, device, obs_dim, act_dim):
    me_reg, actor_nn_constructor = None, None
    
    # 初始化不同的正则化器
    if args.me_type == 'log':
        me_reg = LogWassersteinExclusion(args.lbd)
    elif args.me_type == 'clip':
        me_reg = ClipExclusion(args.lbd)
    elif args.me_type == 'logclip':
        me_reg = LogClipExclusion(args.lbd)
        
    # 初始化不同的 网络构造器
    if args.actor_net_type == 'conv':
        actor_nn_constructor = lambda: EsmbGaussianConv(
            obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
        )
    elif args.actor_net_type == 'mlp':
        actor_nn_constructor = lambda: EsmbGaussianMLP(
            obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m
        )
        
    # 初始化Actor
    actor = MERegMixSoftActor(actor_nn_constructor, me_reg, tar_ent=args.tar_entropy)
    
    # 初始化Critic
    critic = MERegSoftDoubleClipCriticQ(
        lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
        gamma=args.gamma, tau=args.tau
    )
    critic_U = MERegDoubleClipCriticW(
        lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
        gamma=args.gamma, tau=args.tau
    )
    
    # 保存神经网络架构
    with open(f'{path}/nn_architecture.txt', 'w') as f:
        f.writelines([
            '-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
            '-' * 24 + 'Critic-Q' + '-' * 24 + '\n', critic.get_nn_arch_str(),
            '-' * 24 + 'Critic-U' + '-' * 24 + '\n', critic_U.get_nn_arch_str()
        ])
    return MESAC(actor, critic, critic_U, device)


############## PMOESAC ##############
def set_PMOESAC_parser(parser):
    set_common_args(parser)
    parser.add_argument('--name', type=str, default='PMOESAC', help='Name of this algorithm.')
    parser.add_argument('--m', type=int, default=5, help='Number of ensemble heads in the actor')

@drl_train
def train_PMOESAC(args, path, device, obs_dim, act_dim):
    actor = PMOESoftActor(
        lambda: EsmbGaussianMLP(obs_dim, act_dim, args.actor_hiddens, args.actor_hiddens, args.m),
        tar_ent=args.tar_entropy
    )
    critic = SoftDoubleClipCriticQ(
        lambda : ObsActMLP(obs_dim, act_dim, args.critic_hiddens),
        gamma=args.gamma, tau=args.tau
    )
    with open(f'{path}/nn_architecture.txt', 'w') as f:
        f.writelines([
            '-' * 24 + 'Actor' + '-' * 24 + '\n', actor.get_nn_arch_str(),
            '-' * 24 + 'Critic-Q' + '-' * 24 + '\n', critic.get_nn_arch_str()
        ])
    return SAC(actor, critic, device)