Wrappers
AutoResetWrapper(env, next_obs_in_extras=False)
#
Bases: Wrapper[State, ActionSpec, Observation]
, Generic[State, ActionSpec, Observation]
Automatically resets environments that are done. Once the terminal state is reached,
the state, observation, and step_type are reset. The observation and step_type of the
terminal TimeStep is reset to the reset observation and StepType.LAST, respectively.
The reward, discount, and extras retrieved from the transition to the terminal state.
NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].
WARNING: do not jax.vmap
the wrapped environment (e.g. do not use with the VmapWrapper
),
which would lead to inefficient computation due to both the step
and reset
functions
being processed each time step
is called. Please use the VmapAutoResetWrapper
instead.
Wrap an environment to automatically reset it when the episode terminates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
the environment to wrap. |
required |
next_obs_in_extras
|
bool
|
whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation. |
False
|
Source code in jumanji/wrappers.py
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 |
|
step(state, action)
#
Step the environment, with automatic resetting if the episode terminates.
Source code in jumanji/wrappers.py
439 440 441 442 443 444 445 446 447 448 449 450 451 452 |
|
JumanjiToDMEnvWrapper(env, key=None)
#
Bases: Environment
, Generic[State, ActionSpec, Observation]
A wrapper that converts Environment to dm_env.Environment.
Create the wrapped environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
|
required |
key
|
Optional[PRNGKey]
|
optional key to initialize the |
None
|
Source code in jumanji/wrappers.py
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
|
action_spec()
#
Returns the dm_env action spec.
Source code in jumanji/wrappers.py
212 213 214 |
|
observation_spec()
#
Returns the dm_env observation spec.
Source code in jumanji/wrappers.py
208 209 210 |
|
reset()
#
Starts a new sequence and returns the first TimeStep
of this sequence.
Returns:
Type | Description |
---|---|
TimeStep
|
A |
Source code in jumanji/wrappers.py
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
|
step(action)
#
Updates the environment according to the action and returns a TimeStep
.
If the environment returned a TimeStep
with StepType.LAST
at the
previous step, this call to step
will start a new sequence and action
will be ignored.
This method will also start a new sequence if called after the environment
has been constructed and reset
has not been called. Again, in this case
action
will be ignored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action
|
ArrayNumpy
|
A NumPy array, or a nested dict, list or tuple of arrays
corresponding to |
required |
Returns:
Type | Description |
---|---|
TimeStep
|
A |
Source code in jumanji/wrappers.py
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
|
JumanjiToGymWrapper(env, seed=0, backend=None)
#
Bases: Env
, Generic[State, ActionSpec, Observation]
A wrapper that converts a Jumanji Environment
to one that follows the gym.Env
API.
Create the Gym environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
|
required |
seed
|
int
|
the seed that is used to initialize the environment's PRNG. |
0
|
backend
|
Optional[str]
|
the XLA backend. |
None
|
Source code in jumanji/wrappers.py
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 |
|
close()
#
Closes the environment, important for rendering where pygame is imported.
Source code in jumanji/wrappers.py
695 696 697 |
|
render(mode='human')
#
Renders the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mode
|
str
|
currently not used since Jumanji does not currently support modes. |
'human'
|
Source code in jumanji/wrappers.py
684 685 686 687 688 689 690 691 692 693 |
|
reset(*, seed=None, options=None)
#
Resets the environment to an initial state by starting a new sequence
and returns the first Observation
of this sequence.
Returns:
Name | Type | Description |
---|---|---|
obs |
GymObservation
|
an element of the environment's observation_space. |
info |
optional
|
contains supplementary information such as metrics. |
Source code in jumanji/wrappers.py
626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 |
|
seed(seed=0)
#
Function which sets the seed for the environment's random number generator(s).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seed
|
int
|
the seed value for the random number generator(s). |
0
|
Source code in jumanji/wrappers.py
676 677 678 679 680 681 682 |
|
step(action)
#
Updates the environment according to the action and returns an Observation
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
action
|
ArrayNumpy
|
A NumPy array representing the action provided by the agent. |
required |
Returns:
Name | Type | Description |
---|---|---|
observation |
GymObservation
|
an element of the environment's observation_space. |
reward |
float
|
the amount of reward returned as a result of taking the action. |
terminated |
bool
|
whether a terminal state is reached. |
info |
bool
|
contains supplementary information such as metrics. |
Source code in jumanji/wrappers.py
649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 |
|
MultiToSingleWrapper(env, reward_aggregator=jnp.sum, discount_aggregator=jnp.max)
#
Bases: Wrapper[State, ActionSpec, Observation]
, Generic[State, ActionSpec, Observation]
A wrapper that converts a multi-agent Environment to a single-agent Environment.
Create the wrapped environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
|
required |
reward_aggregator
|
Callable
|
a function to aggregate all agents rewards into a single scalar value, e.g. sum. |
sum
|
discount_aggregator
|
Callable
|
a function to aggregate all agents discounts into a single scalar value, e.g. max. |
max
|
Source code in jumanji/wrappers.py
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
|
reset(key)
#
Resets the environment to an initial state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
PRNGKey
|
random key used to reset the environment. |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
State object corresponding to the new state of the environment, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the first timestep returned by the environment, |
Source code in jumanji/wrappers.py
264 265 266 267 268 269 270 271 272 273 274 275 276 |
|
step(state, action)
#
Run one timestep of the environment's dynamics.
The rewards are aggregated into a single value based on the given reward aggregator. The discount value is set to the largest discount of all the agents. This essentially means that if any single agent is alive, the discount value won't be zero.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
State object containing the dynamics of the environment. |
required |
action
|
Array
|
Array containing the action to take. |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
State object corresponding to the next state of the environment, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the timestep returned by the environment, |
Source code in jumanji/wrappers.py
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
|
VmapAutoResetWrapper(env, next_obs_in_extras=False)
#
Bases: Wrapper[State, ActionSpec, Observation]
, Generic[State, ActionSpec, Observation]
Efficient combination of VmapWrapper and AutoResetWrapper, to be used as a replacement of
the combination of both wrappers.
env = VmapAutoResetWrapper(env)
is equivalent to env = VmapWrapper(AutoResetWrapper(env))
but is more efficient as it parallelizes homogeneous computation and does not run branches
of the computational graph that are not needed (heterogeneous computation).
- Homogeneous computation: call step function on all environments in the batch.
- Heterogeneous computation: conditional auto-reset (call reset function for some environments
within the batch because they have terminated).
NOTE: The observation from the terminal TimeStep is stored in timestep.extras["next_obs"].
Wrap an environment to vmap it and automatically reset it when the episode terminates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
env
|
Environment[State, ActionSpec, Observation]
|
the environment to wrap. |
required |
next_obs_in_extras
|
bool
|
whether to store the next observation in the extras of the terminal timestep. This is useful for e.g. truncation. |
False
|
Source code in jumanji/wrappers.py
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 |
|
render(state)
#
Render the first environment state of the given batch. The remaining elements of the batched state are ignored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
State object containing the current dynamics of the environment. |
required |
Source code in jumanji/wrappers.py
573 574 575 576 577 578 579 580 581 |
|
reset(key)
#
Resets a batch of environments to initial states.
The first dimension of the key will dictate the number of concurrent environments.
To obtain a key with the right first dimension, you may call jax.random.split
on key
with the parameter num
representing the number of concurrent environments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
PRNGKey
|
random keys used to reset the environments where the first dimension is the number of desired environments. |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
|
timestep |
TimeStep[Observation]
|
|
Source code in jumanji/wrappers.py
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 |
|
step(state, action)
#
Run one timestep of all environments' dynamics. It automatically resets environment(s) in which episodes have terminated.
The first dimension of the state will dictate the number of concurrent environments.
See VmapAutoResetWrapper.reset
for more details on how to get a state of concurrent
environments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
|
required |
action
|
Array
|
|
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
|
timestep |
TimeStep[Observation]
|
|
Source code in jumanji/wrappers.py
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 |
|
VmapWrapper(env)
#
Bases: Wrapper[State, ActionSpec, Observation]
, Generic[State, ActionSpec, Observation]
Vectorized Jax env. Please note that all methods that return arrays do not return a batch dimension because the batch size is not known to the VmapWrapper. Methods that omit the batch dimension include: - observation_spec - action_spec - reward_spec - discount_spec
Source code in jumanji/wrappers.py
39 40 41 |
|
render(state)
#
Render the first environment state of the given batch. The remaining elements of the batched state are ignored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
State object containing the current dynamics of the environment. |
required |
Source code in jumanji/wrappers.py
347 348 349 350 351 352 353 354 355 |
|
reset(key)
#
Resets the environment to an initial state.
The first dimension of the key will dictate the number of concurrent environments.
To obtain a key with the right first dimension, you may call jax.random.split
on key
with the parameter num
representing the number of concurrent environments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
PRNGKey
|
random keys used to reset the environments where the first dimension is the number of desired environments. |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
State object corresponding to the new state of the environments, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the first timesteps returned by the environments, |
Source code in jumanji/wrappers.py
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
|
step(state, action)
#
Run one timestep of the environment's dynamics.
The first dimension of the state will dictate the number of concurrent environments.
See VmapWrapper.reset
for more details on how to get a state of concurrent
environments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
State object containing the dynamics of the environments. |
required |
action
|
Array
|
Array containing the actions to take. |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
State object corresponding to the next states of the environments, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the timesteps returned by the environments, |
Source code in jumanji/wrappers.py
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
|
Wrapper(env)
#
Bases: Environment[State, ActionSpec, Observation]
, Generic[State, ActionSpec, Observation]
Wraps the environment to allow modular transformations. Source: https://github.com/google/brax/blob/main/brax/envs/env.py#L72
Source code in jumanji/wrappers.py
39 40 41 |
|
action_spec: ActionSpec
cached
property
#
Returns the action spec.
discount_spec: specs.BoundedArray
cached
property
#
Returns the discount spec.
observation_spec: specs.Spec[Observation]
cached
property
#
Returns the observation spec.
reward_spec: specs.Array
cached
property
#
Returns the reward spec.
unwrapped: Environment[State, ActionSpec, Observation]
property
#
Returns the wrapped env.
close()
#
Perform any necessary cleanup.
Environments will automatically :meth:close()
themselves when
garbage collected or when the program exits.
Source code in jumanji/wrappers.py
109 110 111 112 113 114 115 |
|
render(state)
#
Compute render frames during initialisation of the environment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
State object containing the dynamics of the environment. |
required |
Source code in jumanji/wrappers.py
101 102 103 104 105 106 107 |
|
reset(key)
#
Resets the environment to an initial state.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key
|
PRNGKey
|
random key used to reset the environment. |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
State object corresponding to the new state of the environment, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the first timestep returned by the environment, |
Source code in jumanji/wrappers.py
56 57 58 59 60 61 62 63 64 65 66 |
|
step(state, action)
#
Run one timestep of the environment's dynamics.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state
|
State
|
State object containing the dynamics of the environment. |
required |
action
|
Array
|
Array containing the action to take. |
required |
Returns:
Name | Type | Description |
---|---|---|
state |
State
|
State object corresponding to the next state of the environment, |
timestep |
TimeStep[Observation]
|
TimeStep object corresponding the timestep returned by the environment, |
Source code in jumanji/wrappers.py
68 69 70 71 72 73 74 75 76 77 78 79 |
|
add_obs_to_extras(timestep)
#
Place the observation in timestep.extras[NEXT_OBS_KEY_IN_EXTRAS]. Used when auto-resetting to store the observation from the terminal TimeStep (useful for e.g. truncation).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
timestep
|
TimeStep[Observation]
|
TimeStep object containing the timestep returned by the environment. |
required |
Returns:
Type | Description |
---|---|
TimeStep[Observation]
|
timestep where the observation is placed in timestep.extras["next_obs"]. |
Source code in jumanji/wrappers.py
361 362 363 364 365 366 367 368 369 370 371 372 373 374 |
|
jumanji_to_gym_obs(observation)
#
Convert a Jumanji observation into a gym observation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
observation
|
Observation
|
JAX pytree with (possibly nested) containers that
either have the |
required |
Returns:
Type | Description |
---|---|
GymObservation
|
Numpy array or nested dictionary of numpy arrays. |
Source code in jumanji/wrappers.py
704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 |
|