Trajectory Buffer
Bases: Generic[Experience, BufferState, BufferSample]
Pure functions defining the trajectory buffer. This buffer assumes batches added to the buffer are a pytree with a shape prefix of (batch_size, trajectory_length). Consecutive batches are then concatenated along the second axis (i.e. the time axis). During sampling this allows for trajectories to be sampled - by slicing consecutive sequences along the time axis.
Attributes:
Name | Type | Description |
---|---|---|
init |
Callable[[Experience], BufferState]
|
A pure function which may be used to initialise the buffer state using a single timestep (e.g. (s,a,r)). |
add |
Callable[[BufferState, Experience], BufferState]
|
A pure function for adding a new batch of experience to the buffer state. |
sample |
Callable[[BufferState, PRNGKey], BufferSample]
|
A pure function for sampling a batch of data from the replay buffer, with a leading
axis of size ( |
can_sample |
Callable[[BufferState], Array]
|
Whether the buffer can be sampled from, which is determined by if the
number of trajectories added to the buffer state is greater than or equal to the
|
See make_trajectory_buffer
for how this container is instantiated.