Skip to content

Snake

Bases: Environment[State, DiscreteArray, Observation]

A JAX implementation of the 'Snake' game.

  • observation: Observation

    • grid: jax array (float) of shape (num_rows, num_cols, 5) feature maps that include information about the fruit, the snake head, its body and tail.
      • body: 2D map with 1. where a body cell is present, else 0.
      • head: 2D map with 1. where the snake's head is located, else 0.
      • tail: 2D map with 1. where the snake's tail is located, else 0.
      • fruit: 2D map with 1. where the fruit is located, else 0.
      • norm_body_state: 2D map with a float between 0. and 1. for each body cell in the decreasing order from head to tail.
    • step_count: jax array (int32) of shape () current number of steps in the episode.
    • action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position.
  • action: jax array (int32) of shape() [0,1,2,3] -> [Up, Right, Down, Left].

  • reward: jax array (float) of shape () 1.0 if a fruit is eaten, otherwise 0.0.

  • episode termination:

    • if no action can be performed, i.e. the snake is surrounded.
    • if the time limit is reached.
    • if an invalid action is taken, the snake exits the grid or bumps into itself.
  • state: State

    • body: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's body cells.
    • body_state: jax array (int32) of shape (num_rows, num_cols) array ordering the snake's body cells, in decreasing order from head to tail.
    • head_position: Position (int32) of shape () position of the snake's head on the 2D grid.
    • tail: jax array (bool) of shape (num_rows, num_cols) array indicating the snake's tail.
    • fruit_position: Position (int32) of shape () position of the fruit on the 2D grid.
    • length: jax array (int32) of shape () current length of the snake.
    • step_count: jax array (int32) of shape () current number of steps in the episode.
    • action_mask: jax array (bool) of shape (4,) array specifying which directions the snake can move in from its current position.
    • key: jax array (uint32) of shape (2,) random key used to sample a new fruit when one is eaten and used for auto-reset.
1
2
3
4
5
6
7
8
from jumanji.environments import Snake
env = Snake()
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)

Instantiates a Snake environment.

Parameters:

Name Type Description Default
num_rows int

number of rows of the 2D grid. Defaults to 12.

12
num_cols int

number of columns of the 2D grid. Defaults to 12.

12
time_limit int

time_limit of an episode, i.e. number of environment steps before the episode ends. Defaults to 4000.

4000
viewer Optional[Viewer[State]]

Viewer used for rendering. Defaults to SnakeViewer.

None
Source code in jumanji/environments/routing/snake/env.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def __init__(
    self,
    num_rows: int = 12,
    num_cols: int = 12,
    time_limit: int = 4000,
    viewer: Optional[Viewer[State]] = None,
):
    """Instantiates a `Snake` environment.

    Args:
        num_rows: number of rows of the 2D grid. Defaults to 12.
        num_cols: number of columns of the 2D grid. Defaults to 12.
        time_limit: time_limit of an episode, i.e. number of environment steps before
            the episode ends. Defaults to 4000.
        viewer: `Viewer` used for rendering. Defaults to `SnakeViewer`.
    """
    self.num_rows = num_rows
    self.num_cols = num_cols
    self.board_shape = (num_rows, num_cols)
    self.time_limit = time_limit
    super().__init__()
    self._viewer = viewer or SnakeViewer()

action_spec: specs.DiscreteArray cached property #

Returns the action spec. 4 actions: [0,1,2,3] -> [Up, Right, Down, Left].

Returns:

Name Type Description
action_spec DiscreteArray

a specs.DiscreteArray spec.

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec.

Returns:

Type Description
Spec[Observation]

Spec for the Observation whose fields are:

Spec[Observation]
  • grid: BoundedArray (float) of shape (num_rows, num_cols, 5).
Spec[Observation]
  • step_count: DiscreteArray (num_values = time_limit) of shape ().
Spec[Observation]
  • action_mask: BoundedArray (bool) of shape (4,).

animate(states, interval=200, save_path=None) #

Create an animation from a sequence of states.

Parameters:

Name Type Description Default
states Sequence[State]

sequence of State corresponding to subsequent timesteps.

required
interval int

delay between frames in milliseconds, default to 200.

200
save_path Optional[str]

the path where the animation file should be saved. If it is None, the plot will not be saved.

None

Returns:

Type Description
FuncAnimation

Animation object that can be saved as a GIF, MP4, or rendered with HTML.

