Skip to content

Sokoban

Bases: Environment[State, DiscreteArray, Observation]

A JAX implementation of the 'Sokoban' game from deepmind.

  • observation: Observation

    • grid: jax array (uint8) of shape (num_rows, num_cols, 2) Array that includes information about the agent, boxes, and targets in the game.
    • step_count: jax array (int32) of shape () current number of steps in the episode.
  • action: jax array (int32) of shape () [0,1,2,3] -> [Up, Right, Down, Left].

  • reward: jax array (float) of shape () A reward of 1.0 is given for each box placed on a target and -1 when removed from a target and -0.1 for each timestep. 10 is awarded when all boxes are on targets.

  • episode termination:

    • if the time limit is reached.
    • if all boxes are on targets.
  • state: State

    • key: jax array (uint32) of shape (2,) used for auto-reset
    • fixed_grid: jax array (uint8) of shape (num_rows, num_cols) array indicating the walls and targets in the level.
    • variable_grid: jax array (uint8) of shape (num_rows, num_cols) array indicating the current location of the agent and boxes.
    • agent_location: jax array (int32) of shape (2,) the agent's current location.
    • step_count: jax array (int32) of shape () current number of steps in the episode.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from jumanji.environments import Sokoban
from jumanji.environments.routing.sokoban.generator import
HuggingFaceDeepMindGenerator,

env_train = Sokoban(
    generator=HuggingFaceDeepMindGenerator(
        dataset_name="unfiltered-train",
        proportion_of_files=1,
    )
)

env_test = Sokoban(
    generator=HuggingFaceDeepMindGenerator(
        dataset_name="unfiltered-test",
        proportion_of_files=1,
    )
)

