Skip to content

Wrappers

AutoResetWrapper(env, next_obs_in_extras=False) #

Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"]. WARNING: do not jax.vmap the wrapped environment (e.g. do not use with the VmapWrapper), which would lead to inefficient computation due to both the step and reset functions being processed each time step is called. Please use the VmapAutoResetWrapper instead.

Wrap an environment to automatically reset it when the episode terminates.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

the environment to wrap.

required
next_obs_in_extras bool

whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation.

False
Source code in jumanji/wrappers.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    next_obs_in_extras: bool = False,
):
    """Wrap an environment to automatically reset it when the episode terminates.

    Args:
        env: the environment to wrap.
        next_obs_in_extras: whether to store the next observation in the extras of the
            terminal timestep. This is useful for e.g. truncation.
    """
    super().__init__(env)
    self.next_obs_in_extras = next_obs_in_extras
    if next_obs_in_extras:
        self._maybe_add_obs_to_extras = add_obs_to_extras
    else:
        self._maybe_add_obs_to_extras = lambda timestep: timestep  # no-op

step(state, action) #

Step the environment, with automatic resetting if the episode terminates.

Source code in jumanji/wrappers.py
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Step the environment, with automatic resetting if the episode terminates."""
    state, timestep = self._env.step(state, action)

    # Overwrite the state and timestep appropriately if the episode terminates.
    state, timestep = jax.lax.cond(
        timestep.last(),
        self._auto_reset,
        lambda s, t: (s, self._maybe_add_obs_to_extras(t)),
        state,
        timestep,
    )

    return state, timestep

JumanjiToDMEnvWrapper(env, key=None) #

Bases: Environment, Generic[State, ActionSpec, Observation]

A wrapper that converts Environment to dm_env.Environment.

Create the wrapped environment.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

Environmentto wrap to a dm_env.Environment.

required
key Optional[PRNGKey]

optional key to initialize the Environment with.

None
Source code in jumanji/wrappers.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    key: Optional[chex.PRNGKey] = None,
):
    """Create the wrapped environment.

    Args:
        env: `Environment`to wrap to a `dm_env.Environment`.
        key: optional key to initialize the `Environment` with.
    """
    self._env = env
    if key is None:
        self._key = jax.random.PRNGKey(0)
    else:
        self._key = key
    self._state: Any
    self._jitted_reset: Callable[[chex.PRNGKey], Tuple[State, TimeStep]] = jax.jit(
        self._env.reset
    )
    self._jitted_step: Callable[[State, chex.Array], Tuple[State, TimeStep]] = jax.jit(
        self._env.step
    )

action_spec() #

Returns the dm_env action spec.

Source code in jumanji/wrappers.py
212
213
214
def action_spec(self) -> dm_env.specs.Array:
    """Returns the dm_env action spec."""
    return specs.jumanji_specs_to_dm_env_specs(self._env.action_spec)

observation_spec() #

Returns the dm_env observation spec.

Source code in jumanji/wrappers.py
208
209
210
def observation_spec(self) -> dm_env.specs.Array:
    """Returns the dm_env observation spec."""
    return specs.jumanji_specs_to_dm_env_specs(self._env.observation_spec)

reset() #

Starts a new sequence and returns the first TimeStep of this sequence.

Returns:

Type Description
TimeStep

A TimeStep namedtuple containing: - step_type: A StepType of FIRST. - reward: None, indicating the reward is undefined. - discount: None, indicating the discount is undefined. - observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec.

Source code in jumanji/wrappers.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def reset(self) -> dm_env.TimeStep:
    """Starts a new sequence and returns the first `TimeStep` of this sequence.

    Returns:
        A `TimeStep` namedtuple containing:
            - step_type: A `StepType` of `FIRST`.
            - reward: `None`, indicating the reward is undefined.
            - discount: `None`, indicating the discount is undefined.
            - observation: A NumPy array, or a nested dict, list or tuple of arrays.
                Scalar values that can be cast to NumPy arrays (e.g. Python floats)
                are also valid in place of a scalar array. Must conform to the
                specification returned by `observation_spec`.
    """
    reset_key, self._key = jax.random.split(self._key)
    self._state, timestep = self._jitted_reset(reset_key)
    return dm_env.restart(observation=timestep.observation)

step(action) #

Updates the environment according to the action and returns a TimeStep.

If the environment returned a TimeStep with StepType.LAST at the previous step, this call to step will start a new sequence and action will be ignored.

This method will also start a new sequence if called after the environment has been constructed and reset has not been called. Again, in this case action will be ignored.

Parameters:

Name Type Description Default
action ArrayNumpy

A NumPy array, or a nested dict, list or tuple of arrays corresponding to action_spec.

required

Returns:

Type Description
TimeStep

A TimeStep namedtuple containing: - step_type: A StepType value. - reward: Reward at this timestep, or None if step_type is StepType.FIRST. Must conform to the specification returned by reward_spec. - discount: A discount in the range [0, 1], or None if step_type is StepType.FIRST. Must conform to the specification returned by discount_spec. - observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec.

Source code in jumanji/wrappers.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
    """Updates the environment according to the action and returns a `TimeStep`.

    If the environment returned a `TimeStep` with `StepType.LAST` at the
    previous step, this call to `step` will start a new sequence and `action`
    will be ignored.

    This method will also start a new sequence if called after the environment
    has been constructed and `reset` has not been called. Again, in this case
    `action` will be ignored.

    Args:
        action: A NumPy array, or a nested dict, list or tuple of arrays
            corresponding to `action_spec`.

    Returns:
        A `TimeStep` namedtuple containing:
            - step_type: A `StepType` value.
            - reward: Reward at this timestep, or None if step_type is
                `StepType.FIRST`. Must conform to the specification returned by
                `reward_spec`.
            - discount: A discount in the range [0, 1], or None if step_type is
                `StepType.FIRST`. Must conform to the specification returned by
                `discount_spec`.
            - observation: A NumPy array, or a nested dict, list or tuple of arrays.
                Scalar values that can be cast to NumPy arrays (e.g. Python floats)
                are also valid in place of a scalar array. Must conform to the
                specification returned by `observation_spec`.
    """
    self._state, timestep = self._jitted_step(self._state, action)
    return dm_env.TimeStep(
        step_type=timestep.step_type,
        reward=timestep.reward,
        discount=timestep.discount,
        observation=timestep.observation,
    )

JumanjiToGymWrapper(env, seed=0, backend=None) #

Bases: Env, Generic[State, ActionSpec, Observation]

A wrapper that converts a Jumanji Environment to one that follows the gym.Env API.

Create the Gym environment.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

Environment to wrap to a gym.Env.

required
seed int

the seed that is used to initialize the environment's PRNG.

0
backend Optional[str]

the XLA backend.

None
Source code in jumanji/wrappers.py
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    seed: int = 0,
    backend: Optional[str] = None,
):
    """Create the Gym environment.

    Args:
        env: `Environment` to wrap to a `gym.Env`.
        seed: the seed that is used to initialize the environment's PRNG.
        backend: the XLA backend.
    """
    self._env = env
    self.metadata: Dict[str, str] = {}
    self._key = jax.random.PRNGKey(seed)
    self.backend = backend
    self._state = None
    self.observation_space = specs.jumanji_specs_to_gym_spaces(self._env.observation_spec)
    self.action_space = specs.jumanji_specs_to_gym_spaces(self._env.action_spec)

    def reset(key: chex.PRNGKey) -> Tuple[State, Observation, Optional[Dict]]:
        """Reset function of a Jumanji environment to be jitted."""
        state, timestep = self._env.reset(key)
        return state, timestep.observation, timestep.extras

    self._reset = jax.jit(reset, backend=self.backend)

    def step(
        state: State, action: chex.Array
    ) -> Tuple[State, Observation, chex.Array, chex.Array, chex.Array, Optional[Any]]:
        """Step function of a Jumanji environment to be jitted."""
        state, timestep = self._env.step(state, action)
        term = timestep.discount.astype(bool)
        trunc = timestep.last().astype(bool)
        return state, timestep.observation, timestep.reward, term, trunc, timestep.extras

    self._step = jax.jit(step, backend=self.backend)

close() #

Closes the environment, important for rendering where pygame is imported.

Source code in jumanji/wrappers.py
695
696
697
def close(self) -> None:
    """Closes the environment, important for rendering where pygame is imported."""
    self._env.close()

render(mode='human') #

Renders the environment.

Parameters:

Name Type Description Default
mode str

currently not used since Jumanji does not currently support modes.

'human'
Source code in jumanji/wrappers.py
684
685
686
687
688
689
690
691
692
693
def render(self, mode: str = "human") -> Any:
    """Renders the environment.

    Args:
        mode: currently not used since Jumanji does not currently support modes.
    """
    del mode
    if self._state is None:
        raise ValueError("Cannot render when _state is None.")
    return self._env.render(self._state)

reset(*, seed=None, options=None) #

Resets the environment to an initial state by starting a new sequence and returns the first Observation of this sequence.

Returns:

Name Type Description
obs GymObservation

an element of the environment's observation_space.

info optional

contains supplementary information such as metrics.

Source code in jumanji/wrappers.py
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
def reset(
    self,
    *,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> Tuple[GymObservation, Dict[str, Any]]:
    """Resets the environment to an initial state by starting a new sequence
    and returns the first `Observation` of this sequence.

    Returns:
        obs: an element of the environment's observation_space.
        info (optional): contains supplementary information such as metrics.
    """
    if seed is not None:
        self.seed(seed)
    key, self._key = jax.random.split(self._key)
    self._state, obs, extras = self._reset(key)

    # Convert the observation to a numpy array or a nested dict thereof
    obs = jumanji_to_gym_obs(obs)

    return obs, jax.device_get(extras)

seed(seed=0) #

Function which sets the seed for the environment's random number generator(s).

Parameters:

Name Type Description Default
seed int

the seed value for the random number generator(s).

0
Source code in jumanji/wrappers.py
676
677
678
679
680
681
682
def seed(self, seed: int = 0) -> None:
    """Function which sets the seed for the environment's random number generator(s).

    Args:
        seed: the seed value for the random number generator(s).
    """
    self._key = jax.random.PRNGKey(seed)

step(action) #

Updates the environment according to the action and returns an Observation.

Parameters:

Name Type Description Default
action ArrayNumpy

A NumPy array representing the action provided by the agent.

required

Returns:

Name Type Description
observation GymObservation

an element of the environment's observation_space.

reward float

the amount of reward returned as a result of taking the action.

terminated bool

whether a terminal state is reached.

info bool

contains supplementary information such as metrics.

Source code in jumanji/wrappers.py
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
def step(
    self, action: chex.ArrayNumpy
) -> Tuple[GymObservation, float, bool, bool, Dict[str, Any]]:
    """Updates the environment according to the action and returns an `Observation`.

    Args:
        action: A NumPy array representing the action provided by the agent.

    Returns:
        observation: an element of the environment's observation_space.
        reward: the amount of reward returned as a result of taking the action.
        terminated: whether a terminal state is reached.
        info: contains supplementary information such as metrics.
    """

    action_jax = jnp.asarray(action)  # Convert input numpy array to JAX array
    self._state, obs, reward, term, trunc, extras = self._step(self._state, action_jax)

    # Convert to get the correct signature
    obs = jumanji_to_gym_obs(obs)
    reward = float(reward)
    terminated = bool(term)
    truncated = bool(trunc)
    info = jax.device_get(extras)

    return obs, reward, terminated, truncated, info

MultiToSingleWrapper(env, reward_aggregator=jnp.sum, discount_aggregator=jnp.max) #

Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

A wrapper that converts a multi-agent Environment to a single-agent Environment.

Create the wrapped environment.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

Environment to wrap to a dm_env.Environment.

required
reward_aggregator Callable

a function to aggregate all agents rewards into a single scalar value, e.g. sum.

sum
discount_aggregator Callable

a function to aggregate all agents discounts into a single scalar value, e.g. max.

max
Source code in jumanji/wrappers.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    reward_aggregator: Callable = jnp.sum,
    discount_aggregator: Callable = jnp.max,
):
    """Create the wrapped environment.

    Args:
        env: `Environment` to wrap to a `dm_env.Environment`.
        reward_aggregator: a function to aggregate all agents rewards into a single scalar
            value, e.g. sum.
        discount_aggregator: a function to aggregate all agents discounts into a single
            scalar value, e.g. max.
    """
    super().__init__(env)
    self._reward_aggregator = reward_aggregator
    self._discount_aggregator = discount_aggregator

reset(key) #

Resets the environment to an initial state.

Parameters:

Name Type Description Default
key PRNGKey

random key used to reset the environment.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment,

timestep TimeStep[Observation]

TimeStep object corresponding the first timestep returned by the environment,

Source code in jumanji/wrappers.py
264
265
266
267
268
269
270
271
272
273
274
275
276
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment to an initial state.

    Args:
        key: random key used to reset the environment.

    Returns:
        state: State object corresponding to the new state of the environment,
        timestep: TimeStep object corresponding the first timestep returned by the environment,
    """
    state, timestep = self._env.reset(key)
    timestep = self._aggregate_timestep(timestep)
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

