Skip to content

FlatPack

Bases: Environment[State, MultiDiscreteArray, Observation]

The FlatPack environment with a configurable number of row and column blocks. Here the goal of an agent is to completely fill an empty grid by placing all available blocks. It can be thought of as a discrete 2D version of the BinPack environment.

  • observation: Observation

    • grid: jax array (int) of shape (num_rows, num_cols) with the current state of the grid.
    • blocks: jax array (int) of shape (num_blocks, 3, 3) with the blocks to be placed on the grid. Here each block is a 2D array with shape (3, 3).
    • action_mask: jax array (bool) showing where which blocks can be placed on the grid. this mask includes all possible rotations and possible placement locations for each block on the grid.
  • action: jax array (int32) of shape (4,) multi discrete array containing the move to perform (block to place, number of rotations, row coordinate, column coordinate).

  • reward: jax array (float) of shape (), could be either:

    • cell dense: the number of non-zero cells in a placed block normalised by the total number of cells in a grid. this will be a value in the range [0, 1]. that is to say that the agent will optimise for the maximum area to fill on the grid.
    • block dense: each placed block will receive a reward of 1./num_blocks. this will be a value in the range [0, 1]. that is to say that the agent will optimise for the maximum number of blocks placed on the grid.
    • sparse: 1 if the grid is completely filled, otherwise 0 at each timestep.
  • episode termination:

    • if all blocks have been placed on the board.
    • if the agent has taken num_blocks steps in the environment.
  • state: State

    • num_blocks: jax array (int32) of shape () with the number of blocks in the environment.
    • blocks: jax array (int32) of shape (num_blocks, 3, 3) with the blocks to be placed on the grid. Here each block is a 2D array with shape (3, 3).
    • action_mask: jax array (bool) showing where which blocks can be placed on the grid. this mask includes all possible rotations and possible placement locations for each block on the grid.
    • placed_blocks: jax array (bool) of shape (num_blocks,) showing which blocks have been placed on the grid.
    • grid: jax array (int32) of shape (num_rows, num_cols) with the current state of the grid.
    • step_count: jax array (int32) of shape () with the number of steps taken in the environment.
    • key: jax array of shape (2,) with the random key used for board generation.
1
2
3
4
5
6
7
8
from jumanji.environments import FlatPack
env = FlatPack()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Initializes the FlatPack environment.

Parameters:

Name Type Description Default
generator Optional[InstanceGenerator]

Instance generator for the environment, default to RandomFlatPackGenerator with a grid of 5 blocks per row and column.

None
reward_fn Optional[RewardFn]

Reward function for the environment, default to CellDenseReward.

None
viewer Optional[Viewer[State]]

Viewer for rendering the environment.

None
Source code in jumanji/environments/packing/flat_pack/env.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def __init__(
    self,
    generator: Optional[InstanceGenerator] = None,
    reward_fn: Optional[RewardFn] = None,
    viewer: Optional[Viewer[State]] = None,
):
    """Initializes the FlatPack environment.

    Args:
        generator: Instance generator for the environment, default to `RandomFlatPackGenerator`
            with a grid of 5 blocks per row and column.
        reward_fn: Reward function for the environment, default to `CellDenseReward`.
        viewer: Viewer for rendering the environment.
    """

    default_generator = RandomFlatPackGenerator(
        num_row_blocks=5,
        num_col_blocks=5,
    )

    self.generator = generator or default_generator
    self.num_row_blocks = self.generator.num_row_blocks
    self.num_col_blocks = self.generator.num_col_blocks
    self.num_blocks = self.num_row_blocks * self.num_col_blocks
    self.num_rows, self.num_cols = (
        compute_grid_dim(self.num_row_blocks),
        compute_grid_dim(self.num_col_blocks),
    )
    self.reward_fn = reward_fn or CellDenseReward()
    self.viewer = viewer or FlatPackViewer("FlatPack", self.num_blocks, render_mode="human")
    super().__init__()

action_spec: specs.MultiDiscreteArray cached property #

Specifications of the action expected by the FlatPack environment.

Returns:

Type Description
MultiDiscreteArray

MultiDiscreteArray (int32) of shape (num_blocks, num_rotations,

MultiDiscreteArray

num_rows-2, num_cols-2).

