Skip to content

GraphColoring

GraphColoring (Environment) #

Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used.

  • observation: Observation

    • adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph.
    • colors: jax array (int32) of shape (num_nodes,), representing the current color assignments for the vertices.
    • action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment.
    • current_node_index: integer representing the current node being colored.
  • action: int, the color to be assigned to the current node (0 to num_nodes - 1)

  • reward: float, a sparse reward is provided at the end of the episode. Equals the negative of the number of unique colors used to color all vertices in the graph. If an invalid action is taken, the reward is the negative of the total number of colors.

  • episode termination:

    • if all nodes have been assigned a color or if an invalid action is taken.
  • state: State

    • adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph.
    • colors: jax array (int32) of shape (num_nodes,), color assigned to each node, -1 if not assigned.
    • current_node_index: jax array (int) with shape (), index of the current node.
    • action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment.
    • key: jax array (uint32) of shape (2,), random key used to generate random numbers at each step and for auto-reset.
1
2
3
4
5
6
7
8
from jumanji.environments import GraphColoring
env = GraphColoring()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.logic.graph_coloring.types.Observation] cached property writable #

Returns the observation spec.

Returns:

Type Description
Spec for the `Observation` whose fields are
  • adj_matrix: BoundedArray (bool) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph.
  • action_mask: BoundedArray (bool) of shape (num_nodes,). Represents the valid actions in the current state.
  • colors: BoundedArray (int32) of shape (num_nodes,). Represents the colors assigned to each node.
  • current_node_index: BoundedArray (int32) of shape (). Represents the index of the current node.

action_spec: DiscreteArray cached property writable #

Specification of the action for the GraphColoring environment.

Returns:

Type Description
action_spec

specs.DiscreteArray object

__init__(self, generator: Optional[jumanji.environments.logic.graph_coloring.generator.Generator] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.logic.graph_coloring.types.State]] = None) special #

Instantiate a GraphColoring environment.

Parameters:

Name Type Description Default
generator Optional[jumanji.environments.logic.graph_coloring.generator.Generator]

callable to instantiate environment instances. Defaults to RandomGenerator which generates graphs with 20 num_nodes and edge_probability equal to 0.8.

None
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.logic.graph_coloring.types.State]]

environment viewer for rendering. Defaults to GraphColoringViewer.

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]] #

Resets the environment to an initial state.

Returns:

Type Description
Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]]

The initial state and timestep.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]] #

Updates the environment state after the agent takes an action.

Specifically, this function allows the agent to choose a color for the current node (based on the action taken) in a graph coloring problem. It then updates the state of the environment based on the color chosen and calculates the reward based on the validity of the action and the completion of the coloring task.

Parameters:

Name Type Description Default
state State

the current state of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

the action taken by the agent.

required

Returns:

Type Description
state

the new state of the environment. timestep: the next timestep.


Last update: 2024-11-01
Back to top