Spaces:
Runtime error
Runtime error
from kinetix.models.actor_critic import ( | |
ActorCriticPixelsRNN, | |
ActorCriticSymbolicRNN, | |
) | |
from kinetix.models.transformer_model import ActorCriticTransformer | |
def make_network_from_config(env, env_params, config, network_kws={}): | |
env_name = config["env_name"] | |
if "MultiDiscrete" in env_name: | |
action_mode = "multi_discrete" | |
elif "Discrete" in env_name: | |
action_mode = "discrete" | |
elif "Continuous" in env_name: | |
action_mode = "continuous" | |
elif "Hybrid" in env_name: | |
action_mode = "hybrid" | |
else: | |
raise ValueError(f"Unknown action mode for {env_name}") | |
action_dim = ( | |
env.action_space(env_params).shape[0] if action_mode == "continuous" else env.action_space(env_params).n | |
) | |
if "hybrid_action_continuous_dim" not in network_kws: | |
network_kws["hybrid_action_continuous_dim"] = action_dim | |
if "multi_discrete_number_of_dims_per_distribution" not in network_kws: | |
num_joint_bindings = config["static_env_params"]["num_motor_bindings"] | |
num_thruster_bindings = config["static_env_params"]["num_thruster_bindings"] | |
network_kws["multi_discrete_number_of_dims_per_distribution"] = [3 for _ in range(num_joint_bindings)] + [ | |
2 for _ in range(num_thruster_bindings) | |
] | |
network_kws["recurrent"] = config.get("recurrent_model", True) | |
if "Pixels" in env_name: | |
cls_to_use = ActorCriticPixelsRNN | |
elif "Symbolic" in env_name or "Blind" in env_name: | |
cls_to_use = ActorCriticSymbolicRNN | |
if "Entity" in env_name: | |
network = ActorCriticTransformer( | |
action_dim=action_dim, | |
fc_layer_width=config["fc_layer_width"], | |
fc_layer_depth=config["fc_layer_depth"], | |
action_mode=action_mode, | |
num_heads=config["num_heads"], | |
transformer_depth=config["transformer_depth"], | |
transformer_size=config["transformer_size"], | |
transformer_encoder_size=config["transformer_encoder_size"], | |
aggregate_mode=config["aggregate_mode"], | |
full_attention_mask=config["full_attention_mask"], | |
activation=config["activation"], | |
**network_kws, | |
) | |
else: | |
network = cls_to_use( | |
action_dim, | |
fc_layer_width=config["fc_layer_width"], | |
fc_layer_depth=config["fc_layer_depth"], | |
activation=config["activation"], | |
action_mode=action_mode, | |
**network_kws, | |
) | |
return network | |