The rewards are aggregated into a single value based on the given reward aggregator. The discount value is set to the largest discount of all the agents. This essentially means that if any single agent is alive, the discount value won't be zero.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

Array containing the action to take.

required

Returns:

Name Type Description
state State

State object corresponding to the next state of the environment,

timestep TimeStep[Observation]

TimeStep object corresponding the timestep returned by the environment,

Source code in jumanji/wrappers.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    The rewards are aggregated into a single value based on the given reward aggregator.
    The discount value is set to the largest discount of all the agents. This
    essentially means that if any single agent is alive, the discount value won't be zero.

    Args:
        state: State object containing the dynamics of the environment.
        action: Array containing the action to take.

    Returns:
        state: State object corresponding to the next state of the environment,
        timestep: TimeStep object corresponding the timestep returned by the environment,
    """
    state, timestep = self._env.step(state, action)
    timestep = self._aggregate_timestep(timestep)
    return state, timestep

VmapAutoResetWrapper(env, next_obs_in_extras=False) #

Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

Efficient combination of VmapWrapper and AutoResetWrapper, to be used as a replacement of the combination of both wrappers. env = VmapAutoResetWrapper(env) is equivalent to env = VmapWrapper(AutoResetWrapper(env)) but is more efficient as it parallelizes homogeneous computation and does not run branches of the computational graph that are not needed (heterogeneous computation). - Homogeneous computation: call step function on all environments in the batch. - Heterogeneous computation: conditional auto-reset (call reset function for some environments within the batch because they have terminated). NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].

Wrap an environment to vmap it and automatically reset it when the episode terminates.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

the environment to wrap.

required
next_obs_in_extras bool

whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation.

False
Source code in jumanji/wrappers.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    next_obs_in_extras: bool = False,
):
    """Wrap an environment to vmap it and automatically reset it when the episode terminates.

    Args:
        env: the environment to wrap.
        next_obs_in_extras: whether to store the next observation in the extras of the
            terminal timestep. This is useful for e.g. truncation.
    """
    super().__init__(env)
    self.next_obs_in_extras = next_obs_in_extras
    if next_obs_in_extras:
        self._maybe_add_obs_to_extras = add_obs_to_extras
    else:
        self._maybe_add_obs_to_extras = lambda timestep: timestep  # no-op

render(state) #

Render the first environment state of the given batch. The remaining elements of the batched state are ignored.

Parameters:

Name Type Description Default
state State

State object containing the current dynamics of the environment.

required
Source code in jumanji/wrappers.py
573
574
575
576
577
578
579
580
581
def render(self, state: State) -> Any:
    """Render the first environment state of the given batch.
    The remaining elements of the batched state are ignored.

    Args:
        state: State object containing the current dynamics of the environment.
    """
    state_0 = tree_utils.tree_slice(state, 0)
    return super().render(state_0)

reset(key) #

Resets a batch of environments to initial states.

The first dimension of the key will dictate the number of concurrent environments.

To obtain a key with the right first dimension, you may call jax.random.split on key with the parameter num representing the number of concurrent environments.

Parameters:

Name Type Description Default
key PRNGKey

random keys used to reset the environments where the first dimension is the number of desired environments.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environments,

timestep TimeStep[Observation]

TimeStep object corresponding the first timesteps returned by the environments,

Source code in jumanji/wrappers.py
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets a batch of environments to initial states.

    The first dimension of the key will dictate the number of concurrent environments.

    To obtain a key with the right first dimension, you may call `jax.random.split` on key
    with the parameter `num` representing the number of concurrent environments.

    Args:
        key: random keys used to reset the environments where the first dimension is the number
            of desired environments.

    Returns:
        state: `State` object corresponding to the new state of the environments,
        timestep: `TimeStep` object corresponding the first timesteps returned by the
            environments,
    """
    state, timestep = jax.vmap(self._env.reset)(key)
    timestep = self._maybe_add_obs_to_extras(timestep)
    return state, timestep

