Skip to content

FlatPack

FlatPack (Environment) #

The FlatPack environment with a configurable number of row and column blocks. Here the goal of an agent is to completely fill an empty grid by placing all available blocks. It can be thought of as a discrete 2D version of the BinPack environment.

  • observation: Observation

    • grid: jax array (int) of shape (num_rows, num_cols) with the current state of the grid.
    • blocks: jax array (int) of shape (num_blocks, 3, 3) with the blocks to be placed on the grid. Here each block is a 2D array with shape (3, 3).
    • action_mask: jax array (bool) showing where which blocks can be placed on the grid. this mask includes all possible rotations and possible placement locations for each block on the grid.
  • action: jax array (int32) of shape (4,) multi discrete array containing the move to perform (block to place, number of rotations, row coordinate, column coordinate).

  • reward: jax array (float) of shape (), could be either:

    • cell dense: the number of non-zero cells in a placed block normalised by the total number of cells in a grid. this will be a value in the range [0, 1]. that is to say that the agent will optimise for the maximum area to fill on the grid.
    • block dense: each placed block will receive a reward of 1./num_blocks. this will be a value in the range [0, 1]. that is to say that the agent will optimise for the maximum number of blocks placed on the grid.
    • sparse: 1 if the grid is completely filled, otherwise 0 at each timestep.
  • episode termination:

    • if all blocks have been placed on the board.
    • if the agent has taken num_blocks steps in the environment.
  • state: State

    • num_blocks: jax array (int32) of shape () with the number of blocks in the environment.
    • blocks: jax array (int32) of shape (num_blocks, 3, 3) with the blocks to be placed on the grid. Here each block is a 2D array with shape (3, 3).
    • action_mask: jax array (bool) showing where which blocks can be placed on the grid. this mask includes all possible rotations and possible placement locations for each block on the grid.
    • placed_blocks: jax array (bool) of shape (num_blocks,) showing which blocks have been placed on the grid.
    • grid: jax array (int32) of shape (num_rows, num_cols) with the current state of the grid.
    • step_count: jax array (int32) of shape () with the number of steps taken in the environment.
    • key: jax array of shape (2,) with the random key used for board generation.
1
2
3
4
5
6
7
8
from jumanji.environments import FlatPack
env = FlatPack()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.packing.flat_pack.types.Observation] cached property writable #

Returns the observation spec of the environment.

Returns:

Type Description
Spec for each filed in the observation
  • grid: BoundedArray (int) of shape (num_rows, num_cols).
  • blocks: BoundedArray (int) of shape (num_blocks, 3, 3).
  • action_mask: BoundedArray (bool) of shape (num_blocks, 4, num_rows-2, num_cols-2).

action_spec: MultiDiscreteArray cached property writable #

Specifications of the action expected by the FlatPack environment.

Returns:

Type Description
MultiDiscreteArray (int32) of shape (num_blocks, num_rotations, num_rows-2, num_cols-2). - num_blocks

int between 0 and num_blocks - 1 (inclusive). - num_rotations: int between 0 and 3 (inclusive). - max_row_position: int between 0 and num_rows - 3 (inclusive). - max_col_position: int between 0 and num_cols - 3 (inclusive).

__init__(self, generator: Optional[jumanji.environments.packing.flat_pack.generator.InstanceGenerator] = None, reward_fn: Optional[jumanji.environments.packing.flat_pack.reward.RewardFn] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.packing.flat_pack.types.State]] = None) special #

Initializes the FlatPack environment.

Parameters:

Name Type Description Default
generator Optional[jumanji.environments.packing.flat_pack.generator.InstanceGenerator]

Instance generator for the environment, default to RandomFlatPackGenerator with a grid of 5 blocks per row and column.

None
reward_fn Optional[jumanji.environments.packing.flat_pack.reward.RewardFn]

Reward function for the environment, default to CellDenseReward.

None
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.packing.flat_pack.types.State]]

Viewer for rendering the environment.

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.packing.flat_pack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.flat_pack.types.Observation]] #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for generating a new instance.

required

Returns:

Type Description
Tuple[jumanji.environments.packing.flat_pack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.flat_pack.types.Observation]]

a tuple of the initial environment state and a time step.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.packing.flat_pack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.flat_pack.types.Observation]] #

Steps the environment.

Parameters:

Name Type Description Default
state State

current state of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

action to take.

required

Returns:

Type Description
Tuple[jumanji.environments.packing.flat_pack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.flat_pack.types.Observation]]

a tuple of the next environment state and a time step.


Last update: 2024-11-01
Back to top