Skip to content

TrajectoryAccumulator

Pure-Python helper for accumulating multi-timescale pytree samples on the client side before calling client.send. Useful when a single rollout step produces several arrays at different rates (e.g. one transition per env step, plus a single episode-level statistic per episode).

TrajectoryAccumulator

Multi-timescale accumulator: fixed-size pytree buffer with double-buffering.

Parameters:

Name Type Description Default
example dict[str, Any]

Dict with timescale names as top-level keys. The leading dimension of each leaf array is the number of add() calls expected before the buffer is ready to send.

required
Source code in python/echo/trajectory_accumulator.py
class TrajectoryAccumulator:
    """Multi-timescale accumulator: fixed-size pytree buffer with double-buffering.

    Args:
        example: Dict with timescale names as top-level keys. The leading
            dimension of each leaf array is the number of ``add()`` calls
            expected before the buffer is ready to send.
    """

    def __init__(self, example: dict[str, Any]):
        if not isinstance(example, dict):
            raise TypeError("example must be a dict with timescale names as top-level keys")

        self._counts: dict[str, int] = {}
        for name, subtree in example.items():
            leaves = optree.tree_leaves(subtree)
            if not leaves:
                raise ValueError(f"Timescale '{name}' has no array leaves")
            leading = leaves[0].shape[0] if leaves[0].ndim > 0 else 1
            if not all((leaf.shape[0] if leaf.ndim > 0 else 1) == leading for leaf in leaves):
                raise ValueError(
                    f"All leaves in timescale '{name}' must share the same leading dimension"
                )
            self._counts[name] = leading

        # Two copies of the pytree for double-buffering.
        self._trees: list[dict[str, Any]] = [
            {n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items()},
            {n: optree.tree_map(np.zeros_like, sub) for n, sub in example.items()},
        ]
        self._active = 0
        self._slot: dict[str, int] = {name: 0 for name in example}

    def add(self, name: str, data: Any) -> None:
        """Write a single-item pytree into the next slot for timescale *name*."""
        if name not in self._counts:
            raise KeyError(f"Unknown timescale '{name}'. Known: {list(self._counts)}")
        s = self._slot[name]
        if s >= self._counts[name]:
            raise IndexError(
                f"Timescale '{name}' is already full ({self._counts[name]} slots). "
                "Call reset() or build() before adding more."
            )

        def _write_slot(stored, incoming):
            stored[s:s + 1] = incoming
            return stored

        optree.tree_map_(_write_slot, self._trees[self._active][name], data)
        self._slot[name] += 1

    def build(self) -> dict[str, Any]:
        """Return the filled pytree and flip the active buffer."""
        tree = self._trees[self._active]
        self._active = 1 - self._active
        self._slot = {name: 0 for name in self._slot}
        return tree

    def reset(self) -> None:
        """Reset slot counters without sending (e.g. on episode abort)."""
        self._slot = {name: 0 for name in self._slot}

add

add(name: str, data: Any) -> None

Write a single-item pytree into the next slot for timescale name.

Source code in python/echo/trajectory_accumulator.py
def add(self, name: str, data: Any) -> None:
    """Write a single-item pytree into the next slot for timescale *name*."""
    if name not in self._counts:
        raise KeyError(f"Unknown timescale '{name}'. Known: {list(self._counts)}")
    s = self._slot[name]
    if s >= self._counts[name]:
        raise IndexError(
            f"Timescale '{name}' is already full ({self._counts[name]} slots). "
            "Call reset() or build() before adding more."
        )

    def _write_slot(stored, incoming):
        stored[s:s + 1] = incoming
        return stored

    optree.tree_map_(_write_slot, self._trees[self._active][name], data)
    self._slot[name] += 1

build

build() -> dict[str, Any]

Return the filled pytree and flip the active buffer.

Source code in python/echo/trajectory_accumulator.py
def build(self) -> dict[str, Any]:
    """Return the filled pytree and flip the active buffer."""
    tree = self._trees[self._active]
    self._active = 1 - self._active
    self._slot = {name: 0 for name in self._slot}
    return tree

reset

reset() -> None

Reset slot counters without sending (e.g. on episode abort).

Source code in python/echo/trajectory_accumulator.py
def reset(self) -> None:
    """Reset slot counters without sending (e.g. on episode abort)."""
    self._slot = {name: 0 for name in self._slot}