import os
os.system(f"cd {os.getcwd()}/Kinetix")
import itertools
import time
from timeit import default_timer as tmr

import optax
from PIL import Image
from flax.serialization import to_state_dict
from flax.training.train_state import TrainState

from matplotlib import pyplot as plt

from kinetix.environment.ued.distributions import sample_kinetix_level
from kinetix.models import make_network_from_config
from kinetix.models.actor_critic import ScannedRNN
from kinetix.render.renderer_symbolic_entity import make_render_entities

ss = tmr()

import jax

jax.config.update("jax_compilation_cache_dir", ".cache-location")

import hydra
from omegaconf import OmegaConf

from kinetix.environment.ued.mutators import (
    make_mutate_change_shape_rotation,
    mutate_add_connected_shape,
    mutate_add_shape,
    make_mutate_change_shape_size,
    mutate_change_shape_location,
    mutate_swap_role,
    mutate_remove_shape,
    mutate_remove_joint,
    mutate_toggle_fixture,
    mutate_add_thruster,
    mutate_remove_thruster,
)
from kinetix.environment.ued.ued import make_mutate_env, ALL_MUTATION_FNS
from kinetix.environment.ued.ued_state import UEDParams
from kinetix.environment.ued.util import rectangle_vertices
from kinetix.util.config import generate_params_from_config, normalise_config


import argparse
import os
import sys

sys.path.append("editor")
import tkinter
import tkinter.filedialog
from enum import Enum
from timeit import default_timer as tmr

import jax.numpy as jnp
import numpy as np
import pygame
import pygame_widgets
from pygame_widgets.slider import Slider
from pygame_widgets.textbox import TextBox
from pygame_widgets.toggle import Toggle

from jax2d.engine import (
    calc_inverse_inertia_circle,
    calc_inverse_inertia_polygon,
    calc_inverse_mass_circle,
    calc_inverse_mass_polygon,
    calculate_collision_matrix,
    recalculate_mass_and_inertia,
    recompute_global_joint_positions,
    select_shape,
)
from jax2d.maths import rmat
from jax2d.sim_state import RigidBody
from kinetix.environment.env import (
    create_empty_env,
    make_kinetix_env_from_name,
)
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams
from kinetix.environment.utils import permute_pcg_state
from kinetix.environment.wrappers import AutoResetWrapper
from kinetix.pcg.pcg import env_state_to_pcg_state, sample_pcg_state
from kinetix.pcg.pcg_state import PCGState
from kinetix.render.renderer_pixels import make_render_pixels
from kinetix.render.textures import (
    CIRCLE_TEXTURE_RGBA,
    EDIT_TEXTURE_RGBA,
    PLAY_TEXTURE_RGBA,
    RECT_TEXTURE_RGBA,
    RJOINT_TEXTURE_RGBA,
    SELECT_TEXTURE_RGBA,
    THRUSTER_TEXTURE_RGBA,
    TRIANGLE_TEXTURE_RGBA,
)
from kinetix.util.saving import (
    expand_pcg_state,
    export_env_state_to_json,
    get_pcg_state_from_json,
    load_from_json_file,
    load_pcg_state_pickle,
    load_world_state_pickle,
    save_pickle,
    load_params_from_wandb_artifact_path,
    load_train_state_from_wandb_artifact_path,
)
from kinetix.util.timing import time_function
from tkinter import Tk

#root = Tk()
#root.destroy()

ee = tmr()
print(f"Imported in {ee - ss} seconds")

editor = None
outer_timer = tmr()
EMPTY_ENV = False


class ObjectType(Enum):
    POLYGON = 0
    CIRCLE = 1
    JOINT = 2
    THRUSTER = 3


class EditMode(Enum):
    ADD_CIRCLE = 0
    ADD_RECTANGLE = 1
    ADD_JOINT = 2
    SELECT = 3
    ADD_TRIANGLE = 4
    ADD_THRUSTER = 5


TOTAL_DUMMY_STEPS_TO_SNAP = 0
SNAPPING_DIST = 0.1


def select_object(state: EnvState, type: int, index: int):
    if type is None:
        type = ObjectType.POLYGON
    li = {0: state.polygon, 1: state.circle, 2: state.joint, 3: state.thruster}[type.value]
    return jax.tree.map(lambda x: x[index], li)


def snap_to_center(shape: RigidBody, position: jnp.ndarray):
    if jnp.linalg.norm(shape.position - position) < SNAPPING_DIST:
        return shape.position
    return position


def snap_to_polygon_center_line(polygon: RigidBody, position: jnp.ndarray):
    # Snap to the center line
    r = rmat(polygon.rotation)
    x = jnp.matmul(r, position - polygon.position)
    if jnp.abs(x[0]) < SNAPPING_DIST:
        x = x.at[0].set(0.0)
    if jnp.abs(x[1]) < SNAPPING_DIST:
        x = x.at[1].set(0.0)
    x = jnp.matmul(r.transpose(), x)
    return x + polygon.position


def snap_to_circle_center_line(circle: RigidBody, position: jnp.ndarray):
    # Snap to the center line, i.e. on the edge of the circle, if the position is close enough to directly below the circle, etc., snap the position to that
    x = position - circle.position
    if jnp.linalg.norm(x) < SNAPPING_DIST:
        return circle.position
    angle = (jnp.arctan2(x[1], x[0]) + 2 * jnp.pi) % (2 * jnp.pi)

    for i in range(0, 8):
        if jnp.abs(angle - i * jnp.pi / 4) < jnp.radians(25):  # 25 degrees
            angle = i * jnp.pi / 4
            break
    x = jnp.array([jnp.cos(angle), jnp.sin(angle)]) * circle.radius
    return x + circle.position


def prompt_file(save=False):
    dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "worlds")
    """Create a Tk file dialog and cleanup when finished"""
    top = tkinter.Tk()
    top.withdraw()  # hide window
    if save:
        file_name = tkinter.filedialog.asksaveasfilename(parent=top, initialdir=dir)
    else:
        file_name = tkinter.filedialog.askopenfilename(parent=top, initialdir=dir)
    top.destroy()
    return file_name


def get_numeric_key_pressed(pygame_events, is_mod=False):
    for event in pygame_events:
        if not is_mod:
            if event.type == pygame.KEYDOWN and event.unicode.isdigit():
                return int(event.unicode)
        else:
            if event.type == pygame.KEYDOWN:
                pass
            if event.type == pygame.KEYDOWN and event.key in [pygame.K_0 + i for i in range(10)]:
                return int(event.key - pygame.K_0)
    return None


def new_env(static_env_params):
    return create_empty_env(static_env_params)


myrng = jax.random.PRNGKey(0)


def make_reset_function(static_env_params):
    def reset(rng):
        return env_state_to_pcg_state(create_empty_env(static_env_params))

    return reset


def new_pcg_env(static_env_params):
    global myrng
    if EMPTY_ENV:
        env_state = create_empty_env(static_env_params)
    else:
        return get_pcg_state_from_json("worlds/l/h0_angrybirds.json")
    return env_state_to_pcg_state(env_state)


