Skip to content

MMST

Bases: Environment[State, MultiDiscreteArray, Observation]

The MMST (Multi Minimum Spanning Tree) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes).

Note: routing problems are randomly generated and may not be solvable!

Requirements: The total number of nodes should be at least 20% more than the number of nodes we want to connect to guarantee we have enough remaining nodes to create a path with all the nodes we want to connect. An exception will be raised if the number of nodes is not greater than (0.8 x num_agents x num_nodes_per_agent).

  • observation: Observation

    • node_types: jax array (int) of shape (num_nodes): the component type of each node (-1 represents utility nodes).
    • adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph.
    • positions: jax array (int) of shape (num_agents,): the index of the last visited node.
    • step_count: jax array (int) of shape (): integer to keep track of the number of steps.
    • action_mask: jax array (bool) of shape (num_agent, num_nodes): binary mask (False/True <--> invalid/valid action).
  • reward: float

  • action: jax array (int) of shape (num_agents,): [0,1,..., num_nodes-1] Each agent selects the next node to which it wants to connect.

  • state: State

    • node_type: jax array (int) of shape (num_nodes,). the component type of each node (-1 represents utility nodes).
    • adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph.
    • connected_nodes: jax array (int) of shape (num_agents, time_limit). we only count each node visit once.
    • connected_nodes_index: jax array (int) of shape (num_agents, num_nodes).
    • position_index: jax array (int) of shape (num_agents,).
    • node_edges: jax array (int) of shape (num_agents, num_nodes, num_nodes).
    • positions: jax array (int) of shape (num_agents,). the index of the last visited node.
    • action_mask: jax array (bool) of shape (num_agent, num_nodes). binary mask (False/True <--> invalid/valid action).
    • finished_agents: jax array (bool) of shape (num_agent,).
    • nodes_to_connect: jax array (int) of shape (num_agents, num_nodes_per_agent).
    • step_count: step counter.
    • time_limit: the number of steps allowed before an episode terminates.
    • key: PRNG key for random sample.
  • constants definitions:

    • Nodes

      • INVALID_NODE = -1: used to check if an agent selects an invalid node. A node may be invalid if its has no edge with the current node or if it is a utility node already selected by another agent.
      • UTILITY_NODE = -1: utility node (belongs to no agent).
      • EMPTY_NODE = -1: used for padding. state.connected_nodes stores the path (all the nodes) visited by an agent. Hence it has size equal to the step limit. We use this constant to initialise this array since 0 represents the first node.
      • DUMMY_NODE = -10: used for tie-breaking if multiple agents select the same node.
    • Edges

      • EMPTY_EDGE = -1: used for masking edges array. state.node_edges is the graph's adjacency matrix, but we don't represent it using 0s and 1s, we use the node values instead, i.e A_ij = j or A_ij = -1. Also edges are masked when utility nodes are selected by an agent to make it unaccessible by other agents.
    • Actions encoding

      • INVALID_CHOICE = -1
      • INVALID_TIE_BREAK = -2
      • INVALID_ALREADY_TRAVERSED = -3
1
2
3
4
5
6
7
8
from jumanji.environments import MMST
env = MMST()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Create the MMST environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [SplitRandomGenerator]. Defaults to SplitRandomGenerator(num_nodes=36, num_edges=72, max_degree=5, num_agents=3, num_nodes_per_agent=4, max_step=time_limit).

None
reward_fn Optional[RewardFn]

class of type RewardFn, whose __call__ is used as a reward function. Implemented options are [DenseRewardFn]. Defaults to DenseRewardFn(reward_values=(10.0, -1.0, -1.0)).

None
time_limit int

the number of steps allowed before an episode terminates. Defaults to 70.

70
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to MMSTViewer

