Skip to content

Level-Based Foraging

Bases: Environment[State, MultiDiscreteArray, Observation]

An implementation of the Level-Based Foraging environment where agents need to cooperate to collect food and split the reward.

Original implementation: https://github.com/semitable/lb-foraging

  • observation: Observation

    • agent_views: Depending on the observer passed to __init__, it can be a GridObserver or a VectorObserver.
      • GridObserver: Returns an agent's view with a shape of (num_agents, 3, 2 * fov + 1, 2 * fov +1).
      • VectorObserver: Returns an agent's view with a shape of (num_agents, 3 * (num_food + num_agents).
    • action_mask: JAX array (bool) of shape (num_agents, 6) indicating for each agent which size actions (no-op, up, down, left, right, load) are allowed.
    • step_count: int32, the number of steps since the beginning of the episode.
  • action: JAX array (int32) of shape (num_agents,). The valid actions for each agent are (0: noop, 1: up, 2: down, 3: left, 4: right, 5: load).

  • reward: JAX array (float) of shape (num_agents,) When one or more agents load food, the food level is rewarded to the agents, weighted by the level of each agent. The reward is then normalized so that, at the end, the sum of the rewards (if all food items have been picked up) is one.

  • Episode Termination:

    • All food items have been eaten.
    • The number of steps is greater than the limit.
  • state: State

    • agents: Stacked Pytree of Agent objects of length num_agents.
      • Agent:
        • id: JAX array (int32) of shape ().
        • position: JAX array (int32) of shape (2,).
        • level: JAX array (int32) of shape ().
        • loading: JAX array (bool) of shape ().
    • food_items: Stacked Pytree of Food objects of length num_food.
      • Food:
        • id: JAX array (int32) of shape ().
        • position: JAX array (int32) of shape (2,).
        • level: JAX array (int32) of shape ().
        • eaten: JAX array (bool) of shape ().
    • step_count: JAX array (int32) of shape (), the number of steps since the beginning of the episode.
    • key: JAX array (uint) of shape (2,) JAX random generation key. Ignored since the environment is deterministic.

Example:

