Skip to content

Connector

Bases: Environment[State, MultiDiscreteArray, Observation]

The Connector environment is a gridworld problem where multiple pairs of points (sets) must be connected without overlapping the paths taken by any other set. This is achieved by allowing certain points to move to an adjacent cell at each step. However, each time a point moves it leaves an impassable trail behind it. The goal is to connect all sets.

  • observation - Observation

    • action mask: jax array (bool) of shape (num_agents, 5).
    • step_count: jax array (int32) of shape () the current episode step.
    • grid: jax array (int32) of shape (grid_size, grid_size)
      • with 2 agents you might have a grid like this: 4 0 1 5 0 1 6 3 2 which means agent 1 has moved from the top right of the grid down and is currently in the bottom right corner and is aiming to get to the middle bottom cell. Agent 2 started in the top left and moved down once towards its target in the bottom left.
  • action: jax array (int32) of shape (num_agents,):

    • can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left].
    • each value in the array corresponds to an agent's action.
  • reward: jax array (float) of shape (num_agents,):

    • dense: for each agent the reward is 1 for each successful connection on that step. Additionally, each pair of points that have not connected receives a penalty reward of -0.03.
  • episode termination:

    • all agents either can't move (no available actions) or have connected to their target.
    • the time limit is reached.
  • state: State:

    • key: jax PRNG key used to randomly spawn agents and targets.
    • grid: jax array (int32) of shape (grid_size, grid_size) giving the observation.
    • step_count: jax array (int32) of shape () number of steps elapsed in the current episode.
1
2
3
4
5
6
7
8
from jumanji.environments import Connector
env = Connector()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Create the Connector environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [UniformRandomGenerator, RandomWalkGenerator]. Defaults to RandomWalkGenerator with grid_size=10 and num_agents=10.

None
reward_fn Optional[RewardFn]

class of type RewardFn, whose __call__ is used as a reward function. Implemented options are [DenseRewardFn]. Defaults to DenseRewardFn.

None
time_limit int

the number of steps allowed before an episode terminates. Defaults to 50.

50
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to ConnectorViewer with "human" render mode.

None
Source code in jumanji/environments/routing/connector/env.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def __init__(
    self,
    generator: Optional[Generator] = None,
    reward_fn: Optional[RewardFn] = None,
    time_limit: int = 50,
    viewer: Optional[Viewer[State]] = None,
) -> None:
    """Create the `Connector` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment instance.
            Implemented options are [`UniformRandomGenerator`, `RandomWalkGenerator`].
            Defaults to `RandomWalkGenerator` with `grid_size=10` and `num_agents=10`.
        reward_fn: class of type `RewardFn`, whose `__call__` is used as a reward function.
            Implemented options are [`DenseRewardFn`]. Defaults to `DenseRewardFn`.
        time_limit: the number of steps allowed before an episode terminates. Defaults to 50.
        viewer: `Viewer` used for rendering. Defaults to `ConnectorViewer` with "human" render
            mode.
    """
    self._generator = generator or RandomWalkGenerator(grid_size=10, num_agents=10)
    self._reward_fn = reward_fn or DenseRewardFn()
    self.time_limit = time_limit
    self.num_agents = self._generator.num_agents
    self.grid_size = self._generator.grid_size
    super().__init__()
    self._agent_ids = jnp.arange(self.num_agents)
    self._viewer = viewer or ConnectorViewer("Connector", self.num_agents, render_mode="human")

action_spec: specs.MultiDiscreteArray cached property #

Returns the action spec for the Connector environment.

5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,).

Returns:

Name Type Description
observation_spec MultiDiscreteArray

MultiDiscreteArray of shape (num_agents,).

discount_spec: specs.BoundedArray cached property #

Returns: discount per agent.

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the Connector environment.

Returns:

Type Description
Spec[Observation]

Spec for the Observation whose fields are:

Spec[Observation]
  • grid: BoundedArray (int32) of shape (grid_size, grid_size).
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (num_agents, 5).
Spec[Observation]
  • step_count: BoundedArray (int32) of shape ().

reward_spec: specs.Array cached property #

Returns: a reward per agent.

animate(states, interval=200, save_path=None) #

Create an animation from a sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of State corresponding to subsequent timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation that can export to gif, mp4, or render with HTML.

Source code in jumanji/environments/routing/connector/env.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Create an animation from a sequence of states.

    Args:
        states: sequence of `State` corresponding to subsequent timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation that can export to gif, mp4, or render with HTML.
    """
    grids = [state.grid for state in states]
    return self._viewer.animate(grids, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/routing/connector/env.py
308
309
310
311
312
313
314
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Render the given state of the environment.

Parameters:

Name Type Description Default
state State

State object containing the current environment state.

required
Source code in jumanji/environments/routing/connector/env.py
280
281
282
283
284
285
286
def render(self, state: State) -> Optional[NDArray]:
    """Render the given state of the environment.

    Args:
        state: `State` object containing the current environment state.
    """
    return self._viewer.render(state.grid)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

used to randomly generate the connector grid.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment.

timestep TimeStep[Observation]

TimeStep object corresponding to the initial environment timestep.

Source code in jumanji/environments/routing/connector/env.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: used to randomly generate the connector grid.

    Returns:
        state: `State` object corresponding to the new state of the environment.
        timestep: `TimeStep` object corresponding to the initial environment timestep.
    """
    state = self._generator(key)

    action_mask = jax.vmap(self._get_action_mask, (0, None))(state.agents, state.grid)
    observation = Observation(
        grid=state.grid,
        action_mask=action_mask,
        step_count=state.step_count,
    )
    extras = self._get_extras(state)
    timestep = restart(observation=observation, extras=extras, shape=(self.num_agents,))
    return state, timestep

step(state, action) #

Perform an environment step.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

Array containing the actions to take for each agent. - 0 no op - 1 move up - 2 move right - 3 move down - 4 move left

required

Returns:

Name Type Description
state State

State object corresponding to the next state of the environment.

timestep TimeStep[Observation]

TimeStep object corresponding the timestep returned by the environment.

Source code in jumanji/environments/routing/connector/env.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Perform an environment step.

    Args:
        state: State object containing the dynamics of the environment.
        action: Array containing the actions to take for each agent.
            - 0 no op
            - 1 move up
            - 2 move right
            - 3 move down
            - 4 move left

    Returns:
        state: `State` object corresponding to the next state of the environment.
        timestep: `TimeStep` object corresponding the timestep returned by the environment.
    """
    agents, grid = self._step_agents(state, action)
    new_state = State(grid=grid, step_count=state.step_count + 1, agents=agents, key=state.key)

    # Construct timestep: get reward, legal actions and done
    reward = self._reward_fn(state, action, new_state)
    action_mask = jax.vmap(self._get_action_mask, (0, None))(agents, grid)
    observation = Observation(
        grid=grid, action_mask=action_mask, step_count=new_state.step_count
    )

    done = jax.vmap(connected_or_blocked)(agents, action_mask)
    discount = (1 - done).astype(float)
    extras = self._get_extras(new_state)
    timestep = jax.lax.cond(
        jnp.all(done) | (new_state.step_count >= self.time_limit),
        lambda: termination(
            reward=reward,
            observation=observation,
            extras=extras,
            shape=(self.num_agents,),
        ),
        lambda: transition(
            reward=reward,
            observation=observation,
            extras=extras,
            discount=discount,
            shape=(self.num_agents,),
        ),
    )

    return new_state, timestep