Skip to content

Sudoku

Bases: Environment[State, MultiDiscreteArray, Observation]

A JAX implementation of the sudoku game.

  • observation: Observation

    • board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8.
    • action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid.
  • action: multi discrete array containing the square to write a digit, and the digits to input.

  • reward: jax array (float32): 1 at the end of the episode if the board is valid 0 otherwise

  • state: State

    • board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8.

    • action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid (empty cells and valid digits).

    • key: jax array (int32) of shape (2,) used for seeding initial sudoku configuration.

1
2
3
4
5
6
7
8
from jumanji.environments import Sudoku
env = Sudoku()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)
Source code in jumanji/environments/logic/sudoku/env.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def __init__(
    self,
    generator: Optional[Generator] = None,
    reward_fn: Optional[RewardFn] = None,
    viewer: Optional[Viewer[State]] = None,
):
    super().__init__()
    if generator is None:
        file_path = os.path.dirname(os.path.abspath(__file__))
        database_file = DATABASES["mixed"]
        database = jnp.load(os.path.join(file_path, "data", database_file))

    self._generator = generator or DatabaseGenerator(database=database)
    self._reward_fn = reward_fn or SparseRewardFn()
    self._viewer = viewer or SudokuViewer()

action_spec: specs.MultiDiscreteArray cached property #

Returns the action spec. An action is composed of 3 integers: the row index, the column index and the value to be placed in the cell.

Returns:

Name Type Description
action_spec MultiDiscreteArray

MultiDiscreteArray object.

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec containing the board and action_mask arrays.

Returns:

Type Description
Spec[Observation]

Spec containing all the specifications for all the Observation fields: - board: BoundedArray (jnp.int8) of shape (9,9). - action_mask: BoundedArray (bool) of shape (9,9,9).

animate(states, interval=200, save_path=None) #

Creates an animated gif of the board based on the sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

a list of State objects representing the sequence of states.

required
interval int

the delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

animation.FuncAnimation: the animation object that was created.

Source code in jumanji/environments/logic/sudoku/env.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Creates an animated gif of the board based on the sequence of states.

    Args:
        states: a list of `State` objects representing the sequence of states.
        interval: the delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        animation.FuncAnimation: the animation object that was created.
    """
    return self._viewer.animate(states=states, interval=interval, save_path=save_path)

render(state) #

Renders the current state of the sudoku.

Parameters:

Name Type Description Default
state State

the current state to be rendered.

required
Source code in jumanji/environments/logic/sudoku/env.py
170
171
172
173
174
175
176
def render(self, state: State) -> Any:
    """Renders the current state of the sudoku.

    Args:
        state: the current state to be rendered.
    """
    return self._viewer.render(state=state)