Spaces:
Runtime error
Runtime error
import functools | |
import jax | |
import jax.numpy as jnp | |
import flax.linen as nn | |
import numpy as np | |
from flax.linen.initializers import constant, orthogonal | |
from typing import List, Sequence | |
import distrax | |
from kinetix.models.action_spaces import HybridActionDistribution, MultiDiscreteActionDistribution | |
class ScannedRNN(nn.Module): | |
def __call__(self, carry, x): | |
"""Applies the module.""" | |
rnn_state = carry | |
ins, resets = x | |
rnn_state = jnp.where( | |
resets[:, np.newaxis], | |
self.initialize_carry(ins.shape[0], 256), | |
rnn_state, | |
) | |
new_rnn_state, y = nn.GRUCell(features=256)(rnn_state, ins) | |
return new_rnn_state, y | |
def initialize_carry(batch_size, hidden_size=256): | |
# Use a dummy key since the default state init fn is just zeros. | |
cell = nn.GRUCell(features=256) | |
return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size)) | |
class GeneralActorCriticRNN(nn.Module): | |
action_dim: Sequence[int] | |
fc_layer_depth: int | |
fc_layer_width: int | |
action_mode: str # "continuous" or "discrete" or "hybrid" | |
hybrid_action_continuous_dim: int | |
multi_discrete_number_of_dims_per_distribution: List[int] | |
add_generator_embedding: bool = False | |
generator_embedding_number_of_timesteps: int = 10 | |
recurrent: bool = False | |
# Given an embedding, return the action/values, since this is shared across all models. | |
def __call__(self, hidden, obs, embedding, dones, activation): | |
if self.add_generator_embedding: | |
raise NotImplementedError() | |
if self.recurrent: | |
rnn_in = (embedding, dones) | |
hidden, embedding = ScannedRNN()(hidden, rnn_in) | |
actor_mean = embedding | |
critic = embedding | |
actor_mean_last = embedding | |
for _ in range(self.fc_layer_depth): | |
actor_mean = nn.Dense( | |
self.fc_layer_width, | |
kernel_init=orthogonal(np.sqrt(2)), | |
bias_init=constant(0.0), | |
)(actor_mean) | |
actor_mean = activation(actor_mean) | |
critic = nn.Dense( | |
self.fc_layer_width, | |
kernel_init=orthogonal(np.sqrt(2)), | |
bias_init=constant(0.0), | |
)(critic) | |
critic = activation(critic) | |
actor_mean_last = actor_mean | |
actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean) | |
if self.action_mode == "discrete": | |
pi = distrax.Categorical(logits=actor_mean) | |
elif self.action_mode == "continuous": | |
actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,)) | |
pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) | |
elif self.action_mode == "multi_discrete": | |
pi = MultiDiscreteActionDistribution(actor_mean, self.multi_discrete_number_of_dims_per_distribution) | |
else: | |
actor_mean_continuous = nn.Dense( | |
self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) | |
)(actor_mean_last) | |
actor_mean_sigma = jnp.exp( | |
nn.Dense(self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))( | |
actor_mean_last | |
) | |
) | |
pi = HybridActionDistribution(actor_mean, actor_mean_continuous, actor_mean_sigma) | |
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic) | |
return hidden, pi, jnp.squeeze(critic, axis=-1) | |
class ActorCriticPixelsRNN(nn.Module): | |
action_dim: Sequence[int] | |
fc_layer_depth: int | |
fc_layer_width: int | |
action_mode: str | |
hybrid_action_continuous_dim: int | |
multi_discrete_number_of_dims_per_distribution: List[int] | |
activation: str | |
add_generator_embedding: bool = False | |
generator_embedding_number_of_timesteps: int = 10 | |
recurrent: bool = True | |
def __call__(self, hidden, x, **kwargs): | |
if self.activation == "relu": | |
activation = nn.relu | |
else: | |
activation = nn.tanh | |
og_obs, dones = x | |
if self.add_generator_embedding: | |
obs = og_obs.obs | |
else: | |
obs = og_obs | |
image = obs.image | |
global_info = obs.global_info | |
x = nn.Conv(features=16, kernel_size=(8, 8), strides=(4, 4))(image) | |
x = nn.relu(x) | |
x = nn.Conv(features=32, kernel_size=(4, 4), strides=(2, 2))(x) | |
x = nn.relu(x) | |
embedding = x.reshape(x.shape[0], x.shape[1], -1) | |
embedding = jnp.concatenate([embedding, global_info], axis=-1) | |
return GeneralActorCriticRNN( | |
action_dim=self.action_dim, | |
fc_layer_depth=self.fc_layer_depth, | |
fc_layer_width=self.fc_layer_width, | |
action_mode=self.action_mode, | |
hybrid_action_continuous_dim=self.hybrid_action_continuous_dim, | |
multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution, | |
add_generator_embedding=self.add_generator_embedding, | |
generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps, | |
recurrent=self.recurrent, | |
)(hidden, og_obs, embedding, dones, activation) | |
def initialize_carry(batch_size, hidden_size=256): | |
return ScannedRNN.initialize_carry(batch_size, hidden_size) | |
class ActorCriticSymbolicRNN(nn.Module): | |
action_dim: Sequence[int] | |
fc_layer_width: int | |
action_mode: str | |
hybrid_action_continuous_dim: int | |
multi_discrete_number_of_dims_per_distribution: List[int] | |
fc_layer_depth: int | |
activation: str | |
add_generator_embedding: bool = False | |
generator_embedding_number_of_timesteps: int = 10 | |
recurrent: bool = True | |
def __call__(self, hidden, x): | |
if self.activation == "relu": | |
activation = nn.relu | |
else: | |
activation = nn.tanh | |
og_obs, dones = x | |
if self.add_generator_embedding: | |
obs = og_obs.obs | |
else: | |
obs = og_obs | |
embedding = nn.Dense( | |
self.fc_layer_width, | |
kernel_init=orthogonal(np.sqrt(2)), | |
bias_init=constant(0.0), | |
)(obs) | |
embedding = nn.relu(embedding) | |
return GeneralActorCriticRNN( | |
action_dim=self.action_dim, | |
fc_layer_depth=self.fc_layer_depth, | |
fc_layer_width=self.fc_layer_width, | |
action_mode=self.action_mode, | |
hybrid_action_continuous_dim=self.hybrid_action_continuous_dim, | |
multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution, | |
add_generator_embedding=self.add_generator_embedding, | |
generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps, | |
recurrent=self.recurrent, | |
)(hidden, og_obs, embedding, dones, activation) | |
def initialize_carry(batch_size, hidden_size=256): | |
return ScannedRNN.initialize_carry(batch_size, hidden_size) | |