Model training

To start a model training, there are the following prerequisites:

  • Loading and preprocessing a training and validation dataset, as described here.

  • Initializing a force field, as described here.

  • Setting up a loss (see below for details).

  • Setting up an optimizer (see below for details).

  • Creating an instance of TrainingLoopConfig (can be accessed via TrainingLoop.Config, too).

  • Optionally (has a default): Creating an instance of TrainingIOHandler (see below for details).

Once these objects are set up, we can create an instance of TrainingLoop and start the training run:

from mlip.training import TrainingLoop

# Prerequisites
train_set, validation_set, dataset_info = _get_dataset()  # placeholder
force_field = _get_force_field()  # placeholder
loss = _get_loss()  # placeholder
optimizer = _get_optimizer ()  # placeholder
io_handler = _get_training_loop_io_handler()  # placeholder
train_config = TrainingLoop.Config(**config_kwargs)

# Create TrainingLoop class
training_loop = TrainingLoop(
    train_dataset=train_set,
    validation_dataset=validation_set,
    force_field=force_field,
    loss=loss,
    optimizer=optimizer,
    config=train_config,
    io_handler=io_handler,  # also has a default, does not need to be set
)

# Start the model training
training_loop.run()

The final TrainingState can be accessed after the run like this:

final_training_state = training_loop.training_state
final_params = final_training_state.params

Important: The final parameters are not always the ones with the best performance on the validation set, and hence, you can also access these with training_loop.best_model.params. Therefore, use training_loop.best_model to get the ForceField instance that holds the best parameters. If you want to save a trained force field not only via the checkpointing API described further below, you can also use the function save_model_to_zip() to save it as a lightweight zip archive in case you only want to use it for inference or simulation tasks later, as this archive does not include any training state.

Note that it is also possible to run an evaluation on a test dataset after training by using the test() method of the TrainingLoop instance. Moreover, we support evaluating on multiple validation sets separately during training, see this section below for details.

In the following, we describe the prerequisites listed above in more detail.

Loss

All losses must be passed as Loss classes. This class takes in a list of LossTerm classes and a list of corresponding schedules (functions that map an epoch number to a weight for the loss term). Custom LossTerm classes can be added easily via inheritance, however, usually this is not required (see information on built-in losses below). A simple example for a custom loss term implementation is provided below.

from mlip.models.loss import Loss, LossTerm
from mlip.graph import Graph

class CustomLossTerm(LossTerm):
"""A simple custom MAE energy loss."""

property_name = "energy"

def __call__(self, pred_graph: Graph, ref_graph: Graph) -> float:
    """Outputs an MAE energy loss."""
    return np.mean(np.abs(pred_graph.globals.energy - ref_graph.globals.energy))

# Instantiate a Loss with this custom loss term
loss = Loss([CustomLossTerm()], [lambda epoch_number: 1.0])

For convenience, we implement two losses, the Mean-Squared-Error loss (MSELoss), and the Huber loss (HuberLoss), which are both derived classes of Loss and already include all loss terms corresponding to the loss terms available in this library (for energies, forces, stress, Hessians, atomic partial charges, total charge, and dipole moment). See the API reference for these classes for details.

Note

To predict non-default properties (e.g., Hessians or atomic partial charges), one must instantiate the ForceField with the correct required_properties argument such that it outputs these properties and can learn from the corresponding training labels.

For example, if a user wants to use the MSE loss for training that only includes energy and force matching objectives, simply run this code to initialize it:

import optax
from mlip.models.loss import MSELoss

# uses default weight schedules
loss = MSELoss()

# uses a weight flip schedule
energy_weight_schedule = optax.piecewise_constant_schedule(1.0, {100: 25.0})
forces_weight_schedule = optax.piecewise_constant_schedule(25.0, {100: 0.04})
loss = MSELoss(energy_weight_schedule, forces_weight_schedule)

For our two implemented default losses, we also allow for computation of more extended metrics by setting the extended_metrics argument to True in the constructor. By default, it is False. See the implementation of the compute_eval_metrics() function (used inside the default losses) for details on the computed metrics.

Note that even though the loss class is supposed to provide these metrics averaged just over a given input batch, we reweight these metrics based on the number of real (not dummy) graphs per batch in the training loop, such that the resulting metrics that are logged during training are accurately averaged over the whole dataset.

Optimizer

The optimizer provided to the TrainingLoop can be any Optax optimizer, however, this library also has a specialized pipeline that has been inspired by this PyTorch MACE implementation. It is configurable via a OptimizerConfig object that has sensible defaults set for training MLIP models. However, we suggest also checking out our white paper for recommendations for sensible ways to adapt the defaults for specific models, for instance, ViSNet and NequIP seem to be more prone to NaNs with the default learning rate and benefit from using a smaller one, e.g., 1e-4.

The default MLIP optimizer can be set up like this:

from mlip.training import get_default_mlip_optimizer, OptimizerConfig

# with default config
optimizer = get_default_mlip_optimizer()

