Skip to content

Connector

Connector (Environment) #

The Connector environment is a gridworld problem where multiple pairs of points (sets) must be connected without overlapping the paths taken by any other set. This is achieved by allowing certain points to move to an adjacent cell at each step. However, each time a point moves it leaves an impassable trail behind it. The goal is to connect all sets.

  • observation - Observation

    • action mask: jax array (bool) of shape (num_agents, 5).
    • step_count: jax array (int32) of shape () the current episode step.
    • grid: jax array (int32) of shape (grid_size, grid_size)
      • with 2 agents you might have a grid like this: 4 0 1 5 0 1 6 3 2 which means agent 1 has moved from the top right of the grid down and is currently in the bottom right corner and is aiming to get to the middle bottom cell. Agent 2 started in the top left and moved down once towards its target in the bottom left.
  • action: jax array (int32) of shape (num_agents,):

    • can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left].
    • each value in the array corresponds to an agent's action.
  • reward: jax array (float) of shape ():

    • dense: reward is 1 for each successful connection on that step. Additionally, each pair of points that have not connected receives a penalty reward of -0.03.
  • episode termination:

    • all agents either can't move (no available actions) or have connected to their target.
    • the time limit is reached.
  • state: State:

    • key: jax PRNG key used to randomly spawn agents and targets.
    • grid: jax array (int32) of shape (grid_size, grid_size) giving the observation.
    • step_count: jax array (int32) of shape () number of steps elapsed in the current episode.
1
2
3
4
5
6
7
8
from jumanji.environments import Connector
env = Connector()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.routing.connector.types.Observation] cached property writable #

Specifications of the observation of the Connector environment.

Returns:

Type Description
Spec for the `Observation` whose fields are
  • grid: BoundedArray (int32) of shape (grid_size, grid_size).
  • action_mask: BoundedArray (bool) of shape (num_agents, 5).
  • step_count: BoundedArray (int32) of shape ().

action_spec: MultiDiscreteArray cached property writable #

Returns the action spec for the Connector environment.

5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,).

Returns:

Type Description
observation_spec

MultiDiscreteArray of shape (num_agents,).

__init__(self, generator: Optional[jumanji.environments.routing.connector.generator.Generator] = None, reward_fn: Optional[jumanji.environments.routing.connector.reward.RewardFn] = None, time_limit: int = 50, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.connector.types.State]] = None) -> None special #

Create the Connector environment.

Parameters:

Name Type Description Default
generator Optional[jumanji.environments.routing.connector.generator.Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [UniformRandomGenerator, RandomWalkGenerator]. Defaults to RandomWalkGenerator with grid_size=10 and num_agents=10.

None
reward_fn Optional[jumanji.environments.routing.connector.reward.RewardFn]

class of type RewardFn, whose __call__ is used as a reward function. Implemented options are [DenseRewardFn]. Defaults to DenseRewardFn.

None
time_limit int

the number of steps allowed before an episode terminates. Defaults to 50.

50
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.connector.types.State]]

Viewer used for rendering. Defaults to ConnectorViewer with "human" render mode.

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.connector.types.State, jumanji.types.TimeStep[jumanji.environments.routing.connector.types.Observation]] #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKeyArray

used to randomly generate the connector grid.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the initial environment timestep.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.routing.connector.types.State, jumanji.types.TimeStep[jumanji.environments.routing.connector.types.Observation]] #

Perform an environment step.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

Array containing the actions to take for each agent. - 0 no op - 1 move up - 2 move right - 3 move down - 4 move left

required

Returns:

Type Description
state

State object corresponding to the next state of the environment. timestep: TimeStep object corresponding the timestep returned by the environment.

render(self, state: State) -> Optional[numpy.ndarray[Any, numpy.dtype[+ScalarType]]] #

Render the given state of the environment.

Parameters:

Name Type Description Default
state State

State object containing the current environment state.

required

Last update: 2024-11-01
Back to top