Skip to content

Cleaner

Cleaner (Environment) #

A JAX implementation of the 'Cleaner' game where multiple agents have to clean all tiles of a maze.

  • observation: Observation

    • grid: jax array (int8) of shape (num_rows, num_cols) contains the state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall.
    • agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board.
    • action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed.
    • step_count: (int32) the number of step since the beginning of the episode.
  • action: jax array (int32) of shape (num_agents,) the action for each agent: (0: up, 1: right, 2: down, 3: left)

  • reward: jax array (float) of shape () +1 every time a tile is cleaned and a configurable penalty (-0.5 by default) for each timestep.

  • episode termination:

    • All tiles are clean.
    • The number of steps is greater than the limit.
    • An invalid action is selected for any of the agents.
  • state: State

    • grid: jax array (int8) of shape (num_rows, num_cols) contains the current state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall.
    • agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board.
    • action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed.
    • step_count: jax array (int32) of shape () the number of steps since the beginning of the episode.
    • key: jax array (uint) of shape (2,) jax random generation key. Ignored since the environment is deterministic.
1
2
3
4
5
6
7
8
from jumanji.environments import Cleaner
env = Cleaner()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.routing.cleaner.types.Observation] cached property writable #

Specification of the observation of the Cleaner environment.

Returns:

Type Description
Spec for the `Observation`, consisting of the fields
  • grid: BoundedArray (int8) of shape (num_rows, num_cols). Values are between 0 and 2 (inclusive).
    • agent_locations_spec: BoundedArray (int32) of shape (num_agents, 2). Maximum value for the first column is num_rows, and maximum value for the second is num_cols.
    • action_mask: BoundedArray (bool) of shape (num_agent, 4).
    • step_count: BoundedArray (int32) of shape ().

action_spec: MultiDiscreteArray cached property writable #

Specification of the action for the Cleaner environment.

Returns:

Type Description
action_spec

a specs.MultiDiscreteArray spec.

__init__(self, generator: Optional[jumanji.environments.routing.cleaner.generator.Generator] = None, time_limit: Optional[int] = None, penalty_per_timestep: float = 0.5, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cleaner.types.State]] = None) -> None special #

Instantiates a Cleaner environment.

Parameters:

Name Type Description Default
num_agents

number of agents. Defaults to 3.

required
time_limit Optional[int]

max number of steps in an episode. Defaults to num_rows * num_cols.

None
generator Optional[jumanji.environments.routing.cleaner.generator.Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [RandomGenerator]. Defaults to RandomGenerator with num_rows=10, num_cols=10 and num_agents=3.

None
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.cleaner.types.State]]

Viewer used for rendering. Defaults to CleanerViewer with "human" render mode.

None
penalty_per_timestep float

the penalty returned at each timestep in the reward.

0.5

__repr__(self) -> str special #

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.cleaner.types.State, jumanji.types.TimeStep[jumanji.environments.routing.cleaner.types.Observation]] #

Reset the environment to its initial state.

All the tiles except upper left are dirty, and the agents start in the upper left corner of the grid.

Parameters:

Name Type Description Default
key PRNGKeyArray

random key used to reset the environment.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment after a reset. timestep: TimeStep object corresponding to the first timestep returned by the environment after a reset.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.routing.cleaner.types.State, jumanji.types.TimeStep[jumanji.environments.routing.cleaner.types.Observation]] #

Run one timestep of the environment's dynamics.

If an action is invalid, the corresponding agent does not move and the episode terminates.

Parameters:

Name Type Description Default
state State

current environment state.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

Jax array of shape (num_agents,). Each agent moves one step in the specified direction (0: up, 1: right, 2: down, 3: left).

required

Returns:

Type Description
state

State object corresponding to the next state of the environment. timestep: TimeStep object corresponding to the timestep returned by the environment.

render(self, state: State) -> Optional[numpy.ndarray[Any, numpy.dtype[+ScalarType]]] #

Render the given state of the environment.

Parameters:

Name Type Description Default
state State

State object containing the current environment state.

required

animate(self, states: Sequence[jumanji.environments.routing.cleaner.types.State], interval: int = 200, save_path: Optional[str] = None) -> FuncAnimation #

Creates an animated gif of the Cleaner environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[jumanji.environments.routing.cleaner.types.State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
animation.FuncAnimation

the animation object that was created.

close(self) -> None #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.


Last update: 2024-03-21
Back to top