Skip to content

Maze

Bases: Environment[State, DiscreteArray, Observation]

A JAX implementation of a 2D Maze. The goal is to navigate the maze to find the target position.

  • observation:

    • agent_position: current 2D Position of agent.
    • target_position: 2D Position of target cell.
    • walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells.
    • action_mask: array (bool) of shape (4,) defining the available actions in the current position.
    • step_count: jax array (int32) of shape () step number of the episode.
  • action: jax array (int32) of shape () specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken.

  • reward: jax array (float32) of shape (): 1 if the target is reached, 0 otherwise.

  • episode termination (if any):

    • agent reaches the target position.
    • the time_limit is reached.
  • state: State:

    • agent_position: current 2D Position of agent.
    • target_position: 2D Position of target cell.
    • walls: jax array (bool) of shape (num_rows, num_cols) whose values are True where walls are and False for empty cells.
    • action_mask: array (bool) of shape (4,) defining the available actions in the current position.
    • step_count: jax array (int32) of shape () step number of the episode.
    • key: random key (uint) of shape (2,).
1
2
3
4
5
6
7
8
from jumanji.environments import Maze
env = Maze()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a Maze environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [ToyGenerator, RandomGenerator]. Defaults to RandomGenerator with num_rows=10 and num_cols=10.

None
time_limit Optional[int]

the time_limit of an episode, i.e. the maximum number of environment steps before the episode terminates. By default, time_limit = num_rows * num_cols.

None
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to MazeEnvViewer with "human" render mode.

None
Source code in jumanji/environments/routing/maze/env.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(
    self,
    generator: Optional[Generator] = None,
    time_limit: Optional[int] = None,
    viewer: Optional[Viewer[State]] = None,
) -> None:
    """Instantiates a `Maze` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment instance.
            Implemented options are [`ToyGenerator`, `RandomGenerator`].
            Defaults to `RandomGenerator` with `num_rows=10` and `num_cols=10`.
        time_limit: the time_limit of an episode, i.e. the maximum number of environment steps
            before the episode terminates. By default, `time_limit = num_rows * num_cols`.
        viewer: `Viewer` used for rendering. Defaults to `MazeEnvViewer` with "human" render
            mode.
    """
    self.generator = generator or RandomGenerator(num_rows=10, num_cols=10)
    self.num_rows = self.generator.num_rows
    self.num_cols = self.generator.num_cols
    super().__init__()
    self.shape = (self.num_rows, self.num_cols)
    self.time_limit = time_limit or self.num_rows * self.num_cols

    # Create viewer used for rendering
    self._viewer = viewer or MazeEnvViewer("Maze", render_mode="human")

action_spec: specs.DiscreteArray cached property #

Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left].

Returns:

Name Type Description
action_spec DiscreteArray

discrete action space with 4 values.

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the Maze environment.

Returns:

Type Description
Spec[Observation]

Spec for the Observation whose fields are:

Spec[Observation]
  • agent_position: tree of BoundedArray (int32) of shape ().
Spec[Observation]
  • target_position: tree of BoundedArray (int32) of shape ().
Spec[Observation]
  • walls: BoundedArray (bool) of shape (num_rows, num_cols).
Spec[Observation]
  • step_count: Array (int32) of shape ().
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (4,).

animate(states, interval=200, save_path=None) #

Creates an animated gif of the Maze environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/routing/maze/env.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the `Maze` environment based on the sequence of states.

    Args:
        states: sequence of environment states corresponding to consecutive timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/routing/maze/env.py
322
323
324
325
326
327
328
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Render the given state of the environment.

Parameters:

Name Type Description Default
state State

State object containing the current environment state.

required
Source code in jumanji/environments/routing/maze/env.py
295
296
297
298
299
300
301
def render(self, state: State) -> Optional[NDArray]:
    """Render the given state of the environment.

    Args:
        state: `State` object containing the current environment state.
    """
    return self._viewer.render(state)

reset(key) #

Resets the environment by calling the instance generator for a new instance.

Parameters:

Name Type Description Default
key PRNGKey

random key used to reset the environment since it is stochastic.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment after a reset.

timestep TimeStep[Observation]

TimeStep object corresponding the first timestep returned by the environment after a reset.

Source code in jumanji/environments/routing/maze/env.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment by calling the instance generator for a new instance.

    Args:
        key: random key used to reset the environment since it is stochastic.

    Returns:
        state: `State` object corresponding to the new state of the environment after a reset.
        timestep: `TimeStep` object corresponding the first timestep returned by the environment
            after a reset.
    """

    state = self.generator(key)

    # Create the action mask and update the state
    state.action_mask = self._compute_action_mask(state.walls, state.agent_position)

    # Generate the observation from the environment state.
    observation = self._observation_from_state(state)

    # Return a restart timestep whose step type is FIRST.
    timestep = restart(observation)

    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

If an action is invalid, the agent does not move, i.e. the episode does not automatically terminate.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

(int32) specifying which action to take: [0,1,2,3] correspond to [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken.

required

Returns:

Name Type Description
state State

the next state of the environment.

timestep TimeStep[Observation]

the next timestep to be observed.

Source code in jumanji/environments/routing/maze/env.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """
    Run one timestep of the environment's dynamics.

    If an action is invalid, the agent does not move, i.e. the episode does not
    automatically terminate.

    Args:
        state: State object containing the dynamics of the environment.
        action: (int32) specifying which action to take: [0,1,2,3] correspond to
            [Up, Right, Down, Left]. If an invalid action is taken, i.e. there is a wall
            blocking the action, then no action (no-op) is taken.

    Returns:
        state: the next state of the environment.
        timestep: the next timestep to be observed.
    """
    # If the chosen action is invalid, i.e. blocked by a wall, overwrite it to no-op.
    action = jax.lax.select(state.action_mask[action], action, 4)

    # Take the action in the environment:  up, right, down, or left
    # Remember the walls coordinates: (0,0) is top left.
    agent_position = jax.lax.switch(
        action,
        [
            lambda position: Position(position.row - 1, position.col),  # Up
            lambda position: Position(position.row, position.col + 1),  # Right
            lambda position: Position(position.row + 1, position.col),  # Down
            lambda position: Position(position.row, position.col - 1),  # Left
            lambda position: position,  # No-op
        ],
        state.agent_position,
    )

    # Generate action mask to keep in the state for the next step and
    # to provide to the agent in the observation.
    action_mask = self._compute_action_mask(state.walls, agent_position)

    # Build the state.
    state = State(
        agent_position=agent_position,
        target_position=state.target_position,
        walls=state.walls,
        action_mask=action_mask,
        key=state.key,
        step_count=state.step_count + 1,
    )
    # Generate the observation from the environment state.
    observation = self._observation_from_state(state)

    # Check if the episode terminates (i.e. done is True).
    no_actions_available = ~jnp.any(action_mask)
    target_reached = state.agent_position == state.target_position
    time_limit_exceeded = state.step_count >= self.time_limit

    done = no_actions_available | target_reached | time_limit_exceeded

    # Compute the reward.
    reward = jnp.array(state.agent_position == state.target_position, float)

    # Return either a MID or a LAST timestep depending on done.
    timestep = jax.lax.cond(
        done,
        termination,
        transition,
        reward,
        observation,
    )
    return state, timestep