Skip to content

Prioritised Flat Buffer

make_prioritised_flat_buffer(max_length, min_length, sample_batch_size, add_sequences=False, add_batch_size=None, priority_exponent=0.6, device='cpu') #

Makes a prioritised trajectory buffer act as a prioritised flat buffer.

Parameters:

Name Type Description Default
max_length int

The maximum length of the buffer.

required
min_length int

The minimum length of the buffer.

required
sample_batch_size int

The batch size of the samples.

required
add_sequences Optional[bool]

Whether data is being added in sequences to the buffer. If False, single transitions are being added each time add is called. Defaults to False.

False
add_batch_size Optional[int]

If adding data in batches, what is the batch size that is being added each time. If None, single transitions or single sequences are being added each time add is called. Defaults to None.

None
priority_exponent float

The exponent to use when calculating priorities. Defaults to 0.6.

0.6
device str

Depending on desired backend - more optimised functions are selected.

'cpu'

Returns:

Type Description
PrioritisedTrajectoryBuffer

The buffer.

Source code in flashbax/buffers/prioritised_flat_buffer.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def make_prioritised_flat_buffer(
    max_length: int,
    min_length: int,
    sample_batch_size: int,
    add_sequences: bool = False,
    add_batch_size: Optional[int] = None,
    priority_exponent: float = 0.6,
    device: str = "cpu",
) -> PrioritisedTrajectoryBuffer:
    """Makes a prioritised trajectory buffer act as a prioritised flat buffer.

    Args:
        max_length (int): The maximum length of the buffer.
        min_length (int): The minimum length of the buffer.
        sample_batch_size (int): The batch size of the samples.
        add_sequences (Optional[bool], optional): Whether data is being added in sequences
            to the buffer. If False, single transitions are being added each time add
            is called. Defaults to False.
        add_batch_size (Optional[int], optional): If adding data in batches, what is the
            batch size that is being added each time. If None, single transitions or single
            sequences are being added each time add is called. Defaults to None.
        priority_exponent (float, optional): The exponent to use when calculating priorities.
            Defaults to 0.6.
        device (str): Depending on desired backend - more optimised functions are selected.

    Returns:
        The buffer."""

    if add_batch_size is None:
        # add_batch_size being None implies that we are adding single transitions
        add_batch_size = 1
        add_batches = False
    else:
        add_batches = True

    validate_priority_exponent(priority_exponent)
    validate_flat_buffer_args(
        max_length=max_length,
        min_length=min_length,
        sample_batch_size=sample_batch_size,
        add_batch_size=add_batch_size,
    )
    if not validate_device(device):
        device = "cpu"

    with warnings.catch_warnings():
        warnings.filterwarnings(
            "ignore",
            message="Setting max_size dynamically sets the `max_length_time_axis` to "
            f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`."
            "This allows one to control exactly how many transitions are stored in the buffer."
            "Note that this overrides the `max_length_time_axis` argument.",
        )

        buffer = make_prioritised_trajectory_buffer(
            max_length_time_axis=None,  # Unused because max_size is specified
            min_length_time_axis=min_length // add_batch_size + 1,
            add_batch_size=add_batch_size,
            sample_batch_size=sample_batch_size,
            sample_sequence_length=2,
            period=1,
            max_size=max_length,
            priority_exponent=priority_exponent,
            device=device,
        )

    add_fn = buffer.add

    if not add_batches:
        add_fn = add_dim_to_args(
            add_fn, axis=0, starting_arg_index=1, ending_arg_index=2
        )

    if not add_sequences:
        axis = 1 - int(not add_batches)  # 1 if add_batches else 0
        add_fn = add_dim_to_args(
            add_fn, axis=axis, starting_arg_index=1, ending_arg_index=2
        )

    def sample_fn(
        state: PrioritisedTrajectoryBufferState, rng_key: PRNGKey
    ) -> TransitionSample:
        """Samples a batch of transitions from the buffer."""
        sampled_batch = buffer.sample(state, rng_key)
        first = jax.tree_util.tree_map(lambda x: x[:, 0], sampled_batch.experience)
        second = jax.tree_util.tree_map(lambda x: x[:, 1], sampled_batch.experience)
        return PrioritisedTransitionSample(
            experience=ExperiencePair(first=first, second=second),
            indices=sampled_batch.indices,
            priorities=sampled_batch.priorities,
        )

    return buffer.replace(add=add_fn, sample=sample_fn)  # type: ignore