Skip to content

MMST

MMST (Environment) #

The MMST (Multi Minimum Spanning Tree) environment consists of a random connected graph with groups of nodes (same node types) that needs to be connected. The goal of the environment is to connect all nodes of the same type together without using the same utility nodes (nodes that do not belong to any group of nodes).

Note: routing problems are randomly generated and may not be solvable!

Requirements: The total number of nodes should be at least 20% more than the number of nodes we want to connect to guarantee we have enough remaining nodes to create a path with all the nodes we want to connect. An exception will be raised if the number of nodes is not greater than (0.8 x num_agents x num_nodes_per_agent).

  • observation: Observation

    • node_types: jax array (int) of shape (num_nodes): the component type of each node (-1 represents utility nodes).
    • adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph.
    • positions: jax array (int) of shape (num_agents,): the index of the last visited node.
    • step_count: jax array (int) of shape (): integer to keep track of the number of steps.
    • action_mask: jax array (bool) of shape (num_agent, num_nodes): binary mask (False/True <--> invalid/valid action).
  • reward: float

  • action: jax array (int) of shape (num_agents,): [0,1,..., num_nodes-1] Each agent selects the next node to which it wants to connect.

  • state: State

    • node_type: jax array (int) of shape (num_nodes,). the component type of each node (-1 represents utility nodes).
    • adj_matrix: jax array (bool) of shape (num_nodes, num_nodes): adjacency matrix of the graph.
    • connected_nodes: jax array (int) of shape (num_agents, time_limit). we only count each node visit once.
    • connected_nodes_index: jax array (int) of shape (num_agents, num_nodes).
    • position_index: jax array (int) of shape (num_agents,).
    • node_edges: jax array (int) of shape (num_agents, num_nodes, num_nodes).
    • positions: jax array (int) of shape (num_agents,). the index of the last visited node.
    • action_mask: jax array (bool) of shape (num_agent, num_nodes). binary mask (False/True <--> invalid/valid action).
    • finished_agents: jax array (bool) of shape (num_agent,).
    • nodes_to_connect: jax array (int) of shape (num_agents, num_nodes_per_agent).
    • step_count: step counter.
    • time_limit: the number of steps allowed before an episode terminates.
    • key: PRNG key for random sample.
  • constants definitions:

    • Nodes

      • INVALID_NODE = -1: used to check if an agent selects an invalid node. A node may be invalid if its has no edge with the current node or if it is a utility node already selected by another agent.
      • UTILITY_NODE = -1: utility node (belongs to no agent).
      • EMPTY_NODE = -1: used for padding. state.connected_nodes stores the path (all the nodes) visited by an agent. Hence it has size equal to the step limit. We use this constant to initialise this array since 0 represents the first node.
      • DUMMY_NODE = -10: used for tie-breaking if multiple agents select the same node.
    • Edges

      • EMPTY_EDGE = -1: used for masking edges array. state.node_edges is the graph's adjacency matrix, but we don't represent it using 0s and 1s, we use the node values instead, i.e A_ij = j or A_ij = -1. Also edges are masked when utility nodes are selected by an agent to make it unaccessible by other agents.
    • Actions encoding

      • INVALID_CHOICE = -1
      • INVALID_TIE_BREAK = -2
      • INVALID_ALREADY_TRAVERSED = -3
1
2
3
4
5
6
7
8
from jumanji.environments import MMST
env = MMST()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

action_spec: MultiDiscreteArray cached property writable #

Returns the action spec.

Returns:

Type Description
action_spec

a specs.MultiDiscreteArray spec.

observation_spec: jumanji.specs.Spec[jumanji.environments.routing.mmst.types.Observation] cached property writable #

Returns the observation spec.

Returns:

Type Description
Spec for the `Observation` whose fields are
  • node_types: BoundedArray (int32) of shape (num_nodes,).
  • adj_matrix: BoundedArray (int) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph.
  • positions: BoundedArray (int32) of shape (num_agents). Current node position of agent.
  • action_mask: BoundedArray (bool) of shape (num_agents, num_nodes,). Represents the valid actions in the current state.

__init__(self, generator: Optional[jumanji.environments.routing.mmst.generator.Generator] = None, reward_fn: Optional[jumanji.environments.routing.mmst.reward.RewardFn] = None, time_limit: int = 70, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.mmst.types.State]] = None) special #

Create the MMST environment.

Parameters:

Name Type Description Default
generator Optional[jumanji.environments.routing.mmst.generator.Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [SplitRandomGenerator]. Defaults to SplitRandomGenerator(num_nodes=36, num_edges=72, max_degree=5, num_agents=3, num_nodes_per_agent=4, max_step=time_limit).

None
reward_fn Optional[jumanji.environments.routing.mmst.reward.RewardFn]

class of type RewardFn, whose __call__ is used as a reward function. Implemented options are [DenseRewardFn]. Defaults to DenseRewardFn(reward_values=(10.0, -1.0, -1.0)).

None
time_limit int

the number of steps allowed before an episode terminates. Defaults to 70.

70
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.mmst.types.State]]

Viewer used for rendering. Defaults to MMSTViewer

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.mmst.types.State, jumanji.types.TimeStep[jumanji.environments.routing.mmst.types.Observation]] #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKeyArray

used to randomly generate the problem and the different start nodes.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.routing.mmst.types.State, jumanji.types.TimeStep[jumanji.environments.routing.mmst.types.Observation]] #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

Array containing the index of the next node to visit.

required

Returns:

Type Description
state, timestep

Tuple[State, TimeStep] containing the next state of the environment, as well as the timestep to be observed.

render(self, state: State) -> Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] #

Render the environment for a given state.

Returns:

Type Description
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

Array of rgb pixel values in the shape (width, height, rgb).


Last update: 2024-11-01
Back to top