Multi-host utilities¶
- mlip.utils.multihost.create_device_mesh() Mesh¶
Create 1D device mesh for data parallelism. The funcion is cached to ensure that the same mesh object is used across calls.
- Returns:
A Mesh with a single ‘devices’ axis containing all global devices.
- mlip.utils.multihost.create_replicated_sharding(mesh: Mesh) NamedSharding¶
Create sharding for fully replicated data (params, opt state, ema).
- Parameters:
mesh – The device mesh to use for sharding.
- Returns:
A NamedSharding that replicates data across all devices.
- mlip.utils.multihost.create_dp_sharding(mesh: Mesh) NamedSharding¶
Create sharding for data sharded across devices (keys, batches).
- Parameters:
mesh – The device mesh to use for sharding.
- Returns:
A NamedSharding that shards data along the first axis across devices.
- mlip.utils.multihost.sync_string(value: str | None) str¶
Sync a string across all hosts.
This functions must be called from all process at the same time.
- Parameters:
value – The string to sync, should be the value to sync on 1 process and None on all other.
- Returns:
The synced string in all processes.
- mlip.utils.multihost.only_specific_processes(processes: int | List[int] = 0) Callable¶
Decorator to execute a function only if the current process index matches the specified index/indices.
- Parameters:
processes – The process index or a list of process indices for which the function should be executed. Defaults to 0.
- Returns:
The wrapped function, which executes based on the process index.
Examples
@only_specific_processes() def function_for_zero(): print("Function executed for process 0!") @only_specific_processes(1) def function_for_one(): print("Function executed for process 1!") @only_specific_processes([2, 3]) def function_for_two_and_three(): print("Function executed for process 2 or 3!") # When executed in a multi-process JAX setup, functions will be # executed/skipped based on process index.
- mlip.utils.multihost.single_host_jax_and_orbax()¶
Context manager to mock JAX and Orbax functions.
Makes JAX look as though it is running on a single host if in a multi-host setting. Additionally, skips Orbax’s device sync checks (as they are not relevant in a “single-host” setting).
In a true single-host setting, this context manager does nothing.
Examples
print(jax.process_index()) # --> 0,...,num_hosts-1 print(jax.process_count()) # --> num_hosts print( len(jax.devices()) == len(jax.local_devices()) ) # --> False in multi-host setting with single_host_jax_and_orbax(): print(jax.process_index()) # --> 0 print(jax.process_count()) # --> 1 print(len(jax.devices()) == len(jax.local_devices())) # --> True
- mlip.utils.multihost.assert_pytrees_match_across_hosts(tree: Any) None¶
Assert that the provided PyTree matches across multiple JAX hosts/processes.
If there are multiple JAX processes, this function checks that the PyTree
Aon the current host matches the PyTree on the host with process index 0. If there’s only one JAX process, the function returns immediately without any checks.- Parameters:
tree (Any) – The PyTree to check for consistency across hosts.
- Raises:
ValueError – If the PyTrees do not match across the hosts with a detailed path
of differences. –
- mlip.utils.multihost.assert_pytrees_match(a: Any, b: Any) None¶
Assert that the two provided PyTrees have matching structures and values.
PyTrees are a way to flexibly handle nested data structures in the JAX library.
- Parameters:
a (Any) – The first PyTree.
b (Any) – The second PyTree.
- Raises:
ValueError – If the PyTrees do not match with a detailed path of differences.