Skip to content

Maze

Maze (Environment) #

A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position.

  • observation:

    • agent_position: current 2D Position of agent.
    • target_position: 2D Position of target cell.
    • walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells.
    • action_mask: array (bool) of shape (4,) defining the available actions in the current position.
    • step_count: jax array (int32) of shape () step number of the episode.
  • action: jax array (int32) of shape () specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken.

  • reward: jax array (float32) of shape (): 1 if the target is reached, 0 otherwise.

  • episode termination (if any):

    • agent reaches the target position.
    • the time_limit is reached.
  • state: State:

    • agent_position: current 2D Position of agent.
    • target_position: 2D Position of target cell.
    • walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells.
    • action_mask: array (bool) of shape (4,) defining the available actions in the current position.
    • step_count: jax array (int32) of shape () step number of the episode.
    • key: random key (uint) of shape (2,).
1
2
3
4
5
6
7
8
from jumanji.environments import Maze
env = Maze()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.routing.maze.types.Observation] cached property writable #

Specifications of the observation of the Maze environment.

Returns:

Type Description
Spec for the `Observation` whose fields are
  • agent_position: tree of BoundedArray (int32) of shape ().
  • target_position: tree of BoundedArray (int32) of shape ().
  • walls: BoundedArray (bool) of shape (num_rows, num_cols).
  • step_count: Array (int32) of shape ().
  • action_mask: BoundedArray (bool) of shape (4,).

action_spec: DiscreteArray cached property writable #

Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left].

Returns:

Type Description
action_spec

discrete action space with 4 values.

__init__(self, generator: Optional[jumanji.environments.routing.maze.generator.Generator] = None, time_limit: Optional[int] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.maze.types.State]] = None) -> None special #

Instantiates a Maze environment.

Parameters:

Name Type Description Default
generator Optional[jumanji.environments.routing.maze.generator.Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [ToyGenerator, RandomGenerator]. Defaults to RandomGenerator with num_rows=10 and num_cols=10.

None
time_limit Optional[int]

the time_limit of an episode, i.e. the maximum number of environment steps before the episode terminates. By default, time_limit = num_rows * num_cols.

None
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.maze.types.State]]

Viewer used for rendering. Defaults to MazeEnvViewer with "human" render mode.

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.maze.types.State, jumanji.types.TimeStep[jumanji.environments.routing.maze.types.Observation]] #

Resets the environment by calling the instance generator for a new instance.

Parameters:

Name Type Description Default
key PRNGKeyArray

random key used to reset the environment since it is stochastic.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding the first timestep returned by the environment after a reset.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.routing.maze.types.State, jumanji.types.TimeStep[jumanji.environments.routing.maze.types.Observation]] #

Run one timestep of the environment's dynamics.

If an action is invalid, the agent does not move, i.e. the episode does not automatically terminate.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

(int32) specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken.

required

Returns:

Type Description
state

the next state of the environment. timestep: the next timestep to be observed.


Last update: 2024-11-01
Back to top