Snake
Snake (Environment)
#
A JAX implementation of the 'Snake' game.
-
observation:
Observation
- grid: jax array (float) of shape (num_rows, num_cols, 5)
feature maps that include information about the fruit, the snake head, its body and
tail.
- body: 2D map with 1. where a body cell is present, else 0.
- head: 2D map with 1. where the snake's head is located, else 0.
- tail: 2D map with 1. where the snake's tail is located, else 0.
- fruit: 2D map with 1. where the fruit is located, else 0.
- norm_body_state: 2D map with a float between 0. and 1. for each body cell in the decreasing order from head to tail.
- step_count: jax array (int32) of shape () current number of steps in the episode.
- action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position.
- grid: jax array (float) of shape (num_rows, num_cols, 5)
feature maps that include information about the fruit, the snake head, its body and
tail.
-
action: jax array (int32) of shape() [0,1,2,3] -> [Up, Right, Down, Left].
-
reward: jax array (float) of shape () 1.0 if a fruit is eaten, otherwise 0.0.
-
episode termination:
- if no action can be performed, i.e. the snake is surrounded.
- if the time limit is reached.
- if an invalid action is taken, the snake exits the grid or bumps into itself.
-
state:
State
- body: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's body cells.
- body_state: jax array (int32) of shape (num_rows, num_cols) array ordering the snake's body cells, in decreasing order from head to tail.
- head_position:
Position
(int32) of shape () position of the snake's head on the 2D grid. - tail: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's tail.
- fruit_position:
Position
(int32) of shape () position of the fruit on the 2D grid. - length: jax array (int32) of shape () current length of the snake.
- step_count: jax array (int32) of shape () current number of steps in the episode.
- action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position.
- key: jax array (uint32) of shape (2,) random key used to sample a new fruit when one is eaten and used for auto-reset.
1 2 3 4 5 6 7 8 |
|
observation_spec: jumanji.specs.Spec[jumanji.environments.routing.snake.types.Observation]
cached
property
writable
#
Returns the observation spec.
Returns:
Type | Description |
---|---|
Spec for the `Observation` whose fields are |
|
action_spec: DiscreteArray
cached
property
writable
#
Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left].
Returns:
Type | Description |
---|---|
action_spec |
a |
__init__(self, num_rows: int = 12, num_cols: int = 12, time_limit: int = 4000, viewer: Optional[jumanji.viewer.Viewer[jumanji.environments.routing.snake.types.State]] = None)
special
#
Instantiates a Snake
environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_rows |
int |
number of rows of the 2D grid. Defaults to 12. |
12 |
num_cols |
int |
number of columns of the 2D grid. Defaults to 12. |
12 |
time_limit |
int |
time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 4000. |
4000 |
viewer |
Optional[jumanji.viewer.Viewer[jumanji.environments.routing.snake.types.State]] |
|
None |
reset(self, key: PRNGKeyArray) -> Tuple[jumanji.environments.routing.snake.types.State, jumanji.types.TimeStep[jumanji.environments.routing.snake.types.Observation]]
#
Resets the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
PRNGKeyArray |
random key used to sample the snake and fruit positions. |
required |
Returns:
Type | Description |
---|---|
state |
|
step(self, state: State, action: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]) -> Tuple[jumanji.environments.routing.snake.types.State, jumanji.types.TimeStep[jumanji.environments.routing.snake.types.Observation]]
#
Run one timestep of the environment's dynamics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
|
required |
action |
Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int] |
Array containing the action to take: - 0: move up. - 1: move to the right. - 2: move down. - 3: move to the left. |
required |
Returns:
Type | Description |
---|---|
state, timestep |
next state of the environment and timestep to be observed. |