Wrappers
wrappers
#
Wrapper (Environment, Generic)
#
Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72
unwrapped: Environment[State, ActionSpec, Observation]
property
readonly
#
Returns the wrapped env.
observation_spec: specs.Spec[Observation]
cached
property
writable
#
Returns the observation spec.
action_spec: ActionSpec
cached
property
writable
#
Returns the action spec.
reward_spec: specs.Array
cached
property
writable
#
Returns the reward spec.
discount_spec: specs.BoundedArray
cached
property
writable
#
Returns the discount spec.
__init__(self, env: Environment[State, ActionSpec, Observation])
special
#
reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]
#
Resets the environment to an initial state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
chex.PRNGKey |
random key used to reset the environment. |
required |
Returns:
Type | Description |
---|---|
state |
State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, |
step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]
#
Run one timestep of the environment's dynamics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
State object containing the dynamics of the environment. |
required |
action |
chex.Array |
Array containing the action to take. |
required |
Returns:
Type | Description |
---|---|
state |
State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, |
render(self, state: State) -> Any
#
Compute render frames during initialisation of the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
State object containing the dynamics of the environment. |
required |
close(self) -> None
#
Perform any necessary cleanup.
Environments will automatically :meth:close()
themselves when
garbage collected or when the program exits.
__enter__(self) -> Wrapper
special
#
__exit__(self, *args: Any) -> None
special
#
JumanjiToDMEnvWrapper (Environment, Generic)
#
A wrapper that converts Environment to dm_env.Environment.
unwrapped: Environment[State, ActionSpec, Observation]
property
readonly
#
__init__(self, env: Environment[State, ActionSpec, Observation], key: Optional[chex.PRNGKey] = None)
special
#
Create the wrapped environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env |
Environment[State, ActionSpec, Observation] |
|
required |
key |
Optional[chex.PRNGKey] |
optional key to initialize the |
None |
reset(self) -> dm_env.TimeStep
#
Starts a new sequence and returns the first TimeStep
of this sequence.
Returns:
Type | Description |
---|---|
A `TimeStep` namedtuple containing |
|
step(self, action: chex.ArrayNumpy) -> dm_env.TimeStep
#
Updates the environment according to the action and returns a TimeStep
.
If the environment returned a TimeStep
with StepType.LAST
at the
previous step, this call to step
will start a new sequence and action
will be ignored.
This method will also start a new sequence if called after the environment
has been constructed and reset
has not been called. Again, in this case
action
will be ignored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action |
chex.ArrayNumpy |
A NumPy array, or a nested dict, list or tuple of arrays
corresponding to |
required |
Returns:
Type | Description |
---|---|
A `TimeStep` namedtuple containing |
|
observation_spec(self) -> dm_env.specs.Array
#
Returns the dm_env observation spec.
action_spec(self) -> dm_env.specs.Array
#
Returns the dm_env action spec.
MultiToSingleWrapper (Wrapper, Generic)
#
A wrapper that converts a multi-agent Environment to a single-agent Environment.
__init__(self, env: Environment[State, ActionSpec, Observation], reward_aggregator: Callable = <function sum at 0x7ffac2fc6790>, discount_aggregator: Callable = <function amax at 0x7ffac2fc6f70>)
special
#
Create the wrapped environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env |
Environment[State, ActionSpec, Observation] |
|
required |
reward_aggregator |
Callable |
a function to aggregate all agents rewards into a single scalar value, e.g. sum. |
<function sum at 0x7ffac2fc6790> |
discount_aggregator |
Callable |
a function to aggregate all agents discounts into a single scalar value, e.g. max. |
<function amax at 0x7ffac2fc6f70> |
reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]
#
Resets the environment to an initial state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
chex.PRNGKey |
random key used to reset the environment. |
required |
Returns:
Type | Description |
---|---|
state |
State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, |
step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]
#
Run one timestep of the environment's dynamics.
The rewards are aggregated into a single value based on the given reward aggregator. The discount value is set to the largest discount of all the agents. This essentially means that if any single agent is alive, the discount value won't be zero.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
State object containing the dynamics of the environment. |
required |
action |
chex.Array |
Array containing the action to take. |
required |
Returns:
Type | Description |
---|---|
state |
State object corresponding to the next state of the environment, timestep: TimeStep object corresponding the timestep returned by the environment, |
VmapWrapper (Wrapper, Generic)
#
Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec
reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]
#
Resets the environment to an initial state.
The first dimension of the key will dictate the number of concurrent environments.
To obtain a key with the right first dimension, you may call jax.random.split
on key
with the parameter num
representing the number of concurrent environments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
chex.PRNGKey |
random keys used to reset the environments where the first dimension is the number of desired environments. |
required |
Returns:
Type | Description |
---|---|
state |
State object corresponding to the new state of the environments, timestep: TimeStep object corresponding the first timesteps returned by the environments, |
step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]
#
Run one timestep of the environment's dynamics.
The first dimension of the state will dictate the number of concurrent environments.
See VmapWrapper.reset
for more details on how to get a state of concurrent
environments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
State object containing the dynamics of the environments. |
required |
action |
chex.Array |
Array containing the actions to take. |
required |
Returns:
Type | Description |
---|---|
state |
State object corresponding to the next states of the environments, timestep: TimeStep object corresponding the timesteps returned by the environments, |
render(self, state: State) -> Any
#
Render the first environment state of the given batch. The remaining elements of the batched state are ignored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
State |
State object containing the current dynamics of the environment. |
required |
AutoResetWrapper (Wrapper, Generic)
#
Automatically resets environments that are done. Once the terminal state is reached,
the state, observation, and step_type are reset. The observation and step_type of the
terminal TimeStep is reset to the reset observation and StepType.LAST, respectively.
The reward, discount, and extras retrieved from the transition to the terminal state.
NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].
WARNING: do not jax.vmap
the wrapped environment (e.g. do not use with the VmapWrapper
),
which would lead to inefficient computation due to both the step
and reset
functions
being processed each time step
is called. Please use the VmapAutoResetWrapper
instead.
__init__(self, env: Environment[State, ActionSpec, Observation], next_obs_in_extras: bool = False)
special
#
Wrap an environment to automatically reset it when the episode terminates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env |
Environment[State, ActionSpec, Observation] |
the environment to wrap. |
required |
next_obs_in_extras |
bool |
whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation. |
False |
reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]
#
Resets the environment to an initial state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
chex.PRNGKey |
random key used to reset the environment. |
required |
Returns:
Type | Description |
---|---|
state |
State object corresponding to the new state of the environment, timestep: TimeStep object corresponding the first timestep returned by the environment, |
step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]
#
Step the environment, with automatic resetting if the episode terminates.
JumanjiToGymWrapper (Env, Generic)
#
A wrapper that converts a Jumanji Environment
to one that follows the gym.Env
API.
unwrapped: Environment[State, ActionSpec, Observation]
property
readonly
#
Returns the base non-wrapped environment.
Returns:
Type | Description |
---|---|
Env |
The base non-wrapped gym.Env instance |
__init__(self, env: Environment[State, ActionSpec, Observation], seed: int = 0, backend: Optional[str] = None)
special
#
Create the Gym environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env |
Environment[State, ActionSpec, Observation] |
|
required |
seed |
int |
the seed that is used to initialize the environment's PRNG. |
0 |
backend |
Optional[str] |
the XLA backend. |
None |
reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None) -> Union[GymObservation, Tuple[GymObservation, Optional[Any]]]
#
Resets the environment to an initial state by starting a new sequence
and returns the first Observation
of this sequence.
Returns:
Type | Description |
---|---|
obs |
an element of the environment's observation_space. info (optional): contains supplementary information such as metrics. |
step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Optional[Any]]
#
Updates the environment according to the action and returns an Observation
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action |
chex.ArrayNumpy |
A NumPy array representing the action provided by the agent. |
required |
Returns:
Type | Description |
---|---|
observation |
an element of the environment's observation_space. reward: the amount of reward returned as a result of taking the action. terminated: whether a terminal state is reached. info: contains supplementary information such as metrics. |
seed(self, seed: int = 0) -> None
#
Function which sets the seed for the environment's random number generator(s).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seed |
int |
the seed value for the random number generator(s). |
0 |
render(self, mode: str = 'human') -> Any
#
Renders the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode |
str |
currently not used since Jumanji does not currently support modes. |
'human' |
close(self) -> None
#
Closes the environment, important for rendering where pygame is imported.
jumanji_to_gym_obs(observation: Observation) -> GymObservation
#
Convert a Jumanji observation into a gym observation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
observation |
Observation |
JAX pytree with (possibly nested) containers that
either have the |
required |
Returns:
Type | Description |
---|---|
GymObservation |
Numpy array or nested dictionary of numpy arrays. |