Skip to content

Minesweeper

Bases: Environment[State, MultiDiscreteArray, Observation]

A JAX implementation of the minesweeper game.

  • observation: Observation

    • board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares.
    • action_mask: jax array (bool) of shape (num_rows, num_cols): indicates which actions are valid (not yet explored squares).
    • num_mines: jax array (int32) of shape (), indicates the number of mines to locate.
    • step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset.
  • action: multi discrete array containing the square to explore (row and col).

  • reward: jax array (float32): Configurable function of state and action. By default: 1 for every timestep where a valid action is chosen that doesn't reveal a mine, 0 for revealing a mine or selecting an already revealed square (and terminate the episode).

  • episode termination: Configurable function of state, next_state, and action. By default: Stop the episode if a mine is explored, an invalid action is selected (exploring an already explored square), or the board is solved.

  • state: State

    • board: jax array (int32) of shape (num_rows, num_cols): each cell contains -1 if not yet explored, or otherwise the number of mines in the 8 adjacent squares.
    • step_count: jax array (int32) of shape (): specifies how many timesteps have elapsed since environment reset.
    • flat_mine_locations: jax array (int32) of shape (num_rows * num_cols,): indicates the (flat) locations of all the mines on the board. Will be of length num_mines.
    • key: jax array (int32) of shape (2,) used for seeding the sampling of mine placement on reset.
1
2
3
4
5
6
7
8
from jumanji.environments import Minesweeper
env = Minesweeper()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiate a Minesweeper environment.

Parameters:

Name Type Description Default
generator Optional[Generator]

Generator to generate problem instances on environment reset. Implemented options are [SamplingGenerator]. Defaults to SamplingGenerator. The generator will have attributes: - num_rows: number of rows, i.e. height of the board. Defaults to 10. - num_cols: number of columns, i.e. width of the board. Defaults to 10. - num_mines: number of mines generated. Defaults to 10.

None
reward_function Optional[RewardFn]

RewardFn whose __call__ method computes the reward of an environment transition based on the given current state and selected action. Implemented options are [DefaultRewardFn]. Defaults to DefaultRewardFn, giving a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and 0.0 for an invalid action (selecting an already revealed square).

None
done_function Optional[DoneFn]

DoneFn whose __call__ method computes the done signal given the current state, action taken, and next state. Implemented options are [DefaultDoneFn]. Defaults to DefaultDoneFn, ending the episode on solving the board, revealing a mine, or picking an invalid action.

None
viewer Optional[Viewer[State]]

Viewer to support rendering and animation methods. Implemented options are [MinesweeperViewer]. Defaults to MinesweeperViewer.

None
Source code in jumanji/environments/logic/minesweeper/env.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def __init__(
    self,
    generator: Optional[Generator] = None,
    reward_function: Optional[RewardFn] = None,
    done_function: Optional[DoneFn] = None,
    viewer: Optional[Viewer[State]] = None,
):
    """Instantiate a `Minesweeper` environment.

    Args:
        generator: `Generator` to generate problem instances on environment reset.
            Implemented options are [`SamplingGenerator`]. Defaults to `SamplingGenerator`.
            The generator will have attributes:
                - num_rows: number of rows, i.e. height of the board. Defaults to 10.
                - num_cols: number of columns, i.e. width of the board. Defaults to 10.
                - num_mines: number of mines generated. Defaults to 10.
        reward_function: `RewardFn` whose `__call__` method computes the reward of an
            environment transition based on the given current state and selected action.
            Implemented options are [`DefaultRewardFn`]. Defaults to `DefaultRewardFn`, giving
            a reward of 1.0 for revealing an empty square, 0.0 for revealing a mine, and
            0.0 for an invalid action (selecting an already revealed square).
        done_function: `DoneFn` whose `__call__` method computes the done signal given the
            current state, action taken, and next state.
            Implemented options are [`DefaultDoneFn`]. Defaults to `DefaultDoneFn`, ending the
            episode on solving the board, revealing a mine, or picking an invalid action.
        viewer: `Viewer` to support rendering and animation methods.
            Implemented options are [`MinesweeperViewer`]. Defaults to `MinesweeperViewer`.
    """
    self.reward_function = reward_function or DefaultRewardFn(
        revealed_empty_square_reward=1.0,
        revealed_mine_reward=0.0,
        invalid_action_reward=0.0,
    )
    self.done_function = done_function or DefaultDoneFn()
    self.generator = generator or UniformSamplingGenerator(
        num_rows=10, num_cols=10, num_mines=10
    )
    self.num_rows = self.generator.num_rows
    self.num_cols = self.generator.num_cols
    self.num_mines = self.generator.num_mines
    super().__init__()
    self._viewer = viewer or MinesweeperViewer(num_rows=self.num_rows, num_cols=self.num_cols)