# with modified config
optimizer = get_default_mlip_optimizer(OptimizerConfig(**config_kwargs))

See the API reference for get_default_mlip_optimizer and OptimizerConfig for further details on the components of this MLIP optimizer and how it works internally.

IO handling and logging

Checkpointing

During training, we want to allow for checkpointing of the training state and logging of metrics. The TrainingIOHandler class manages these tasks. It comes with its own config, the TrainingIOHandlerConfig, which, like most other configs in the library, can be accessed via TrainingIOHandler.Config. The IO handler uses Orbax Checkpointing to save and restore model checkpoints. Also, for loading a trained model for simulations or other inference tasks, this library relies on loading these model checkpoints (see load_parameters_from_checkpoint()). The local checkpointing location can be set in the config class, however, uploading these checkpoints to remote storage locations can be achieved via a provided data upload function:

import os
from mlip.training import TrainingIOHandler

io_config = TrainingIOHandler.Config(**config_kwargs)

def remote_storage_sync_fun(source: str | os.PathLike) -> None:
    """Makes sure local data in source is uploaded to remote storage"""
    pass  # placeholder

io_handler = TrainingIOHandler(io_config, remote_storage_sync_fun)

Locally, after the training run has started, the checkpointing location will contain a dataset_info.json file with the saved DatasetInfo object, and a model subdirectory with all the model checkpoints, one for each epoch that had the best model up to that point judging by validation set loss. In this location, it is recommended to also save other metadata manually, such as the applied model config.

Note that if the checkpointing directory provided in the TrainingIOHandlerConfig is a Path`-like object, we will forward this object to the Orbax checkpointing code as is, i.e., any direct remote-storage checkpointing that is made available via Orbax is also indirectly supported by this library. See the Orbax documentation for details.

We also support intra-epoch checkpointing, which can be useful for running training on preemptible compute instances in the Cloud that require to checkpoint often. For details, see the API reference of TrainingIOHandlerConfig.

Logging

For advanced logging, e.g., to an experiment tracking platform (such as MLflow), one can also attach custom logging functions to the IO handler:

mlip.training.training_io_handler import LogCategory

def train_logging_fun(
    category: LogCategory, to_log: dict[str, Any], epoch_number: int
) -> None:
"""Advanced logging function"""
    pass  # placeholder

io_handler.attach_logger(train_logging_fun)

See the documentation of LogCategory for more details on what type of data can be logged with such a logger during training. Furthermore, this library provides built-in logging functions that can be attached to the IO handler, log_metrics_to_table(), which prints the training metrics to the console in a nice table format (using Rich tables), or log_metrics_to_line(), which logs the metrics in a single line.

These logging functions automatically convert any MSE metrics to RMSE for easier interpretation. Internally, we only keep track of MSE instead of RMSE because we must ensure that the square root is taken at the very end and not before any averaging across batches or devices happens. If one desires to do the same conversion in their custom logging function, see convert_mse_to_rmse_in_logs(), which is a helper function we provide for this task.

Note that it is possible to omit the io_handler argument in the TrainingLoop class. In that case, a default IO handler is set up internally and used. This IO handler does not include checkpointing, but it does have the log_metrics_to_line() logging function attached by default.

Multi-host training

The mlip library supports multi-host (multi-node) training via JAX’s built-in distributed runtime. When training across multiple hosts, each host manages its own local devices and coordinates with other hosts through a designated coordinator process. The training loop automatically handles data-parallel sharding of batches across all global devices.

To run multi-host training, you need to initialize the JAX distributed runtime before any other JAX calls. This is done via jax.distributed.initialize(). Below, we present a minimal example:

import jax

# Option 1: Automatic discovery (e.g. on SLURM clusters)
jax.distributed.initialize()

# Option 2: Explicit coordinator (e.g. on Kubernetes / custom setups)
jax.distributed.initialize(
    coordinator_address="host0:1234",
    num_processes=4,
    process_id=0,  # unique per process
)

# After initialization, jax.devices() returns all global devices
# and jax.local_devices() returns only this host's devices.

Once JAX distributed is initialized, the rest of the training code (dataset creation, training loop, checkpointing) works the same as in the single-host case, and the library handles the data sharding internally.

For more details, see the JAX multi-process documentation and the Flax distributed training guide.

Multiple validation sets

Instead of passing a single validation set to the training loop, we support passing a dictionary of different validation sets. This option does not change the training process, as the validation loss to decide whether to checkpoint a model will be computed as a weighted average over all given validation sets.

However, the evaluation metrics will be reported per set, i.e., each metric name will have a prefix of the validation set name. See the example below:

from mlip.training import TrainingLoop

validation_sets = {
    "organics": _get_organics_validation_set_placeholder(),
    "materials": _get_materials_validation_set_placeholder(),
}

training_loop = TrainingLoop(
    validation_dataset=validation_sets,
    **other_kwargs,
)

# Instead of metrics like "mse_f" for MSE of forces, one will now get
# "organics_mae_f" and "materials_mae_f" separately
training_loop.run()