Skip to content

Wrappers

wrappers #

Wrapper (Environment, Generic) #

Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72

unwrapped: Environment[State, ActionSpec, Observation] property readonly #

Returns the wrapped env.

observation_spec: specs.Spec[Observation] cached property writable #

Returns the observation spec.

action_spec: ActionSpec cached property writable #

Returns the action spec.

reward_spec: specs.Array cached property writable #

Returns the reward spec.

discount_spec: specs.BoundedArray cached property writable #

Returns the discount spec.

__init__(self, env: Environment[State, ActionSpec, Observation]) special #

reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]] #

Resets the environment to an initial state.

Parameters:

Name Type Description Default
key chex.PRNGKey

random key used to reset the environment.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment,

step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]] #

Run one timestep of the environment's dynamics.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action chex.Array

Array containing the action to take.

required

Returns:

Type Description
state

State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment,

render(self, state: State) -> Any #

Compute render frames during initialisation of the environment.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required

close(self) -> None #

Perform any necessary cleanup.

Environments will automatically :meth:close() themselves when garbage collected or when the program exits.

__enter__(self) -> Wrapper special #

__exit__(self, *args: Any) -> None special #

JumanjiToDMEnvWrapper (Environment, Generic) #

A wrapper that converts Environment to dm_env.Environment.

unwrapped: Environment[State, ActionSpec, Observation] property readonly #

__init__(self, env: Environment[State, ActionSpec, Observation], key: Optional[chex.PRNGKey] = None) special #

Create the wrapped environment.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

Environmentto wrap to a dm_env.Environment.

required
key Optional[chex.PRNGKey]

optional key to initialize the Environment with.

None

reset(self) -> dm_env.TimeStep #

Starts a new sequence and returns the first TimeStep of this sequence.

Returns:

Type Description
A `TimeStep` namedtuple containing
  • step_type: A StepType of FIRST.
    • reward: None, indicating the reward is undefined.
    • discount: None, indicating the discount is undefined.
    • observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec.

step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep #

Updates the environment according to the action and returns a TimeStep.

If the environment returned a TimeStep with StepType.LAST at the previous step, this call to step will start a new sequence and action will be ignored.

This method will also start a new sequence if called after the environment has been constructed and reset has not been called. Again, in this case action will be ignored.

Parameters:

Name Type Description Default
action chex.ArrayNumpy

A NumPy array, or a nested dict, list or tuple of arrays corresponding to action_spec.

required

Returns:

Type Description
A `TimeStep` namedtuple containing
  • step_type: A StepType value.
    • reward: Reward at this timestep, or None if step_type is StepType.FIRST. Must conform to the specification returned by reward_spec.
    • discount: A discount in the range [0, 1], or None if step_type is StepType.FIRST. Must conform to the specification returned by discount_spec.
    • observation: A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array. Must conform to the specification returned by observation_spec.

observation_spec(self) -> dm_env.specs.Array #

Returns the dm_env observation spec.

action_spec(self) -> dm_env.specs.Array #

Returns the dm_env action spec.

MultiToSingleWrapper (Wrapper, Generic) #

A wrapper that converts a multi-agent Environment to a single-agent Environment.

__init__(self, env: Environment[State, ActionSpec, Observation], reward_aggregator: Callable = <function sum at 0x7ffac2fc6790>, discount_aggregator: Callable = <function amax at 0x7ffac2fc6f70>) special #

Create the wrapped environment.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

Environment to wrap to a dm_env.Environment.

required
reward_aggregator Callable

a function to aggregate all agents rewards into a single scalar value, e.g. sum.

<function sum at 0x7ffac2fc6790>
discount_aggregator Callable

a function to aggregate all agents discounts into a single scalar value, e.g. max.

<function amax at 0x7ffac2fc6f70>

reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]] #

Resets the environment to an initial state.

Parameters:

Name Type Description Default
key chex.PRNGKey

random key used to reset the environment.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment,

step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]] #

Run one timestep of the environment's dynamics.

The rewards are aggregated into a single value based on the given reward aggregator. The discount value is set to the largest discount of all the agents. This essentially means that if any single agent is alive, the discount value won't be zero.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environment.

required
action chex.Array

Array containing the action to take.

required

Returns:

Type Description
state

State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment,

VmapWrapper (Wrapper, Generic) #

Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec

reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]] #

Resets the environment to an initial state.

