Skip to content

Snake

Snake (Environment) #

A JAX implementation of the 'Snake' game.

  • observation: Observation

    • grid: jax array (float) of shape (num_rows, num_cols, 5) feature maps that include information about the fruit, the snake head, its body and tail.
      • body: 2D map with 1. where a body cell is present, else 0.
      • head: 2D map with 1. where the snake's head is located, else 0.
      • tail: 2D map with 1. where the snake's tail is located, else 0.
      • fruit: 2D map with 1. where the fruit is located, else 0.
      • norm_body_state: 2D map with a float between 0. and 1. for each body cell in the decreasing order from head to tail.
    • step_count: jax array (int32) of shape () current number of steps in the episode.
    • action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position.
  • action: jax array (int32) of shape() [0,1,2,3] -> [Up, Right, Down, Left].

  • reward: jax array (float) of shape () 1.0 if a fruit is eaten, otherwise 0.0.

  • episode termination:

    • if no action can be performed, i.e. the snake is surrounded.
    • if the time limit is reached.
    • if an invalid action is taken, the snake exits the grid or bumps into itself.
  • state: State

    • body: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's body cells.
    • body_state: jax array (int32) of shape (num_rows, num_cols) array ordering the snake's body cells, in decreasing order from head to tail.
    • head_position: Position (int32) of shape () position of the snake's head on the 2D grid.
    • tail: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's tail.
    • fruit_position: Position (int32) of shape () position of the fruit on the 2D grid.
    • length: jax array (int32) of shape () current length of the snake.
    • step_count: jax array (int32) of shape () current number of steps in the episode.
    • action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position.
    • key: jax array (uint32) of shape (2,) random key used to sample a new fruit when one is eaten and used for auto-reset.
1
2
3
4
5
6
7
8
from jumanji.environments import Snake
env = Snake()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

observation_spec: jumanji.specs.Spec[jumanji.environments.routing.snake.types.Observation] cached property writable #

Returns the observation spec.

Returns:

Type Description
Spec for the `Observation` whose fields are
  • grid: BoundedArray (float) of shape (num_rows, num_cols, 5).
  • step_count: DiscreteArray (num_values = time_limit) of shape ().
  • action_mask: BoundedArray (bool) of shape (4,).

action_spec: DiscreteArray cached property writable #

Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left].

Returns:

Type Description
action_spec

a specs.DiscreteArray spec.

__init__(self, num_rows: int = 12, num_cols: int = 12, time_limit: int = 4000, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.snake.types.State]] = None) special #

Instantiates a Snake environment.

Parameters:

Name Type Description Default
num_rows int

number of rows of the 2D grid. Defaults to 12.

12
num_cols int

number of columns of the 2D grid. Defaults to 12.

12
time_limit int

time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 4000.

4000
viewer Optional[jumanji.viewer.Viewer[jumanji.environments.routing.snake.types.State]]

Viewer used for rendering. Defaults to SnakeViewer.

None

reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.snake.types.State, jumanji.types.TimeStep[jumanji.environments.routing.snake.types.Observation]] #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKeyArray

random key used to sample the snake and fruit positions.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment. timestep: TimeStep object corresponding to the first timestep returned by the environment.

step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]) -> Tuple[jumanji.environments.routing.snake.types.State, jumanji.types.TimeStep[jumanji.environments.routing.snake.types.Observation]] #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]

Array containing the action to take: - 0: move up. - 1: move to the right. - 2: move down. - 3: move to the left.

required

Returns:

Type Description
state, timestep

next state of the environment and timestep to be observed.


Last update: 2024-11-01
Back to top