TSP
TSP (Environment)
#
Traveling Salesman Problem (TSP) environment as described in [1].
-
observation: Observation
- coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city.
- position: jax array (int32) of shape () the index corresponding to the last visited city.
- trajectory: jax array (int32) of shape (num_cities,) array of city indices defining the route (-1 --> not filled yet).
- action_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> illegal/legal <--> cannot be visited/can be visited).
-
action: jax array (int32) of shape () [0, ..., num_cities - 1] -> city to visit.
-
reward: jax array (float) of shape (), could be either:
- dense: the negative distance between the current city and the chosen next city to go to. It is 0 for the first chosen city, and for the last city, it also includes the distance to the initial city to complete the tour.
- sparse: the negative tour length at the end of the episode. The tour length is defined
as the sum of the distances between consecutive cities. It is computed by starting at
the first city and ending there, after visiting all the cities.
In both cases, the reward is a large negative penalty of
-num_cities * sqrt(2)
if the action is invalid, i.e. a previously selected city is selected again.
-
episode termination:
- if no action can be performed, i.e. all cities have been visited.
- if an invalid action is taken, i.e. an already visited city is chosen.
-
state:
State
- coordinates: jax array (float) of shape (num_cities, 2) the coordinates of each city.
- position: int32 the identifier (index) of the last visited city.
- visited_mask: jax array (bool) of shape (num_cities,) binary mask (False/True <--> not visited/visited).
- trajectory: jax array (int32) of shape (num_cities,) the identifiers of the cities that have been visited (-1 means that no city has been visited yet at that time in the sequence).
- num_visited: int32 number of cities that have been visited.
[1] Kwon Y., Choo J., Kim B., Yoon I., Min S., Gwon Y. (2020). "POMO: Policy Optimization with Multiple Optima for Reinforcement Learning".
1 2 3 4 5 6 7 8 |
|
observation_spec: jumanji.specs.Spec[jumanji.environments.routing.tsp.types.Observation]
cached
property
writable
#
Returns the observation spec.
Returns:
Type | Description |
---|---|
Spec for the `Observation` whose fields are |
|
action_spec: DiscreteArray
cached
property
writable
#
Returns the action spec.
Returns:
Type | Description |
---|---|
action_spec |
a |
__init__(self, generator: Optional[jumanji.environments.routing.tsp.generator.Generator] = None, reward_fn: Optional[jumanji.environments.routing.tsp.reward.RewardFn] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.tsp.types.State]] = None)
special
#
Instantiates a TSP
environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
generator |
Optional[jumanji.environments.routing.tsp.generator.Generator] |
|
None |
reward_fn |
Optional[jumanji.environments.routing.tsp.reward.RewardFn] |
RewardFn whose |
None |
viewer |
Optional[jumanji.viewer.Viewer[jumanji.environments.routing.tsp.types.State]] |
|
None |
reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.tsp.types.State, jumanji.types.TimeStep[jumanji.environments.routing.tsp.types.Observation]]
#
Resets the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray |
used to randomly generate the coordinates. |
required |
Returns:
Type | Description |
---|---|
state |
State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment. |
step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]) -> Tuple[jumanji.environments.routing.tsp.types.State, jumanji.types.TimeStep[jumanji.environments.routing.tsp.types.Observation]]
#
Run one timestep of the environment's dynamics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
|
required |
action |
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] |
|
required |
Returns:
Type | Description |
---|---|
state |
the next state of the environment. timestep: the timestep to be observed. |