Skip to content

Clients

echo ships one client, TcpClient, paired with the built-in TcpTransport. If you plug a custom transport (gRPC, RDMA, …) into the Rust Transport trait, you'll write a matching client to go with it.

TcpClient

TcpClient

TCP client that sends pytree samples to an echo Server.

Parameters:

Name Type Description Default
host str

Server hostname or IP

required
port int

Server TCP port

required
data_sample PyTree[ndarray]

Example pytree (used to validate against the server spec)

required
max_inflight_msgs int

Max unacknowledged messages before send() blocks

32
Source code in python/echo/tcp_client.py
class TcpClient:
    """TCP client that sends pytree samples to an echo Server.

    Args:
        host: Server hostname or IP
        port: Server TCP port
        data_sample: Example pytree (used to validate against the server spec)
        max_inflight_msgs: Max unacknowledged messages before send() blocks
    """

    def __init__(
        self,
        host: str,
        port: int,
        data_sample: optree.PyTree[np.ndarray],
        max_inflight_msgs: int = 32,
    ):
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        self._peer = f"{host}:{port}"

        # Anything below this connect() may leave self._sock open; the outer
        # try/except guarantees we close it if construction fails partway.
        try:
            self._sock.connect((host, port))

            magic = self._recv_exact(4)
            if magic != b"ECHO":
                raise ValueError(
                    f"Unexpected handshake magic {magic!r} from {self._peer}; "
                    "is the peer an echo server?"
                )
            version = self._recv_exact(1)[0]
            if version != 1:
                raise ValueError(
                    f"Unsupported echo wire version {version} from {self._peer} "
                    "(this client supports 1)"
                )

            num_arrays = struct.unpack("<I", self._recv_exact(4))[0]
            shapes = []
            dtype_sizes = []
            for _ in range(num_arrays):
                dtype_size = struct.unpack("<I", self._recv_exact(4))[0]
                num_dims = struct.unpack("<I", self._recv_exact(4))[0]
                shape = list(struct.unpack(f"<{num_dims}I", self._recv_exact(num_dims * 4)))
                shapes.append(shape)
                dtype_sizes.append(dtype_size)

            leaves, _ = optree.tree_flatten(data_sample)
            if not all(isinstance(leaf, np.ndarray) for leaf in leaves):
                raise TypeError("All leaves of data_sample must be numpy arrays")
            expected_shapes = [list(leaf.shape) for leaf in leaves]
            expected_dtype_sizes = [leaf.dtype.itemsize for leaf in leaves]

            if shapes != expected_shapes or dtype_sizes != expected_dtype_sizes:
                raise ValueError(
                    f"Server spec mismatch: "
                    f"shapes={shapes} vs expected={expected_shapes}, "
                    f"dtype_sizes={dtype_sizes} vs expected={expected_dtype_sizes}"
                )

            self._payload_size = sum(
                int(np.prod(s)) * d for s, d in zip(shapes, dtype_sizes)
            )

            self._inflight_msgs = threading.BoundedSemaphore(max_inflight_msgs)
            self._in_flight_count = 0
            self._in_flight_zero = threading.Condition(threading.Lock())
            self._closed = threading.Event()
            self._ack_thread = threading.Thread(target=self._consume_acks, daemon=True)
            self._ack_thread.start()
        except BaseException:
            self._sock.close()
            raise

    def _recv_exact(self, n: int) -> bytes:
        parts = []
        remaining = n
        while remaining > 0:
            chunk = self._sock.recv(remaining)
            if not chunk:
                raise ConnectionError("Connection closed during recv")
            parts.append(chunk)
            remaining -= len(chunk)
        return b"".join(parts)

    def _consume_acks(self):
        try:
            while not self._closed.is_set():
                ack = self._sock.recv(1)
                if not ack:
                    self._closed.set()
                    break
                if ack == b"\x01":
                    self._inflight_msgs.release()
                    with self._in_flight_zero:
                        self._in_flight_count -= 1
                        if self._in_flight_count == 0:
                            self._in_flight_zero.notify_all()
        except OSError as e:
            if not self._closed.is_set():
                _log.warning("ack thread %s exiting: %s", self._peer, e)
            self._closed.set()
        finally:
            with self._in_flight_zero:
                self._in_flight_zero.notify_all()

    def send(self, data: optree.PyTree[np.ndarray]) -> None:
        """Send a pytree sample to the server.

        Raises:
            ConnectionClosedError: client closed or peer disconnected.
            ValueError: payload-size mismatch.
        """
        if self._closed.is_set():
            raise ConnectionClosedError(f"TcpClient for {self._peer} is closed")

        leaves, _ = optree.tree_flatten(data)
        payload = b"".join(leaf.tobytes() for leaf in leaves)

        if len(payload) != self._payload_size:
            raise ValueError(
                f"Payload size {len(payload)} != expected {self._payload_size}"
            )

        with self._in_flight_zero:
            self._in_flight_count += 1

        self._inflight_msgs.acquire()

        try:
            self._sock.sendall(payload)
        except OSError as e:
            self._closed.set()
            self._inflight_msgs.release()
            with self._in_flight_zero:
                self._in_flight_count -= 1
                if self._in_flight_count == 0:
                    self._in_flight_zero.notify_all()
            raise ConnectionClosedError(
                f"TcpClient for {self._peer} connection lost"
            ) from e

    def wait(self) -> None:
        """Block until all in-flight messages have been acknowledged."""
        with self._in_flight_zero:
            self._in_flight_zero.wait_for(lambda: self._in_flight_count == 0)

    def close(self) -> None:
        """Close the client. Idempotent."""
        if self._closed.is_set():
            return
        self._closed.set()
        try:
            self._sock.shutdown(socket.SHUT_RDWR)
        except OSError:
            pass
        self._sock.close()
        # Ack thread is daemon, but we still join briefly so callers can rely
        # on close() leaving no live threads behind in tests.
        self._ack_thread.join(timeout=1.0)

