Connector
Connector (Environment)
#
The Connector
environment is a gridworld problem where multiple pairs of points (sets)
must be connected without overlapping the paths taken by any other set. This is achieved
by allowing certain points to move to an adjacent cell at each step. However, each time a
point moves it leaves an impassable trail behind it. The goal is to connect all sets.
-
observation -
Observation
- action mask: jax array (bool) of shape (num_agents, 5).
- step_count: jax array (int32) of shape () the current episode step.
- grid: jax array (int32) of shape (grid_size, grid_size)
- with 2 agents you might have a grid like this: 4 0 1 5 0 1 6 3 2 which means agent 1 has moved from the top right of the grid down and is currently in the bottom right corner and is aiming to get to the middle bottom cell. Agent 2 started in the top left and moved down once towards its target in the bottom left.
-
action: jax array (int32) of shape (num_agents,):
- can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left].
- each value in the array corresponds to an agent's action.
-
reward: jax array (float) of shape ():
- dense: reward is 1 for each successful connection on that step. Additionally, each pair of points that have not connected receives a penalty reward of -0.03.
-
episode termination:
- all agents either can't move (no available actions) or have connected to their target.
- the time limit is reached.
-
state: State:
- key: jax PRNG key used to randomly spawn agents and targets.
- grid: jax array (int32) of shape (grid_size, grid_size) giving the observation.
- step_count: jax array (int32) of shape () number of steps elapsed in the current episode.
1 2 3 4 5 6 7 8 |
|
observation_spec: jumanji.specs.Spec[jumanji.environments.routing.connector.types.Observation]
cached
property
writable
#
Specifications of the observation of the Connector
environment.
Returns:
Type | Description |
---|---|
Spec for the `Observation` whose fields are |
|
action_spec: MultiDiscreteArray
cached
property
writable
#
Returns the action spec for the Connector environment.
5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,).
Returns:
Type | Description |
---|---|
observation_spec |
|
__init__(self, generator: Optional[jumanji.environments.routing.connector.generator.Generator] = None, reward_fn: Optional[jumanji.environments.routing.connector.reward.RewardFn] = None, time_limit: int = 50, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.connector.types.State]] = None) -> None
special
#
Create the Connector
environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
generator |
Optional[jumanji.environments.routing.connector.generator.Generator] |
|
None |
reward_fn |
Optional[jumanji.environments.routing.connector.reward.RewardFn] |
class of type |
None |
time_limit |
int |
the number of steps allowed before an episode terminates. Defaults to 50. |
50 |
viewer |
Optional[jumanji.viewer.Viewer[jumanji.environments.routing.connector.types.State]] |
|
None |
reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.connector.types.State, jumanji.types.TimeStep[jumanji.environments.routing.connector.types.Observation]]
#
Resets the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray |
used to randomly generate the connector grid. |
required |
Returns:
Type | Description |
---|---|
state |
|
step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.routing.connector.types.State, jumanji.types.TimeStep[jumanji.environments.routing.connector.types.Observation]]
#
Perform an environment step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
State object containing the dynamics of the environment. |
required |
action |
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] |
Array containing the actions to take for each agent. - 0 no op - 1 move up - 2 move right - 3 move down - 4 move left |
required |
Returns:
Type | Description |
---|---|
state |
|
render(self, state: State) -> Optional[numpy.ndarray[Any, numpy.dtype[+ScalarType]]]
#
Render the given state of the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
|
required |