Skip to content

Types

StepType #

Bases: int8

Defines the status of a TimeStep within a sequence.

First: 0 Mid: 1 Last: 2

TimeStep #

Bases: Generic[Observation]

Copied from dm_env.TimeStep with the goal of making it a Jax Type. The original dm_env.TimeStep is not a Jax type because inheriting a namedtuple is not treated as a valid Jax type (https://github.com/google/jax/issues/806).

A TimeStep contains the data emitted by an environment at each step of interaction. A TimeStep holds a step_type, an observation (typically a NumPy array or a dict or list of arrays), and an associated reward and discount.

The first TimeStep in a sequence will have StepType.FIRST. The final TimeStep will have StepType.LAST. All other TimeSteps in a sequence will have `StepType.MID.

Attributes:

Name Type Description
step_type StepType

A StepType enum value.

reward Array

A scalar, NumPy array, nested dict, list or tuple of rewards; or None if step_type is StepType.FIRST, i.e. at the start of a sequence.

discount Array

A scalar, NumPy array, nested dict, list or tuple of discount values in the range [0, 1], or None if step_type is StepType.FIRST, i.e. at the start of a sequence.

observation Observation

A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array.

extras Dict

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is an empty dictionary.

get_valid_dtype(dtype) #

Cast a dtype taking into account the user type precision. E.g., if 64 bit is not enabled, jnp.dtype(jnp.float_) is still float64. By passing the given dtype through jnp.empty we get the supported dtype of float32.

Parameters:

Name Type Description Default
dtype Union[dtype, type]

jax numpy dtype or string specifying the array dtype.

required

Returns:

Type Description
dtype

dtype converted to the correct type precision.

Source code in jumanji/types.py
238
239
240
241
242
243
244
245
246
247
248
249
def get_valid_dtype(dtype: Union[jnp.dtype, type]) -> jnp.dtype:
    """Cast a dtype taking into account the user type precision. E.g., if 64 bit is not enabled,
    jnp.dtype(jnp.float_) is still float64. By passing the given dtype through `jnp.empty` we get
    the supported dtype of float32.

    Args:
        dtype: jax numpy dtype or string specifying the array dtype.

    Returns:
        dtype converted to the correct type precision.
    """
    return jnp.empty((), dtype).dtype  # type: ignore

restart(observation, extras=None, shape=(), dtype=float) #

Returns a TimeStep with step_type set to StepType.FIRST.

Parameters:

Name Type Description Default
observation Observation

array or tree of arrays.

required
extras Optional[Dict]

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.

None
shape Union[int, Sequence[int]]

optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount.

()
dtype Union[dtype, type]

Optional parameter to specify the data type of the rewards and discounts. Defaults to float.

float

Returns:

Type Description
TimeStep

TimeStep identified as a reset.

Source code in jumanji/types.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def restart(
    observation: Observation,
    extras: Optional[Dict] = None,
    shape: Union[int, Sequence[int]] = (),
    dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
    """Returns a `TimeStep` with `step_type` set to `StepType.FIRST`.

    Args:
        observation: array or tree of arrays.
        extras: environment metric(s) or information returned by the environment but
            not observed by the agent (hence not in the observation). For example, it
            could be whether an invalid action was taken. In most environments, extras
            is None.
        shape: optional parameter to specify the shape of the rewards and discounts.
            Allows multi-agent environment compatibility. Defaults to () for
            scalar reward and discount.
        dtype: Optional parameter to specify the data type of the rewards and discounts.
            Defaults to `float`.

    Returns:
        TimeStep identified as a reset.
    """
    extras = extras or {}
    return TimeStep(
        step_type=StepType.FIRST,
        reward=jnp.zeros(shape, dtype=dtype),
        discount=jnp.ones(shape, dtype=dtype),
        observation=observation,
        extras=extras,
    )

termination(reward, observation, extras=None, shape=(), dtype=float) #

Returns a TimeStep with step_type set to StepType.LAST.

Parameters:

Name Type Description Default
reward Array

array.

required
observation Observation

array or tree of arrays.

required
extras Optional[Dict]

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.

None
shape Union[int, Sequence[int]]

optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount.

()
dtype Union[dtype, type]

Optional parameter to specify the data type of the discounts. Defaults to float.

float

Returns:

Type Description
TimeStep

TimeStep identified as the termination of an episode.

Source code in jumanji/types.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def termination(
    reward: Array,
    observation: Observation,
    extras: Optional[Dict] = None,
    shape: Union[int, Sequence[int]] = (),
    dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
    """Returns a `TimeStep` with `step_type` set to `StepType.LAST`.

    Args:
        reward: array.
        observation: array or tree of arrays.
        extras: environment metric(s) or information returned by the environment but
            not observed by the agent (hence not in the observation). For example, it
            could be whether an invalid action was taken. In most environments, extras
            is None.
        shape: optional parameter to specify the shape of the rewards and discounts.
            Allows multi-agent environment compatibility. Defaults to () for
            scalar reward and discount.
        dtype: Optional parameter to specify the data type of the discounts. Defaults
            to `float`.

    Returns:
        TimeStep identified as the termination of an episode.
    """
    extras = extras or {}
    return TimeStep(
        step_type=StepType.LAST,
        reward=reward,
        discount=jnp.zeros(shape, dtype=dtype),
        observation=observation,
        extras=extras,
    )

transition(reward, observation, discount=None, extras=None, shape=(), dtype=float) #

Returns a TimeStep with step_type set to StepType.MID.

Parameters:

Name Type Description Default
reward Array

array.

required
observation Observation

array or tree of arrays.

required
discount Optional[Array]

array.

None
extras Optional[Dict]

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.

None
shape Union[int, Sequence[int]]

optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount.

()
dtype Union[dtype, type]

Optional parameter to specify the data type of the discounts. Defaults to float.

float

Returns:

Type Description
TimeStep

TimeStep identified as a transition.

Source code in jumanji/types.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def transition(
    reward: Array,
    observation: Observation,
    discount: Optional[Array] = None,
    extras: Optional[Dict] = None,
    shape: Union[int, Sequence[int]] = (),
    dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
    """Returns a `TimeStep` with `step_type` set to `StepType.MID`.

    Args:
        reward: array.
        observation: array or tree of arrays.
        discount: array.
        extras: environment metric(s) or information returned by the environment but
            not observed by the agent (hence not in the observation). For example, it
            could be whether an invalid action was taken. In most environments, extras
            is None.
        shape: optional parameter to specify the shape of the rewards and discounts.
            Allows multi-agent environment compatibility. Defaults to () for
            scalar reward and discount.
        dtype: Optional parameter to specify the data type of the discounts. Defaults
            to `float`.

    Returns:
        TimeStep identified as a transition.
    """
    discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
    extras = extras or {}
    return TimeStep(
        step_type=StepType.MID,
        reward=reward,
        discount=discount,
        observation=observation,
        extras=extras,
    )

truncation(reward, observation, discount=None, extras=None, shape=(), dtype=float) #

Returns a TimeStep with step_type set to StepType.LAST.

Parameters:

Name Type Description Default
reward Array

array.

required
observation Observation

array or tree of arrays.

required
discount Optional[Array]

array.

None
extras Optional[Dict]

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.

None
shape Union[int, Sequence[int]]

optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount.

()
dtype Union[dtype, type]

Optional parameter to specify the data type of the discounts. Defaults to float.

float

Returns:

Type Description
TimeStep

TimeStep identified as the truncation of an episode.

Source code in jumanji/types.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def truncation(
    reward: Array,
    observation: Observation,
    discount: Optional[Array] = None,
    extras: Optional[Dict] = None,
    shape: Union[int, Sequence[int]] = (),
    dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
    """Returns a `TimeStep` with `step_type` set to `StepType.LAST`.

    Args:
        reward: array.
        observation: array or tree of arrays.
        discount: array.
        extras: environment metric(s) or information returned by the environment but
            not observed by the agent (hence not in the observation). For example, it
            could be whether an invalid action was taken. In most environments, extras
            is None.
        shape: optional parameter to specify the shape of the rewards and discounts.
            Allows multi-agent environment compatibility. Defaults to () for
            scalar reward and discount.
        dtype: Optional parameter to specify the data type of the discounts. Defaults
            to `float`.

    Returns:
        TimeStep identified as the truncation of an episode.
    """
    discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
    extras = extras or {}
    return TimeStep(
        step_type=StepType.LAST,
        reward=reward,
        discount=discount,
        observation=observation,
        extras=extras,
    )