Skip to content

Wrappers

AutoResetWrapper(env, next_obs_in_extras=False) #

Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"]. WARNING: do not jax.vmap the wrapped environment (e.g. do not use with the VmapWrapper), which would lead to inefficient computation due to both the step and reset functions being processed each time step is called. Please use the VmapAutoResetWrapper instead.

Wrap an environment to automatically reset it when the episode terminates.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

the environment to wrap.

required
next_obs_in_extras bool

whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation.

False
Source code in jumanji/wrappers.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    next_obs_in_extras: bool = False,
):
    """Wrap an environment to automatically reset it when the episode terminates.

    Args:
        env: the environment to wrap.
        next_obs_in_extras: whether to store the next observation in the extras of the
            terminal timestep. This is useful for e.g. truncation.
    """
    super().__init__(env)
    self.next_obs_in_extras = next_obs_in_extras
    if next_obs_in_extras:
        self._maybe_add_obs_to_extras = add_obs_to_extras
    else:
        self._maybe_add_obs_to_extras = lambda timestep: timestep  # no-op

step(state, action) #

Step the environment, with automatic resetting if the episode terminates.

Source code in jumanji/wrappers.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Step the environment, with automatic resetting if the episode terminates."""
    state, timestep = self._env.step(state, action)

    # Overwrite the state and timestep appropriately if the episode terminates.
    state, timestep = jax.lax.cond(
        timestep.last(),
        self._auto_reset,
        lambda s, t: (s, self._maybe_add_obs_to_extras(t)),
        state,
        timestep,
    )

    return state, timestep

JumanjiToDMEnvWrapper(env, key=None) #

Bases: Environment, Generic[State, ActionSpec, Observation]

A wrapper that converts Environment to dm_env.Environment.

Create the wrapped environment.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

Environmentto wrap to a dm_env.Environment.

required
key Optional[PRNGKey]

optional key to initialize the Environment with.

None
Source code in jumanji/wrappers.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    key: Optional[chex.PRNGKey] = None,
):
    """Create the wrapped environment.

    Args:
        env: `Environment`to wrap to a `dm_env.Environment`.
        key: optional key to initialize the `Environment` with.
    """
    self._env = env
    if key is None:
        self._key = jax.random.PRNGKey(0)
    else:
        self._key = key
    self._state: Any
    self._jitted_reset: Callable[[chex.PRNGKey], Tuple[State, TimeStep]] = jax.jit(
        self._env.reset
    )
    self._jitted_step: Callable[[State, chex.Array], Tuple[State, TimeStep]] = jax.jit(
        self._env.step
    )

action_spec() #

Returns the dm_env action spec.

Source code in jumanji/wrappers.py
212
213
214
def action_spec(self) -> dm_env.specs.Array:
    """Returns the dm_env action spec."""
    return specs.jumanji_specs_to_dm_env_specs(self._env.action_spec)

observation_spec() #

Returns the dm_env observation spec.

Source code in jumanji/wrappers.py
208
209
210
def observation_spec(self) -> dm_env.specs.Array:
    """Returns the dm_env observation spec."""
    return specs.jumanji_specs_to_dm_env_specs(self._env.observation_spec)

reset() #

Starts a new sequence and returns the first TimeStep of this sequence.

Returns:

Type Description
TimeStep

A TimeStep namedtuple containing: - step_type: A StepType of FIRST. - reward: None, indicating the reward is undefined. - discount: None, indicating the discount is undefined. - observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec.

Source code in jumanji/wrappers.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def reset(self) -> dm_env.TimeStep:
    """Starts a new sequence and returns the first `TimeStep` of this sequence.

    Returns:
        A `TimeStep` namedtuple containing:
            - step_type: A `StepType` of `FIRST`.
            - reward: `None`, indicating the reward is undefined.
            - discount: `None`, indicating the discount is undefined.
            - observation: A NumPy array, or a nested dict, list or tuple of arrays.
                Scalar values that can be cast to NumPy arrays (e.g. Python floats)
                are also valid in place of a scalar array. Must conform to the
                specification returned by `observation_spec`.
    """
    reset_key, self._key = jax.random.split(self._key)
    self._state, timestep = self._jitted_reset(reset_key)
    return dm_env.restart(observation=timestep.observation)

step(action) #

Updates the environment according to the action and returns a TimeStep.

If the environment returned a TimeStep with StepType.LAST at the previous step, this call to step will start a new sequence and action will be ignored.

This method will also start a new sequence if called after the environment has been constructed and reset has not been called. Again, in this case action will be ignored.

Parameters:

Name Type Description Default
action ArrayNumpy

A NumPy array, or a nested dict, list or tuple of arrays corresponding to action_spec.

required

Returns:

Type Description
TimeStep

A TimeStep namedtuple containing: - step_type: A StepType value. - reward: Reward at this timestep, or None if step_type is StepType.FIRST. Must conform to the specification returned by reward_spec. - discount: A discount in the range [0, 1], or None if step_type is StepType.FIRST. Must conform to the specification returned by discount_spec. - observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec.

Source code in jumanji/wrappers.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep:
    """Updates the environment according to the action and returns a `TimeStep`.

    If the environment returned a `TimeStep` with `StepType.LAST` at the
    previous step, this call to `step` will start a new sequence and `action`
    will be ignored.

    This method will also start a new sequence if called after the environment
    has been constructed and `reset` has not been called. Again, in this case
    `action` will be ignored.

    Args:
        action: A NumPy array, or a nested dict, list or tuple of arrays
            corresponding to `action_spec`.

    Returns:
        A `TimeStep` namedtuple containing:
            - step_type: A `StepType` value.
            - reward: Reward at this timestep, or None if step_type is
                `StepType.FIRST`. Must conform to the specification returned by
                `reward_spec`.
            - discount: A discount in the range [0, 1], or None if step_type is
                `StepType.FIRST`. Must conform to the specification returned by
                `discount_spec`.
            - observation: A NumPy array, or a nested dict, list or tuple of arrays.
                Scalar values that can be cast to NumPy arrays (e.g. Python floats)
                are also valid in place of a scalar array. Must conform to the
                specification returned by `observation_spec`.
    """
    self._state, timestep = self._jitted_step(self._state, action)
    return dm_env.TimeStep(
        step_type=timestep.step_type,
        reward=timestep.reward,
        discount=timestep.discount,
        observation=timestep.observation,
    )

JumanjiToGymWrapper(env, seed=0, backend=None) #

Bases: Env, Generic[State, ActionSpec, Observation]

A wrapper that converts a Jumanji Environment to one that follows the gym.Env API.

Create the Gym environment.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

Environment to wrap to a gym.Env.

required
seed int

the seed that is used to initialize the environment's PRNG.

0
backend Optional[str]

the XLA backend.

None
Source code in jumanji/wrappers.py
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    seed: int = 0,
    backend: Optional[str] = None,
):
    """Create the Gym environment.

    Args:
        env: `Environment` to wrap to a `gym.Env`.
        seed: the seed that is used to initialize the environment's PRNG.
        backend: the XLA backend.
    """
    self._env = env
    self.metadata: Dict[str, str] = {}
    self._key = jax.random.PRNGKey(seed)
    self.backend = backend
    self._state = None
    self.observation_space = specs.jumanji_specs_to_gym_spaces(self._env.observation_spec)
    self.action_space = specs.jumanji_specs_to_gym_spaces(self._env.action_spec)

    def reset(key: chex.PRNGKey) -> Tuple[State, Observation, Optional[Dict]]:
        """Reset function of a Jumanji environment to be jitted."""
        state, timestep = self._env.reset(key)
        return state, timestep.observation, timestep.extras

    self._reset = jax.jit(reset, backend=self.backend)

    def step(
        state: State, action: chex.Array
    ) -> Tuple[State, Observation, chex.Array, chex.Array, chex.Array, Optional[Any]]:
        """Step function of a Jumanji environment to be jitted."""
        state, timestep = self._env.step(state, action)
        term = ~timestep.discount.astype(bool)
        trunc = timestep.last().astype(bool)
        return state, timestep.observation, timestep.reward, term, trunc, timestep.extras

    self._step = jax.jit(step, backend=self.backend)

close() #

Closes the environment, important for rendering where pygame is imported.

Source code in jumanji/wrappers.py
713
714
715
def close(self) -> None:
    """Closes the environment, important for rendering where pygame is imported."""
    self._env.close()

render(mode='human') #

Renders the environment.

Parameters:

Name Type Description Default
mode str

currently not used since Jumanji does not currently support modes.

'human'
Source code in jumanji/wrappers.py
702
703
704
705
706
707
708
709
710
711
def render(self, mode: str = "human") -> Any:
    """Renders the environment.

    Args:
        mode: currently not used since Jumanji does not currently support modes.
    """
    del mode
    if self._state is None:
        raise ValueError("Cannot render when _state is None.")
    return self._env.render(self._state)

reset(*, seed=None, options=None) #

Resets the environment to an initial state by starting a new sequence and returns the first Observation of this sequence.

Returns:

Name Type Description
obs GymObservation

an element of the environment's observation_space.

info optional

contains supplementary information such as metrics.

Source code in jumanji/wrappers.py
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
def reset(
    self,
    *,
    seed: Optional[int] = None,
    options: Optional[dict] = None,
) -> Tuple[GymObservation, Dict[str, Any]]:
    """Resets the environment to an initial state by starting a new sequence
    and returns the first `Observation` of this sequence.

    Returns:
        obs: an element of the environment's observation_space.
        info (optional): contains supplementary information such as metrics.
    """
    if seed is not None:
        self.seed(seed)
    key, self._key = jax.random.split(self._key)
    self._state, obs, extras = self._reset(key)

    # Convert the observation to a numpy array or a nested dict thereof
    obs = jumanji_to_gym_obs(obs)

    return obs, jax.device_get(extras)

seed(seed=0) #

Function which sets the seed for the environment's random number generator(s).

Parameters:

Name Type Description Default
seed int

the seed value for the random number generator(s).

0
Source code in jumanji/wrappers.py
694
695
696
697
698
699
700
def seed(self, seed: int = 0) -> None:
    """Function which sets the seed for the environment's random number generator(s).

    Args:
        seed: the seed value for the random number generator(s).
    """
    self._key = jax.random.PRNGKey(seed)

step(action) #

Updates the environment according to the action and returns an Observation.

Parameters:

Name Type Description Default
action ArrayNumpy

A NumPy array representing the action provided by the agent.

required

Returns:

Name Type Description
observation GymObservation

an element of the environment's observation_space.

reward float

the amount of reward returned as a result of taking the action.

terminated bool

whether a terminal state is reached.

info bool

contains supplementary information such as metrics.

Source code in jumanji/wrappers.py
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
def step(
    self, action: chex.ArrayNumpy
) -> Tuple[GymObservation, float, bool, bool, Dict[str, Any]]:
    """Updates the environment according to the action and returns an `Observation`.

    Args:
        action: A NumPy array representing the action provided by the agent.

    Returns:
        observation: an element of the environment's observation_space.
        reward: the amount of reward returned as a result of taking the action.
        terminated: whether a terminal state is reached.
        info: contains supplementary information such as metrics.
    """

    action_jax = jnp.asarray(action)  # Convert input numpy array to JAX array
    self._state, obs, reward, term, trunc, extras = self._step(self._state, action_jax)

    # Convert to get the correct signature
    obs = jumanji_to_gym_obs(obs)
    reward = float(reward)
    terminated = bool(term)
    truncated = bool(trunc)
    info = jax.device_get(extras)

    return obs, reward, terminated, truncated, info

MultiToSingleWrapper(env, reward_aggregator=jnp.sum, discount_aggregator=jnp.max) #

Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

A wrapper that converts a multi-agent Environment to a single-agent Environment.

Create the wrapped environment.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

Environment to wrap to a dm_env.Environment.

required
reward_aggregator Callable

a function to aggregate all agents rewards into a single scalar value, e.g. sum.

sum
discount_aggregator Callable

a function to aggregate all agents discounts into a single scalar value, e.g. max.

max
Source code in jumanji/wrappers.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    reward_aggregator: Callable = jnp.sum,
    discount_aggregator: Callable = jnp.max,
):
    """Create the wrapped environment.

    Args:
        env: `Environment` to wrap to a `dm_env.Environment`.
        reward_aggregator: a function to aggregate all agents rewards into a single scalar
            value, e.g. sum.
        discount_aggregator: a function to aggregate all agents discounts into a single
            scalar value, e.g. max.
    """
    super().__init__(env)
    self._reward_aggregator = reward_aggregator
    self._discount_aggregator = discount_aggregator

discount_spec cached property #

Scalar discount spec matching the aggregated output.

reward_spec cached property #

Scalar reward spec matching the aggregated output.

reset(key) #

Resets the environment to an initial state.

Parameters:

Name Type Description Default
key PRNGKey

random key used to reset the environment.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment,

timestep TimeStep[Observation]

TimeStep object corresponding the first timestep returned by the environment,

Source code in jumanji/wrappers.py
264
265
266
267
268
269
270
271
272
273
274
275
276
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment to an initial state.

    Args:
        key: random key used to reset the environment.

    Returns:
        state: State object corresponding to the new state of the environment,
        timestep: TimeStep object corresponding the first timestep returned by the environment,
    """
    state, timestep = self._env.reset(key)
    timestep = self._aggregate_timestep(timestep)
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

