kinet-test / Kinetix /examples /example_premade_level_replay.py
tree3po's picture
Upload 190 files
e0f25ed verified
import jax
import jax.numpy as jnp
import jax.random
from jax2d.engine import PhysicsEngine
from matplotlib import pyplot as plt
from kinetix.environment.env import make_kinetix_env_from_args
from kinetix.environment.env_state import StaticEnvParams, EnvParams
from kinetix.environment.ued.distributions import sample_kinetix_level
from kinetix.environment.ued.ued_state import UEDParams
from kinetix.render.renderer_pixels import make_render_pixels
from kinetix.util.saving import load_from_json_file
def main():
# Load a premade level
level, static_env_params, env_params = load_from_json_file("worlds/l/grasp_easy.json")
# Create the environment
env = make_kinetix_env_from_args(
obs_type="pixels", action_type="continuous", reset_type="replay", static_env_params=static_env_params
)
# Reset the environment state to this level
rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
obs, env_state = env.reset_to_level(_rng, level, env_params)
# Take a step in the environment
rng, _rng = jax.random.split(rng)
action = env.action_space(env_params).sample(_rng)
rng, _rng = jax.random.split(rng)
obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)
# Render environment
renderer = make_render_pixels(env_params, static_env_params)
# There are a lot of wrappers
pixels = renderer(env_state.env_state.env_state.env_state)
plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1])
plt.show()
if __name__ == "__main__":
main()