None
Source code in jumanji/environments/routing/mmst/env.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def __init__(
    self,
    generator: Optional[Generator] = None,
    reward_fn: Optional[RewardFn] = None,
    time_limit: int = 70,
    viewer: Optional[Viewer[State]] = None,
):
    """Create the `MMST` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment instance.
            Implemented options are [`SplitRandomGenerator`].
            Defaults to `SplitRandomGenerator(num_nodes=36, num_edges=72, max_degree=5,
            num_agents=3, num_nodes_per_agent=4, max_step=time_limit)`.
        reward_fn: class of type `RewardFn`, whose `__call__` is used as a reward function.
            Implemented options are [`DenseRewardFn`].
            Defaults to `DenseRewardFn(reward_values=(10.0, -1.0, -1.0))`.
        time_limit: the number of steps allowed before an episode terminates. Defaults to 70.
        viewer: `Viewer` used for rendering. Defaults to `MMSTViewer`
    """

    self._generator = generator or SplitRandomGenerator(
        num_nodes=36,
        num_edges=72,
        max_degree=5,
        num_agents=3,
        num_nodes_per_agent=4,
        max_step=time_limit,
    )

    self.num_agents = self._generator.num_agents
    self.num_nodes = self._generator.num_nodes
    self.num_nodes_per_agent = self._generator.num_nodes_per_agent

    self._reward_fn = reward_fn or DenseRewardFn(reward_values=(10.0, -1.0, -1.0))

    self._env_viewer = viewer or MMSTViewer(num_agents=self.num_agents)
    self.time_limit = time_limit
    super().__init__()

action_spec: specs.MultiDiscreteArray cached property #

Returns the action spec.

Returns:

Name Type Description
action_spec MultiDiscreteArray

a specs.MultiDiscreteArray spec.

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec.

Returns:

Type Description
Spec[Observation]

Spec for the Observation whose fields are:

Spec[Observation]
  • node_types: BoundedArray (int32) of shape (num_nodes,).
Spec[Observation]
  • adj_matrix: BoundedArray (int) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph.
Spec[Observation]
  • positions: BoundedArray (int32) of shape (num_agents). Current node position of agent.
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (num_agents, num_nodes,). Represents the valid actions in the current state.

animate(states, interval=200, save_path=None) #

Calls the environment renderer to animate a sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

List of states to animate.

required
interval int

Time between frames in milliseconds, defaults to 200.

200
save_path Optional[str]

Optional path to save the animation.

None
Source code in jumanji/environments/routing/mmst/env.py
625
626
627
628
629
630
631
632
633
634
635
636
637
638
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Calls the environment renderer to animate a sequence of states.

    Args:
        states: List of states to animate.
        interval: Time between frames in milliseconds, defaults to 200.
        save_path: Optional path to save the animation.
    """
    return self._env_viewer.animate(states, interval, save_path)

get_finished_agents(state) #

Get the done flags for each agent.

Parameters:

Name Type Description Default
node_types

the environment state node_types.

required
connected_nodes

the agent specifc view of connected nodes

required

Returns: Array : array of boolean flags in the shape (number of agents, ).

Source code in jumanji/environments/routing/mmst/env.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
def get_finished_agents(self, state: State) -> chex.Array:
    """Get the done flags for each agent.

    Args:
        node_types: the environment state node_types.
        connected_nodes: the agent specifc view of connected nodes
    Returns:
        Array : array of boolean flags in the shape (number of agents, ).
    """

    def done_fun(nodes: chex.Array, connected_nodes: chex.Array, n_comps: int) -> jnp.bool_:
        connects = jnp.isin(nodes, connected_nodes)
        return jnp.sum(connects) == n_comps

    finished_agents = jax.vmap(done_fun, in_axes=(0, 0, None))(
        state.nodes_to_connect,
        state.connected_nodes,
        self.num_nodes_per_agent,
    )

    return finished_agents

render(state) #

Render the environment for a given state.

Returns:

Type Description
Array

Array of rgb pixel values in the shape (width, height, rgb).

Source code in jumanji/environments/routing/mmst/env.py
617
618
619
620
621
622
623
def render(self, state: State) -> chex.Array:
    """Render the environment for a given state.

    Returns:
        Array of rgb pixel values in the shape (width, height, rgb).
    """
    return self._env_viewer.render(state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

used to randomly generate the problem and the different start nodes.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment.

timestep TimeStep[Observation]

TimeStep object corresponding to the first timestep returned by the environment.

Source code in jumanji/environments/routing/mmst/env.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: used to randomly generate the problem and the different start nodes.

    Returns:
         state: State object corresponding to the new state of the environment.
         timestep: TimeStep object corresponding to the first timestep returned by the
            environment.
    """

    key, problem_key = jax.random.split(key)
    state = self._generator(problem_key)
    extras = self._get_extras(state)
    timestep = restart(observation=self._state_to_observation(state), extras=extras)
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

