Skip to content

Knapsack

Knapsack (Environment) #

Knapsack environment as described in [1].

  • observation: Observation

    • weights: jax array (float) of shape (num_items,) the weights of the items.
    • values: jax array (float) of shape (num_items,) the values of the items.
    • packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack.
    • action_mask: jax array (bool) of shape (num_items,) binary mask denoting which items can be packed into the knapsack.
  • action: jax array (int32) of shape () [0, ..., num_items - 1] -> item to pack.

  • reward: jax array (float) of shape (), could be either:

    • dense: the value of the item to pack at the current timestep.
    • sparse: the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity.
  • episode termination:

    • if no action can be performed, i.e. all items are packed or each remaining item's weight is larger than the bag capacity.
    • if an invalid action is taken, i.e. the chosen item is already packed or has a weight larger than the bag capacity.
  • state: State

    • weights: jax array (float) of shape (num_items,) the weights of the items.
    • values: jax array (float) of shape (num_items,) the values of the items.
    • packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack.
    • remaining_budget: jax array (float) the budget currently remaining.

[1] https://arxiv.org/abs/2010.16011

1
2
3
4
5
6
7
8
from jumanji.environments import Knapsack
env = Knapsack()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.packing.knapsack.types.Observation] cached property writable #

Returns the observation spec.

Returns:

Type Description
Spec for each field in the Observation
  • weights: BoundedArray (float) of shape (num_items,).
  • values: BoundedArray (float) of shape (num_items,).
  • packed_items: BoundedArray (bool) of shape (num_items,).
  • action_mask: BoundedArray (bool) of shape (num_items,).

action_spec: DiscreteArray cached property writable #

Returns the action spec.

Returns:

Type Description
action_spec

a specs.DiscreteArray spec.

__init__(self, generator: Optional[jumanji.environments.packing.knapsack.generator.Generator] = None, reward_fn: Optional[jumanji.environments.packing.knapsack.reward.RewardFn] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.packing.knapsack.types.State]] = None) special #

Instantiates a Knapsack environment.

Parameters:

Name Type Description Default
generator Optional[jumanji.environments.packing.knapsack.generator.Generator]

Generator whose __call__ instantiates an environment instance. The default option is 'RandomGenerator' which samples Knapsack instances with 50 items and a total budget of 12.5.

None
reward_fn Optional[jumanji.environments.packing.knapsack.reward.RewardFn]

RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [DenseReward, SparseReward]. Defaults to DenseReward.

None
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.knapsack.types.State]]

Viewer used for rendering. Defaults to KnapsackViewer with "human" render mode.

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.packing.knapsack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.knapsack.types.Observation]] #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKeyArray

used to randomly generate the weights and values of the items.

required

Returns:

Type Description
state

the new state of the environment. timestep: the first timestep returned by the environment.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]) -> Tuple[jumanji.environments.packing.knapsack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.knapsack.types.Observation]] #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]

index of next item to take.

required

Returns:

Type Description
state

next state of the environment. timestep: the timestep to be observed.


Last update: 2024-11-01
Back to top