MultiDiscreteArray
  • num_blocks: int between 0 and num_blocks - 1 (inclusive).
MultiDiscreteArray
  • num_rotations: int between 0 and 3 (inclusive).
MultiDiscreteArray
  • max_row_position: int between 0 and num_rows - 3 (inclusive).
MultiDiscreteArray
  • max_col_position: int between 0 and num_cols - 3 (inclusive).

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec of the environment.

Returns:

Type Description
Spec[Observation]

Spec for each filed in the observation:

Spec[Observation]
  • grid: BoundedArray (int) of shape (num_rows, num_cols).
Spec[Observation]
  • blocks: BoundedArray (int) of shape (num_blocks, 3, 3).
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (num_blocks, 4, num_rows-2, num_cols-2).

animate(states, interval=200, save_path=None) #

Create an animation from a sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of State corresponding to subsequent timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation that can export to gif, mp4, or render with HTML.

Source code in jumanji/environments/packing/flat_pack/env.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Create an animation from a sequence of states.

    Args:
        states: sequence of `State` corresponding to subsequent timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation that can export to gif, mp4, or render with HTML.
    """

    return self.viewer.animate(states, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/packing/flat_pack/env.py
249
250
251
252
253
254
255
256
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically `close()` themselves when
    garbage collected or when the program exits.
    """

    self.viewer.close()

render(state) #

Render a given state of the environment.

Parameters:

Name Type Description Default
state State

State object containing the current environment state.

required
Source code in jumanji/environments/packing/flat_pack/env.py
220
221
222
223
224
225
226
227
def render(self, state: State) -> Optional[NDArray]:
    """Render a given state of the environment.

    Args:
        state: `State` object containing the current environment state.
    """

    return self.viewer.render(state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

PRNG key for generating a new instance.

required

Returns:

Type Description
Tuple[State, TimeStep[Observation]]

a tuple of the initial environment state and a time step.

Source code in jumanji/environments/packing/flat_pack/env.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def reset(
    self,
    key: chex.PRNGKey,
) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: PRNG key for generating a new instance.

    Returns:
        a tuple of the initial environment state and a time step.
    """

    grid_state = self.generator(key)

    obs = self._observation_from_state(grid_state)
    timestep = restart(observation=obs)

    return grid_state, timestep

step(state, action) #

Steps the environment.

Parameters:

Name Type Description Default
state State

current state of the environment.

required
action Array

action to take.

required

Returns:

Type Description
Tuple[State, TimeStep[Observation]]

a tuple of the next environment state and a time step.

Source code in jumanji/environments/packing/flat_pack/env.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Steps the environment.

    Args:
        state: current state of the environment.
        action: action to take.

    Returns:
        a tuple of the next environment state and a time step.
    """

    # Unpack and use actions
    block_idx, rotation, row_idx, col_idx = action

    chosen_block = state.blocks[block_idx]

    # Rotate chosen block
    chosen_block = rotate_block(chosen_block, rotation)

    grid_block = self._expand_block_to_grid(chosen_block, row_idx, col_idx)

    action_is_legal = state.action_mask[block_idx, rotation, row_idx, col_idx]

    # If the action is legal create a new grid and update the placed blocks array
    new_grid = jax.lax.cond(
        action_is_legal,
        lambda: state.grid + grid_block,
        lambda: state.grid,
    )
    placed_blocks = jax.lax.cond(
        action_is_legal,
        lambda: state.placed_blocks.at[block_idx].set(True),
        lambda: state.placed_blocks,
    )

    new_action_mask = self._make_action_mask(new_grid, state.blocks, placed_blocks)

    next_state = State(
        grid=new_grid,
        blocks=state.blocks,
        action_mask=new_action_mask,
        num_blocks=state.num_blocks,
        key=state.key,
        step_count=state.step_count + 1,
        placed_blocks=placed_blocks,
    )

    done = self._is_done(next_state)
    next_obs = self._observation_from_state(next_state)
    reward = self.reward_fn(state, grid_block, next_state, action_is_legal, done)

    timestep = jax.lax.cond(
        done,
        termination,
        transition,
        reward,
        next_obs,
    )

    return next_state, timestep