The rewards are aggregated into a single value based on the given reward aggregator. The discount value is set to the largest discount of all the agents. This essentially means that if any single agent is alive, the discount value won't be zero.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

Array containing the action to take.

required

Returns:

Name Type Description
state State

State object corresponding to the next state of the environment,

timestep TimeStep[Observation]

TimeStep object corresponding the timestep returned by the environment,

Source code in jumanji/wrappers.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    The rewards are aggregated into a single value based on the given reward aggregator.
    The discount value is set to the largest discount of all the agents. This
    essentially means that if any single agent is alive, the discount value won't be zero.

    Args:
        state: State object containing the dynamics of the environment.
        action: Array containing the action to take.

    Returns:
        state: State object corresponding to the next state of the environment,
        timestep: TimeStep object corresponding the timestep returned by the environment,
    """
    state, timestep = self._env.step(state, action)
    timestep = self._aggregate_timestep(timestep)
    return state, timestep

VmapAutoResetWrapper(env, next_obs_in_extras=False) #

Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

Efficient combination of VmapWrapper and AutoResetWrapper, to be used as a replacement of the combination of both wrappers. env = VmapAutoResetWrapper(env) is equivalent to env = VmapWrapper(AutoResetWrapper(env)) but is more efficient as it parallelizes homogeneous computation and does not run branches of the computational graph that are not needed (heterogeneous computation). - Homogeneous computation: call step function on all environments in the batch. - Heterogeneous computation: conditional auto-reset (call reset function for some environments within the batch because they have terminated). NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].

Wrap an environment to vmap it and automatically reset it when the episode terminates.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

the environment to wrap.

required
next_obs_in_extras bool

whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation.

False
Source code in jumanji/wrappers.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def __init__(
    self,
    env: Environment[State, ActionSpec, Observation],
    next_obs_in_extras: bool = False,
):
    """Wrap an environment to vmap it and automatically reset it when the episode terminates.

    Args:
        env: the environment to wrap.
        next_obs_in_extras: whether to store the next observation in the extras of the
            terminal timestep. This is useful for e.g. truncation.
    """
    super().__init__(env)
    self.next_obs_in_extras = next_obs_in_extras
    if next_obs_in_extras:
        self._maybe_add_obs_to_extras = add_obs_to_extras
    else:
        self._maybe_add_obs_to_extras = lambda timestep: timestep  # no-op

render(state) #

Render the first environment state of the given batch. The remaining elements of the batched state are ignored.

Parameters:

Name Type Description Default
state State

State object containing the current dynamics of the environment.

required
Source code in jumanji/wrappers.py
591
592
593
594
595
596
597
598
599
def render(self, state: State) -> Any:
    """Render the first environment state of the given batch.
    The remaining elements of the batched state are ignored.

    Args:
        state: State object containing the current dynamics of the environment.
    """
    state_0 = tree_utils.tree_slice(state, 0)
    return super().render(state_0)

reset(key) #

Resets a batch of environments to initial states.

The first dimension of the key will dictate the number of concurrent environments.

To obtain a key with the right first dimension, you may call jax.random.split on key with the parameter num representing the number of concurrent environments.

Parameters:

Name Type Description Default
key PRNGKey

random keys used to reset the environments where the first dimension is the number of desired environments.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environments,

timestep TimeStep[Observation]

TimeStep object corresponding the first timesteps returned by the environments,

Source code in jumanji/wrappers.py
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets a batch of environments to initial states.

    The first dimension of the key will dictate the number of concurrent environments.

    To obtain a key with the right first dimension, you may call `jax.random.split` on key
    with the parameter `num` representing the number of concurrent environments.

    Args:
        key: random keys used to reset the environments where the first dimension is the number
            of desired environments.

    Returns:
        state: `State` object corresponding to the new state of the environments,
        timestep: `TimeStep` object corresponding the first timesteps returned by the
            environments,
    """
    state, timestep = jax.vmap(self._env.reset)(key)
    timestep = self._maybe_add_obs_to_extras(timestep)
    return state, timestep

