Skip to content

Types

types #

StepType (int8) #

Defines the status of a TimeStep within a sequence.

First: 0 Mid: 1 Last: 2

TimeStep (Generic, Mapping) dataclass #

Copied from dm_env.TimeStep with the goal of making it a Jax Type. The original dm_env.TimeStep is not a Jax type because inheriting a namedtuple is not treated as a valid Jax type (https://github.com/google/jax/issues/806).

A TimeStep contains the data emitted by an environment at each step of interaction. A TimeStep holds a step_type, an observation (typically a NumPy array or a dict or list of arrays), and an associated reward and discount.

The first TimeStep in a sequence will have StepType.FIRST. The final TimeStep will have StepType.LAST. All other TimeSteps in a sequence will have `StepType.MID.

Attributes:

Name Type Description
step_type StepType

A StepType enum value.

reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

A scalar, NumPy array, nested dict, list or tuple of rewards; or None if step_type is StepType.FIRST, i.e. at the start of a sequence.

discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

A scalar, NumPy array, nested dict, list or tuple of discount values in the range [0, 1], or None if step_type is StepType.FIRST, i.e. at the start of a sequence.

observation ~Observation

A NumPy array, or a nested dict, list or tuple of arrays. Scalar values that can be cast to NumPy arrays (e.g. Python floats) are also valid in place of a scalar array.

extras Dict

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is an empty dictionary.

step_type: StepType dataclass-field #

reward: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] dataclass-field #

discount: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] dataclass-field #

observation: ~Observation dataclass-field #

extras: Dict dataclass-field #

__eq__(self, other) special #

__init__(self, step_type: StepType, reward: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number], discount: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number], observation: ~Observation, extras: Dict = <factory>) -> None special #

__repr__(self) special #

__getitem__(self, x) special #

__len__(self) special #

__iter__(self) special #

first(self) -> Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] #

mid(self) -> Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] #

last(self) -> Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] #

from_tuple(args) #

to_tuple(self) #

replace(self, **kwargs) #

__getstate__(self) special #

__setstate__(self, state) special #

restart(observation: ~Observation, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = ()) -> TimeStep #

Returns a TimeStep with step_type set to StepType.FIRST.

Parameters:

Name Type Description Default
observation ~Observation

array or tree of arrays.

required
extras Optional[Dict]

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.

None
shape Union[int, Sequence[int]]

optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount.

()

Returns:

Type Description
TimeStep

TimeStep identified as a reset.

transition(reward: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number], observation: ~Observation, discount: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] = None, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = ()) -> TimeStep #

Returns a TimeStep with step_type set to StepType.MID.

Parameters:

Name Type Description Default
reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

array.

required
observation ~Observation

array or tree of arrays.

required
discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

array.

None
extras Optional[Dict]

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.

None
shape Union[int, Sequence[int]]

optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount.

()

Returns:

Type Description
TimeStep

TimeStep identified as a transition.

termination(reward: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number], observation: ~Observation, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = ()) -> TimeStep #

Returns a TimeStep with step_type set to StepType.LAST.

Parameters:

Name Type Description Default
reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

array.

required
observation ~Observation

array or tree of arrays.

required
extras Optional[Dict]

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.

None
shape Union[int, Sequence[int]]

optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount.

()

Returns:

Type Description
TimeStep

TimeStep identified as the termination of an episode.

truncation(reward: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number], observation: ~Observation, discount: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number] = None, extras: Optional[Dict] = None, shape: Union[int, Sequence[int]] = ()) -> TimeStep #

Returns a TimeStep with step_type set to StepType.LAST.

Parameters:

Name Type Description Default
reward Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

array.

required
observation ~Observation

array or tree of arrays.

required
discount Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number]

array.

None
extras Optional[Dict]

environment metric(s) or information returned by the environment but not observed by the agent (hence not in the observation). For example, it could be whether an invalid action was taken. In most environments, extras is None.

None
shape Union[int, Sequence[int]]

optional parameter to specify the shape of the rewards and discounts. Allows multi-agent environment compatibility. Defaults to () for scalar reward and discount.

()

Returns:

Type Description
TimeStep

TimeStep identified as the truncation of an episode.

get_valid_dtype(dtype: Union[numpy.dtype, type]) -> dtype #

Cast a dtype taking into account the user type precision. E.g., if 64 bit is not enabled, jnp.dtype(jnp.float_) is still float64. By passing the given dtype through jnp.empty we get the supported dtype of float32.

Parameters:

Name Type Description Default
dtype Union[numpy.dtype, type]

jax numpy dtype or string specifying the array dtype.

required

Returns:

Type Description
dtype

dtype converted to the correct type precision.


Last update: 2024-11-01
Back to top