RubiksCube
RubiksCube (Environment)
#
A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset.
-
observation:
Observation
- cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble.
- step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset.
-
action: multi discrete array containing the move to perform (face, depth, and direction).
-
reward: jax array (float) of shape (): by default, 1.0 if cube is solved, otherwise 0.0.
-
episode termination: if either the cube is solved or a time limit is reached.
-
state:
State
- cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble.
- step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset.
- key: jax array (uint) of shape (2,) used for seeding the sampling for scrambling on reset.
1 2 3 4 5 6 7 8 |
|
observation_spec: jumanji.specs.Spec[jumanji.environments.logic.rubiks_cube.types.Observation]
cached
property
writable
#
Specifications of the observation of the RubiksCube
environment.
Returns:
Type | Description |
---|---|
Spec containing all the specifications for all the `Observation` fields |
|
action_spec: MultiDiscreteArray
cached
property
writable
#
Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions.
Returns:
Type | Description |
---|---|
action_spec |
|
__init__(self, generator: Optional[jumanji.environments.logic.rubiks_cube.generator.Generator] = None, time_limit: int = 200, reward_fn: Optional[jumanji.environments.logic.rubiks_cube.reward.RewardFn] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.logic.rubiks_cube.types.State]] = None)
special
#
Instantiate a RubiksCube
environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
generator |
Optional[jumanji.environments.logic.rubiks_cube.generator.Generator] |
|
None |
time_limit |
int |
the number of steps allowed before an episode terminates. Defaults to 200. |
200 |
reward_fn |
Optional[jumanji.environments.logic.rubiks_cube.reward.RewardFn] |
|
None |
viewer |
Optional[jumanji.viewer.Viewer[jumanji.environments.logic.rubiks_cube.types.State]] |
|
None |
reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.logic.rubiks_cube.types.State, jumanji.types.TimeStep[jumanji.environments.logic.rubiks_cube.types.Observation]]
#
Resets the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray |
needed for scramble. |
required |
Returns:
Type | Description |
---|---|
state |
|
step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.logic.rubiks_cube.types.State, jumanji.types.TimeStep[jumanji.environments.logic.rubiks_cube.types.Observation]]
#
Run one timestep of the environment's dynamics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
|
required |
action |
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] |
|
required |
Returns:
Type | Description |
---|---|
next_state |
|