Spaces:
Sleeping
Sleeping
| """Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE""" | |
| import torch | |
| from torch.nn import functional as F | |
| from torch import nn | |
| from abc import abstractmethod | |
| from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple, Optional | |
| from ding.utils.type_helper import Tensor | |
| class VanillaVAE(nn.Module): | |
| """ | |
| Overview: | |
| Implementation of Vanilla variational autoencoder for action reconstruction. | |
| Interfaces: | |
| ``__init__``, ``encode``, ``decode``, ``decode_with_obs``, ``reparameterize``, \ | |
| ``forward``, ``loss_function`` . | |
| """ | |
| def __init__( | |
| self, | |
| action_shape: int, | |
| obs_shape: int, | |
| latent_size: int, | |
| hidden_dims: List = [256, 256], | |
| **kwargs | |
| ) -> None: | |
| super(VanillaVAE, self).__init__() | |
| self.action_shape = action_shape | |
| self.obs_shape = obs_shape | |
| self.latent_size = latent_size | |
| self.hidden_dims = hidden_dims | |
| # Build Encoder | |
| self.encode_action_head = nn.Sequential(nn.Linear(self.action_shape, hidden_dims[0]), nn.ReLU()) | |
| self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_shape, hidden_dims[0]), nn.ReLU()) | |
| self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[1]), nn.ReLU()) | |
| self.encode_mu_head = nn.Linear(hidden_dims[1], latent_size) | |
| self.encode_logvar_head = nn.Linear(hidden_dims[1], latent_size) | |
| # Build Decoder | |
| self.decode_action_head = nn.Sequential(nn.Linear(latent_size, hidden_dims[-1]), nn.ReLU()) | |
| self.decode_common = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-2]), nn.ReLU()) | |
| # TODO(pu): tanh | |
| self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[-2], self.action_shape), nn.Tanh()) | |
| # residual prediction | |
| self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[-2], hidden_dims[-2]), nn.ReLU()) | |
| self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[-2], self.obs_shape) | |
| self.obs_encoding = None | |
| def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Encodes the input by passing through the encoder network and returns the latent codes. | |
| Arguments: | |
| - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) and \ | |
| `action` (:obj:`torch.Tensor`), representing the observation and agent's action respectively. | |
| Returns: | |
| - outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`), \ | |
| ``log_var`` (:obj:`torch.Tensor`) and ``obs_encoding`` (:obj:`torch.Tensor`) \ | |
| representing latent codes. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. | |
| - action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. | |
| - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
| - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
| - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``. | |
| """ | |
| action_encoding = self.encode_action_head(input['action']) | |
| obs_encoding = self.encode_obs_head(input['obs']) | |
| # obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network | |
| input = obs_encoding * action_encoding # TODO(pu): what about add, cat? | |
| result = self.encode_common(input) | |
| # Split the result into mu and var components | |
| # of the latent Gaussian distribution | |
| mu = self.encode_mu_head(result) | |
| log_var = self.encode_logvar_head(result) | |
| return {'mu': mu, 'log_var': log_var, 'obs_encoding': obs_encoding} | |
| def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Maps the given latent action and obs_encoding onto the original action space. | |
| Arguments: | |
| - z (:obj:`torch.Tensor`): the sampled latent action | |
| - obs_encoding (:obj:`torch.Tensor`): observation encoding | |
| Returns: | |
| - outputs (:obj:`Dict`): DQN forward outputs, such as q_value. | |
| ReturnsKeys: | |
| - reconstruction_action (:obj:`torch.Tensor`): reconstruction_action. | |
| - predition_residual (:obj:`torch.Tensor`): predition_residual. | |
| Shapes: | |
| - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` | |
| - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim`` | |
| """ | |
| action_decoding = self.decode_action_head(torch.tanh(z)) # NOTE: tanh, here z is not bounded | |
| action_obs_decoding = action_decoding * obs_encoding | |
| action_obs_decoding_tmp = self.decode_common(action_obs_decoding) | |
| reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) | |
| predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) | |
| predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) | |
| return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} | |
| def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]: | |
| """ | |
| Overview: | |
| Maps the given latent action and obs onto the original action space. | |
| Using the method self.encode_obs_head(obs) to get the obs_encoding. | |
| Arguments: | |
| - z (:obj:`torch.Tensor`): the sampled latent action | |
| - obs (:obj:`torch.Tensor`): observation | |
| Returns: | |
| - outputs (:obj:`Dict`): DQN forward outputs, such as q_value. | |
| ReturnsKeys: | |
| - reconstruction_action (:obj:`torch.Tensor`): the action reconstructed by VAE . | |
| - predition_residual (:obj:`torch.Tensor`): the observation predicted by VAE. | |
| Shapes: | |
| - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``obs_shape`` | |
| """ | |
| obs_encoding = self.encode_obs_head(obs) | |
| # TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh | |
| action_decoding = self.decode_action_head(z) | |
| action_obs_decoding = action_decoding * obs_encoding | |
| action_obs_decoding_tmp = self.decode_common(action_obs_decoding) | |
| reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) | |
| predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) | |
| predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) | |
| return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} | |
| def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: | |
| """ | |
| Overview: | |
| Reparameterization trick to sample from N(mu, var) from N(0,1). | |
| Arguments: | |
| - mu (:obj:`torch.Tensor`): Mean of the latent Gaussian | |
| - logvar (:obj:`torch.Tensor`): Standard deviation of the latent Gaussian | |
| Shapes: | |
| - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` | |
| - logvar (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` | |
| """ | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return eps * std + mu | |
| def forward(self, input: Dict[str, Tensor], **kwargs) -> dict: | |
| """ | |
| Overview: | |
| Encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`. | |
| Argumens: | |
| - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) \ | |
| and `action` (:obj:`torch.Tensor`), representing the observation \ | |
| and agent's action respectively. | |
| Returns: | |
| - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \ | |
| (:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \ | |
| ``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \ | |
| ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). | |
| Shapes: | |
| - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. | |
| - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, \ | |
| where B is batch size and O is ``observation dim``. | |
| - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
| - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
| - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` | |
| """ | |
| encode_output = self.encode(input) | |
| z = self.reparameterize(encode_output['mu'], encode_output['log_var']) | |
| decode_output = self.decode(z, encode_output['obs_encoding']) | |
| return { | |
| 'recons_action': decode_output['reconstruction_action'], | |
| 'prediction_residual': decode_output['predition_residual'], | |
| 'input': input, | |
| 'mu': encode_output['mu'], | |
| 'log_var': encode_output['log_var'], | |
| 'z': z | |
| } | |
| def loss_function(self, args: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: | |
| """ | |
| Overview: | |
| Computes the VAE loss function. | |
| Arguments: | |
| - args (:obj:`Dict[str, Tensor]`): Dict containing keywords ``recons_action``, ``prediction_residual`` \ | |
| ``original_action``, ``mu``, ``log_var`` and ``true_residual``. | |
| - kwargs (:obj:`Dict`): Dict containing keywords ``kld_weight`` and ``predict_weight``. | |
| Returns: | |
| - outputs (:obj:`Dict[str, Tensor]`): Dict containing different ``loss`` results, including ``loss``, \ | |
| ``reconstruction_loss``, ``kld_loss``, ``predict_loss``. | |
| Shapes: | |
| - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size \ | |
| and A is ``action dim``. | |
| - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size \ | |
| and O is ``observation dim``. | |
| - original_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. | |
| - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
| - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
| - true_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. | |
| """ | |
| recons_action = args['recons_action'] | |
| prediction_residual = args['prediction_residual'] | |
| original_action = args['original_action'] | |
| mu = args['mu'] | |
| log_var = args['log_var'] | |
| true_residual = args['true_residual'] | |
| kld_weight = kwargs['kld_weight'] | |
| predict_weight = kwargs['predict_weight'] | |
| recons_loss = F.mse_loss(recons_action, original_action) | |
| kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) | |
| predict_loss = F.mse_loss(prediction_residual, true_residual) | |
| loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss | |
| return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss} | |