step(state, action) #

Run one timestep of all environments' dynamics. It automatically resets environment(s) in which episodes have terminated.

The first dimension of the state will dictate the number of concurrent environments.

See VmapAutoResetWrapper.reset for more details on how to get a state of concurrent environments.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environments.

required
action Array

Array containing the actions to take.

required

Returns:

Name Type Description
state State

State object corresponding to the next states of the environments.

timestep TimeStep[Observation]

TimeStep object corresponding the timesteps returned by the environments.

Source code in jumanji/wrappers.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of all environments' dynamics. It automatically resets environment(s)
    in which episodes have terminated.

    The first dimension of the state will dictate the number of concurrent environments.

    See `VmapAutoResetWrapper.reset` for more details on how to get a state of concurrent
    environments.

    Args:
        state: `State` object containing the dynamics of the environments.
        action: `Array` containing the actions to take.

    Returns:
        state: `State` object corresponding to the next states of the environments.
        timestep: `TimeStep` object corresponding the timesteps returned by the environments.
    """
    # Vmap homogeneous computation (parallelizable).
    state, timestep = jax.vmap(self._env.step)(state, action)
    # Map heterogeneous computation (non-parallelizable).
    state, timestep = jax.lax.map(lambda args: self._maybe_reset(*args), (state, timestep))
    return state, timestep

VmapWrapper(env) #

Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec

Source code in jumanji/wrappers.py
39
40
41
def __init__(self, env: Environment[State, ActionSpec, Observation]):
    self._env = env
    super().__init__()

render(state) #

Render the first environment state of the given batch. The remaining elements of the batched state are ignored.

Parameters:

Name Type Description Default
state State

State object containing the current dynamics of the environment.

required
Source code in jumanji/wrappers.py
347
348
349
350
351
352
353
354
355
def render(self, state: State) -> Any:
    """Render the first environment state of the given batch.
    The remaining elements of the batched state are ignored.

    Args:
        state: State object containing the current dynamics of the environment.
    """
    state_0 = tree_utils.tree_slice(state, 0)
    return super().render(state_0)

reset(key) #

Resets the environment to an initial state.

The first dimension of the key will dictate the number of concurrent environments.

To obtain a key with the right first dimension, you may call jax.random.split on key with the parameter num representing the number of concurrent environments.

Parameters:

Name Type Description Default
key PRNGKey

random keys used to reset the environments where the first dimension is the number of desired environments.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environments,

timestep TimeStep[Observation]

TimeStep object corresponding the first timesteps returned by the environments,

Source code in jumanji/wrappers.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment to an initial state.

    The first dimension of the key will dictate the number of concurrent environments.

    To obtain a key with the right first dimension, you may call `jax.random.split` on key
    with the parameter `num` representing the number of concurrent environments.

    Args:
        key: random keys used to reset the environments where the first dimension is the number
            of desired environments.

    Returns:
        state: State object corresponding to the new state of the environments,
        timestep: TimeStep object corresponding the first timesteps returned by the
            environments,
    """
    state, timestep = jax.vmap(self._env.reset)(key)
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

