Skip to content

PacMan

Bases: Environment[State, DiscreteArray, Observation]

A JAX implementation of the 'PacMan' game where a single agent must navigate a maze to collect pellets and avoid 4 heuristic agents. The game takes place on a 31x28 grid where the player can move in 4 directions (left, right, up, down) and collect pellets to gain points. The goal is to collect all of the pellets on the board without colliding with one of the heuristic agents. Using the AsciiGenerator the environment will always generate the same maze as long as the same Ascii diagram is in use.

  • observation: Observation

    • player_locations: current 2D position of agent.
    • grid: jax array (int) of the ingame maze with walls.
    • ghost_locations: jax array (int) of ghost positions.
    • power_up_locations: jax array (int) of power-pellet locations
    • pellet_locations: jax array (int) of pellets.
    • action_mask: jax array (bool) defining current actions.
    • score: (int32) of total points aquired.
  • action: jax array (int) of shape () specifiying which action to take [0,1,2,3,4] corresponding to [up, right, down, left, no-op. If there is an invalid action taken, i.e. there is a wall blocking the action, then no action (no-op) is taken.

  • reward: jax array (float32) of shape (): 10 per pellet collected, 20 for a power pellet and 200 for each unique ghost eaten.

  • episode termination (if any):

    • agent has collected all pellets.
    • agent killed by ghost.
    • timer has elapsed.
  • state: State:

    • key: jax array (uint32) of shape(2,).
    • grid: jax array (int)) of shape (31,28) of the ingame maze with walls.
    • pellets: int tracking the number of pellets.
    • frightened_state_time: jax array (int) of shape () tracks number of steps for the scatter state.
    • pellet_locations: jax array (int) of pellets of shape (316,2).
    • power_up_locations: jax array (int) of power-pellet locations of shape (4,2).
    • player_locations: current 2D position of agent.
    • ghost_locations: jax array (int) of ghost positions of shape (4,2).
    • initial_player_locations: starting 2D position of agent.
    • initial_ghost_positions: jax array (int) of ghost positions of shape (4,2).
    • ghost_init_targets: jax array (int) of ghost positions. used to direct ghosts on respawn.
    • old_ghost_locations: jax array (int) of shape (4,2) of ghost positions from last step. used to prevent ghost backtracking.
    • ghost_init_steps: jax array (int) of shape (4,2) number of initial ghost steps. used to determine per ghost initialisation.
    • ghost_actions: jax array (int) of shape (4,).
    • last_direction: int tracking the last direction of the player.
    • dead: bool used to track player death.
    • visited_index: jax array (int) of visited locations of shape (320,2). used to prevent repeated pellet points.
    • ghost_starts: jax array (int) of shape (4,2) used to reset ghost positions if eaten
    • scatter_targets: jax array (int) of shape (4,2) target locations for ghosts when scatter behavior is active.
    • step_count: (int32) of total steps taken from reset till current timestep.
    • ghost_eaten: jax array (bool)of shape (4,) tracking if ghost has been eaten before.
    • score: (int32) of total points aquired.
1
2
3
4
5
6
7
8
from jumanji.environments import pac_man
env = PacMan()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a PacMan environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [AsciiGenerator].

None
time_limit Optional[int]

the time_limit of an episode, i.e. the maximum number of environment steps before the episode terminates. By default, set to 1000.

None
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to PacManViewer.

None
Source code in jumanji/environments/routing/pac_man/env.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def __init__(
    self,
    generator: Optional[Generator] = None,
    viewer: Optional[Viewer[State]] = None,
    time_limit: Optional[int] = None,
) -> None:
    """Instantiates a `PacMan` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment instance.
            Implemented options are [`AsciiGenerator`].
        time_limit: the time_limit of an episode, i.e. the maximum number of environment steps
            before the episode terminates. By default, set to 1000.
        viewer: `Viewer` used for rendering. Defaults to `PacManViewer`.
    """

    self.generator = generator or AsciiGenerator(DEFAULT_MAZE)
    self.x_size = self.generator.x_size
    self.y_size = self.generator.y_size
    self.pellet_spaces = self.generator.pellet_spaces
    super().__init__()
    self._viewer = viewer or PacManViewer("Pacman", render_mode="human")
    self.time_limit = 1000 or time_limit

action_spec: specs.DiscreteArray cached property #

Returns the action spec.

5 actions: [0,1,2,3,4] -> [Up, Right, Down, Left, No-op].

Returns:

Name Type Description
action_spec DiscreteArray

a specs.DiscreteArray spec object.

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the PacMan environment.

Returns:

Type Description
Spec[Observation]

Spec containing all the specifications for all the Observation fields:

Spec[Observation]
  • player_locations: tree of BoundedArray (int32) of shape ().
