|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from agent.helpers import init_weights |
|
|
|
|
|
class VAE(nn.Module): |
|
def __init__(self, state_dim, action_dim, device, hidden_size=256) -> None: |
|
super(VAE, self).__init__() |
|
|
|
self.hidden_size = hidden_size |
|
self.action_dim = action_dim |
|
|
|
input_dim = state_dim + action_dim |
|
|
|
self.encoder = nn.Sequential(nn.Linear(input_dim, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.Mish()) |
|
|
|
self.fc_mu = nn.Linear(hidden_size, hidden_size) |
|
self.fc_var = nn.Linear(hidden_size, hidden_size) |
|
|
|
self.decoder = nn.Sequential(nn.Linear(hidden_size + state_dim, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.Mish()) |
|
|
|
self.final_layer = nn.Sequential(nn.Linear(hidden_size, action_dim)) |
|
|
|
self.apply(init_weights) |
|
|
|
self.device = device |
|
|
|
def encode(self, action, state): |
|
x = torch.cat([action, state], dim=-1) |
|
result = self.encoder(x) |
|
result = torch.flatten(result, start_dim=1) |
|
|
|
|
|
|
|
mu = self.fc_mu(result) |
|
log_var = self.fc_var(result) |
|
|
|
return mu, log_var |
|
|
|
def decode(self, z, state): |
|
x = torch.cat([z, state], dim=-1) |
|
result = self.decoder(x) |
|
result = self.final_layer(result) |
|
return result |
|
|
|
def reparameterize(self, mu, logvar): |
|
""" |
|
Will a single z be enough ti compute the expectation |
|
for the loss?? |
|
:param mu: (Tensor) Mean of the latent Gaussian |
|
:param logvar: (Tensor) Standard deviation of the latent Gaussian |
|
:return: |
|
""" |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
return eps * std + mu |
|
|
|
def loss(self, action, state): |
|
mu, log_var = self.encode(action, state) |
|
z = self.reparameterize(mu, log_var) |
|
recons = self.decode(z, state) |
|
|
|
kld_weight = 0.1 |
|
recons_loss = F.mse_loss(recons, action) |
|
|
|
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) |
|
|
|
|
|
|
|
|
|
loss = recons_loss + kld_weight * kld_loss |
|
return loss |
|
|
|
def forward(self, state, eval=False): |
|
batch_size = state.shape[0] |
|
shape = (batch_size, self.hidden_size) |
|
|
|
if eval: |
|
z = torch.zeros(shape, device=self.device) |
|
else: |
|
z = torch.randn(shape, device=self.device) |
|
samples = self.decode(z, state) |
|
|
|
return samples.clamp(-1., 1.) |
|
|