Model training¶
To start a model training, there are the following prerequisites:
Loading and preprocessing a training and validation dataset, as described here.
Initializing a force field, as described here.
Setting up a loss (see below for details).
Setting up an optimizer (see below for details).
Creating an instance of
TrainingLoopConfig
(can be accessed viaTrainingLoop.Config
, too).Optionally (has a default): Creating an instance of
TrainingIOHandler
(see below for details).
Once these objects are set up, we can create an instance of
TrainingLoop
and start
the training run:
from mlip.training import TrainingLoop
# Prerequisites
train_set, validation_set, dataset_info = _get_dataset() # placeholder
force_field = _get_force_field() # placeholder
loss = _get_loss() # placeholder
optimizer = _get_optimizer () # placeholder
io_handler = _get_training_loop_io_handler() # placeholder
config = TrainingLoop.Config(**config_kwargs)
# Create TrainingLoop class
training_loop = TrainingLoop(
train_dataset=train_set,
validation_dataset=validation_set,
force_field=force_field,
loss=loss,
optimizer=optimizer,
config=config,
io_handler=io_handler, # also has a default, does not need to be set
should_parallelize=len(jax.devices()) > 1, # has a default of False
)
# Start the model training
training_loop.run()
The final TrainingState
can be accessed after the run like this:
final_training_state = training_loop.training_state
final_params = final_training_state.params
However, the final parameters are not always the ones with the best
performance on the validation set, and hence,
you can also access these with training_loop.best_params
or directly
use training_loop.best_model
to get the
ForceField
instance that holds
the best parameters. If you want to save a
trained force field not only via the checkpointing API described further below,
you can also use the function
save_model_to_zip()
to save it
as a lightweight zip archive in case you only want to use it for inference tasks later,
as this archive does not include any training state.
Note that it is also possible to run an evaluation on a test dataset after training by
using the
test()
method of the
TrainingLoop
instance.
In the following, we describe the prerequisites listed above in more detail.
Loss¶
All losses must be implemented as derived classes of
Loss
. We currently implement two losses, the
Mean-Squared-Error loss (Loss
), and the
Huber loss (Loss
), which are both losses
that are derived from a loss that computes errors for energies, forces, and stress,
and weights them according to some weighting schedule that can depend on the epoch
number (base class: Loss
).
If one wants to use the MSE loss for training, simply run this code to initialize it:
import optax
from mlip.models.loss import MSELoss
# uses default weight schedules
loss = MSELoss()
# uses a weight flip schedule
energy_weight_schedule = optax.piecewise_constant_schedule(1.0, {100: 25.0})
forces_weight_schedule = optax.piecewise_constant_schedule(25.0, {100: 0.04})
loss = MSELoss(energy_weight_schedule, forces_weight_schedule)
For our two implemented losses, we also allow for computation of more extended metrics
by setting the extended_metrics
argument to True
in the loss constructor.
By default, it is False
.
Optimizer¶
The optimizer provided to the
TrainingLoop
can be any Optax optimizer,
however, this library also has a specialized pipeline that has been inspired by
this PyTorch MACE implementation.
It is configurable via a
OptimizerConfig
object that
has sensible defaults set for training MLIP models.
This default MLIP optimizer can be set up like this:
from mlip.training import get_default_mlip_optimizer, OptimizerConfig
# with default config
optimizer = get_default_mlip_optimizer()
# with modified config
optimizer = get_default_mlip_optimizer(OptimizerConfig(**config_kwargs))
See the API reference for
get_default_mlip_optimizer
and
OptimizerConfig
for further details on how this MLIP optimizer works internally.
IO handling and logging¶
During training, we want to allow for checkpointing of the training state and logging
of metrics. The
TrainingIOHandler
class manages these tasks. It comes with its own config, the
TrainingIOHandlerConfig
,
which like most other configs in the library can be accessed
via TrainingIOHandler.Config
. The IO handler uses
Orbax Checkpointing
to save and restore model checkpoints. Also, for loading a trained model for simulations or
other inference tasks, this library relies on loading these model checkpoints
(see load_parameters_from_checkpoint()
).
The local checkpointing location can be set in the config, however, uploading these checkpoints
to remote storage locations can be achieved via a provided data upload function:
import os
from mlip.training import TrainingIOHandler
io_config = TrainingIOHandler.Config(**config_kwargs)
def remote_storage_sync_fun(source: str | os.PathLike) -> None:
"""Makes sure local data in source is uploaded to remote storage"""
pass # placeholder
io_handler = TrainingIOHandler(io_config, remote_storage_sync_fun)
Locally, after the training run has started,
the checkpointing location will contain a dataset_info.json
file with
the saved DatasetInfo
object, and a model
subdirectory with all the model checkpoints, one for
each epoch that had the best model up to that point judging by validation set loss.
In this location, it is recommended to also save other metadata manually,
such as the applied model config.
For advanced logging, e.g., to an experiment tracking platform (such as Neptune), one can also attach custom logging functions to the IO handler:
mlip.training.training_io_handler import LogCategory
def train_logging_fun(
category: LogCategory, to_log: dict[str, Any], epoch_number: int
) -> None:
"""Advanced logging function"""
pass # placeholder
io_handler.attach_logger(train_logging_fun)
See the documentation of
LogCategory
for more details on what type of data can be logged with such a logger during training.
Furthermore, this library provides built-in logging functions that can be attached
to the IO handler,
log_metrics_to_table()
,
which prints the training metrics to the console in a nice table format (using
Rich tables), or
log_metrics_to_line()
,
which logs the metrics in a single line.
Note that it is possible to omit the io_handler
argument in the
TrainingLoop
class. In that case,
a default IO handler is set up internally and used. This IO handler does not include
checkpointing, but it does have the
log_metrics_to_line()
logging function attached by default.