MMST
Bases: Environment[State, MultiDiscreteArray, Observation]
The MMST
(Multi Minimum Spanning Tree) environment
consists of a random connected graph
with groups of nodes (same node types) that needs to be connected.
The goal of the environment is to connect all nodes of the same type together
without using the same utility nodes (nodes that do not belong to any group of nodes).
Note: routing problems are randomly generated and may not be solvable!
Requirements: The total number of nodes should be at least 20% more than the number of nodes we want to connect to guarantee we have enough remaining nodes to create a path with all the nodes we want to connect. An exception will be raised if the number of nodes is not greater than (0.8 x num_agents x num_nodes_per_agent).
-
observation: Observation
- node_types: jax array (int) of shape (num_nodes): the component type of each node (-1 represents utility nodes).
- adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph.
- positions: jax array (int) of shape (num_agents,): the index of the last visited node.
- step_count: jax array (int) of shape (): integer to keep track of the number of steps.
- action_mask: jax array (bool) of shape (num_agent, num_nodes): binary mask (False/True <--> invalid/valid action).
-
reward: float
-
action: jax array (int) of shape (num_agents,): [0,1,..., num_nodes-1] Each agent selects the next node to which it wants to connect.
-
state: State
- node_type: jax array (int) of shape (num_nodes,). the component type of each node (-1 represents utility nodes).
- adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph.
- connected_nodes: jax array (int) of shape (num_agents, time_limit). we only count each node visit once.
- connected_nodes_index: jax array (int) of shape (num_agents, num_nodes).
- position_index: jax array (int) of shape (num_agents,).
- node_edges: jax array (int) of shape (num_agents, num_nodes, num_nodes).
- positions: jax array (int) of shape (num_agents,). the index of the last visited node.
- action_mask: jax array (bool) of shape (num_agent, num_nodes). binary mask (False/True <--> invalid/valid action).
- finished_agents: jax array (bool) of shape (num_agent,).
- nodes_to_connect: jax array (int) of shape (num_agents, num_nodes_per_agent).
- step_count: step counter.
- time_limit: the number of steps allowed before an episode terminates.
- key: PRNG key for random sample.
-
constants definitions:
-
Nodes
- INVALID_NODE = -1: used to check if an agent selects an invalid node. A node may be invalid if its has no edge with the current node or if it is a utility node already selected by another agent.
- UTILITY_NODE = -1: utility node (belongs to no agent).
- EMPTY_NODE = -1: used for padding. state.connected_nodes stores the path (all the nodes) visited by an agent. Hence it has size equal to the step limit. We use this constant to initialise this array since 0 represents the first node.
- DUMMY_NODE = -10: used for tie-breaking if multiple agents select the same node.
-
Edges
- EMPTY_EDGE = -1: used for masking edges array.
state.node_edges is the graph's adjacency matrix, but we don't represent it
using 0s and 1s, we use the node values instead, i.e
A_ij = j
orA_ij = -1
. Also edges are masked when utility nodes are selected by an agent to make it unaccessible by other agents.
- EMPTY_EDGE = -1: used for masking edges array.
state.node_edges is the graph's adjacency matrix, but we don't represent it
using 0s and 1s, we use the node values instead, i.e
-
Actions encoding
- INVALID_CHOICE = -1
- INVALID_TIE_BREAK = -2
- INVALID_ALREADY_TRAVERSED = -3
-
1 2 3 4 5 6 7 8 |
|
Create the MMST
environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
generator
|
Optional[Generator]
|
|
None
|
reward_fn
|
Optional[RewardFn]
|
class of type |
None
|
time_limit
|
int
|
the number of steps allowed before an episode terminates. Defaults to 70. |
70
|
viewer
|
Optional[Viewer[State]]
|
|
None
|
Source code in jumanji/environments/routing/mmst/env.py
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
|
action_spec: specs.MultiDiscreteArray
cached
property
#
Returns the action spec.
Returns:
Name | Type | Description |
---|---|---|
action_spec |
MultiDiscreteArray
|
a |
observation_spec: specs.Spec[Observation]
cached
property
#
Returns the observation spec.
Returns:
Type | Description |
---|---|
Spec[Observation]
|
Spec for the |
Spec[Observation]
|
|
Spec[Observation]
|
|
Spec[Observation]
|
|
Spec[Observation]
|
|
animate(states, interval=200, save_path=None)
#
Calls the environment renderer to animate a sequence of states.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
states
|
Sequence[State]
|
List of states to animate. |
required |
interval
|
int
|
Time between frames in milliseconds, defaults to 200. |
200
|
save_path
|
Optional[str]
|
Optional path to save the animation. |
None
|
Source code in jumanji/environments/routing/mmst/env.py
625 626 627 628 629 630 631 632 633 634 635 636 637 638 |
|
get_finished_agents(state)
#
Get the done flags for each agent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node_types
|
the environment state node_types. |
required | |
connected_nodes
|
the agent specifc view of connected nodes |
required |
Returns: Array : array of boolean flags in the shape (number of agents, ).
Source code in jumanji/environments/routing/mmst/env.py
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 |
|
render(state)
#
Render the environment for a given state.
Returns:
Type | Description |
---|---|
Array
|
Array of rgb pixel values in the shape (width, height, rgb). |
Source code in jumanji/environments/routing/mmst/env.py
617 618 619 620 621 622 623 |
|
reset(key)
#
Resets the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
PRNGKey
|
used to randomly generate the problem and the different start nodes. |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
State object corresponding to the new state of the environment. |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding to the first timestep returned by the environment. |
Source code in jumanji/environments/routing/mmst/env.py
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
|
step(state, action)
#
Run one timestep of the environment's dynamics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
State object containing the dynamics of the environment. |
required |
action
|
Array
|
Array containing the index of the next node to visit. |
required |
Returns:
Type | Description |
---|---|
Tuple[State, TimeStep[Observation]]
|
state, timestep: Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed. |
Source code in jumanji/environments/routing/mmst/env.py
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
|