Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any, Tuple, Union | |
| from collections import namedtuple | |
| import torch | |
| from torch import nn | |
| from copy import deepcopy | |
| from ding.torch_utils import Adam, to_device | |
| from ding.rl_utils import get_train_sample | |
| from ding.utils import POLICY_REGISTRY, deep_merge_dicts | |
| from ding.utils.data import default_collate, default_decollate | |
| from ding.policy import Policy | |
| from ding.model import model_wrap | |
| from ding.policy.common_utils import default_preprocess_learn | |
| from .utils import imagine, compute_target, compute_actor_loss, RewardEMA, tensorstats | |
| class DREAMERPolicy(Policy): | |
| config = dict( | |
| # (str) RL policy register name (refer to function "POLICY_REGISTRY"). | |
| type='dreamer', | |
| # (bool) Whether to use cuda for network and loss computation. | |
| cuda=False, | |
| # (int) Number of training samples (randomly collected) in replay buffer when training starts. | |
| random_collect_size=5000, | |
| # (bool) Whether to need policy-specific data in preprocess transition. | |
| transition_with_policy_data=False, | |
| # (int) | |
| imag_horizon=15, | |
| learn=dict( | |
| # (float) Lambda for TD-lambda return. | |
| lambda_=0.95, | |
| # (float) Max norm of gradients. | |
| grad_clip=100, | |
| learning_rate=3e-5, | |
| batch_size=16, | |
| batch_length=64, | |
| imag_sample=True, | |
| slow_value_target=True, | |
| slow_target_update=1, | |
| slow_target_fraction=0.02, | |
| discount=0.997, | |
| reward_EMA=True, | |
| actor_entropy=3e-4, | |
| actor_state_entropy=0.0, | |
| value_decay=0.0, | |
| ), | |
| ) | |
| def default_model(self) -> Tuple[str, List[str]]: | |
| return 'dreamervac', ['ding.model.template.vac'] | |
| def _init_learn(self) -> None: | |
| r""" | |
| Overview: | |
| Learn mode init method. Called by ``self.__init__``. | |
| Init the optimizer, algorithm config, main and target models. | |
| """ | |
| # Algorithm config | |
| self._lambda = self._cfg.learn.lambda_ | |
| self._grad_clip = self._cfg.learn.grad_clip | |
| self._critic = self._model.critic | |
| self._actor = self._model.actor | |
| if self._cfg.learn.slow_value_target: | |
| self._slow_value = deepcopy(self._critic) | |
| self._updates = 0 | |
| # Optimizer | |
| self._optimizer_value = Adam( | |
| self._critic.parameters(), | |
| lr=self._cfg.learn.learning_rate, | |
| ) | |
| self._optimizer_actor = Adam( | |
| self._actor.parameters(), | |
| lr=self._cfg.learn.learning_rate, | |
| ) | |
| self._learn_model = model_wrap(self._model, wrapper_name='base') | |
| self._learn_model.reset() | |
| self._forward_learn_cnt = 0 | |
| if self._cfg.learn.reward_EMA: | |
| self.reward_ema = RewardEMA(device=self._device) | |
| def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]: | |
| # log dict | |
| log_vars = {} | |
| self._learn_model.train() | |
| self._update_slow_target() | |
| self._actor.requires_grad_(requires_grad=True) | |
| # start is dict of {stoch, deter, logit} | |
| if self._cuda: | |
| start = to_device(start, self._device) | |
| # train self._actor | |
| imag_feat, imag_state, imag_action = imagine( | |
| self._cfg.learn, world_model, start, self._actor, self._cfg.imag_horizon | |
| ) | |
| reward = world_model.heads["reward"](world_model.dynamics.get_feat(imag_state)).mode() | |
| actor_ent = self._actor(imag_feat).entropy() | |
| state_ent = world_model.dynamics.get_dist(imag_state).entropy() | |
| # this target is not scaled | |
| # slow is flag to indicate whether slow_target is used for lambda-return | |
| target, weights, base = compute_target( | |
| self._cfg.learn, world_model, self._critic, imag_feat, imag_state, reward, actor_ent, state_ent | |
| ) | |
| actor_loss, mets = compute_actor_loss( | |
| self._cfg.learn, | |
| self._actor, | |
| self.reward_ema, | |
| imag_feat, | |
| imag_state, | |
| imag_action, | |
| target, | |
| actor_ent, | |
| state_ent, | |
| weights, | |
| base, | |
| ) | |
| log_vars.update(mets) | |
| value_input = imag_feat | |
| self._actor.requires_grad_(requires_grad=False) | |
| self._critic.requires_grad_(requires_grad=True) | |
| value = self._critic(value_input[:-1].detach()) | |
| # to do | |
| # target = torch.stack(target, dim=1) | |
| # (time, batch, 1), (time, batch, 1) -> (time, batch) | |
| value_loss = -value.log_prob(target.detach()) | |
| slow_target = self._slow_value(value_input[:-1].detach()) | |
| if self._cfg.learn.slow_value_target: | |
| value_loss = value_loss - value.log_prob(slow_target.mode().detach()) | |
| if self._cfg.learn.value_decay: | |
| value_loss += self._cfg.learn.value_decay * value.mode() | |
| # (time, batch, 1), (time, batch, 1) -> (1,) | |
| value_loss = torch.mean(weights[:-1] * value_loss[:, :, None]) | |
| self._critic.requires_grad_(requires_grad=False) | |
| log_vars.update(tensorstats(value.mode(), "value")) | |
| log_vars.update(tensorstats(target, "target")) | |
| log_vars.update(tensorstats(reward, "imag_reward")) | |
| log_vars.update(tensorstats(imag_action, "imag_action")) | |
| log_vars["actor_ent"] = torch.mean(actor_ent).detach().cpu().numpy().item() | |
| # ==================== | |
| # actor-critic update | |
| # ==================== | |
| self._model.requires_grad_(requires_grad=True) | |
| world_model.requires_grad_(requires_grad=True) | |
| loss_dict = { | |
| 'critic_loss': value_loss, | |
| 'actor_loss': actor_loss, | |
| } | |
| norm_dict = self._update(loss_dict) | |
| self._model.requires_grad_(requires_grad=False) | |
| world_model.requires_grad_(requires_grad=False) | |
| # ============= | |
| # after update | |
| # ============= | |
| self._forward_learn_cnt += 1 | |
| return { | |
| **log_vars, | |
| **norm_dict, | |
| **loss_dict, | |
| } | |
| def _update(self, loss_dict): | |
| # update actor | |
| self._optimizer_actor.zero_grad() | |
| loss_dict['actor_loss'].backward() | |
| actor_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip) | |
| self._optimizer_actor.step() | |
| # update critic | |
| self._optimizer_value.zero_grad() | |
| loss_dict['critic_loss'].backward() | |
| critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip) | |
| self._optimizer_value.step() | |
| return {'actor_grad_norm': actor_norm, 'critic_grad_norm': critic_norm} | |
| def _update_slow_target(self): | |
| if self._cfg.learn.slow_value_target: | |
| if self._updates % self._cfg.learn.slow_target_update == 0: | |
| mix = self._cfg.learn.slow_target_fraction | |
| for s, d in zip(self._critic.parameters(), self._slow_value.parameters()): | |
| d.data = mix * s.data + (1 - mix) * d.data | |
| self._updates += 1 | |
| def _state_dict_learn(self) -> Dict[str, Any]: | |
| ret = { | |
| 'model': self._learn_model.state_dict(), | |
| 'optimizer_value': self._optimizer_value.state_dict(), | |
| 'optimizer_actor': self._optimizer_actor.state_dict(), | |
| } | |
| return ret | |
| def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: | |
| self._learn_model.load_state_dict(state_dict['model']) | |
| self._optimizer_value.load_state_dict(state_dict['optimizer_value']) | |
| self._optimizer_actor.load_state_dict(state_dict['optimizer_actor']) | |
| def _init_collect(self) -> None: | |
| self._unroll_len = self._cfg.collect.unroll_len | |
| self._collect_model = model_wrap(self._model, wrapper_name='base') | |
| self._collect_model.reset() | |
| def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=None) -> dict: | |
| data_id = list(data.keys()) | |
| data = default_collate(list(data.values())) | |
| if self._cuda: | |
| data = to_device(data, self._device) | |
| self._collect_model.eval() | |
| if state is None: | |
| batch_size = len(data_id) | |
| latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} | |
| action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device) | |
| else: | |
| #state = default_collate(list(state.values())) | |
| latent = to_device(default_collate(list(zip(*state))[0]), self._device) | |
| action = to_device(default_collate(list(zip(*state))[1]), self._device) | |
| if len(action.shape) == 1: | |
| action = action.unsqueeze(-1) | |
| if reset.any(): | |
| mask = 1 - reset | |
| for key in latent.keys(): | |
| for i in range(latent[key].shape[0]): | |
| latent[key][i] *= mask[i] | |
| for i in range(len(action)): | |
| action[i] *= mask[i] | |
| data = data - 0.5 | |
| embed = world_model.encoder(data) | |
| latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) | |
| feat = world_model.dynamics.get_feat(latent) | |
| actor = self._actor(feat) | |
| action = actor.sample() | |
| logprob = actor.log_prob(action) | |
| latent = {k: v.detach() for k, v in latent.items()} | |
| action = action.detach() | |
| state = (latent, action) | |
| output = {"action": action, "logprob": logprob, "state": state} | |
| if self._cuda: | |
| output = to_device(output, 'cpu') | |
| output = default_decollate(output) | |
| return {i: d for i, d in zip(data_id, output)} | |
| def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: | |
| r""" | |
| Overview: | |
| Generate dict type transition data from inputs. | |
| Arguments: | |
| - obs (:obj:`Any`): Env observation | |
| - model_output (:obj:`dict`): Output of collect model, including at least ['action'] | |
| - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ | |
| (here 'obs' indicates obs after env step). | |
| Returns: | |
| - transition (:obj:`dict`): Dict type transition data. | |
| """ | |
| transition = { | |
| 'obs': obs, | |
| 'action': model_output['action'], | |
| # TODO(zp) random_collect just have action | |
| #'logprob': model_output['logprob'], | |
| 'reward': timestep.reward, | |
| 'discount': timestep.info['discount'], | |
| 'done': timestep.done, | |
| } | |
| return transition | |
| def _get_train_sample(self, data: list) -> Union[None, List[Any]]: | |
| return get_train_sample(data, self._unroll_len) | |
| def _init_eval(self) -> None: | |
| self._eval_model = model_wrap(self._model, wrapper_name='base') | |
| self._eval_model.reset() | |
| def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict: | |
| data_id = list(data.keys()) | |
| data = default_collate(list(data.values())) | |
| if self._cuda: | |
| data = to_device(data, self._device) | |
| self._eval_model.eval() | |
| if state is None: | |
| batch_size = len(data_id) | |
| latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter} | |
| action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device) | |
| else: | |
| #state = default_collate(list(state.values())) | |
| latent = to_device(default_collate(list(zip(*state))[0]), self._device) | |
| action = to_device(default_collate(list(zip(*state))[1]), self._device) | |
| if len(action.shape) == 1: | |
| action = action.unsqueeze(-1) | |
| if reset.any(): | |
| mask = 1 - reset | |
| for key in latent.keys(): | |
| for i in range(latent[key].shape[0]): | |
| latent[key][i] *= mask[i] | |
| for i in range(len(action)): | |
| action[i] *= mask[i] | |
| data = data - 0.5 | |
| embed = world_model.encoder(data) | |
| latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample) | |
| feat = world_model.dynamics.get_feat(latent) | |
| actor = self._actor(feat) | |
| action = actor.mode() | |
| logprob = actor.log_prob(action) | |
| latent = {k: v.detach() for k, v in latent.items()} | |
| action = action.detach() | |
| state = (latent, action) | |
| output = {"action": action, "logprob": logprob, "state": state} | |
| if self._cuda: | |
| output = to_device(output, 'cpu') | |
| output = default_decollate(output) | |
| return {i: d for i, d in zip(data_id, output)} | |
| def _monitor_vars_learn(self) -> List[str]: | |
| r""" | |
| Overview: | |
| Return variables' name if variables are to used in monitor. | |
| Returns: | |
| - vars (:obj:`List[str]`): Variables' name list. | |
| """ | |
| return [ | |
| 'normed_target_mean', 'normed_target_std', 'normed_target_min', 'normed_target_max', 'EMA_005', 'EMA_095', | |
| 'actor_entropy', 'actor_state_entropy', 'value_mean', 'value_std', 'value_min', 'value_max', 'target_mean', | |
| 'target_std', 'target_min', 'target_max', 'imag_reward_mean', 'imag_reward_std', 'imag_reward_min', | |
| 'imag_reward_max', 'imag_action_mean', 'imag_action_std', 'imag_action_min', 'imag_action_max', 'actor_ent', | |
| 'actor_loss', 'critic_loss', 'actor_grad_norm', 'critic_grad_norm' | |
| ] | |