Skip to content

TSP

Bases: Environment[State, DiscreteArray, Observation]

Traveling Salesman Problem (TSP) environment as described in [1].

  • observation: Observation

    • coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city.
    • position: jax array (int32) of shape () the index corresponding to the last visited city.
    • trajectory: jax array (int32) of shape (num_cities,) array of city indices defining the route (-1 --> not filled yet).
    • action_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> illegal/legal <--> cannot be visited/can be visited).
  • action: jax array (int32) of shape () [0, ..., num_cities - 1] -> city to visit.

  • reward: jax array (float) of shape (), could be either:

    • dense: the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour.
    • sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again.
  • episode termination:

    • if no action can be performed, i.e. all cities have been visited.
    • if an invalid action is taken, i.e. an already visited city is chosen.
  • state: State

    • coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city.
    • position: int32 the identifier (index) of the last visited city.
    • visited_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> not visited/visited).
    • trajectory: jax array (int32) of shape (num_cities,) the identifiers of the cities that have been visited (-1 means that no city has been visited yet at that time in the sequence).
    • num_visited: int32 number of cities that have been visited.

[1] Kwon Y., Choo J., Kim B., Yoon I., Min S., Gwon Y. (2020). "POMO: Policy Optimization with Multiple Optima for Reinforcement Learning".

1
2
3
4
5
6
7
8
from jumanji.environments import TSP
env = TSP()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a TSP environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates TSP instances with 20 cities sampled from a uniform distribution.

None
reward_fn Optional[RewardFn]

RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action and the next state. Implemented options are [DenseReward, SparseReward]. Defaults to DenseReward.

None
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to TSPViewer with "human" render mode.

None
Source code in jumanji/environments/routing/tsp/env.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(
    self,
    generator: Optional[Generator] = None,
    reward_fn: Optional[RewardFn] = None,
    viewer: Optional[Viewer[State]] = None,
):
    """Instantiates a `TSP` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment instance.
            The default option is 'UniformGenerator' which randomly generates
            TSP instances with 20 cities sampled from a uniform distribution.
        reward_fn: RewardFn whose `__call__` method computes the reward of an environment
            transition. The function must compute the reward based on the current state,
            the chosen action and the next state.
            Implemented options are [`DenseReward`, `SparseReward`]. Defaults to `DenseReward`.
        viewer: `Viewer` used for rendering. Defaults to `TSPViewer` with "human" render mode.
    """

    self.generator = generator or UniformGenerator(
        num_cities=20,
    )
    self.num_cities = self.generator.num_cities
    super().__init__()
    self.reward_fn = reward_fn or DenseReward()
    self._viewer = viewer or TSPViewer(name="TSP", render_mode="human")

action_spec: specs.DiscreteArray cached property #

Returns the action spec.

Returns:

Name Type Description
action_spec DiscreteArray

a specs.DiscreteArray spec.

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec.

Returns:

Type Description
Spec[Observation]

Spec for the Observation whose fields are:

Spec[Observation]
  • coordinates: BoundedArray (float) of shape (num_cities,).
Spec[Observation]
  • position: DiscreteArray (num_values = num_cities) of shape ().
Spec[Observation]
  • trajectory: BoundedArray (int32) of shape (num_cities,).
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (num_cities,).

animate(states, interval=200, save_path=None) #

Creates an animated gif of the TSP environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/routing/tsp/env.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the `TSP` environment based on the sequence of states.

    Args:
        states: sequence of environment states corresponding to consecutive timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/routing/tsp/env.py
254
255
256
257
258
259
260
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Render the given state of the environment. This rendering shows the layout of the cities and the tour so far.

Parameters:

Name Type Description Default
state State

current environment state.

required

Returns:

Name Type Description
rgb_array Optional[NDArray]

the RGB image of the state as an array.

Source code in jumanji/environments/routing/tsp/env.py
223
224
225
226
227
228
229
230
231
232
233
def render(self, state: State) -> Optional[NDArray]:
    """Render the given state of the environment. This rendering shows the layout of the cities
    and the tour so far.

    Args:
        state: current environment state.

    Returns:
        rgb_array: the RGB image of the state as an array.
    """
    return self._viewer.render(state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

used to randomly generate the coordinates.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment.

timestep TimeStep[Observation]

TimeStep object corresponding to the first timestep returned by the environment.

Source code in jumanji/environments/routing/tsp/env.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: used to randomly generate the coordinates.

    Returns:
        state: State object corresponding to the new state of the environment.
        timestep: TimeStep object corresponding to the first timestep returned
            by the environment.
    """
    state = self.generator(key)
    timestep = restart(observation=self._state_to_observation(state))
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Numeric

Array containing the index of the next position to visit.

required

Returns:

Name Type Description
state State

the next state of the environment.

timestep TimeStep[Observation]

the timestep to be observed.

Source code in jumanji/environments/routing/tsp/env.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def step(self, state: State, action: chex.Numeric) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: `State` object containing the dynamics of the environment.
        action: `Array` containing the index of the next position to visit.

    Returns:
        state: the next state of the environment.
        timestep: the timestep to be observed.
    """
    is_valid = ~state.visited_mask[action]
    next_state = jax.lax.cond(
        is_valid,
        self._update_state,
        lambda *_: state,
        state,
        action,
    )

    reward = self.reward_fn(state, action, next_state, is_valid)
    observation = self._state_to_observation(next_state)

    # Terminate if all cities have been visited or the action is invalid
    is_done = (next_state.num_visited == self.num_cities) | ~is_valid
    timestep = jax.lax.cond(
        is_done,
        termination,
        transition,
        reward,
        observation,
    )
    return next_state, timestep