action_spec: specs.MultiDiscreteArray cached property #

Returns the action spec. An action consists of the height and width of the square to be explored.

Returns:

Name Type Description
action_spec MultiDiscreteArray

specs.MultiDiscreteArray object.

observation_spec: specs.Spec[Observation] cached property #

Specifications of the observation of the Minesweeper environment.

Returns:

Type Description
Spec[Observation]

Spec for the Observation whose fields are: - board: BoundedArray (int32) of shape (num_rows, num_cols). - action_mask: BoundedArray (bool) of shape (num_rows, num_cols). - num_mines: BoundedArray (int32) of shape (). - step_count: BoundedArray (int32) of shape ().

animate(states, interval=200, save_path=None) #

Creates an animated gif of the board based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

a list of State objects representing the sequence of states.

required
interval int

the delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/logic/minesweeper/env.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the board based on the sequence of states.

    Args:
        states: a list of `State` objects representing the sequence of states.
        interval: the delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states=states, interval=interval, save_path=save_path)

close() #

Perform any necessary cleanup. Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/logic/minesweeper/env.py
278
279
280
281
282
283
def close(self) -> None:
    """Perform any necessary cleanup.
    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Renders the current state of the board.

Parameters:

Name Type Description Default
state State

the current state to be rendered.

required
Source code in jumanji/environments/logic/minesweeper/env.py
251
252
253
254
255
256
257
def render(self, state: State) -> Optional[NDArray]:
    """Renders the current state of the board.

    Args:
        state: the current state to be rendered.
    """
    return self._viewer.render(state=state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

needed for placing mines.

required

Returns:

Name Type Description
state State

State corresponding to the new state of the environment,

timestep TimeStep[Observation]

TimeStep corresponding to the first timestep returned by the environment.

Source code in jumanji/environments/logic/minesweeper/env.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: needed for placing mines.

    Returns:
        state: `State` corresponding to the new state of the environment,
        timestep: `TimeStep` corresponding to the first timestep returned by the
            environment.
    """
    state = self.generator(key)
    observation = self._state_to_observation(state=state)
    timestep = restart(observation=observation)
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

Array containing the row and column of the square to be explored.

required

Returns:

Name Type Description
next_state State

State corresponding to the next state of the environment,

next_timestep TimeStep[Observation]

TimeStep corresponding to the timestep returned by the environment.

Source code in jumanji/environments/logic/minesweeper/env.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: `State` object containing the dynamics of the environment.
        action: `Array` containing the row and column of the square to be explored.

    Returns:
        next_state: `State` corresponding to the next state of the environment,
        next_timestep: `TimeStep` corresponding to the timestep returned by the environment.
    """
    board = state.board.at[tuple(action)].set(count_adjacent_mines(state=state, action=action))
    step_count = state.step_count + 1
    next_state = State(
        board=board,
        step_count=step_count,
        key=state.key,
        flat_mine_locations=state.flat_mine_locations,
    )
    reward = self.reward_function(state, action)
    done = self.done_function(state, next_state, action)
    next_observation = self._state_to_observation(state=next_state)
    next_timestep = jax.lax.cond(
        done,
        termination,
        transition,
        reward,
        next_observation,
    )
    return next_state, next_timestep