Skip to content

CVRP

Bases: Environment[State, DiscreteArray, Observation]

Capacitated Vehicle Routing Problem (CVRP) environment as described in [1].

  • observation: Observation

    • coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot.
    • demands: jax array (float) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot).
    • unvisited_nodes: jax array (bool) of shape (num_nodes + 1,) indicates nodes that remain to be visited.
    • position: jax array (int32) of shape () the index of the last visited node.
    • trajectory: jax array (int32) of shape (2 * num_nodes,) array of node indices defining the route (set to DEPOT_IDX if not filled yet).
    • capacity: jax array (float) of shape () the current capacity of the vehicle.
    • action_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> invalid/valid action).
  • action: jax array (int32) of shape () [0, ..., num_nodes] -> node to visit. 0 corresponds to visiting the depot.

  • reward: jax array (float) of shape (), could be either:

    • dense: the negative distance between the current node and the chosen next node to go to. For the last node, it also includes the distance to the depot to complete the tour.
    • sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive nodes. In both cases, the reward is a large negative penalty of -2 * num_nodes * sqrt(2) if the action is invalid, e.g. a previously selected node other than the depot is selected again.
  • episode termination:

    • if no action can be performed, i.e. all nodes have been visited.
    • if an invalid action is taken, i.e. a previously visited city other than the depot is chosen.
  • state: State

    • coordinates: jax array (float) of shape (num_nodes + 1, 2) the coordinates of each node and the depot.
    • demands: jax array (int32) of shape (num_nodes + 1,) the associated cost of each node and the depot (0.0 for the depot).
    • position: jax array (int32) the index of the last visited node.
    • capacity: jax array (int32) the current capacity of the vehicle.
    • visited_mask: jax array (bool) of shape (num_nodes + 1,) binary mask (False/True <--> not visited/visited).
    • trajectory: jax array (int32) of shape (2 * num_nodes,) identifiers of the nodes that have been visited (set to DEPOT_IDX if not filled yet).
    • num_visits: int32 number of actions that have been taken (i.e., unique visits).

[1] Toth P., Vigo D. (2014). "Vehicle routing: problems, methods, and applications".

1
2
3
4
5
6
7
8
from jumanji.environments import CVRP
env = CVRP()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a CVRP environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates CVRP instances with 20 cities sampled from a uniform distribution, a maximum vehicle capacity of 30, and a maximum city demand of 10.

None
reward_fn Optional[RewardFn]

RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [DenseReward, SparseReward]. Defaults to DenseReward.

None
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to CVRPViewer with "human" render mode.

None
Source code in jumanji/environments/routing/cvrp/env.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def __init__(
    self,
    generator: Optional[Generator] = None,
    reward_fn: Optional[RewardFn] = None,
    viewer: Optional[Viewer[State]] = None,
):
    """Instantiates a `CVRP` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment instance.
            The default option is 'UniformGenerator' which randomly generates
            CVRP instances with 20 cities sampled from a uniform distribution,
            a maximum vehicle capacity of 30, and a maximum city demand of 10.
        reward_fn: `RewardFn` whose `__call__` method computes the reward of an environment
            transition. The function must compute the reward based on the current state,
            the chosen action, the next state and whether the action is valid.
            Implemented options are [`DenseReward`, `SparseReward`]. Defaults to `DenseReward`.
        viewer: `Viewer` used for rendering. Defaults to `CVRPViewer` with "human" render mode.
    """

    self.generator = generator or UniformGenerator(
        num_nodes=20,
        max_capacity=30,
        max_demand=10,
    )
    self.num_nodes = self.generator.num_nodes
    super().__init__()
    self.max_capacity = self.generator.max_capacity
    self.max_demand = self.generator.max_demand
    if self.max_capacity < self.max_demand:
        raise ValueError(
            f"The demand associated with each node must be lower than the maximum capacity, "
            f"hence the maximum capacity must be >= {self.max_demand}."
        )
    self.reward_fn = reward_fn or DenseReward()
    self._viewer = viewer or CVRPViewer(
        name="CVRP",
        num_cities=self.num_nodes,
        render_mode="human",
    )

action_spec: specs.DiscreteArray cached property #

Returns the action spec.

Returns:

Name Type Description
action_spec DiscreteArray

a specs.DiscreteArray spec.

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec.

Returns:

Type Description
Spec[Observation]

Spec for the Observation whose fields are:

Spec[Observation]
  • coordinates: BoundedArray (float) of shape (num_nodes + 1, 2).
Spec[Observation]
  • demands: BoundedArray (float) of shape (num_nodes + 1,).
Spec[Observation]
  • unvisited_nodes: BoundedArray (bool) of shape (num_nodes + 1,).
Spec[Observation]
  • position: DiscreteArray (num_values = num_nodes + 1) of shape ().
Spec[Observation]
  • trajectory: BoundedArray (int32) of shape (2 * num_nodes,).
Spec[Observation]
  • capacity: BoundedArray (float) of shape ().
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (num_nodes + 1,).

animate(states, interval=200, save_path=None) #

Creates an animated gif of the CVRP environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/routing/cvrp/env.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the CVRP environment based on the sequence of states.

    Args:
        states: sequence of environment states corresponding to consecutive timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/routing/cvrp/env.py
303
304
305
306
307
308
309
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Render the given state of the environment. This rendering shows the layout of the tour so far with the cities as circles, and the depot as a square.

Parameters:

Name Type Description Default
state State

environment state to render.

required

Returns:

Name Type Description
rgb_array Optional[ArrayNumpy]

the RGB image of the state as an array.

Source code in jumanji/environments/routing/cvrp/env.py
272
273
274
275
276
277
278
279
280
281
282
def render(self, state: State) -> Optional[chex.ArrayNumpy]:
    """Render the given state of the environment. This rendering shows the layout of the tour so
     far with the cities as circles, and the depot as a square.

    Args:
        state: environment state to render.

    Returns:
        rgb_array: the RGB image of the state as an array.
    """
    return self._viewer.render(state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

used to randomly generate the coordinates.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment.

timestep TimeStep[Observation]

TimeStep object corresponding to the first timestep returned by the environment.

Source code in jumanji/environments/routing/cvrp/env.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: used to randomly generate the coordinates.

    Returns:
         state: `State` object corresponding to the new state of the environment.
         timestep: `TimeStep` object corresponding to the first timestep returned by the
            environment.
    """
    state = self.generator(key)
    timestep = restart(observation=self._state_to_observation(state))
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Numeric

jax array (int32) of shape () containing the index of the next node to visit.

required

Returns:

Type Description
Tuple[State, TimeStep[Observation]]

state, timestep: next state of the environment and timestep to be observed.

Source code in jumanji/environments/routing/cvrp/env.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def step(self, state: State, action: chex.Numeric) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: `State` object containing the dynamics of the environment.
        action: jax array (int32) of shape () containing the index of the next node to visit.

    Returns:
        state, timestep: next state of the environment and timestep to be observed.
    """
    node_demand = state.demands[action]
    node_is_visited = state.visited_mask[action]
    is_valid = ~node_is_visited & (state.capacity >= node_demand)

    next_state = jax.lax.cond(
        is_valid,
        self._update_state,
        lambda *_: state,
        state,
        action,
    )

    reward = self.reward_fn(state, action, next_state, is_valid)
    observation = self._state_to_observation(next_state)

    # Terminate if all nodes have been visited or the action is invalid.
    is_done = next_state.visited_mask.all() | ~is_valid

    timestep = jax.lax.cond(
        is_done,
        termination,
        transition,
        reward,
        observation,
    )
    return next_state, timestep