Skip to content

TSP

TSP (Environment) #

Traveling Salesman Problem (TSP) environment as described in [1].

  • observation: Observation

    • coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city.
    • position: jax array (int32) of shape () the index corresponding to the last visited city.
    • trajectory: jax array (int32) of shape (num_cities,) array of city indices defining the route (-1 --> not filled yet).
    • action_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> illegal/legal <--> cannot be visited/can be visited).
  • action: jax array (int32) of shape () [0, ..., num_cities - 1] -> city to visit.

  • reward: jax array (float) of shape (), could be either:

    • dense: the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour.
    • sparse: the negative tour length at the end of the episode. The tour length is defined as the sum of the distances between consecutive cities. It is computed by starting at the first city and ending there, after visiting all the cities. In both cases, the reward is a large negative penalty of -num_cities * sqrt(2) if the action is invalid, i.e. a previously selected city is selected again.
  • episode termination:

    • if no action can be performed, i.e. all cities have been visited.
    • if an invalid action is taken, i.e. an already visited city is chosen.
  • state: State

    • coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city.
    • position: int32 the identifier (index) of the last visited city.
    • visited_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> not visited/visited).
    • trajectory: jax array (int32) of shape (num_cities,) the identifiers of the cities that have been visited (-1 means that no city has been visited yet at that time in the sequence).
    • num_visited: int32 number of cities that have been visited.

[1] Kwon Y., Choo J., Kim B., Yoon I., Min S., Gwon Y. (2020). "POMO: Policy Optimization with Multiple Optima for Reinforcement Learning".

1
2
3
4
5
6
7
8
from jumanji.environments import TSP
env = TSP()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.routing.tsp.types.Observation] cached property writable #

Returns the observation spec.

Returns:

Type Description
Spec for the `Observation` whose fields are
  • coordinates: BoundedArray (float) of shape (num_cities,).
  • position: DiscreteArray (num_values = num_cities) of shape ().
  • trajectory: BoundedArray (int32) of shape (num_cities,).
  • action_mask: BoundedArray (bool) of shape (num_cities,).

action_spec: DiscreteArray cached property writable #

Returns the action spec.

Returns:

Type Description
action_spec

a specs.DiscreteArray spec.

__init__(self, generator: Optional[jumanji.environments.routing.tsp.generator.Generator] = None, reward_fn: Optional[jumanji.environments.routing.tsp.reward.RewardFn] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.tsp.types.State]] = None) special #

Instantiates a TSP environment.

Parameters:

Name Type Description Default
generator Optional[jumanji.environments.routing.tsp.generator.Generator]

Generator whose __call__ instantiates an environment instance. The default option is 'UniformGenerator' which randomly generates TSP instances with 20 cities sampled from a uniform distribution.

None
reward_fn Optional[jumanji.environments.routing.tsp.reward.RewardFn]

RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action and the next state. Implemented options are [DenseReward, SparseReward]. Defaults to DenseReward.

None
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.tsp.types.State]]

Viewer used for rendering. Defaults to TSPViewer with "human" render mode.

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.tsp.types.State, jumanji.types.TimeStep[jumanji.environments.routing.tsp.types.Observation]] #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKeyArray

used to randomly generate the coordinates.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]) -> Tuple[jumanji.environments.routing.tsp.types.State, jumanji.types.TimeStep[jumanji.environments.routing.tsp.types.Observation]] #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]

Array containing the index of the next position to visit.

required

Returns:

Type Description
state

the next state of the environment. timestep: the timestep to be observed.


Last update: 2024-11-01
Back to top