Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any, Tuple, Union, Optional | |
| from collections import namedtuple | |
| import torch | |
| import copy | |
| from ding.torch_utils import RMSprop, to_device | |
| from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ | |
| v_nstep_td_data, v_nstep_td_error, get_nstep_return_data | |
| from ding.model import model_wrap | |
| from ding.utils import POLICY_REGISTRY | |
| from ding.utils.data import timestep_collate, default_collate, default_decollate | |
| from .qmix import QMIXPolicy | |
| class MADQNPolicy(QMIXPolicy): | |
| config = dict( | |
| # (str) RL policy register name (refer to function "POLICY_REGISTRY"). | |
| type='madqn', | |
| # (bool) Whether to use cuda for network. | |
| cuda=True, | |
| # (bool) Whether the RL algorithm is on-policy or off-policy. | |
| on_policy=False, | |
| # (bool) Whether use priority(priority sample, IS weight, update priority) | |
| priority=False, | |
| # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. | |
| priority_IS_weight=False, | |
| nstep=3, | |
| learn=dict( | |
| update_per_collect=20, | |
| batch_size=32, | |
| learning_rate=0.0005, | |
| clip_value=100, | |
| # ============================================================== | |
| # The following configs is algorithm-specific | |
| # ============================================================== | |
| # (float) Target network update momentum parameter. | |
| # in [0, 1]. | |
| target_update_theta=0.008, | |
| # (float) The discount factor for future rewards, | |
| # in [0, 1]. | |
| discount_factor=0.99, | |
| # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation) | |
| double_q=False, | |
| weight_decay=1e-5, | |
| ), | |
| collect=dict( | |
| # (int) Only one of [n_sample, n_episode] shoule be set | |
| n_episode=32, | |
| # (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps | |
| # in each forward when training. In qmix, it is greater than 1 because there is RNN. | |
| unroll_len=10, | |
| ), | |
| eval=dict(), | |
| other=dict( | |
| eps=dict( | |
| # (str) Type of epsilon decay | |
| type='exp', | |
| # (float) Start value for epsilon decay, in [0, 1]. | |
| # 0 means not use epsilon decay. | |
| start=1, | |
| # (float) Start value for epsilon decay, in [0, 1]. | |
| end=0.05, | |
| # (int) Decay length(env step) | |
| decay=50000, | |
| ), | |
| replay_buffer=dict( | |
| replay_buffer_size=5000, | |
| # (int) The maximum reuse times of each data | |
| max_reuse=1e+9, | |
| max_staleness=1e+9, | |
| ), | |
| ), | |
| ) | |
| def default_model(self) -> Tuple[str, List[str]]: | |
| """ | |
| Overview: | |
| Return this algorithm default model setting for demonstration. | |
| Returns: | |
| - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names | |
| """ | |
| return 'madqn', ['ding.model.template.madqn'] | |
| def _init_learn(self) -> None: | |
| self._priority = self._cfg.priority | |
| self._priority_IS_weight = self._cfg.priority_IS_weight | |
| assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in QMIX" | |
| self._optimizer_current = RMSprop( | |
| params=self._model.current.parameters(), | |
| lr=self._cfg.learn.learning_rate, | |
| alpha=0.99, | |
| eps=0.00001, | |
| weight_decay=self._cfg.learn.weight_decay | |
| ) | |
| self._optimizer_cooperation = RMSprop( | |
| params=self._model.cooperation.parameters(), | |
| lr=self._cfg.learn.learning_rate, | |
| alpha=0.99, | |
| eps=0.00001, | |
| weight_decay=self._cfg.learn.weight_decay | |
| ) | |
| self._gamma = self._cfg.learn.discount_factor | |
| self._nstep = self._cfg.nstep | |
| self._target_model = copy.deepcopy(self._model) | |
| self._target_model = model_wrap( | |
| self._target_model, | |
| wrapper_name='target', | |
| update_type='momentum', | |
| update_kwargs={'theta': self._cfg.learn.target_update_theta} | |
| ) | |
| self._target_model = model_wrap( | |
| self._target_model, | |
| wrapper_name='hidden_state', | |
| state_num=self._cfg.learn.batch_size, | |
| init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] | |
| ) | |
| self._learn_model = model_wrap( | |
| self._model, | |
| wrapper_name='hidden_state', | |
| state_num=self._cfg.learn.batch_size, | |
| init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)] | |
| ) | |
| self._learn_model.reset() | |
| self._target_model.reset() | |
| def _data_preprocess_learn(self, data: List[Any]) -> dict: | |
| r""" | |
| Overview: | |
| Preprocess the data to fit the required data format for learning | |
| Arguments: | |
| - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function | |
| Returns: | |
| - data (:obj:`Dict[str, Any]`): the processed data, from \ | |
| [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])} | |
| """ | |
| # data preprocess | |
| data = timestep_collate(data) | |
| if self._cuda: | |
| data = to_device(data, self._device) | |
| data['weight'] = data.get('weight', None) | |
| data['done'] = data['done'].float() | |
| return data | |
| def _forward_learn(self, data: dict) -> Dict[str, Any]: | |
| r""" | |
| Overview: | |
| Forward and backward function of learn mode. | |
| Arguments: | |
| - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \ | |
| np.ndarray or dict/list combinations. | |
| Returns: | |
| - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \ | |
| recorded in text log and tensorboard, values are python scalar or a list of scalars. | |
| ArgumentsKeys: | |
| - necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done`` | |
| ReturnsKeys: | |
| - necessary: ``cur_lr``, ``total_loss`` | |
| - cur_lr (:obj:`float`): Current learning rate | |
| - total_loss (:obj:`float`): The calculated loss | |
| """ | |
| data = self._data_preprocess_learn(data) | |
| # ==================== | |
| # Q-mix forward | |
| # ==================== | |
| self._learn_model.train() | |
| self._target_model.train() | |
| # for hidden_state plugin, we need to reset the main model and target model | |
| self._learn_model.reset(state=data['prev_state'][0]) | |
| self._target_model.reset(state=data['prev_state'][0]) | |
| inputs = {'obs': data['obs'], 'action': data['action']} | |
| total_q = self._learn_model.forward(inputs, single_step=False)['total_q'] | |
| if self._cfg.learn.double_q: | |
| next_inputs = {'obs': data['next_obs']} | |
| self._learn_model.reset(state=data['prev_state'][1]) | |
| logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach() | |
| next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)} | |
| else: | |
| next_inputs = {'obs': data['next_obs']} | |
| with torch.no_grad(): | |
| target_total_q = self._target_model.forward(next_inputs, cooperation=True, single_step=False)['total_q'] | |
| if self._nstep == 1: | |
| v_data = v_1step_td_data(total_q, target_total_q, data['reward'], data['done'], data['weight']) | |
| loss, td_error_per_sample = v_1step_td_error(v_data, self._gamma) | |
| # for visualization | |
| with torch.no_grad(): | |
| if data['done'] is not None: | |
| target_v = self._gamma * (1 - data['done']) * target_total_q + data['reward'] | |
| else: | |
| target_v = self._gamma * target_total_q + data['reward'] | |
| else: | |
| data['reward'] = data['reward'].permute(0, 2, 1).contiguous() | |
| loss = [] | |
| td_error_per_sample = [] | |
| for t in range(self._cfg.collect.unroll_len): | |
| v_data = v_nstep_td_data( | |
| total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], self._gamma | |
| ) | |
| # calculate v_nstep_td critic_loss | |
| loss_i, td_error_per_sample_i = v_nstep_td_error(v_data, self._gamma, self._nstep) | |
| loss.append(loss_i) | |
| td_error_per_sample.append(td_error_per_sample_i) | |
| loss = sum(loss) / (len(loss) + 1e-8) | |
| td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8) | |
| self._optimizer_current.zero_grad() | |
| loss.backward() | |
| grad_norm = torch.nn.utils.clip_grad_norm_(self._model.current.parameters(), self._cfg.learn.clip_value) | |
| self._optimizer_current.step() | |
| # cooperation | |
| self._learn_model.reset(state=data['prev_state'][0]) | |
| self._target_model.reset(state=data['prev_state'][0]) | |
| cooperation_total_q = self._learn_model.forward(inputs, cooperation=True, single_step=False)['total_q'] | |
| next_inputs = {'obs': data['next_obs']} | |
| with torch.no_grad(): | |
| cooperation_target_total_q = self._target_model.forward( | |
| next_inputs, cooperation=True, single_step=False | |
| )['total_q'] | |
| if self._nstep == 1: | |
| v_data = v_1step_td_data( | |
| cooperation_total_q, cooperation_target_total_q, data['reward'], data['done'], data['weight'] | |
| ) | |
| cooperation_loss, _ = v_1step_td_error(v_data, self._gamma) | |
| else: | |
| cooperation_loss_all = [] | |
| for t in range(self._cfg.collect.unroll_len): | |
| v_data = v_nstep_td_data( | |
| cooperation_total_q[t], cooperation_target_total_q[t], data['reward'][t], data['done'][t], | |
| data['weight'], self._gamma | |
| ) | |
| cooperation_loss, _ = v_nstep_td_error(v_data, self._gamma, self._nstep) | |
| cooperation_loss_all.append(cooperation_loss) | |
| cooperation_loss = sum(cooperation_loss_all) / (len(cooperation_loss_all) + 1e-8) | |
| self._optimizer_cooperation.zero_grad() | |
| cooperation_loss.backward() | |
| cooperation_grad_norm = torch.nn.utils.clip_grad_norm_( | |
| self._model.cooperation.parameters(), self._cfg.learn.clip_value | |
| ) | |
| self._optimizer_cooperation.step() | |
| # ============= | |
| # after update | |
| # ============= | |
| self._target_model.update(self._learn_model.state_dict()) | |
| return { | |
| 'cur_lr': self._optimizer_current.defaults['lr'], | |
| 'total_loss': loss.item(), | |
| 'total_q': total_q.mean().item() / self._cfg.model.agent_num, | |
| 'target_total_q': target_total_q.mean().item() / self._cfg.model.agent_num, | |
| 'grad_norm': grad_norm, | |
| 'cooperation_grad_norm': cooperation_grad_norm, | |
| 'cooperation_loss': cooperation_loss.item(), | |
| } | |
| def _reset_learn(self, data_id: Optional[List[int]] = None) -> None: | |
| r""" | |
| Overview: | |
| Reset learn model to the state indicated by data_id | |
| Arguments: | |
| - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\ | |
| the model state to the state indicated by data_id | |
| """ | |
| self._learn_model.reset(data_id=data_id) | |
| def _state_dict_learn(self) -> Dict[str, Any]: | |
| r""" | |
| Overview: | |
| Return the state_dict of learn mode, usually including model and optimizer. | |
| Returns: | |
| - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. | |
| """ | |
| return { | |
| 'model': self._learn_model.state_dict(), | |
| 'target_model': self._target_model.state_dict(), | |
| 'optimizer_current': self._optimizer_current.state_dict(), | |
| 'optimizer_cooperation': self._optimizer_cooperation.state_dict(), | |
| } | |
| def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: | |
| """ | |
| Overview: | |
| Load the state_dict variable into policy learn mode. | |
| Arguments: | |
| - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. | |
| .. tip:: | |
| If you want to only load some parts of model, you can simply set the ``strict`` argument in \ | |
| load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \ | |
| complicated operation. | |
| """ | |
| self._learn_model.load_state_dict(state_dict['model']) | |
| self._target_model.load_state_dict(state_dict['target_model']) | |
| self._optimizer_current.load_state_dict(state_dict['optimizer_current']) | |
| self._optimizer_cooperation.load_state_dict(state_dict['optimizer_cooperation']) | |
| def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: | |
| r""" | |
| Overview: | |
| Generate dict type transition data from inputs. | |
| Arguments: | |
| - obs (:obj:`Any`): Env observation | |
| - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state'] | |
| - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\ | |
| (here 'obs' indicates obs after env step). | |
| Returns: | |
| - transition (:obj:`dict`): Dict type transition data, including 'obs', 'next_obs', 'prev_state',\ | |
| 'action', 'reward', 'done' | |
| """ | |
| transition = { | |
| 'obs': obs, | |
| 'next_obs': timestep.obs, | |
| 'prev_state': model_output['prev_state'], | |
| 'action': model_output['action'], | |
| 'reward': timestep.reward, | |
| 'done': timestep.done, | |
| } | |
| return transition | |
| def _get_train_sample(self, data: list) -> Union[None, List[Any]]: | |
| r""" | |
| Overview: | |
| Get the train sample from trajectory. | |
| Arguments: | |
| - data (:obj:`list`): The trajectory's cache | |
| Returns: | |
| - samples (:obj:`dict`): The training samples generated | |
| """ | |
| if self._cfg.nstep == 1: | |
| return get_train_sample(data, self._unroll_len) | |
| else: | |
| data = get_nstep_return_data(data, self._nstep, gamma=self._gamma) | |
| return get_train_sample(data, self._unroll_len) | |
| def _monitor_vars_learn(self) -> List[str]: | |
| r""" | |
| Overview: | |
| Return variables' name if variables are to used in monitor. | |
| Returns: | |
| - vars (:obj:`List[str]`): Variables' name list. | |
| """ | |
| return [ | |
| 'cur_lr', 'total_loss', 'total_q', 'target_total_q', 'grad_norm', 'target_reward_total_q', | |
| 'cooperation_grad_norm', 'cooperation_loss' | |
| ] | |