Skip to content

BinPack

Bases: Environment[State, MultiDiscreteArray, Observation]

Problem of 3D bin packing, where a set of items have to be placed in a 3D container with the goal of maximizing its volume utilization. This environment only supports 1 bin, meaning it is equivalent to the 3D-knapsack problem. We use the Empty Maximal Space (EMS) formulation of this problem. An EMS is a 3D-rectangular space that lives inside the container and has the following properties: - It does not intersect any items, and it is not fully included into any other EMSs. - It is defined by 2 3D-points, hence 6 coordinates (x1, x2, y1, y2, z1, z2), the first point corresponding to its bottom-left location while the second defining its top-right corner.

  • observation: Observation

    • ems: EMS tree of jax arrays (float if normalize_dimensions else int32) each of shape (obs_num_ems,), coordinates of all EMSs at the current timestep.
    • ems_mask: jax array (bool) of shape (obs_num_ems,) indicates the EMSs that are valid.
    • items: Item tree of jax arrays (float if normalize_dimensions else int32) each of shape (max_num_items,), characteristics of all items for this instance.
    • items_mask: jax array (bool) of shape (max_num_items,) indicates the items that are valid.
    • items_placed: jax array (bool) of shape (max_num_items,) indicates the items that have been placed so far.
    • action_mask: jax array (bool) of shape (obs_num_ems, max_num_items) mask of the joint action space: True if the action (ems_id, item_id) is valid.
  • action: MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items).

    • ems_id: int between 0 and obs_num_ems - 1 (included).
    • item_id: int between 0 and max_num_items - 1 (included).
  • reward: jax array (float) of shape (), could be either:

    • dense: increase in volume utilization of the container due to packing the chosen item.
    • sparse: volume utilization of the container at the end of the episode.
  • episode termination:

    • if no action can be performed, i.e. no items fit in any EMSs, or all items have been packed.
    • if an invalid action is taken, i.e. an item that does not fit in an EMS or one that is already packed.
  • state: State

    • container: space defined by 2 points, i.e. 6 coordinates.
    • ems: empty maximal spaces (EMSs) in the container, each defined by 2 points (6 coordinates).
    • ems_mask: array of booleans that indicate the EMSs that are valid.
    • items: defined by 3 attributes (x, y, z).
    • items_mask: array of booleans that indicate the items that can be packed.
    • items_placed: array of booleans that indicate the items that have been placed so far.
    • items_location: locations of items in the container, defined by 3 coordinates (x, y, x).
    • action_mask: array of booleans that indicate the valid actions, i.e. EMSs and items that can be chosen.
    • sorted_ems_indexes: EMS indexes that are sorted by decreasing volume order.
    • key: random key used for auto-reset.
1
2
3
4
5
6
7
8
from jumanji.environments import BinPack
env = BinPack()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a BinPack environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator whose __call__ instantiates an environment instance. Implemented options are [RandomGenerator, ToyGenerator, CSVGenerator]. Defaults to RandomGenerator that generates up to 20 items maximum and that can handle 40 EMSs.

None
obs_num_ems int

number of EMSs (possible spaces in which to place an item) to show to the agent. If obs_num_ems is smaller than generator.max_num_ems, the first obs_num_ems largest EMSs (in terms of volume) will be returned in the observation. The good number heavily depends on the number of items (given by the instance generator). Default to 40 EMSs observable.

40
reward_fn Optional[RewardFn]

compute the reward based on the current state, the chosen action, the next state, whether the transition is valid and if it is terminal. Implemented options are [DenseReward, SparseReward]. In each case, the total return at the end of an episode is the volume utilization of the container. Defaults to DenseReward.

None
normalize_dimensions bool

if True, the observation is normalized (float) along each dimension into a unit cubic container. If False, the observation is returned in millimeters, i.e. integers (for both items and EMSs). Default to True.

True
debug bool

if True, will add to timestep.extras an invalid_ems_from_env field that checks if an invalid EMS was created by the environment, which should not happen. Computing this metric slows down the environment. Default to False.

False
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to BinPackViewer with "human" render mode.

