FlatPack
FlatPack (Environment)
#
The FlatPack environment with a configurable number of row and column blocks.
Here the goal of an agent is to completely fill an empty grid by placing all
available blocks. It can be thought of as a discrete 2D version of the BinPack
environment.
-
observation:
Observation
- grid: jax array (int) of shape (num_rows, num_cols) with the current state of the grid.
- blocks: jax array (int) of shape (num_blocks, 3, 3) with the blocks to be placed on the grid. Here each block is a 2D array with shape (3, 3).
- action_mask: jax array (bool) showing where which blocks can be placed on the grid. this mask includes all possible rotations and possible placement locations for each block on the grid.
-
action: jax array (int32) of shape (4,) multi discrete array containing the move to perform (block to place, number of rotations, row coordinate, column coordinate).
-
reward: jax array (float) of shape (), could be either:
- cell dense: the number of non-zero cells in a placed block normalised by the total number of cells in a grid. this will be a value in the range [0, 1]. that is to say that the agent will optimise for the maximum area to fill on the grid.
- block dense: each placed block will receive a reward of 1./num_blocks. this will be a value in the range [0, 1]. that is to say that the agent will optimise for the maximum number of blocks placed on the grid.
- sparse: 1 if the grid is completely filled, otherwise 0 at each timestep.
-
episode termination:
- if all blocks have been placed on the board.
- if the agent has taken
num_blocks
steps in the environment.
-
state:
State
- num_blocks: jax array (int32) of shape () with the number of blocks in the environment.
- blocks: jax array (int32) of shape (num_blocks, 3, 3) with the blocks to be placed on the grid. Here each block is a 2D array with shape (3, 3).
- action_mask: jax array (bool) showing where which blocks can be placed on the grid. this mask includes all possible rotations and possible placement locations for each block on the grid.
- placed_blocks: jax array (bool) of shape (num_blocks,) showing which blocks have been placed on the grid.
- grid: jax array (int32) of shape (num_rows, num_cols) with the current state of the grid.
- step_count: jax array (int32) of shape () with the number of steps taken in the environment.
- key: jax array of shape (2,) with the random key used for board generation.
1 2 3 4 5 6 7 8 |
|
observation_spec: jumanji.specs.Spec[jumanji.environments.packing.flat_pack.types.Observation]
cached
property
writable
#
Returns the observation spec of the environment.
Returns:
Type | Description |
---|---|
Spec for each filed in the observation |
|
action_spec: MultiDiscreteArray
cached
property
writable
#
Specifications of the action expected by the FlatPack
environment.
Returns:
Type | Description |
---|---|
MultiDiscreteArray (int32) of shape (num_blocks, num_rotations,
num_rows-2, num_cols-2).
- num_blocks |
int between 0 and num_blocks - 1 (inclusive). - num_rotations: int between 0 and 3 (inclusive). - max_row_position: int between 0 and num_rows - 3 (inclusive). - max_col_position: int between 0 and num_cols - 3 (inclusive). |
__init__(self, generator: Optional[jumanji.environments.packing.flat_pack.generator.InstanceGenerator] = None, reward_fn: Optional[jumanji.environments.packing.flat_pack.reward.RewardFn] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.packing.flat_pack.types.State]] = None)
special
#
Initializes the FlatPack environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
generator |
Optional[jumanji.environments.packing.flat_pack.generator.InstanceGenerator] |
Instance generator for the environment, default to |
None |
reward_fn |
Optional[jumanji.environments.packing.flat_pack.reward.RewardFn] |
Reward function for the environment, default to |
None |
viewer |
Optional[jumanji.viewer.Viewer[jumanji.environments.packing.flat_pack.types.State]] |
Viewer for rendering the environment. |
None |
reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.packing.flat_pack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.flat_pack.types.Observation]]
#
Resets the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray |
PRNG key for generating a new instance. |
required |
Returns:
Type | Description |
---|---|
Tuple[jumanji.environments.packing.flat_pack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.flat_pack.types.Observation]] |
a tuple of the initial environment state and a time step. |
step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.packing.flat_pack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.flat_pack.types.Observation]]
#
Steps the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
current state of the environment. |
required |
action |
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] |
action to take. |
required |
Returns:
Type | Description |
---|---|
Tuple[jumanji.environments.packing.flat_pack.types.State, jumanji.types.TimeStep[jumanji.environments.packing.flat_pack.types.Observation]] |
a tuple of the next environment state and a time step. |