Skip to content

Tetris

Bases: Environment[State, MultiDiscreteArray, Observation]

RL Environment for the game of Tetris. The environment has a grid where the player can place tetrominoes. The environment has the following characteristics:

  • observation: Observation
    • grid: jax array (int32) of shape (num_rows, num_cols) representing the current state of the grid.
    • tetromino: jax array (int32) of shape (4, 4) representing the current tetromino sampled from the tetromino list.
    • action_mask: jax array (bool) of shape (4, num_cols). For each tetromino there are 4 rotations, each one corresponds to a line in the action_mask. Mask of the joint action space: True if the action (x_position and rotation degree) is feasible for the current tetromino and grid state.
  • action: multi discrete array of shape (2,)

    • rotation_index: The degree index determines the rotation of the tetromino: 0 corresponds to 0 degrees, 1 corresponds to 90 degrees, 2 corresponds to 180 degrees, and 3 corresponds to 270 degrees.
    • x_position: int between 0 and num_cols - 1 (included).
  • reward: The reward is 0 if no lines was cleared by the action and a convex function of the number of cleared lines otherwise.

  • episode termination: if the tetromino cannot be placed anymore (i.e., it hits the top of the grid).

1
2
3
4
5
6
7
8
from jumanji.environments import Tetris
env = Tetris()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a Tetris environment.

Parameters:

Name Type Description Default
num_rows int

number of rows of the 2D grid. Defaults to 10.

10
num_cols int

number of columns of the 2D grid. Defaults to 10.

10
time_limit int

time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 400.

400
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to TetrisViewer.

None
Source code in jumanji/environments/packing/tetris/env.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def __init__(
    self,
    num_rows: int = 10,
    num_cols: int = 10,
    time_limit: int = 400,
    viewer: Optional[Viewer[State]] = None,
) -> None:
    """Instantiates a `Tetris` environment.

    Args:
        num_rows: number of rows of the 2D grid. Defaults to 10.
        num_cols: number of columns of the 2D grid. Defaults to 10.
        time_limit: time_limit of an episode, i.e. number of environment steps before
            the episode ends. Defaults to 400.
        viewer: `Viewer` used for rendering. Defaults to `TetrisViewer`.
    """
    if num_rows < 4:
        raise ValueError(f"The `num_rows` must be >= 4, but got num_rows={num_rows}")
    if num_cols < 4:
        raise ValueError(f"The `num_cols` must be >= 4, but got num_cols={num_cols}")
    self.num_rows = num_rows
    self.num_cols = num_cols
    self.padded_num_rows = num_rows + 3
    self.padded_num_cols = num_cols + 3
    self.TETROMINOES_LIST = jnp.array(TETROMINOES_LIST, jnp.int32)
    self.reward_list = jnp.array(REWARD_LIST, float)
    self.time_limit = time_limit
    super().__init__()
    self._viewer = viewer or TetrisViewer(
        num_rows=self.num_rows,
        num_cols=self.num_cols,
    )

action_spec: specs.MultiDiscreteArray cached property #

Returns the action spec. An action consists of two pieces of information: the amount of rotation (number of 90-degree rotations) and the x-position of the leftmost part of the tetromino.

Returns:

Type Description
MultiDiscreteArray

The action spec, which is a specs.MultiDiscreteArray object.

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the Tetris environment.

Returns:

Type Description
Spec[Observation]

Spec containing all the specifications for all the Observation fields: - grid: BoundedArray (jnp.int32) of shape (num_rows, num_cols). - tetromino: BoundedArray (bool) of shape (4, 4). - action_mask: BoundedArray (bool) of shape (NUM_ROTATIONS, num_cols). - step_count: DiscreteArray (num_values = time_limit) of shape ().

animate(states, interval=100, save_path=None) #

Create an animation from a sequence of states. Args: states: sequence of State corresponding to subsequent timesteps. interval: delay between frames in milliseconds, default to 100. save_path: the path where the animation file should be saved. If it is None, the plot will not be saved. Returns: animation that can export to gif, mp4, or render with HTML.

