Skip to content

GraphColoring

Bases: Environment[State, DiscreteArray, Observation]

Environment for the GraphColoring problem. The problem is a combinatorial optimization task where the goal is to assign a color to each vertex of a graph in such a way that no two adjacent vertices share the same color. The problem is usually formulated as minimizing the number of colors used.

  • observation: Observation

    • adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph.
    • colors: jax array (int32) of shape (num_nodes,), representing the current color assignments for the vertices.
    • action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment.
    • current_node_index: integer representing the current node being colored.
  • action: int, the color to be assigned to the current node (0 to num_nodes - 1)

  • reward: float, a sparse reward is provided at the end of the episode. Equals the negative of the number of unique colors used to color all vertices in the graph. If an invalid action is taken, the reward is the negative of the total number of colors.

  • episode termination:

    • if all nodes have been assigned a color or if an invalid action is taken.
  • state: State

    • adj_matrix: jax array (bool) of shape (num_nodes, num_nodes), representing the adjacency matrix of the graph.
    • colors: jax array (int32) of shape (num_nodes,), color assigned to each node, -1 if not assigned.
    • current_node_index: jax array (int) with shape (), index of the current node.
    • action_mask: jax array (bool) of shape (num_colors,), indicating which actions are valid in the current state of the environment.
    • key: jax array (uint32) of shape (2,), random key used to generate random numbers at each step and for auto-reset.
1
2
3
4
5
6
7
8
from jumanji.environments import GraphColoring
env = GraphColoring()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiate a GraphColoring environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

callable to instantiate environment instances. Defaults to RandomGenerator which generates graphs with 20 num_nodes and edge_probability equal to 0.8.

None
viewer Optional[Viewer[State]]

environment viewer for rendering. Defaults to GraphColoringViewer.

None
Source code in jumanji/environments/logic/graph_coloring/env.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(
    self,
    generator: Optional[Generator] = None,
    viewer: Optional[Viewer[State]] = None,
):
    """Instantiate a `GraphColoring` environment.

    Args:
        generator: callable to instantiate environment instances.
            Defaults to `RandomGenerator` which generates graphs with
            20 `num_nodes` and `edge_probability` equal to 0.8.
        viewer: environment viewer for rendering.
            Defaults to `GraphColoringViewer`.
    """
    self.generator = generator or RandomGenerator(num_nodes=20, edge_probability=0.8)
    self.num_nodes = self.generator.num_nodes
    super().__init__()

    # Create viewer used for rendering
    self._env_viewer = viewer or GraphColoringViewer(name="GraphColoring")

action_spec: specs.DiscreteArray cached property #

Specification of the action for the GraphColoring environment.

Returns:

Name Type Description
action_spec DiscreteArray

specs.DiscreteArray object

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec.

Returns:

Type Description
Spec[Observation]

Spec for the Observation whose fields are:

Spec[Observation]
  • adj_matrix: BoundedArray (bool) of shape (num_nodes, num_nodes). Represents the adjacency matrix of the graph.
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (num_nodes,). Represents the valid actions in the current state.
Spec[Observation]
  • colors: BoundedArray (int32) of shape (num_nodes,). Represents the colors assigned to each node.
Spec[Observation]
  • current_node_index: BoundedArray (int32) of shape (). Represents the index of the current node.

animate(states, interval=200, save_path=None) #

Creates an animated gif of the GraphColoring environment based on a sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

is a list of State objects representing the sequence of game states.

required
interval int

the delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be stored.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/logic/graph_coloring/env.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> animation.FuncAnimation:
    """Creates an animated gif of the `GraphColoring` environment based on a sequence of states.

    Args:
        states: is a list of `State` objects representing the sequence of game states.
        interval: the delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be stored.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._env_viewer.animate(states=states, interval=interval, save_path=save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/logic/graph_coloring/env.py
304
305
306
307
308
309
310
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._env_viewer.close()

render(state) #

Renders the current state of the GraphColoring environment.

Parameters:

Name Type Description Default
state State

is the current game state to be rendered.

required
Source code in jumanji/environments/logic/graph_coloring/env.py
277
278
279
280
281
282
283
def render(self, state: State) -> Optional[NDArray]:
    """Renders the current state of the `GraphColoring` environment.

    Args:
        state: is the current game state to be rendered.
    """
    return self._env_viewer.render(state=state)

reset(key) #

Resets the environment to an initial state.

Returns:

Type Description
Tuple[State, TimeStep[Observation]]

The initial state and timestep.

Source code in jumanji/environments/logic/graph_coloring/env.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment to an initial state.

    Returns:
        The initial state and timestep.
    """
    colors = jnp.full(self.num_nodes, -1, dtype=jnp.int32)
    key, subkey = jax.random.split(key)
    adj_matrix = self.generator(subkey)

    action_mask = jnp.ones(self.num_nodes, dtype=bool)
    current_node_index = jnp.array(0, jnp.int32)
    state = State(
        adj_matrix=adj_matrix,
        colors=colors,
        current_node_index=current_node_index,
        action_mask=action_mask,
        key=key,
    )
    obs = Observation(
        adj_matrix=adj_matrix,
        colors=colors,
        action_mask=action_mask,
        current_node_index=current_node_index,
    )
    timestep = restart(observation=obs)

    return state, timestep

step(state, action) #

Updates the environment state after the agent takes an action.

Specifically, this function allows the agent to choose a color for the current node (based on the action taken) in a graph coloring problem. It then updates the state of the environment based on the color chosen and calculates the reward based on the validity of the action and the completion of the coloring task.

Parameters:

Name Type Description Default
state State

the current state of the environment.

required
action Array

the action taken by the agent.

required

Returns:

Name Type Description
state State

the new state of the environment.

timestep TimeStep[Observation]

the next timestep.

Source code in jumanji/environments/logic/graph_coloring/env.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Updates the environment state after the agent takes an action.

    Specifically, this function allows the agent to choose
    a color for the current node (based on the action taken)
    in a graph coloring problem.
    It then updates the state of the environment based on
    the color chosen and calculates the reward based on
    the validity of the action and the completion of the coloring task.

    Args:
        state: the current state of the environment.
        action: the action taken by the agent.

    Returns:
        state: the new state of the environment.
        timestep: the next timestep.
    """
    # Get the valid actions for the current state.
    valid_actions = state.action_mask

    # Check if the chosen action is invalid (not in valid_actions).
    invalid_action_taken = jnp.logical_not(valid_actions[action])

    # Update the colors array with the chosen action.
    colors = state.colors.at[state.current_node_index].set(action)

    # Determine if all nodes have been assigned a color
    all_nodes_colored = jnp.all(colors >= 0)

    # Calculate the reward
    unique_colors_used = jnp.unique(colors, size=self.num_nodes, fill_value=-1)
    num_unique_colors = jnp.count_nonzero(unique_colors_used >= 0)
    reward = jnp.where(all_nodes_colored, -num_unique_colors, 0.0)

    # Apply the maximum penalty when an invalid action is taken and terminate the episode
    reward = jnp.where(invalid_action_taken, -self.num_nodes, reward)
    done = jnp.logical_or(all_nodes_colored, invalid_action_taken)

    # Update the current node index
    next_node_index = (state.current_node_index + 1) % self.num_nodes

    next_action_mask = self._get_valid_actions(next_node_index, state.adj_matrix, state.colors)

    next_state = State(
        adj_matrix=state.adj_matrix,
        colors=colors,
        current_node_index=next_node_index,
        action_mask=next_action_mask,
        key=state.key,
    )
    obs = Observation(
        adj_matrix=state.adj_matrix,
        colors=colors,
        action_mask=next_state.action_mask,
        current_node_index=next_node_index,
    )
    timestep = lax.cond(
        done,
        termination,
        transition,
        reward,
        obs,
    )
    return next_state, timestep