Skip to content

SlidingTilePuzzle

Bases: Environment[State, DiscreteArray, Observation]

Environment for the Sliding Tile Puzzle problem.

The problem is a combinatorial optimization task where the goal is to move the empty tile around in order to arrange all the tiles in order. See more info: https://en.wikipedia.org/wiki/Sliding_puzzle.

  • observation: Observation

    • puzzle: jax array (int32) of shape (N, N), representing the current state of the puzzle.
    • empty_tile_position: Tuple of int32, representing the position of the empty tile.
    • action_mask: jax array (bool) of shape (4,), indicating which actions are valid in the current state of the environment.
  • action: int32, representing the direction to move the empty tile (up, down, left, right)

  • reward: float, a dense reward is provided based on the arrangement of the tiles. It equals the negative sum of the boolean difference between the current state of the puzzle and the goal state (correctly arranged tiles). Each incorrectly placed tile contributes -1 to the reward.

  • episode termination: if the puzzle is solved.

  • state: State

    • puzzle: jax array (int32) of shape (N, N), representing the current state of the puzzle.
    • empty_tile_position: Tuple of int32, representing the position of the empty tile.
    • key: jax array (uint32) of shape (2,), random key used to generate random numbers at each step and for auto-reset.

Instantiate a SlidingTilePuzzle environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

callable to instantiate environment instances. Defaults to RandomWalkGenerator which generates shuffled puzzles with a size of 5x5.

None
reward_fn Optional[RewardFn]

RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action and the next state. Implemented options are [DenseRewardFn, SparseRewardFn]. Defaults to DenseRewardFn.

None
time_limit int

maximum number of steps before the episode is terminated, default to 500.

500
viewer Optional[Viewer[State]]

environment viewer for rendering.

None
Source code in jumanji/environments/logic/sliding_tile_puzzle/env.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def __init__(
    self,
    generator: Optional[Generator] = None,
    reward_fn: Optional[RewardFn] = None,
    time_limit: int = 500,
    viewer: Optional[Viewer[State]] = None,
) -> None:
    """Instantiate a `SlidingTilePuzzle` environment.

    Args:
        generator: callable to instantiate environment instances.
            Defaults to `RandomWalkGenerator` which generates shuffled puzzles with
            a size of 5x5.
        reward_fn: RewardFn whose `__call__` method computes the reward of an environment
            transition. The function must compute the reward based on the current state,
            the chosen action and the next state.
            Implemented options are [`DenseRewardFn`, `SparseRewardFn`].
            Defaults to `DenseRewardFn`.
        time_limit: maximum number of steps before the episode is terminated, default to 500.
        viewer: environment viewer for rendering.
    """
    self.generator = generator or RandomWalkGenerator(grid_size=5, num_random_moves=200)
    self.reward_fn = reward_fn or DenseRewardFn()
    self.time_limit = time_limit
    super().__init__()

    # Create viewer used for rendering
    self._env_viewer = viewer or SlidingTilePuzzleViewer(name="SlidingTilePuzzle")
    self.solved_puzzle = self.generator.make_solved_puzzle()

action_spec: specs.DiscreteArray cached property #

Returns the action spec.

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec.

animate(states, interval=200, save_path=None) #

Creates an animated gif of the puzzle board based on the sequence of game states.

Parameters:

Name Type Description Default
states Sequence[State]

is a list of State objects representing the sequence of game states.

required
interval int

the delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/logic/sliding_tile_puzzle/env.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> animation.FuncAnimation:
    """Creates an animated gif of the puzzle board based on the sequence of game states.

    Args:
        states: is a list of `State` objects representing the sequence of game states.
        interval: the delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
        will not be stored.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._env_viewer.animate(states=states, interval=interval, save_path=save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/logic/sliding_tile_puzzle/env.py
274
275
276
277
278
279
280
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._env_viewer.close()

render(state) #

Renders the current state of the puzzle board.

Parameters:

Name Type Description Default
state State

is the current game state to be rendered.

required
Source code in jumanji/environments/logic/sliding_tile_puzzle/env.py
247
248
249
250
251
252
253
def render(self, state: State) -> Optional[NDArray]:
    """Renders the current state of the puzzle board.

    Args:
        state: is the current game state to be rendered.
    """
    return self._env_viewer.render(state=state)

reset(key) #

Resets the environment to an initial state.

Source code in jumanji/environments/logic/sliding_tile_puzzle/env.py
104
105
106
107
108
109
110
111
112
113
114
115
116
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment to an initial state."""
    key, subkey = jax.random.split(key)
    state = self.generator(subkey)
    action_mask = self._get_valid_actions(state.empty_tile_position)
    obs = Observation(
        puzzle=state.puzzle,
        empty_tile_position=state.empty_tile_position,
        action_mask=action_mask,
        step_count=state.step_count,
    )
    timestep = restart(observation=obs, extras=self._get_extras(state))
    return state, timestep

step(state, action) #

Updates the environment state after the agent takes an action.

Source code in jumanji/environments/logic/sliding_tile_puzzle/env.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Updates the environment state after the agent takes an action."""
    (updated_puzzle, updated_empty_tile_position) = self._move_empty_tile(
        state.puzzle, state.empty_tile_position, action
    )
    # Check if the puzzle is solved
    done = jnp.array_equal(updated_puzzle, self.solved_puzzle)

    # Update the action mask
    action_mask = self._get_valid_actions(updated_empty_tile_position)

    next_state = State(
        puzzle=updated_puzzle,
        empty_tile_position=updated_empty_tile_position,
        key=state.key,
        step_count=state.step_count + 1,
    )
    obs = Observation(
        puzzle=updated_puzzle,
        empty_tile_position=updated_empty_tile_position,
        action_mask=action_mask,
        step_count=next_state.step_count,
    )

    reward = self.reward_fn(state, action, next_state, self.solved_puzzle)
    extras = self._get_extras(next_state)

    timestep = jax.lax.cond(
        done | (next_state.step_count >= self.time_limit),
        lambda: termination(
            reward=reward,
            observation=obs,
            extras=extras,
        ),
        lambda: transition(
            reward=reward,
            observation=obs,
            extras=extras,
        ),
    )

    return next_state, timestep