Connector
Bases: Environment[State, MultiDiscreteArray, Observation]
The Connector
environment is a gridworld problem where multiple pairs of points (sets)
must be connected without overlapping the paths taken by any other set. This is achieved
by allowing certain points to move to an adjacent cell at each step. However, each time a
point moves it leaves an impassable trail behind it. The goal is to connect all sets.
-
observation -
Observation
- action mask: jax array (bool) of shape (num_agents, 5).
- step_count: jax array (int32) of shape () the current episode step.
- grid: jax array (int32) of shape (grid_size, grid_size)
- with 2 agents you might have a grid like this: 4 0 1 5 0 1 6 3 2 which means agent 1 has moved from the top right of the grid down and is currently in the bottom right corner and is aiming to get to the middle bottom cell. Agent 2 started in the top left and moved down once towards its target in the bottom left.
-
action: jax array (int32) of shape (num_agents,):
- can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left].
- each value in the array corresponds to an agent's action.
-
reward: jax array (float) of shape (num_agents,):
- dense: for each agent the reward is 1 for each successful connection on that step. Additionally, each pair of points that have not connected receives a penalty reward of -0.03.
-
episode termination:
- all agents either can't move (no available actions) or have connected to their target.
- the time limit is reached.
-
state: State:
- key: jax PRNG key used to randomly spawn agents and targets.
- grid: jax array (int32) of shape (grid_size, grid_size) giving the observation.
- step_count: jax array (int32) of shape () number of steps elapsed in the current episode.
1 2 3 4 5 6 7 8 |
|
Create the Connector
environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
generator
|
Optional[Generator]
|
|
None
|
reward_fn
|
Optional[RewardFn]
|
class of type |
None
|
time_limit
|
int
|
the number of steps allowed before an episode terminates. Defaults to 50. |
50
|
viewer
|
Optional[Viewer[State]]
|
|
None
|
Source code in jumanji/environments/routing/connector/env.py
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
|
action_spec: specs.MultiDiscreteArray
cached
property
#
Returns the action spec for the Connector environment.
5 actions: [0,1,2,3,4] -> [No Op, Up, Right, Down, Left]. Since this is an environment with a multi-dimensional action space, it expects an array of actions of shape (num_agents,).
Returns:
Name | Type | Description |
---|---|---|
observation_spec |
MultiDiscreteArray
|
|
discount_spec: specs.BoundedArray
cached
property
#
Returns: discount per agent.
observation_spec: specs.Spec[Observation]
cached
property
#
Specifications of the observation of the Connector
environment.
Returns:
Type | Description |
---|---|
Spec[Observation]
|
Spec for the |
Spec[Observation]
|
|
Spec[Observation]
|
|
Spec[Observation]
|
|
reward_spec: specs.Array
cached
property
#
Returns: a reward per agent.
animate(states, interval=200, save_path=None)
#
Create an animation from a sequence of states.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
states
|
Sequence[State]
|
sequence of |
required |
interval
|
int
|
delay between frames in milliseconds, default to 200. |
200
|
save_path
|
Optional[str]
|
the path where the animation file should be saved. If it is None, the plot will not be saved. |
None
|
Returns:
Type | Description |
---|---|
FuncAnimation
|
animation that can export to gif, mp4, or render with HTML. |
Source code in jumanji/environments/routing/connector/env.py
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
|
close()
#
Perform any necessary cleanup.
Environments will automatically :meth:close()
themselves when
garbage collected or when the program exits.
Source code in jumanji/environments/routing/connector/env.py
308 309 310 311 312 313 314 |
|
render(state)
#
Render the given state of the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
|
required |
Source code in jumanji/environments/routing/connector/env.py
280 281 282 283 284 285 286 |
|
reset(key)
#
Resets the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
PRNGKey
|
used to randomly generate the connector grid. |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
|
timestep |
TimeStep[Observation]
|
|
Source code in jumanji/environments/routing/connector/env.py
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
|
step(state, action)
#
Perform an environment step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
State object containing the dynamics of the environment. |
required |
action
|
Array
|
Array containing the actions to take for each agent. - 0 no op - 1 move up - 2 move right - 3 move down - 4 move left |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
|
timestep |
TimeStep[Observation]
|
|
Source code in jumanji/environments/routing/connector/env.py
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
|