Skip to content

Knapsack

Bases: Environment[State, DiscreteArray, Observation]

Knapsack environment as described in [1].

  • observation: Observation

    • weights: jax array (float) of shape (num_items,) the weights of the items.
    • values: jax array (float) of shape (num_items,) the values of the items.
    • packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack.
    • action_mask: jax array (bool) of shape (num_items,) binary mask denoting which items can be packed into the knapsack.
  • action: jax array (int32) of shape () [0, ..., num_items - 1] -> item to pack.

  • reward: jax array (float) of shape (), could be either:

    • dense: the value of the item to pack at the current timestep.
    • sparse: the sum of the values of the items packed in the bag at the end of the episode. In both cases, the reward is 0 if the action is invalid, i.e. an item that was previously selected is selected again or has a weight larger than the bag capacity.
  • episode termination:

    • if no action can be performed, i.e. all items are packed or each remaining item's weight is larger than the bag capacity.
    • if an invalid action is taken, i.e. the chosen item is already packed or has a weight larger than the bag capacity.
  • state: State

    • weights: jax array (float) of shape (num_items,) the weights of the items.
    • values: jax array (float) of shape (num_items,) the values of the items.
    • packed_items: jax array (bool) of shape (num_items,) binary mask denoting which items are already packed into the knapsack.
    • remaining_budget: jax array (float) the budget currently remaining.

[1] https://arxiv.org/abs/2010.16011

1
2
3
4
5
6
7
8
from jumanji.environments import Knapsack
env = Knapsack()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a Knapsack environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. The default option is 'RandomGenerator' which samples Knapsack instances with 50 items and a total budget of 12.5.

None
reward_fn Optional[RewardFn]

RewardFn whose __call__ method computes the reward of an environment transition. The function must compute the reward based on the current state, the chosen action, the next state and whether the action is valid. Implemented options are [DenseReward, SparseReward]. Defaults to DenseReward.

None
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to KnapsackViewer with "human" render mode.

None
Source code in jumanji/environments/packing/knapsack/env.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(
    self,
    generator: Optional[Generator] = None,
    reward_fn: Optional[RewardFn] = None,
    viewer: Optional[Viewer[State]] = None,
):
    """Instantiates a `Knapsack` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment instance.
            The default option is 'RandomGenerator' which samples Knapsack instances
            with 50 items and a total budget of 12.5.
        reward_fn: `RewardFn` whose `__call__` method computes the reward of an environment
            transition. The function must compute the reward based on the current state,
            the chosen action, the next state and whether the action is valid.
            Implemented options are [`DenseReward`, `SparseReward`]. Defaults to `DenseReward`.
        viewer: `Viewer` used for rendering. Defaults to `KnapsackViewer` with "human" render
            mode.
    """

    self.generator = generator or RandomGenerator(
        num_items=50,
        total_budget=12.5,
    )
    self.num_items = self.generator.num_items
    super().__init__()
    self.total_budget = self.generator.total_budget
    self.reward_fn = reward_fn or DenseReward()
    self._viewer = viewer or KnapsackViewer(
        name="Knapsack",
        render_mode="human",
        total_budget=self.total_budget,
    )

action_spec: specs.DiscreteArray cached property #

Returns the action spec.

Returns:

Name Type Description
action_spec DiscreteArray

a specs.DiscreteArray spec.

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec.

Returns:

Type Description
Spec[Observation]

Spec for each field in the Observation:

Spec[Observation]
  • weights: BoundedArray (float) of shape (num_items,).
Spec[Observation]
  • values: BoundedArray (float) of shape (num_items,).
Spec[Observation]
  • packed_items: BoundedArray (bool) of shape (num_items,).
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (num_items,).

animate(states, interval=200, save_path=None) #

Creates an animated gif of the Knapsack environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/packing/knapsack/env.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the `Knapsack` environment based on the sequence of states.

    Args:
        states: sequence of environment states corresponding to consecutive timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/packing/knapsack/env.py
264
265
266
267
268
269
270
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Render the environment state, displaying which items have been picked so far, their value, and the remaining budget.

Parameters:

Name Type Description Default
state State

the environment state to be rendered.

required
Source code in jumanji/environments/packing/knapsack/env.py
236
237
238
239
240
241
242
243
def render(self, state: State) -> Optional[NDArray]:
    """Render the environment state, displaying which items have been picked so far,
    their value, and the remaining budget.

    Args:
        state: the environment state to be rendered.
    """
    return self._viewer.render(state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

used to randomly generate the weights and values of the items.

required

Returns:

Name Type Description
state State

the new state of the environment.

timestep TimeStep[Observation]

the first timestep returned by the environment.

Source code in jumanji/environments/packing/knapsack/env.py
126
127
128
129
130
131
132
133
134
135
136
137
138
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: used to randomly generate the weights and values of the items.

    Returns:
        state: the new state of the environment.
        timestep: the first timestep returned by the environment.
    """
    state = self.generator(key)
    timestep = restart(observation=self._state_to_observation(state))
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Numeric

index of next item to take.

required

Returns:

Name Type Description
state State

next state of the environment.

timestep TimeStep[Observation]

the timestep to be observed.

Source code in jumanji/environments/packing/knapsack/env.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def step(self, state: State, action: chex.Numeric) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: State object containing the dynamics of the environment.
        action: index of next item to take.

    Returns:
        state: next state of the environment.
        timestep: the timestep to be observed.
    """
    item_fits = state.remaining_budget >= state.weights[action]
    item_not_packed = ~state.packed_items[action]
    is_valid = item_fits & item_not_packed
    next_state = jax.lax.cond(
        is_valid,
        self._update_state,
        lambda *_: state,
        state,
        action,
    )

    observation = self._state_to_observation(next_state)

    no_items_available = ~jnp.any(observation.action_mask)
    is_done = no_items_available | ~is_valid

    reward = self.reward_fn(state, action, next_state, is_valid, is_done)

    timestep = jax.lax.cond(
        is_done,
        termination,
        transition,
        reward,
        observation,
    )

    return next_state, timestep