Skip to content

Level-Based Foraging

LevelBasedForaging (Environment) #

An implementation of the Level-Based Foraging environment where agents need to cooperate to collect food and split the reward.

Original implementation: https://github.com/semitable/lb-foraging

  • observation: Observation

    • agent_views: Depending on the observer passed to __init__, it can be a GridObserver or a VectorObserver.
      • GridObserver: Returns an agent's view with a shape of (num_agents, 3, 2 * fov + 1, 2 * fov +1).
      • VectorObserver: Returns an agent's view with a shape of (num_agents, 3 * (num_food + num_agents).
    • action_mask: JAX array (bool) of shape (num_agents, 6) indicating for each agent which size actions (no-op, up, down, left, right, load) are allowed.
    • step_count: int32, the number of steps since the beginning of the episode.
  • action: JAX array (int32) of shape (num_agents,). The valid actions for each agent are (0: noop, 1: up, 2: down, 3: left, 4: right, 5: load).

  • reward: JAX array (float) of shape (num_agents,) When one or more agents load food, the food level is rewarded to the agents, weighted by the level of each agent. The reward is then normalized so that, at the end, the sum of the rewards (if all food items have been picked up) is one.

  • Episode Termination:

    • All food items have been eaten.
    • The number of steps is greater than the limit.
  • state: State

    • agents: Stacked Pytree of Agent objects of length num_agents.
      • Agent:
        • id: JAX array (int32) of shape ().
        • position: JAX array (int32) of shape (2,).
        • level: JAX array (int32) of shape ().
        • loading: JAX array (bool) of shape ().
    • food_items: Stacked Pytree of Food objects of length num_food.
      • Food:
        • id: JAX array (int32) of shape ().
        • position: JAX array (int32) of shape (2,).
        • level: JAX array (int32) of shape ().
        • eaten: JAX array (bool) of shape ().
    • step_count: JAX array (int32) of shape (), the number of steps since the beginning of the episode.
    • key: JAX array (uint) of shape (2,) JAX random generation key. Ignored since the environment is deterministic.

Examples:

1
2
3
4
5
6
7
8
from jumanji.environments import LevelBasedForaging
env = LevelBasedForaging()
key = jax.random.key(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Initialization Args: - generator: A Generator object that generates the initial state of the environment. Defaults to a RandomGenerator with the following parameters: - grid_size: 8 - fov: 8 (full observation of the grid) - num_agents: 2 - num_food: 2 - max_agent_level: 2 - force_coop: True - time_limit: The maximum number of steps in an episode. Defaults to 200. - grid_observation: If True, the observer generates a grid observation (default is False). - normalize_reward: If True, normalizes the reward (default is True). - penalty: The penalty value (default is 0.0). - viewer: Viewer to render the environment. Defaults to LevelBasedForagingViewer.

observation_spec: jumanji.specs.Spec[jumanji.environments.routing.lbf.types.Observation] cached property writable #

Specifications of the observation of the environment.

The spec's shape depends on the observer passed to __init__.

The GridObserver returns an agent's view with a shape of (num_agents, 3, 2 * fov + 1, 2 * fov +1). The VectorObserver returns an agent's view with a shape of (num_agents, 3 * num_food + 3 * num_agents). See a more detailed description of the observations in the docs of GridObserver and VectorObserver.

Returns:

Type Description
specs.Spec[Observation]

Spec for the Observation with fields grid, action_mask, and step_count.

action_spec: MultiDiscreteArray cached property writable #

Returns the action spec for the Level Based Foraging environment.

Returns:

Type Description
specs.MultiDiscreteArray

Action spec for the environment with shape (num_agents,).

__init__(self, generator: Optional[jumanji.environments.routing.lbf.generator.RandomGenerator] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.lbf.types.State]] = None, time_limit: int = 100, grid_observation: bool = False, normalize_reward: bool = True, penalty: float = 0.0) -> None special #

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.lbf.types.State, jumanji.types.TimeStep] #

Resets the environment.

Parameters:

Name Type Description Default
key chex.PRNGKey

Used to randomly generate the new State.

required

Returns:

Type Description
Tuple[State, TimeStep]

State object corresponding to the new initial state of the environment and TimeStep object corresponding to the initial timestep.

step(self, state: State, actions: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.routing.lbf.types.State, jumanji.types.TimeStep] #

Simulate one step of the environment.

Parameters:

Name Type Description Default
state State

State containing the dynamics of the environment.

required
actions chex.Array

Array containing the actions to take for each agent.

required

Returns:

Type Description
Tuple[State, TimeStep]

State object corresponding to the next state and TimeStep object corresponding the timestep returned by the environment.

render(self, state: State) -> Optional[numpy.ndarray[Any, numpy.dtype[+ScalarType]]] #

Renders the current state of the LevelBasedForaging environment.

Parameters:

Name Type Description Default
state State

The current environment state to be rendered.

required

Returns:

Type Description
Optional[NDArray]

Rendered environment state.


Last update: 2024-11-01
Back to top