The first dimension of the state will dictate the number of concurrent environments.

See VmapWrapper.reset for more details on how to get a state of concurrent environments.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environments.

required
action Array

Array containing the actions to take.

required

Returns:

Name Type Description
state State

State object corresponding to the next states of the environments,

timestep TimeStep[Observation]

TimeStep object corresponding the timesteps returned by the environments,

Source code in jumanji/wrappers.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    The first dimension of the state will dictate the number of concurrent environments.

    See `VmapWrapper.reset` for more details on how to get a state of concurrent
    environments.

    Args:
        state: State object containing the dynamics of the environments.
        action: Array containing the actions to take.

    Returns:
        state: State object corresponding to the next states of the environments,
        timestep: TimeStep object corresponding the timesteps returned by the environments,
    """
    state, timestep = jax.vmap(self._env.step)(state, action)
    return state, timestep

Wrapper(env) #

Bases: Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72

Source code in jumanji/wrappers.py
39
40
41
def __init__(self, env: Environment[State, ActionSpec, Observation]):
    self._env = env
    super().__init__()

action_spec: ActionSpec cached property #

Returns the action spec.

discount_spec: specs.BoundedArray cached property #

Returns the discount spec.

observation_spec: specs.Spec[Observation] cached property #

Returns the observation spec.

reward_spec: specs.Array cached property #

Returns the reward spec.

unwrapped: Environment[State, ActionSpec, Observation] property #

Returns the wrapped env.

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/wrappers.py
109
110
111
112
113
114
115
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    return self._env.close()

render(state) #

Compute render frames during initialisation of the environment.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
Source code in jumanji/wrappers.py
101
102
103
104
105
106
107
def render(self, state: State) -> Any:
    """Compute render frames during initialisation of the environment.

    Args:
        state: State object containing the dynamics of the environment.
    """
    return self._env.render(state)

reset(key) #

Resets the environment to an initial state.

Parameters:

Name Type Description Default
key PRNGKey

random key used to reset the environment.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment,

timestep TimeStep[Observation]

TimeStep object corresponding the first timestep returned by the environment,

Source code in jumanji/wrappers.py
56
57
58
59
60
61
62
63
64
65
66
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment to an initial state.

    Args:
        key: random key used to reset the environment.

    Returns:
        state: State object corresponding to the new state of the environment,
        timestep: TimeStep object corresponding the first timestep returned by the environment,
    """
    return self._env.reset(key)

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

