Skip to content

Trajectory Buffer

Bases: Generic[Experience, BufferState, BufferSample]

Pure functions defining the trajectory buffer. This buffer assumes batches added to the buffer are a pytree with a shape prefix of (batch_size, trajectory_length). Consecutive batches are then concatenated along the second axis (i.e. the time axis). During sampling this allows for trajectories to be sampled - by slicing consecutive sequences along the time axis.

Attributes:

Name Type Description
init Callable[[Experience], BufferState]

A pure function which may be used to initialise the buffer state using a single timestep (e.g. (s,a,r)).

add Callable[[BufferState, Experience], BufferState]

A pure function for adding a new batch of experience to the buffer state.

sample Callable[[BufferState, PRNGKey], BufferSample]

A pure function for sampling a batch of data from the replay buffer, with a leading axis of size (sample_batch_size, sample_sequence_length). Note sample_batch_size and sample_sequence_length may be different to the batch size and sequence length of data added to the state using the add function.

can_sample Callable[[BufferState], Array]

Whether the buffer can be sampled from, which is determined by if the number of trajectories added to the buffer state is greater than or equal to the min_length.

See make_trajectory_buffer for how this container is instantiated.