Wrappers
AutoResetWrapper(env, next_obs_in_extras=False)
#
Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]
Automatically resets environments that are done. Once the terminal state is reached,
the state, observation, and step_type are reset. The observation and step_type of the
terminal TimeStep is reset to the reset observation and StepType.LAST, respectively.
The reward, discount, and extras retrieved from the transition to the terminal state.
NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].
WARNING: do not jax.vmap the wrapped environment (e.g. do not use with the VmapWrapper),
which would lead to inefficient computation due to both the step and reset functions
being processed each time step is called. Please use the VmapAutoResetWrapper instead.
Wrap an environment to automatically reset it when the episode terminates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
the environment to wrap. |
required |
next_obs_in_extras
|
bool
|
whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation. |
False
|
Source code in jumanji/wrappers.py
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 | |
step(state, action)
#
Step the environment, with automatic resetting if the episode terminates.
Source code in jumanji/wrappers.py
457 458 459 460 461 462 463 464 465 466 467 468 469 470 | |
JumanjiToDMEnvWrapper(env, key=None)
#
Bases: Environment, Generic[State, ActionSpec, Observation]
A wrapper that converts Environment to dm_env.Environment.
Create the wrapped environment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
|
required |
key
|
Optional[PRNGKey]
|
optional key to initialize the |
None
|
Source code in jumanji/wrappers.py
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | |
action_spec()
#
Returns the dm_env action spec.
Source code in jumanji/wrappers.py
212 213 214 | |
observation_spec()
#
Returns the dm_env observation spec.
Source code in jumanji/wrappers.py
208 209 210 | |
reset()
#
Starts a new sequence and returns the first TimeStep of this sequence.
Returns:
| Type | Description |
|---|---|
TimeStep
|
A |
Source code in jumanji/wrappers.py
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | |
step(action)
#
Updates the environment according to the action and returns a TimeStep.
If the environment returned a TimeStep with StepType.LAST at the
previous step, this call to step will start a new sequence and action
will be ignored.
This method will also start a new sequence if called after the environment
has been constructed and reset has not been called. Again, in this case
action will be ignored.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
action
|
ArrayNumpy
|
A NumPy array, or a nested dict, list or tuple of arrays
corresponding to |
required |
Returns:
| Type | Description |
|---|---|
TimeStep
|
A |
Source code in jumanji/wrappers.py
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | |
JumanjiToGymWrapper(env, seed=0, backend=None)
#
Bases: Env, Generic[State, ActionSpec, Observation]
A wrapper that converts a Jumanji Environment to one that follows the gym.Env API.
Create the Gym environment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
|
required |
seed
|
int
|
the seed that is used to initialize the environment's PRNG. |
0
|
backend
|
Optional[str]
|
the XLA backend. |
None
|
Source code in jumanji/wrappers.py
605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 | |
close()
#
Closes the environment, important for rendering where pygame is imported.
Source code in jumanji/wrappers.py
713 714 715 | |
render(mode='human')
#
Renders the environment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mode
|
str
|
currently not used since Jumanji does not currently support modes. |
'human'
|
Source code in jumanji/wrappers.py
702 703 704 705 706 707 708 709 710 711 | |
reset(*, seed=None, options=None)
#
Resets the environment to an initial state by starting a new sequence
and returns the first Observation of this sequence.
Returns:
| Name | Type | Description |
|---|---|---|
obs |
GymObservation
|
an element of the environment's observation_space. |
info |
optional
|
contains supplementary information such as metrics. |
Source code in jumanji/wrappers.py
644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 | |
seed(seed=0)
#
Function which sets the seed for the environment's random number generator(s).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int
|
the seed value for the random number generator(s). |
0
|
Source code in jumanji/wrappers.py
694 695 696 697 698 699 700 | |
step(action)
#
Updates the environment according to the action and returns an Observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
action
|
ArrayNumpy
|
A NumPy array representing the action provided by the agent. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
observation |
GymObservation
|
an element of the environment's observation_space. |
reward |
float
|
the amount of reward returned as a result of taking the action. |
terminated |
bool
|
whether a terminal state is reached. |
info |
bool
|
contains supplementary information such as metrics. |
Source code in jumanji/wrappers.py
667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 | |
MultiToSingleWrapper(env, reward_aggregator=jnp.sum, discount_aggregator=jnp.max)
#
Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]
A wrapper that converts a multi-agent Environment to a single-agent Environment.
Create the wrapped environment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
|
required |
reward_aggregator
|
Callable
|
a function to aggregate all agents rewards into a single scalar value, e.g. sum. |
sum
|
discount_aggregator
|
Callable
|
a function to aggregate all agents discounts into a single scalar value, e.g. max. |
max
|
Source code in jumanji/wrappers.py
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 | |
discount_spec
cached
property
#
Scalar discount spec matching the aggregated output.
reward_spec
cached
property
#
Scalar reward spec matching the aggregated output.
reset(key)
#
Resets the environment to an initial state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKey
|
random key used to reset the environment. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
state |
State
|
State object corresponding to the new state of the environment, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the first timestep returned by the environment, |
Source code in jumanji/wrappers.py
264 265 266 267 268 269 270 271 272 273 274 275 276 | |
step(state, action)
#
Run one timestep of the environment's dynamics.
The rewards are aggregated into a single value based on the given reward aggregator. The discount value is set to the largest discount of all the agents. This essentially means that if any single agent is alive, the discount value won't be zero.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
State object containing the dynamics of the environment. |
required |
action
|
Array
|
Array containing the action to take. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
state |
State
|
State object corresponding to the next state of the environment, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the timestep returned by the environment, |
Source code in jumanji/wrappers.py
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 | |
VmapAutoResetWrapper(env, next_obs_in_extras=False)
#
Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]
Efficient combination of VmapWrapper and AutoResetWrapper, to be used as a replacement of
the combination of both wrappers.
env = VmapAutoResetWrapper(env) is equivalent to env = VmapWrapper(AutoResetWrapper(env))
but is more efficient as it parallelizes homogeneous computation and does not run branches
of the computational graph that are not needed (heterogeneous computation).
- Homogeneous computation: call step function on all environments in the batch.
- Heterogeneous computation: conditional auto-reset (call reset function for some environments
within the batch because they have terminated).
NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].
Wrap an environment to vmap it and automatically reset it when the episode terminates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
the environment to wrap. |
required |
next_obs_in_extras
|
bool
|
whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation. |
False
|
Source code in jumanji/wrappers.py
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 | |
render(state)
#
Render the first environment state of the given batch. The remaining elements of the batched state are ignored.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
State object containing the current dynamics of the environment. |
required |
Source code in jumanji/wrappers.py
591 592 593 594 595 596 597 598 599 | |
reset(key)
#
Resets a batch of environments to initial states.
The first dimension of the key will dictate the number of concurrent environments.
To obtain a key with the right first dimension, you may call jax.random.split on key
with the parameter num representing the number of concurrent environments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKey
|
random keys used to reset the environments where the first dimension is the number of desired environments. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
state |
State
|
|
timestep |
TimeStep[Observation]
|
|
Source code in jumanji/wrappers.py
506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 | |
step(state, action)
#
Run one timestep of all environments' dynamics. It automatically resets environment(s) in which episodes have terminated.
The first dimension of the state will dictate the number of concurrent environments.
See VmapAutoResetWrapper.reset for more details on how to get a state of concurrent
environments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
|
required |
action
|
Array
|
|
required |
Returns:
| Name | Type | Description |
|---|---|---|
state |
State
|
|
timestep |
TimeStep[Observation]
|
|
Source code in jumanji/wrappers.py
527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 | |
VmapWrapper(env)
#
Bases: Wrapper[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]
Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec
Source code in jumanji/wrappers.py
39 40 41 | |
render(state)
#
Render the first environment state of the given batch. The remaining elements of the batched state are ignored.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
State object containing the current dynamics of the environment. |
required |
Source code in jumanji/wrappers.py
365 366 367 368 369 370 371 372 373 | |
reset(key)
#
Resets the environment to an initial state.
The first dimension of the key will dictate the number of concurrent environments.
To obtain a key with the right first dimension, you may call jax.random.split on key
with the parameter num representing the number of concurrent environments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKey
|
random keys used to reset the environments where the first dimension is the number of desired environments. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
state |
State
|
State object corresponding to the new state of the environments, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the first timesteps returned by the environments, |
Source code in jumanji/wrappers.py
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 | |
step(state, action)
#
Run one timestep of the environment's dynamics.
The first dimension of the state will dictate the number of concurrent environments.
See VmapWrapper.reset for more details on how to get a state of concurrent
environments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
State object containing the dynamics of the environments. |
required |
action
|
Array
|
Array containing the actions to take. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
state |
State
|
State object corresponding to the next states of the environments, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the timesteps returned by the environments, |
Source code in jumanji/wrappers.py
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 | |
Wrapper(env)
#
Bases: Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]
Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72
Source code in jumanji/wrappers.py
39 40 41 | |
action_spec
cached
property
#
Returns the action spec.
discount_spec
cached
property
#
Returns the discount spec.
observation_spec
cached
property
#
Returns the observation spec.
reward_spec
cached
property
#
Returns the reward spec.
unwrapped
property
#
Returns the wrapped env.
close()
#
Perform any necessary cleanup.
Environments will automatically :meth:close() themselves when
garbage collected or when the program exits.
Source code in jumanji/wrappers.py
109 110 111 112 113 114 115 | |
render(state)
#
Compute render frames during initialisation of the environment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
State object containing the dynamics of the environment. |
required |
Source code in jumanji/wrappers.py
101 102 103 104 105 106 107 | |
reset(key)
#
Resets the environment to an initial state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKey
|
random key used to reset the environment. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
state |
State
|
State object corresponding to the new state of the environment, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the first timestep returned by the environment, |
Source code in jumanji/wrappers.py
56 57 58 59 60 61 62 63 64 65 66 | |
step(state, action)
#
Run one timestep of the environment's dynamics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
State
|
State object containing the dynamics of the environment. |
required |
action
|
Array
|
Array containing the action to take. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
state |
State
|
State object corresponding to the next state of the environment, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the timestep returned by the environment, |
Source code in jumanji/wrappers.py
68 69 70 71 72 73 74 75 76 77 78 79 | |
add_obs_to_extras(timestep)
#
Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]. Used when auto-resetting to store the observation from the terminal TimeStep (useful for e.g. truncation).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
timestep
|
TimeStep[Observation]
|
TimeStep object containing the timestep returned by the environment. |
required |
Returns:
| Type | Description |
|---|---|
TimeStep[Observation]
|
timestep where the observation is placed in timestep.extras["next_obs"]. |
Source code in jumanji/wrappers.py
379 380 381 382 383 384 385 386 387 388 389 390 391 392 | |
jumanji_to_gym_obs(observation)
#
Convert a Jumanji observation into a gym observation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
observation
|
Observation
|
JAX pytree with (possibly nested) containers that
either have the |
required |
Returns:
| Type | Description |
|---|---|
GymObservation
|
Numpy array or nested dictionary of numpy arrays. |
Source code in jumanji/wrappers.py
722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 | |