# Train...
key_train = jax.random.PRNGKey(0) state, timestep = jax.jit(env_train.reset)(key_train) env_train.render(state) action = env_train.action_spec.generate_value() state, timestep = jax.jit(env_train.step)(state, action) env_train.render(state) ```

Instantiates a Sokoban environment with a specific generator, time limit, and viewer.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance (an initial state). Implemented options are [ToyGenerator, DeepMindGenerator, and HuggingFaceDeepMindGenerator]. Defaults to HuggingFaceDeepMindGenerator with dataset_name="unfiltered-train", proportion_of_files=1.

None
time_limit int

int, max steps for the environment, defaults to 120.

120
viewer Optional[Viewer]

'Viewer' object, used to render the environment.

None
Source code in jumanji/environments/routing/sokoban/env.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def __init__(
    self,
    generator: Optional[Generator] = None,
    reward_fn: Optional[RewardFn] = None,
    viewer: Optional[Viewer] = None,
    time_limit: int = 120,
) -> None:
    """
    Instantiates a `Sokoban` environment with a specific generator,
    time limit, and viewer.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment
         instance (an initial state). Implemented options are [`ToyGenerator`,
         `DeepMindGenerator`, and `HuggingFaceDeepMindGenerator`].
         Defaults to `HuggingFaceDeepMindGenerator` with
         `dataset_name="unfiltered-train", proportion_of_files=1`.
        time_limit: int, max steps for the environment, defaults to 120.
        viewer: 'Viewer' object, used to render the environment.
        If not provided, defaults to`BoxViewer`.
    """

    self.num_rows = GRID_SIZE
    self.num_cols = GRID_SIZE
    self.shape = (self.num_rows, self.num_cols)
    self.time_limit = time_limit

    super().__init__()

    self.generator = generator or HuggingFaceDeepMindGenerator(
        "unfiltered-train",
        proportion_of_files=1,
    )

    self._viewer = viewer or BoxViewer(
        name="Sokoban",
        grid_combine=self.grid_combine,
    )
    self.reward_fn = reward_fn or DenseReward()

action_spec: specs.DiscreteArray cached property #

Returns the action specification for the Sokoban environment. There are 4 actions: [0,1,2,3] -> [Up, Right, Down, Left].

Returns:

Type Description
DiscreteArray

specs.DiscreteArray: Discrete action specifications.

observation_spec: specs.Spec[Observation] cached property #

Returns the specifications of the observation of the Sokoban environment.

Returns:

Type Description
Spec[Observation]

specs.Spec[Observation]: The specifications of the observations.

__repr__() #

Returns a printable representation of the Sokoban environment.

Returns:

Name Type Description
str str

A string representation of the Sokoban environment.

Source code in jumanji/environments/routing/sokoban/env.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def __repr__(self) -> str:
    """
    Returns a printable representation of the Sokoban environment.

    Returns:
        str: A string representation of the Sokoban environment.
    """
    return "\n".join(
        [
            "Bokoban environment:",
            f" - num_rows: {self.num_rows}",
            f" - num_cols: {self.num_cols}",
            f" - time_limit: {self.time_limit}",
            f" - generator: {self.generator}",
        ]
    )

animate(states, interval=200, save_path=None) #

Creates an animated gif of the Sokoban environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

Sequence of 'State' object

required
interval int

int, The interval between frames in the animation.

200
save_path Optional[str]

str The path where to save the animation. If not

None

Returns:

Name Type Description
animation FuncAnimation

'matplotlib.animation.FuncAnimation'.

Source code in jumanji/environments/routing/sokoban/env.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """
    Creates an animated gif of the Sokoban environment based on the
    sequence of states.

    Args:
        states: Sequence of 'State' object
        interval: int, The interval between frames in the animation.
        Defaults to 200.
        save_path: str The path where to save the animation. If not
        provided, the animation is not saved.

    Returns:
        animation: 'matplotlib.animation.FuncAnimation'.
    """
    return self._viewer.animate(states, interval, save_path)

check_space(grid, location, value) #

Checks if a specific location in the grid contains a given value.

Parameters:

Name Type Description Default
grid Array

Array (uint8) shape (num_rows, num_cols) The grid to check.

required
location Array

Tuple size 2 of Array (int32) shape () containing the x

required
value int

int The value to look for.

required

Returns:

Name Type Description
present Array

Array (bool) shape () indicating whether the location

Array

in the grid contains the given value or not.

Source code in jumanji/environments/routing/sokoban/env.py
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def check_space(
    self,
    grid: chex.Array,
    location: chex.Array,
    value: int,
) -> chex.Array:
    """
    Checks if a specific location in the grid contains a given value.

    Args:
        grid: Array (uint8) shape (num_rows, num_cols) The grid to check.
        location: Tuple size 2 of Array (int32) shape () containing the x
        and y coodinate of the location to check in the grid.
        value: int The value to look for.

    Returns:
        present: Array (bool) shape () indicating whether the location
        in the grid contains the given value or not.
    """

    return grid[tuple(location)] == value

detect_noop_action(variable_grid, fixed_grid, action, agent_location) #

Masks actions to -1 that have no effect on the variable grid. Determines if there is space in the destination square or if there is a box in the destination square, it determines if the box destination square is valid.

Parameters:

Name Type Description Default
variable_grid Array

Array (uint8) shape (num_rows, num_cols).

required
action Array

Array (int32) shape () The action to check.

required

Returns:

Name Type Description
updated_action Array

Array (int32) shape () The updated action after

Array

detecting noop action.

Source code in jumanji/environments/routing/sokoban/env.py
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def detect_noop_action(
    self,
    variable_grid: chex.Array,
    fixed_grid: chex.Array,
    action: chex.Array,
    agent_location: chex.Array,
) -> chex.Array:
    """
    Masks actions to -1 that have no effect on the variable grid.
    Determines if there is space in the destination square or if
    there is a box in the destination square, it determines if the box
    destination square is valid.

    Args:
        variable_grid: Array (uint8) shape (num_rows, num_cols).
        fixed_grid Array (uint8) shape (num_rows, num_cols) .
        action: Array (int32) shape () The action to check.

    Returns:
        updated_action: Array (int32) shape () The updated action after
        detecting noop action.
    """

    new_location = agent_location + MOVES[action].squeeze()

    valid_destination = self.check_space(fixed_grid, new_location, WALL) | ~self.in_grid(
        new_location
    )

    updated_action = jax.lax.select(
        valid_destination,
        jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32),
        jax.lax.select(
            self.check_space(variable_grid, new_location, BOX),
            self.update_box_push_action(
                fixed_grid,
                variable_grid,
                new_location,
                action,
            ),
            action,
        ),
    )

    return updated_action

grid_combine(variable_grid, fixed_grid) #

Combines the variable grid and fixed grid into one single grid representation of the current Sokoban state required for visual representation of the Sokoban state. Takes care of two possible overlaps of fixed and variable entries (an agent on a target or a box on a target), introducing two additional encodings.

Parameters:

Name Type Description Default
variable_grid Array

Array (uint8) of shape (num_rows, num_cols).

required
fixed_grid Array

Array (uint8) of shape (num_rows, num_cols).

required

Returns:

Name Type Description
full_grid Array

Array (uint8) of shape (num_rows, num_cols, 2).

Source code in jumanji/environments/routing/sokoban/env.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def grid_combine(self, variable_grid: chex.Array, fixed_grid: chex.Array) -> chex.Array:
    """
    Combines the variable grid and fixed grid into one single grid
    representation of the current Sokoban state required for visual
    representation of the Sokoban state. Takes care of two possible
    overlaps of fixed and variable entries (an agent on a target or a box
    on a target), introducing two additional encodings.

    Args:
        variable_grid: Array (uint8) of shape (num_rows, num_cols).
        fixed_grid: Array (uint8) of shape (num_rows, num_cols).

    Returns:
        full_grid: Array (uint8) of shape (num_rows, num_cols, 2).
    """

    mask_target_agent = jnp.logical_and(
        fixed_grid == TARGET,
        variable_grid == AGENT,
    )

    mask_target_box = jnp.logical_and(
        fixed_grid == TARGET,
        variable_grid == BOX,
    )

    single_grid = jnp.where(
        mask_target_agent,
        TARGET_AGENT,
        jnp.where(
            mask_target_box,
            TARGET_BOX,
            jnp.maximum(variable_grid, fixed_grid),
        ),
    ).astype(jnp.uint8)

    return single_grid

in_grid(coordinates) #

Checks if given coordinates are within the grid size.

Parameters:

Name Type Description Default
coordinates Array

Array (uint8) shape (num_rows, num_cols) The

required

Returns: in_grid: Array (bool) shape () Boolean indicating whether the coordinates are within the grid.

Source code in jumanji/environments/routing/sokoban/env.py
405
406
407
408
409
410
411
412
413
414
415
416
def in_grid(self, coordinates: chex.Array) -> chex.Array:
    """
    Checks if given coordinates are within the grid size.

    Args:
        coordinates: Array (uint8) shape (num_rows, num_cols) The
        coordinates to check.
    Returns:
        in_grid: Array (bool) shape () Boolean indicating whether the
        coordinates are within the grid.
    """
    return jnp.all((0 <= coordinates) & (coordinates < GRID_SIZE))

level_complete(state) #

Checks if the sokoban level is complete.

Parameters:

Name Type Description Default
state State

State object representing the current state of the environment.

required

Returns:

Name Type Description
complete Array

Boolean indicating whether the level is complete

Array

or not.

Source code in jumanji/environments/routing/sokoban/env.py
370
371
372
373
374
375
376
377
378
379
380
381
def level_complete(self, state: State) -> chex.Array:
    """
    Checks if the sokoban level is complete.

    Args:
        state: `State` object representing the current state of the environment.

    Returns:
        complete: Boolean indicating whether the level is complete
        or not.
    """
    return self.reward_fn.count_targets(state) == N_BOXES

move_agent(variable_grid, action, current_location) #

Executes the movement of the agent specified by the action and executes the movement of a box if present at the destination.

Parameters:

Name Type Description Default
variable_grid Array

Array (uint8) shape (num_rows, num_cols)

required
action Array

Array (int32) shape () The action to take.

required
current_location Array

Array (int32) shape (2,)

required

Returns:

Name Type Description
next_variable_grid Array

Array (uint8) shape (num_rows, num_cols)

next_location Array

Array (int32) shape (2,)

Source code in jumanji/environments/routing/sokoban/env.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
def move_agent(
    self,
    variable_grid: chex.Array,
    action: chex.Array,
    current_location: chex.Array,
) -> Tuple[chex.Array, chex.Array]:
    """
    Executes the movement of the agent specified by the action and
    executes the movement of a box if present at the destination.

    Args:
        variable_grid: Array (uint8) shape (num_rows, num_cols)
        action: Array (int32) shape () The action to take.
        current_location: Array (int32) shape (2,)

    Returns:
        next_variable_grid: Array (uint8) shape (num_rows, num_cols)
        next_location: Array (int32) shape (2,)
    """

    next_location = current_location + MOVES[action]
    box_location = next_location + MOVES[action]

    # remove agent from current location
    next_variable_grid = variable_grid.at[tuple(current_location)].set(EMPTY)

    # either move agent or move agent and box

    next_variable_grid = jax.lax.select(
        self.check_space(variable_grid, next_location, BOX),
        next_variable_grid.at[tuple(next_location)].set(AGENT).at[tuple(box_location)].set(BOX),
        next_variable_grid.at[tuple(next_location)].set(AGENT),
    )

    return next_variable_grid, next_location

render(state) #

Renders the current state of Sokoban.

Parameters:

Name Type Description Default
state State

'State' object , the current state to be rendered.

required
Source code in jumanji/environments/routing/sokoban/env.py
543
544
545
546
547
548
549
550
551
def render(self, state: State) -> None:
    """
    Renders the current state of Sokoban.

    Args:
        state: 'State' object , the current state to be rendered.
    """

    self._viewer.render(state=state)

reset(key) #

Resets the environment by calling the instance generator for a new instance.

Parameters:

Name Type Description Default
key PRNGKey

random key used to sample new Sokoban problem.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the

TimeStep[Observation]

environment after a reset.

timestep Tuple[State, TimeStep[Observation]]

TimeStep object corresponding the first timestep

Tuple[State, TimeStep[Observation]]

returned by the environment after a reset.

Source code in jumanji/environments/routing/sokoban/env.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """
    Resets the environment by calling the instance generator for a
    new instance.

    Args:
        key: random key used to sample new Sokoban problem.

    Returns:
        state: `State` object corresponding to the new state of the
        environment after a reset.
        timestep: `TimeStep` object corresponding the first timestep
        returned by the environment after a reset.
    """

    generator_key, key = jax.random.split(key)

    state = self.generator(generator_key)

    timestep = restart(
        self._state_to_observation(state),
        extras=self._get_extras(state),
    )

    return state, timestep

step(state, action) #

Executes one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

'State' object representing the current state of the

required
action Array

Array (int32) of shape (). - 0: move up. - 1: move down. - 2: move left. - 3: move right.

required

Returns:

Type Description
State

state, timestep: next state of the environment and timestep to be

TimeStep[Observation]

observed.

Source code in jumanji/environments/routing/sokoban/env.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """
    Executes one timestep of the environment's dynamics.

    Args:
        state: 'State' object representing the current state of the
        environment.
        action: Array (int32) of shape ().
            - 0: move up.
            - 1: move down.
            - 2: move left.
            - 3: move right.

    Returns:
        state, timestep: next state of the environment and timestep to be
        observed.
    """

    # switch to noop if action will have no impact on variable grid
    action = self.detect_noop_action(
        state.variable_grid, state.fixed_grid, action, state.agent_location
    )

    next_variable_grid, next_agent_location = jax.lax.cond(
        jnp.all(action == NOOP),
        lambda: (state.variable_grid, state.agent_location),
        lambda: self.move_agent(state.variable_grid, action, state.agent_location),
    )

    next_state = State(
        key=state.key,
        fixed_grid=state.fixed_grid,
        variable_grid=next_variable_grid,
        agent_location=next_agent_location,
        step_count=state.step_count + 1,
    )

    target_reached = self.level_complete(next_state)
    time_limit_exceeded = next_state.step_count >= self.time_limit

    done = jnp.logical_or(target_reached, time_limit_exceeded)

    reward = jnp.asarray(self.reward_fn(state, action, next_state), float)

    observation = self._state_to_observation(next_state)

    extras = self._get_extras(next_state)

    timestep = jax.lax.cond(
        done,
        lambda: termination(
            reward=reward,
            observation=observation,
            extras=extras,
        ),
        lambda: transition(
            reward=reward,
            observation=observation,
            extras=extras,
        ),
    )

    return next_state, timestep

update_box_push_action(fixed_grid, variable_grid, new_location, action) #

Masks actions to -1 if pushing the box is not a valid move. If it would be pushed out of the grid or the resulting square is either a wall or another box.

Parameters:

Name Type Description Default
fixed_grid Array

Array (uint8) shape (num_rows, num_cols) The fixed grid.

required
variable_grid Array

Array (uint8) shape (num_rows, num_cols) The

required
new_location Array

Array (int32) shape (2,) The new location of the agent.

required
action Array

Array (int32) shape () The action to be executed.

required

Returns:

Name Type Description
updated_action Array

Array (int32) shape () The updated action after

Array

checking if pushing the box is a valid move.

Source code in jumanji/environments/routing/sokoban/env.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
def update_box_push_action(
    self,
    fixed_grid: chex.Array,
    variable_grid: chex.Array,
    new_location: chex.Array,
    action: chex.Array,
) -> chex.Array:
    """
    Masks actions to -1 if pushing the box is not a valid move. If it
    would be pushed out of the grid or the resulting square
    is either a wall or another box.

    Args:
        fixed_grid: Array (uint8) shape (num_rows, num_cols) The fixed grid.
        variable_grid: Array (uint8) shape (num_rows, num_cols) The
        variable grid.
        new_location: Array (int32) shape (2,) The new location of the agent.
        action: Array (int32) shape () The action to be executed.

    Returns:
        updated_action: Array (int32) shape () The updated action after
        checking if pushing the box is a valid move.
    """

    return jax.lax.select(
        self.check_space(
            variable_grid,
            new_location + MOVES[action].squeeze(),
            BOX,
        )
        | ~self.in_grid(new_location + MOVES[action].squeeze()),
        jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32),
        jax.lax.select(
            self.check_space(
                fixed_grid,
                new_location + MOVES[action].squeeze(),
                WALL,
            ),
            jnp.full(shape=(), fill_value=NOOP, dtype=jnp.int32),
            action,
        ),
    )