Spec[Observation]
  • grid: BoundedArray (int)) of the ingame maze with walls.
Spec[Observation]
  • ghost_locations: jax array (int) of ghost positions.
Spec[Observation]
  • power_up_locations: jax array (int) of power-pellet locations
Spec[Observation]
  • pellet_locations: jax array (int) of pellet locations.
Spec[Observation]
  • action_mask: jax array (bool) defining current actions.
Spec[Observation]
  • frightened_state_time: int counting time remaining in scatter mode.
Spec[Observation]
  • score: (int) of total score obtained by player.

animate(states, interval=200, save_path=None) #

Creates an animated gif of the Maze environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/routing/pac_man/env.py
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the `Maze` environment based on the sequence of states.

    Args:
        states: sequence of environment states corresponding to consecutive timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states, interval, save_path)

check_power_up(state) #

Check if the player is on a power-up location and update the power-up locations array accordingly.

Parameters:

Name Type Description Default
state State

'state` object corresponding to the new state of the environment

required

Returns:

Name Type Description
power_up_locations Array

locations of the remaining power-ups

eat Numeric

a bool indicating if the player can eat the ghosts

reward Numeric

an int of the reward gained from collecting power-ups

Source code in jumanji/environments/routing/pac_man/env.py
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def check_power_up(self, state: State) -> Tuple[chex.Array, chex.Numeric, chex.Numeric]:
    """
    Check if the player is on a power-up location and update the power-up
    locations array accordingly.

    Args:
        state: 'state` object corresponding to the new state of the environment

    Returns:
        power_up_locations: locations of the remaining power-ups
        eat: a bool indicating if the player can eat the ghosts
        reward: an int of the reward gained from collecting power-ups
    """

    power_up_locations = jnp.array(state.power_up_locations)

    player_space = state.player_locations
    player_loc_x = player_space.x
    player_loc_y = player_space.y
    player_loc = jnp.array([player_loc_y, player_loc_x])

    # Check if player and power_up position are shared
    on_powerup = (player_loc == power_up_locations).all(axis=-1).any()
    eat = on_powerup.astype(int)
    mask = (player_loc == power_up_locations).all(axis=-1)
    invert_mask = ~mask
    invert_mask = invert_mask.reshape(4, 1)

    # Mask out collected power-ups
    power_up_locations = power_up_locations * invert_mask

    # Assign reward for power-up
    reward = eat * 50.0

    return power_up_locations, eat, reward

check_rewards(state) #

Process the state of the game to compute rewards, updated pellet spaces, and remaining number of pellets.

Parameters:

Name Type Description Default
state State

'State` object corresponding to the current state of the environment

required

Returns:

Name Type Description
rewards int

an integer representing the reward earned by the player in the current state

pellet_spaces Array

a 2D jax array showing the location of all remaining cookies

num_cookies int

an integer counting the remaining cookies on the map.

Source code in jumanji/environments/routing/pac_man/env.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def check_rewards(self, state: State) -> Tuple[int, chex.Array, int]:
    """
    Process the state of the game to compute rewards, updated pellet spaces, and remaining
    number of pellets.

    Args:
        state: 'State` object corresponding to the current state of the environment

    Returns:
        rewards: an integer representing the reward earned by the player in the current state
        pellet_spaces: a 2D jax array showing the location of all remaining cookies
        num_cookies: an integer counting the remaining cookies on the map.
    """

    # Get the locations of the pellets and the player
    pellet_spaces = jnp.array(state.pellet_locations)
    player_space = state.player_locations
    ps = jnp.array([player_space.y, player_space.x])

    # Get the number of pellets on the map
    num_pellets = state.pellets

    # Check if player has eaten a pellet in this step
    ate_pellet = jnp.any(jnp.all(ps == pellet_spaces, axis=-1))

    # Reduce number of pellets on map if eaten, add reward and remove eaten pellet
    num_pellets -= ate_pellet.astype(int)
    rewards = ate_pellet * 10.0
    mask = jnp.logical_not(jnp.all(ps == pellet_spaces, axis=-1))
    pellet_spaces = pellet_spaces * mask[..., None]

    return rewards, pellet_spaces, num_pellets

check_wall_collisions(state, new_player_pos) #

Check if the new player position collides with a wall.

Parameters:

Name Type Description Default
state State

'State` object corresponding to the new state of the environment.

required
new_player_pos Position

the position of the player after the last action.

required

Returns:

Name Type Description
collision Any

a boolean indicating if the player has moved into a wall.

