Multi-host utilities

mlip.utils.multihost.create_device_mesh() Mesh

Create 1D device mesh for data parallelism. The funcion is cached to ensure that the same mesh object is used across calls.

Returns:

A Mesh with a single ‘devices’ axis containing all global devices.

mlip.utils.multihost.create_replicated_sharding(mesh: Mesh) NamedSharding

Create sharding for fully replicated data (params, opt state, ema).

Parameters:

mesh – The device mesh to use for sharding.

Returns:

A NamedSharding that replicates data across all devices.

mlip.utils.multihost.create_dp_sharding(mesh: Mesh) NamedSharding

Create sharding for data sharded across devices (keys, batches).

Parameters:

mesh – The device mesh to use for sharding.

Returns:

A NamedSharding that shards data along the first axis across devices.

mlip.utils.multihost.sync_string(value: str | None) str

Sync a string across all hosts.

This functions must be called from all process at the same time.

Parameters:

value – The string to sync, should be the value to sync on 1 process and None on all other.

Returns:

The synced string in all processes.

mlip.utils.multihost.only_specific_processes(processes: int | List[int] = 0) Callable

Decorator to execute a function only if the current process index matches the specified index/indices.

Parameters:

processes – The process index or a list of process indices for which the function should be executed. Defaults to 0.

Returns:

The wrapped function, which executes based on the process index.

Examples

@only_specific_processes()
def function_for_zero():
    print("Function executed for process 0!")

@only_specific_processes(1)
def function_for_one():
    print("Function executed for process 1!")

@only_specific_processes([2, 3])
def function_for_two_and_three():
    print("Function executed for process 2 or 3!")

# When executed in a multi-process JAX setup, functions will be
# executed/skipped based on process index.
mlip.utils.multihost.single_host_jax_and_orbax()

Context manager to mock JAX and Orbax functions.

Makes JAX look as though it is running on a single host if in a multi-host setting. Additionally, skips Orbax’s device sync checks (as they are not relevant in a “single-host” setting).

In a true single-host setting, this context manager does nothing.

Examples

print(jax.process_index())  # --> 0,...,num_hosts-1
print(jax.process_count())  # --> num_hosts
print(
    len(jax.devices()) == len(jax.local_devices())
)  # --> False in multi-host setting

with single_host_jax_and_orbax():
    print(jax.process_index())  # --> 0
    print(jax.process_count())  # --> 1
    print(len(jax.devices()) == len(jax.local_devices()))  # --> True
mlip.utils.multihost.assert_pytrees_match_across_hosts(tree: Any) None

Assert that the provided PyTree matches across multiple JAX hosts/processes.

If there are multiple JAX processes, this function checks that the PyTree A on the current host matches the PyTree on the host with process index 0. If there’s only one JAX process, the function returns immediately without any checks.

Parameters:

tree (Any) – The PyTree to check for consistency across hosts.

Raises:
  • ValueError – If the PyTrees do not match across the hosts with a detailed path

  • of differences.

mlip.utils.multihost.assert_pytrees_match(a: Any, b: Any) None

Assert that the two provided PyTrees have matching structures and values.

PyTrees are a way to flexibly handle nested data structures in the JAX library.

Parameters:
  • a (Any) – The first PyTree.

  • b (Any) – The second PyTree.

Raises:

ValueError – If the PyTrees do not match with a detailed path of differences.