Spaces:
Sleeping
Sleeping
| from typing import Union, Dict, Optional, List | |
| from easydict import EasyDict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from ding.utils import SequenceType, squeeze, MODEL_REGISTRY | |
| from ..common import RegressionHead, ReparameterizationHead | |
| from .vae import VanillaVAE | |
| class BCQ(nn.Module): | |
| """ | |
| Overview: | |
| Model of BCQ (Batch-Constrained deep Q-learning). | |
| Off-Policy Deep Reinforcement Learning without Exploration. | |
| https://arxiv.org/abs/1812.02900 | |
| Interface: | |
| ``forward``, ``compute_actor``, ``compute_critic``, ``compute_vae``, ``compute_eval`` | |
| Property: | |
| ``mode`` | |
| """ | |
| mode = ['compute_actor', 'compute_critic', 'compute_vae', 'compute_eval'] | |
| def __init__( | |
| self, | |
| obs_shape: Union[int, SequenceType], | |
| action_shape: Union[int, SequenceType, EasyDict], | |
| actor_head_hidden_size: List = [400, 300], | |
| critic_head_hidden_size: List = [400, 300], | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| vae_hidden_dims: List = [750, 750], | |
| phi: float = 0.05 | |
| ) -> None: | |
| """ | |
| Overview: | |
| Initialize neural network, i.e. agent Q network and actor. | |
| Arguments: | |
| - obs_shape (:obj:`int`): the dimension of observation state | |
| - action_shape (:obj:`int`): the dimension of action shape | |
| - actor_hidden_size (:obj:`list`): the list of hidden size of actor | |
| - critic_hidden_size (:obj:'list'): the list of hidden size of critic | |
| - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU(). | |
| - vae_hidden_dims (:obj:`list`): the list of hidden size of vae | |
| """ | |
| super(BCQ, self).__init__() | |
| obs_shape: int = squeeze(obs_shape) | |
| action_shape = squeeze(action_shape) | |
| self.action_shape = action_shape | |
| self.input_size = obs_shape | |
| self.phi = phi | |
| critic_input_size = self.input_size + action_shape | |
| self.critic = nn.ModuleList() | |
| for _ in range(2): | |
| net = [] | |
| d = critic_input_size | |
| for dim in critic_head_hidden_size: | |
| net.append(nn.Linear(d, dim)) | |
| net.append(activation) | |
| d = dim | |
| net.append(nn.Linear(d, 1)) | |
| self.critic.append(nn.Sequential(*net)) | |
| net = [] | |
| d = critic_input_size | |
| for dim in actor_head_hidden_size: | |
| net.append(nn.Linear(d, dim)) | |
| net.append(activation) | |
| d = dim | |
| net.append(nn.Linear(d, 1)) | |
| self.actor = nn.Sequential(*net) | |
| self.vae = VanillaVAE(action_shape, obs_shape, action_shape * 2, vae_hidden_dims) | |
| def forward(self, inputs: Dict[str, torch.Tensor], mode: str) -> Dict[str, torch.Tensor]: | |
| """ | |
| Overview: | |
| The unique execution (forward) method of BCQ method, and one can indicate different modes to implement \ | |
| different computation graph, including ``compute_actor`` and ``compute_critic`` in BCQ. | |
| Mode compute_actor: | |
| Arguments: | |
| - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | |
| Returns: | |
| - output (:obj:`Dict`): Output dict data, including action tensor. | |
| Mode compute_critic: | |
| Arguments: | |
| - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | |
| Returns: | |
| - output (:obj:`Dict`): Output dict data, including q_value tensor. | |
| Mode compute_vae: | |
| Arguments: | |
| - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | |
| Returns: | |
| - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \ | |
| (:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \ | |
| ``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \ | |
| ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). | |
| Mode compute_eval: | |
| Arguments: | |
| - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | |
| Returns: | |
| - output (:obj:`Dict`): Output dict data, including action tensor. | |
| Examples: | |
| >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} | |
| >>> model = BCQ(32, 6) | |
| >>> outputs = model(inputs, mode='compute_actor') | |
| >>> outputs = model(inputs, mode='compute_critic') | |
| >>> outputs = model(inputs, mode='compute_vae') | |
| >>> outputs = model(inputs, mode='compute_eval') | |
| .. note:: | |
| For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively. | |
| """ | |
| assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
| return getattr(self, mode)(inputs) | |
| def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """ | |
| Overview: | |
| Use critic network to compute q value. | |
| Arguments: | |
| - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | |
| Returns: | |
| - outputs (:obj:`Dict`): Dict containing keywords ``q_value`` (:obj:`torch.Tensor`). | |
| Shapes: | |
| - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. | |
| - outputs (:obj:`Dict`): :math:`(B, N)`. | |
| Examples: | |
| >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} | |
| >>> model = BCQ(32, 6) | |
| >>> outputs = model.compute_critic(inputs) | |
| """ | |
| obs, action = inputs['obs'], inputs['action'] | |
| if len(action.shape) == 1: # (B, ) -> (B, 1) | |
| action = action.unsqueeze(1) | |
| x = torch.cat([obs, action], dim=-1) | |
| x = [m(x).squeeze() for m in self.critic] | |
| return {'q_value': x} | |
| def compute_actor(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]: | |
| """ | |
| Overview: | |
| Use actor network to compute action. | |
| Arguments: | |
| - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | |
| Returns: | |
| - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`). | |
| Shapes: | |
| - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. | |
| - outputs (:obj:`Dict`): :math:`(B, N)`. | |
| Examples: | |
| >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} | |
| >>> model = BCQ(32, 6) | |
| >>> outputs = model.compute_actor(inputs) | |
| """ | |
| input = torch.cat([inputs['obs'], inputs['action']], -1) | |
| x = self.actor(input) | |
| action = self.phi * 1 * torch.tanh(x) | |
| action = (action + inputs['action']).clamp(-1, 1) | |
| return {'action': action} | |
| def compute_vae(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """ | |
| Overview: | |
| Use vae network to compute action. | |
| Arguments: | |
| - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | |
| Returns: | |
| - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` (:obj:`torch.Tensor`), \ | |
| ``prediction_residual`` (:obj:`torch.Tensor`), ``input`` (:obj:`torch.Tensor`), \ | |
| ``mu`` (:obj:`torch.Tensor`), ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). | |
| Shapes: | |
| - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. | |
| - outputs (:obj:`Dict`): :math:`(B, N)`. | |
| Examples: | |
| >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} | |
| >>> model = BCQ(32, 6) | |
| >>> outputs = model.compute_vae(inputs) | |
| """ | |
| return self.vae.forward(inputs) | |
| def compute_eval(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """ | |
| Overview: | |
| Use actor network to compute action. | |
| Arguments: | |
| - inputs (:obj:`Dict`): Input dict data, including obs and action tensor. | |
| Returns: | |
| - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`). | |
| Shapes: | |
| - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension. | |
| - outputs (:obj:`Dict`): :math:`(B, N)`. | |
| Examples: | |
| >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)} | |
| >>> model = BCQ(32, 6) | |
| >>> outputs = model.compute_eval(inputs) | |
| """ | |
| obs = inputs['obs'] | |
| obs_rep = obs.clone().unsqueeze(0).repeat_interleave(100, dim=0) | |
| z = torch.randn((obs_rep.shape[0], obs_rep.shape[1], self.action_shape * 2)).to(obs.device).clamp(-0.5, 0.5) | |
| sample_action = self.vae.decode_with_obs(z, obs_rep)['reconstruction_action'] | |
| action = self.compute_actor({'obs': obs_rep, 'action': sample_action})['action'] | |
| q = self.compute_critic({'obs': obs_rep, 'action': action})['q_value'][0] | |
| idx = q.argmax(dim=0).unsqueeze(0).unsqueeze(-1) | |
| idx = idx.repeat_interleave(action.shape[-1], dim=-1) | |
| action = action.gather(0, idx).squeeze() | |
| return {'action': action} | |