TrajectoryAccumulator
Pure-Python helper for accumulating multi-timescale pytree samples on the
client side before calling client.send. Useful when a single rollout step
produces several arrays at different rates (e.g. one transition per env
step, plus a single episode-level statistic per episode).
Each timescale is either:
- Buffered — every leaf shares the same leading dim
N (N > 1);
that's the capacity, filled by N add() calls.
- Single-item — capacity 1; detected when every leaf is 1-d or has
at least one leaf that is 0-d.
add() replaces the whole leaf.
See the guide for the rationale and
worked examples.
TrajectoryAccumulator
Multi-timescale accumulator: fixed-size pytree buffer.
Per timescale, the example pytree determines how many add() calls
fit before the buffer is full:
-
Buffered timescale — every leaf shares the same leading dim N
(N > 1); the accumulator stores N per-add items into
stored[s:s+1] = incoming slot-by-slot. N becomes the
timescale's capacity.
-
Single-item timescale — the timescale holds one trailing piece of
context (e.g. a bootstrap step, an episode return) rather than a
buffer. Detected when at least one leaf is 0-d, or all leaves
have shape[0] == 1. Capacity is 1; add() replaces the
whole leaf, so non-0-d leaves may have any per-item shape (apart
from the optional leading 1).
Parameters:
| Name |
Type |
Description |
Default |
example
|
dict[str, Any]
|
Dict with timescale names as top-level keys. Each value is
a pytree whose leaves declare the per-timescale layout per the
rule above.
|
required
|
Source code in python/echo/trajectory_accumulator.py
| class TrajectoryAccumulator:
"""Multi-timescale accumulator: fixed-size pytree buffer.
Per timescale, the example pytree determines how many ``add()`` calls
fit before the buffer is full:
* **Buffered timescale** — every leaf shares the same leading dim ``N``
(``N > 1``); the accumulator stores ``N`` per-add items into
``stored[s:s+1] = incoming`` slot-by-slot. ``N`` becomes the
timescale's capacity.
* **Single-item timescale** — the timescale holds one trailing piece of
context (e.g. a bootstrap step, an episode return) rather than a
buffer. Detected when at least one leaf is 0-d, or all leaves
have ``shape[0] == 1``. Capacity is ``1``; ``add()`` replaces the
whole leaf, so non-0-d leaves may have any per-item shape (apart
from the optional leading 1).
Args:
example: Dict with timescale names as top-level keys. Each value is
a pytree whose leaves declare the per-timescale layout per the
rule above.
"""
def __init__(self, example: dict[str, Any]):
if not isinstance(example, dict):
raise TypeError("example must be a dict with timescale names as top-level keys")
self._counts: dict[str, int] = {}
self._single_item: dict[str, bool] = {}
for name, subtree in example.items():
leaves = optree.tree_leaves(subtree)
if not leaves:
raise ValueError(f"Timescale '{name}' has no array leaves")
# Single-item: any 0-d leaf OR every leaf with leading dim 1.
if any(leaf.ndim == 0 for leaf in leaves) or all(leaf.shape[0] == 1 for leaf in leaves):
self._counts[name] = 1
self._single_item[name] = True
else:
leading = [leaf.shape[0] for leaf in leaves]
if not all(s == leading[0] for s in leading):
raise ValueError(
f"All leaves in buffered timescale '{name}' must share the same "
f"leading dimension (got {leading}); make any leaf 0-d or "
f"all shape (1, ...) to mark the timescale single-item instead"
)
self._counts[name] = leading[0]
self._single_item[name] = False
self._tree: dict[str, Any] = {n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items()}
self._slot: dict[str, int] = {name: 0 for name in example}
def add(self, name: str, data: Any) -> None:
"""Write a single-item pytree into the next slot for timescale *name*."""
if name not in self._counts:
raise KeyError(f"Unknown timescale '{name}'. Known: {list(self._counts)}")
s = self._slot[name]
if s >= self._counts[name]:
raise IndexError(f"Timescale '{name}' has {self._counts[name]} slots, but you tried to add at index {s}")
# Single-item: replace the whole leaf
# Buffered: write into the next slot of the leading dim.
key = Ellipsis if self._single_item[name] else slice(s, s + 1)
def _write_slot(stored, incoming):
stored[key] = incoming
return stored
optree.tree_map_(_write_slot, self._tree[name], data)
self._slot[name] += 1
def build(self) -> dict[str, Any]:
"""Return the filled pytree and reset slot counters.
The returned tree aliases internal buffers; callers must finish using
it (e.g. complete the synchronous send) before the next ``add()``.
"""
self._slot = {name: 0 for name in self._slot}
return self._tree
def reset(self) -> None:
"""Reset slot counters without sending (e.g. on episode abort)."""
self._slot = {name: 0 for name in self._slot}
|
add
add(name: str, data: Any) -> None
Write a single-item pytree into the next slot for timescale name.
Source code in python/echo/trajectory_accumulator.py
| def add(self, name: str, data: Any) -> None:
"""Write a single-item pytree into the next slot for timescale *name*."""
if name not in self._counts:
raise KeyError(f"Unknown timescale '{name}'. Known: {list(self._counts)}")
s = self._slot[name]
if s >= self._counts[name]:
raise IndexError(f"Timescale '{name}' has {self._counts[name]} slots, but you tried to add at index {s}")
# Single-item: replace the whole leaf
# Buffered: write into the next slot of the leading dim.
key = Ellipsis if self._single_item[name] else slice(s, s + 1)
def _write_slot(stored, incoming):
stored[key] = incoming
return stored
optree.tree_map_(_write_slot, self._tree[name], data)
self._slot[name] += 1
|
build
build() -> dict[str, Any]
Return the filled pytree and reset slot counters.
The returned tree aliases internal buffers; callers must finish using
it (e.g. complete the synchronous send) before the next add().
Source code in python/echo/trajectory_accumulator.py
| def build(self) -> dict[str, Any]:
"""Return the filled pytree and reset slot counters.
The returned tree aliases internal buffers; callers must finish using
it (e.g. complete the synchronous send) before the next ``add()``.
"""
self._slot = {name: 0 for name in self._slot}
return self._tree
|
reset
Reset slot counters without sending (e.g. on episode abort).
Source code in python/echo/trajectory_accumulator.py
| def reset(self) -> None:
"""Reset slot counters without sending (e.g. on episode abort)."""
self._slot = {name: 0 for name in self._slot}
|