send

send(data: PyTree[ndarray]) -> None

Send a pytree sample to the server.

Raises:

Type Description
ConnectionClosedError

client closed or peer disconnected.

ValueError

payload-size mismatch.

Source code in python/echo/tcp_client.py
def send(self, data: optree.PyTree[np.ndarray]) -> None:
    """Send a pytree sample to the server.

    Raises:
        ConnectionClosedError: client closed or peer disconnected.
        ValueError: payload-size mismatch.
    """
    if self._closed.is_set():
        raise ConnectionClosedError(f"TcpClient for {self._peer} is closed")

    leaves, _ = optree.tree_flatten(data)
    payload = b"".join(leaf.tobytes() for leaf in leaves)

    if len(payload) != self._payload_size:
        raise ValueError(
            f"Payload size {len(payload)} != expected {self._payload_size}"
        )

    with self._in_flight_zero:
        self._in_flight_count += 1

    self._inflight_msgs.acquire()

    try:
        self._sock.sendall(payload)
    except OSError as e:
        self._closed.set()
        self._inflight_msgs.release()
        with self._in_flight_zero:
            self._in_flight_count -= 1
            if self._in_flight_count == 0:
                self._in_flight_zero.notify_all()
        raise ConnectionClosedError(
            f"TcpClient for {self._peer} connection lost"
        ) from e

wait

wait() -> None

Block until all in-flight messages have been acknowledged.

Source code in python/echo/tcp_client.py
def wait(self) -> None:
    """Block until all in-flight messages have been acknowledged."""
    with self._in_flight_zero:
        self._in_flight_zero.wait_for(lambda: self._in_flight_count == 0)

close

close() -> None

Close the client. Idempotent.

Source code in python/echo/tcp_client.py
def close(self) -> None:
    """Close the client. Idempotent."""
    if self._closed.is_set():
        return
    self._closed.set()
    try:
        self._sock.shutdown(socket.SHUT_RDWR)
    except OSError:
        pass
    self._sock.close()
    # Ack thread is daemon, but we still join briefly so callers can rely
    # on close() leaving no live threads behind in tests.
    self._ack_thread.join(timeout=1.0)

ConnectionClosedError

Bases: ConnectionError

Raised on send() after a TcpClient has been closed or its peer has disconnected.

Source code in python/echo/tcp_client.py
class ConnectionClosedError(ConnectionError):
    """Raised on send() after a TcpClient has been closed or its peer has disconnected."""