Skip to content

Game2048

Bases: Environment[State, DiscreteArray, Observation]

Environment for the game 2048. The game consists of a board of size board_size x board_size (4x4 by default) in which the player can take actions to move the tiles on the board up, down, left, or right. The goal of the game is to combine tiles with the same number to create a tile with twice the value, until the player at least creates a tile with the value 2048 to consider it a win.

  • observation: Observation

    • board: jax array (int32) of shape (board_size, board_size) the current state of the board. An empty tile is represented by zero whereas a non-empty tile is an exponent of 2, e.g. 1, 2, 3, 4, ... (corresponding to 2, 4, 8, 16, ...).
    • action_mask: jax array (bool) of shape (4,) indicates which actions are valid in the current state of the environment.
  • action: jax array (int32) of shape (). Is in [0, 1, 2, 3] representing the actions up, right, down, and left, respectively.

  • reward: jax array (float) of shape (). The reward is 0 except when the player combines tiles to create a new tile with twice the value. In this case, the reward is the value of the new tile.

  • episode termination:

    • if no more valid moves exist (this can happen when the board is full).
  • state: State

    • board: same as observation.
    • step_count: jax array (int32) of shape (), the number of time steps in the episode so far.
    • action_mask: same as observation.
    • score: jax array (int32) of shape (), the sum of all tile values on the board.
    • key: jax array (uint32) of shape (2,) random key used to generate random numbers at each step and for auto-reset.
1
2
3
4
5
6
7
8
from jumanji.environments import Game2048
env = Game2048()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Initialize the 2048 game.

Parameters:

Name Type Description Default
board_size int

size of the board. Defaults to 4.

4
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to Game2048Viewer.

None
Source code in jumanji/environments/logic/game_2048/env.py
79
80
81
82
83
84
85
86
87
88
89
90
def __init__(self, board_size: int = 4, viewer: Optional[Viewer[State]] = None) -> None:
    """Initialize the 2048 game.

    Args:
        board_size: size of the board. Defaults to 4.
        viewer: `Viewer` used for rendering. Defaults to `Game2048Viewer`.
    """
    self.board_size = board_size
    super().__init__()

    # Create viewer used for rendering
    self._viewer = viewer or Game2048Viewer("2048", board_size)

action_spec: specs.DiscreteArray cached property #

Returns the action spec.

4 actions: [0, 1, 2, 3] -> [Up, Right, Down, Left].

Returns:

Name Type Description
action_spec DiscreteArray

DiscreteArray spec object.

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the Game2048 environment.

Returns:

Type Description
Spec[Observation]

Spec containing all the specifications for all the Observation fields: - board: Array (jnp.int32) of shape (board_size, board_size). - action_mask: BoundedArray (bool) of shape (4,).

__repr__() #

String representation of the environment.

Returns:

Name Type Description
str str

the string representation of the environment.

Source code in jumanji/environments/logic/game_2048/env.py
92
93
94
95
96
97
98
def __repr__(self) -> str:
    """String representation of the environment.

    Returns:
        str: the string representation of the environment.
    """
    return f"2048 Game(board_size={self.board_size})"

animate(states, interval=200, save_path=None) #

Creates an animated gif of the 2048 game board based on the sequence of game states.

Parameters:

Name Type Description Default
states Sequence[State]

is a list of State objects representing the sequence of game states.

required
interval int

the delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/logic/game_2048/env.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> animation.FuncAnimation:
    """Creates an animated gif of the 2048 game board based on the sequence of game states.

    Args:
        states: is a list of `State` objects representing the sequence of game states.
        interval: the delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
        will not be stored.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states=states, interval=interval, save_path=save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/logic/game_2048/env.py
324
325
326
327
328
329
330
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Renders the current state of the game board.

Parameters:

Name Type Description Default
state State

is the current game state to be rendered.

required
Source code in jumanji/environments/logic/game_2048/env.py
297
298
299
300
301
302
303
def render(self, state: State) -> Optional[NDArray]:
    """Renders the current state of the game board.

    Args:
        state: is the current game state to be rendered.
    """
    return self._viewer.render(state=state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

random number generator key.

required

Returns:

Name Type Description
state State

the new state of the environment.

timestep TimeStep[Observation]

the first timestep returned by the environment.

Source code in jumanji/environments/logic/game_2048/env.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: random number generator key.

    Returns:
        state: the new state of the environment.
        timestep: the first timestep returned by the environment.
    """

    key, board_key = jax.random.split(key)
    board = self._generate_board(board_key)
    action_mask = self._get_action_mask(board)

    obs = Observation(board=board, action_mask=action_mask)

    state = State(
        board=board,
        step_count=jnp.array(0, jnp.int32),
        action_mask=action_mask,
        key=key,
        score=jnp.array(0, float),
    )

    highest_tile = 2 ** jnp.max(board)
    timestep = restart(observation=obs, extras={"highest_tile": highest_tile})

    return state, timestep

step(state, action) #

Updates the environment state after the agent takes an action.

Parameters:

Name Type Description Default
state State

the current state of the environment.

required
action Array

the action taken by the agent.

required

Returns:

Name Type Description
state State

the new state of the environment.

timestep TimeStep[Observation]

the next timestep.

Source code in jumanji/environments/logic/game_2048/env.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Updates the environment state after the agent takes an action.

    Args:
        state: the current state of the environment.
        action: the action taken by the agent.

    Returns:
        state: the new state of the environment.
        timestep: the next timestep.
    """
    # Take the action in the environment: Up, Right, Down, Left.
    updated_board, reward = move(state.board, action)

    # Generate new key.
    random_cell_key, new_state_key = jax.random.split(state.key)

    # Update the state of the board by adding a new random cell.
    updated_board = jax.lax.cond(
        state.action_mask[action],
        self._add_random_cell,
        lambda board, key: board,
        updated_board,
        random_cell_key,
    )

    # Generate action mask to keep in the state for the next step and
    # to provide to the agent in the observation.
    action_mask = self._get_action_mask(board=updated_board)

    # Build the state.
    state = State(
        board=updated_board,
        action_mask=action_mask,
        step_count=state.step_count + 1,
        key=new_state_key,
        score=state.score + reward,
    )

    # Generate the observation from the environment state.
    observation = Observation(
        board=updated_board,
        action_mask=action_mask,
    )

    # Check if the episode terminates (i.e. there are no legal actions).
    done = ~jnp.any(action_mask)

    # Return either a MID or a LAST timestep depending on done.
    highest_tile = 2 ** jnp.max(updated_board)
    extras = {"highest_tile": highest_tile}
    timestep = jax.lax.cond(
        done,
        lambda: termination(
            reward=reward,
            observation=observation,
            extras=extras,
        ),
        lambda: transition(
            reward=reward,
            observation=observation,
            extras=extras,
        ),
    )

    return state, timestep