step(state, action) #

Run one timestep of all environments' dynamics. It automatically resets environment(s) in which episodes have terminated.

The first dimension of the state will dictate the number of concurrent environments.

See VmapAutoResetWrapper.reset for more details on how to get a state of concurrent environments.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environments.

required
action Array

Array containing the actions to take.

required

Returns:

Name Type Description
state State

State object corresponding to the next states of the environments.

timestep TimeStep[Observation]

TimeStep object corresponding the timesteps returned by the environments.

Source code in jumanji/wrappers.py
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of all environments' dynamics. It automatically resets environment(s)
    in which episodes have terminated.

    The first dimension of the state will dictate the number of concurrent environments.

    See `VmapAutoResetWrapper.reset` for more details on how to get a state of concurrent
    environments.

    Args:
        state: `State` object containing the dynamics of the environments.
        action: `Array` containing the actions to take.

    Returns:
        state: `State` object corresponding to the next states of the environments.
        timestep: `TimeStep` object corresponding the timesteps returned by the environments.
    """
    # Vmap homogeneous computation (parallelizable).
    state, timestep = jax.vmap(self._env.step)(state, action)
    # Map heterogeneous computation (non-parallelizable).
    state, timestep = jax.lax.map(lambda args: self._maybe_reset(*args), (state, timestep))
    return state, timestep

VmapWrapper(env) #

Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec

Source code in jumanji/wrappers.py
39
40
41
def __init__(self, env: Environment[State, ActionSpec, Observation]):
    self._env = env
    super().__init__()

render(state) #

Render the first environment state of the given batch. The remaining elements of the batched state are ignored.

Parameters:

Name Type Description Default
state State

State object containing the current dynamics of the environment.

required
Source code in jumanji/wrappers.py
365
366
367
368
369
370
371
372
373
def render(self, state: State) -> Any:
    """Render the first environment state of the given batch.
    The remaining elements of the batched state are ignored.

    Args:
        state: State object containing the current dynamics of the environment.
    """
    state_0 = tree_utils.tree_slice(state, 0)
    return super().render(state_0)

reset(key) #

Resets the environment to an initial state.

The first dimension of the key will dictate the number of concurrent environments.

To obtain a key with the right first dimension, you may call jax.random.split on key with the parameter num representing the number of concurrent environments.

Parameters:

Name Type Description Default
key PRNGKey

random keys used to reset the environments where the first dimension is the number of desired environments.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environments,

timestep TimeStep[Observation]

TimeStep object corresponding the first timesteps returned by the environments,

Source code in jumanji/wrappers.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment to an initial state.

    The first dimension of the key will dictate the number of concurrent environments.

    To obtain a key with the right first dimension, you may call `jax.random.split` on key
    with the parameter `num` representing the number of concurrent environments.

    Args:
        key: random keys used to reset the environments where the first dimension is the number
            of desired environments.

    Returns:
        state: State object corresponding to the new state of the environments,
        timestep: TimeStep object corresponding the first timesteps returned by the
            environments,
    """
    state, timestep = jax.vmap(self._env.reset)(key)
    return state, timestep