Array containing the index of the next node to visit.

required

Returns:

Type Description
Tuple[State, TimeStep[Observation]]

state, timestep: Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed.

Source code in jumanji/environments/routing/mmst/env.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: State object containing the dynamics of the environment.
        action: Array containing the index of the next node to visit.

    Returns:
        state, timestep: Tuple[State, TimeStep] containing the next state of the
           environment, as well as the timestep to be observed.
    """

    def step_agent_fn(
        connected_nodes: chex.Array,
        conn_index: chex.Array,
        action: chex.Array,
        node: int,
        indices: chex.Array,
        agent_id: int,
    ) -> Tuple[chex.Array, ...]:
        is_invalid_choice = jnp.any(action == INVALID_CHOICE) | jnp.any(
            action == INVALID_TIE_BREAK
        )
        is_valid = (~is_invalid_choice) & (node != INVALID_NODE)
        connected_nodes, conn_index, new_node, indices = jax.lax.cond(
            is_valid,
            self._update_conected_nodes,
            lambda *_: (
                connected_nodes,
                conn_index,
                state.positions[agent_id],
                indices,
            ),
            connected_nodes,
            conn_index,
            node,
            indices,
        )

        return connected_nodes, conn_index, new_node, indices

    key, step_key = jax.random.split(state.key)
    action, next_nodes = self._trim_duplicated_invalid_actions(state, action, step_key)

    connected_nodes = jnp.zeros_like(state.connected_nodes)
    connected_nodes_index = jnp.zeros_like(state.connected_nodes_index)
    agents_pos = jnp.zeros_like(state.positions)
    position_index = jnp.zeros_like(state.position_index)

    for agent in range(self.num_agents):
        conn_nodes_i, conn_nodes_id, pos_i, pos_ind = step_agent_fn(
            state.connected_nodes[agent],
            state.connected_nodes_index[agent],
            action[agent],
            next_nodes[agent],
            state.position_index[agent],
            agent,
        )

        connected_nodes = connected_nodes.at[agent].set(conn_nodes_i)
        connected_nodes_index = connected_nodes_index.at[agent].set(conn_nodes_id)
        agents_pos = agents_pos.at[agent].set(pos_i)
        position_index = position_index.at[agent].set(pos_ind)

    active_node_edges = update_active_edges(
        self.num_agents, state.node_edges, agents_pos, state.node_types
    )

    state = State(
        node_types=state.node_types,
        adj_matrix=state.adj_matrix,
        nodes_to_connect=state.nodes_to_connect,
        connected_nodes=connected_nodes,
        connected_nodes_index=connected_nodes_index,
        position_index=position_index,
        positions=agents_pos,
        node_edges=active_node_edges,
        action_mask=make_action_mask(
            self.num_agents,
            self.num_nodes,
            active_node_edges,
            agents_pos,
            state.finished_agents,
        ),
        finished_agents=state.finished_agents,  # Not updated yet.
        step_count=state.step_count,
        key=key,
    )

    state, timestep = self._state_to_timestep(state, action)
    return state, timestep