Array containing the action to take.

required

Returns:

Name Type Description
state State

State object corresponding to the next state of the environment,

timestep TimeStep[Observation]

TimeStep object corresponding the timestep returned by the environment,

Source code in jumanji/wrappers.py
68
69
70
71
72
73
74
75
76
77
78
79
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: State object containing the dynamics of the environment.
        action: Array containing the action to take.

    Returns:
        state: State object corresponding to the next state of the environment,
        timestep: TimeStep object corresponding the timestep returned by the environment,
    """
    return self._env.step(state, action)

add_obs_to_extras(timestep) #

Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]. Used when auto-resetting to store the observation from the terminal TimeStep (useful for e.g. truncation).

Parameters:

Name Type Description Default
timestep TimeStep[Observation]

TimeStep object containing the timestep returned by the environment.

required

Returns:

Type Description
TimeStep[Observation]

timestep where the observation is placed in timestep.extras["next_obs"].

Source code in jumanji/wrappers.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def add_obs_to_extras(timestep: TimeStep[Observation]) -> TimeStep[Observation]:
    """Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS].
    Used when auto-resetting to store the observation from the terminal TimeStep (useful for
    e.g. truncation).

    Args:
        timestep: TimeStep object containing the timestep returned by the environment.

    Returns:
        timestep where the observation is placed in timestep.extras["next_obs"].
    """
    extras = timestep.extras
    extras[NEXT_OBS_KEY_IN_EXTRAS] = timestep.observation
    return timestep.replace(extras=extras)  # type: ignore

jumanji_to_gym_obs(observation) #

Convert a Jumanji observation into a gym observation.

Parameters:

Name Type Description Default
observation Observation

JAX pytree with (possibly nested) containers that either have the __dict__ or _asdict methods implemented.

required

Returns:

Type Description
GymObservation

Numpy array or nested dictionary of numpy arrays.

Source code in jumanji/wrappers.py
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
def jumanji_to_gym_obs(observation: Observation) -> GymObservation:
    """Convert a Jumanji observation into a gym observation.

    Args:
        observation: JAX pytree with (possibly nested) containers that
            either have the `__dict__` or `_asdict` methods implemented.

    Returns:
        Numpy array or nested dictionary of numpy arrays.
    """
    if isinstance(observation, jnp.ndarray):
        return np.asarray(observation)
    elif hasattr(observation, "__dict__"):
        # Applies to various containers including `chex.dataclass`
        return {key: jumanji_to_gym_obs(value) for key, value in vars(observation).items()}
    elif hasattr(observation, "_asdict"):
        # Applies to `NamedTuple` container.
        return {
            key: jumanji_to_gym_obs(value)
            for key, value in observation._asdict().items()  # type: ignore
        }
    else:
        raise NotImplementedError(
            "Conversion only implemented for JAX pytrees with (possibly nested) containers "
            "that either have the `__dict__` or `_asdict` methods implemented."
        )