Source code in jumanji/environments/packing/tetris/env.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def animate(
    self,
    states: Sequence[State],
    interval: int = 100,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Create an animation from a sequence of states.
    Args:
        states: sequence of `State` corresponding to subsequent timesteps.
        interval: delay between frames in milliseconds, default to 100.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.
    Returns:
        animation that can export to gif, mp4, or render with HTML.
    """

    return self._viewer.animate(states, interval, save_path)

render(state) #

Render the given state of the environment. Args: state: State object containing the current environment state.

Source code in jumanji/environments/packing/tetris/env.py
232
233
234
235
236
237
def render(self, state: State) -> Optional[NDArray]:
    """Render the given state of the environment.
    Args:
        state: `State` object containing the current environment state.
    """
    return self._viewer.render(state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

needed for generating new tetrominoes.

required

Returns:

Name Type Description
state State

State corresponding to the new state of the environment,

timestep TimeStep[Observation]

TimeStep corresponding to the first timestep returned by the environment.

Source code in jumanji/environments/packing/tetris/env.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: needed for generating new tetrominoes.

    Returns:
        state: `State` corresponding to the new state of the environment,
        timestep: `TimeStep` corresponding to the first timestep returned by the
            environment.
    """
    grid_padded = jnp.zeros(shape=(self.padded_num_rows, self.padded_num_cols), dtype=jnp.int32)
    tetromino, tetromino_index = utils.sample_tetromino_list(key, self.TETROMINOES_LIST)

    action_mask = self._calculate_action_mask(grid_padded, tetromino_index)
    state = State(
        grid_padded=grid_padded,
        grid_padded_old=grid_padded,
        tetromino_index=tetromino_index,
        old_tetromino_rotated=tetromino,
        new_tetromino=tetromino,
        x_position=jnp.array(0, jnp.int32),
        y_position=jnp.array(0, jnp.int32),
        action_mask=action_mask,
        full_lines=jnp.full((self.num_rows + 3), False),
        score=jnp.array(0, float),
        reward=jnp.array(0, float),
        key=key,
        is_reset=True,
        step_count=jnp.array(0, jnp.int32),
    )

    observation = Observation(
        grid=grid_padded[: self.num_rows, : self.num_cols],
        tetromino=tetromino,
        action_mask=action_mask,
        step_count=jnp.array(0, jnp.int32),
    )
    timestep = restart(observation=observation)
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

chex.Array containing the rotation_index and x_position of the tetromino.

required

Returns:

Name Type Description
next_state State

State corresponding to the next state of the environment,

next_timestep TimeStep[Observation]

TimeStep corresponding to the timestep returned by the environment.

Source code in jumanji/environments/packing/tetris/env.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: `State` object containing the dynamics of the environment.
        action: `chex.Array` containing the rotation_index and x_position of the tetromino.

    Returns:
        next_state: `State` corresponding to the next state of the environment,
        next_timestep: `TimeStep` corresponding to the timestep returned by the environment.
    """
    rotation_index, x_position = action
    tetromino_index = state.tetromino_index
    key, sample_key = jax.random.split(state.key)
    tetromino = self._rotate(rotation_index, tetromino_index)
    # Place the tetromino in the selected place
    grid_padded, y_position = utils.place_tetromino(state.grid_padded, tetromino, x_position)
    # A line is full when it doesn't contain any 0.
    full_lines = jnp.all(grid_padded[:, : self.num_cols] != 0, axis=1)
    nbr_full_lines = sum(full_lines)
    grid_padded = utils.clean_lines(grid_padded, full_lines)
    # Generate new tetromino
    new_tetromino, tetromino_index = utils.sample_tetromino_list(
        sample_key, self.TETROMINOES_LIST
    )
    grid_padded_cliped = jnp.clip(grid_padded, a_max=1)
    action_mask = self._calculate_action_mask(grid_padded_cliped, tetromino_index)
    # The maximum should be bigger than 0.
    # In case the grid is empty the color should be set 0.
    color = jnp.array([1, grid_padded.max()])
    colored_tetromino = tetromino * jnp.max(color)
    is_valid = state.action_mask[tuple(action)]
    reward = self.reward_list[nbr_full_lines] * is_valid
    step_count = state.step_count + 1
    next_state = State(
        grid_padded=grid_padded,
        grid_padded_old=state.grid_padded,
        tetromino_index=tetromino_index,
        old_tetromino_rotated=colored_tetromino,
        new_tetromino=new_tetromino,
        x_position=x_position,
        y_position=y_position,
        action_mask=action_mask,
        full_lines=full_lines,
        score=state.score + reward,
        reward=reward,
        key=key,
        is_reset=False,
        step_count=step_count,
    )
    next_observation = Observation(
        grid=grid_padded_cliped[: self.num_rows, : self.num_cols],
        tetromino=new_tetromino,
        action_mask=action_mask,
        step_count=jnp.array(0, jnp.int32),
    )

    tetris_completed = ~jnp.any(action_mask)
    done = tetris_completed | ~is_valid | (step_count >= self.time_limit)

    next_timestep = jax.lax.cond(
        done,
        termination,
        transition,
        reward,
        next_observation,
    )
    return next_state, next_timestep