Source code in jumanji/environments/routing/snake/env.py
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
def animate(
    self,
    states: Sequence[State],
    interval: int = 200,
    save_path: Optional[str] = None,
) -> matplotlib.animation.FuncAnimation:
    """Create an animation from a sequence of states.

    Args:
        states: sequence of `State` corresponding to subsequent timesteps.
        interval: delay between frames in milliseconds, default to 200.
        save_path: the path where the animation file should be saved. If it is None, the plot
            will not be saved.

    Returns:
        Animation object that can be saved as a GIF, MP4, or rendered with HTML.
    """
    return self._viewer.animate(states, interval, save_path)

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/environments/routing/snake/env.py
402
403
404
405
406
407
408
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    self._viewer.close()

render(state) #

Render frames of the environment for a given state using matplotlib.

Parameters:

Name Type Description Default
state State

State object containing the current dynamics of the environment.

required
Source code in jumanji/environments/routing/snake/env.py
375
376
377
378
379
380
381
def render(self, state: State) -> None:
    """Render frames of the environment for a given state using matplotlib.

    Args:
        state: State object containing the current dynamics of the environment.
    """
    self._viewer.render(state)

reset(key) #

Resets the environment.

Parameters:

Name Type Description Default
key PRNGKey

random key used to sample the snake and fruit positions.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment.

timestep TimeStep[Observation]

TimeStep object corresponding to the first timestep returned by the environment.

Source code in jumanji/environments/routing/snake/env.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment.

    Args:
        key: random key used to sample the snake and fruit positions.

    Returns:
         state: `State` object corresponding to the new state of the environment.
         timestep: `TimeStep` object corresponding to the first timestep returned by the
            environment.
    """
    key, snake_key, fruit_key = jax.random.split(key, 3)
    # Sample Snake's head position.
    head_coordinates = jax.random.randint(
        snake_key,
        shape=(2,),
        minval=jnp.zeros(2, int),
        maxval=jnp.array(self.board_shape),
    )
    head_position = Position(*tuple(head_coordinates))

    body = jnp.zeros(self.board_shape, bool).at[tuple(head_position)].set(True)
    tail = body
    body_state = body.astype(jnp.int32)
    fruit_position = self._sample_fruit_coord(body, fruit_key)
    state = State(
        key=key,
        body=body,
        body_state=body_state,
        head_position=head_position,
        tail=tail,
        fruit_position=fruit_position,
        length=jnp.array(1, jnp.int32),
        step_count=jnp.array(0, jnp.int32),
        action_mask=self._get_action_mask(head_position, body_state),
    )
    timestep = restart(observation=self._state_to_observation(state))
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Numeric

Array containing the action to take: - 0: move up. - 1: move to the right. - 2: move down. - 3: move to the left.

required

Returns:

Type Description
Tuple[State, TimeStep[Observation]]

state, timestep: next state of the environment and timestep to be observed.

Source code in jumanji/environments/routing/snake/env.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def step(self, state: State, action: chex.Numeric) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: `State` object containing the dynamics of the environment.
        action: Array containing the action to take:
            - 0: move up.
            - 1: move to the right.
            - 2: move down.
            - 3: move to the left.

    Returns:
        state, timestep: next state of the environment and timestep to be observed.
    """
    is_valid = state.action_mask[action]
    key, fruit_key = jax.random.split(state.key)

    head_position = self._update_head_position(state.head_position, action)

    fruit_eaten = head_position == state.fruit_position

    length = state.length + fruit_eaten

    body_state_without_head = jax.lax.select(
        fruit_eaten,
        state.body_state,
        jnp.clip(state.body_state - 1, 0),
    )
    body_state = body_state_without_head.at[tuple(head_position)].set(length)

    body = body_state > 0

    tail = body_state == 1

    fruit_position = jax.lax.cond(
        fruit_eaten,
        self._sample_fruit_coord,
        lambda *_: state.fruit_position,
        body,
        fruit_key,
    )
    step_count = state.step_count + 1
    next_state = State(
        key=key,
        body=body,
        body_state=body_state,
        head_position=head_position,
        tail=tail,
        fruit_position=fruit_position,
        length=length,
        step_count=state.step_count + 1,
        action_mask=self._get_action_mask(head_position, body_state),
    )

    snake_completed = jnp.all(body)
    done = ~is_valid | snake_completed | (step_count >= self.time_limit)

    reward = jnp.asarray(fruit_eaten, float)
    observation = self._state_to_observation(next_state)

    timestep = jax.lax.cond(
        done,
        termination,
        transition,
        reward,
        observation,
    )
    return next_state, timestep