Source code in jumanji/environments/routing/pac_man/env.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def check_wall_collisions(self, state: State, new_player_pos: Position) -> Any:
    """
    Check if the new player position collides with a wall.

    Args:
        state: 'State` object corresponding to the new state of the environment.
        new_player_pos: the position of the player after the last action.

    Returns:
        collision: a boolean indicating if the player has moved into a wall.
    """

    grid = state.grid
    location_value = grid[new_player_pos.x, new_player_pos.y]

    collision = jax.lax.cond(
        location_value == 1,
        lambda x: new_player_pos,
        lambda x: state.player_locations,
        0,
    )
    return collision

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/routing/pac_man/env.py
540
541
542
543
544
545
546
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

player_step(state, action, steps=1) #

Compute the new position of the player based on the given state and action.

Parameters:

Name Type Description Default
state State

'state` object corresponding to the new state of the environment.

required
action int

an integer between 0 and 4 representing the player's chosen action.

required
steps int

how many steps ahead of current position to search.

1

Returns:

Name Type Description
new_pos Position

a Position object representing the new position of the player.

Source code in jumanji/environments/routing/pac_man/env.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def player_step(self, state: State, action: int, steps: int = 1) -> Position:
    """
    Compute the new position of the player based on the given state and action.

    Args:
        state: 'state` object corresponding to the new state of the environment.
        action: an integer between 0 and 4 representing the player's chosen action.
        steps: how many steps ahead of current position to search.

    Returns:
        new_pos: a `Position` object representing the new position of the player.
    """

    position = state.player_locations

    move_left = lambda position: (position.y, position.x - steps)
    move_up = lambda position: (position.y - steps, position.x)
    move_right = lambda position: (position.y, position.x + steps)
    move_down = lambda position: (position.y + steps, position.x)
    no_op = lambda position: (position.y, position.x)

    new_pos_row, new_pos_col = jax.lax.switch(
        action, [move_left, move_up, move_right, move_down, no_op], position
    )

    new_pos = Position(x=new_pos_col % self.x_size, y=new_pos_row % self.y_size)
    return new_pos

render(state) #

Render the given state of the environment.

Parameters:

Name Type Description Default
state State

state object containing the current environment state.

required
Source code in jumanji/environments/routing/pac_man/env.py
513
514
515
516
517
518
519
def render(self, state: State) -> Any:
    """Render the given state of the environment.

    Args:
        state: `state` object containing the current environment state.
    """
    return self._viewer.render(state)

reset(key) #

Resets the environment by calling the instance generator for a new instance.

Parameters:

Name Type Description Default
key PRNGKey

A PRNGKey to use for random number generation.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment after a reset.

timestep TimeStep[Observation]

TimeStep object corresponding the first timestep returned by the environment after a reset.

Source code in jumanji/environments/routing/pac_man/env.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment by calling the instance generator for a new instance.

    Args:
        key: A PRNGKey to use for random number generation.

    Returns:
        state: `State` object corresponding to the new state of the environment after a reset.
        timestep: `TimeStep` object corresponding the first timestep returned by the environment
            after a reset.
    """

    state = self.generator(key)

    # Generate observation
    obs = self._observation_from_state(state)

    # Return a restart timestep of step type is FIRST.
    timestep = restart(observation=obs)

    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

If an action is invalid, the agent does not move, i.e. the episode does not automatically terminate.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

(int32) specifying which action to take: [0,1,2,3,4] correspond to [Up, Right, Down, Left, No-op]. If an invalid action is taken, i.e. there is a wall blocking the action, then no action (no-op) is taken.

required

Returns:

Name Type Description
state State

the new state of the environment.

TimeStep[Observation]

the next timestep to be observed.

Source code in jumanji/environments/routing/pac_man/env.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    If an action is invalid, the agent does not move, i.e. the episode does not
    automatically terminate.

    Args:
        state: State object containing the dynamics of the environment.
        action: (int32) specifying which action to take: [0,1,2,3,4] correspond to
            [Up, Right, Down, Left, No-op]. If an invalid action is taken, i.e. there is a wall
            blocking the action, then no action (no-op) is taken.

    Returns:
        state: the new state of the environment.
        the next timestep to be observed.
    """

    # Collect updated state based on environment dynamics
    updated_state, collision_rewards = self._update_state(state, action)

    # Create next_state from updated state
    next_state = updated_state.replace(step_count=state.step_count + 1)  # type: ignore

    # Check if episode terminates
    num_pellets = next_state.pellets
    dead = next_state.dead
    time_limit_exceeded = next_state.step_count >= self.time_limit
    all_pellets_found = num_pellets == 0
    dead = next_state.dead == 1
    done = time_limit_exceeded | dead | all_pellets_found

    reward = jnp.asarray(collision_rewards)
    # Generate observation from the state
    observation = self._observation_from_state(next_state)

    # Return either a MID or a LAST timestep depending on done.
    timestep = jax.lax.cond(
        done,
        termination,
        transition,
        reward,
        observation,
    )

    return next_state, timestep