1
2
3
4
5
6
7
8
from jumanji.environments import LevelBasedForaging
env = LevelBasedForaging()
key = jax.random.key(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Initialization Args: - generator: A Generator object that generates the initial state of the environment. Defaults to a RandomGenerator with the following parameters: - grid_size: 8 - fov: 8 (full observation of the grid) - num_agents: 2 - num_food: 2 - max_agent_level: 2 - force_coop: True - time_limit: The maximum number of steps in an episode. Defaults to 200. - grid_observation: If True, the observer generates a grid observation (default is False). - normalize_reward: If True, normalizes the reward (default is True). - penalty: The penalty value (default is 0.0). - viewer: Viewer to render the environment. Defaults to LevelBasedForagingViewer.

Source code in jumanji/environments/routing/lbf/env.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def __init__(
    self,
    generator: Optional[RandomGenerator] = None,
    viewer: Optional[Viewer[State]] = None,
    time_limit: int = 100,
    grid_observation: bool = False,
    normalize_reward: bool = True,
    penalty: float = 0.0,
) -> None:
    self._generator = generator or RandomGenerator(
        grid_size=8,
        fov=8,
        num_agents=2,
        num_food=2,
        force_coop=True,
    )
    self.time_limit = time_limit
    self.grid_size: int = self._generator.grid_size
    self.num_agents: int = self._generator.num_agents
    self.num_food: int = self._generator.num_food
    self.fov = self._generator.fov
    self.normalize_reward = normalize_reward
    self.penalty = penalty

    self._observer: Union[VectorObserver, GridObserver]
    if not grid_observation:
        self._observer = VectorObserver(
            fov=self.fov,
            grid_size=self.grid_size,
            num_agents=self.num_agents,
            num_food=self.num_food,
        )
    else:
        self._observer = GridObserver(
            fov=self.fov,
            grid_size=self.grid_size,
            num_agents=self.num_agents,
            num_food=self.num_food,
        )

    super().__init__()

    # create viewer for rendering environment
    self._viewer = viewer or LevelBasedForagingViewer(self.grid_size, "LevelBasedForaging")

action_spec: specs.MultiDiscreteArray cached property #

Returns the action spec for the Level Based Foraging environment.

Returns:

Type Description
MultiDiscreteArray

specs.MultiDiscreteArray: Action spec for the environment with shape (num_agents,).

discount_spec: specs.BoundedArray cached property #

Describes the discount returned by the environment.

Returns:

Name Type Description
discount_spec BoundedArray

a specs.BoundedArray spec.

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the environment.

The spec's shape depends on the observer passed to __init__.

The GridObserver returns an agent's view with a shape of (num_agents, 3, 2 * fov + 1, 2 * fov +1). The VectorObserver returns an agent's view with a shape of (num_agents, 3 * num_food + 3 * num_agents). See a more detailed description of the observations in the docs of GridObserver and VectorObserver.

Returns:

Type Description
Spec[Observation]

specs.Spec[Observation]: Spec for the Observation with fields grid,

Spec[Observation]

action_mask, and step_count.

reward_spec: specs.Array cached property #

Returns the reward specification for the LevelBasedForaging environment.

Since this is a multi-agent environment each agent gets its own reward.

Returns:

Type Description
Array

specs.Array: Reward specification, of shape (num_agents,) for the environment.

animate(states, interval=200, save_path=None) #

Creates an animation from a sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

Sequence of State corresponding to subsequent timesteps.

required
interval int

Delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

The path where the animation file should be saved.

None

Returns:

Type Description
FuncAnimation

matplotlib.animation.FuncAnimation: Animation object that can be saved as a GIF, MP4,

FuncAnimation

or rendered with HTML.

Source code in jumanji/environments/routing/lbf/env.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animation from a sequence of states.

    Args:
        states (Sequence[State]): Sequence of `State` corresponding to subsequent timesteps.
        interval (int): Delay between frames in milliseconds, default to 200.
        save_path (Optional[str]): The path where the animation file should be saved.

    Returns:
        matplotlib.animation.FuncAnimation: Animation object that can be saved as a GIF, MP4,
        or rendered with HTML.
    """
    return self._viewer.animate(states=states, interval=interval, save_path=save_path)

close() #

Perform any necessary cleanup.

Source code in jumanji/environments/routing/lbf/env.py
324
325
326
def close(self) -> None:
    """Perform any necessary cleanup."""
    self._viewer.close()

get_reward(food_items, adj_loading_agents_levels, eaten_this_step) #

Returns a reward for all agents given all food items.

Parameters:

Name Type Description Default
food_items Food

All the food items in the environment.

required
adj_loading_agents_levels Array

The level of all agents adjacent to all foods.

required
eaten_this_step Array

Whether the food was eaten or not (this step).

required
Source code in jumanji/environments/routing/lbf/env.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def get_reward(
    self,
    food_items: Food,
    adj_loading_agents_levels: chex.Array,
    eaten_this_step: chex.Array,
) -> chex.Array:
    """Returns a reward for all agents given all food items.

    Args:
        food_items (Food): All the food items in the environment.
        adj_loading_agents_levels (chex.Array): The level of all agents adjacent to all foods.
        eaten_this_step (chex.Array): Whether the food was eaten or not (this step).
    """

    def get_reward_per_food(
        food: Food,
        adj_loading_agents_levels: chex.Array,
        eaten_this_step: chex.Array,
    ) -> chex.Array:
        """Returns the reward for all agents given a single food."""

        # If the food has already been eaten or is not loaded, the sum will be equal to 0
        sum_agents_levels = jnp.sum(adj_loading_agents_levels)

        # Penalize agents for not being able to cooperate and eat food
        penalty = jnp.where(
            (sum_agents_levels != 0) & (sum_agents_levels < food.level),
            self.penalty,
            0,
        )

        # Zero out all agents if food was not eaten and add penalty
        reward = (adj_loading_agents_levels * eaten_this_step * food.level) - penalty

        # jnp.nan_to_num: Used in the case where no agents are adjacent to the food
        normalizer = sum_agents_levels * total_food_level
        reward = jnp.where(self.normalize_reward, jnp.nan_to_num(reward / normalizer), reward)

        return reward

    # Get reward per food for all food items,
    # then sum it on the agent dimension to get reward per agent.
    total_food_level = jnp.sum(food_items.level)
    reward_per_food = jax.vmap(get_reward_per_food, in_axes=(0, 0, 0))(
        food_items, adj_loading_agents_levels, eaten_this_step
    )
    return jnp.sum(reward_per_food, axis=0)

render(state) #

Renders the current state of the LevelBasedForaging environment.

Parameters:

Name Type Description Default
state State

The current environment state to be rendered.

required

Returns:

Type Description
Optional[NDArray]

Optional[NDArray]: Rendered environment state.

Source code in jumanji/environments/routing/lbf/env.py
294
295
296
297
298
299
300
301
302
303
def render(self, state: State) -> Optional[NDArray]:
    """Renders the current state of the `LevelBasedForaging` environment.

    Args:
        state (State): The current environment state to be rendered.

    Returns:
        Optional[NDArray]: Rendered environment state.
    """
    return self._viewer.render(state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

Used to randomly generate the new State.

required

Returns:

Type Description
State

Tuple[State, TimeStep]: State object corresponding to the new initial state

TimeStep

of the environment and TimeStep object corresponding to the initial timestep.

Source code in jumanji/environments/routing/lbf/env.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep]:
    """Resets the environment.

    Args:
        key (chex.PRNGKey): Used to randomly generate the new `State`.

    Returns:
        Tuple[State, TimeStep]: `State` object corresponding to the new initial state
        of the environment and `TimeStep` object corresponding to the initial timestep.
    """
    state = self._generator(key)
    observation = self._observer.state_to_observation(state)
    timestep = restart(observation, shape=self.num_agents)
    timestep.extras = self._get_extra_info(state, timestep)

    return state, timestep

step(state, actions) #

Simulate one step of the environment.

Parameters:

Name Type Description Default
state State

State containing the dynamics of the environment.

required
actions Array

Array containing the actions to take for each agent.

required

Returns:

Type Description
State

Tuple[State, TimeStep]: State object corresponding to the next state and

TimeStep

TimeStep object corresponding the timestep returned by the environment.

Source code in jumanji/environments/routing/lbf/env.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep]:
    """Simulate one step of the environment.

    Args:
        state (State): State  containing the dynamics of the environment.
        actions (chex.Array): Array containing the actions to take for each agent.

    Returns:
        Tuple[State, TimeStep]: `State` object corresponding to the next state and
        `TimeStep` object corresponding the timestep returned by the environment.
    """
    # Move agents, fix collisions that may happen and set loading status.
    moved_agents = utils.update_agent_positions(
        state.agents, actions, state.food_items, self.grid_size
    )

    # Eat the food
    food_items, eaten_this_step, adj_loading_agents_levels = jax.vmap(
        utils.eat_food, (None, 0)
    )(moved_agents, state.food_items)

    reward = self.get_reward(food_items, adj_loading_agents_levels, eaten_this_step)

    state = State(
        agents=moved_agents,
        food_items=food_items,
        step_count=state.step_count + 1,
        key=state.key,
    )
    observation = self._observer.state_to_observation(state)

    # First condition is truncation, second is termination.
    terminate = jnp.all(state.food_items.eaten)
    truncate = state.step_count >= self.time_limit

    timestep = jax.lax.switch(
        terminate + 2 * truncate,
        [
            # !terminate !trunc
            lambda rew, obs: transition(reward=rew, observation=obs, shape=self.num_agents),
            # terminate !truncate
            lambda rew, obs: termination(reward=rew, observation=obs, shape=self.num_agents),
            # !terminate truncate
            lambda rew, obs: truncation(reward=rew, observation=obs, shape=self.num_agents),
            # terminate truncate
            lambda rew, obs: termination(reward=rew, observation=obs, shape=self.num_agents),
        ],
        reward,
        observation,
    )
    timestep.extras = self._get_extra_info(state, timestep)

    return state, timestep