Skip to content

RubiksCube

RubiksCube (Environment) #

A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset.

  • observation: Observation

    • cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble.
    • step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset.
  • action: multi discrete array containing the move to perform (face, depth, and direction).

  • reward: jax array (float) of shape (): by default, 1.0 if cube is solved, otherwise 0.0.

  • episode termination: if either the cube is solved or a time limit is reached.

  • state: State

    • cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble.
    • step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset.
    • key: jax array (uint) of shape (2,) used for seeding the sampling for scrambling on reset.
1
2
3
4
5
6
7
8
from jumanji.environments import RubiksCube
env = RubiksCube()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.logic.rubiks_cube.types.Observation] cached property writable #

Specifications of the observation of the RubiksCube environment.

Returns:

Type Description
Spec containing all the specifications for all the `Observation` fields
  • cube: BoundedArray (jnp.int8) of shape (num_faces, cube_size, cube_size).
  • step_count: BoundedArray (jnp.int32) of shape ().

action_spec: MultiDiscreteArray cached property writable #

Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions.

Returns:

Type Description
action_spec

MultiDiscreteArray object.

__init__(self, generator: Optional[jumanji.environments.logic.rubiks_cube.generator.Generator] = None, time_limit: int = 200, reward_fn: Optional[jumanji.environments.logic.rubiks_cube.reward.RewardFn] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.logic.rubiks_cube.types.State]] = None) special #

Instantiate a RubiksCube environment.

Parameters:

Name Type Description Default
generator Optional[jumanji.environments.logic.rubiks_cube.generator.Generator]

Generator used to generate problem instances on environment reset. Implemented options are [ScramblingGenerator]. Defaults to ScramblingGenerator, with 100 scrambles on reset. The generator will contain an attribute cube_size, corresponding to the number of cubies to an edge, and defaulting to 3.

None
time_limit int

the number of steps allowed before an episode terminates. Defaults to 200.

200
reward_fn Optional[jumanji.environments.logic.rubiks_cube.reward.RewardFn]

RewardFn whose __call__ method computes the reward given the new state. Implemented options are [SparseRewardFn]. Defaults to SparseRewardFn, giving a reward of 1.0 if the cube is solved or otherwise 0.0.

None
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.rubiks_cube.types.State]]

Viewer to support rendering and animation methods. Implemented options are [RubiksCubeViewer]. Defaults to RubiksCubeViewer.

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.logic.rubiks_cube.types.State, jumanji.types.TimeStep[jumanji.environments.logic.rubiks_cube.types.Observation]] #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKeyArray

needed for scramble.

required

Returns:

Type Description
state

State corresponding to the new state of the environment. timestep: TimeStep corresponding to the first timestep returned by the environment.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.logic.rubiks_cube.types.State, jumanji.types.TimeStep[jumanji.environments.logic.rubiks_cube.types.Observation]] #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

Array of shape (3,) indicating the face to move, depth of the move, and the amount to move by.

required

Returns:

Type Description
next_state

State corresponding to the next state of the environment. next_timestep: TimeStep corresponding to the timestep returned by the environment.


Last update: 2024-11-01
Back to top