step(state, action) #

Run one timestep of the environment's dynamics.

The first dimension of the state will dictate the number of concurrent environments.

See VmapWrapper.reset for more details on how to get a state of concurrent environments.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environments.

required
action Array

Array containing the actions to take.

required

Returns:

Name Type Description
state State

State object corresponding to the next states of the environments,

timestep TimeStep[Observation]

TimeStep object corresponding the timesteps returned by the environments,

Source code in jumanji/wrappers.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    The first dimension of the state will dictate the number of concurrent environments.

    See `VmapWrapper.reset` for more details on how to get a state of concurrent
    environments.

    Args:
        state: State object containing the dynamics of the environments.
        action: Array containing the actions to take.

    Returns:
        state: State object corresponding to the next states of the environments,
        timestep: TimeStep object corresponding the timesteps returned by the environments,
    """
    state, timestep = jax.vmap(self._env.step)(state, action)
    return state, timestep

Wrapper(env) #

Bases: Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]

Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72

Source code in jumanji/wrappers.py
39
40
41
def __init__(self, env: Environment[State, ActionSpec, Observation]):
    self._env = env
    super().__init__()

action_spec cached property #

Returns the action spec.

discount_spec cached property #

Returns the discount spec.

observation_spec cached property #

Returns the observation spec.

reward_spec cached property #

Returns the reward spec.

unwrapped property #

Returns the wrapped env.

close() #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

Source code in jumanji/wrappers.py
109
110
111
112
113
114
115
def close(self) -> None:
    """Perform any necessary cleanup.

    Environments will automatically :meth:`close()` themselves when
    garbage collected or when the program exits.
    """
    return self._env.close()

render(state) #

Compute render frames during initialisation of the environment.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
Source code in jumanji/wrappers.py
101
102
103
104
105
106
107
def render(self, state: State) -> Any:
    """Compute render frames during initialisation of the environment.

    Args:
        state: State object containing the dynamics of the environment.
    """
    return self._env.render(state)

reset(key) #

Resets the environment to an initial state.

Parameters:

Name Type Description Default
key PRNGKey

random key used to reset the environment.

required

Returns:

Name Type Description
state State

State object corresponding to the new state of the environment,

timestep TimeStep[Observation]

TimeStep object corresponding the first timestep returned by the environment,

Source code in jumanji/wrappers.py
56
57
58
59
60
61
62
63
64
65
66
def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
    """Resets the environment to an initial state.

    Args:
        key: random key used to reset the environment.

    Returns:
        state: State object corresponding to the new state of the environment,
        timestep: TimeStep object corresponding the first timestep returned by the environment,
    """
    return self._env.reset(key)

step(state, action) #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action Array

Array containing the action to take.

required

Returns:

Name Type Description
state State

State object corresponding to the next state of the environment,

timestep TimeStep[Observation]

TimeStep object corresponding the timestep returned by the environment,

Source code in jumanji/wrappers.py
68
69
70
71
72
73
74
75
76
77
78
79
def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
    """Run one timestep of the environment's dynamics.

    Args:
        state: State object containing the dynamics of the environment.
        action: Array containing the action to take.

    Returns:
        state: State object corresponding to the next state of the environment,
        timestep: TimeStep object corresponding the timestep returned by the environment,
    """
    return self._env.step(state, action)

add_obs_to_extras(timestep) #

Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]. Used when auto-resetting to store the observation from the terminal TimeStep (useful for e.g. truncation).

Parameters:

Name Type Description Default
timestep TimeStep[Observation]

TimeStep object containing the timestep returned by the environment.

required

Returns:

Type Description
TimeStep[Observation]

timestep where the observation is placed in timestep.extras["next_obs"].

Source code in jumanji/wrappers.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def add_obs_to_extras(timestep: TimeStep[Observation]) -> TimeStep[Observation]:
    """Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS].
    Used when auto-resetting to store the observation from the terminal TimeStep (useful for
    e.g. truncation).

    Args:
        timestep: TimeStep object containing the timestep returned by the environment.

    Returns:
        timestep where the observation is placed in timestep.extras["next_obs"].
    """
    extras = timestep.extras
    extras[NEXT_OBS_KEY_IN_EXTRAS] = timestep.observation
    return timestep.replace(extras=extras)  # type: ignore

jumanji_to_gym_obs(observation) #

Convert a Jumanji observation into a gym observation.

Parameters:

Name Type Description Default
observation Observation

JAX pytree with (possibly nested) containers that either have the __dict__ or _asdict methods implemented.

required

Returns:

Type Description
GymObservation

Numpy array or nested dictionary of numpy arrays.

Source code in jumanji/wrappers.py
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
def jumanji_to_gym_obs(observation: Observation) -> GymObservation:
    """Convert a Jumanji observation into a gym observation.

    Args:
        observation: JAX pytree with (possibly nested) containers that
            either have the `__dict__` or `_asdict` methods implemented.

    Returns:
        Numpy array or nested dictionary of numpy arrays.
    """
    if isinstance(observation, jnp.ndarray):
        return np.asarray(observation)
    elif hasattr(observation, "__dict__"):
        # Applies to various containers including `chex.dataclass`
        return {key: jumanji_to_gym_obs(value) for key, value in vars(observation).items()}
    elif hasattr(observation, "_asdict"):
        # Applies to `NamedTuple` container.
        return {
            key: jumanji_to_gym_obs(value)
            for key, value in observation._asdict().items()  # type: ignore
        }
    else:
        raise NotImplementedError(
            "Conversion only implemented for JAX pytrees with (possibly nested) containers "
            "that either have the `__dict__` or `_asdict` methods implemented."
        )