kinet-test / kinetix /models /actor_critic.py
tree3po's picture
Upload 46 files
581eeac verified
import functools
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from flax.linen.initializers import constant, orthogonal
from typing import List, Sequence
import distrax
from kinetix.models.action_spaces import HybridActionDistribution, MultiDiscreteActionDistribution
class ScannedRNN(nn.Module):
@functools.partial(
nn.scan,
variable_broadcast="params",
in_axes=0,
out_axes=0,
split_rngs={"params": False},
)
@nn.compact
def __call__(self, carry, x):
"""Applies the module."""
rnn_state = carry
ins, resets = x
rnn_state = jnp.where(
resets[:, np.newaxis],
self.initialize_carry(ins.shape[0], 256),
rnn_state,
)
new_rnn_state, y = nn.GRUCell(features=256)(rnn_state, ins)
return new_rnn_state, y
@staticmethod
def initialize_carry(batch_size, hidden_size=256):
# Use a dummy key since the default state init fn is just zeros.
cell = nn.GRUCell(features=256)
return cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, hidden_size))
class GeneralActorCriticRNN(nn.Module):
action_dim: Sequence[int]
fc_layer_depth: int
fc_layer_width: int
action_mode: str # "continuous" or "discrete" or "hybrid"
hybrid_action_continuous_dim: int
multi_discrete_number_of_dims_per_distribution: List[int]
add_generator_embedding: bool = False
generator_embedding_number_of_timesteps: int = 10
recurrent: bool = False
# Given an embedding, return the action/values, since this is shared across all models.
@nn.compact
def __call__(self, hidden, obs, embedding, dones, activation):
if self.add_generator_embedding:
raise NotImplementedError()
if self.recurrent:
rnn_in = (embedding, dones)
hidden, embedding = ScannedRNN()(hidden, rnn_in)
actor_mean = embedding
critic = embedding
actor_mean_last = embedding
for _ in range(self.fc_layer_depth):
actor_mean = nn.Dense(
self.fc_layer_width,
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
)(actor_mean)
actor_mean = activation(actor_mean)
critic = nn.Dense(
self.fc_layer_width,
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
)(critic)
critic = activation(critic)
actor_mean_last = actor_mean
actor_mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
if self.action_mode == "discrete":
pi = distrax.Categorical(logits=actor_mean)
elif self.action_mode == "continuous":
actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))
elif self.action_mode == "multi_discrete":
pi = MultiDiscreteActionDistribution(actor_mean, self.multi_discrete_number_of_dims_per_distribution)
else:
actor_mean_continuous = nn.Dense(
self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
)(actor_mean_last)
actor_mean_sigma = jnp.exp(
nn.Dense(self.hybrid_action_continuous_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(
actor_mean_last
)
)
pi = HybridActionDistribution(actor_mean, actor_mean_continuous, actor_mean_sigma)
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)
return hidden, pi, jnp.squeeze(critic, axis=-1)
class ActorCriticPixelsRNN(nn.Module):
action_dim: Sequence[int]
fc_layer_depth: int
fc_layer_width: int
action_mode: str
hybrid_action_continuous_dim: int
multi_discrete_number_of_dims_per_distribution: List[int]
activation: str
add_generator_embedding: bool = False
generator_embedding_number_of_timesteps: int = 10
recurrent: bool = True
@nn.compact
def __call__(self, hidden, x, **kwargs):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
og_obs, dones = x
if self.add_generator_embedding:
obs = og_obs.obs
else:
obs = og_obs
image = obs.image
global_info = obs.global_info
x = nn.Conv(features=16, kernel_size=(8, 8), strides=(4, 4))(image)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4), strides=(2, 2))(x)
x = nn.relu(x)
embedding = x.reshape(x.shape[0], x.shape[1], -1)
embedding = jnp.concatenate([embedding, global_info], axis=-1)
return GeneralActorCriticRNN(
action_dim=self.action_dim,
fc_layer_depth=self.fc_layer_depth,
fc_layer_width=self.fc_layer_width,
action_mode=self.action_mode,
hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
add_generator_embedding=self.add_generator_embedding,
generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
recurrent=self.recurrent,
)(hidden, og_obs, embedding, dones, activation)
@staticmethod
def initialize_carry(batch_size, hidden_size=256):
return ScannedRNN.initialize_carry(batch_size, hidden_size)
class ActorCriticSymbolicRNN(nn.Module):
action_dim: Sequence[int]
fc_layer_width: int
action_mode: str
hybrid_action_continuous_dim: int
multi_discrete_number_of_dims_per_distribution: List[int]
fc_layer_depth: int
activation: str
add_generator_embedding: bool = False
generator_embedding_number_of_timesteps: int = 10
recurrent: bool = True
@nn.compact
def __call__(self, hidden, x):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
og_obs, dones = x
if self.add_generator_embedding:
obs = og_obs.obs
else:
obs = og_obs
embedding = nn.Dense(
self.fc_layer_width,
kernel_init=orthogonal(np.sqrt(2)),
bias_init=constant(0.0),
)(obs)
embedding = nn.relu(embedding)
return GeneralActorCriticRNN(
action_dim=self.action_dim,
fc_layer_depth=self.fc_layer_depth,
fc_layer_width=self.fc_layer_width,
action_mode=self.action_mode,
hybrid_action_continuous_dim=self.hybrid_action_continuous_dim,
multi_discrete_number_of_dims_per_distribution=self.multi_discrete_number_of_dims_per_distribution,
add_generator_embedding=self.add_generator_embedding,
generator_embedding_number_of_timesteps=self.generator_embedding_number_of_timesteps,
recurrent=self.recurrent,
)(hidden, og_obs, embedding, dones, activation)
@staticmethod
def initialize_carry(batch_size, hidden_size=256):
return ScannedRNN.initialize_carry(batch_size, hidden_size)