Skip to content

Sudoku

Sudoku (Environment) #

A JAX implementation of the sudoku game.

  • observation: Observation

    • board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8.
    • action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid.
  • action: multi discrete array containing the square to write a digit, and the digits to input.

  • reward: jax array (float32): 1 at the end of the episode if the board is valid 0 otherwise

  • state: State

    • board: jax array (int32) of shape (9,9): empty cells are represented by -1, and filled cells are represented by 0-8.

    • action_mask: jax array (bool) of shape (9,9,9): indicates which actions are valid (empty cells and valid digits).

    • key: jax array (int32) of shape (2,) used for seeding initial sudoku configuration.

1
2
3
4
5
6
7
8
from jumanji.environments import Sudoku
env = Sudoku()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.logic.sudoku.types.Observation] cached property writable #

Returns the observation spec containing the board and action_mask arrays.

Returns:

Type Description
Spec containing all the specifications for all the `Observation` fields
  • board: BoundedArray (jnp.int8) of shape (9,9).
  • action_mask: BoundedArray (bool) of shape (9,9,9).

action_spec: MultiDiscreteArray cached property writable #

Returns the action spec. An action is composed of 3 integers: the row index, the column index and the value to be placed in the cell.

Returns:

Type Description
action_spec

MultiDiscreteArray object.

__init__(self, generator: Optional[jumanji.environments.logic.sudoku.generator.Generator] = None, reward_fn: Optional[jumanji.environments.logic.sudoku.reward.RewardFn] = None, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.logic.sudoku.types.State]] = None) special #

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.logic.sudoku.types.State, jumanji.types.TimeStep[jumanji.environments.logic.sudoku.types.Observation]] #

Resets the environment to an initial state.

Parameters:

Name Type Description Default
key PRNGKeyArray

random key used to reset the environment.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment,

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]) -> Tuple[jumanji.environments.logic.sudoku.types.State, jumanji.types.TimeStep[jumanji.environments.logic.sudoku.types.Observation]] #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

Array containing the action to take.

required

Returns:

Type Description
state

State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment,


Last update: 2024-11-01
Back to top