Level-Based Foraging
LevelBasedForaging (Environment)
#
An implementation of the Level-Based Foraging environment where agents need to cooperate to collect food and split the reward.
Original implementation: https://github.com/semitable/lb-foraging
-
observation
:Observation
agent_views
: Depending on theobserver
passed to__init__
, it can be aGridObserver
or aVectorObserver
.GridObserver
: Returns an agent's view with a shape of (num_agents, 3, 2 * fov + 1, 2 * fov +1).VectorObserver
: Returns an agent's view with a shape of (num_agents, 3 * (num_food + num_agents).
action_mask
: JAX array (bool) of shape (num_agents, 6) indicating for each agent which size actions (no-op, up, down, left, right, load) are allowed.step_count
: int32, the number of steps since the beginning of the episode.
-
action
: JAX array (int32) of shape (num_agents,). The valid actions for each agent are (0: noop, 1: up, 2: down, 3: left, 4: right, 5: load). -
reward
: JAX array (float) of shape (num_agents,) When one or more agents load food, the food level is rewarded to the agents, weighted by the level of each agent. The reward is then normalized so that, at the end, the sum of the rewards (if all food items have been picked up) is one. -
Episode Termination:
- All food items have been eaten.
- The number of steps is greater than the limit.
-
state
:State
agents
: Stacked Pytree ofAgent
objects of lengthnum_agents
.Agent
:id
: JAX array (int32) of shape ().position
: JAX array (int32) of shape (2,).level
: JAX array (int32) of shape ().loading
: JAX array (bool) of shape ().
food_items
: Stacked Pytree ofFood
objects of lengthnum_food
.Food
:id
: JAX array (int32) of shape ().position
: JAX array (int32) of shape (2,).level
: JAX array (int32) of shape ().eaten
: JAX array (bool) of shape ().
step_count
: JAX array (int32) of shape (), the number of steps since the beginning of the episode.key
: JAX array (uint) of shape (2,) JAX random generation key. Ignored since the environment is deterministic.
Examples:
1 2 3 4 5 6 7 8 |
|
Initialization Args:
- generator
: A Generator
object that generates the initial state of the environment.
Defaults to a RandomGenerator
with the following parameters:
- grid_size
: 8
- fov
: 8 (full observation of the grid)
- num_agents
: 2
- num_food
: 2
- max_agent_level
: 2
- force_coop
: True
- time_limit
: The maximum number of steps in an episode. Defaults to 200.
- grid_observation
: If True
, the observer generates a grid observation (default is False
).
- normalize_reward
: If True
, normalizes the reward (default is True
).
- penalty
: The penalty value (default is 0.0).
- viewer
: Viewer to render the environment. Defaults to LevelBasedForagingViewer
.
observation_spec: jumanji.specs.Spec[jumanji.environments.routing.lbf.types.Observation]
cached
property
writable
#
Specifications of the observation of the environment.
The spec's shape depends on the observer
passed to __init__
.
The GridObserver returns an agent's view with a shape of
(num_agents, 3, 2 * fov + 1, 2 * fov +1).
The VectorObserver returns an agent's view with a shape of
(num_agents, 3 * num_food + 3 * num_agents).
See a more detailed description of the observations in the docs
of GridObserver
and VectorObserver
.
Returns:
Type | Description |
---|---|
specs.Spec[Observation] |
Spec for the |
action_spec: MultiDiscreteArray
cached
property
writable
#
Returns the action spec for the Level Based Foraging environment.
Returns:
Type | Description |
---|---|
specs.MultiDiscreteArray |
Action spec for the environment with shape (num_agents,). |
__init__(self, generator: Optional[jumanji.environments.routing.lbf.generator.RandomGenerator] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.lbf.types.State]] = None, time_limit: int = 100, grid_observation: bool = False, normalize_reward: bool = True, penalty: float = 0.0) -> None
special
#
reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.lbf.types.State, jumanji.types.TimeStep]
#
Resets the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
chex.PRNGKey |
Used to randomly generate the new |
required |
Returns:
Type | Description |
---|---|
Tuple[State, TimeStep] |
|
step(self, state: State, actions: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.routing.lbf.types.State, jumanji.types.TimeStep]
#
Simulate one step of the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
State containing the dynamics of the environment. |
required |
actions |
chex.Array |
Array containing the actions to take for each agent. |
required |
Returns:
Type | Description |
---|---|
Tuple[State, TimeStep] |
|
render(self, state: State) -> Optional[numpy.ndarray[Any, numpy.dtype[+ScalarType]]]
#
Renders the current state of the LevelBasedForaging
environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
The current environment state to be rendered. |
required |
Returns:
Type | Description |
---|---|
Optional[NDArray] |
Rendered environment state. |