File size: 1,570 Bytes
e0f25ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import jax
import jax.numpy as jnp
import jax.random
from jax2d.engine import PhysicsEngine
from matplotlib import pyplot as plt

from kinetix.environment.env import make_kinetix_env_from_args
from kinetix.environment.env_state import StaticEnvParams, EnvParams
from kinetix.environment.ued.distributions import sample_kinetix_level
from kinetix.environment.ued.ued_state import UEDParams
from kinetix.render.renderer_pixels import make_render_pixels
from kinetix.util.saving import load_from_json_file


def main():
    # Load a premade level
    level, static_env_params, env_params = load_from_json_file("worlds/l/grasp_easy.json")

    # Create the environment
    env = make_kinetix_env_from_args(
        obs_type="pixels", action_type="continuous", reset_type="replay", static_env_params=static_env_params
    )

    # Reset the environment state to this level
    rng = jax.random.PRNGKey(0)
    rng, _rng = jax.random.split(rng)
    obs, env_state = env.reset_to_level(_rng, level, env_params)

    # Take a step in the environment
    rng, _rng = jax.random.split(rng)
    action = env.action_space(env_params).sample(_rng)
    rng, _rng = jax.random.split(rng)
    obs, env_state, reward, done, info = env.step(_rng, env_state, action, env_params)

    # Render environment
    renderer = make_render_pixels(env_params, static_env_params)

    # There are a lot of wrappers
    pixels = renderer(env_state.env_state.env_state.env_state)

    plt.imshow(pixels.astype(jnp.uint8).transpose(1, 0, 2)[::-1])
    plt.show()


if __name__ == "__main__":
    main()