None
Source code in jumanji/environments/packing/bin_pack/env.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def __init__(
    self,
    generator: Optional[Generator] = None,
    obs_num_ems: int = 40,
    reward_fn: Optional[RewardFn] = None,
    normalize_dimensions: bool = True,
    debug: bool = False,
    viewer: Optional[Viewer[State]] = None,
):
    """Instantiates a `BinPack` environment.

    Args:
        generator: `Generator` whose `__call__` instantiates an environment
            instance. Implemented options are [`RandomGenerator`, `ToyGenerator`,
            `CSVGenerator`]. Defaults to `RandomGenerator` that generates up to 20 items maximum
            and that can handle 40 EMSs.
        obs_num_ems: number of EMSs (possible spaces in which to place an item) to show to the
            agent. If `obs_num_ems` is smaller than `generator.max_num_ems`, the first
            `obs_num_ems` largest EMSs (in terms of volume) will be returned in the observation.
            The good number heavily depends on the number of items (given by the instance
            generator). Default to 40 EMSs observable.
        reward_fn: compute the reward based on the current state, the chosen action, the next
            state, whether the transition is valid and if it is terminal. Implemented options
            are [`DenseReward`, `SparseReward`]. In each case, the total return at the end of
            an episode is the volume utilization of the container. Defaults to `DenseReward`.
        normalize_dimensions: if True, the observation is normalized (float) along each
            dimension into a unit cubic container. If False, the observation is returned in
            millimeters, i.e. integers (for both items and EMSs). Default to True.
        debug: if True, will add to timestep.extras an `invalid_ems_from_env` field that checks
            if an invalid EMS was created by the environment, which should not happen. Computing
            this metric slows down the environment. Default to False.
        viewer: `Viewer` used for rendering. Defaults to `BinPackViewer` with "human" render
            mode.
    """
    self.generator = generator or RandomGenerator(
        max_num_items=20,
        max_num_ems=40,
        split_num_same_items=2,
    )
    self.obs_num_ems = obs_num_ems
    self.reward_fn = reward_fn or DenseReward()
    self.normalize_dimensions = normalize_dimensions
    super().__init__()
    self._viewer = viewer or BinPackViewer("BinPack", render_mode="human")
    self.debug = debug

action_spec: specs.MultiDiscreteArray cached property #

Specifications of the action expected by the BinPack environment.

Returns:

Type Description
MultiDiscreteArray

MultiDiscreteArray (int32) of shape (obs_num_ems, max_num_items).

MultiDiscreteArray
  • ems_id: int between 0 and obs_num_ems - 1 (included).
MultiDiscreteArray
  • item_id: int between 0 and max_num_items - 1 (included).

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the BinPack environment.

Returns:

Type Description
Spec[Observation]

Spec for the Observation whose fields are:

Spec[Observation]
  • ems:
  • if normalize_dimensions: tree of BoundedArray (float) of shape (obs_num_ems,).
  • else: tree of BoundedArray (int32) of shape (obs_num_ems,).
Spec[Observation]
  • ems_mask: BoundedArray (bool) of shape (obs_num_ems,).
Spec[Observation]
  • items:
  • if normalize_dimensions: tree of BoundedArray (float) of shape (max_num_items,).
  • else: tree of BoundedArray (int32) of shape (max_num_items,).
Spec[Observation]
  • items_mask: BoundedArray (bool) of shape (max_num_items,).
Spec[Observation]
  • items_placed: BoundedArray (bool) of shape (max_num_items,).
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (obs_num_ems, max_num_items).

animate(states, interval=200, save_path=None) #

Creates an animated gif of the BinPack environment based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of environment states corresponding to consecutive timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/packing/bin_pack/env.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the `BinPack` environment based on the sequence of states.

    Args:
        states: sequence of environment states corresponding to consecutive timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/packing/bin_pack/env.py
377
378
379
380
381
382
383
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Render the given state of the environment.

Parameters:

Name Type Description Default
state State

State object containing the current dynamics of the environment.

required
Source code in jumanji/environments/packing/bin_pack/env.py
350
351
352
353
354
355
356
def render(self, state: State) -> Optional[NDArray]:
    """Render the given state of the environment.

    Args:
        state: State object containing the current dynamics of the environment.
    """
    return self._viewer.render(state)

reset(key) #

Resets the environment by calling the instance generator for a new instance.

Parameters:

Name Type Description Default
key PRNGKey

random key used to reset the environment.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment after a reset.

