GraphColoring
GraphColoring (Environment)
#
Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used.
-
observation:
Observation
- adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph.
- colors: jax array (int32) of shape (num_nodes,), representing the current color assignments for the vertices.
- action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment.
- current_node_index: integer representing the current node being colored.
-
action: int, the color to be assigned to the current node (0 to num_nodes - 1)
-
reward: float, a sparse reward is provided at the end of the episode. Equals the negative of the number of unique colors used to color all vertices in the graph. If an invalid action is taken, the reward is the negative of the total number of colors.
-
episode termination:
- if all nodes have been assigned a color or if an invalid action is taken.
-
state:
State
- adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph.
- colors: jax array (int32) of shape (num_nodes,), color assigned to each node, -1 if not assigned.
- current_node_index: jax array (int) with shape (), index of the current node.
- action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment.
- key: jax array (uint32) of shape (2,), random key used to generate random numbers at each step and for auto-reset.
1 2 3 4 5 6 7 8 |
|
observation_spec: jumanji.specs.Spec[jumanji.environments.logic.graph_coloring.types.Observation]
cached
property
writable
#
Returns the observation spec.
Returns:
Type | Description |
---|---|
Spec for the `Observation` whose fields are |
|
action_spec: DiscreteArray
cached
property
writable
#
Specification of the action for the GraphColoring
environment.
Returns:
Type | Description |
---|---|
action_spec |
specs.DiscreteArray object |
__init__(self, generator: Optional[jumanji.environments.logic.graph_coloring.generator.Generator] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.logic.graph_coloring.types.State]] = None)
special
#
Instantiate a GraphColoring
environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
generator |
Optional[jumanji.environments.logic.graph_coloring.generator.Generator] |
callable to instantiate environment instances.
Defaults to |
None |
viewer |
Optional[jumanji.viewer.Viewer[jumanji.environments.logic.graph_coloring.types.State]] |
environment viewer for rendering.
Defaults to |
None |
reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]]
#
Resets the environment to an initial state.
Returns:
Type | Description |
---|---|
Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]] |
The initial state and timestep. |
step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.logic.graph_coloring.types.State, jumanji.types.TimeStep[jumanji.environments.logic.graph_coloring.types.Observation]]
#
Updates the environment state after the agent takes an action.
Specifically, this function allows the agent to choose a color for the current node (based on the action taken) in a graph coloring problem. It then updates the state of the environment based on the color chosen and calculates the reward based on the validity of the action and the completion of the coloring task.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
the current state of the environment. |
required |
action |
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] |
the action taken by the agent. |
required |
Returns:
Type | Description |
---|---|
state |
the new state of the environment. timestep: the next timestep. |