Skip to content


Maze (Environment) #

A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position.

  • observation:

    • agent_position: current 2D Position of agent.
    • target_position: 2D Position of target cell.
    • walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells.
    • action_mask: array (bool) of shape (4,) defining the available actions in the current position.
    • step_count: jax array (int32) of shape () step number of the episode.
  • action: jax array (int32) of shape () specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken.

  • reward: jax array (float32) of shape (): 1 if the target is reached, 0 otherwise.

  • episode termination (if any):

    • agent reaches the target position.
    • the time_limit is reached.
  • state: State:

    • agent_position: current 2D Position of agent.
    • target_position: 2D Position of target cell.
    • walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells.
    • action_mask: array (bool) of shape (4,) defining the available actions in the current position.
    • step_count: jax array (int32) of shape () step number of the episode.
    • key: random key (uint) of shape (2,).
from jumanji.environments import Maze
env = Maze()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)

observation_spec: jumanji.specs.Spec[jumanji.environments.routing.maze.types.Observation] cached property writable #

Specifications of the observation of the Maze environment.


Type Description
Spec for the `Observation` whose fields are
  • agent_position: tree of BoundedArray (int32) of shape ().
  • target_position: tree of BoundedArray (int32) of shape ().
  • walls: BoundedArray (bool) of shape (num_rows, num_cols).
  • step_count: Array (int32) of shape ().
  • action_mask: BoundedArray (bool) of shape (4,).

action_spec: DiscreteArray cached property writable #

Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left].


Type Description

discrete action space with 4 values.

__init__(self, generator: Optional[jumanji.environments.routing.maze.generator.Generator] = None, time_limit: Optional[int] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.maze.types.State]] = None) -> None special #

Instantiates a Maze environment.


Name Type Description Default
generator Optional[jumanji.environments.routing.maze.generator.Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [ToyGenerator, RandomGenerator]. Defaults to RandomGenerator with num_rows=10 and num_cols=10.

time_limit Optional[int]

the time_limit of an episode, i.e. the maximum number of environment steps before the episode terminates. By default, time_limit = num_rows * num_cols.

viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.maze.types.State]]

Viewer used for rendering. Defaults to MazeEnvViewer with "human" render mode.


reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.maze.types.State, jumanji.types.TimeStep[jumanji.environments.routing.maze.types.Observation]] #

Resets the environment by calling the instance generator for a new instance.


Name Type Description Default
key PRNGKeyArray

random key used to reset the environment since it is stochastic.



Type Description

State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.routing.maze.types.State, jumanji.types.TimeStep[jumanji.environments.routing.maze.types.Observation]] #

Run one timestep of the environment's dynamics.

If an action is invalid, the agent does not move, i.e. the episode does not automatically terminate.


Name Type Description Default
state State

State object containing the dynamics of the environment.

action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

(int32) specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken.



Type Description

the next state of the environment. timestep: the next timestep to be observed.

Last update: 2024-03-29
Back to top