timestep TimeStep[Observation]

TimeStep object corresponding the first timestep returned by the environment after a reset. Also contains the following metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of active EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode.

Source code in jumanji/environments/packing/bin_pack/env.py
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment by calling the instance generator for a new instance.

    Args:
        key: random key used to reset the environment.

    Returns:
        state: `State` object corresponding to the new state of the environment after a reset.
        timestep: `TimeStep` object corresponding the first timestep returned by the environment
            after a reset. Also contains the following metrics in the `extras` field:
            - volume_utilization: utilization (in [0, 1]) of the container.
            - packed_items: number of items that are packed in the container.
            - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container.
            - active_ems: number of active EMSs in the current instance.
            - invalid_action: True if the action that was just taken was invalid.
            - invalid_ems_from_env (optional): True if the environment produced an EMS that was
                invalid. Only available in debug mode.
    """
    # Generate a new instance.
    state = self.generator(key)

    # Make the observation.
    state, observation, extras = self._make_observation_and_extras(state)

    extras.update(invalid_action=jnp.array(False))
    if self.debug:
        extras.update(invalid_ems_from_env=jnp.array(False))
    timestep = restart(observation, extras)

    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics. If the action is invalid, the state is not updated, i.e. the action is not taken, and the episode terminates.

Parameters:

Name Type Description Default
state State

State object containing the data of the current instance.

required
action Array

jax array (int32) of shape (2,): (ems_id, item_id). This means placing the given item at the location of the given EMS. If the action is not valid, the flag invalid_action will be set to True in timestep.extras and the episode terminates.

required

Returns:

Name Type Description
state State

State object corresponding to the next state of the environment.

timestep TimeStep[Observation]

TimeStep object corresponding to the timestep returned by the environment. Also contains metrics in the extras field: - volume_utilization: utilization (in [0, 1]) of the container. - packed_items: number of items that are packed in the container. - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container. - active_ems: number of EMSs in the current instance. - invalid_action: True if the action that was just taken was invalid. - invalid_ems_from_env (optional): True if the environment produced an EMS that was invalid. Only available in debug mode.

Source code in jumanji/environments/packing/bin_pack/env.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics. If the action is invalid, the state
    is not updated, i.e. the action is not taken, and the episode terminates.

    Args:
        state: `State` object containing the data of the current instance.
        action: jax array (int32) of shape (2,): (ems_id, item_id). This means placing the given
            item at the location of the given EMS. If the action is not valid, the flag
            `invalid_action` will be set to True in `timestep.extras` and the episode
            terminates.

    Returns:
        state: `State` object corresponding to the next state of the environment.
        timestep: `TimeStep` object corresponding to the timestep returned by the environment.
            Also contains metrics in the `extras` field:
            - volume_utilization: utilization (in [0, 1]) of the container.
            - packed_items: number of items that are packed in the container.
            - ratio_packed_items: ratio (in [0, 1]) of items that are packed in the container.
            - active_ems: number of EMSs in the current instance.
            - invalid_action: True if the action that was just taken was invalid.
            - invalid_ems_from_env (optional): True if the environment produced an EMS that was
                invalid. Only available in debug mode.
    """
    action_is_valid = state.action_mask[tuple(action)]  # type: ignore

    obs_ems_id, item_id = action
    ems_id = state.sorted_ems_indexes[obs_ems_id]

    # Pack the item if the provided action is valid.
    next_state = jax.lax.cond(
        action_is_valid,
        lambda s: self._pack_item(s, ems_id, item_id),
        lambda s: s,
        state,
    )

    # Make the observation.
    next_state, observation, extras = self._make_observation_and_extras(next_state)

    done = ~jnp.any(next_state.action_mask) | ~action_is_valid
    reward = self.reward_fn(state, action, next_state, action_is_valid, done)

    extras.update(invalid_action=~action_is_valid)
    if self.debug:
        ems_are_all_valid = self._ems_are_all_valid(next_state)
        extras.update(invalid_ems_from_env=~ems_are_all_valid)

    timestep = jax.lax.cond(
        done,
        lambda: termination(
            reward=reward,
            observation=observation,
            extras=extras,
        ),
        lambda: transition(
            reward=reward,
            observation=observation,
            extras=extras,
        ),
    )

    return next_state, timestep