Skip to content

Sokoban

Sokoban (Environment) #

A JAX implementation of the 'Sokoban' game from deepmind.

  • observation: Observation

    • grid: jax array (uint8) of shape (num_rows, num_cols, 2) Array that includes information about the agent, boxes, and targets in the game.
    • step_count: jax array (int32) of shape () current number of steps in the episode.
  • action: jax array (int32) of shape () [0,1,2,3] -> [Up, Right, Down, Left].

  • reward: jax array (float) of shape () A reward of 1.0 is given for each box placed on a target and -1 when removed from a target and -0.1 for each timestep. 10 is awarded when all boxes are on targets.

  • episode termination:

    • if the time limit is reached.
    • if all boxes are on targets.
  • state: State

    • key: jax array (uint32) of shape (2,) used for auto-reset
    • fixed_grid: jax array (uint8) of shape (num_rows, num_cols) array indicating the walls and targets in the level.
    • variable_grid: jax array (uint8) of shape (num_rows, num_cols) array indicating the current location of the agent and boxes.
    • agent_location: jax array (int32) of shape (2,) the agent's current location.
    • step_count: jax array (int32) of shape () current number of steps in the episode.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from jumanji.environments import Sokoban
from jumanji.environments.routing.sokoban.generator import
HuggingFaceDeepMindGenerator,

env_train = Sokoban(
    generator=HuggingFaceDeepMindGenerator(
        dataset_name="unfiltered-train",
        proportion_of_files=1,
    )
)

env_test = Sokoban(
    generator=HuggingFaceDeepMindGenerator(
        dataset_name="unfiltered-test",
        proportion_of_files=1,
    )
)

# Train...
key_train = jax.random.PRNGKey(0) state, timestep = jax.jit(env_train.reset)(key_train) env_train.render(state) action = env_train.action_spec.generate_value() state, timestep = jax.jit(env_train.step)(state, action) env_train.render(state) ```

observation_spec: jumanji.specs.Spec[jumanji.environments.routing.sokoban.types.Observation] cached property writable #

Returns the specifications of the observation of the Sokoban environment.

Returns:

Type Description
specs.Spec[Observation]

The specifications of the observations.

action_spec: DiscreteArray cached property writable #

Returns the action specification for the Sokoban environment. There are 4 actions: [0,1,2,3] -> [Up, Right, Down, Left].

Returns:

Type Description
specs.DiscreteArray

Discrete action specifications.

__init__(self, generator: Optional[jumanji.environments.routing.sokoban.generator.Generator] = None, reward_fn: Optional[jumanji.environments.routing.sokoban.reward.RewardFn] = None, viewer: Optional[jumanji.viewer.Viewer] = None, time_limit: int = 120) -> None special #

Instantiates a Sokoban environment with a specific generator, time limit, and viewer.

Parameters:

Name Type Description Default
generator Optional[jumanji.environments.routing.sokoban.generator.Generator]

Generator whose __call__ instantiates an environment instance (an initial state). Implemented options are [ToyGenerator, DeepMindGenerator, and HuggingFaceDeepMindGenerator]. Defaults to HuggingFaceDeepMindGenerator with dataset_name="unfiltered-train", proportion_of_files=1.

None
time_limit int

int, max steps for the environment, defaults to 120.

120
viewer Optional[jumanji.viewer.Viewer]

'Viewer' object, used to render the environment.

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.sokoban.types.State, jumanji.types.TimeStep[jumanji.environments.routing.sokoban.types.Observation]] #

Resets the environment by calling the instance generator for a new instance.

Parameters:

Name Type Description Default
key PRNGKeyArray

random key used to sample new Sokoban problem.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.routing.sokoban.types.State, jumanji.types.TimeStep[jumanji.environments.routing.sokoban.types.Observation]] #

Executes one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

'State' object representing the current state of the

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

Array (int32) of shape (). - 0: move up. - 1: move down. - 2: move left. - 3: move right.

required

Returns:

Type Description
state, timestep

next state of the environment and timestep to be observed.

render(self, state: State) -> None #

Renders the current state of Sokoban.

Parameters:

Name Type Description Default
state State

'State' object , the current state to be rendered.

required

Last update: 2024-11-01
Back to top