The first dimension of the key will dictate the number of concurrent environments.

To obtain a key with the right first dimension, you may call jax.random.split on key with the parameter num representing the number of concurrent environments.

Parameters:

Name Type Description Default
key chex.PRNGKey

random keys used to reset the environments where the first dimension is the number of desired environments.

required

Returns:

Type Description
state

State object corresponding to the new state of the environments, timestep: TimeStep object corresponding the first timesteps returned by the environments,

step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]] #

Run one timestep of the environment's dynamics.

The first dimension of the state will dictate the number of concurrent environments.

See VmapWrapper.reset for more details on how to get a state of concurrent environments.

Parameters:

Name Type Description Default
state State

State object containing the dynamics of the environments.

required
action chex.Array

Array containing the actions to take.

required

Returns:

Type Description
state

State object corresponding to the next states of the environments, timestep: TimeStep object corresponding the timesteps returned by the environments,

render(self, state: State) -> Any #

Render the first environment state of the given batch. The remaining elements of the batched state are ignored.

Parameters:

Name Type Description Default
state State

State object containing the current dynamics of the environment.

required

AutoResetWrapper (Wrapper, Generic) #

Automatically resets environments that are done. Once the terminal state is reached, the state, observation, and step_type are reset. The observation and step_type of the terminal TimeStep is reset to the reset observation and StepType.LAST, respectively. The reward, discount, and extras retrieved from the transition to the terminal state. NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"]. WARNING: do not jax.vmap the wrapped environment (e.g. do not use with the VmapWrapper), which would lead to inefficient computation due to both the step and reset functions being processed each time step is called. Please use the VmapAutoResetWrapper instead.

__init__(self, env: Environment[State, ActionSpec, Observation], next_obs_in_extras: bool = False) special #

Wrap an environment to automatically reset it when the episode terminates.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

the environment to wrap.

required
next_obs_in_extras bool

whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation.

False

reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]] #

Resets the environment to an initial state.

Parameters:

Name Type Description Default
key chex.PRNGKey

random key used to reset the environment.

required

Returns:

Type Description
state

State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment,

step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]] #

Step the environment, with automatic resetting if the episode terminates.

JumanjiToGymWrapper (Env, Generic) #

A wrapper that converts a Jumanji Environment to one that follows the gym.Env API.

unwrapped: Environment[State, ActionSpec, Observation] property readonly #

Returns the base non-wrapped environment.

Returns:

Type Description
Env

The base non-wrapped gym.Env instance

__init__(self, env: Environment[State, ActionSpec, Observation], seed: int = 0, backend: Optional[str] = None) special #

Create the Gym environment.

Parameters:

Name Type Description Default
env Environment[State, ActionSpec, Observation]

Environment to wrap to a gym.Env.

required
seed int

the seed that is used to initialize the environment's PRNG.

0
backend Optional[str]

the XLA backend.

None

reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None) -> Union[GymObservation, Tuple[GymObservation, Optional[Any]]] #

Resets the environment to an initial state by starting a new sequence and returns the first Observation of this sequence.

Returns:

Type Description
obs

an element of the environment's observation_space. info (optional): contains supplementary information such as metrics.

step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Optional[Any]] #

Updates the environment according to the action and returns an Observation.

Parameters:

Name Type Description Default
action chex.ArrayNumpy

A NumPy array representing the action provided by the agent.

required

Returns:

Type Description
observation

an element of the environment's observation_space. reward: the amount of reward returned as a result of taking the action. terminated: whether a terminal state is reached. info: contains supplementary information such as metrics.

seed(self, seed: int = 0) -> None #

Function which sets the seed for the environment's random number generator(s).

Parameters:

Name Type Description Default
seed int

the seed value for the random number generator(s).

0

render(self, mode: str = 'human') -> Any #

Renders the environment.

Parameters:

Name Type Description Default
mode str

currently not used since Jumanji does not currently support modes.

'human'

close(self) -> None #

Closes the environment, important for rendering where pygame is imported.

jumanji_to_gym_obs(observation: Observation) -> GymObservation #

Convert a Jumanji observation into a gym observation.

Parameters:

Name Type Description Default
observation Observation

JAX pytree with (possibly nested) containers that either have the __dict__ or _asdict methods implemented.

required

Returns:

Type Description
GymObservation

Numpy array or nested dictionary of numpy arrays.


Last update: 2024-11-01
Back to top