Skip to content

RubiksCube

Bases: Environment[State, MultiDiscreteArray, Observation]

A JAX implementation of the Rubik's Cube with a configurable cube size (by default, 3) and number of scrambles at reset.

  • observation: Observation

    • cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble.
    • step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset.
  • action: multi discrete array containing the move to perform (face, depth, and direction).

  • reward: jax array (float) of shape (): by default, 1.0 if cube is solved, otherwise 0.0.

  • episode termination: if either the cube is solved or a time limit is reached.

  • state: State

    • cube: jax array (int8) of shape (6, cube_size, cube_size): each cell contains the index of the corresponding colour of the sticker in the scramble.
    • step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset.
    • key: jax array (uint) of shape (2,) used for seeding the sampling for scrambling on reset.
1
2
3
4
5
6
7
8
from jumanji.environments import RubiksCube
env = RubiksCube()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiate a RubiksCube environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator used to generate problem instances on environment reset. Implemented options are [ScramblingGenerator]. Defaults to ScramblingGenerator, with 100 scrambles on reset. The generator will contain an attribute cube_size, corresponding to the number of cubies to an edge, and defaulting to 3.

None
time_limit int

the number of steps allowed before an episode terminates. Defaults to 200.

200
reward_fn Optional[RewardFn]

RewardFn whose __call__ method computes the reward given the new state. Implemented options are [SparseRewardFn]. Defaults to SparseRewardFn, giving a reward of 1.0 if the cube is solved or otherwise 0.0.

None
viewer Optional[Viewer[State]]

Viewer to support rendering and animation methods. Implemented options are [RubiksCubeViewer]. Defaults to RubiksCubeViewer.

None
Source code in jumanji/environments/logic/rubiks_cube/env.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(
    self,
    generator: Optional[Generator] = None,
    time_limit: int = 200,
    reward_fn: Optional[RewardFn] = None,
    viewer: Optional[Viewer[State]] = None,
):
    """Instantiate a `RubiksCube` environment.

    Args:
        generator: `Generator` used to generate problem instances on environment reset.
            Implemented options are [`ScramblingGenerator`]. Defaults to `ScramblingGenerator`,
            with 100 scrambles on reset.
            The generator will contain an attribute `cube_size`, corresponding to the number of
            cubies to an edge, and defaulting to 3.
        time_limit: the number of steps allowed before an episode terminates. Defaults to 200.
        reward_fn: `RewardFn` whose `__call__` method computes the reward given the new state.
            Implemented options are [`SparseRewardFn`]. Defaults to `SparseRewardFn`, giving a
            reward of 1.0 if the cube is solved or otherwise 0.0.
        viewer: `Viewer` to support rendering and animation methods.
            Implemented options are [`RubiksCubeViewer`]. Defaults to `RubiksCubeViewer`.
    """
    if time_limit <= 0:
        raise ValueError(
            f"The time_limit must be positive, but received time_limit={time_limit}"
        )
    self.time_limit = time_limit
    self.reward_function = reward_fn or SparseRewardFn()
    self.generator = generator or ScramblingGenerator(
        cube_size=3,
        num_scrambles_on_reset=100,
    )
    super().__init__()
    self._viewer = viewer or RubiksCubeViewer(
        sticker_colors=DEFAULT_STICKER_COLORS, cube_size=self.generator.cube_size
    )

action_spec: specs.MultiDiscreteArray cached property #

Returns the action spec. An action is composed of 3 elements that range in: 6 faces, each with cube_size//2 possible depths, and 3 possible directions.

Returns:

Name Type Description
action_spec MultiDiscreteArray

MultiDiscreteArray object.

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the RubiksCube environment.

Returns:

Type Description
Spec[Observation]

Spec containing all the specifications for all the Observation fields: - cube: BoundedArray (jnp.int8) of shape (num_faces, cube_size, cube_size). - step_count: BoundedArray (jnp.int32) of shape ().

animate(states, interval=200, save_path=None) #

Creates an animated gif of the cube based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

a list of State objects representing the sequence of game states.

required
interval int

the delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/logic/rubiks_cube/env.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the cube based on the sequence of states.

    Args:
        states: a list of `State` objects representing the sequence of game states.
        interval: the delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states=states, interval=interval, save_path=save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/logic/rubiks_cube/env.py
250
251
252
253
254
255
256
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Renders the current state of the cube.

Parameters:

Name Type Description Default
state State

the current state to be rendered.

required
Source code in jumanji/environments/logic/rubiks_cube/env.py
223
224
225
226
227
228
229
def render(self, state: State) -> Optional[NDArray]:
    """Renders the current state of the cube.

    Args:
        state: the current state to be rendered.
    """
    return self._viewer.render(state=state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

needed for scramble.

required

Returns:

Name Type Description
state State

State corresponding to the new state of the environment.

timestep TimeStep[Observation]

TimeStep corresponding to the first timestep returned by the environment.

Source code in jumanji/environments/logic/rubiks_cube/env.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: needed for scramble.

    Returns:
        state: `State` corresponding to the new state of the environment.
        timestep: `TimeStep` corresponding to the first timestep returned by the
            environment.
    """
    state = self.generator(key)
    observation = self._state_to_observation(state=state)
    timestep = restart(observation=observation)
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

Array of shape (3,) indicating the face to move, depth of the move, and the amount to move by.

required

Returns:

Name Type Description
next_state State

State corresponding to the next state of the environment.

next_timestep TimeStep[Observation]

TimeStep corresponding to the timestep returned by the environment.

Source code in jumanji/environments/logic/rubiks_cube/env.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: `State` object containing the dynamics of the environment.
        action: `Array` of shape (3,) indicating the face to move, depth of the move, and the
            amount to move by.

    Returns:
        next_state: `State` corresponding to the next state of the environment.
        next_timestep: `TimeStep` corresponding to the timestep returned by the environment.
    """
    flattened_action = flatten_action(
        unflattened_action=action, cube_size=self.generator.cube_size
    )
    cube = rotate_cube(
        cube=state.cube,
        flattened_action=flattened_action,
    )
    step_count = state.step_count + 1
    next_state = State(
        cube=cube,
        step_count=step_count,
        key=state.key,
    )
    reward = self.reward_function(state=next_state)
    solved = is_solved(cube)
    done = (step_count >= self.time_limit) | solved
    next_observation = self._state_to_observation(state=next_state)
    next_timestep = jax.lax.cond(
        done,
        termination,
        transition,
        reward,
        next_observation,
    )
    return next_state, next_timestep