class Editor:
    def __init__(self, env, env_params, config, upscale=1):
        self.env = env
        self.upscale = upscale
        self.env_params = env_params
        self.static_env_params = env.static_env_params
        self.ued_params = UEDParams()

        self.side_panel_width = env.static_env_params.screen_dim[0] // 3

        self.rng = jax.random.PRNGKey(0)
        self.config = config
        self.pcg_state = new_pcg_env(self.static_env_params)

        self.rng, _rng = jax.random.split(self.rng)
        self.play_state = sample_pcg_state(_rng, self.pcg_state, self.env_params, self.static_env_params)
        self.last_played_level = None

        self.pygame_events = []

        self.mutate_world = make_mutate_env(env.static_env_params, env_params, self.ued_params)

        self.num_triangle_clicks = 0
        self.triangle_order = jnp.array([0, 1, 2])

        # Init rendering
        pygame.init()
        pygame.key.set_repeat(250, 75)

        self.screen_surface = pygame.display.set_mode(
            tuple(
                (t + extra) * self.upscale
                for t, extra in zip(self.static_env_params.screen_dim, (self.side_panel_width, 0))
            ),
            display=0,
        )
        self.all_widgets = {}
        self._setup_side_panel()
        self.has_done_action = False

        self._setup_rendering(self.static_env_params, env_params)

        self._render_edit_overlay_fn = jax.jit(self._render_edit_overlay)
        self._step_fn = jax.jit(env.step)

        # Agent
        if self.config["agent_taking_actions"]:
            self._entity_renderer = jax.jit(make_render_entities(env_params, self.env.static_env_params))
            self.network = make_network_from_config(env, env_params, config)

            rng = jax.random.PRNGKey(0)
            dones = jnp.zeros((config["num_train_envs"]), dtype=jnp.bool_)
            rng, _rng = jax.random.split(rng)
            init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"])
            obsv = self._entity_renderer(self.play_state)
            obsv = jax.tree.map(lambda x: jnp.repeat(x[None, ...], repeats=config["num_train_envs"], axis=0), obsv)
            init_x = jax.tree.map(lambda x: x[None, ...], (obsv, dones))
            network_params = self.network.init(_rng, init_hstate, init_x)

            tx = optax.chain(
                optax.clip_by_global_norm(config["max_grad_norm"]),
                optax.adam(config["lr"], eps=1e-5),
            )
            self.train_state = TrainState.create(
                apply_fn=self.network.apply,
                params=network_params,
                tx=tx,
            )

            self.train_state = load_train_state_from_wandb_artifact_path(self.train_state, config["agent_wandb_path"])

            self.apply_fn = jax.jit(self.network.apply)

        # JIT Compile
        def _jit_step():
            rng = jax.random.PRNGKey(0)
            ans = self._step_fn(
                rng,
                self.env.reset_env_to_level(rng, self.play_state, self.env_params)[1],
                jnp.zeros(
                    env.static_env_params.num_motor_bindings + env.static_env_params.num_thruster_bindings, dtype=int
                ),
                self.env_params,
            )

        def _jit_render():
            self._render_fn(self.play_state)
            self._render_fn_edit(self.play_state)

        time_function(_jit_step, "_jit_step")
        time_function(_jit_render, "_jit_render")

        # self._step_fn(rng, self.play_state, 0, self.env_params)

        # Editing
        self.is_editing = True

        self.edit_shape_mode = EditMode.ADD_CIRCLE
        self.creating_shape = False
        self.create_shape_position = jnp.array([0.0, 0.0])
        self.creating_shape_index = 0
        self.selected_shape_index = -1
        self.selected_shape_type = ObjectType.POLYGON

        self.all_selected_shapes = []

        self.rng = jax.random.PRNGKey(0)
        time_function(self._jit, "self._jit")
        self._put_state_values_into_gui(self.pcg_state)

        self.mutate_change_shape_size = make_mutate_change_shape_size(self.env_params, self.static_env_params)
        self.mutate_change_shape_rotation = make_mutate_change_shape_rotation(self.env_params, self.static_env_params)

    def _setup_rendering(self, static_env_params: StaticEnvParams, env_params: EnvParams):
        def _make_render(should_do_edit_additions=False):
            def _render(env_state):
                side_panel = self._render_side_panel()
                render_pixels = make_render_pixels(params=env_params, static_params=static_env_params)
                pixels = render_pixels(
                    env_state,
                )
                pixels = jnp.concatenate([side_panel, pixels], axis=0)

                pixels = jnp.repeat(pixels, repeats=static_env_params.downscale * self.upscale, axis=0)
                pixels = jnp.repeat(pixels, repeats=static_env_params.downscale * self.upscale, axis=1)

                return pixels[:, ::-1, :]

            return _render

        def _make_screenshot_render():
            def _render(env_state):
                px_upscale = 4

                static_params = static_env_params.replace(screen_dim=(500 * px_upscale, 500 * px_upscale))
                ss_env_params = env_params.replace(
                    pixels_per_unit=100 * px_upscale,
                )

                render_pixels = make_render_pixels(
                    params=ss_env_params,
                    static_params=static_params,
                    render_rjoint_sectors=False,
                    pixel_upscale=2 * px_upscale,
                )
                pixels = render_pixels(
                    env_state,
                )

                return pixels[:, ::-1, :]

            return _render

        self._render_fn_edit = jax.jit(_make_render(True))
        self._render_fn = jax.jit(_make_render(False))
        self._render_fn_screenshot = jax.jit(_make_screenshot_render())

    def _jit(self):
        self._get_circles_on_mouse(self.pcg_state.env_state)
        self._get_polygons_on_mouse(self.pcg_state.env_state)
        self._get_revolute_joints_on_mouse(self.pcg_state.env_state)
        self._get_thrusters_on_mouse(self.pcg_state.env_state)
        self.pygame_events = list(pygame.event.get())
        self._handle_events(do_dummy=True)

        state = self.play_state
        for mutation_fn in ALL_MUTATION_FNS:
            mutation_fn(jax.random.PRNGKey(0), state, self.env_params, self.static_env_params, self.ued_params)

    def update(self, rng):
        # Update pygame events
        self.pygame_events = list(pygame.event.get())

        for event in self.pygame_events:
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_SPACE:
                    self.has_done_action = False
                    self.is_editing = not self.is_editing
                    if not self.is_editing:
                        self.pcg_state = self._discard_shape_being_created(self.pcg_state)
                        self.pcg_state = self._reset_select_shape(self.pcg_state)
                        self.pcg_state = self.pcg_state.replace(
                            env_state=self.pcg_state.env_state.replace(
                                collision_matrix=calculate_collision_matrix(
                                    self.static_env_params, self.pcg_state.env_state.joint
                                ),
                            ),
                            env_state_pcg_mask=self.pcg_state.env_state_pcg_mask.replace(
                                collision_matrix=jnp.zeros_like(self.pcg_state.env_state_pcg_mask.collision_matrix)
                            ),
                        )
                        self.rng, _rng = jax.random.split(self.rng)
                        self.play_state = sample_pcg_state(
                            _rng, self.pcg_state, self.env_params, self.static_env_params
                        )
                        self.last_played_level = self.play_state
                elif event.key == pygame.K_s and not self.is_editing:
                    self.take_screenshot()

        if self.is_editing:
            self.pcg_state = self.edit()
        else:
            rng, _rng = jax.random.split(rng)
            # action = []
            action = jnp.zeros(
                self.env.static_env_params.num_motor_bindings + self.env.static_env_params.num_thruster_bindings,
                dtype=jnp.int32,
            )

            keys = pygame.key.get_pressed()
            if keys[pygame.K_LEFT]:
                action = action.at[0].set(1)
            if keys[pygame.K_RIGHT]:
                action = action.at[0].set(2)
            if keys[pygame.K_UP]:
                action = action.at[1].set(1)
            if keys[pygame.K_DOWN]:
                action = action.at[1].set(2)

            if keys[pygame.K_1]:
                action = action.at[0 + self.env.static_env_params.num_motor_bindings].set(1)
            if keys[pygame.K_2]:
                action = action.at[1 + self.env.static_env_params.num_motor_bindings].set(1)
            if keys[pygame.K_3]:
                action = action.at[2 + self.env.static_env_params.num_motor_bindings].set(1)
            if keys[pygame.K_4]:
                action = action.at[3 + self.env.static_env_params.num_motor_bindings].set(1)
            # if self.has_done_action: action = action * 0
            self.has_done_action = self.has_done_action | (action != 0).any()

            if self.config["agent_taking_actions"]:
                obs = self._entity_renderer(self.play_state)
                obs = jax.tree.map(lambda x: x[None, ...], obs)

                last_done = jnp.zeros((1, 1), dtype=bool)
                ac_in = (jax.tree.map(lambda x: x[np.newaxis, :], obs), last_done[np.newaxis, :])
                hstate = ScannedRNN.initialize_carry(1)
                hstate, pi, value = self.apply_fn(self.train_state.params, hstate, ac_in)
                rng, _rng = jax.random.split(rng)
                action = pi.sample(seed=_rng)
                action = action[0, 0]

            _rng, __rng = jax.random.split(_rng)
            obs, self.play_state, reward, done, info = self._step_fn(
                _rng, self.env.reset_to_level(__rng, self.play_state, self.env_params)[1], action, self.env_params
            )
            if done:
                self.rng, _rng = jax.random.split(self.rng)
                self.play_state = sample_pcg_state(_rng, self.pcg_state, self.env_params, self.static_env_params)
        state_to_render = self.pcg_state.env_state if self.is_editing else self.play_state

        self.render(state_to_render)
        self._handle_events()

        # Update screen
        pygame.display.flip()

    def take_screenshot(self):
        print("screenshot!")

        pixels = self._render_fn_screenshot(self.play_state)
        mtime = round(time.time() * 1000)

        pixels = pixels.transpose((1, 0, 2))

        # Black border
        border_thickness = 5
        pixels = pixels.at[:, :border_thickness].set(0.0)
        pixels = pixels.at[:, -border_thickness:].set(0.0)
        pixels = pixels.at[:border_thickness, :].set(0.0)
        pixels = pixels.at[-border_thickness:, :].set(0.0)

        im = Image.fromarray(np.array(pixels.astype(jnp.uint8)))
        im.save(f"results/screenshot_{str(mtime)}.png")

    def _get_selected_shape_global_indices(self):
        def _idx(idx, type):
            if type == ObjectType.CIRCLE:
                return idx + self.static_env_params.num_polygons
            return idx

        indices_to_use = jnp.array([_idx(idx, type) for idx, type in self.all_selected_shapes])
        return indices_to_use

    # flag1
    def _handle_events(self, do_dummy=False):
        pygame_widgets.update(self.pygame_events)
        if do_dummy or self.selected_shape_index < 0:
            gravity_main = self.all_widgets[None]["sldGravity"].getValue()
            gravity_max = self.all_widgets[None]["sldMaxGravity"].getValue()
            gravity_main, gravity_max = min(gravity_main, gravity_max), max(gravity_main, gravity_max)
            gravity_pcg_mask = self.all_widgets[None]["pcgTglGravity"].getValue()

            def _set_single_global(state, gravity):
                return state.replace(
                    gravity=state.gravity.at[1].set(gravity),
                )

            env_state = _set_single_global(self.pcg_state.env_state, gravity_main)
            env_state_max = _set_single_global(self.pcg_state.env_state_max, gravity_max)
            env_state_pcg_mask = _set_single_global(self.pcg_state.env_state_pcg_mask, gravity_pcg_mask)
            if not do_dummy:
                self.pcg_state = self.pcg_state.replace(
                    env_state=env_state,
                    env_state_max=env_state_max,
                    env_state_pcg_mask=env_state_pcg_mask,
                )

        if self.edit_shape_mode == EditMode.SELECT or do_dummy:  # is on the hand.
            if do_dummy or len(self.all_selected_shapes) > 1:
                # this processes the tying together.
                indices_to_use = self._get_selected_shape_global_indices()
                if len(indices_to_use) > 1:
                    toggle_val = self.all_widgets["TIE_TOGETHER"]["tglTieTogether"].getValue()
                    idxs = itertools.product(indices_to_use, indices_to_use)
                    idxs_a = []
                    idxs_b = []
                    for i, j in idxs:
                        idxs_a.append(i)
                        idxs_b.append(j)
                    idxs = jnp.array(idxs_a), jnp.array(idxs_b)
                    if toggle_val:
                        self.pcg_state = self.pcg_state.replace(
                            tied_together=self.pcg_state.tied_together.at[idxs[0], idxs[1]].set(True)
                        )
                    else:
                        self.pcg_state = self.pcg_state.replace(
                            tied_together=self.pcg_state.tied_together.at[idxs[0], idxs[1]].set(False)
                        )
            if self.selected_shape_index < 0 and not do_dummy:
                return
            if do_dummy or self.selected_shape_type in [ObjectType.POLYGON, ObjectType.CIRCLE]:  # rigidbody
                shape_main = select_object(
                    self.pcg_state.env_state, self.selected_shape_type, self.selected_shape_index
                )
                shape_max = select_object(
                    self.pcg_state.env_state_max, self.selected_shape_type, self.selected_shape_index
                )

                parent_container_main = (
                    self.pcg_state.env_state.circle
                    if self.selected_shape_type == ObjectType.CIRCLE
                    else self.pcg_state.env_state.polygon
                )
                parent_container_max = (
                    self.pcg_state.env_state_max.circle
                    if self.selected_shape_type == ObjectType.CIRCLE
                    else self.pcg_state.env_state_max.polygon
                )
                parent_container_pcg_mask = (
                    self.pcg_state.env_state_pcg_mask.circle
                    if self.selected_shape_type == ObjectType.CIRCLE
                    else self.pcg_state.env_state_pcg_mask.polygon
                )

                new_density_main = self.all_widgets[self.selected_shape_type]["sldDensity"].getValue()
                new_density_max = self.all_widgets[self.selected_shape_type]["sldMaxDensity"].getValue()
                density_pcg_mask = self.all_widgets[self.selected_shape_type]["pcgTglDensity"].getValue()
                if density_pcg_mask:
                    new_density_main, new_density_max = min(new_density_main, new_density_max), max(
                        new_density_main, new_density_max
                    )
                fixated = self.all_widgets[self.selected_shape_type]["tglFixate"].getValue()

                fix_val = 0.0 if fixated else 1.0

                def _density_calcs(base, new_density):
                    inv_mass = jax.lax.select(
                        self.selected_shape_type == ObjectType.CIRCLE,
                        calc_inverse_mass_circle(base.radius, new_density),
                        calc_inverse_mass_polygon(base.vertices, base.n_vertices, self.static_env_params, new_density)[
                            0
                        ],
                    )
                    inv_inertia = jax.lax.select(
                        self.selected_shape_type == ObjectType.CIRCLE,
                        calc_inverse_inertia_circle(base.radius, new_density),
                        calc_inverse_inertia_polygon(
                            base.vertices, base.n_vertices, self.static_env_params, new_density
                        ),
                    )

                    return inv_mass, inv_inertia

                inv_mass_main, inv_inertia_main = _density_calcs(shape_main, new_density_main)
                inv_mass_max, inv_inertia_max = _density_calcs(shape_max, new_density_max)

                friction_main = self.all_widgets[self.selected_shape_type]["sldFriction"].getValue()
                friction_max = self.all_widgets[self.selected_shape_type]["sldMaxFriction"].getValue()
                friction_pcg_mask = self.all_widgets[self.selected_shape_type]["pcgTglFriction"].getValue()
                if friction_pcg_mask:
                    friction_main, friction_max = min(friction_main, friction_max), max(friction_main, friction_max)

                restitution = self.all_widgets[self.selected_shape_type]["sldRestitution"].getValue()

                position_main = jnp.array(
                    [
                        self.all_widgets[self.selected_shape_type]["sldPosition_X"].getValue(),
                        self.all_widgets[self.selected_shape_type]["sldPosition_Y"].getValue(),
                    ]
                )

                position_max = jnp.array(
                    [
                        self.all_widgets[self.selected_shape_type]["sldMaxPosition_X"].getValue(),
                        self.all_widgets[self.selected_shape_type]["sldMaxPosition_Y"].getValue(),
                    ]
                )
                position_pcg_mask = self.all_widgets[self.selected_shape_type]["pcgTglPosition_X"].getValue()
                if position_pcg_mask:
                    position_main, position_max = jnp.minimum(position_main, position_max), jnp.maximum(
                        position_main, position_max
                    )

                rotation_main = self.all_widgets[self.selected_shape_type]["sldRotation"].getValue()
                rotation_max = self.all_widgets[self.selected_shape_type]["sldMaxRotation"].getValue()
                rotation_pcg_mask = self.all_widgets[self.selected_shape_type]["pcgTglRotation"].getValue()
                if rotation_pcg_mask:
                    rotation_main, rotation_max = min(rotation_main, rotation_max), max(rotation_main, rotation_max)

                velocity_main = jnp.array(
                    [
                        self.all_widgets[self.selected_shape_type]["sldVelocity_X"].getValue(),
                        self.all_widgets[self.selected_shape_type]["sldVelocity_Y"].getValue(),
                    ]
                )

                velocity_max = jnp.array(
                    [
                        self.all_widgets[self.selected_shape_type]["sldMaxVelocity_X"].getValue(),
                        self.all_widgets[self.selected_shape_type]["sldMaxVelocity_Y"].getValue(),
                    ]
                )
                velocity_main, velocity_max = jnp.minimum(velocity_main, velocity_max), jnp.maximum(
                    velocity_main, velocity_max
                )
                velocity_pcg_mask = self.all_widgets[self.selected_shape_type]["pcgTglVelocity_X"].getValue()

                angular_velocity_main = self.all_widgets[self.selected_shape_type]["sldAngular_Velocity"].getValue()
                angular_velocity_max = self.all_widgets[self.selected_shape_type]["sldMaxAngular_Velocity"].getValue()
                angular_velocity_pcg_mask = self.all_widgets[self.selected_shape_type][
                    "pcgTglAngular_Velocity"
                ].getValue()
                if angular_velocity_pcg_mask:
                    angular_velocity_main, angular_velocity_max = min(angular_velocity_main, angular_velocity_max), max(
                        angular_velocity_main, angular_velocity_max
                    )

                # Circle stuff
                radius_main, radius_max, radius_pcg_mask = None, None, None
                if self.selected_shape_type == ObjectType.CIRCLE:
                    radius_main = self.all_widgets[self.selected_shape_type]["sldRadius"].getValue()
                    radius_max = self.all_widgets[self.selected_shape_type]["sldMaxRadius"].getValue()
                    radius_pcg_mask = self.all_widgets[self.selected_shape_type]["pcgTglRadius"].getValue()

                # Poly stuff
                vertices_main, vertices_max, vertices_pcg_mask = None, None, None
                if self.selected_shape_type == ObjectType.POLYGON:
                    # Triangle
                    new_size_main = self.all_widgets[self.selected_shape_type]["sldSize"].getValue()
                    new_size_max = self.all_widgets[self.selected_shape_type]["sldMaxSize"].getValue()

                    current_size = jnp.abs(self.pcg_state.env_state.polygon.vertices[self.selected_shape_index]).max()
                    scale_main = new_size_main / current_size
                    scale_max = new_size_max / current_size

                    vertices_main = self.pcg_state.env_state.polygon.vertices[self.selected_shape_index] * scale_main
                    vertices_max = self.pcg_state.env_state.polygon.vertices[self.selected_shape_index] * scale_max
                    vertices_pcg_mask = (
                        jnp.ones_like(vertices_main, dtype=bool)
                        * self.all_widgets[self.selected_shape_type]["pcgTglSize"].getValue()
                    )

                def _set_single_state_rbody(
                    state,
                    parent_container,
                    density,
                    inv_mass,
                    inv_inertia,
                    friction,
                    position,
                    radius,
                    rotation,
                    velocity,
                    angular_velocity,
                    vertices,
                ):
                    new = {
                        "friction": parent_container.friction.at[self.selected_shape_index].set(friction),
                        "collision_mode": parent_container.collision_mode.at[self.selected_shape_index].set(
                            int(self.all_widgets[self.selected_shape_type]["sldCollidability"].getValue())
                        ),
                        "inverse_mass": parent_container.inverse_mass.at[self.selected_shape_index].set(
                            inv_mass * fix_val
                        ),
                        "inverse_inertia": parent_container.inverse_inertia.at[self.selected_shape_index].set(
                            inv_inertia * fix_val
                        ),
                        "position": parent_container.position.at[self.selected_shape_index].set(position),
                        "rotation": parent_container.rotation.at[self.selected_shape_index].set(rotation),
                        "velocity": parent_container.velocity.at[self.selected_shape_index].set(velocity),
                        "angular_velocity": parent_container.angular_velocity.at[self.selected_shape_index].set(
                            angular_velocity
                        ),
                        "restitution": parent_container.restitution.at[self.selected_shape_index].set(restitution),
                    }

                    if self.selected_shape_type == ObjectType.CIRCLE:
                        state = state.replace(
                            circle=state.circle.replace(
                                **new,
                                radius=parent_container.radius.at[self.selected_shape_index].set(radius),
                            ),
                            circle_shape_roles=state.circle_shape_roles.at[self.selected_shape_index].set(
                                int(self.all_widgets[self.selected_shape_type]["sldRole"].getValue())
                            ),
                            circle_densities=state.circle_densities.at[self.selected_shape_index].set(density),
                        )
                    else:
                        state = state.replace(
                            polygon=state.polygon.replace(
                                **new,
                                vertices=parent_container.vertices.at[self.selected_shape_index].set(vertices),
                            ),
                            polygon_shape_roles=state.polygon_shape_roles.at[self.selected_shape_index].set(
                                int(self.all_widgets[self.selected_shape_type]["sldRole"].getValue())
                            ),
                            polygon_densities=state.polygon_densities.at[self.selected_shape_index].set(density),
                        )

                    return state

                position_delta = position_main - shape_main.position
                env_state = _set_single_state_rbody(
                    self.pcg_state.env_state,
                    parent_container_main,
                    new_density_main,
                    inv_mass_main,
                    inv_inertia_main,
                    friction_main,
                    position_main,
                    radius_main,
                    rotation_main,
                    velocity_main,
                    angular_velocity_main,
                    vertices_main,
                )
                env_state_max = _set_single_state_rbody(
                    self.pcg_state.env_state_max,
                    parent_container_max,
                    new_density_max,
                    inv_mass_max,
                    inv_inertia_max,
                    friction_max,
                    position_max,
                    radius_max,
                    rotation_max,
                    velocity_max,
                    angular_velocity_max,
                    vertices_max,
                )
                env_state_pcg_mask = _set_single_state_rbody(
                    self.pcg_state.env_state_pcg_mask,
                    parent_container_pcg_mask,
                    density_pcg_mask,
                    density_pcg_mask,
                    density_pcg_mask,
                    friction_pcg_mask,
                    position_pcg_mask,
                    radius_pcg_mask,
                    rotation_pcg_mask,
                    velocity_pcg_mask,
                    angular_velocity_pcg_mask,
                    vertices_pcg_mask,
                )

                # Now, we must also set all of the tied shape's positions

                correct_index = self.selected_shape_index + (
                    self.static_env_params.num_polygons if self.selected_shape_type == ObjectType.CIRCLE else 0
                )

                nonzero_indices = set(
                    jnp.nonzero(self.pcg_state.tied_together[correct_index].at[correct_index].set(False))[0].tolist()
                )
                for idx in nonzero_indices:
                    if idx < self.static_env_params.num_polygons:
                        env_state = env_state.replace(
                            polygon=env_state.polygon.replace(
                                position=env_state.polygon.position.at[idx].set(
                                    env_state.polygon.position[idx] + position_delta
                                )
                            )
                        )
                    else:
                        idx = idx - self.static_env_params.num_polygons
                        env_state = env_state.replace(
                            circle=env_state.circle.replace(
                                position=env_state.circle.position.at[idx].set(
                                    env_state.circle.position[idx] + position_delta
                                )
                            )
                        )

                if not do_dummy:
                    self.pcg_state = PCGState(
                        env_state=env_state,
                        env_state_max=env_state_max,
                        env_state_pcg_mask=env_state_pcg_mask,
                        tied_together=self.pcg_state.tied_together,
                    )

            if do_dummy or self.selected_shape_type == ObjectType.JOINT:  # rjoint
                speed_main = self.all_widgets[ObjectType.JOINT]["sldSpeed"].getValue()
                speed_max = self.all_widgets[ObjectType.JOINT]["sldMaxSpeed"].getValue()
                speed_pcg_mask = self.all_widgets[ObjectType.JOINT]["pcgTglSpeed"].getValue()
                if speed_pcg_mask:
                    speed_main, speed_max = min(speed_main, speed_max), max(speed_main, speed_max)

                power_main = self.all_widgets[ObjectType.JOINT]["sldPower"].getValue()
                power_max = self.all_widgets[ObjectType.JOINT]["sldMaxPower"].getValue()
                power_pcg_mask = self.all_widgets[ObjectType.JOINT]["pcgTglPower"].getValue()
                if power_pcg_mask:
                    power_main, power_max = min(power_main, power_max), max(power_main, power_max)

                motor_binding_val_min = int(self.all_widgets[ObjectType.JOINT]["sldColour"].getValue())
                motor_binding_val_max = int(self.all_widgets[ObjectType.JOINT]["sldMaxColour"].getValue())
                colour_pcg_mask = self.all_widgets[ObjectType.JOINT]["pcgTglColour"].getValue()
                if colour_pcg_mask:
                    motor_binding_val_min, motor_binding_val_max = min(
                        motor_binding_val_min, motor_binding_val_max
                    ), max(motor_binding_val_min, motor_binding_val_max)

                auto_motor_val = self.all_widgets[ObjectType.JOINT]["tglAutoMotor"].getValue()
                joint_limits_val = self.all_widgets[ObjectType.JOINT]["tglJointLimits"].getValue()
                is_fixed_val = self.all_widgets[ObjectType.JOINT]["tglIsFixedJoint"].getValue()
                is_motor_on = self.all_widgets[ObjectType.JOINT]["tglIsMotorOn"].getValue()

                min_rot_val = jnp.deg2rad(self.all_widgets[ObjectType.JOINT]["sldMin_Rotation"].getValue())
                max_rot_val = jnp.deg2rad(self.all_widgets[ObjectType.JOINT]["sldMax_Rotation"].getValue())
                # ensure the min is less than the max
                min_rot_val, max_rot_val = min(min_rot_val, max_rot_val), max(min_rot_val, max_rot_val)

                def _set_single_state_joint(state, speed, power, colour):
                    state = state.replace(
                        joint=state.joint.replace(
                            motor_speed=state.joint.motor_speed.at[self.selected_shape_index].set(speed),
                            motor_power=state.joint.motor_power.at[self.selected_shape_index].set(power),
                            motor_has_joint_limits=state.joint.motor_has_joint_limits.at[self.selected_shape_index].set(
                                joint_limits_val
                            ),
                            min_rotation=state.joint.min_rotation.at[self.selected_shape_index].set(min_rot_val),
                            max_rotation=state.joint.max_rotation.at[self.selected_shape_index].set(max_rot_val),
                            is_fixed_joint=state.joint.is_fixed_joint.at[self.selected_shape_index].set(is_fixed_val),
                            motor_on=state.joint.motor_on.at[self.selected_shape_index].set(is_motor_on),
                        ),
                        motor_bindings=state.motor_bindings.at[self.selected_shape_index].set(colour),
                        motor_auto=state.motor_auto.at[self.selected_shape_index].set(auto_motor_val),
                    )
                    return state

                env_state = _set_single_state_joint(
                    self.pcg_state.env_state,
                    speed_main,
                    power_main,
                    motor_binding_val_min,
                )

                env_state_max = _set_single_state_joint(
                    self.pcg_state.env_state_max,
                    speed_max,
                    power_max,
                    motor_binding_val_max,
                )

                env_state_pcg_mask = _set_single_state_joint(
                    self.pcg_state.env_state_pcg_mask, speed_pcg_mask, power_pcg_mask, colour_pcg_mask
                )
                if not do_dummy:
                    self.pcg_state = self.pcg_state.replace(
                        env_state=env_state,
                        env_state_max=env_state_max,
                        env_state_pcg_mask=env_state_pcg_mask,
                    )

            if do_dummy or self.selected_shape_type == ObjectType.THRUSTER:  # thruster
                power_main = self.all_widgets[ObjectType.THRUSTER]["sldPower"].getValue()
                power_max = self.all_widgets[ObjectType.THRUSTER]["sldMaxPower"].getValue()
                power_pcg_mask = self.all_widgets[ObjectType.THRUSTER]["pcgTglPower"].getValue()
                if power_pcg_mask:
                    power_main, power_max = min(power_main, power_max), max(power_main, power_max)

                def _set_single_state_thruster(state, power):
                    return state.replace(
                        thruster=state.thruster.replace(
                            power=state.thruster.power.at[self.selected_shape_index].set(power),
                        ),
                        thruster_bindings=state.thruster_bindings.at[self.selected_shape_index].set(
                            int(self.all_widgets[ObjectType.THRUSTER]["sldColour"].getValue())
                        ),
                    )

                env_state = _set_single_state_thruster(self.pcg_state.env_state, power_main)
                env_state_max = _set_single_state_thruster(self.pcg_state.env_state_max, power_max)
                env_state_pcg_mask = _set_single_state_thruster(self.pcg_state.env_state_pcg_mask, power_pcg_mask)
                if not do_dummy:
                    self.pcg_state = self.pcg_state.replace(
                        env_state=env_state,
                        env_state_max=env_state_max,
                        env_state_pcg_mask=env_state_pcg_mask,
                    )

    # flag2
    def _put_state_values_into_gui(self, pcg_state=None):
        def _set_toggle_val(toggle, val):
            if toggle.getValue() != val:
                toggle.toggle()

        def _enable_slider(slider):
            slider.enable()
            slider.colour = (200, 200, 200)
            slider.handleColour = (0, 0, 0)

        def _disable_slider(slider):
            slider.disable()
            slider.colour = (255, 0, 0)
            slider.handleColour = (255, 0, 0)

        if pcg_state is None:
            # state = self.edit_state
            raise ValueError

        def pcg_text(main_val, max_val, pcg_mask):
            if pcg_mask:
                return f"{main_val:.2f} - {max_val:.2f}"
            else:
                return f"{main_val:.2f}"

        # global ones
        gravity_pcg_mask = pcg_state.env_state_pcg_mask.gravity[1]
        self.all_widgets[None]["lblGravity"].setText(
            f"Gravity: {pcg_text(pcg_state.env_state.gravity[1], pcg_state.env_state_max.gravity[1], pcg_state.env_state_pcg_mask.gravity[1])}"
        )
        self.all_widgets[None]["sldGravity"].setValue(pcg_state.env_state.gravity[1])
        self.all_widgets[None]["sldMaxGravity"].setValue(pcg_state.env_state_max.gravity[1])
        if not gravity_pcg_mask:
            _disable_slider(self.all_widgets[None]["sldMaxGravity"])
        else:
            _enable_slider(self.all_widgets[None]["sldMaxGravity"])

        if self.edit_shape_mode != EditMode.SELECT or self.selected_shape_index < 0:
            return

        obj_main = select_object(pcg_state.env_state, self.selected_shape_type, self.selected_shape_index)
        obj_max = select_object(pcg_state.env_state_max, self.selected_shape_type, self.selected_shape_index)
        obj_pcg_mask = select_object(pcg_state.env_state_pcg_mask, self.selected_shape_type, self.selected_shape_index)

        if len(self.all_selected_shapes) > 1:
            indices_to_use = self._get_selected_shape_global_indices()
            are_all_tied = pcg_state.tied_together[indices_to_use.min(), indices_to_use].all()
            _set_toggle_val(self.all_widgets["TIE_TOGETHER"]["tglTieTogether"], are_all_tied)

        if self.selected_shape_type == ObjectType.JOINT:
            self.all_widgets[ObjectType.JOINT]["lblSpeed"].setText(
                f"Speed: {pcg_text(obj_main.motor_speed, obj_max.motor_speed, obj_pcg_mask.motor_speed)}"
            )
            self.all_widgets[ObjectType.JOINT]["sldSpeed"].setValue(obj_main.motor_speed)
            self.all_widgets[ObjectType.JOINT]["sldMaxSpeed"].setValue(obj_max.motor_speed)
            _set_toggle_val(self.all_widgets[ObjectType.JOINT]["pcgTglSpeed"], obj_pcg_mask.motor_speed)
            if obj_pcg_mask.motor_speed:
                _enable_slider(self.all_widgets[ObjectType.JOINT]["sldMaxSpeed"])
            else:
                _disable_slider(self.all_widgets[ObjectType.JOINT]["sldMaxSpeed"])

            self.all_widgets[ObjectType.JOINT]["lblPower"].setText(
                f"Power: {pcg_text(obj_main.motor_power, obj_max.motor_power, obj_pcg_mask.motor_power)}"
            )
            self.all_widgets[ObjectType.JOINT]["sldPower"].setValue(obj_main.motor_power)
            self.all_widgets[ObjectType.JOINT]["sldMaxPower"].setValue(obj_max.motor_power)
            _set_toggle_val(self.all_widgets[ObjectType.JOINT]["pcgTglPower"], obj_pcg_mask.motor_power)
            if obj_pcg_mask.motor_power:
                _enable_slider(self.all_widgets[ObjectType.JOINT]["sldMaxPower"])
            else:
                _disable_slider(self.all_widgets[ObjectType.JOINT]["sldMaxPower"])

            self.all_widgets[ObjectType.JOINT]["lblColour"].setText(
                f"Colour: {pcg_state.env_state.motor_bindings[self.selected_shape_index]}"
            )
            self.all_widgets[ObjectType.JOINT]["sldColour"].setValue(
                pcg_state.env_state.motor_bindings[self.selected_shape_index]
            )

            self.all_widgets[ObjectType.JOINT]["sldMaxColour"].setValue(
                pcg_state.env_state_max.motor_bindings[self.selected_shape_index]
            )

            colour_pcg_mask = pcg_state.env_state_pcg_mask.motor_bindings[self.selected_shape_index]
            if not colour_pcg_mask:
                _disable_slider(self.all_widgets[ObjectType.JOINT]["sldMaxColour"])
            else:
                _enable_slider(self.all_widgets[ObjectType.JOINT]["sldMaxColour"])

            self.all_widgets[ObjectType.JOINT]["lblJointLimits"].setText(
                f"Joint Limits: {obj_main.motor_has_joint_limits}"
            )
            widget_motor_has_joint_limits = self.all_widgets[ObjectType.JOINT]["tglJointLimits"].getValue()
            if obj_main.motor_has_joint_limits != widget_motor_has_joint_limits:  # Update the toggle
                self.all_widgets[ObjectType.JOINT]["tglJointLimits"].toggle()

            mini, maxi = jnp.rad2deg(obj_main.min_rotation), jnp.rad2deg(obj_main.max_rotation)
            self.all_widgets[ObjectType.JOINT]["lblMin_Rotation"].setText(f"Min Rot: {int(mini)}")
            self.all_widgets[ObjectType.JOINT]["sldMin_Rotation"].setValue(mini)
            self.all_widgets[ObjectType.JOINT]["lblMax_Rotation"].setText(f"Max Rot: {int(maxi)}")
            self.all_widgets[ObjectType.JOINT]["sldMax_Rotation"].setValue(maxi)

            if not obj_main.motor_has_joint_limits:
                for k in ["min_rotation", "max_rotation"]:
                    self.all_widgets[self.selected_shape_type][f"sld{k.title()}"].disable()
                    self.all_widgets[self.selected_shape_type][f"sld{k.title()}"].colour = (255, 0, 0)
                    self.all_widgets[self.selected_shape_type][f"sld{k.title()}"].handleColour = (255, 0, 0)
            else:
                for k in ["min_rotation", "max_rotation"]:
                    self.all_widgets[self.selected_shape_type][f"sld{k.title()}"].enable()
                    self.all_widgets[self.selected_shape_type][f"sld{k.title()}"].colour = (200, 200, 200)
                    self.all_widgets[self.selected_shape_type][f"sld{k.title()}"].handleColour = (0, 0, 0)

            self.all_widgets[ObjectType.JOINT]["lblAutoMotor"].setText(
                f"Auto: {pcg_state.env_state.motor_auto[self.selected_shape_index]}"
            )
            widget_is_auto_motor = self.all_widgets[ObjectType.JOINT]["tglAutoMotor"].getValue()
            if pcg_state.env_state.motor_auto[self.selected_shape_index] != widget_is_auto_motor:  # Update the toggle
                self.all_widgets[ObjectType.JOINT]["tglAutoMotor"].toggle()

            self.all_widgets[ObjectType.JOINT]["lblIsFixedJoint"].setText(f"Fixed: {obj_main.is_fixed_joint}")
            widget_is_motor_on = self.all_widgets[ObjectType.JOINT]["tglIsFixedJoint"].getValue()
            if obj_main.is_fixed_joint != widget_is_motor_on:  # Update the toggle
                self.all_widgets[ObjectType.JOINT]["tglIsFixedJoint"].toggle()

            self.all_widgets[ObjectType.JOINT]["lblIsMotorOn"].setText(f"Motor On: {obj_main.motor_on}")
            widget_is_motor_on = self.all_widgets[ObjectType.JOINT]["tglIsMotorOn"].getValue()
            if obj_main.motor_on != widget_is_motor_on:  # Update the toggle
                self.all_widgets[ObjectType.JOINT]["tglIsMotorOn"].toggle()

        elif self.selected_shape_type == ObjectType.THRUSTER:
            # thruster
            self.all_widgets[ObjectType.THRUSTER]["lblPower"].setText(
                f"Power: {pcg_text(obj_main.power, obj_max.power, obj_pcg_mask.power)}"
            )
            self.all_widgets[ObjectType.THRUSTER]["sldPower"].setValue(obj_main.power)
            self.all_widgets[ObjectType.THRUSTER]["sldMaxPower"].setValue(obj_max.power)
            _set_toggle_val(self.all_widgets[ObjectType.THRUSTER]["pcgTglPower"], obj_pcg_mask.power)
            if obj_pcg_mask.power:
                _enable_slider(self.all_widgets[ObjectType.THRUSTER]["sldMaxPower"])
            else:
                _disable_slider(self.all_widgets[ObjectType.THRUSTER]["sldMaxPower"])
            self.all_widgets[ObjectType.THRUSTER]["sldColour"].setValue(
                pcg_state.env_state.thruster_bindings[self.selected_shape_index]
            )
            self.all_widgets[ObjectType.THRUSTER]["lblColour"].setText(
                f"Colour: {pcg_state.env_state.thruster_bindings[self.selected_shape_index]}"
            )

        elif self.selected_shape_type in [ObjectType.POLYGON, ObjectType.CIRCLE]:
            # rigidbody

            # Position
            # We use the mask for position_x for the entire position vector
            self.all_widgets[self.selected_shape_type]["lblPosition_X"].setText(
                f"Position X: {pcg_text(obj_main.position[0], obj_max.position[0], obj_pcg_mask.position[0])}"
            )
            self.all_widgets[self.selected_shape_type]["sldPosition_X"].setValue(obj_main.position[0])
            self.all_widgets[self.selected_shape_type]["sldMaxPosition_X"].setValue(obj_max.position[0])
            _set_toggle_val(self.all_widgets[self.selected_shape_type]["pcgTglPosition_X"], obj_pcg_mask.position[0])

            self.all_widgets[self.selected_shape_type]["lblPosition_Y"].setText(
                f"Position Y: {pcg_text(obj_main.position[1], obj_max.position[1], obj_pcg_mask.position[0])}"
            )
            self.all_widgets[self.selected_shape_type]["sldPosition_Y"].setValue(obj_main.position[1])
            self.all_widgets[self.selected_shape_type]["sldMaxPosition_Y"].setValue(obj_max.position[1])

            if obj_pcg_mask.position[0]:
                _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxPosition_X"])
                _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxPosition_Y"])
            else:
                _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxPosition_X"])
                _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxPosition_Y"])

            # Velocity
            # We use the mask for velocity_x for the entire velocity vector
            self.all_widgets[self.selected_shape_type]["lblVelocity_X"].setText(
                f"Velocity X: {pcg_text(obj_main.velocity[0], obj_max.velocity[0], obj_pcg_mask.velocity[0])}"
            )
            self.all_widgets[self.selected_shape_type]["sldVelocity_X"].setValue(obj_main.velocity[0])
            self.all_widgets[self.selected_shape_type]["sldMaxVelocity_X"].setValue(obj_max.velocity[0])
            _set_toggle_val(self.all_widgets[self.selected_shape_type]["pcgTglVelocity_X"], obj_pcg_mask.velocity[0])

            self.all_widgets[self.selected_shape_type]["lblVelocity_Y"].setText(
                f"Velocity Y: {pcg_text(obj_main.velocity[1], obj_max.velocity[1], obj_pcg_mask.velocity[0])}"
            )
            self.all_widgets[self.selected_shape_type]["sldVelocity_Y"].setValue(obj_main.velocity[1])
            self.all_widgets[self.selected_shape_type]["sldMaxVelocity_Y"].setValue(obj_max.velocity[1])

            if obj_pcg_mask.velocity[0]:
                _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxVelocity_X"])
                _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxVelocity_Y"])
            else:
                _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxVelocity_X"])
                _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxVelocity_Y"])

            # Density
            is_fixated = obj_main.inverse_mass == 0

            def _calc_density(state):
                if self.selected_shape_type == ObjectType.POLYGON:
                    return state.polygon_densities[self.selected_shape_index]
                elif self.selected_shape_type == ObjectType.CIRCLE:
                    return state.circle_densities[self.selected_shape_index]
                else:
                    raise ValueError

            density_main = _calc_density(pcg_state.env_state)
            density_max = _calc_density(pcg_state.env_state_max)
            density_pcg_mask = obj_pcg_mask.inverse_mass

            self.all_widgets[self.selected_shape_type]["lblDensity"].setText(
                f"Density: {pcg_text(density_main, density_max, density_pcg_mask)}"
            )
            self.all_widgets[self.selected_shape_type]["sldDensity"].setValue(density_main)
            self.all_widgets[self.selected_shape_type]["sldMaxDensity"].setValue(density_max)
            _set_toggle_val(self.all_widgets[self.selected_shape_type]["pcgTglDensity"], density_pcg_mask)
            if is_fixated:
                _disable_slider(self.all_widgets[self.selected_shape_type]["sldDensity"])
            else:
                _enable_slider(self.all_widgets[self.selected_shape_type]["sldDensity"])

            if is_fixated or (not density_pcg_mask):
                _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxDensity"])
            else:
                _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxDensity"])

            # Friction
            self.all_widgets[self.selected_shape_type]["lblFriction"].setText(
                f"Friction: {pcg_text(obj_main.friction, obj_max.friction, obj_pcg_mask.friction)}"
            )
            self.all_widgets[self.selected_shape_type]["sldFriction"].setValue(obj_main.friction)
            self.all_widgets[self.selected_shape_type]["sldMaxFriction"].setValue(obj_max.friction)
            _set_toggle_val(self.all_widgets[self.selected_shape_type]["pcgTglFriction"], obj_pcg_mask.friction)

            if not obj_pcg_mask.friction:
                _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxFriction"])
            else:
                _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxFriction"])

            # Restitution
            self.all_widgets[self.selected_shape_type]["lblRestitution"].setText(
                f"Restitution: {obj_main.restitution:.2f}"
            )
            self.all_widgets[self.selected_shape_type]["sldRestitution"].setValue(obj_main.restitution)

            # Rotation
            self.all_widgets[self.selected_shape_type]["lblRotation"].setText(
                f"Rotation: {pcg_text(obj_main.rotation, obj_max.rotation, obj_pcg_mask.rotation)}"
            )
            self.all_widgets[self.selected_shape_type]["sldRotation"].setValue(obj_main.rotation)
            self.all_widgets[self.selected_shape_type]["sldMaxRotation"].setValue(obj_max.rotation)
            _set_toggle_val(self.all_widgets[self.selected_shape_type]["pcgTglRotation"], obj_pcg_mask.rotation)

            if not obj_pcg_mask.rotation:
                _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxRotation"])
            else:
                _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxRotation"])

            # Angular_Velocity
            self.all_widgets[self.selected_shape_type]["lblAngular_Velocity"].setText(
                f"Angular_Velocity: {pcg_text(obj_main.angular_velocity, obj_max.angular_velocity, obj_pcg_mask.angular_velocity)}"
            )
            self.all_widgets[self.selected_shape_type]["sldAngular_Velocity"].setValue(obj_main.angular_velocity)
            self.all_widgets[self.selected_shape_type]["sldMaxAngular_Velocity"].setValue(obj_max.angular_velocity)
            _set_toggle_val(
                self.all_widgets[self.selected_shape_type]["pcgTglAngular_Velocity"], obj_pcg_mask.angular_velocity
            )

            if not obj_pcg_mask.angular_velocity:
                _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxAngular_Velocity"])
            else:
                _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxAngular_Velocity"])

            # Collision mode
            self.all_widgets[self.selected_shape_type]["lblCollidability"].setText(
                f"Collidability: {obj_main.collision_mode}"
            )
            self.all_widgets[self.selected_shape_type]["sldCollidability"].setValue(obj_main.collision_mode)

            # Shape role
            if self.selected_shape_type == ObjectType.POLYGON:
                shape_role = pcg_state.env_state.polygon_shape_roles[self.selected_shape_index]
            else:
                shape_role = pcg_state.env_state.circle_shape_roles[self.selected_shape_index]

            self.all_widgets[self.selected_shape_type]["sldRole"].setValue(shape_role)
            self.all_widgets[self.selected_shape_type]["lblRole"].setText(f"Role: {shape_role}")

            # Fixate
            self.all_widgets[self.selected_shape_type]["lblFixate"].setText(f"Fixate: {is_fixated}")
            widget_is_fixed = self.all_widgets[self.selected_shape_type]["tglFixate"].getValue()
            if is_fixated != widget_is_fixed:  # Update the toggle
                self.all_widgets[self.selected_shape_type]["tglFixate"].toggle()

            # Radius
            if self.selected_shape_type == ObjectType.CIRCLE:
                self.all_widgets[self.selected_shape_type]["lblRadius"].setText(
                    f"Radius: {pcg_text(obj_main.radius, obj_max.radius, obj_pcg_mask.radius)}"
                )
                self.all_widgets[self.selected_shape_type]["sldRadius"].setValue(obj_main.radius)
                self.all_widgets[self.selected_shape_type]["sldMaxRadius"].setValue(obj_max.radius)
                _set_toggle_val(self.all_widgets[self.selected_shape_type]["pcgTglRadius"], obj_pcg_mask.radius)

                if not obj_pcg_mask.radius:
                    _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxRadius"])
                else:
                    _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxRadius"])

            elif self.selected_shape_type == ObjectType.POLYGON:
                size_main = jnp.abs(obj_main.vertices).max()
                size_max = jnp.abs(obj_max.vertices).max()
                size_pcg_mask = obj_pcg_mask.vertices.any()

                self.all_widgets[self.selected_shape_type]["lblSize"].setText(
                    f"Size: {pcg_text(size_main, size_max, size_pcg_mask)}"
                )
                self.all_widgets[self.selected_shape_type]["sldSize"].setValue(size_main)
                self.all_widgets[self.selected_shape_type]["sldMaxSize"].setValue(size_max)
                _set_toggle_val(self.all_widgets[self.selected_shape_type]["pcgTglSize"], size_pcg_mask)

                if not size_pcg_mask:
                    _disable_slider(self.all_widgets[self.selected_shape_type]["sldMaxSize"])
                else:
                    _enable_slider(self.all_widgets[self.selected_shape_type]["sldMaxSize"])

    def _render_side_panel(self):
        arr = jnp.ones((self.side_panel_width, self.static_env_params.screen_dim[1], 3)) * (
            jnp.array([135.0, 206.0, 235.0])[None, None] + 20
        )
        return arr

    def make_label_and_slider(
        self,
        start_y,
        label_text,
        slider_min=0.0,
        slider_max=1.0,
        slider_step=0.05,
        font_size=18,
        is_toggle=False,
        is_pcg=False,
        add_pcg_toggle=True,
    ):
        Wl = round(self.W * 0.7)
        pcg_lbl = None
        pcg_toggle = None
        widget_max = None

        label = TextBox(
            self.screen_surface,
            self.MARGIN,
            start_y,
            Wl,
            20,
            fontSize=font_size,
            margin=0,
            placeholderText=label_text,
            font=pygame.font.SysFont("sans-serif", font_size),
        )
        label.disable()  # Act as label instead of textbox
        if is_toggle:
            widget = Toggle(self.screen_surface, self.W // 2 - 20, start_y + 23, 20, 13)
        else:
            widget = Slider(
                self.screen_surface,
                self.MARGIN,
                start_y + 23,
                Wl,
                13,
                min=slider_min,
                max=slider_max,
                step=slider_step,
            )

            if is_pcg:
                widget_max = Slider(
                    self.screen_surface,
                    self.MARGIN,
                    start_y + 50,
                    Wl,
                    13,
                    min=slider_min,
                    max=slider_max,
                    step=slider_step,
                )

                pcg_lbl = TextBox(
                    self.screen_surface,
                    self.MARGIN + Wl + 30,
                    start_y,
                    60,
                    20,
                    fontSize=font_size,
                    margin=0,
                    placeholderText="PCG",
                    font=pygame.font.SysFont("sans-serif", font_size),
                )
                if add_pcg_toggle:
                    pcg_toggle = Toggle(self.screen_surface, Wl + 80, start_y + 23, 20, 13)

        return label, widget, (pcg_toggle, pcg_lbl, widget_max)

    def _setup_side_panel(self):

        W = self.W = int(self.side_panel_width * self.upscale * 0.8)
        MARGIN = self.MARGIN = (self.side_panel_width * self.upscale - W) // 2

        # global values
        G = {}
        gravity_label, gravity_slider, (pcg_toggle, pcg_lbl, slider_max) = self.make_label_and_slider(
            150, "Gravity", -20.0, 0.0, 0.2, is_pcg=True
        )
        G["lblGravity"] = gravity_label
        G["sldGravity"] = gravity_slider
        G["sldMaxGravity"] = slider_max
        G["pcgLblGravity"] = pcg_lbl
        G["pcgTglGravity"] = pcg_toggle

        # thruster values
        T = {}
        thruster_power_label, thruster_power_slider, (pcg_toggle, pcg_lbl, slider_max) = self.make_label_and_slider(
            150, "Power", slider_max=3.0, is_pcg=True
        )
        T["lblPower"] = thruster_power_label
        T["sldPower"] = thruster_power_slider
        T["sldMaxPower"] = slider_max
        T["pcgLblPower"] = pcg_lbl
        T["pcgTglPower"] = pcg_toggle
        thruster_colour_label, thruster_colour_slider, _ = self.make_label_and_slider(
            250, "Colour", 0, self.static_env_params.num_thruster_bindings - 1, 1
        )
        T["lblColour"] = thruster_colour_label
        T["sldColour"] = thruster_colour_slider

        # joints
        D = {}
        for i, (name, (mini, maxi, step), is_pcg) in enumerate(
            zip(
                ["speed", "power", "colour", "min_rotation", "max_rotation"],
                [
                    (-3, 3, 0.05),
                    (0, 3, 0.05),
                    (0, self.static_env_params.num_motor_bindings - 1, 1),
                    (-180, 180, 5),
                    (-180, 180, 5),
                ],
                [True, True, True, False, False],
            )
        ):
            label, slider, (pcg_toggle, pcg_lbl, slider_max) = self.make_label_and_slider(
                150 + 80 * i,
                f"Motor {name.title()}",
                slider_min=mini,
                slider_max=maxi,
                slider_step=step,
                is_pcg=is_pcg,
            )
            D["lbl" + name.title()] = label
            D["sld" + name.title()] = slider
            if is_pcg:
                D["sldMax" + name.title()] = slider_max
                D["pcgLbl" + name.title()] = pcg_lbl
                D["pcgTgl" + name.title()] = pcg_toggle

        label, toggle, _ = self.make_label_and_slider(150 + 80 * 5, "Joint Limits", is_toggle=True)
        D["lblJointLimits"] = label
        D["tglJointLimits"] = toggle

        label, toggle, _ = self.make_label_and_slider(150 + 80 * 6, "Auto", is_toggle=True)
        D["lblAutoMotor"] = label
        D["tglAutoMotor"] = toggle

        label, toggle, _ = self.make_label_and_slider(150 + 80 * 7, "Fixed", is_toggle=True)
        D["lblIsFixedJoint"] = label
        D["tglIsFixedJoint"] = toggle

        label, toggle, _ = self.make_label_and_slider(150 + 80 * 8, "Motor On", is_toggle=True)
        D["lblIsMotorOn"] = label
        D["tglIsMotorOn"] = toggle

        def _create_rigid_body_base_gui():
            D_rigid = {}
            # rigidbodies
            total_toggles = 0
            total_non_toggles = 0
            for i, (name, bounds, is_pcg, add_pcg_toggle) in enumerate(
                zip(
                    [
                        "position_x",
                        "position_y",
                        "rotation",
                        "velocity_x",
                        "velocity_y",
                        "angular_velocity",
                        "density",
                        "friction",
                        "restitution",
                        "collidability",
                        "role",
                    ],
                    [
                        (0, 5.0, 0.01),
                        (0, 5.0, 0.01),
                        (-2 * jnp.pi, 2 * jnp.pi, 0.01),
                        (-10.0, 10.0, 0.1),
                        (-10.0, 10.0, 0.1),
                        (-6, 6.0, 0.01),
                        (0.1, 5.0, 0.1),
                        (0.02, 1.0, 0.02),
                        (0.0, 0.8, 0.02),
                        (0, 2, 1),
                        (0, 3, 1),
                    ],
                    [True, True, True, True, True, True, True, True, False, False, False],
                    [True, False, True, True, False, True, True, True, False, False, False],
                )
            ):
                location = 50 + 80 * total_toggles + 40 * total_non_toggles
                label, slider, (pcg_toggle, pcg_lbl, slider_max) = self.make_label_and_slider(
                    location, name.title(), *bounds, is_pcg=is_pcg, add_pcg_toggle=add_pcg_toggle
                )
                total_toggles += is_pcg
                total_non_toggles += not is_pcg
                D_rigid["lbl" + name.title()] = label
                D_rigid["sld" + name.title()] = slider

                if is_pcg:
                    D_rigid["sldMax" + name.title()] = slider_max
                    if add_pcg_toggle:
                        D_rigid["pcgLbl" + name.title()] = pcg_lbl
                        D_rigid["pcgTgl" + name.title()] = pcg_toggle

            location = 50 + 80 * total_toggles + 40 * total_non_toggles
            # toggles:
            label, toggle, _ = self.make_label_and_slider(location, "Fixate", is_toggle=True)
            D_rigid["lblFixate"] = label
            D_rigid["tglFixate"] = toggle

            return D_rigid, location

        # Circle extras
        D_circle, location = _create_rigid_body_base_gui()
        label, slider, (pcg_toggle, pcg_lbl, slider_max) = self.make_label_and_slider(
            location + 40, "Radius", slider_min=0.1, slider_max=1.0, slider_step=0.02, is_pcg=True, add_pcg_toggle=True
        )
        D_circle["lblRadius"] = label
        D_circle["sldRadius"] = slider
        D_circle["sldMaxRadius"] = slider_max
        D_circle["pcgLblRadius"] = pcg_lbl
        D_circle["pcgTglRadius"] = pcg_toggle

        # Polygon extras
        D_poly, location = _create_rigid_body_base_gui()
        label, slider, (pcg_toggle, pcg_lbl, slider_max) = self.make_label_and_slider(
            location + 40, "Size", slider_min=0.1, slider_max=2.0, slider_step=0.02, is_pcg=True, add_pcg_toggle=True
        )
        D_poly["lblSize"] = label
        D_poly["sldSize"] = slider
        D_poly["sldMaxSize"] = slider_max
        D_poly["pcgLblSize"] = pcg_lbl
        D_poly["pcgTglSize"] = pcg_toggle

        label, toggle, _ = self.make_label_and_slider(150 + 80 * 5, "Tie Positions Together", is_toggle=True)
        TIE_TOGETHER = {
            "lblTieTogether": label,
            "tglTieTogether": toggle,
        }

        self.all_widgets = {
            ObjectType.THRUSTER: T,
            ObjectType.JOINT: D,
            "GENERAL": {
                "lblGeneral": TextBox(
                    self.screen_surface,
                    MARGIN,
                    10,
                    W,
                    30,
                    fontSize=20,
                    margin=0,
                    placeholderText="General",
                    font=pygame.font.SysFont("sans-serif", 35),
                ),
            },
            ObjectType.POLYGON: D_poly,
            ObjectType.CIRCLE: D_circle,
            None: G,
            "TIE_TOGETHER": TIE_TOGETHER,
        }

        self._hide_all_widgets()

    def _render_edit_overlay(self, pixels, is_editing, edit_shape_mode):
        is_editing_texture = jax.lax.select(is_editing, EDIT_TEXTURE_RGBA, PLAY_TEXTURE_RGBA)
        is_editing_texture = jnp.repeat(jnp.repeat(is_editing_texture, self.upscale, axis=0), self.upscale, axis=1)

        offset = self.side_panel_width * self.upscale
        w = 64 * self.upscale
        offset2 = int(w * 1.25)
        offset_y = 16 * self.upscale

        play_tex_with_background = (1 - is_editing_texture[:, :, 3:]) * pixels[
            offset + 0 : offset + w, 0:w
        ] + is_editing_texture[:, :, 3:] * is_editing_texture[:, :, :3]
        pixels = pixels.at[offset : offset + w, 0:w].set(play_tex_with_background)

        edit_shape_texture = jax.lax.switch(
            edit_shape_mode,
            [
                lambda: CIRCLE_TEXTURE_RGBA,
                lambda: RECT_TEXTURE_RGBA,
                lambda: RJOINT_TEXTURE_RGBA,
                lambda: SELECT_TEXTURE_RGBA,
                lambda: TRIANGLE_TEXTURE_RGBA,
                lambda: THRUSTER_TEXTURE_RGBA,
            ],
        )

        edit_shape_texture = jnp.repeat(jnp.repeat(edit_shape_texture, self.upscale, axis=0), self.upscale, axis=1)

        edit_shape_texture_alpha = edit_shape_texture[:, :, 3:] * is_editing

        w = 32 * self.upscale
        edit_shape_texture_with_background = (1 - edit_shape_texture_alpha) * pixels[
            offset + offset2 : offset + offset2 + w, offset_y : offset_y + w
        ] + edit_shape_texture_alpha * edit_shape_texture[:, :, :3]
        pixels = pixels.at[offset + offset2 : offset + offset2 + w, offset_y : offset_y + w].set(
            edit_shape_texture_with_background
        )

        return pixels

    def edit(self):
        self.rng, _rng = jax.random.split(self.rng)
        pcg_state = self.pcg_state

        left_click = False
        right_click = False

        keys = []
        keys_up_this_frame = set()
        for event in self.pygame_events:
            if event.type == pygame.KEYDOWN:
                if (
                    event.key == pygame.K_s
                    and (pygame.key.get_mods() & pygame.KMOD_CTRL)
                    and (pygame.key.get_mods() & pygame.KMOD_SHIFT)
                ):
                    filename = prompt_file(save=True)
                    if filename:
                        filename += ".level.pkl"
                        save_pickle(filename, self.last_played_level)
                        print(f"Saved last sampled level to {filename}")
                elif event.key == pygame.K_s and (pygame.key.get_mods() & pygame.KMOD_CTRL):
                    pcg_state = self._reset_select_shape(pcg_state)
                    filename = prompt_file(save=True)
                    if filename:
                        if filename.endswith(".json"):
                            export_env_state_to_json(
                                filename, pcg_state.env_state, self.static_env_params, self.env_params
                            )
                        elif not filename.endswith(".pcg.pkl"):
                            filename += ".pcg.pkl"
                            save_pickle(filename, pcg_state)
                        print(f"Saved PCG state to {filename}")
                elif event.key == pygame.K_o and (pygame.key.get_mods() & pygame.KMOD_CTRL):
                    filename = prompt_file(save=False)
                    if filename:
                        self._reset_select_shape(pcg_state)
                        if filename.endswith(".pcg.pkl"):
                            pcg_state = load_pcg_state_pickle(filename)
                            pcg_state = expand_pcg_state(pcg_state, self.static_env_params)
                            print(f"Loaded PCG state from {filename}")
                        elif filename.endswith(".level.pkl"):
                            env_state = load_world_state_pickle(filename)
                            pcg_state = env_state_to_pcg_state(env_state)
                            print(f"Converted level state to PCG state from {filename}")
                        elif filename.endswith(".json"):
                            env_state, new_static_env_params, new_env_params = load_from_json_file(filename)
                            self._update_params(new_static_env_params, new_env_params)
                            pcg_state = env_state_to_pcg_state(env_state)
                    self._reset_triangles()
                elif event.key == pygame.K_n and (pygame.key.get_mods() & pygame.KMOD_CTRL):
                    self._reset_select_shape(pcg_state)
                    pcg_state = new_pcg_env(self.static_env_params)
                    self._reset_triangles()
                else:
                    keys.append(event.key)
            elif event.type == pygame.KEYUP:
                keys_up_this_frame.add(event.key)
            if event.type == pygame.MOUSEBUTTONDOWN and self._get_mouse_position_world_space()[0] >= 0:
                if event.button == 1:
                    left_click = True
                if event.button == 3:
                    right_click = True

            if event.type == pygame.MOUSEWHEEL:
                pcg_state = self._handle_scroll_wheel(pcg_state, event.y)

        if self.selected_shape_index == -1:
            num = get_numeric_key_pressed(self.pygame_events)
            if num is not None:
                self.edit_shape_mode = EditMode(num % len(EditMode))

        # We have to do these checks outside the loop, otherwise they get triggered multiple times per key press.
        if pygame.key.get_mods() & pygame.KMOD_SHIFT:
            state = pcg_state.env_state
            if pygame.K_m in keys_up_this_frame:
                state = self.mutate_world(_rng, state, 1)
            if pygame.K_c in keys_up_this_frame:
                state, _ = mutate_add_connected_shape(
                    _rng, state, self.env_params, self.static_env_params, self.ued_params
                )
            elif pygame.K_s in keys_up_this_frame:
                state = mutate_add_shape(_rng, state, self.env_params, self.static_env_params, self.ued_params)
            elif pygame.K_p in keys_up_this_frame:
                state = mutate_swap_role(_rng, state, self.env_params, self.static_env_params, self.ued_params)
            elif pygame.K_r in keys_up_this_frame:
                state = mutate_remove_shape(_rng, state, self.env_params, self.static_env_params, self.ued_params)
            elif pygame.K_j in keys_up_this_frame:
                state = mutate_remove_joint(_rng, state, self.env_params, self.static_env_params, self.ued_params)
            elif pygame.K_t in keys_up_this_frame:
                state = mutate_toggle_fixture(_rng, state, self.env_params, self.static_env_params, self.ued_params)
            elif pygame.K_g in keys_up_this_frame:
                state = mutate_add_thruster(_rng, state, self.env_params, self.static_env_params, self.ued_params)
            elif pygame.K_l in keys_up_this_frame:
                state = mutate_remove_thruster(_rng, state, self.env_params, self.static_env_params, self.ued_params)
            elif pygame.K_b in keys_up_this_frame:
                state = self.mutate_change_shape_size(
                    _rng, state, self.env_params, self.static_env_params, self.ued_params
                )
            elif pygame.K_x in keys_up_this_frame:
                state = mutate_change_shape_location(
                    _rng, state, self.env_params, self.static_env_params, self.ued_params
                )
            elif pygame.K_k in keys_up_this_frame:
                state = self.mutate_change_shape_rotation(
                    _rng, state, self.env_params, self.static_env_params, self.ued_params
                )
            pcg_state = pcg_state.replace(env_state=state)

        if pygame.K_p in keys_up_this_frame:
            global myrng
            myrng, _rng = jax.random.split(myrng)
            # use the same rng
            pcg_state = permute_pcg_state(_rng, pcg_state, self.static_env_params)
        if self.edit_shape_mode == EditMode.SELECT:  # select a shape
            pcg_state = self._edit_select_shape(pcg_state, left_click, right_click, keys)
            self._put_state_values_into_gui(pcg_state)
            pcg_state = self._select_shape_keyboard_shortcuts(pcg_state, left_click, keys)
        else:
            pcg_state = self._reset_select_shape(pcg_state)  # don't highlight
            self._put_state_values_into_gui(pcg_state)
            self._show_correct_widgets(None)

        if self.edit_shape_mode != EditMode.ADD_TRIANGLE or not self.creating_shape:
            self.num_triangle_clicks = 0

        if self.edit_shape_mode == EditMode.ADD_CIRCLE:
            pcg_state = self._edit_circle(pcg_state, left_click, right_click)
        elif self.edit_shape_mode == EditMode.ADD_RECTANGLE:
            pcg_state = self._edit_rect(pcg_state, left_click, right_click)
        elif self.edit_shape_mode == EditMode.ADD_JOINT:
            pcg_state = self._edit_joint(pcg_state, left_click, right_click)
        elif self.edit_shape_mode == EditMode.ADD_TRIANGLE:
            pcg_state = self._edit_triangle(pcg_state, left_click, right_click)
        elif self.edit_shape_mode == EditMode.ADD_THRUSTER:
            pcg_state = self._edit_thruster(pcg_state, left_click, right_click)

        pcg_state = pcg_state.replace(
            env_state=recompute_global_joint_positions(
                pcg_state.env_state.replace(
                    collision_matrix=calculate_collision_matrix(self.static_env_params, pcg_state.env_state.joint),
                ),
                self.static_env_params,
            ),
            env_state_pcg_mask=pcg_state.env_state_pcg_mask.replace(
                collision_matrix=jnp.zeros_like(pcg_state.env_state_pcg_mask.collision_matrix)
            ),
        )
        return pcg_state

    def _update_params(self, new_static_env_params: StaticEnvParams, new_env_params: EnvParams):
        self.static_env_params = new_static_env_params.replace(
            frame_skip=self.config["frame_skip"], downscale=self.config["downscale"]
        )
        self.env_params = new_env_params
        env = make_kinetix_env_from_name("Kinetix-Entity-MultiDiscrete-v1", static_env_params=self.static_env_params)
        self.env = AutoResetWrapper(env, make_reset_function(self.static_env_params))
        self._setup_rendering(self.static_env_params, self.env_params)

    def _discard_shape_being_created(self, pcg_state):
        env_state = pcg_state.env_state
        if self.creating_shape:
            if self.edit_shape_mode == EditMode.ADD_CIRCLE:
                env_state = env_state.replace(
                    circle=env_state.circle.replace(
                        active=env_state.circle.active.at[self.creating_shape_index].set(False)
                    )
                )

            elif self.edit_shape_mode == EditMode.ADD_RECTANGLE:
                env_state = env_state.replace(
                    polygon=env_state.polygon.replace(
                        active=env_state.polygon.active.at[self.creating_shape_index].set(False)
                    )
                )

        self.creating_shape = False

        return pcg_state.replace(env_state=env_state)

    def _handle_scroll_wheel(self, pcg_state, y):
        if y == 0:
            return pcg_state

        state = self._discard_shape_being_created(pcg_state)

        self.edit_shape_mode = EditMode((self.edit_shape_mode.value + y) % len(EditMode))

        return state

    def _get_mouse_position_world_space(self):
        mouse_pos = pygame.mouse.get_pos()
        return (
            jnp.array(
                [
                    mouse_pos[0] / self.upscale - self.side_panel_width,
                    self.static_env_params.screen_dim[1] - mouse_pos[1] / self.upscale,
                ]
            )
            / self.env_params.pixels_per_unit
        )

    def _get_circles_on_mouse(self, state):
        mouse_pos = self._get_mouse_position_world_space()

        cis = []

        for ci in jnp.arange(self.static_env_params.num_circles)[::-1]:
            circle = jax.tree.map(lambda x: x[ci], state.circle)

            if not circle.active:
                continue

            dist = jnp.linalg.norm(mouse_pos - circle.position)
            if dist <= circle.radius:
                cis.append(ci)

        return cis

    def _get_revolute_joints_on_mouse(self, state: EnvState):
        mouse_pos = self._get_mouse_position_world_space()

        ris = []

        for ri in jnp.arange(self.static_env_params.num_joints)[::-1]:
            joint = jax.tree.map(lambda x: x[ri], state.joint)

            if not joint.active:
                continue

            dist = jnp.linalg.norm(mouse_pos - joint.global_position)
            if dist <= 10 / 100:  # arbitrary
                ris.append(ri)

        return ris

    def _get_thrusters_on_mouse(self, state: EnvState):
        mouse_pos = self._get_mouse_position_world_space()

        ris = []

        for ri in jnp.arange(self.static_env_params.num_thrusters)[::-1]:
            thruster = jax.tree.map(lambda x: x[ri], state.thruster)

            if not thruster.active:
                continue

            dist = jnp.linalg.norm(mouse_pos - thruster.global_position)
            if dist <= 16 / 100:  # arbitrary
                ris.append(ri)

        return ris

    def _get_joints_attached_to_shape(self, state, shape_index):
        r_a = jnp.arange(self.static_env_params.num_joints)[state.joint.a_index == shape_index]
        r_b = jnp.arange(self.static_env_params.num_joints)[state.joint.b_index == shape_index]

        t = jnp.arange(self.static_env_params.num_thrusters)[state.thruster.object_index == shape_index]
        return jnp.concatenate([r_a, r_b], axis=0), t

    def _edit_thruster(self, pcg_state: PCGState, left_click: bool, right_click: bool):
        if not self.creating_shape and (1 - pcg_state.env_state.thruster.active.astype(int)).sum() == 0:
            if not right_click:
                return pcg_state

        thruster_pos = self._get_mouse_position_world_space()
        idx = -1
        for ri in self._get_polygons_on_mouse(pcg_state.env_state):
            r = jax.tree.map(lambda x: x[ri], pcg_state.env_state.polygon)
            thruster_pos = snap_to_polygon_center_line(r, thruster_pos)
            relative_pos = jnp.matmul(rmat(r.rotation).transpose((1, 0)), thruster_pos - r.position)
            idx = ri
            break
        if idx == -1:
            for ci in self._get_circles_on_mouse(pcg_state.env_state):
                c = jax.tree.map(lambda x: x[ci], pcg_state.env_state.circle)
                thruster_pos = snap_to_center(c, thruster_pos)
                thruster_pos = snap_to_circle_center_line(c, thruster_pos)
                relative_pos = thruster_pos - c.position
                idx = ci + self.static_env_params.num_polygons
                break
        if left_click:
            if self.creating_shape:
                self.creating_shape = False
            else:
                if idx >= 0:
                    self.creating_shape = True
                    self.creating_shape_position = thruster_pos
                    self.creating_shape_index = jnp.argmin(pcg_state.env_state.thruster.active)
                    shape = select_shape(pcg_state.env_state, idx, self.static_env_params)

                    def _add_thruster_to_state(state):
                        state = state.replace(
                            thruster=state.thruster.replace(
                                object_index=state.thruster.object_index.at[self.creating_shape_index].set(idx),
                                relative_position=state.thruster.relative_position.at[self.creating_shape_index].set(
                                    relative_pos
                                ),
                                power=state.thruster.power.at[self.creating_shape_index].set(
                                    1.0 / jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass)
                                ),
                                active=state.thruster.active.at[self.creating_shape_index].set(True),
                                global_position=state.thruster.global_position.at[self.creating_shape_index].set(
                                    thruster_pos
                                ),
                                rotation=state.thruster.rotation.at[self.creating_shape_index].set(0.0),
                            ),
                            thruster_bindings=state.thruster_bindings.at[self.creating_shape_index].set(0),
                        )
                        return state

                    pcg_state = pcg_state.replace(
                        env_state=_add_thruster_to_state(pcg_state.env_state),
                        env_state_max=_add_thruster_to_state(pcg_state.env_state_max),
                    )
        elif right_click:
            for ti in self._get_thrusters_on_mouse(pcg_state.env_state):

                def _remove_thruster_from_state(state):
                    return state.replace(
                        thruster=state.thruster.replace(active=state.thruster.active.at[ti].set(False))
                    )

                return pcg_state.replace(
                    env_state=_remove_thruster_from_state(pcg_state.env_state),
                    env_state_max=_remove_thruster_from_state(pcg_state.env_state_max),
                )
        else:
            if self.creating_shape:
                curr_pos = self._get_mouse_position_world_space()

                normal = pcg_state.env_state.thruster.relative_position[self.creating_shape_index]
                angle = jnp.arctan2(normal[1], normal[0])

                relative_pos = curr_pos - self.creating_shape_position
                # rotation = jnp.arctan2(relative_pos[1], relative_pos[0])
                rotation = jnp.pi + jnp.arctan2(relative_pos[1], relative_pos[0]) + angle
                angle_round = jnp.round(rotation / (jnp.pi / 2))
                angle_norm = rotation / (jnp.pi / 2)
                if jnp.abs(angle_round - angle_norm) < 0.3:
                    rotation = angle_round * (jnp.pi / 2)

                def _update_thruster_rotation(state):
                    return state.replace(
                        thruster=state.thruster.replace(
                            rotation=state.thruster.rotation.at[self.creating_shape_index].set(rotation - angle),
                        )
                    )

                pcg_state = pcg_state.replace(
                    env_state=_update_thruster_rotation(pcg_state.env_state),
                    env_state_max=_update_thruster_rotation(pcg_state.env_state_max),
                )
            else:
                pass

        return pcg_state

    def _edit_circle(self, pcg_state: PCGState, left_click: bool, right_click: bool):
        if right_click:
            for ci in self._get_circles_on_mouse(pcg_state.env_state):
                attached_j, attached_t = self._get_joints_attached_to_shape(
                    pcg_state.env_state, ci + self.static_env_params.num_polygons
                )

                def _remove_circle_from_state(state):
                    return state.replace(
                        circle=state.circle.replace(active=state.circle.active.at[ci].set(False)),
                        joint=state.joint.replace(active=state.joint.active.at[attached_j].set(False)),
                        thruster=state.thruster.replace(active=state.thruster.active.at[attached_t].set(False)),
                    )

                env_state = _remove_circle_from_state(pcg_state.env_state)
                env_state_pcg_mask = _remove_circle_from_state(pcg_state.env_state_pcg_mask)
                env_state_max = _remove_circle_from_state(pcg_state.env_state_max)

                env_state = env_state.replace(
                    collision_matrix=calculate_collision_matrix(self.static_env_params, env_state.joint)
                )

                return PCGState(
                    env_state=env_state,
                    env_state_pcg_mask=env_state_pcg_mask,
                    env_state_max=env_state_max,
                    tied_together=pcg_state.tied_together,
                )

        if not self.creating_shape and (1 - pcg_state.env_state.circle.active.astype(int)).sum() == 0:
            return pcg_state

        radius = jnp.linalg.norm(self._get_mouse_position_world_space() - self.create_shape_position)
        radius = jnp.clip(radius, 5.0 / self.env_params.pixels_per_unit, self.static_env_params.max_shape_size / 2)

        def _add_circle(state, highlight):
            state = state.replace(
                circle=state.circle.replace(
                    position=state.circle.position.at[self.creating_shape_index].set(self.create_shape_position),
                    velocity=state.circle.velocity.at[self.creating_shape_index].set(jnp.array([0.0, 0.0])),
                    radius=state.circle.radius.at[self.creating_shape_index].set(radius),
                    inverse_mass=state.circle.inverse_mass.at[self.creating_shape_index].set(1.0),
                    inverse_inertia=state.circle.inverse_inertia.at[self.creating_shape_index].set(1.0),
                    active=state.circle.active.at[self.creating_shape_index].set(True),
                    collision_mode=state.circle.collision_mode.at[self.creating_shape_index].set(1),
                ),
                circle_shape_roles=state.circle_shape_roles.at[self.creating_shape_index].set(0),
                circle_highlighted=state.circle_highlighted.at[self.creating_shape_index].set(highlight),
                circle_densities=state.circle_densities.at[self.creating_shape_index].set(1.0),
            )
            return state

        if left_click:
            if self.creating_shape:
                env_state = _add_circle(pcg_state.env_state, False)
                env_state_max = _add_circle(pcg_state.env_state_max, False)

                env_state = recalculate_mass_and_inertia(
                    env_state,
                    self.static_env_params,
                    env_state.polygon_densities,
                    env_state.circle_densities,
                )
                env_state_max = recalculate_mass_and_inertia(
                    env_state_max,
                    self.static_env_params,
                    env_state.polygon_densities,
                    env_state.circle_densities,
                )

                pcg_state = pcg_state.replace(env_state=env_state, env_state_max=env_state_max)

                self.creating_shape = False
            else:
                self.creating_shape_index = jnp.argmin(pcg_state.env_state.circle.active)
                self.create_shape_position = self._get_mouse_position_world_space()
                self.creating_shape = True

        else:
            if self.creating_shape:
                env_state = _add_circle(pcg_state.env_state, True)
                env_state_max = _add_circle(pcg_state.env_state_max, True)

                pcg_state = pcg_state.replace(env_state=env_state, env_state_max=env_state_max)
            else:
                pass

        return pcg_state

    def _get_polygons_on_mouse(self, state, n_vertices=None):
        # n_vertices=None selects both triangles and quads
        ris = []

        mouse_pos = self._get_mouse_position_world_space()

        for ri in jnp.arange(self.static_env_params.num_polygons)[::-1]:
            polygon = jax.tree.map(lambda x: x[ri], state.polygon)

            if (not polygon.active) or ((n_vertices is not None) and polygon.n_vertices != n_vertices):
                continue

            mpos = rmat(-polygon.rotation) @ (mouse_pos - polygon.position)

            def _signed_line_distance(a, b, c):
                return (b[0] - a[0]) * (c[1] - a[1]) - (b[1] - a[1]) * (c[0] - a[0])

            inside = True
            for fi in range(polygon.n_vertices):
                v1 = polygon.vertices[fi]
                v2 = polygon.vertices[(fi + 1) % polygon.n_vertices]
                if _signed_line_distance(mpos, v1, v2) > 0:
                    inside = False

            if inside:
                ris.append(ri)

        return ris

    def _edit_rect(self, pcg_state: PCGState, left_click: bool, right_click: bool):
        if right_click:
            for ri in self._get_polygons_on_mouse(pcg_state.env_state, n_vertices=4):
                attached_j, attached_t = self._get_joints_attached_to_shape(pcg_state.env_state, ri)

                def _remove_rect_from_state(state):
                    state = state.replace(
                        polygon=state.polygon.replace(
                            active=state.polygon.active.at[ri].set(False),
                            rotation=state.polygon.rotation.at[ri].set(0.0),
                        ),
                        joint=state.joint.replace(active=state.joint.active.at[attached_j].set(False)),
                        thruster=state.thruster.replace(active=state.thruster.active.at[attached_t].set(False)),
                    )

                    return state

                env_state = _remove_rect_from_state(pcg_state.env_state)
                env_state_max = _remove_rect_from_state(pcg_state.env_state_max)
                env_state_pcg_mask = _remove_rect_from_state(pcg_state.env_state_pcg_mask)

                env_state = env_state.replace(
                    collision_matrix=calculate_collision_matrix(self.static_env_params, env_state.joint)
                )

                return PCGState(
                    env_state=env_state,
                    env_state_max=env_state_max,
                    env_state_pcg_mask=env_state_pcg_mask,
                    tied_together=pcg_state.tied_together,
                )

        if not self.creating_shape and (1 - pcg_state.env_state.polygon.active.astype(int)).sum() == 0:
            return pcg_state

        diff = (self._get_mouse_position_world_space() - self.create_shape_position) / 2
        diff = jnp.clip(
            diff,
            -(self.static_env_params.max_shape_size / 2) / jnp.sqrt(2),
            (self.static_env_params.max_shape_size / 2) / jnp.sqrt(2),
        )
        half_dim = jnp.abs(diff)
        half_dim = jnp.clip(half_dim, a_min=5.0 / self.env_params.pixels_per_unit)
        vertices = rectangle_vertices(half_dim)

        def _add_rect_to_state(state, highlight):
            state = state.replace(
                polygon=state.polygon.replace(
                    position=state.polygon.position.at[self.creating_shape_index].set(
                        self.create_shape_position + diff
                    ),
                    velocity=state.polygon.velocity.at[self.creating_shape_index].set(jnp.array([0.0, 0.0])),
                    vertices=state.polygon.vertices.at[self.creating_shape_index].set(vertices),
                    inverse_mass=state.polygon.inverse_mass.at[self.creating_shape_index].set(1.0),
                    inverse_inertia=state.polygon.inverse_inertia.at[self.creating_shape_index].set(1.0),
                    active=state.polygon.active.at[self.creating_shape_index].set(True),
                    collision_mode=state.polygon.collision_mode.at[self.creating_shape_index].set(1),
                    n_vertices=state.polygon.n_vertices.at[self.creating_shape_index].set(4),
                ),
                polygon_shape_roles=state.polygon_shape_roles.at[self.creating_shape_index].set(0),
                polygon_highlighted=state.polygon_highlighted.at[self.creating_shape_index].set(highlight),
                polygon_densities=state.polygon_densities.at[self.creating_shape_index].set(1.0),
            )

            return state

        if left_click:
            if self.creating_shape:
                env_state = _add_rect_to_state(pcg_state.env_state, False)
                env_state_max = _add_rect_to_state(pcg_state.env_state_max, False)

                env_state = recalculate_mass_and_inertia(
                    env_state,
                    self.static_env_params,
                    env_state.polygon_densities,
                    env_state.circle_densities,
                )
                env_state_max = recalculate_mass_and_inertia(
                    env_state_max,
                    self.static_env_params,
                    env_state.polygon_densities,
                    env_state.circle_densities,
                )

                pcg_state = pcg_state.replace(env_state=env_state, env_state_max=env_state_max)

                self.creating_shape = False
            else:
                self.creating_shape_index = jnp.argmin(pcg_state.env_state.polygon.active)
                self.create_shape_position = self._get_mouse_position_world_space()
                self.creating_shape = True
        else:
            if self.creating_shape:
                env_state = _add_rect_to_state(pcg_state.env_state, True)
                env_state_max = _add_rect_to_state(pcg_state.env_state_max, True)

                pcg_state = pcg_state.replace(env_state=env_state, env_state_max=env_state_max)
            else:
                pass

        return pcg_state

    def _reset_triangles(self):
        self.triangle_order = jnp.array([0, 1, 2])
        self.num_triangle_clicks = 0
        self.creating_shape = False

    def _edit_triangle(self, pcg_state: PCGState, left_click: bool, right_click: bool):
        if right_click:
            self.num_triangle_clicks = 0
            for ri in self._get_polygons_on_mouse(pcg_state.env_state, n_vertices=3):
                attached_r, attached_f = self._get_joints_attached_to_shape(pcg_state.env_state, ri)

                def _remove_triangle_from_state(state):
                    state = state.replace(
                        polygon=state.polygon.replace(
                            active=state.polygon.active.at[ri].set(False),
                            rotation=state.polygon.rotation.at[ri].set(0.0),
                        ),
                        joint=state.joint.replace(active=state.joint.active.at[attached_r].set(False)),
                    )

                    return state

                env_state = _remove_triangle_from_state(pcg_state.env_state)
                env_state_max = _remove_triangle_from_state(pcg_state.env_state_max)
                env_state_pcg_mask = _remove_triangle_from_state(pcg_state.env_state_pcg_mask)

                env_state = env_state.replace(
                    collision_matrix=calculate_collision_matrix(self.static_env_params, env_state.joint)
                )

                pcg_state = PCGState(
                    env_state=env_state,
                    env_state_max=env_state_max,
                    env_state_pcg_mask=env_state_pcg_mask,
                    tied_together=pcg_state.tied_together,
                )

                return pcg_state

        if not self.creating_shape and (1 - pcg_state.env_state.polygon.active.astype(int)).sum() == 0:
            return pcg_state

        def get_correct_center_two_verts(verts):
            return (jnp.max(verts, axis=0) + jnp.min(verts, axis=0)) / 2

        def order_clockwise(verts, loose_ordering=False):
            # verts has shape (3, 2), order them clockwise.
            # https://stackoverflow.com/questions/51074984/sort-vertices-in-clockwise-order
            # Calculate centroid
            centroid = jnp.mean(verts, axis=0)
            # Calculate angles
            angles = jnp.round(jnp.arctan2(verts[:, 1] - centroid[1], verts[:, 0] - centroid[0]), 2)
            # Order vertices
            order = jnp.argsort(-angles, stable=True)
            if loose_ordering:
                order = jnp.arange(len(order))
            ans = verts[order]
            # order is of shape (2, ) or (3, ). I want it to always be of shape 3
            if len(order) < 3:
                order = jnp.concatenate([order, jnp.array([2])])
            return ans, order

        def do_triangle_n_click(pcg_state, how_many_clicks, is_on_a_click=False):
            n = how_many_clicks
            # if we must keep them clockwise all the time, then the one we edit / move around may have varying indices.
            current_index_to_change = self.triangle_order[n]
            sign = 1
            idxs = jnp.arange(n + 1)
            idxs_to_allow = idxs[~(idxs == current_index_to_change)]

            # Get the new vertex and clip its position
            new_tentative_vert = (
                self._get_mouse_position_world_space() - pcg_state.env_state.polygon.position[self.creating_shape_index]
            )
            new_tentative_vert = jnp.clip(
                new_tentative_vert,
                jnp.max(pcg_state.env_state.polygon.vertices[self.creating_shape_index, idxs_to_allow], axis=0)
                - self.static_env_params.max_shape_size * 0.8,
                jnp.min(pcg_state.env_state.polygon.vertices[self.creating_shape_index, idxs_to_allow], axis=0)
                + self.static_env_params.max_shape_size * 0.8,
            )
            new_verts = pcg_state.env_state.polygon.vertices.at[self.creating_shape_index, current_index_to_change].set(
                new_tentative_vert
            )
            new_center_two = get_correct_center_two_verts(new_verts[self.creating_shape_index, : n + 1])

            _, new_center_three = calc_inverse_mass_polygon(
                new_verts[self.creating_shape_index],
                3,
                self.static_env_params,
                1.0,
            )
            new_center = jax.lax.select(n == 1, new_center_two, new_center_three)

            new_verts = new_verts.at[self.creating_shape_index].add(-sign * new_center)
            vvs = new_verts[self.creating_shape_index, : n + 1]

            ordered_vertices, new_permutation = order_clockwise(vvs, loose_ordering=not is_on_a_click)
            self.triangle_order = self.triangle_order[new_permutation]
            new_verts = new_verts.at[self.creating_shape_index, : n + 1].set(ordered_vertices)

            env_state = pcg_state.env_state.replace(
                polygon=pcg_state.env_state.polygon.replace(
                    vertices=new_verts,
                    position=pcg_state.env_state.polygon.position.at[self.creating_shape_index].add(sign * new_center),
                    n_vertices=pcg_state.env_state.polygon.n_vertices.at[self.creating_shape_index].set(n + 1),
                ),
            )
            env_state_max = pcg_state.env_state_max.replace(
                polygon=pcg_state.env_state_max.polygon.replace(
                    vertices=new_verts,
                    position=pcg_state.env_state_max.polygon.position.at[self.creating_shape_index].add(
                        sign * new_center
                    ),
                    n_vertices=pcg_state.env_state_max.polygon.n_vertices.at[self.creating_shape_index].set(n + 1),
                ),
            )

            pcg_state = pcg_state.replace(env_state=env_state, env_state_max=env_state_max)

            return pcg_state

        if left_click:
            if self.creating_shape:
                assert 3 > self.num_triangle_clicks > 0
                if self.num_triangle_clicks == 1:
                    pcg_state = do_triangle_n_click(pcg_state, 1, is_on_a_click=True)
                    self.num_triangle_clicks += 1
                else:  # this finishes it
                    pcg_state = do_triangle_n_click(pcg_state, 2, is_on_a_click=True)
                    self.creating_shape = False
                    self.num_triangle_clicks = 0
                    pcg_state = pcg_state.replace(
                        env_state=recalculate_mass_and_inertia(
                            pcg_state.env_state,
                            self.static_env_params,
                            pcg_state.env_state.polygon_densities,
                            pcg_state.env_state.circle_densities,
                        )
                    )

            else:
                self.triangle_order = jnp.array([0, 1, 2])
                self.creating_shape_index = jnp.argmin(pcg_state.env_state.polygon.active)
                self.create_shape_position = self._get_mouse_position_world_space()
                self.creating_shape = True
                self.num_triangle_clicks = 1
                vertices = jnp.zeros((self.static_env_params.max_polygon_vertices, 2), dtype=jnp.float32)

                def _add_triangle_to_state(state):
                    state = state.replace(
                        polygon=state.polygon.replace(
                            position=state.polygon.position.at[self.creating_shape_index].set(
                                self.create_shape_position
                            ),
                            velocity=state.polygon.velocity.at[self.creating_shape_index].set(jnp.array([0.0, 0.0])),
                            vertices=state.polygon.vertices.at[self.creating_shape_index].set(vertices),
                            inverse_mass=state.polygon.inverse_mass.at[self.creating_shape_index].set(1.0),
                            inverse_inertia=state.polygon.inverse_inertia.at[self.creating_shape_index].set(1.0),
                            active=state.polygon.active.at[self.creating_shape_index].set(True),
                            n_vertices=state.polygon.n_vertices.at[self.creating_shape_index].set(1),
                        ),
                        polygon_shape_roles=state.polygon_shape_roles.at[self.creating_shape_index].set(0),
                        polygon_highlighted=state.polygon_highlighted.at[self.creating_shape_index].set(False),
                        polygon_densities=state.polygon_densities.at[self.creating_shape_index].set(1.0),
                    )

                    return state

                pcg_state = pcg_state.replace(
                    env_state=_add_triangle_to_state(pcg_state.env_state),
                    env_state_max=_add_triangle_to_state(pcg_state.env_state_max),
                )

        elif self.creating_shape:
            assert 1 <= self.num_triangle_clicks <= 2
            pcg_state = do_triangle_n_click(
                pcg_state, self.num_triangle_clicks, is_on_a_click=self.num_triangle_clicks == 1
            )

        return pcg_state

    def _edit_joint(self, pcg_state: PCGState, left_click: bool, right_click: bool):
        if left_click and pcg_state.env_state.joint.active.all():
            return pcg_state

        if left_click:
            joint_index = jnp.argmin(pcg_state.env_state.joint.active)
            joint_position = self._get_mouse_position_world_space()
            # reverse them so that the joint order and rendering order remains the same.
            # We want the first shape to have a lower index than the second shape, with circles always having higher indices compared to rectangles.
            circles = self._get_circles_on_mouse(pcg_state.env_state)[::-1]
            rects = self._get_polygons_on_mouse(pcg_state.env_state)[::-1]

            if len(rects) + len(circles) >= 2:
                r1 = len(rects) >= 1
                r2 = len(rects) >= 2

                a_index = rects[0] if r1 else circles[0]  # + self.static_env_params.num_polygons
                b_index = rects[r1 * 1] if r2 else circles[1 - 1 * r1]  # + self.static_env_params.num_polygons

                a_shape = pcg_state.env_state.polygon if r1 else pcg_state.env_state.circle
                b_shape = pcg_state.env_state.polygon if r2 else pcg_state.env_state.circle

                a = jax.tree.map(lambda x: x[a_index], a_shape)
                b = jax.tree.map(lambda x: x[b_index], b_shape)

                a_index += (not r1) * self.static_env_params.num_polygons
                b_index += (not r2) * self.static_env_params.num_polygons

                joint_position = snap_to_center(a, joint_position)
                joint_position = snap_to_center(b, joint_position)

                a_relative_pos = jnp.matmul(rmat(a.rotation).transpose((1, 0)), joint_position - a.position)
                b_relative_pos = jnp.matmul(rmat(b.rotation).transpose((1, 0)), joint_position - b.position)

                def _add_joint_to_state(state):
                    state = state.replace(
                        joint=state.joint.replace(
                            a_index=state.joint.a_index.at[joint_index].set(a_index),
                            b_index=state.joint.b_index.at[joint_index].set(b_index),
                            a_relative_pos=state.joint.a_relative_pos.at[joint_index].set(a_relative_pos),
                            b_relative_pos=state.joint.b_relative_pos.at[joint_index].set(b_relative_pos),
                            active=state.joint.active.at[joint_index].set(True),
                            global_position=state.joint.global_position.at[joint_index].set(joint_position),
                            motor_on=state.joint.motor_on.at[joint_index].set(True),
                            motor_speed=state.joint.motor_speed.at[joint_index].set(1.0),
                            motor_power=state.joint.motor_power.at[joint_index].set(1.0),
                            rotation=state.joint.rotation.at[joint_index].set(b.rotation - a.rotation),
                        )
                    )

                    return state

                env_state = _add_joint_to_state(pcg_state.env_state)
                env_state_max = _add_joint_to_state(pcg_state.env_state_max)

                env_state = env_state.replace(
                    collision_matrix=calculate_collision_matrix(self.static_env_params, env_state.joint)
                )

                pcg_state = pcg_state.replace(env_state=env_state, env_state_max=env_state_max)

        return pcg_state

    def _reset_select_shape(self, pcg_state):
        pcg_state = pcg_state.replace(
            env_state=pcg_state.env_state.replace(
                polygon_highlighted=jnp.zeros_like(pcg_state.env_state.polygon_highlighted),
                circle_highlighted=jnp.zeros_like(pcg_state.env_state.circle_highlighted),
            )
        )
        self.selected_shape_index = -1
        self.selected_shape_type = ObjectType.POLYGON
        self._hide_all_widgets()
        return pcg_state

    def _hide_all_widgets(self):
        for widget in self.all_widgets.values():
            for w in widget.values():
                w.hide()

    def _show_correct_widgets(self, type: ObjectType | None, do_tie_ui: bool = False):
        for widget in self.all_widgets["GENERAL"].values():
            widget.show()
        if do_tie_ui:
            for widget in self.all_widgets["TIE_TOGETHER"].values():
                widget.show()

            n = len(self.all_selected_shapes)
            #  {[int(i) for (i, t) in self.all_selected_shapes]}
            self.all_widgets["GENERAL"]["lblGeneral"].setText(f"Selected {n} Objects")
            return

        for widget in self.all_widgets[type].values():
            widget.show()
        if type is None:
            self.all_widgets["GENERAL"]["lblGeneral"].setText(f"Global")
        else:
            self.all_widgets["GENERAL"]["lblGeneral"].setText(f"{type.name} (idx {self.selected_shape_index})")

    def _select_shape_keyboard_shortcuts(self, pcg_state: PCGState, left_click: bool, keys: list[int]):
        if left_click:
            return pcg_state
        if len(keys) != 0 and self.selected_shape_index != -1:
            s = 1.0
            ang_s = 0.1
            vel = jnp.array([0.0, 0.0])
            angular_vel = 0.0
            should_toggle_fixed = False
            should_toggle_collidable = False
            change_angle = 0

            def add_step(widget_name, direction, speed=10, overwrite_amount=None):
                widget = self.all_widgets[self.selected_shape_type][widget_name]
                val = widget.getValue()
                step = widget.step
                amount_to_add = overwrite_amount or step * direction * speed
                widget.setValue(jnp.clip(val + amount_to_add, widget.min, widget.max))

            if pygame.K_w in keys:
                add_step("sldPosition_Y", 1)
            if pygame.K_s in keys:
                add_step("sldPosition_Y", -1)
            if pygame.K_a in keys:
                add_step("sldPosition_X", -1)
            if pygame.K_d in keys:
                add_step("sldPosition_X", 1)

            if pygame.K_q in keys:
                add_step("sldRotation", 1)
            if pygame.K_e in keys:
                add_step("sldRotation", -1)

            if pygame.K_f in keys:
                self.all_widgets[self.selected_shape_type]["tglFixate"].toggle()

            if pygame.K_c in keys and not (pygame.key.get_mods() & pygame.KMOD_CTRL):
                widget = self.all_widgets[self.selected_shape_type]["sldCollidability"]
                curr_val = int(widget.getValue())
                widget.setValue((curr_val + 1) % (widget.max + 1))

            if pygame.K_r in keys and not (pygame.key.get_mods() & pygame.KMOD_CTRL):
                widget = self.all_widgets[self.selected_shape_type]["sldRole"]
                curr_val = int(widget.getValue())
                widget.setValue((curr_val + 1) % (widget.max + 1))

            if pygame.K_LEFTBRACKET in keys:
                add_step("sldRotation", 1, 10, jnp.pi / 4)

            if pygame.K_RIGHTBRACKET in keys:
                add_step("sldRotation", -1, 10, -jnp.pi / 4)

            if pygame.K_c in keys and (pygame.key.get_mods() & pygame.KMOD_CTRL):
                # copy
                if self.selected_shape_type == ObjectType.POLYGON:  # rect
                    if not self.pcg_state.env_state.polygon.active.all():
                        where_to_add = jnp.argmin(pcg_state.env_state.polygon.active)
                        if where_to_add < self.static_env_params.num_polygons:

                            def _copy_polygon(state, shift):
                                state = state.replace(
                                    polygon=jax.tree.map(
                                        lambda x: x.at[where_to_add].set(x[self.selected_shape_index]), state.polygon
                                    )
                                )
                                if shift:
                                    state = state.replace(
                                        polygon=state.polygon.replace(
                                            position=state.polygon.position.at[where_to_add].add(0.1),
                                        ),
                                        polygon_highlighted=state.polygon_highlighted.at[where_to_add].set(False),
                                    )
                                return state

                            pcg_state = pcg_state.replace(
                                env_state=_copy_polygon(pcg_state.env_state, shift=True),
                                env_state_max=_copy_polygon(pcg_state.env_state_max, shift=True),
                                env_state_pcg_mask=_copy_polygon(pcg_state.env_state_pcg_mask, shift=False),
                            )

                elif self.selected_shape_type == ObjectType.CIRCLE:  # circle
                    if not self.pcg_state.env_state.circle.active.all():
                        where_to_add = jnp.argmin(pcg_state.env_state.circle.active)
                        if where_to_add < self.static_env_params.num_circles:

                            def _copy_circle(state, shift=True):
                                state = state.replace(
                                    circle=jax.tree.map(
                                        lambda x: x.at[where_to_add].set(x[self.selected_shape_index]), state.circle
                                    )
                                )
                                if shift:
                                    state = state.replace(
                                        circle=state.circle.replace(
                                            position=state.circle.position.at[where_to_add].add(0.1),
                                        ),
                                        circle_highlighted=state.circle_highlighted.at[where_to_add].set(False),
                                    )
                                return state

                            pcg_state = pcg_state.replace(
                                env_state=_copy_circle(pcg_state.env_state),
                                env_state_max=_copy_circle(pcg_state.env_state_max),
                                env_state_pcg_mask=_copy_circle(pcg_state.env_state_pcg_mask, shift=False),
                            )

            if self.selected_shape_index >= 0:
                num = get_numeric_key_pressed(self.pygame_events)
                if num is not None:
                    if self.selected_shape_type in [ObjectType.CIRCLE, ObjectType.POLYGON]:
                        self.all_widgets[self.selected_shape_type]["sldRole"].setValue(num % 4)
                    elif self.selected_shape_type == ObjectType.JOINT:
                        self.all_widgets[self.selected_shape_type]["sldColour"].setValue(
                            num % self.static_env_params.num_motor_bindings
                        )
                    elif self.selected_shape_type == ObjectType.THRUSTER:
                        self.all_widgets[self.selected_shape_type]["sldColour"].setValue(
                            num % self.static_env_params.num_thruster_bindings
                        )

        return pcg_state

    def _edit_select_shape(self, pcg_state: PCGState, left_click: bool, right_click: bool, keys: list[int]):
        def _find_shape(pcg_state):
            found_shape = False
            selected_shape_index, selected_shape_type = -1, ObjectType.POLYGON
            for ri in self._get_revolute_joints_on_mouse(pcg_state.env_state):
                selected_shape_index = ri
                selected_shape_type = ObjectType.JOINT
                found_shape = True
                break
            if not found_shape:
                for ti in self._get_thrusters_on_mouse(pcg_state.env_state):
                    selected_shape_index = ti
                    selected_shape_type = ObjectType.THRUSTER
                    found_shape = True
                    break
            if not found_shape:
                for ri in self._get_polygons_on_mouse(pcg_state.env_state):
                    pcg_state = pcg_state.replace(
                        env_state=pcg_state.env_state.replace(
                            polygon_highlighted=pcg_state.env_state.polygon_highlighted.at[ri].set(True),
                        )
                    )
                    selected_shape_index = ri
                    selected_shape_type = ObjectType.POLYGON
                    found_shape = True
                    break
            if not found_shape:
                for ci in self._get_circles_on_mouse(pcg_state.env_state):
                    pcg_state = pcg_state.replace(
                        env_state=pcg_state.env_state.replace(
                            circle_highlighted=pcg_state.env_state.circle_highlighted.at[ci].set(True),
                        )
                    )
                    selected_shape_index = ci
                    selected_shape_type = ObjectType.CIRCLE
                    found_shape = True
                    break
            return selected_shape_index, selected_shape_type, found_shape, pcg_state
            if found_shape and self.selected_shape_type in self.all_widgets:
                self._show_correct_widgets(self.selected_shape_type)

        # if left and shift
        if left_click and (pygame.key.get_mods() & pygame.KMOD_SHIFT):
            # This is trying to select multiple things.
            idx, type, found, pcg_state = _find_shape(pcg_state)
            if found:
                t = (idx, type)
                if t in self.all_selected_shapes:
                    self.all_selected_shapes.remove(t)
                else:
                    self.all_selected_shapes.append(t)

            self._hide_all_widgets()
            self._show_correct_widgets(None, do_tie_ui=True)
        elif left_click:
            self.all_selected_shapes = []
            self._hide_all_widgets()
            pcg_state = self._reset_select_shape(pcg_state)
            self.selected_shape_index, self.selected_shape_type, found_shape, pcg_state = _find_shape(pcg_state)
            if found_shape:
                self.all_selected_shapes = [(self.selected_shape_index, self.selected_shape_type)]
                if self.selected_shape_type in self.all_widgets:
                    self._show_correct_widgets(self.selected_shape_type)
        if self.selected_shape_index < 0:
            self._show_correct_widgets(None)

        return pcg_state

    def render(self, env_state):
        # Clear
        self.screen_surface.fill((0, 0, 0))

        if self.is_editing:
            pixels = self._render_fn_edit(env_state)
        else:
            pixels = self._render_fn(env_state)
        pixels = self._render_edit_overlay_fn(pixels, self.is_editing, self.edit_shape_mode.value)

        surface = pygame.surfarray.make_surface(np.array(pixels))
        self.screen_surface.blit(surface, (0, 0))

    def is_quit_requested(self):
        for event in self.pygame_events:
            if event.type == pygame.QUIT:
                return True
        return False


@hydra.main(version_base=None, config_path="../configs", config_name="editor")
def main(config):
    config = normalise_config(OmegaConf.to_container(config), "EDITOR", editor_config=True)
    env_params, static_env_params = generate_params_from_config(config)
    static_env_params = static_env_params.replace(frame_skip=config["frame_skip"], downscale=config["downscale"])
    config["env_params"] = to_state_dict(env_params)
    config["static_env_params"] = to_state_dict(static_env_params)

    env = make_kinetix_env_from_name("Kinetix-Entity-MultiDiscrete-v1", static_env_params=static_env_params)
    env = AutoResetWrapper(env, make_reset_function(static_env_params))

    seed = config["seed"]
    print("seed", seed)
    rng = jax.random.PRNGKey(seed)
    outer_timer = tmr()
    editor = Editor(env, env_params, config, upscale=config["upscale"])
    time_e = tmr()
    print("Took {:2f}s to create editor".format(time_e - outer_timer))

    clock = pygame.time.Clock()
    while not editor.is_quit_requested():
        rng, _rng = jax.random.split(rng)
        editor.update(_rng)
        clock.tick(config["fps"])


if __name__ == "__main__":
    main()