Skip to content

Cleaner

Bases: Environment[State, MultiDiscreteArray, Observation]

A JAX implementation of the 'Cleaner' game where multiple agents have to clean all tiles of a maze.

  • observation: Observation

    • grid: jax array (int8) of shape (num_rows, num_cols) contains the state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall.
    • agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board.
    • action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed.
    • step_count: (int32) the number of step since the beginning of the episode.
  • action: jax array (int32) of shape (num_agents,) the action for each agent: (0: up, 1: right, 2: down, 3: left)

  • reward: jax array (float) of shape () +1 every time a tile is cleaned and a configurable penalty (-0.5 by default) for each timestep.

  • episode termination:

    • All tiles are clean.
    • The number of steps is greater than the limit.
    • An invalid action is selected for any of the agents.
  • state: State

    • grid: jax array (int8) of shape (num_rows, num_cols) contains the current state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall.
    • agents_locations: jax array (int32) of shape (num_agents, 2) contains the location of each agent on the board.
    • action_mask: jax array (bool) of shape (num_agents, 4) indicates for each agent if each of the four actions (up, right, down, left) is allowed.
    • step_count: jax array (int32) of shape () the number of steps since the beginning of the episode.
    • key: jax array (uint) of shape (2,) jax random generation key. Ignored since the environment is deterministic.
1
2
3
4
5
6
7
8
from jumanji.environments import Cleaner
env = Cleaner()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a Cleaner environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [RandomGenerator]. Defaults to RandomGenerator with num_rows=10, num_cols=10 and num_agents=3.

None
time_limit Optional[int]

max number of steps in an episode. Defaults to num_rows * num_cols.

None
penalty_per_timestep float

the penalty returned at each timestep in the reward.

0.5
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to CleanerViewer with "human" render mode.

None
Source code in jumanji/environments/routing/cleaner/env.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def __init__(
    self,
    generator: Optional[Generator] = None,
    time_limit: Optional[int] = None,
    penalty_per_timestep: float = 0.5,
    viewer: Optional[Viewer[State]] = None,
) -> None:
    """Instantiates a `Cleaner` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment instance.
            Implemented options are [`RandomGenerator`]. Defaults to `RandomGenerator` with
            `num_rows=10`, `num_cols=10` and `num_agents=3`.
        time_limit: max number of steps in an episode. Defaults to `num_rows * num_cols`.
        penalty_per_timestep: the penalty returned at each timestep in the reward.
        viewer: `Viewer` used for rendering. Defaults to `CleanerViewer` with "human" render
            mode.
    """
    self.generator = generator or RandomGenerator(num_rows=10, num_cols=10, num_agents=3)
    self.num_agents = self.generator.num_agents
    self.num_rows = self.generator.num_rows
    self.num_cols = self.generator.num_cols
    self.grid_shape = (self.num_rows, self.num_cols)
    self.time_limit = time_limit or (self.num_rows * self.num_cols)
    super().__init__()
    self.penalty_per_timestep = penalty_per_timestep

    # Create viewer used for rendering
    self._viewer = viewer or CleanerViewer("Cleaner", render_mode="human")

action_spec: specs.MultiDiscreteArray cached property #

Specification of the action for the Cleaner environment.

Returns:

Name Type Description
action_spec MultiDiscreteArray

a specs.MultiDiscreteArray spec.

observation_spec: specs.Spec[Observation] cached property #

Specification of the observation of the Cleaner environment.

Returns:

Type Description
Spec[Observation]

Spec for the Observation, consisting of the fields: - grid: BoundedArray (int8) of shape (num_rows, num_cols). Values are between 0 and 2 (inclusive). - agent_locations_spec: BoundedArray (int32) of shape (num_agents, 2). Maximum value for the first column is num_rows, and maximum value for the second is num_cols. - action_mask: BoundedArray (bool) of shape (num_agent, 4). - step_count: BoundedArray (int32) of shape ().

animate(states, interval=200, save_path=None) #

Creates an animated gif of the Cleaner environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/routing/cleaner/env.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the `Cleaner` environment based on the sequence of states.

    Args:
        states: sequence of environment states corresponding to consecutive timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/routing/cleaner/env.py
282
283
284
285
286
287
288
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Render the given state of the environment.

Parameters:

Name Type Description Default
state State

State object containing the current environment state.

required
Source code in jumanji/environments/routing/cleaner/env.py
255
256
257
258
259
260
261
def render(self, state: State) -> Optional[NDArray]:
    """Render the given state of the environment.

    Args:
        state: `State` object containing the current environment state.
    """
    return self._viewer.render(state)

reset(key) #

Reset the environment to its initial state.

All the tiles except upper left are dirty, and the agents start in the upper left corner of the grid.

Parameters:

Name Type Description Default
key PRNGKey

random key used to reset the environment.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment after a reset.

timestep TimeStep[Observation]

TimeStep object corresponding to the first timestep returned by the environment after a reset.

Source code in jumanji/environments/routing/cleaner/env.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Reset the environment to its initial state.

    All the tiles except upper left are dirty, and the agents start in the upper left
    corner of the grid.

    Args:
        key: random key used to reset the environment.

    Returns:
        state: `State` object corresponding to the new state of the environment after a reset.
        timestep: `TimeStep` object corresponding to the first timestep returned by the
            environment after a reset.
    """
    # Agents start in upper left corner
    agents_locations = jnp.zeros((self.num_agents, 2), int)

    state = self.generator(key)

    # Create the action mask and update the state
    state.action_mask = self._compute_action_mask(state.grid, agents_locations)

    observation = self._observation_from_state(state)

    extras = self._compute_extras(state)
    timestep = restart(observation, extras)

    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

If an action is invalid, the corresponding agent does not move and the episode terminates.

Parameters:

Name Type Description Default
state State

current environment state.

required
action Array

Jax array of shape (num_agents,). Each agent moves one step in the specified direction (0: up, 1: right, 2: down, 3: left).

required

Returns:

Name Type Description
state State

State object corresponding to the next state of the environment.

timestep TimeStep[Observation]

TimeStep object corresponding to the timestep returned by the environment.

Source code in jumanji/environments/routing/cleaner/env.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    If an action is invalid, the corresponding agent does not move and
    the episode terminates.

    Args:
        state: current environment state.
        action: Jax array of shape (num_agents,). Each agent moves one step in
            the specified direction (0: up, 1: right, 2: down, 3: left).

    Returns:
        state: `State` object corresponding to the next state of the environment.
        timestep: `TimeStep` object corresponding to the timestep returned by the environment.
    """
    is_action_valid = self._is_action_valid(action, state.action_mask)

    agents_locations = self._update_agents_locations(
        state.agents_locations, action, is_action_valid
    )

    grid = self._clean_tiles_containing_agents(state.grid, agents_locations)

    prev_state = state

    state = State(
        agents_locations=agents_locations,
        grid=grid,
        action_mask=self._compute_action_mask(grid, agents_locations),
        step_count=state.step_count + 1,
        key=state.key,
    )

    reward = self._compute_reward(prev_state, state)

    observation = self._observation_from_state(state)

    done = self._should_terminate(state, is_action_valid)

    extras = self._compute_extras(state)
    # Return either a MID or a LAST timestep depending on done.
    timestep = jax.lax.cond(
        done,
        lambda reward, observation, extras: termination(
            reward=reward,
            observation=observation,
            extras=extras,
        ),
        lambda reward, observation, extras: transition(
            reward=reward,
            observation=observation,
            extras=extras,
        ),
        reward,
        observation,
        extras,
    )

    return state, timestep